In [None]:
from pathlib import Path
try:
    from google.colab import drive
    drive.mount("/content/drive")
    DATA_DIR = Path("/content/mnist_data")
    EXP_DIR = Path("/content/drive/MyDrive/research/experiments")
except:
    DATA_DIR = Path("/home/avishkar/Desktop/research/mnist_data")
    EXP_DIR = Path("/home/avishkar/Desktop/research/experiments")
    print("About to get data")    

from torchvision.datasets import MNIST
from torchvision.transforms import transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

trainset = MNIST(root=DATA_DIR, train=True, download=True, transform= transform)
testset = MNIST(root=DATA_DIR, train=False, download=True, transform = transform)
print(f"Trainset: {len(trainset)} , Testset : {len(testset)}")
classes = trainset.classes

train_loader = DataLoader(trainset, batch_size=32, num_workers=2, shuffle=True)
test_loader = DataLoader(testset, batch_size=32, num_workers = 2, shuffle=False)

"""VISUALIZE DATA"""
import matplotlib.pyplot as plt

for i, (imgs, labels) in enumerate(train_loader):
    print(imgs.shape)
    img = imgs[0]
    plt.imshow(img.permute(1, 2, 0).cpu().numpy())
    plt.show()
    break

In [None]:
import torch
import torch.nn as nn

class PatchEmbeddings(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.img_size = config["img_size"]
        self.embed_dim = config["embed_dim"]
        self.patch_size = config["patch_size"]
        self.num_patches = (self.img_size // self.patch_size) **2
        self.num_channels = config["num_channels"]
        self.patcher = nn.Sequential(
            nn.Conv2d(
                in_channels=self.num_channels,
                out_channels=self.embed_dim,
                kernel_size=self.patch_size,
                stride=self.patch_size,
            ),
            nn.Flatten(2)
        )
        self.cls_token = nn.Parameter(
            torch.randn(size=(1, self.num_channels,self.embed_dim)),
            requires_grad=True
            )
        self.position_embeddings = nn.Parameter(
            torch.randn(size=(1, self.num_patches+1, self.embed_dim)),
            requires_grad=True
            )
        self.dropout = nn.Dropout(config["dropout"])
        
    def forward(self, x):
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        
        x = self.patcher(x).permute(0, 2, 1)
        x = torch.cat([cls_token,x ], dim=1)
        x = self.position_embeddings + x
        x = self.dropout(x)
        
        return x

class ViT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embed_dim = config["embed_dim"]
        self.num_heads = config["num_heads"]
        self.num_classes = config["num_classes"]
        
        self.embeddings = PatchEmbeddings(config)
        self.encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.embed_dim,
            nhead=self.num_heads,
            dropout=config["dropout"],
            activation = "gelu",
            batch_first=True,
            norm_first=True
        )
        self.encoder_blocks = nn.TransformerEncoder(
            self.encoder_layer,
            num_layers = config["num_layers"]
        )

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(self.embed_dim),
            nn.Linear(self.embed_dim, self.num_classes)
        )
        
    def forward(self, x):
        x = self.embeddings(x)
        x = self.encoder_blocks(x)
        x = self.mlp_head(x[:, 0, :]) # apply MLP on the CLS token only
        return x

In [None]:
"""UTILS"""
import json
from pathlib import Path

def save_checkpoint(state_dict, epoch, path):
    p = Path(path)
    if not p.exists():
        print("Creating folder")
        p.mkdir(parents=True, exist_ok=True)

    model_details = {
        "epoch":epoch,
        "state_dict": state_dict,
    }
    torch.save(model_details, f"{p}/vit_cifar10_{epoch}.pth")
    print(f"model saved at path : {p}/vit_cifar10_{epoch}.pth")


def load_pretrained(model, path, epoch):
    model.load_state_dict(torch.load(f"{path}/vit_cifar10_{epoch}.pth")["state_dict"])
    return model

