LOADING DATASET

READING MNIST DATASET

In [1]:
import struct

from jax import numpy as jnp
import numpy as np

def load_mnist_images(filename):
    with open(filename, 'rb') as f:
        _, num, rows, cols = struct.unpack('>IIII', f.read(16))
        images = np.fromfile(f, dtype=np.uint8).reshape(num, rows, cols)
    return images

def load_mnist_labels(filename):
    with open(filename, 'rb') as f:
        _, num = struct.unpack('>II', f.read(8))
        labels = np.fromfile(f, dtype=np.uint8)
    return labels

# Load the data
train_images = load_mnist_images('MNIST_DATASET/train-images.idx3-ubyte')
train_labels = load_mnist_labels('MNIST_DATASET/train-labels.idx1-ubyte')
test_images = load_mnist_images('MNIST_DATASET/t10k-images.idx3-ubyte')
test_labels = load_mnist_labels('MNIST_DATASET/t10k-labels.idx1-ubyte')

# Check the shapes
print(f"train_images shape: {train_images.shape}, train_labels shape: {train_labels.shape}")
print(f"test_images shape: {test_images.shape}, test_labels shape: {test_labels.shape}")

# Preparing train, test, val splits
val_mask = np.arange(0,60000)
np.random.shuffle(val_mask)
val_images = train_images[val_mask[:2000]]
val_labels = train_labels[val_mask[:2000]]
train_images = train_images[val_mask[2000:]]
train_labels = train_labels[val_mask[2000:]]

# Converting to jax.numpy arrays
test_images  = jnp.asarray(test_images)
test_labels  = jnp.asarray(test_labels)
val_images   = jnp.asarray(val_images)
val_labels   = jnp.asarray(val_labels)
train_images = jnp.asarray(train_images)
train_labels = jnp.asarray(train_labels)

train_images shape: (60000, 28, 28), train_labels shape: (60000,)
test_images shape: (10000, 28, 28), test_labels shape: (10000,)


Our 2 Layer NN

In [2]:
from NN import NN
model = NN(p_keep = 0.7315, reg = 0)

TRAIN/TEST FUNCTIONS

In [17]:
import math

import jax

def train(model, num_epochs = 1, verbose = True):
    batch_size = 512
    total_iters = math.ceil(len(train_labels)/batch_size)
    train_images_loader = jnp.array_split(train_images, total_iters)
    train_labels_loader = jnp.array_split(train_labels, total_iters)

    best_model = model
    best_acc = (0, 0)

    for epoch in range(num_epochs):
        for iter,(x_train,y_train) in enumerate(zip(train_images_loader, train_labels_loader)):
            model.train()
            logits, loss = model(x_train, y_train)
            model.adamStep()

            if (iter + 1) % 100 == 0:
                train_acc = testLogits(logits, y_train)
                val_acc = testInputs(model, val_images, val_labels)

                if verbose:
                    print(f"epoch: {epoch+1}  iter:{iter+1}  loss:{loss:.3f}  acc: {train_acc:.2f}  val_acc: {val_acc:.2f}")

                if val_acc > best_acc[0] or (val_acc == best_acc[0] and train_acc > best_acc[1]):
                    best_model = model
                    best_acc = (val_acc, train_acc)
                    if verbose:
                        print(f'BEST VAL: {val_acc:f}  TRAIN: {train_acc}')

    return best_model

@jax.jit
def testLogits(logits: jnp.ndarray, input_labels: jnp.ndarray):
    pred_labels = np.argmax(logits, axis = 1)
    train_acc = (pred_labels == input_labels).sum()/input_labels.shape[0]
    return train_acc

def testInputs(model: NN, input_data: jnp.ndarray, input_labels:jnp.ndarray):
    model.eval()
    logits = model(input_data)
    return testLogits(logits, input_labels)

In [18]:
def dryRun():
    test_model = NN()
    x_train = test_images
    y_train = test_labels

    test_model.train()
    logits, loss = test_model(x_train, y_train)
    test_model.adamStep()

    acc = testInputs(test_model, x_train, y_train)
    acc.block_until_ready()

dryRun()

In [19]:
model = train(model, num_epochs = 5)

epoch: 1  iter:100  loss:0.093  acc: 0.99  val_acc: 0.98
BEST VAL: 0.978000  TRAIN: 0.9940944910049438
epoch: 2  iter:100  loss:0.075  acc: 0.99  val_acc: 0.98
epoch: 3  iter:100  loss:-0.000  acc: 1.00  val_acc: 0.98
epoch: 4  iter:100  loss:0.173  acc: 0.99  val_acc: 0.98
epoch: 5  iter:100  loss:0.087  acc: 0.99  val_acc: 0.98


In [12]:
test_acc = testInputs(model, test_images, test_labels)
print(f"\nTEST ACC: {test_acc}")


TEST ACC: 0.9771999716758728
