In [None]:
import typing as tp
import time

import equinox as eqx
from equinox import (
    nn,
    Module
)
import jax
from jax import (
    numpy as jnp,
    Array
)
import optax

import torch
import torchvision

print(f"| JAX: {jax.__version__} | Equinox: {eqx.__version__} |")

| JAX: 0.4.31 | Equinox: 0.11.4 |


In [2]:
trainset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True)
train_data = trainset.data.float()
y_train = trainset.targets

testset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True)
X_VAL = testset.data.float()
y_val = testset.targets

X_train = train_data.reshape((-1, 28*28))/255.
X_val = X_VAL.reshape((-1, 28*28))/255.

num_classes = torch.unique(y_train)

print(f"X_train.shape = {X_train.shape}\ny_train.shape = {y_train.shape}\nX_val.shape = {X_val.shape}\ny_val.shape = {y_val.shape}")

X_train.shape = torch.Size([60000, 784])
y_train.shape = torch.Size([60000])
X_val.shape = torch.Size([10000, 784])
y_val.shape = torch.Size([10000])


In [None]:
def shuffle(X, y): # (m, ...), (m, ...)
    """returns X, y (but shuffled)"""
    idx = torch.randperm(X.shape[0])
    return X[idx], y[idx]

def batch(X, y, batch_size):
    """returns X, y (batched form of inputs X and y)"""
    num_features = X.shape[-1]
    rem_idx = None
    if len(X)%batch_size != 0:
        rem_idx = -(len(X)%batch_size)
    X, y = shuffle(X, y) # (m, 784), (m, 10)
    X, y = X[:rem_idx], y[:rem_idx]
    X, y = X[None].reshape((-1, batch_size, num_features)), y[None].reshape((-1, batch_size))
    return X, y # (m//B, B, 784), (m//B, B, 10)

In [None]:
class config:
    batch_size = 32
    output_size = 10
    input_size = 784

    epochs:int = 30
    cumbersome_model_lr:float = ...
    distilled_model_lr:float = ...

In [None]:
class CumbersomeModel(Module):
    def __init__(
        self, 
        key:Array, 
        input_size:int=784, 
        hidden_size:int=1200, 
        output_size:int=10,
        dropout_rates:list[float]=[0.2, 0.5]
    ):
        k1, k2, k3, k4, k5 = jax.random.split(key, 5)

        self.linear1 = nn.Linear(input_size, hidden_size, key=k1)
        self.dropout1 = nn.Dropout(dropout_rates[0], key=k4)

        self.linear2 = nn.Linear(hidden_size, hidden_size, key=k2)
        self.dropout2 = nn.Dropout(dropout_rates[1], key=k5)

        self.linear3 = nn.Linear(hidden_size, output_size, key=k3)
    
    def __call__(self, x:Array, training:bool=True): # (784,)
        x = self.dropout1(self.linear1(x), inference=not training) # (1200,)
        x = jax.nn.relu(x)
        x = self.dropout2(self.linear2(x), inference=not training) # (1200,)
        x = jax.nn.relu(x)
        x = self.linear3(x) # (10,)
        return x
    

class DistilledModel(CumbersomeModel):
    def __init__(self, key:Array, hidden_size:int=800, dropout_rates:list[float]=[0.0, 0.0]):
        super().__init__(key, hidden_size=hidden_size, dropout_rates=dropout_rates)

In [None]:
@eqx.filter_jit
def compute_accuracy(
    model: Module, 
    x: Array, # (B, 784)
    y: Array  # (B,)
):
    pred_y = jax.vmap(model)(x, training=False) # (B, 10)
    pred_y = jnp.argmax(pred_y, axis=1) # (B,)
    return jnp.mean(y == pred_y)

@eqx.filter_jit
def soft_target_loss(
    target_logits:Array, # (B, 10)
    student_logits:Array, # (B, 10)
    temperature:float
):
    target_probs = jax.nn.softmax(target_logits / temperature)
    student_log_probs = jax.nn.log_softmax(student_logits / temperature)
    return -jnp.sum(target_probs * student_log_probs, axis=-1).mean(0)

