In [1]:
import torch
import math
from torch import nn
from torch.nn import functional as F
from matplotlib import pyplot as plt
from src.nn.saint import SAINTClassifier
# export mnist
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [2]:
# load data
train_set = datasets.MNIST(
    root='data', 
    train=True, 
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.view(-1)),
        transforms.Lambda(lambda x: (x > 0.1).long())
    ])
)
# split data into train and validation
train_set, val_set = torch.utils.data.random_split(train_set, [50000, 10000])

In [3]:
# training
def train_model_epoch(
    model,
    data_loader,
    optimizer,
    device,
    verbose: int=-1
):
    loss_accum = 0.0
    model.train()
    for step, (x, y) in enumerate(data_loader):
        x = x.to(device)
        y = y.to(device)

        optimizer.zero_grad()
        output = model((None, x), y)
        loss = output.loss
        
        optimizer.step()
        loss_accum += loss.detach()
        loss.backward()
        optimizer.step()

        if verbose > 0 and (step + 1) % verbose == 0:
            loss_mean = loss_accum / (step + 1)
            print(f'step: {step + 1:4d}, loss: {loss_mean:.4f}')

    return model, loss_accum

In [4]:
def evaluate(model, data_loader, device):
    loss_accm = 0.0
    n_correct, n_total = 0, 0
    model.eval()
    with torch.no_grad():
        for x, y in data_loader:
            x = x.to(device)
            y = y.to(device)
            output = model((None, x), y)
            loss = output.loss
            loss_accm += loss.detach() * x.size(0)

            pred = output.logits.argmax(dim=-1)
            n_correct += (pred == y).sum().item()
            n_total += y.size(0)
    loss = loss_accm / n_total
    acc = n_correct / n_total
    return loss, acc

In [5]:
def train_model(
    model,
    train_loader,
    valid_loader,
    optimizer,
    device,
    n_epochs: int=10,
):
    for epoch in range(n_epochs):
        model, train_loss = train_model_epoch(model, train_loader, optimizer, device, verbose=-1)
        train_loss /= len(train_loader)
        # evaluate on validation set
        train_loss, train_acc = evaluate(model, train_loader, device)
        valid_loss, valid_acc = evaluate(model, valid_loader, device)
        print(f'epoch: {epoch + 1:3d}, train_loss: {train_loss:.4f}, train_acc: {train_acc:.4f}, valid_loss: {valid_loss:.4f}, valid_acc: {valid_acc:.4f}')
    
    return model

In [6]:
model = SAINTClassifier(
    dense_size=0,
    sparse_size=784,
    sparse_key_size=2,
    num_hiddens=32,
    num_classes=10,
    num_layers=4,
    num_heads=4,
    inter_sample=False,
    col_attn_latent_dim=32,
    row_attn_latent_dim=64,
    attn_dropout=0.2,
    ffn_dropout=0.2,
    ffn_hiddens_factor=4,
    col_embedding=True
)
optimizer = model.build_optimizer(lr=1e-4, weight_decay=1e-3)

In [7]:
# create data loader
train_loader = DataLoader(train_set, batch_size=128, shuffle=True)
valid_loader = DataLoader(val_set, batch_size=128, shuffle=False)

In [8]:
device = 'cuda'
model = model.to(device)

In [11]:
model = train_model(model, train_loader, valid_loader, optimizer, device, n_epochs=50)

epoch:   1, train_loss: 0.1541, train_acc: 0.9506, valid_loss: 0.1663, valid_acc: 0.9454
epoch:   2, train_loss: 0.1575, train_acc: 0.9490, valid_loss: 0.1728, valid_acc: 0.9458
epoch:   3, train_loss: 0.1531, train_acc: 0.9505, valid_loss: 0.1663, valid_acc: 0.9467
epoch:   4, train_loss: 0.1483, train_acc: 0.9527, valid_loss: 0.1607, valid_acc: 0.9485
epoch:   5, train_loss: 0.1534, train_acc: 0.9511, valid_loss: 0.1687, valid_acc: 0.9462
epoch:   6, train_loss: 0.1488, train_acc: 0.9525, valid_loss: 0.1650, valid_acc: 0.9480
epoch:   7, train_loss: 0.1484, train_acc: 0.9527, valid_loss: 0.1621, valid_acc: 0.9481
epoch:   8, train_loss: 0.1503, train_acc: 0.9509, valid_loss: 0.1672, valid_acc: 0.9469
epoch:   9, train_loss: 0.1413, train_acc: 0.9544, valid_loss: 0.1549, valid_acc: 0.9514
epoch:  10, train_loss: 0.1434, train_acc: 0.9541, valid_loss: 0.1561, valid_acc: 0.9508
epoch:  11, train_loss: 0.1453, train_acc: 0.9534, valid_loss: 0.1610, valid_acc: 0.9486
epoch:  12, train_los