In [1]:
import torch
from torch import nn
import torchvision
import torchvision.transforms.v2 as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import warnings
import os
import csv
cwd = os.getcwd()
warnings.filterwarnings('ignore')
from tqdm import tqdm
from datetime import datetime
from zoneinfo import ZoneInfo
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

In [2]:
from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR
from torch.optim import Adam,SGD,AdamW
import matplotlib.pyplot as plt

In [3]:
def get_timestamp():
    try:
        now = datetime.now(ZoneInfo("Asia/Kolkata"))
    except:
        print("Couldn't get Zone Info...")
        now = datetime.now()
    time_str = now.strftime("%d_%m_%H_%M")
    return time_str

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

### Dataset Initialization

In [5]:
train_transform = transforms.Compose([
    transforms.ToImage(),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomCrop(size=32,padding=4),
    transforms.ToDtype(torch.float32,scale=True),
    transforms.Normalize(mean=(0.4914, 0.4822, 0.4465),std=(0.247, 0.243, 0.261))
])

test_transform = transforms.Compose([
    transforms.ToImage(),
    transforms.ToDtype(torch.float32,scale=True),
    transforms.Normalize(mean=(0.4914, 0.4822, 0.4465),std=(0.247, 0.243, 0.261))
])

In [6]:
from torch.utils.data import default_collate

cutmix = transforms.CutMix(num_classes=10)
mixup = transforms.MixUp(num_classes=10)
cutmix_or_mixup = transforms.RandomChoice([cutmix, mixup])

def collate_fn(batch):
    return cutmix_or_mixup(*default_collate(batch))

In [None]:
# Train dataset
train_dataset = torchvision.datasets.CIFAR10(
    root=os.path.join(cwd,"./data"),
    train=True,
    download=True,
    transform=train_transform
)

# Test dataset
test_dataset = torchvision.datasets.CIFAR10(
    root=os.path.join(cwd,"./data"),
    train=False,
    download=True,
    transform=test_transform
)

In [None]:
from torch.utils.data import Subset

# train_subset = Subset(train_dataset, range(1000))
train_subset = train_dataset
# test_subset = Subset(test_dataset, range(10000))
test_subset = test_dataset

In [9]:
train_loader = DataLoader(train_subset, batch_size=256, shuffle=True, num_workers = 0,collate_fn=collate_fn)
test_loader = DataLoader(test_subset, batch_size=256, shuffle=False, num_workers = 0)

### Model Class

