In [1]:
!which python
!pwd

/home/tristan/miniconda3/envs/.jax_conda_env_MML_2/bin/python
/home/tristan/ModernML/S4


# Random SSM
This notebook trains a Random SSM on sin(x) or sin(ax+b) sequences. 
The hyperparameters are defined in the `config.yaml` file. 
- N (hidden state dimension): 64
- H (d_model in yaml, number of heads): 10

In [1]:
import jax
import jax.numpy as jnp
import numpy as np

from data import Datasets
from model import BatchStackedModel, MultiHeadSSMLayer, MultiHeadS4DLayer
import torch
import torchsummary
from utils import cross_entropy_loss, compute_accuracy
from tqdm import tqdm

import os 
# set device to "2"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

#check if cuda is available
torch.cuda.is_available()

True

In [2]:
rng = jax.random.PRNGKey(1)

In [3]:
from omegaconf import OmegaConf

# load the configuration file
cfg = OmegaConf.load("config.yaml")

In [4]:
dataset = cfg.dataset  # str, sin_ax, sin_ax_b
layer = cfg.layer  # str, s4
seed  = cfg.seed  # int, 0
model_p = cfg.model  # DictConfig, {d_model: 10 (nb heads), n_layers: 4, dropout: 0.0, prenorm: true, embedding: false, layer: {N:64}}
train_p = cfg.train  # DictConfig, {epochs: 100, bsz: 128, lr: 0.001, lr_schedule: false, weight_decay: 0.01, checkpoint: false, suffix: null, sample: null}

In [5]:
# Set randomness...
print("[*] Setting Randomness...")
torch.random.manual_seed(seed)  # For dataloader order
key = jax.random.PRNGKey(seed)
key, rng, train_rng = jax.random.split(key, num=3) # TODO: This is bad 


[*] Setting Randomness...


In [6]:
# Check if classification dataset
# because for each dataset there are two versions, eg: mnist and mnist-classification
dataset = 'cifar-gs-classification'
classification = "classification" in dataset
print(f'{classification=}')

create_dataset_fn = Datasets[dataset]
# trainloader, testloader, n_classes, l_max, d_input, data = create_dataset_fn(n_examples=1024, bsz=train.bsz)
trainloader, testloader, n_classes, l_max, d_input = create_dataset_fn(bsz=train_p.bsz)
print(f'{n_classes=}, {l_max=}, {d_input=}')

print(f'{next(iter(trainloader))[0].shape=}')
print(f'{next(iter(testloader))[0].shape=}')

classification=True
[*] Generating CIFAR-10 Classification Dataset
Files already downloaded and verified
Files already downloaded and verified
n_classes=10, l_max=1024, d_input=1
next(iter(trainloader))[0].shape=torch.Size([64, 1024, 1])
next(iter(testloader))[0].shape=torch.Size([64, 1024, 1])


In [7]:
from model_torch import S4Model

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = S4Model(
    d_input=d_input,
    d_output=n_classes,
    d_model=model_p.d_model,
    d_state=model_p.layer.N,
    lr=train_p.lr,
    dropout=model_p.dropout,
    
)
model = model.to(device)

In [9]:
# print model summary
print(model)

S4Model(
  (encoder): Linear(in_features=1, out_features=128, bias=True)
  (s4_layers): ModuleList(
    (0-3): 4 x S4D(
      (kernel): S4DKernel()
      (activation): GELU(approximate='none')
      (dropout): DropoutNd()
      (output_linear): Sequential(
        (0): Conv1d(128, 256, kernel_size=(1,), stride=(1,))
        (1): GLU(dim=-2)
      )
    )
  )
  (norms): ModuleList(
    (0-3): 4 x LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  )
  (dropouts): ModuleList(
    (0-3): 4 x Dropout1d(p=0.1, inplace=False)
  )
  (decoder): Linear(in_features=128, out_features=10, bias=True)
)


In [10]:
# torchsummary.summary(model, (1,1,2,))

In [11]:
# # define a adam optimizer
# optimizer = torch.optim.AdamW(model.parameters(), lr=train_p.lr, weight_decay=train_p.weight_decay)
# #scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=train_p.epochs)
# 
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10, factor=0.2)
# # define weight decay
# 
# # define a loss function
# criterion = torch.nn.CrossEntropyLoss()

# define device of apple "mps"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'{device=}')
model = model.to(device)

