In [20]:
import jax
from jax import random, numpy as jnp
import numpy as np
from jax.tree_util import tree_map
from flax import linen as nn
from flax.training import train_state
import optax
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split, default_collate

In [21]:
class LeNet(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=6, kernel_size=(5, 5), padding="SAME")(x)
        x = nn.sigmoid(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=16, kernel_size=(5, 5), padding="VALID")(x)
        x = nn.sigmoid(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((-1, 400))
        x = nn.Dense(120)(x)
        x = nn.sigmoid(x)
        x = nn.Dense(84)(x)
        x = nn.sigmoid(x)
        x = nn.Dense(10)(x)
        return x
    
def create_train_state(rng, learning_rate):
    model = LeNet()
    params = model.init(rng, jnp.ones((1, 28, 28, 1))) # correct w/ channel dims?
    optimizer = optax.adam(learning_rate)
    return train_state.TrainState.create(
        apply_fn=model.apply, params=params, tx=optimizer
    )

In [25]:
def numpy_collate(batch):
    return tree_map(np.asarray, default_collate(batch))
    
class NumpyLoader(DataLoader):
    def __init__(self, dataset, batch_size=1,
                shuffle=False, sampler=None,
                batch_sampler=None, num_workers=0,
                pin_memory=False, drop_last=False,
                timeout=0, worker_init_fn=None):
        super(self.__class__, self).__init__(dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            sampler=sampler,
            batch_sampler=batch_sampler,
            num_workers=num_workers,
            collate_fn=numpy_collate,
            pin_memory=pin_memory,
            drop_last=drop_last,
            timeout=timeout,
            worker_init_fn=worker_init_fn)

# NOTE that we cannot flatten, it expects (batch_size, 28, 28, 1) shape
# this is a difference between the torch conv and the flax conv,
# you have to give it inputs with the extra channel dim at the end
class Cast(object):
    def __call__(self, pic):
        return np.array(pic, dtype=jnp.float32)

def add_channel_dim(image):
    return np.expand_dims(image, -1)
    
transform = transforms.Compose([
    transforms.ToTensor(),
    Cast(),
    add_channel_dim # adds channel dimension so that we can increase channel nums when we convolve
])

"""
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_dataset, val_dataset = random_split(train_dataset, [55000, 5000])
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)
"""

train_dataset = datasets.MNIST(root='../mnistMLP/data', train=True, download=True, transform=transform)
train_dataset, val_dataset = random_split(train_dataset, [55000, 5000])
test_dataset = datasets.MNIST(root='../mnistMLP/data', train=False, transform=transform)

train_loader = NumpyLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = NumpyLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = NumpyLoader(test_dataset, batch_size=32, shuffle=False)

In [26]:
rng = random.PRNGKey(12)
state = create_train_state(rng, learning_rate=0.001)

def validate(state, val_loader):
    val_loss = 0
    correct = 0
    total_samples = 0
    for data, target in val_loader:
        target_one_hot = jax.nn.one_hot(target, num_classes=10)
        logits = state.apply_fn(state.params, data)
        loss = optax.softmax_cross_entropy(logits, target_one_hot).mean()
        val_loss += loss.item() * data.shape[0]  # Scale loss by batch size
        # Compute accuracy
        predicted_class = jnp.argmax(logits, axis=1)
        correct += jnp.sum(predicted_class == target)
        total_samples += data.shape[0]
    val_loss /= total_samples
    accuracy = 100. * correct / total_samples
    print(f'Validation set: Average loss: {val_loss:.4f}, Accuracy: {accuracy:.2f}%')

@jax.jit
def train_step(state, batch):
    data, target = batch
    target_one_hot = jax.nn.one_hot(target, num_classes=10)
    def loss_fn(params):
        logits = state.apply_fn(params, data)
        loss = optax.softmax_cross_entropy(logits, target_one_hot).mean()
        return loss
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    return state.apply_gradients(grads=grads), loss

def train(epoch, state, log_interval=200):
    for batch_idx, batch in enumerate(train_loader):
        state, loss = train_step(state, batch)
        if batch_idx % log_interval == 0:
            loss_val = loss.item() if hasattr(loss, 'item') else loss
            print(f'Logging Interval - Train Epoch: {epoch} Batch Idx: {batch_idx} \tLoss: {loss_val:.6f}')
    return state

num_epochs = 5

for epoch in range(1, num_epochs + 1):
    state = train(epoch, state)
    #validate(state, val_loader)
validate(state, test_loader)

Logging Interval - Train Epoch: 1 Batch Idx: 0 	Loss: 2.328917
Logging Interval - Train Epoch: 1 Batch Idx: 200 	Loss: 2.051440
Logging Interval - Train Epoch: 1 Batch Idx: 400 	Loss: 0.629653
Logging Interval - Train Epoch: 1 Batch Idx: 600 	Loss: 0.301854
Logging Interval - Train Epoch: 1 Batch Idx: 800 	Loss: 0.496858
Logging Interval - Train Epoch: 1 Batch Idx: 1000 	Loss: 0.345888
Logging Interval - Train Epoch: 1 Batch Idx: 1200 	Loss: 0.258772
Logging Interval - Train Epoch: 1 Batch Idx: 1400 	Loss: 0.139123
Logging Interval - Train Epoch: 1 Batch Idx: 1600 	Loss: 0.376808
Logging Interval - Train Epoch: 2 Batch Idx: 0 	Loss: 0.082918
Logging Interval - Train Epoch: 2 Batch Idx: 200 	Loss: 0.178216
Logging Interval - Train Epoch: 2 Batch Idx: 400 	Loss: 0.113553
Logging Interval - Train Epoch: 2 Batch Idx: 600 	Loss: 0.105725
Logging Interval - Train Epoch: 2 Batch Idx: 800 	Loss: 0.131342
Logging Interval - Train Epoch: 2 Batch Idx: 1000 	Loss: 0.122676
Logging Interval - Train