In [None]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.optim import AdamW
import os
import matplotlib.pyplot as plt

# Hyperparameters
exp_name = 'vit-with-10-epochs'
batch_size = 32
epochs = 10
lr = 1e-5
save_model_every = 0
device = "cuda" if torch.cuda.is_available() else "cpu"

config = {
    "patch_size": 7,
    "hidden_size": 64,
    "num_hidden_layers": 4,
    "num_attention_heads": 4,
    "intermediate_size": 4 * 64,  # 4 * hidden_size
    "hidden_dropout_prob": 0.0,
    "attention_probs_dropout_prob": 0.0,
    "initializer_range": 0.02,
    "image_size": 28,  # MNIST image size
    "num_classes": 10,  # Number of classes in MNIST
    "num_channels": 1,  # MNIST images are grayscale (1 channel)
    "qkv_bias": True,
    "use_faster_attention": True,
}

# Assert to make sure configurations are valid
assert config["hidden_size"] % config["num_attention_heads"] == 0
assert config['intermediate_size'] == 4 * config['hidden_size']
assert config['image_size'] % config['patch_size'] == 0


class ViTForClassification(nn.Module):
    """
    Vision Transformer (ViT) Model for Image Classification with L2 Regularization and Batch Normalization.
    """

    def __init__(self, config):
        super(ViTForClassification, self).__init__()
        self.num_patches = (config['image_size'] // config['patch_size']) ** 2
        self.embedding = nn.Linear(config['patch_size'] ** 2, config['hidden_size'])
        self.attention = nn.MultiheadAttention(config['hidden_size'], config['num_attention_heads'])
        self.fc = nn.Linear(config['hidden_size'], config['num_classes'])

        # Batch normalization layers for the hidden_size dimension
        self.batch_norm1 = nn.BatchNorm1d(config['hidden_size'])
        self.batch_norm2 = nn.BatchNorm1d(config['hidden_size'])

        self.dropout = nn.Dropout(p=0.3)  # 30% dropout for regularization

    def forward(self, x):
        batch_size = x.size(0)
        x = x.view(batch_size, self.num_patches, -1)  # Flattening the patches
        x = self.embedding(x)

        # Apply batch normalization after embedding (reshape to (batch_size, seq_len, hidden_size))
        x = x.transpose(1, 2)  # Shape (batch_size, hidden_size, seq_len)
        x = self.batch_norm1(x)  # Normalize over the hidden_size
        x = x.transpose(1, 2)  # Transpose back to (batch_size, seq_len, hidden_size)

        x = self.dropout(x)  # Apply dropout after batch norm

        # Apply attention (this is simplified for the example)
        x, _ = self.attention(x, x, x)

        # Batch normalization after attention
        x = x.transpose(1, 2)  # Shape (batch_size, hidden_size, seq_len)
        x = self.batch_norm2(x)
        x = x.transpose(1, 2)  # Transpose back to (batch_size, seq_len, hidden_size)

        x = self.dropout(x)  # Apply dropout after batch norm
        x = x.mean(dim=1)  # Pooling over sequence length (mean pooling)
        logits = self.fc(x)
        return logits, x


class Trainer:
    """
    The simple trainer.
    """

    def __init__(self, model, optimizer, loss_fn, exp_name, device):
        self.model = model.to(device)
        self.optimizer = optimizer
        self.loss_fn = loss_fn
        self.exp_name = exp_name
        self.device = device
        self.train_losses = []
        self.test_losses = []
        self.accuracies = []

    def train(self, trainloader, testloader, epochs, save_model_every_n_epochs=0):
        """
        Train the model for the specified number of epochs.
        """
        for i in range(epochs):
            train_loss = self.train_epoch(trainloader)
            accuracy, test_loss = self.evaluate(testloader)
            self.train_losses.append(train_loss)
            self.test_losses.append(test_loss)
            self.accuracies.append(accuracy)
            print(f"Epoch: {i+1}, Train loss: {train_loss:.4f}, Test loss: {test_loss:.4f}, Accuracy: {accuracy:.4f}")
            if save_model_every_n_epochs > 0 and (i+1) % save_model_every_n_epochs == 0 and i+1 != epochs:
                print('\tSave checkpoint at epoch', i+1)
                self.save_checkpoint(i+1)
        self.plot_metrics()

    def train_epoch(self, trainloader):
        """
        Train the model for one epoch.
        """
        self.model.train()
        total_loss = 0
        for batch in trainloader:
            batch = [t.to(self.device) for t in batch]
            images, labels = batch
            self.optimizer.zero_grad()
            loss = self.loss_fn(self.model(images)[0], labels)
            loss.backward()
            self.optimizer.step()
            total_loss += loss.item() * len(images)
        return total_loss / len(trainloader.dataset)

    @torch.no_grad()
    def evaluate(self, testloader):
        self.model.eval()
        total_loss = 0
        correct = 0
        with torch.no_grad():
            for batch in testloader:
                batch = [t.to(self.device) for t in batch]
                images, labels = batch
                logits, _ = self.model(images)
                loss = self.loss_fn(logits, labels)
                total_loss += loss.item() * len(images)
                predictions = torch.argmax(logits, dim=1)
                correct += torch.sum(predictions == labels).item()
        accuracy = correct / len(testloader.dataset)
        avg_loss = total_loss / len(testloader.dataset)
        return accuracy, avg_loss

    def save_checkpoint(self, epoch):
        checkpoint_dir = f"./checkpoints/{self.exp_name}"
        os.makedirs(checkpoint_dir, exist_ok=True)
        checkpoint_path = os.path.join(checkpoint_dir, f"epoch_{epoch}.pth")
        torch.save(self.model.state_dict(), checkpoint_path)
        print(f"Checkpoint saved to {checkpoint_path}")

    def save_experiment(self, train_losses, test_losses, accuracies):
        # Save experiment details, this can be enhanced as needed
        print("Saving experiment details...")
        # Implement saving experiment results, model, etc.

    def plot_metrics(self):
        # Plot training loss, test loss, and accuracy vs epochs
        epochs = range(1, len(self.train_losses) + 1)

        # Plot loss vs epoch
        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.plot(epochs, self.train_losses, label='Train Loss')
        plt.plot(epochs, self.test_losses, label='Test Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Train and Test Loss vs Epoch')
        plt.legend()

        # Plot accuracy vs epoch
        plt.subplot(1, 2, 2)
        plt.plot(epochs, self.accuracies, label='Accuracy', color='orange')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.title('Accuracy vs Epoch')
        plt.legend()

        plt.tight_layout()
        plt.show()


def prepare_data(batch_size):
    """
    Prepare MNIST dataset with the specified batch size and augmentations.
    """
    transform = transforms.Compose([
        transforms.RandomRotation(10),  # Random rotation of up to 10 degrees
        transforms.RandomHorizontalFlip(),  # Random horizontal flip
        transforms.RandomAffine(5),  # Random affine transformation (slight shift/scale)
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=test_transform)

    trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    testloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return trainloader, testloader


def main():
    trainloader, testloader = prepare_data(batch_size=batch_size)
    model = ViTForClassification(config)

    # Applying L2 regularization (weight decay) in optimizer
    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=1e-4)  # L2 regularization
    loss_fn = nn.CrossEntropyLoss()

    trainer = Trainer(model, optimizer, loss_fn, exp_name, device=device)
    trainer.train(trainloader, testloader, epochs, save_model_every_n_epochs=save_model_every)

if __name__ == '__main__':
    main()
