In [3]:
!git clone https://github.com/WouterBant/GEVit-DL2-Project.git

fatal: destination path 'GEVit-DL2-Project' already exists and is not an empty directory.


In [4]:
%cd GEVit-DL2-Project/

/content/GEVit-DL2-Project


In [5]:
%load_ext autoreload
%autoreload 2

In [6]:
import os,sys
g_selfatt_source =  os.path.join(os.getcwd(), '..')
if g_selfatt_source not in sys.path:
    sys.path.append(g_selfatt_source)

In [7]:
!pip install einops



In [8]:
import torch
from g_selfatt.utils import num_params
from torchvision import transforms
from datasets import MNIST_rot
import matplotlib.pyplot as plt
import torch.nn as nn

In [9]:
data_mean = (0.1307,)
data_stddev = (0.3081,)
train_test = transforms.Compose([
    transforms.RandomRotation(degrees=(-180, 180)),  # Random rotation between -15 to +15 degrees
    transforms.RandomHorizontalFlip(),  # Random horizontal flip with a probability of 0.5
    transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(data_mean, data_stddev)
])
transform_test = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(data_mean, data_stddev),
    ]
)
train_set = MNIST_rot(root="../data", stage="train", download=True, transform=train_test, data_fraction=1, only_3_and_8=False)
test_set = MNIST_rot(root="../data", stage="test", download=True, transform=transform_test, data_fraction=1, only_3_and_8=False)
train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=128,
    shuffle=True,
    num_workers=4,
)

test_loader = torch.utils.data.DataLoader(
    test_set,
    batch_size=128,
    shuffle=False,
    num_workers=4,
)



In [18]:
# https://lightning.ai/docs/pytorch/stable/notebooks/course_UvA-DL/11-vision-transformer.html
def img_to_patch(x, patch_size, flatten_channels=True):
    """
    Args:
        x: Tensor representing the image of shape [B, C, H, W]
        patch_size: Number of pixels per dimension of the patches (integer)
        flatten_channels: If True, the patches will be returned in a flattened format
                           as a feature vector instead of a image grid.
    """
    B, C, H, W = x.shape
    x = x.reshape(B, C, H // patch_size, patch_size, W // patch_size, patch_size)
    x = x.permute(0, 2, 4, 1, 3, 5)  # [B, H', W', C, p_H, p_W]
    x = x.flatten(1, 2)  # [B, H'*W', C, p_H, p_W]
    if flatten_channels:
        x = x.flatten(2, 4)  # [B, H'*W', C*p_H*p_W]
    return x

class AttentionBlock(nn.Module):
    def __init__(self, embed_dim, hidden_dim, num_heads, dropout=0.0):
        """Attention Block.

        Args:
            embed_dim: Dimensionality of input and attention feature vectors
            hidden_dim: Dimensionality of hidden layer in feed-forward network
                         (usually 2-4x larger than embed_dim)
            num_heads: Number of heads to use in the Multi-Head Attention block
            dropout: Amount of dropout to apply in the feed-forward network
        """
        super().__init__()

        self.layer_norm_1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads)
        self.layer_norm_2 = nn.LayerNorm(embed_dim)
        self.linear = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, embed_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        inp_x = self.layer_norm_1(x)
        x = x + self.attn(inp_x, inp_x, inp_x)[0]
        x = x + self.linear(self.layer_norm_2(x))
        return x

