## Classification with ANN

So far, you've seen how to define a model using jax, train the model, do backpropagation, and test it. Let's take it a bit further. In this notebook we'll be writing a simple ANN to classify penguin species using the [palmers penguin dataset](https://github.com/mcnakhaee/palmerpenguins).

Let's load the data and process it to get started. (apparently the main focus of this notebook is to show you the ANN and not how you can load data, so, I'm going to short circuit the whole process a bit)

### Data

In [None]:
import numpy as np
import pandas as pd
from palmerpenguins import load_penguins

In [None]:
import jax
import jax.numpy as jnp

from tqdm.auto import tqdm
from sklearn import preprocessing


def load_data():
    penguins = load_penguins() # penguins is a dataframe
    penguins = penguins.dropna() # type: ignore
    
    # print the head of the dataframe to give some view
    print(penguins.head()) # type: ignore
    
    # collect the feature columns
    feature_columns = ['bill_length_mm', 'bill_depth_mm', 'flipper_length_mm', 'body_mass_g']
    # classification target
    target_column = "species"
    
    # features and targets
    features = penguins[feature_columns].values  # type: ignore
    targets = penguins[target_column].values  # type: ignore
    
    # but here's a catch
    # the targets are categorical, so we have two two options here
    # one hot encode them, or, assign a numeric value to them and keep a dictionary
    # with the target label to int mapping
    # the second approach is easier xD
    
    target_ids_dict = dict()
    unique_target_labels = set(targets)
    _id = 0
    
    for ul in unique_target_labels:
        target_ids_dict[ul] = _id
        _id += 1
        
    # convert target labels to integers using the same dict
    def convert_label_to_ids(targets, id_dict):
        converted_targets = np.zeros(shape=(len(targets, )), dtype=np.int32)
        for idx, target in tqdm(enumerate(targets)):
            converted_targets[idx] = id_dict[target]

        return converted_targets
    
    targets_converted = convert_label_to_ids(
        targets=targets, id_dict=target_ids_dict)
    
    assert features.shape[0] == targets_converted.shape[0]
    
    # the features from the dataset are not normalised and
    # this can cause probblems during training, such as 
    # gradients getting stuck in a local minima
    # there's a lot of literature which talks about the 
    # necessity of normalisation, this is a good starter
    # https://machinelearningmastery.com/how-to-improve-neural-network-stability-and-modeling-performance-with-data-scaling/
    
    features_norm = preprocessing.normalize(features, norm="l2")
    
    
    return (features_norm, targets_converted)
    
    
X, y = load_data()
print(X.shape) # type: ignore
print(y.shape) # type: ignore

Okay splendid. Now to create the data split and also convert these numpy arrays to jax arrays. 

In [None]:
from sklearn.model_selection import train_test_split

features_train, features_test, targets_train, targets_test = train_test_split(
    X, y, test_size=0.3, random_state=42)


features_train = jnp.array(features_train)
features_test = jnp.array(features_test)
targets_train = jnp.array(targets_train)
targets_test = jnp.array(targets_test)


print(f"Train Size : {features_train.shape[0]}")
print(f"Test Size: {features_test.shape[0]}")

### ANN

This is going to be a 2 layer ANN (with bias) and ReLU as activation, 3 target classes and cross entropy as the loss function.

#### PRNG

In [None]:
# old ritual of generating prngs
key = jax.random.PRNGKey(42)
key, *subkeys = jax.random.split(key, num=10)

#### ANN definition

In [None]:
# inits params
# w, b
# init strategy : Kaiming
def ann(in_features, hidden_features, out_features, *prngs):
    scale_factor = jnp.sqrt(2/in_features)
    
    # layer 1
    w1 = jax.random.normal(prngs[0], (in_features, hidden_features)) * scale_factor
    b1 = jax.random.normal(prngs[1], (1, hidden_features))
    
    # layer 2
    w2 = jax.random.normal(
        prngs[2], (hidden_features, out_features)) * scale_factor
    b2 = jax.random.normal(
        prngs[3], (1, out_features))
    
    return {
        "w1": w1, 
        "b1": b1, 
        "w2": w2,
        "b2": b2
    }
    


params = ann(4, 4, 3, *subkeys)
params

#### Forward Pass

In [None]:
# forward pass
@jax.jit
def forward(params, x): 
    # from layer 1
    # xW + b
    out1 = x @ params["w1"] + params["b1"]
    out1 = jax.nn.relu(out1)
    
    # layer 2
    out2 = out1 @ params["w2"] + params["b2"]
    out2 = jax.nn.relu(out2)
    
    # apply softmax to convert to probability dist
    # since the loss function is cross entropy
    logits = jax.nn.softmax(out2)
    
    return logits

#### Loss function and grad

The formal definition of cross entropy loss for a multiclass classification is this: 
$$
ce = -\sum_{c=1}^My_{t}\log(p_{t})
$$

where, $t$ stands for the correct class

$y_t$ is the correct label and $p_t$ is what a model predicted for $t$

There is a nifty trick to it if you represent your classes with int ids as I have done above. The ids start from 0, so you can basically treat them as indexes. Using this, the cross entropy for an instance basically becomes

$$
ce = -ln(p_t)
$$

This trick works fine for single dimension multi class probabilities. I have never verified it outside course assignments or simple experiments. Then again, this notebook is just here to show you how jax works. In practice, it'll be buckwild to write everything from scratch. *Don't violate the DRY principle!*

In [None]:
# loss function
# for a single instance
# will vmap for batches

@jax.jit
def cross_entropy(params, x, y):
    logits = forward(params, x)[0] # these arrays ....
    
    return -jnp.log(logits[y])


cross_entropy(params, features_train[0], targets_train[0])


In [None]:
@jax.jit
def calculate_loss(params, x_batched, y_batched):
    batch_loss = jax.vmap(cross_entropy, in_axes=(None, 0, 0))(params, x_batched, y_batched)
    return jnp.mean(batch_loss)


calculate_loss(params, features_train, targets_train)


#### Update Function

In [None]:
# update params during training
# SGD
# fixed learning rate, this ain't kaggle
#@jax.jit
def update(params, gradients, lr=1e-3):
    return jax.tree_map(
        lambda p, g: p - lr * g, params, gradients
    )

#### Train-ing
Finally! Der Zug ist da! (Oh wait..........no, it's late again.)

In [None]:
loss_grad_fn = jax.value_and_grad(calculate_loss)

In [None]:
# batched
@jax.jit
def train_step(params, x_batched, y_batched):
    loss_value, grad = loss_grad_fn(params, x_batched, y_batched)
    params = update(params, grad)
    
    return params, loss_value
    

In [None]:
from tqdm.auto import trange


def train(params, x_batched, y_batched, epochs, log_every_n_step):
    losses = list() # to keep track of losses per epoch
    steps = list()
    
    step_counter = 0
    
    for _ in trange(epochs):
        params, loss_val = train_step(params, x_batched, y_batched)
        
        # log
        if step_counter % log_every_n_step == 0:
            losses.append(loss_val)
            steps.append(step_counter)
        step_counter += 1
        
    return params, losses, steps

In [None]:
trained_params, losses, steps = train(params, features_train, targets_train, 1000, 50)

In [None]:
trained_params

In [None]:
jax.tree_map(
    lambda p: print(p), trained_params
)

In [None]:
trained_params