In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import os

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using {} device'.format(device))


from utils import show_images
from models import MultilayerPerceptron
from models import MLPMixer # need to install <einops>

from layers import CustomLinear

# Load dataset

In [None]:
root = 'FashionMNISH'

transform = transforms.Compose([
    transforms.ToTensor()
])

train_data = datasets.FashionMNIST(
    root,
    train=True,
    download=True,
    transform=transform
)

test_data = datasets.FashionMNIST(
    root, 
    train=False,
    download=True,
    transform=transform
)

In [None]:
torch.manual_seed(42)

train_dataloader = DataLoader(train_data, batch_size=2048, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=2048, shuffle=True)

labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}


N_samples = 9
images, labels = next(iter(train_dataloader))
show_images(images[:N_samples], [labels_map[i.item()] for i in labels[:N_samples]])

In [None]:
import IPython
from math import ceil
from time import time


def train_loop(model, dataloader, loss_fn, optimizer, step=0.05):
    out = display(IPython.display.Pretty('Learning...'), display_id=True)

    size = len(dataloader.dataset) 
    len_size = len(str(size))
    batches = ceil(size / dataloader.batch_size) - 1
    
    percentage = 0
    
    history = {
        'backward_time': []
    }
    for batch, (X, y) in enumerate(tqdm(dataloader, leave=False, desc="Batch #")):
        X, y = X.to(device), y.to(device)

        # evaluate
        pred = model(X)
        loss = loss_fn(pred, y)
        
        # backpropagation
        optimizer.zero_grad()
        start = time()
        loss.backward()
        history['backward_time'].append(time() - start)
        optimizer.step()
        
        # print info
        if batch / batches > percentage or batch == batches: 
            out.update(f'[{int(percentage * size)}/{size}] Loss: {loss:>8f}')
            percentage += step
            
    return history
        
        
def test_loop(model, dataloader, loss_fn):

    size = len(dataloader.dataset)
    test_loss, correct = 0, 0
    batches = ceil(size / dataloader.batch_size)

    with torch.no_grad():
        for batch, (X, y) in enumerate(tqdm(dataloader, leave=False, desc='Batch #')):
            
            X, y = X.to(device), y.to(device)
            
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(dim=1) == y).type(torch.int).sum().item()

    test_loss /= batches
    correct /= size
    
    print(f"Validation accuracy: {(100*correct):>0.1f}%, Validation loss: {test_loss:>8f} \n")
    return 100 * correct, test_loss

# Load mixer

### Uncomment this to start the memory test. ( ~7 Gb VideoRAM instead of 16)

In [None]:
# %%time 

# batch_size = 510

# train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
# test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True)

# net = MLPMixer(
#     image_size=28, channels=1, patch_size=4, 
#     dim=256, depth=15, 
#     num_classes=10, 
#     Dense=CustomLinear(100, 'naive')
# ).to(device)

# loss_fn = nn.CrossEntropyLoss().to(device)
# optimizer = torch.optim.Adam(net.parameters())

# for epoch in range(1):
#     print(f"Epoch {epoch+1}\n-------------------------------")
#     train_loop(net, train_dataloader, loss_fn, optimizer)
#     test_loop(net, test_dataloader, loss_fn)

In [None]:
# print('Number of parameters:', sum(p.numel() for p in net.parameters()))

# Load MLP

In [None]:
blocks = [256, 512, 512, 256]
epochs = 8
in_features = images[0].shape
out_features = 10
k = 500

# Test models

### Accuracy, Loss and Time per Batch tests

In [None]:
%%time

history = {
    'torch': None,
    'naive': None,
    'vanilla': None,
    'gauss': None,
}

train_dataloader = DataLoader(train_data, batch_size=2048, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=2048, shuffle=True)

for method in history.keys():
    
    print(f'METHOD: {method}')
    
    net = MultilayerPerceptron(in_features, out_features, blocks, CustomLinear(k, method)).to(device)

    loss_fn = nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.Adam(net.parameters())

    backward_time = []
    accuracy = []
    loss = []
    for epoch in range(epochs):
        print(f"Epoch {epoch+1}\n-------------------------------")
        backward_time.append(train_loop(net, train_dataloader, loss_fn, optimizer)['backward_time'])
        val_acc, val_loss = test_loop(net, test_dataloader, loss_fn)
        accuracy.append(val_acc)
        loss.append(val_loss)
        
    history[method] = (backward_time, accuracy, loss)

In [None]:
plt.figure(figsize=(6, 6))
for method, (backward_time, accuracy, loss) in history.items():
    
    plt.plot(np.arange(len(loss)), loss, label=method)
    
plt.xlabel('Epoch #')
plt.ylabel('Loss on validation')
plt.legend()
plt.show()

In [None]:
plt.figure(figsize=(6, 6))
for method, (backward_time, accuracy, loss) in history.items():
    
    plt.plot(np.arange(len(accuracy)), accuracy, label=method)
    
plt.xlabel('Epoch #')
plt.ylabel('Accuracy on validation')
plt.legend()
plt.show()

### Execution time over batch size

In [None]:
%%time


batch_sizes = [64, 128, 256, 512, 1024, 2048, 4096]

history = {
    'torch': {},
    'naive': {},
    'vanilla': {},
    'gauss': {},
}

for batch_size in batch_sizes:
    
    print(f'BATCH SIZE: {batch_size}')
    
    train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True)

    for method in history.keys():

        print(f'METHOD: {method}')

        net = MultilayerPerceptron(in_features, out_features, blocks, CustomLinear(k, method))

        loss_fn = nn.CrossEntropyLoss().to(device)
        optimizer = torch.optim.Adam(net.parameters())

        backward_time = train_loop(net, train_dataloader, loss_fn, optimizer)['backward_time']
        history[method][batch_size] = backward_time

In [None]:
plt.figure(figsize=(6, 6))
for method, backward_time in history.items():
    
    batch_sizes = backward_time.keys()
    timings = [np.mean(backward_time[batch_size]) for batch_size in batch_sizes]

    plt.plot(batch_sizes, timings, label=method)
    
plt.xscale('log', base=2)
plt.xlabel('Batch size')
plt.ylabel('Average time, s')
plt.legend()
plt.show()

### Accuracy after 2 epochs over different factors

In [None]:
%%time


factors = [20, 50, 100, 200, 300, 400, 600, 1000]

history = {
    'torch': {},
    'naive': {},
    'vanilla': {},
    'gauss': {},
}

train_dataloader = DataLoader(train_data, batch_size=2048, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=2048, shuffle=True)

for factor in factors:
    
    print(f'FACTOR: {factor}')
    

    for method in history.keys():

        print(f'METHOD: {method}')

        net = MultilayerPerceptron(in_features, out_features, blocks, CustomLinear(factor, method))

        loss_fn = nn.CrossEntropyLoss().to(device)
        optimizer = torch.optim.Adam(net.parameters())

        train_loop(net, train_dataloader, loss_fn, optimizer)['backward_time']
        acc, loss = test_loop(net, train_dataloader, loss_fn)
        history[method][factor] = acc

In [None]:
plt.figure(figsize=(6, 6))
for method, accuracy in history.items():
    
    factors = accuracy.keys()
    accuracy = [np.mean(accuracy[factor]) for factor in factors]

    plt.plot(factors, accuracy, label=method)
    
plt.xlabel('Factor, $r$')
plt.ylabel('Accuracy')
plt.legend()
plt.show()