<a href="https://colab.research.google.com/github/95-sanya-95/Summer_ML_internship/blob/main/MNIST_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
pip install jax jaxlib dm-haiku optax tensorflow-datasets

Collecting dm-haiku
  Downloading dm_haiku-0.0.12-py3-none-any.whl (371 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m371.7/371.7 kB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
Collecting jmp>=0.0.2 (from dm-haiku)
  Downloading jmp-0.0.4-py3-none-any.whl (18 kB)
Installing collected packages: jmp, dm-haiku
Successfully installed dm-haiku-0.0.12 jmp-0.0.4


In [12]:
import jax.numpy as jnp
from sklearn.model_selection import train_test_split
from tensorflow.keras.datasets import mnist

# Load and preprocess the MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Normalize pixel values to range [0, 1] and convert to float32
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

# Split data into training and validation sets
x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.1, random_state=42)

# Convert numpy arrays to JAX arrays
x_train = jnp.array(x_train)
x_val = jnp.array(x_val)
y_train = jnp.array(y_train)
y_val = jnp.array(y_val)

print(x_train.shape)
print(y_train.shape)

print(type(x_train))
print(type(y_train))

(54000, 28, 28)
(54000,)
<class 'jaxlib.xla_extension.ArrayImpl'>
<class 'jaxlib.xla_extension.ArrayImpl'>


In [13]:
import jax
import haiku as hk

class MNIST_model(hk.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.num_classes = num_classes

    def __call__(self, x):
        x = hk.Conv2D(output_channels=32, kernel_shape=3, stride=1, padding='SAME')(x)
        x = jax.nn.relu(x)
        x = hk.MaxPool(window_shape=2, strides=1, padding='SAME')(x)

        x = hk.Conv2D(output_channels=64, kernel_shape=3, stride=1, padding='SAME')(x) # increase the number of channels bcz dense layers learn more precise features
        x = jax.nn.relu(x)
        x = hk.MaxPool(window_shape=2, strides=1, padding='SAME')(x)

        x = hk.Flatten()(x) # converting the data into a single column

        x = hk.Linear(64)(x)
        x = jax.nn.relu(x)

        x = hk.Linear(64)(x)
        x = jax.nn.relu(x)

        x = hk.Linear(self.num_classes)(x)
        x = jax.nn.softmax(x)  # Apply softmax activation
        return x


In [14]:
def forward_fn(x):
    model = MNIST_model(num_classes = 10) # since there are 10 different numbers
    return model(x)

forward = hk.transform(forward_fn)

In [15]:
rng = jax.random.PRNGKey(42)
x_sample = x_train[:1]
params = forward.init(rng, x_sample)

In [21]:
print(y_train.shape)
print(y_train[90].shape) # just testing any random
print(y_train[120])

(54000,)
()
5


In [23]:
import math

def loss_fn(params, x, y):
    predictions = forward.apply(params, None, x)
    batch_sz = predictions.shape[0]
    sum = 0
    for i in range(batch_sz):
        label = y[i].astype(int)

        sum = sum + jnp.log(predictions[i][label])

    return -sum/batch_sz


In [30]:
print(x_test.shape)
print(y_test.shape)
print(x_train.shape)
print(y_train.shape)

(10000, 28, 28)
(10000,)
(54000, 28, 28)
(54000,)


In [28]:
import optax
# Initialize optimizer
optimizer = optax.adam(1e-3)
opt_state = optimizer.init(params)

@jax.jit
def update(params, opt_state, x, y):
    grads = jax.grad(loss_fn)(params, x, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)

    return new_params, opt_state

num_epochs = 1
batch_size = 64
num_batches = x_train.shape[0] # batch_size

for epoch in range(num_epochs):
    for batch_idx in range(num_batches):
        start_idx = batch_idx * batch_size
        end_idx = start_idx + batch_size
        x_batch = x_train[start_idx:end_idx]
        y_batch = y_train[start_idx:end_idx]

        params, opt_state = update(params, opt_state, x_batch, y_batch)

        if batch_idx % 100 == 0:
            prediction = forward.apply(params,None,x_batch)
            accuracy = jnp.mean(jnp.argmax(prediction, axis=-1) == y_batch)
            print("Accuracy: ")
            print(accuracy)
            train_loss = loss_fn(params, x_batch, y_batch)
            val_loss = loss_fn(params, x_val, y_val)


            print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{num_batches}, Train Loss: {train_loss}, Val Loss: {val_loss}")

test_loss = loss_fn(params, x_test, y_test)
print(f"Test Loss: {test_loss}")

Accuracy: 
1.0
Epoch 1/1, Batch 0/54000, Train Loss: 0.021468909457325935, Val Loss: 0.1571112424135208
Accuracy: 
1.0
Epoch 1/1, Batch 100/54000, Train Loss: 0.03524559363722801, Val Loss: 0.14852117002010345
Accuracy: 
0.96875
Epoch 1/1, Batch 200/54000, Train Loss: 0.13498039543628693, Val Loss: 0.15915970504283905
Accuracy: 
0.984375
Epoch 1/1, Batch 300/54000, Train Loss: 0.058817241340875626, Val Loss: 0.13168548047542572
Accuracy: 
0.9375
Epoch 1/1, Batch 400/54000, Train Loss: 0.22844427824020386, Val Loss: 0.13621315360069275
Accuracy: 
1.0
Epoch 1/1, Batch 500/54000, Train Loss: 0.03560644015669823, Val Loss: 0.11376401036977768
Accuracy: 
0.953125
Epoch 1/1, Batch 600/54000, Train Loss: 0.08900947123765945, Val Loss: 0.11280450224876404
Accuracy: 
0.984375
Epoch 1/1, Batch 700/54000, Train Loss: 0.05749135836958885, Val Loss: 0.12777554988861084
Accuracy: 
0.984375
Epoch 1/1, Batch 800/54000, Train Loss: 0.036359671503305435, Val Loss: 0.11517132818698883


ZeroDivisionError: division by zero

In [25]:
num_samples = 5  # Number of test cases to show predictions for
for i in range(num_samples):
    x_sample = x_test[i:i+1]
    y_true = y_test[i]
    logits = forward.apply(params, None, x_sample)
    prediction = jnp.argmax(logits, axis=-1)[0]
    print(f"Sample {i+1}: Prediction = {prediction}, True Label = {y_true}")

Sample 1: Prediction = 7, True Label = 7
Sample 2: Prediction = 2, True Label = 2
Sample 3: Prediction = 1, True Label = 1
Sample 4: Prediction = 0, True Label = 0
Sample 5: Prediction = 4, True Label = 4
