In [22]:
import jax
import equinox as eqx
import equinox.nn as nn
import optax
from jax import numpy as jnp

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.utils.tensorboard import SummaryWriter
import warnings
warnings.simplefilter('ignore')  # ignore warnings
from pathlib import Path
import numpy as np
import os
import pickle
from typing import Callable, Optional, Tuple
from jaxtyping import PRNGKeyArray, PyTree, ArrayLike, Scalar, Array

In [18]:
batch_size = 128
tfms = transforms.Compose([transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))])
train_data = datasets.MNIST(root='./data',
    train=True, download=True, transform=tfms)
test_data = datasets.MNIST(
    root='./data', train=False, download=True, transform=tfms)

train_loader = DataLoader(
    dataset=train_data,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True
)
test_loader = DataLoader(
    dataset=test_data,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True
)

100%|██████████| 9.91M/9.91M [00:02<00:00, 3.50MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 156kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.20MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 4.34MB/s]


In [19]:
def get_act_fn(name: str) -> Callable:
    _ACT_FNS = [
        "linear", "tanh", "hard_tanh", "relu", "leaky_relu", "gelu", "selu", "silu"
    ]
    if name == "linear":
        return nn.Identity()
    elif name == "tanh":
        return jnp.tanh
    elif name == "hard_tanh":
        return jax.nn.hard_tanh
    elif name == "relu":
        return jax.nn.relu
    elif name == "leaky_relu":
        return jax.nn.leaky_relu
    elif name == "gelu":
        return jax.nn.gelu
    elif name == "selu":
        return jax.nn.selu
    elif name == "silu":
        return jax.nn.silu
    else:
        raise ValueError(f"""
                Invalid activation function ID. Options are {_ACT_FNS}.
        """)
    

def simple_make_mlp(key: PRNGKeyArray, 
        input_dim: int, 
        width: int,
        depth: int, 
        output_dim: int, 
        act_fn: str, 
        use_bias: bool = False,):
    subkeys = jax.random.split(key, depth + 1)
    layers = []
    for i in range(depth):
        act_fn_l = nn.Identity() if i == 0 else get_act_fn(act_fn)
        in_dim = input_dim if i == 0 else width
        out_dim = output_dim if i == depth - 1 else width

        linear = nn.Linear(in_dim, out_dim, use_bias=use_bias, key=subkeys[i])
        layers.append(eqx.nn.Sequential([nn.Lambda(act_fn_l), linear]))

    return layers

In [20]:
# debug
key = jax.random.PRNGKey(0)
toy_input = jax.random.normal(key, (1, 28*28))
mlp_layers = simple_make_mlp(key, 28*28, 512, 3, 10, "relu")
for layer in mlp_layers:
    vlayer = eqx.filter_vmap(layer, in_axes=0, out_axes=0)
    # vlayer = jax.vmap(layer, in_axes=0, out_axes=0)
    toy_input = vlayer(toy_input)
    print(toy_input.shape)

(1, 512)
(1, 512)
(1, 10)


In [26]:
@eqx.filter_jit
def one_step_forward(model: PyTree, x: ArrayLike) -> ArrayLike:
    for layer in model:
        vlayer = eqx.filter_vmap(layer, in_axes=0, out_axes=0)
        x = vlayer(x)
    return x

In [27]:
# Initialize model and optimizer state
key = jax.random.PRNGKey(42)
EPOCHS = 10
LR = 1e-3
optimizer = optax.adam(LR)

model = simple_make_mlp(key, 28*28, 512, 3, 10, "relu")
opt_state = optimizer.init(eqx.filter(model, eqx.is_array))

# Define loss function
def compute_loss(model, x, y):
    logits = one_step_forward(model, x)
    return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()

# Define update step
@eqx.filter_jit
def train_step(model, opt_state, x, y):
    loss, grads = eqx.filter_value_and_grad(compute_loss)(model, x, y)
    updates, opt_state = optimizer.update(grads, opt_state, eqx.filter(model, eqx.is_array))
    model = eqx.apply_updates(model, updates)
    return {'model': model, 
            'opt_state': opt_state, 
            'loss': loss,
            'grads': grads
        }


