In [1]:
from datasets import load_dataset
import pandas as pd

from tqdm import tqdm
import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset

torch.set_default_device("cuda" if torch.cuda.is_available() else "cpu")
img_to_tensor = transforms.ToTensor()

img_size = 32
patch_size = 4
n_patches = (img_size// patch_size)**2
n_heads = 4
trf_blocks = 4

n_classes = 10
embed_dim = 64
batch_size = 100

In [2]:
class ImageDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return img_to_tensor(self.data[idx]["img"]), torch.tensor(self.data[idx]["label"])
    
ds = load_dataset("uoft-cs/cifar10")

train_dl = DataLoader(ImageDataset(ds["train"]),batch_size=batch_size, shuffle=True, drop_last=True)
test_dl = DataLoader(ImageDataset(ds["test"]), batch_size=batch_size, shuffle=False)

In [3]:
class PatchEmbeddings(nn.Module):
    def __init__(self):
        super().__init__()
        self.embed = nn.Conv2d(3, embed_dim, kernel_size=(patch_size, patch_size), stride= patch_size)

    def forward(self,x):
        x = self.embed(x) # batch_size x embed_dim x img_size/patch_size x img_size/patch_size
        x = x.flatten(2).transpose(1,2) # batch_size x n_patches x embed_dim
        return x

class TransformerEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

        self.attention = nn.MultiheadAttention(embed_dim, n_heads, batch_first=True)
        
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, embed_dim*4),
            nn.GELU(),
            nn.Linear(embed_dim*4,embed_dim)
        )

    def forward(self, x):
        res1 = x
        x = self.norm1(x)
        x = self.attention(x, x, x)[0]
        x += res1

        res2 = x
        x = self.norm2(x)
        x = self.mlp(x)
        x += res2
        return x
    
class VisionTransformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.patch_embedding = PatchEmbeddings()
        self.cls_token = nn.Parameter(torch.randn(1,1,embed_dim)) # 1 x 1 x embed_dim
        self.pos_embed = nn.Parameter(torch.randn(1, n_patches + 1, embed_dim)) # 1 x numpatches+1 x embed_dim
        self.transformer_layers = nn.Sequential(*[TransformerEncoder() for _ in range(trf_blocks)]) # batch_size x n_patches x embed_dim

        self.out_head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, 10)
        )

    def forward(self,x):
        # batch_size x num_channels x img_size x img_size
        x = self.patch_embedding(x) # batch_size x num_patches x embed_dim
        B = len(x)

        cls_tokens = self.cls_token.expand(B , -1, -1) # batch_size x 1 x embed_dim

        x = torch.cat((cls_tokens, x), dim=1) # batch_size x (numpatches+1) x emebed_dim
        x = x + self.pos_embed #  batch_size x (numpatches+1) x emebed_dim
        x = self.transformer_layers(x) #  batch_size x (numpatches+1) x emebed_dim
        # getting output from first token
        x = x[:,0]  # batchsize x embed_dim
        x = self.out_head(x) # batchsize x n_classes
        return x

In [4]:
epochs = 10   
model = VisionTransformer()
opt = torch.optim.AdamW(model.parameters())
loss_fn = nn.CrossEntropyLoss()
history = []

for epoch in range(epochs):
    total_loss = 0
    correct = 0

    for input_batch, label_batch in tqdm(train_dl):
        opt.zero_grad()

        pred_batch = model(input_batch)
        loss = loss_fn(pred_batch, label_batch)
        loss.backward()

        opt.step()

        with torch.no_grad():
            total_loss += loss.item()
            for i,label in enumerate(label_batch):
                if pred_batch[i,label.item()] == pred_batch[i].max():
                    correct+=1

    
    train_loss = total_loss/ len(train_dl)
    train_acc = correct / (len(train_dl) * batch_size)
    
    model.eval()
    test_loss, test_correct, m = 0, 0, 0

    with torch.no_grad():
        for input_batch, label_batch in test_dl:
            logits = model(input_batch)
            loss = loss_fn(logits, label_batch)

            test_loss += loss.item() * input_batch.size(0)
            preds = logits.argmax(dim=1)
            test_correct += (preds == label_batch).sum().item()
            m += input_batch.size(0)

    test_loss /= m
    test_acc = test_correct / m

    metrics = {
        'epoch': epoch + 1,
        'train_loss': train_loss,
        'train_acc': train_acc,
        'test_loss': test_loss,
        'test_acc': test_acc
    }
    print(metrics, "\n\n")
    # ---- Log metrics ----
    history.append(metrics)

    history_df = pd.DataFrame(history)
    history_df.to_csv("./history.csv", index=False)


    torch.save(model.state_dict(), "./vit_model.pth")

100%|██████████| 500/500 [01:08<00:00,  7.32it/s]


{'epoch': 1, 'train_loss': 1.9590074326992035, 'train_acc': 0.26922, 'test_loss': 1.7137407767772674, 'test_acc': 0.3781}


100%|██████████| 500/500 [01:08<00:00,  7.30it/s]


{'epoch': 2, 'train_loss': 1.622039264678955, 'train_acc': 0.41186, 'test_loss': 1.5276664400100708, 'test_acc': 0.4479}


100%|██████████| 500/500 [01:06<00:00,  7.46it/s]


{'epoch': 3, 'train_loss': 1.4793764526844024, 'train_acc': 0.46648, 'test_loss': 1.445084947347641, 'test_acc': 0.4797}


100%|██████████| 500/500 [01:09<00:00,  7.21it/s]


{'epoch': 4, 'train_loss': 1.3849075150489807, 'train_acc': 0.50098, 'test_loss': 1.3312592387199402, 'test_acc': 0.521}


100%|██████████| 500/500 [01:09<00:00,  7.15it/s]


{'epoch': 5, 'train_loss': 1.3072915108203889, 'train_acc': 0.5302, 'test_loss': 1.3084520983695984, 'test_acc': 0.5303}


100%|██████████| 500/500 [01:11<00:00,  6.97it/s]


{'epoch': 6, 'train_loss': 1.245615979909897, 'train_acc': 0.55426, 'test_loss': 1.257235858440399, 'test_acc': 0.5483}


100%|██████████| 500/500 [01:09<00:00,  7.24it/s]


{'epoch': 7, 'train_loss': 1.1933291679620743, 'train_acc': 0.57204, 'test_loss': 1.257988201379776, 'test_acc': 0.5481}


100%|██████████| 500/500 [01:08<00:00,  7.27it/s]


{'epoch': 8, 'train_loss': 1.1428585118055343, 'train_acc': 0.58946, 'test_loss': 1.2425014615058898, 'test_acc': 0.5535}


100%|██████████| 500/500 [01:08<00:00,  7.32it/s]


{'epoch': 9, 'train_loss': 1.096358889222145, 'train_acc': 0.6062, 'test_loss': 1.2333447873592376, 'test_acc': 0.5591}


100%|██████████| 500/500 [01:07<00:00,  7.38it/s]


{'epoch': 10, 'train_loss': 1.0568419719934463, 'train_acc': 0.62126, 'test_loss': 1.2011856669187546, 'test_acc': 0.5722}