def save_experiment(model, epoch, config, train_losses, test_losses, accuracies, path):
    exp_data = {
        "train_losses":train_losses,
        "test_losses":test_losses,
        "accuracies":accuracies,
        "epoch":epoch,
    }
    exp_name = config["exp_name"]
    config_file = path/f"{exp_name}"/"config.json"
    metrics_file = path/f"{exp_name}"/"metrics.json"
    files = [config_file , metrics_file]
    for file in files:
        if file.exists():
            print(f"{file} exists")
        else:
            file.parent.mkdir(parents=True, exist_ok=True)
            file.touch()
            print(f"{file} created")

    with open(config_file, "w") as f:
        json.dump(config, f, sort_keys=True, indent=4)
    with open(metrics_file, "w") as f:
        json.dump(exp_data, f, sort_keys=True, indent=4)

    save_checkpoint(model.state_dict(), epoch, path/f"{exp_name}")

def load_experiment(model ,exp_name, path):
    with open(path/f"{exp_name}"/"metrics.json", 'r') as file:
      data = json.load(file)
    train_losses=data["train_losses"]
    test_losses=data["test_losses"]
    accuracies=data["accuracies"]
    epoch=data["epoch"]

    model = load_pretrained(model, path/exp_name, epoch)

    return model, train_losses, test_losses, accuracies, epoch
    
    

In [None]:
config ={
        "img_size":28,
        "embed_dim":16, # (PATCH_SIZE ** 2) * IN_CHANNELS
        "patch_size":4,
        "dropout":0.01,
        "num_channels":1,
        "num_heads":4,
        "num_layers":8,
        "num_classes":10,
        "device": "cuda" if torch.cuda.is_available() else "cpu",
        "exp_name":"vit_mnist_40_epoch",
        "num_epoch":40
    }

import time
from tqdm import tqdm


def main():
    model = ViT(config).to(config["device"])
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-2)
    device = config["device"]
    
    start = time.time()
    for epoch in tqdm(range(config["num_epoch"]), position=0, leave=True):
        model.train()
        train_labels = []
        train_preds = []
        train_running_loss = 0
        for idx, img_label in enumerate(tqdm(train_loader, position=0, leave=True)):
            img = imgs.float().to(device)
            label = labels.type(torch.uint8).to(device)
            y_pred = model(img)
            y_pred_label = torch.argmax(y_pred, dim=1)

            train_labels.extend(label.cpu().detach())
            train_preds.extend(y_pred_label.cpu().detach())
            
            loss = criterion(y_pred, label)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_running_loss += loss.item()
        train_loss = train_running_loss / (idx + 1)
    
        model.eval()
        val_labels = []
        val_preds = []
        val_running_loss = 0
        with torch.no_grad():
            for idx, img_label in enumerate(tqdm(test_loader, position=0, leave=True)):
                img = imgs.float().to(device)
                label = labels.type(torch.uint8).to(device)        
                y_pred = model(img)
                y_pred_label = torch.argmax(y_pred, dim=1)
                
                val_labels.extend(label.cpu().detach())
                val_preds.extend(y_pred_label.cpu().detach())
                
                loss = criterion(y_pred, label)
                val_running_loss += loss.item()
        val_loss = val_running_loss / (idx + 1)

        print("-"*30)
        print(f"Train Loss EPOCH {epoch+1}: {train_loss:.4f}")
        print(f"Valid Loss EPOCH {epoch+1}: {val_loss:.4f}")
        print(f"Train Accuracy EPOCH {epoch+1}: {sum(1 for x,y in zip(train_preds, train_labels) if x == y) / len(train_labels):.4f}")
        print(f"Valid Accuracy EPOCH {epoch+1}: {sum(1 for x,y in zip(val_preds, val_labels) if x == y) / len(val_labels):.4f}")
        print("-"*30)
        
    stop = time.time()
    print(f"Training Time: {stop-start:.2f}s")

In [None]:
main()