@eqx.filter_jit
def hard_target_loss(
    targets:Array, # (B,)
    logits:Array # (B, 10)
):
    return -jnp.sum(
        jax.nn.one_hot(targets, config.output_size) * # target prob dist
        jax.nn.log_softmax(logits), axis=-1   # student prob dist
    ).mean(0)

@eqx.filter_jit
def undistilled_model_loss(
    model:CumbersomeModel|DistilledModel, 
    inputs:Array, # (B, 784)
    targets:Array, # (B,)
    training:bool=True
):
    predict = jax.vmap(model)(inputs, training=training)
    return hard_target_loss(targets, predict)

@eqx.filter_jit
def distilled_model_loss(
    model:DistilledModel,
    target_model:CumbersomeModel,
    hard_targets:Array, # (B,)
    inputs:Array, # (B, 784)
    temperature:float,
    w:float,
    training:bool=True
):
    assert 0 <= w <= 1, "w must be in [0, 1]"
    target_logits = jax.lax.stop_gradient(jax.vmap(target_model)(inputs, training=False))
    student_logits = jax.vmap(model)(inputs, training=training)
    return (
        w* soft_target_loss(target_logits, student_logits, temperature=temperature) +
        (1-w) * hard_target_loss(hard_targets, student_logits)
        )

In [None]:
def train_cumbersome_model(
    model:CumbersomeModel, X_train:Array, y_train:Array, X_val:Array, y_val:Array
):
    X_train, y_train = batch(X_train, y_train, config.batch_size) # (m//B, B, 784), (m//B, B, 10)
    X_val, y_val = batch(X_val, y_val, config.batch_size) # (m//B, B, 784), (m//B, B, 10)

    optim = optax.adamw(learning_rate=config.cumbersome_model_lr, weight_decay=1e-4)
    opt_state = optim.init(eqx.filter(model, eqx.is_array))

    @eqx.filter_jit
    def update(
        model:CumbersomeModel, 
        opt_state:optax.OptState, 
        inputs:Array, 
        targets:Array
    ):
        loss, grads = eqx.filter_value_and_grad(undistilled_model_loss)(model, inputs, targets)
        updates, opt_state = optim.update(grads, opt_state, eqx.filter(model, eqx.is_array))
        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss
    
    for epoch in range(1, config.epochs+1):
        X_train, y_train = shuffle(X_train, y_train)
        losses = []
        t0 = time.time()
        #          (B, 784), (B,)
        for (step, (inputs, targets)) in enumerate(zip(X_train, y_train)):
            inputs, targets = inputs.numpy(), targets.numpy()
            model, opt_state, loss = update(model, opt_state, inputs, targets); losses.append(loss)
        
        train_accuracy = compute_accuracy(model, X_train.reshape((-1, 784)), y_train.reshape(-1))
        val_accuracy = compute_accuracy(model, X_val.reshape((-1, 784)), y_val.reshape(-1))

        mean_train_loss = jnp.array(losses).mean()
        mean_val_loss = undistilled_model_loss(model, X_val.reshape((-1, 784)), y_val.reshape(-1), training=False)
        dt = time.time() - t0
        print(
            f"Epoch: {epoch}/{config.epochs} | Train Loss: {mean_train_loss:.4f} | Validation Loss: {mean_val_loss:.4f} |"
            f"Train Accuracy: {train_accuracy:.4f} | Validation Accuracy: {val_accuracy:.4f} |"
            f"dt per epoch: {dt}s |"
        )
    
    return model

