In [6]:
import torchvision
import torch
import torch.nn as nn

from torchvision import transforms

from tqdm import tqdm

from mlp_mixer.model.mlp_mixer import MLPMixer

In [2]:
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
dataset1 = torchvision.datasets.MNIST('./data', train=True, download=True,
                                      transform=transform)
dataset2 = torchvision.datasets.MNIST('./data', train=False,
                                      transform=transform)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw



In [13]:
model = MLPMixer(image_size=224, in_channels=1, patch_size=16)

In [9]:
train_params = {
    'epochs': 16,
    'lr': 0.1,
    'eval_portion': 0.2,
    'batch_size': 16
}

device = torch.device('cuda')

In [10]:
from torch.utils.data import DataLoader, random_split

EVAL_LENGTH = int(len(dataset1) * train_params['eval_portion'])

train_set, eval_set = random_split(dataset1, [len(dataset1) - EVAL_LENGTH, EVAL_LENGTH])
train_loader = DataLoader(train_set, batch_size=train_params['batch_size'],
                          shuffle=True)

eval_loader = DataLoader(eval_set, batch_size=train_params['batch_size'],
                         shuffle=True)

test_loader = DataLoader(dataset2, batch_size=train_params['batch_size'],
                         shuffle=True)


In [15]:
model = model.to(device)

optimizer = torch.optim.SGD(model.parameters(), lr=train_params['lr'], momentum=0.9, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.9)

loss = nn.CrossEntropyLoss()

best_accuracy = 0.0

for e in range(train_params['epochs']):
    for images, labels in tqdm(iter(train_loader), desc='Training...'):
        optimizer.zero_grad()

        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        loss = loss(outputs, labels)

        loss.backward()
        optimizer.step()

    test_error_count = 0.0
    for images, labels in tqdm(iter(eval_loader), desc='Eval...'):
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        test_error_count += float(torch.sum(torch.abs(labels - outputs.argmax(1))))

    test_accuracy = 1.0 - float(test_error_count) / float(len(eval_set))

    print(f'Epoch: {e}, accuracy {test_accuracy}')
    if test_accuracy > best_accuracy:
        torch.save(model.state_dict(), 'checkpoints/model.pth')
        best_accuracy = test_accuracy

Training...:   0%|          | 0/3000 [00:01<?, ?it/s]


RuntimeError: Given normalized_shape=[14], expected input with shape [*, 14], but got input of size[16, 196, 768]