In [1]:
import jax 
import jax.numpy as jnp
from jax import random, grad, jit, vmap, tree_util, nn
import optax

Since JAX follows the pure functional approach, so in order to track model's parameters we need to build a pytree containing parameters(weights and biases) for each layer which may seem unusual at first but that's the pure JAX way of tracking params instead of having them as attributes of objects of each layer class.

In [3]:
def init_model(key, input_dim, layer_dims):
    """Initializes the parameters for the Multi Layer Perceptron"""

    params = []
    keys = random.split(key, len(layer_dims))
    for in_dim, out_dim, key in zip([input_dim]+layer_dims[:-1],layer_dims,keys):
        w_key, b_key = random.split(key)
    
        w = nn.initializers.he_normal()(key, (in_dim, out_dim))
        b = jnp.zeros(out_dim)
        params.append({'w':w,'b':b})
    return params

In [4]:
def  apply_model(params, inputs):
    """Forward Pass of the model"""
    
    x = inputs 
  
    for layer in params[:-1]:
        x = x @ layer['w'] + layer['b']
        x = jax.nn.relu(x)

    final_layer = params[-1]
    
    x = x @ final_layer['w'] + final_layer['b']
    # return jax.nn.log_softmax(x)
    return x

Here we did not applied the `softmax` activation in the output layer because in the `loss_fn` we are calculating derivative of cross_entropy which gets cancelled out pretty nicely when paired with softmax resulting in much efficient calculations.

In [6]:
def loss_fn(params, inputs, targets):
    """Calculates the loss for the model"""
    logits = apply_model(params, inputs)

    return jnp.mean(optax.softmax_cross_entropy_with_integer_labels(logits = logits, labels = targets))

In [7]:
def update_fn(params, grads, learning_rate):
    """Updates the parameters using SGD update Formula"""    
    return tree_util.tree_map(
        lambda p,g : p - learning_rate * g,
        params,
        grads
    )

In [8]:
@jit
def train_step(params, inputs, targets, learning_rate = 0.001):
    
    grads = grad(loss_fn)(params, inputs, targets)
    return update_fn(params, grads, learning_rate)

In [9]:
@jit
def predict(params, inputs):
    
    return jnp.argmax(apply_model(params, inputs),axis=1)

Note :- In the training step the final layer's `softmax` activation was paired with `loss_fn` but in the predict we are not using softmax at all this is because applying softmax activation here won't change the model's output, the bigger values will get higher probability and smaller will get lower, so to make inference more efficient we didn't applied it here. However if needed to build the `confidence interval` we would have applied the `softmax` activation. 

In [11]:
import pandas as pd
import seaborn as sns
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

In [12]:
data = sns.load_dataset("penguins")
data.dropna(inplace=True)

le = LabelEncoder()
non_numerical_cols = ['island','sex']

for i in non_numerical_cols:
    data[i]=le.fit_transform(data[i])

In [13]:
X = data.drop('species',axis=1).values
y = le.fit_transform(data['species'])

scaler = StandardScaler()
X = scaler.fit_transform(X)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, random_state = 42)

In [14]:
import time

print("Started training...")
key = random.PRNGKey(0)
learning_rate = 0.01
epochs = 200

params = init_model(key, input_dim = 6, layer_dims=[10,3])

for epoch in range(epochs):
    start=time.time()
    params = train_step(params, X_train, y_train, learning_rate)
    end=time.time()

    if(epoch+1) % 20 == 0 or (epoch+1)==1:
        current_loss = loss_fn(params, X_train, y_train)
        print(f"Epoch {epoch+1}, Loss = {current_loss} , Time = {end-start}s")

print("Training finished.")


Started training...
Epoch 1, Loss = 1.6930781602859497 , Time = 0.11561083793640137s
Epoch 20, Loss = 1.2324516773223877 , Time = 0.0s
Epoch 40, Loss = 0.9929127097129822 , Time = 0.0s
Epoch 60, Loss = 0.8408139944076538 , Time = 0.0s
Epoch 80, Loss = 0.7315548062324524 , Time = 0.0s
Epoch 100, Loss = 0.6493933200836182 , Time = 0.0s
Epoch 120, Loss = 0.5852449536323547 , Time = 0.0s
Epoch 140, Loss = 0.5339083075523376 , Time = 0.0s
Epoch 160, Loss = 0.4920057952404022 , Time = 0.0s
Epoch 180, Loss = 0.4572613835334778 , Time = 0.0s
Epoch 200, Loss = 0.4281741976737976 , Time = 0.0s
Training finished.


As we can see above that the first run(epoch) of the training step took longer than subsequent calls because of `JIT compilation` of the `train_step` function

In [16]:
print("Testing...")
y_pred = predict(params, X_test)
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)

Testing...
Accuracy: 0.8656716417910447