In [None]:
def train_distilled_model_without_cumbersome_model(
    model:DistilledModel, 
    X_train:Array, y_train:Array, X_val:Array, y_val:Array
):
    X_train, y_train = batch(X_train, y_train, config.batch_size) # (m//B, B, 784), (m//B, B, 10)
    X_val, y_val = batch(X_val, y_val, config.batch_size) # (m//B, B, 784), (m//B, B, 10)

    optim = optax.adamw(learning_rate=config.distilled_model_lr, weight_decay=0e-4)
    opt_state = optim.init(eqx.filter(model, eqx.is_array))

    @eqx.filter_jit
    def update(
        model:DistilledModel, 
        opt_state:optax.OptState, 
        inputs:Array, 
        targets:Array
    ):
        loss, grads = eqx.filter_value_and_grad(undistilled_model_loss)(model, inputs, targets)
        updates, opt_state = optim.update(grads, opt_state, eqx.filter(model, eqx.is_array))
        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss
    
    for epoch in range(1, config.epochs+1):
        X_train, y_train = shuffle(X_train, y_train)
        losses = []
        t0 = time.time()
        #          (B, 784), (B,)
        for (step, (inputs, targets)) in enumerate(zip(X_train, y_train)):
            inputs, targets = inputs.numpy(), targets.numpy()
            model, opt_state, loss = update(model, opt_state, inputs, targets); losses.append(loss)
        
        train_accuracy = compute_accuracy(model, X_train.reshape((-1, 784)), y_train.reshape(-1))
        val_accuracy = compute_accuracy(model, X_val.reshape((-1, 784)), y_val.reshape(-1))

        mean_train_loss = jnp.array(losses).mean()
        mean_val_loss = undistilled_model_loss(model, X_val.reshape((-1, 784)), y_val.reshape(-1), training=False)
        dt = time.time() - t0
        print(
            f"Epoch: {epoch}/{config.epochs} | Train Loss: {mean_train_loss:.4f} | Validation Loss: {mean_val_loss:.4f} |"
            f"Train Accuracy: {train_accuracy:.4f} | Validation Accuracy: {val_accuracy:.4f} |"
            f"dt per epoch: {dt}s |"
        )
    return model

In [None]:
def train_distilled_model_with_cumbersome_model(
    model:DistilledModel, cumbersom_model:CumbersomeModel,
    X_train:Array, y_train:Array, X_val:Array, y_val:Array
):
    X_train, y_train = batch(X_train, y_train, config.batch_size) # (m//B, B, 784), (m//B, B, 10)
    X_val, y_val = batch(X_val, y_val, config.batch_size) # (m//B, B, 784), (m//B, B, 10)

    optim = optax.adamw(learning_rate=config.distilled_model_lr, weight_decay=0e-4)
    opt_state = optim.init(eqx.filter(model, eqx.is_array))

    @eqx.filter_jit
    def update(
        model:DistilledModel, 
        opt_state:optax.OptState, 
        inputs:Array, 
        targets:Array
    ):
        loss, grads = eqx.filter_value_and_grad(distilled_model_loss)(model, cumbersom_model, inputs, targets, config.temperature, config.w)
        updates, opt_state = optim.update(grads, opt_state, eqx.filter(model, eqx.is_array))
        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss
    
    for epoch in range(1, config.epochs+1):
        X_train, y_train = shuffle(X_train, y_train)
        losses = []
        t0 = time.time()
        #          (B, 784), (B,)
        for (step, (inputs, targets)) in enumerate(zip(X_train, y_train)):
            inputs, targets = inputs.numpy(), targets.numpy()
            model, opt_state, loss = update(model, opt_state, inputs, targets); losses.append(loss)
        
        train_accuracy = compute_accuracy(model, X_train.reshape((-1, 784)), y_train.reshape(-1))
        val_accuracy = compute_accuracy(model, X_val.reshape((-1, 784)), y_val.reshape(-1))

        mean_train_loss = jnp.array(losses).mean()
        mean_val_loss = distilled_model_loss(model, X_val.reshape((-1, 784)), y_val.reshape(-1), training=False)
        dt = time.time() - t0
        print(
            f"Epoch: {epoch}/{config.epochs} | Train Loss: {mean_train_loss:.4f} | Validation Loss: {mean_val_loss:.4f} |"
            f"Train Accuracy: {train_accuracy:.4f} | Validation Accuracy: {val_accuracy:.4f} |"
            f"dt per epoch: {dt}s |"
        )
    return model

In [None]:
# TODO: Train the models, plot the results and compare the results. This is not yet complete.