In [None]:
from pathlib import Path
try:
    from google.colab import drive
    drive.mount("/content/drive")
    DATA_DIR = Path("/content/mnist_data")
except:
    DATA_DIR = Path("/home/avishkar/Desktop/research/mnist_data")
    print("About to get data")    

from torchvision.datasets import MNIST
from torchvision.transforms import transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

trainset = MNIST(root=DATA_DIR, train=True, download=True, transform= transform)
testset = MNIST(root=DATA_DIR, train=False, download=True, transform = transform)
print(f"Trainset: {len(trainset)} , Testset : {len(testset)}")
classes = trainset.classes

train_loader = DataLoader(trainset, batch_size=32, num_workers=2, shuffle=True)
test_loader = DataLoader(testset, batch_size=32, num_workers = 2, shuffle=False)

"""VISUALIZE DATA"""
import matplotlib.pyplot as plt

for i, (imgs, labels) in enumerate(train_loader):
    print(imgs.shape)
    img = imgs[0]
    plt.imshow(img.permute(1, 2, 0).cpu().numpy())
    plt.show()
    break

In [None]:
import torch
import torch.nn as nn

class PatchEmbeddings(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.img_size = config["img_size"]
        self.embed_dim = config["embed_dim"]
        self.patch_size = config["patch_size"]
        self.num_patches = (self.img_size // self.patch_size) **2
        self.num_channels = config["num_channels"]
        self.patcher = nn.Sequential(
            nn.Conv2d(
                in_channels=self.num_channels,
                out_channels=self.embed_dim,
                kernel_size=self.patch_size,
                stride=self.patch_size,
            ),
            nn.Flatten(2)
        )
        self.cls_token = nn.Parameter(
            torch.randn(size=(1, self.num_channels,self.embed_dim)),
            requires_grad=True
            )
        self.position_embeddings = nn.Parameter(
            torch.randn(size=(1, self.num_patches+1, self.embed_dim)),
            requires_grad=True
            )
        self.dropout = nn.Dropout(config["dropout"])
        
    def forward(self, x):
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        
        x = self.patcher(x).permute(0, 2, 1)
        x = torch.cat([cls_token,x ], dim=1)
        x = self.position_embeddings + x
        x = self.dropout(x)
        
        return x

class ViT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embed_dim = config["embed_dim"]
        self.num_heads = config["num_heads"]
        self.num_classes = config["num_classes"]
        
        self.embeddings = PatchEmbeddings(config)
        self.encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.embed_dim,
            nhead=self.num_heads,
            dropout=config["dropout"],
            activation = "gelu",
            batch_first=True,
            norm_first=True
        )
        self.encoder_blocks = nn.TransformerEncoder(
            self.encoder_layer,
            num_layers = config["num_layers"]
        )

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(self.embed_dim),
            nn.Linear(self.embed_dim, self.num_classes)
        )
        
    def forward(self, x):
        x = self.embeddings(x)
        x = self.encoder_blocks(x)
        x = self.mlp_head(x[:, 0, :]) # apply MLP on the CLS token only
        return x

In [None]:
config ={
        "img_size":28,
        "embed_dim":768,
        "patch_size":6,
        "dropout":0.01,
        "num_channels":1,
        "num_heads":4,
        "num_layers":8,
        "num_classes":10,
        "device": "cuda" if torch.cuda.is_available() else "cpu",
    }

import time

class Trainer:
    def __init__(self, model, optimizer, criterion, device):
        self.model=model 
        self.optim = optimizer
        self.loss = criterion
        self.device = device

    def train(self, train_loader, test_loader, num_epochs):
        train_losses = []
        test_losses = []
        accuracies = [] 
        start = time.time()
        for i in range(num_epochs):
            ep_start = time.time()
            train_loss = self.train_epoch(train_loader)
            accuracy, test_loss = self.evaluate(test_loader)
            train_losses.append(train_loss)
            test_losses.append(test_loss)
            accuracies.append(accuracy)
            ep_end = time.time()
            print(f"Epoch: {i+1}, Train loss: {train_loss:.4f}, Test loss: {test_loss:.4f}, Accuracy: {accuracy:.4f}, Ep Time:{(ep_end-ep_start):.4f}s")
        end = time.time()
        print(f"Total Training Time {(end-start):.4f}s")

    def train_epoch(self, train_loader):
        self.model.train()
        total_loss = 0
        for i, (imgs, labels) in enumerate(train_loader):
            imgs = imgs.to(self.device)
            labels = labels.to(self.device)

            self.optim.zero_grad()
            predictions = self.model(imgs)
            # predictions = torch.argmax(predictions, dim=1)
            
            loss = self.loss(predictions, labels)
            loss.backward()
            self.optim.step()
            total_loss += loss.item()* len(imgs)

        return total_loss / len(train_loader.dataset)
        
    @torch.no_grad()
    def evaluate(self, test_loader):
        self.model.eval()
        total_loss = 0
        correct = 0
        with torch.no_grad():
            for i, (imgs, labels) in enumerate(test_loader):
                imgs = imgs.to(self.device)
                labels = labels.to(self.device)

                predictions = self.model(imgs)
                # predictions = torch.argmax(predictions, dim=1)
                
                
                loss = self.loss(predictions, labels)
                total_loss += loss.item() * len(imgs)

                 # Calculate the accuracy
                predictions = torch.argmax(predictions, dim=1)
                correct += torch.sum(predictions == labels).item()

        accuracy = correct / len(test_loader.dataset)
        avg_loss = total_loss / len(test_loader.dataset)
        return accuracy, avg_loss

def main():
    model = ViT(config)
    loss_func = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-2)
    trainer = Trainer(model, optimizer, loss_func, device=config["device"])
    trainer.train(train_loader, test_loader, 10)

In [None]:
main()

In [None]:
"""VIS Loss, ACC"""

"""
def load_experiment(model ,exp_name, path):
    data = json.load(path/f"exp_name"/"metrics.json")
    train_losses=data["train_losses"]
    test_losses=data["test_losses"]
    accuracies=data["accuracies"]
    epoch=data["epoch"]

    model = load_pretrained(model, path, epoch)

    return model, train_losses, test_losses, accuracies, epoch
    """


import matplotlib.pyplot as plt

model = ViT(config)
_, train_losses, test_losses, accuracies,_ = load_experiment(model, config["exp_name"], EXPERIMENT_DIR)
# Create two subplots of train/test losses and accuracies
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
ax1.plot(train_losses, label="Train loss")
ax1.plot(test_losses, label="Test loss")
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Loss")
ax1.legend()
ax2.plot(accuracies)
ax2.set_xlabel("Epoch")
ax2.set_ylabel("Accuracy")
plt.savefig("metrics.png")
plt.show()