In [1]:
!pip -q install vit_pytorch 

In [None]:
#You need to install the following python packages
#pytorch, vit_pytorch.
import torch
import torchvision
from vit_pytorch import ViT
import time
import torch.nn.functional as F
import torch.optim as optim

torch.manual_seed(97)

Dpath = '/data/mnist'
Bs_Train = 100
Bs_Test = 1000

tform_mnist = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                                  torchvision.transforms.Normalize((0.1307,),(0.3081,))])

tr_set = torchvision.datasets.MNIST(Dpath, train = True, download = True,
                                       transform = tform_mnist)

tr_load = torch.utils.data.DataLoader(tr_set, batch_size = Bs_Train, shuffle = True)

ts_set = torchvision.datasets.MNIST(Dpath, train = False, download = True, transform = tform_mnist)

ts_load = torch.utils.data.DataLoader(ts_set, batch_size = Bs_Test, shuffle = True)

def train_iter(model, optimz, data_load, loss_val):
    samples = len(data_load.dataset)
    model.train()
    
    for i, (data, target) in enumerate(data_load):
        optimz.zero_grad()
        out = F.log_softmax(model(data), dim=1)
        loss = F.nll_loss(out, target)
        loss.backward()
        optimz.step()
        
        if i % 100 == 0:
            print('[' +  '{:5}'.format(i * len(data)) + '/' + '{:5}'.format(samples) +
                  ' (' + '{:3.0f}'.format(100 * i / len(data_load)) + '%)]  Loss: ' +
                  '{:6.4f}'.format(loss.item()))
            loss_val.append(loss.item())

def evaluate(model, data_load, loss_val):
    model.eval()
    
    samples = len(data_load.dataset)
    csamp = 0
    tloss = 0

    with torch.no_grad():
        for data, target in data_load:
            output = F.log_softmax(model(data), dim=1)
            loss = F.nll_loss(output, target, reduction='sum')
            _, pred = torch.max(output, dim=1)
            
            tloss += loss.item()
            csamp += pred.eq(target).sum()

    aloss = tloss / samples
    loss_val.append(aloss)
    print('\nAverage test loss: ' + '{:.4f}'.format(aloss) +
          '  Accuracy:' + '{:5}'.format(csamp) + '/' +
          '{:5}'.format(samples) + ' (' +
          '{:4.2f}'.format(100.0 * csamp / samples) + '%)\n')
    
N_EPOCHS = 25

start_time = time.time()
model = ViT(image_size=28, patch_size=4, num_classes=10, channels=1,
            dim=64, depth=6, heads=8, mlp_dim=128)
optimz = optim.Adam(model.parameters(), lr=0.003)

trloss_val, tsloss_val = [], []
for epoch in range(1, N_EPOCHS + 1):
    print('Epoch:', epoch)
    train_iter(model, optimz, tr_load, trloss_val)
    evaluate(model, ts_load, tsloss_val)

print('Execution time:', '{:5.2f}'.format(time.time() - start_time), 'seconds')

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /data/mnist/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting /data/mnist/MNIST/raw/train-images-idx3-ubyte.gz to /data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting /data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to /data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting /data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to /data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting /data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to /data/mnist/MNIST/raw

Epoch: 1


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)



Average test loss: 0.2067  Accuracy: 9334/10000 (93.34%)

Epoch: 2

Average test loss: 0.1533  Accuracy: 9514/10000 (95.14%)

Epoch: 3

Average test loss: 0.1162  Accuracy: 9634/10000 (96.34%)

Epoch: 4

Average test loss: 0.0881  Accuracy: 9716/10000 (97.16%)

Epoch: 5

Average test loss: 0.1074  Accuracy: 9635/10000 (96.35%)

Epoch: 6

Average test loss: 0.0963  Accuracy: 9694/10000 (96.94%)

Epoch: 7

Average test loss: 0.0759  Accuracy: 9762/10000 (97.62%)

Epoch: 8

Average test loss: 0.0823  Accuracy: 9748/10000 (97.48%)

Epoch: 9