In [10]:
class FeedForward(nn.Module):
    def __init__(self, feature_dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(feature_dim),
            nn.Linear(feature_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, feature_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, feature_dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(feature_dim)

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(feature_dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, feature_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class ShiftedPatchTokenization(nn.Module):
    def __init__(self, *, feature_dim, patch_size, channels = 3):
        super().__init__()
        patch_dim = patch_size * patch_size * 5 * channels

        self.to_patch_tokens = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, feature_dim)
        )

    def forward(self, x):
        shifts = ((1, -1, 0, 0), (-1, 1, 0, 0), (0, 0, 1, -1), (0, 0, -1, 1))
        shifted_x = list(map(lambda shift: F.pad(x, shift), shifts))
        x_with_shifts = torch.cat((x, *shifted_x), dim = 1)
        return self.to_patch_tokens(x_with_shifts)

class Transformer(nn.Module):
    def __init__(self, feature_dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.norm = nn.LayerNorm(feature_dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(feature_dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(feature_dim, mlp_dim, dropout = dropout)
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x

        return self.norm(x)

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, feature_dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = (image_size,image_size)
        patch_height, patch_width = (patch_size,patch_size)

        assert (image_height % patch_height == 0) and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        # patch_dim = channels * patch_height * patch_width
        # self.to_patch_embedding = nn.Sequential(
        #     Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
        #     nn.LayerNorm(patch_dim),
        #     nn.Linear(patch_dim, feature_dim),
        #     nn.LayerNorm(feature_dim),
        # )

        self.to_patch_embedding = ShiftedPatchTokenization(
            feature_dim=feature_dim,
            patch_size=patch_size,
            channels=channels
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, feature_dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, feature_dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(feature_dim, depth, heads, dim_head, mlp_dim, dropout)

        self.mlp_head = nn.Linear(feature_dim, num_classes)

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x[:, 0] # get the CLS token

        return self.mlp_head(x)

In [11]:
# model = ViT(
#     image_size=32,
#     num_classes=10,
#     patch_size=4,
#     feature_dim=256,
#     mlp_dim=512,
#     depth=6,
#     heads=4,
#     dropout=0.3,
#     emb_dropout=0.3,
#     )

In [12]:
# print(model)

In [13]:
# from torchinfo import summary
# summary(model,input_size=(64,3,32,32))

### Training Loop

In [14]:
class SimpleLogger:
    def __init__(self) -> None:
        self.epoch = []
        self.train_loss = []
        self.train_acc = []
        self.eval_loss = []
        self.eval_acc = []
        
    def log(self, metrics):
        self.epoch.append(metrics['epoch'])
        self.train_loss.append(metrics['train_loss'])
        self.train_acc.append(metrics['train_acc'])
        self.eval_loss.append(metrics['eval_loss'])
        self.eval_acc.append(metrics['eval_acc'])

    def plot_metrics(self, save_path=None,display=False):
        if not self.epoch:
            print("No metrics to plot. Train the model first.")
            return
        
        epochs = self.epoch
        train_losses = self.train_loss
        eval_losses = self.eval_loss
        train_accs = self.train_acc
        eval_accs = self.eval_acc
        
        plt.close()

        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
        
        # Plot loss curves
        ax1.plot(epochs, train_losses, label='Train Loss', marker='o', linewidth=2)
        ax1.plot(epochs, eval_losses, label='Eval Loss', marker='s', linewidth=2)
        ax1.set_xlabel('Epochs', fontsize=12)
        ax1.set_ylabel('Loss', fontsize=12)
        ax1.set_title('Training and Evaluation Loss', fontsize=14, fontweight='bold')
        ax1.legend(loc='best')
        ax1.grid(True, alpha=0.3)
        
        # Plot accuracy curves
        ax2.plot(epochs, train_accs, label='Train Accuracy', marker='o', linewidth=2)
        ax2.plot(epochs, eval_accs, label='Eval Accuracy', marker='s', linewidth=2)
        ax2.set_xlabel('Epochs', fontsize=12)
        ax2.set_ylabel('Accuracy', fontsize=12)
        ax2.set_title('Training and Evaluation Accuracy', fontsize=14, fontweight='bold')
        ax2.legend(loc='best')
        ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        
        if save_path is None and display==False:
            raise AttributeError('Must provide either save path or option to display.')
        else:
            if save_path:
                plt.savefig(save_path+".png", dpi=300, bbox_inches='tight')
                print(f"Plot saved to {save_path}")
            if display:
                plt.show()
    
    def store_log_as_csv(self, save_path=None):
        if save_path is None:
            raise ValueError("Save path cannot be None")
        
        save_path += '.csv'
        
        rows = []
        rows.append(["epoch", "train_loss", "train_acc", "eval_loss", "eval_acc"])
        
        for i in range(len(self.epoch)):
            rows.append([
                self.epoch[i],
                self.train_loss[i],
                self.train_acc[i],
                self.eval_loss[i],
                self.eval_acc[i]
            ])
        
        with open(save_path, "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerows(rows)

In [15]:
class Trainer:
    def __init__(
        self,
        model,
        train_loader,
        test_loader,
        optimizer,
        loss_fn,
        scheduler=None,
        logger=None,
        ckpt_path="checkpoint.pt",
        device=device,
    ):
        self.model = model
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.optimizer = optimizer
        self.loss_fn = loss_fn
        self.scheduler = scheduler
        self.logger:SimpleLogger = logger
        self.ckpt_path = ckpt_path
        self.device = device
        self.model.to(self.device)
        self.best_acc = 0.0
        self.start_epoch = 0
        self._load_checkpoint()

    def _save_checkpoint(self, epoch, acc):
        state = {
            "epoch": epoch,
            "model_state": self.model.state_dict(),
            "optimizer_state": self.optimizer.state_dict(),
            "scheduler_state": self.scheduler.state_dict() if self.scheduler else None,
            "best_acc": acc,
        }
        torch.save(state, self.ckpt_path + '.pt')

    def _load_checkpoint(self):
        if os.path.isfile(self.ckpt_path + '.pt'):
            state = torch.load(self.ckpt_path + '.pt', map_location=self.device)
            self.model.load_state_dict(state["model_state"])
            self.optimizer.load_state_dict(state["optimizer_state"])
            if self.scheduler and state.get("scheduler_state"):
                self.scheduler.load_state_dict(state["scheduler_state"])
            self.best_acc = state.get("best_acc", 0.0)
            self.start_epoch = state.get("epoch", 0)

    def train(self, num_epochs):
        for epoch in range(self.start_epoch, num_epochs):
            print(f'Epoch {epoch}/{num_epochs}')
            train_loss, train_acc = self._train_one_epoch(epoch)
            eval_loss, eval_acc = self._evaluate(epoch)
            if self.logger:
                self.logger.log({
                    "epoch": epoch,
                    "train_loss": train_loss,
                    "train_acc": train_acc,
                    "eval_loss": eval_loss,
                    "eval_acc": eval_acc
                })
            if eval_acc > self.best_acc:
                self.best_acc = eval_acc
                print("Saving model...")
                self._save_checkpoint(epoch, eval_acc)
            if self.scheduler:
                self.scheduler.step()
            print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f} | Best Eval Acc: {self.best_acc:.2f}")
        
        self.logger.plot_metrics(save_path = self.ckpt_path)
        self.logger.store_log_as_csv(self.ckpt_path)
        return self.best_acc

    def _train_one_epoch(self, epoch):
        self.model.train()
        total_loss = 0.0
        total_correct = 0
        total_samples = 0

        pbar = tqdm(self.train_loader, total=len(self.train_loader),desc="Training: ")
        for batch in pbar:
            x, y = batch[0].to(self.device), batch[1].to(self.device)
            self.optimizer.zero_grad()
            outputs = self.model(x)

            if torch.is_floating_point(y) and y.dim() == 2:
                # y are soft labels
                log_probs = F.log_softmax(outputs, dim=1)
                # negative log-likelihood for soft labels
                loss = -(y * log_probs).sum(dim=1).mean()
            else:
                # assume standard class indices
                loss = self.loss_fn(outputs, y)
            
            loss = self.loss_fn(outputs, y)
            loss.backward()
            self.optimizer.step()
            total_loss += loss.item() * x.size(0)

            # need to calculate accuracy differently for soft labels
            labels = batch[1]
            if torch.is_floating_point(labels) and labels.dim() == 2:
                # compute accuracy by comparing argmax of soft labels
                target_indices = labels.argmax(dim=1).to(self.device)
                total_correct += (outputs.argmax(1) == target_indices).sum().item()
            else:
                total_correct += (outputs.argmax(1) == y).sum().item()
            total_samples += x.size(0)

        avg_loss = total_loss / total_samples
        avg_acc = total_correct / total_samples
        return avg_loss, avg_acc

    def _evaluate(self, epoch):
        self.model.eval()
        total_loss = 0.0
        total_correct = 0
        total_samples = 0

        with torch.no_grad():
            pbar = tqdm(self.test_loader, total=len(self.test_loader),desc="Evaluating: ")
            for batch in pbar:
                x, y = batch[0].to(self.device), batch[1].to(self.device)
                outputs = self.model(x)
                loss = self.loss_fn(outputs, y)
                total_loss += loss.item() * x.size(0)
                total_correct += (outputs.argmax(1) == y).sum().item()
                total_samples += x.size(0)
        avg_loss = total_loss / total_samples
        avg_acc = total_correct / total_samples
        return avg_loss, avg_acc

In [None]:
model = ViT(
image_size=32,
num_classes=10,
patch_size=4,
feature_dim=768,
mlp_dim=768*2,
depth=12,
heads=4,
dropout=0.3,
emb_dropout=0.3,
)

optimizer = AdamW(
    model.parameters(),
    lr=6e-4,
    betas=(0.9, 0.999),
    eps=1e-8,
    weight_decay=0.3
)

loss_fn = nn.CrossEntropyLoss()

num_epochs = 200
warmup_epochs = 30

# Linear warmup scheduler
warmup_scheduler = LinearLR(
    optimizer,
    start_factor=0.01,
    end_factor=1.0,
    total_iters=warmup_epochs
)

# Cosine annealing scheduler
cosine_scheduler = CosineAnnealingLR(
    optimizer,
    T_max=num_epochs - warmup_epochs,
    eta_min=1e-6
)

scheduler = SequentialLR(
    optimizer,
    schedulers=[warmup_scheduler, cosine_scheduler],
    milestones=[warmup_epochs]
)

logger = SimpleLogger()

# Instantiate the Trainer
save_name = f"checkpoint_{get_timestamp()}_SPT"
print(f'Saving to {save_name}')
trainer = Trainer(
    model=model,
    train_loader=train_loader,
    test_loader=test_loader,
    optimizer=optimizer,
    loss_fn=loss_fn,
    scheduler=scheduler,
    logger=logger,
    ckpt_path=os.path.join(cwd,save_name),
    device="cuda" if torch.cuda.is_available() else "cpu"
)

best_acc = trainer.train(num_epochs=num_epochs)

print(f'Best Acc: {best_acc}')