def validate(model, data_loader):
    total_loss = 0.0
    num_batches = 0
    preds = []
    ys = []
    for images, labels in data_loader:
        # Convert torch tensors to jax arrays
        x = jnp.array(images.numpy().reshape(batch_size, -1))
        y = jnp.array(labels.numpy())
        preds.append(one_step_forward(model, x).argmax(axis=-1))
        ys.append(y)
        loss = compute_loss(model, x, y)
        total_loss += loss
        num_batches += 1
    
    # compute accuracy
    acc = jnp.mean(jnp.concatenate(preds) == jnp.concatenate(ys))
    avg_loss = total_loss / num_batches
    return avg_loss, acc
# Training loop

for epoch in range(EPOCHS):
    total_loss = 0.0
    num_batches = 0
    
    for images, labels in train_loader:
        # Convert torch tensors to jax arrays
        x = jnp.array(images.numpy().reshape(batch_size, -1))
        y = jnp.array(labels.numpy())
        
        # Training step
        result = train_step(model, opt_state, x, y)
        total_loss += result['loss']
        model = result['model']
        opt_state = result['opt_state']
        num_batches += 1
    
    avg_loss, acc = validate(model, test_loader)
    print(f"Epoch {epoch + 1}/{EPOCHS}, Loss: {avg_loss:.4f}, Accuracy: {acc:.4f}")

print("Training complete!")


Epoch 1/10, Loss: 0.1039, Accuracy: 0.9667
Epoch 2/10, Loss: 0.0793, Accuracy: 0.9745
Epoch 2/10, Loss: 0.0793, Accuracy: 0.9745
Epoch 3/10, Loss: 0.0893, Accuracy: 0.9738
Epoch 3/10, Loss: 0.0893, Accuracy: 0.9738
Epoch 4/10, Loss: 0.0799, Accuracy: 0.9756
Epoch 4/10, Loss: 0.0799, Accuracy: 0.9756
Epoch 5/10, Loss: 0.0871, Accuracy: 0.9763
Epoch 5/10, Loss: 0.0871, Accuracy: 0.9763
Epoch 6/10, Loss: 0.0839, Accuracy: 0.9784
Epoch 6/10, Loss: 0.0839, Accuracy: 0.9784
Epoch 7/10, Loss: 0.0833, Accuracy: 0.9813
Epoch 7/10, Loss: 0.0833, Accuracy: 0.9813
Epoch 8/10, Loss: 0.0948, Accuracy: 0.9783
Epoch 8/10, Loss: 0.0948, Accuracy: 0.9783
Epoch 9/10, Loss: 0.0890, Accuracy: 0.9778
Epoch 9/10, Loss: 0.0890, Accuracy: 0.9778
Epoch 10/10, Loss: 0.0951, Accuracy: 0.9802
Training complete!
Epoch 10/10, Loss: 0.0951, Accuracy: 0.9802
Training complete!


In [25]:
model

[Sequential(
   layers=(
     Lambda(fn=Identity()),
     Linear(
       weight=f32[512,784],
       bias=None,
       in_features=784,
       out_features=512,
       use_bias=False
     )
   )
 ),
 Sequential(
   layers=(
     Lambda(fn=<PjitFunction of <function relu at 0x77d364e62ca0>>),
     Linear(
       weight=f32[512,512],
       bias=None,
       in_features=512,
       out_features=512,
       use_bias=False
     )
   )
 ),
 Sequential(
   layers=(
     Lambda(fn=<PjitFunction of <function relu at 0x77d364e62ca0>>),
     Linear(
       weight=f32[10,512],
       bias=None,
       in_features=512,
       out_features=10,
       use_bias=False
     )
   )
 )]

In [68]:
def apply_and_values(fs: list[Callable], x):
    f, *fs = fs
    # y = one_step_forward(f, x)
    y = f(x)
    return [y] + (apply_and_values(fs, y) if len(fs) > 0 else [])


In [69]:
from functools import partial

In [None]:
x, y = next(iter(test_loader))
x = jnp.array(images.numpy().reshape(batch_size, -1))
y = jnp.array(labels.numpy())
rst = apply_and_values(
    list(
        map(
            lambda l: partial(one_step_forward, l), 
            model
            )
            ), 
    x)

In [71]:
len(rst), len(model)

(3, 3)

In [72]:
a, *b = [1]

In [60]:
b

[]