device=device(type='cuda')


In [12]:
def setup_optimizer(model, lr, weight_decay, epochs):
    """
    S4 requires a specific optimizer setup.

    The S4 layer (A, B, C, dt) parameters typically
    require a smaller learning rate (typically 0.001), with no weight decay.

    The rest of the model can be trained with a higher learning rate (e.g. 0.004, 0.01)
    and weight decay (if desired).
    """

    # All parameters in the model
    all_parameters = list(model.parameters())

    # General parameters don't contain the special _optim key
    params = [p for p in all_parameters if not hasattr(p, "_optim")]

    # Create an optimizer with the general parameters
    optimizer = torch.optim.AdamW(params, lr=lr, weight_decay=weight_decay)

    # Add parameters with special hyperparameters
    hps = [getattr(p, "_optim") for p in all_parameters if hasattr(p, "_optim")]
    hps = [
        dict(s) for s in sorted(list(dict.fromkeys(frozenset(hp.items()) for hp in hps)))
    ]  # Unique dicts
    for hp in hps:
        params = [p for p in all_parameters if getattr(p, "_optim", None) == hp]
        optimizer.add_param_group(
            {"params": params, **hp}
        )

    # Create a lr scheduler
    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=patience, factor=0.2)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)

    # Print optimizer info
    keys = sorted(set([k for hp in hps for k in hp.keys()]))
    for i, g in enumerate(optimizer.param_groups):
        group_hps = {k: g.get(k, None) for k in keys}
        print(' | '.join([
            f"Optimizer group {i}",
            f"{len(g['params'])} tensors",
        ] + [f"{k} {v}" for k, v in group_hps.items()]))

    return optimizer, scheduler

criterion = torch.nn.CrossEntropyLoss()
optimizer, scheduler = setup_optimizer(
    model, lr=train_p.lr, weight_decay=train_p.weight_decay, epochs=train_p.epochs
)

Optimizer group 0 | 28 tensors | lr 0.01 | weight_decay 0.01
Optimizer group 1 | 12 tensors | lr 0.001 | weight_decay 0.0


In [13]:
for param in optimizer.param_groups[1]['params']:
    print(param.shape)

torch.Size([128])
torch.Size([128, 32])
torch.Size([128, 32])
torch.Size([128])
torch.Size([128, 32])
torch.Size([128, 32])
torch.Size([128])
torch.Size([128, 32])
torch.Size([128, 32])
torch.Size([128])
torch.Size([128, 32])
torch.Size([128, 32])


In [14]:
# Training
def train():
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    pbar = tqdm(enumerate(trainloader))
    for batch_idx, (inputs, targets) in pbar:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        pbar.set_description(
            'Batch Idx: (%d/%d) | Loss: %.3f | Acc: %.3f%% (%d/%d)' %
            (batch_idx, len(trainloader), train_loss/(batch_idx+1), 100.*correct/total, correct, total)
        )

In [15]:
def eval(epoch, dataloader, checkpoint=False):
    global best_acc
    model.eval()
    eval_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        pbar = tqdm(enumerate(dataloader))
        for batch_idx, (inputs, targets) in pbar:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            eval_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            pbar.set_description(
                'Batch Idx: (%d/%d) | Loss: %.3f | Acc: %.3f%% (%d/%d)' %
                (batch_idx, len(dataloader), eval_loss/(batch_idx+1), 100.*correct/total, correct, total)
            )

    # Save checkpoint.
    if checkpoint:
        acc = 100.*correct/total
        if acc > best_acc:
            state = {
                'model': model.state_dict(),
                'acc': acc,
                'epoch': epoch,
            }
            if not os.path.isdir('checkpoint'):
                os.mkdir('checkpoint')
            torch.save(state, './checkpoint/ckpt.pth')
            best_acc = acc

        return acc

In [None]:
for epoch in range(train_p.epochs):
    if epoch == 0:
        print(f"Epoch | Train Loss | Train Acc | Test Loss | Test Acc")
        print(f"-------------------------------------------------------")
    train()
    val_acc = eval(epoch, testloader, checkpoint=False)
    eval(epoch, testloader)
    scheduler.step(val_acc)
    print(f'{epoch=}, {val_acc=}, {scheduler.get_last_lr()}')
    # print(f"Epoch {epoch} learning rate: {scheduler.get_last_lr()}")
