<div class='alert alert-success'>

**JAX IMPLEMENTATION FOR CLASSIFICATION**

</div>

In [59]:
#IMPORTS
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

In [60]:
#Data Preparation
def load_data(test_test=0.3, random_state=42):
    iris = datasets.load_iris()
    X, y = iris.data, iris.target
    
    #test-train split
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=test_test, random_state=random_state
    )
    
    #Standardizing
    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train)
    X_test = scaler.transform(X_test)
    
    y_train = y_train.flatten()
    y_test = y_test.flatten()
    #print(f"X_train shape: {X_train.shape}")
    #print(f"y_train shape: {y_train.shape}")
    #print(f"X_test shape: {X_test.shape}")
    #print(f"y_test shape: {y_test.shape}")
    
    return (X_train, y_train), (X_test, y_test)

In [73]:
#Model definition
class MLP(nn.Module):
    #Number of hidden units
    features: int
    
    @nn.compact #Flax magic for concise layer definitions. 
    #Instead explicitly defining submodules in a separate method, 
    #this allows writing the network's logic in a single forward pass.
    def __call__(self, x):
        x = nn.Dense(self.features)(x)
        x = nn.relu(x)
        #x = nn.Dropout(0.2)(x, deterministic=False)
        x = nn.Dense(3)(x) #For the 3 classes
        return x

In [62]:
#Loss function
def cross_entropy_loss(logits, labels):
    #Converts labels into one-hot vectors 
    one_hot_labels = jax.nn.one_hot(labels, num_classes=3)
    #Compute and return the softmax cross-entropy
    return optax.softmax_cross_entropy(logits, one_hot_labels).mean()

In [63]:
def create_train_state(rng, lr=0.01):
    #Model initialization 
    model = MLP(features=10) #i.e. 10 hidden states
    #Input shape is (batch_size, 4 features)
    init_rng, _ = jax.random.split(rng)
    params = model.init(init_rng, jnp.ones([1, 4]))['params']
    
    #Optimizer
    tx = optax.adam(lr)
    
    return train_state.TrainState.create(
        apply_fn=model.apply, #Function to apply the model
        params=params, #Initial parameters
        tx=tx #Optimizer
    )

In [72]:
#Training
@jax.jit #compile to XLA for faster execution
#This means that just-in-time compilation speeds up the execution by
#compiling function to XLA(Accelerated Linear Algebra)
#Do to optimize performance, reduce overhead, and cache results
def train_step(state, batch):
    def loss_fn(params):
        #rng = jax.random.PRNGKey(0)
        logits = state.apply_fn({'params': params}, batch['features'])
        loss = cross_entropy_loss(logits, batch['labels'])
        return loss, logits #Returning the loss for gradient computations
    
    #Gradient computations and parameter updates
    #Compute loss + gradients
    #value_and_grad is JAX's auto-diff to compute gradients
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True) 
    (loss, logits), grads =  grad_fn(state.params)
    
    #update parameters using the gradients
    state = state.apply_gradients(grads=grads)
    return state, loss

def train_epoch(state, train_data, batch_size=32):
    #Convert to JAX arrays
    features = jnp.array(train_data[0])
    labels = jnp.array(train_data[1])
    
    assert labels.ndim==1, "Labels must be 1D"
    
    #Batches creation
    #Shuffling for randomness i.e., jax.random.permutation
    dataset_size = features.shape[0]
    indices = jax.random.permutation(jax.random.PRNGKey(0), dataset_size)
    
    #Processing data in batches
    for start_idx in range(0, dataset_size, batch_size):
        end_idx = start_idx + batch_size
        batch = {
            'features': features[indices[start_idx:end_idx]],
            'labels': labels[indices[start_idx:end_idx]]
        }
        #Sanity checks
        #print(f"Batch features shape: {batch['features'].shape}")
        #print(f"Batch labels shape: {batch['labels'].shape}")
        state, loss = train_step(state, batch)
        
    return state, loss

In [68]:
#Evaluation
#@jax.jit
def compute_accuracy(params, apply_fn, features, labels):
    logits = apply_fn({'params': params}, features) #The forward pass
    predicted = jnp.argmax(logits, axis=1) #Class with highes score
    return jnp.mean(predicted == labels) #Accuracy

compute_accuracy = jax.jit(compute_accuracy, static_argnums=(1,))
#jnp.argmax convert logits to class predictions
#jnp.mean calculate the fraction of correct predictions

In [74]:
#Main
def main():
    #load data
    (X_train, y_train), (X_test, y_test) = load_data()
    #Sanity checks
    #print(f"X_train shape: {X_train.shape}")
    #print(f"X_test shape: {X_test.shape}")
    
    #Training state initialization
    #PRNHKey(0): Seed for JAX's random number generator
    rng = jax.random.PRNGKey(0)
    state = create_train_state(rng)
    
    #Training params
    num_epochs = 100
    
    #Training loop
    for epoch in range(num_epochs):
        state, loss = train_epoch(state, (X_train, y_train))
        #Print metrics for every 10 epochs
        if epoch%10 == 0:
            train_acc = compute_accuracy(
                state.params, state.apply_fn, X_train, y_train
            )
            test_acc = compute_accuracy(
                state.params, state.apply_fn, X_test, y_test
            )
            print(f"Epoch: {epoch:3d}, loss: {loss:.4f}, "
                  f"Train Acc: {train_acc.item():.2%}, Test Acc: {test_acc.item():.2%}")

if __name__ == "__main__":
    main()

Epoch:   0, loss: 1.4103, Train Acc: 61.90%, Test Acc: 66.67%
Epoch:  10, loss: 0.5088, Train Acc: 81.90%, Test Acc: 82.22%
Epoch:  20, loss: 0.2462, Train Acc: 89.52%, Test Acc: 86.67%
Epoch:  30, loss: 0.1027, Train Acc: 95.24%, Test Acc: 97.78%
Epoch:  40, loss: 0.0615, Train Acc: 95.24%, Test Acc: 97.78%
Epoch:  50, loss: 0.0491, Train Acc: 95.24%, Test Acc: 97.78%
Epoch:  60, loss: 0.0441, Train Acc: 95.24%, Test Acc: 97.78%
Epoch:  70, loss: 0.0414, Train Acc: 95.24%, Test Acc: 97.78%
Epoch:  80, loss: 0.0399, Train Acc: 95.24%, Test Acc: 100.00%
Epoch:  90, loss: 0.0390, Train Acc: 96.19%, Test Acc: 100.00%
