<a href="https://colab.research.google.com/github/avijit-mukherjee-25/transformers/blob/main/Transformer_to_classify_MNIST_data.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# install libraries
!pip install transformers torch numpy datasets evaluate matplotlib

## Import dependencies

In [None]:
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [None]:
# If there's a GPU available...
if torch.cuda.is_available():

    # Tell PyTorch to use the GPU.
    device = torch.device("cuda")

    print('There are %d GPU(s) available.' % torch.cuda.device_count())

    print('We will use the GPU:', torch.cuda.get_device_name(0))

# If not...
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

## Get global hyper params

In [None]:
batch_size = 32
loss_fn = nn.CrossEntropyLoss()
learning_rate = 1e-4
num_epochs = 10
num_heads = 4
num_layers = 4
d_model = 512

## Get Data

In [None]:
train_dataset = datasets.MNIST(root="./datasets/", train=True, download=True, transform=ToTensor())
test_dataset = datasets.MNIST(root="./datasets/", train=False, download=True, transform=ToTensor())

In [None]:
train_dataset, test_dataset

In [None]:
# nornalize the data
imgs = torch.stack([img for img, _ in train_dataset], dim=0)
print (imgs.shape)
mean = imgs.view(1, -1).mean(dim=1)
std = imgs.view(1, -1).std(dim=1)
print (mean, std)

In [None]:
# transform data
mnist_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])

In [None]:
train_dataset = datasets.MNIST(root="./datasets/", train=True, download=False, transform=mnist_transforms)
test_dataset = datasets.MNIST(root="./datasets/", train=False, download=False, transform=mnist_transforms)
train_dataset, test_dataset

In [None]:
# pick a random image and plot
random_idx = np.random.randint(0, len(train_dataset))
print (random_idx)

img, label = train_dataset[random_idx]
print (img.shape)
plt.imshow(img.squeeze(), cmap='gray')
print (label)

In [None]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

# Let's see no of batches that we have now with the current batch-size
print (len(train_dataloader), len(test_dataloader))
print (len(train_dataset), len(train_dataloader.dataset))

## Define train and test functions

In [None]:
!pip install evaluate

In [None]:
import evaluate

In [None]:
def train(model, optimizer, loss_fn, train_dataloader):
    train_loss = 0.0
    model.train()
    metric = evaluate.load("accuracy")
    for step, (X, y) in enumerate(train_dataloader):
        X, y = X.to(device), y.to(device)
        logits = model(X)
        loss = loss_fn(logits, y)
        train_loss += loss.item()
        metric.add_batch(predictions=logits.argmax(dim=1), references=y)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if step%1000==0:
            print (f'training loss at {step} is {train_loss}')
    train_accuracy = metric.compute()
    return train_loss, train_accuracy

In [None]:
@torch.no_grad()
def eval(model, test_dataloader):
    model.eval()
    metric = evaluate.load("accuracy")
    test_loss = 0.0
    for _, (X, y) in enumerate(test_dataloader):
        X, y = X.to(device), y.to(device)
        logits = model(X)
        loss = loss_fn(logits, y)
        test_loss += loss.item()
        metric.add_batch(predictions=logits.argmax(dim=1), references=y)
    test_accuracy = metric.compute()
    model.train()
    return test_loss, test_accuracy

## LeNet

In [None]:
# define the model
class LeNet(nn.Module):

    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel (black & white), 6 output channels, 5x5 square convolution
        # kernel
        self.feature = nn.Sequential(
            #1
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2),   # 28*28->32*32-->28*28
            nn.Tanh(),
            nn.AvgPool2d(kernel_size=2, stride=2),  # 14*14

            #2
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1),  # 10*10
            nn.Tanh(),
            nn.AvgPool2d(kernel_size=2, stride=2),  # 5*5

        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=16*5*5, out_features=120),
            nn.Tanh(),
            nn.Linear(in_features=120, out_features=84),
            nn.Tanh(),
            nn.Linear(in_features=84, out_features=10),
        )

    def forward(self, x):
        return self.classifier(self.feature(x))

LeNet_model = LeNet()
LeNet_model.to(device)

In [None]:
# hyperparams
from torch.optim import AdamW
optimizer = AdamW(LeNet_model.parameters(), lr=learning_rate)

In [None]:
for epoch in range(num_epochs):
    print (f'Epoch: {epoch}')
    train_loss, train_accuracy = train(LeNet_model, optimizer, loss_fn, train_dataloader)
    print (f'train loss at epoch {epoch} is {train_loss}; train accuracy is {train_accuracy}')
    test_loss, test_accuracy = eval(LeNet_model, test_dataloader)
    print (f'test loss at epoch {epoch} is {test_loss}; test accuracy is {test_accuracy}')

## ViT

In [None]:
# define the model
class ViT(nn.Module):

    def __init__(self, num_heads, num_layers, d_model):
        super(ViT, self).__init__()
        self.conv = nn.Sequential(
            #1 conv
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2),   # 28*28->32*32-->28*28
            nn.Tanh(),
            nn.AvgPool2d(kernel_size=2, stride=2),  # batch_size*out_channels*14*14

            #2 conv
            nn.Conv2d(in_channels=6, out_channels=d_model, kernel_size=5, stride=1),  # 10*10
            nn.Tanh(),
            nn.AvgPool2d(kernel_size=2, stride=2),  # batch_size*out_channels*5*5
        )
        self.transformer_encoder_layer = nn.TransformerEncoderLayer(
               d_model=d_model, nhead=num_heads,
               dim_feedforward=int(d_model * 4),
               dropout=0.1,
               batch_first = True
        )
        self.transformer_encoder = nn.TransformerEncoder(
                   encoder_layer=self.transformer_encoder_layer,
                   num_layers=num_layers
        )
        self.linear = nn.Linear(d_model, 10)

    def forward(self, x):
        x = self.conv(x) # --> batch_size*d_model*5*5
        x = x.flatten(start_dim=2).permute(0,2,1) # --> batch_size*seq*d_model

        _batch_size = x.shape[0]
        cls_token = nn.Parameter(torch.randn(1, 1, d_model)).to(device)
        cls_tokens = cls_token.expand(_batch_size, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        _batch_size, _seq_len, _ = x.shape
        x += nn.Parameter(torch.randn(_batch_size, _seq_len, d_model)).to(device)

        x = self.transformer_encoder(x) # --> batch_size*(seq+1)*d_model

        out = self.linear(x[:,0,:])

        return out

ViT_model = ViT(num_heads, num_layers, d_model)
ViT_model.to(device)

In [None]:
from torch.optim import AdamW
optimizer = AdamW(ViT_model.parameters(), lr=1e-4)

In [None]:
for epoch in range(num_epochs):
    print (f'Epoch: {epoch}')
    train_loss, train_accuracy = train(ViT_model, optimizer, loss_fn, train_dataloader)
    print (f'train loss at epoch {epoch} is {train_loss}; train accuracy is {train_accuracy}')
    test_loss, test_accuracy = eval(ViT_model, test_dataloader)
    print (f'test loss at epoch {epoch} is {test_loss}; test accuracy is {test_accuracy}')