In [1]:
!pwd
!which python
!pip freeze | grep -E 'flax|jax|orbax|optax'

/Users/tristantorchet/Desktop/Code/VSCode/LearningJAX/Flax
/Users/tristantorchet/Desktop/Code/VSCode/LearningJAX/.venv/bin/python
flax==0.8.0
jax==0.4.25
jaxlib==0.4.25
jaxtyping==0.2.36
optax==0.1.8
orbax-checkpoint==0.5.0


In [23]:
import jax
import numpy as np
import torch
from jax import numpy as jnp
from torch.utils.data import TensorDataset
from tqdm import tqdm
import torchvision
import torchvision.transforms as transforms
from flax import linen as nn
from jax.nn.initializers import lecun_normal
from typing import Any, Tuple, Sequence, Optional

jnp.set_printoptions(precision=3, suppress=True)


In [24]:
# WARNING: this code is from QSSM project and won't be updated 
def create_mnist_classification_dataset(bsz=128, root="./data"):
    print("[*] Generating MNIST Classification Dataset...")

    # Constants
    SEQ_LENGTH, N_CLASSES, IN_DIM = 784, 10, 1
    tf = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(mean=0.5, std=0.5),
            transforms.Lambda(lambda x: x.view(IN_DIM, SEQ_LENGTH).t()),
        ]
    )

    train = torchvision.datasets.MNIST(
        root, train=True, download=True, transform=tf
    )
    test = torchvision.datasets.MNIST(
        root, train=False, download=True, transform=tf
    )

    def custom_collate_fn(batch):
        transposed_data = list(zip(*batch))
        labels = np.array(transposed_data[1])
        images = np.array(transposed_data[0])

        return images, labels       


    # Return data loaders, with the provided batch size
    trainloader = torch.utils.data.DataLoader(
        train, batch_size=bsz, shuffle=True, collate_fn=custom_collate_fn
    )
    testloader = torch.utils.data.DataLoader(
        test, batch_size=bsz, shuffle=False, collate_fn=custom_collate_fn
    )

    return trainloader, testloader, N_CLASSES, SEQ_LENGTH, IN_DIM


In [25]:
trainloader, testloader, N_CLASSES, SEQ_LENGTH, IN_DIM = create_mnist_classification_dataset(root="../data")

[*] Generating MNIST Classification Dataset...


In [26]:
batch_x, batch_y = next(iter(testloader))
print(batch_x.shape, batch_y.shape)
print(batch_y.dtype)
# convert batch_y to float

print(batch_y.dtype)


(128, 784, 1) (128,)
int64
int64


In [27]:
class MLP(nn.Module):
    features: Sequence[int]

    @nn.compact
    def __call__(self, x):
        for i, feat in enumerate(self.features[:-1]):
            x = nn.Dense(features=feat, name=f"dense_{i}")(x)
            x = nn.relu(x)
        x = nn.Dense(features=self.features[-1], name="output")(x)
        return x

In [28]:
from flax.training import train_state
import optax


def create_train_state(key, model_cls, lr):
    '''
    Create the training state for the model.
    '''
    model = model_cls(features=[128, 64, 10])
    params = model.init(key, jnp.ones((1, 784)))['params']
    # use adam 
    optimizer = optax.adam(learning_rate=lr)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)

In [29]:
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
lr = 1e-2
train_state = create_train_state(subkey, MLP, lr)
print(train_state.params.keys())


dict_keys(['dense_0', 'dense_1', 'output'])


In [30]:
print(train_state.params['dense_0']['kernel'].shape)
print(train_state.params['dense_0']['bias'].shape)
print(train_state.params['output']['kernel'].shape)
print(train_state.params['output']['bias'].shape)

(784, 128)
(128,)
(64, 10)
(10,)


In [31]:
@jax.jit
def update_model(state, grads):
    return state.apply_gradients(grads=grads)

In [32]:
def apply_model(state, images, labels):
    """Computes gradients, loss and accuracy for a single batch."""

    def loss_fn(params):
        logits = state.apply_fn({'params': params}, images)
        one_hot = jax.nn.one_hot(labels, 10)
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
        return loss, logits

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(state.params)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    return grads, loss, accuracy
    

In [35]:
def run_epoch(state, train_dl, rng):
    """Train for a single epoch."""

    epoch_loss = []
    epoch_accuracy = []
    progress_bar = tqdm(train_dl, desc="Training", leave=True)
    batch_id = 0
    for batch_images, batch_labels in progress_bar:
        grads, loss, accuracy = apply_model(state, batch_images.squeeze(-1), batch_labels)
        # print(jnp.max(grads['dense_0']['kernel']), jnp.min(grads['dense_0']['kernel']))
        # print(loss)
        state = update_model(state, grads)
        epoch_loss.append(loss)
        epoch_accuracy.append(accuracy)
        batch_id += 1
        if batch_id % 100 == 0:
            progress_bar.set_postfix(loss=loss.item(), accuracy=accuracy.item())

        
    train_loss = np.mean(epoch_loss)
    train_accuracy = np.mean(epoch_accuracy)
    return state, train_loss, train_accuracy


In [36]:
for epoch in range(30):
    train_state, train_loss, train_accuracy = run_epoch(train_state, trainloader, key)
    print(f"Epoch {epoch} | Loss: {train_loss} | Accuracy: {train_accuracy}")

Training: 100%|██████████| 469/469 [00:07<00:00, 62.73it/s, accuracy=0.953, loss=0.133]


Epoch 0 | Loss: 0.2339933216571808 | Accuracy: 0.9293543696403503


Training: 100%|██████████| 469/469 [00:07<00:00, 62.55it/s, accuracy=0.938, loss=0.24] 


Epoch 1 | Loss: 0.19879622757434845 | Accuracy: 0.9397488236427307


Training:  91%|█████████▏| 429/469 [00:06<00:00, 61.91it/s, accuracy=0.953, loss=0.143]


KeyboardInterrupt: 