class VisionTransformer(nn.Module):
    def __init__(
        self,
        embed_dim,
        hidden_dim,
        num_channels,
        num_heads,
        num_layers,
        num_classes,
        patch_size,
        num_patches,
        dropout=0.0,
    ):
        """Vision Transformer.

        Args:
            embed_dim: Dimensionality of the input feature vectors to the Transformer
            hidden_dim: Dimensionality of the hidden layer in the feed-forward networks
                         within the Transformer
            num_channels: Number of channels of the input (3 for RGB)
            num_heads: Number of heads to use in the Multi-Head Attention block
            num_layers: Number of layers to use in the Transformer
            num_classes: Number of classes to predict
            patch_size: Number of pixels that the patches have per dimension
            num_patches: Maximum number of patches an image can have
            dropout: Amount of dropout to apply in the feed-forward network and
                      on the input encoding
        """
        super().__init__()

        self.patch_size = patch_size

        # Layers/Networks
        self.input_layer = nn.Linear(num_channels * (patch_size**2), embed_dim)
        self.transformer = nn.Sequential(
            *(AttentionBlock(embed_dim, hidden_dim, num_heads, dropout=dropout) for _ in range(num_layers))
        )
        self.mlp_head = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, num_classes))
        self.dropout = nn.Dropout(dropout)

        # Parameters/Embeddings
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.pos_embedding = nn.Parameter(torch.randn(1, 1 + num_patches, embed_dim))

    def forward(self, x, output_cls=False):
        # Preprocess input
        x = img_to_patch(x, self.patch_size)
        B, T, _ = x.shape
        x = self.input_layer(x)

        # Add CLS token and positional encoding
        cls_token = self.cls_token.repeat(B, 1, 1)
        x = torch.cat([cls_token, x], dim=1)
        x = x + self.pos_embedding[:, : T + 1]

        # Apply Transforrmer
        x = self.dropout(x)
        x = x.transpose(0, 1)
        x = self.transformer(x)

        # Perform classification prediction
        cls = x[0]
        if output_cls:
            return cls

        out = self.mlp_head(cls)
        return out

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = VisionTransformer(embed_dim=64,
                          hidden_dim=256,
                          num_heads=4,
                          num_layers=6,
                          patch_size=4,
                          num_channels=1,
                          num_patches=49,
                          num_classes=10,
                          dropout=0.1).to(device)

In [19]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), 0.0001)
# optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-3)

# scheduler for linear warmup of lr and then cosine decay
# linear_warmup = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1/10, end_factor=1.0, total_iters=10-1, last_epoch=-1, verbose=True)
# cos_decay = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=200-10, eta_min=1e-5, verbose=True)

In [20]:
print(f"Number of parameters in the model: {num_params(model)}")

Number of parameters in the model: 305034


In [None]:
model.train()
for epoch in range(200):
    losses = []
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)  # Move inputs and labels to device
        optimizer.zero_grad()
        out = model(inputs)
        loss = criterion(out, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()  # Update weights
        losses.append(loss.item())
    print(f"Epoch {epoch+1}, Average Loss: {sum(losses)/len(losses)}")

Epoch 1, Average Loss: 2.2753843929194195
Epoch 2, Average Loss: 2.069302521174467
Epoch 3, Average Loss: 1.9924414384214184
Epoch 4, Average Loss: 1.9486270267752153
Epoch 5, Average Loss: 1.8672694284704667
Epoch 6, Average Loss: 1.6589191472983058
Epoch 7, Average Loss: 1.5485421373874326
Epoch 8, Average Loss: 1.484795647331431
Epoch 9, Average Loss: 1.4429857670506345
Epoch 10, Average Loss: 1.4099895410899874
Epoch 11, Average Loss: 1.3983201693884935
Epoch 12, Average Loss: 1.3716289604766458
Epoch 13, Average Loss: 1.3461040748825557
Epoch 14, Average Loss: 1.3346410174913044
Epoch 15, Average Loss: 1.3168048632295826
Epoch 16, Average Loss: 1.3108432474015634
Epoch 17, Average Loss: 1.299592423288128
Epoch 18, Average Loss: 1.2833537844162952
Epoch 19, Average Loss: 1.278141312961337
Epoch 20, Average Loss: 1.2725274049783055
Epoch 21, Average Loss: 1.2627321662782114
Epoch 22, Average Loss: 1.251509363138223
Epoch 23, Average Loss: 1.232766849330709
Epoch 24, Average Loss: 1.

In [None]:
model.eval()  # Set the model to evaluation mode

correct = 0
total = 0
with torch.no_grad():  # Disable gradient calculation during inference
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)  # Move inputs and labels to device
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f"Accuracy on test set: {accuracy:.2f}%")