In [2]:
device = 'cpu'
#device = 'cuda'

In [3]:
import torch.nn.functional as F
from torch import nn
import torch

class MSA(nn.Module):
    def __init__(self, input_dim, embed_dim, num_heads):
        '''
        input_dim: Dimension of input token embeddings
        embed_dim: Dimension of internal key, query, and value embeddings
        num_heads: Number of self-attention heads
        '''

        super().__init__()

        self.input_dim = input_dim
        self.embed_dim = embed_dim
        self.num_heads = num_heads

        self.K_embed = nn.Linear(input_dim, embed_dim, bias=False)
        self.Q_embed = nn.Linear(input_dim, embed_dim, bias=False)
        self.V_embed = nn.Linear(input_dim, embed_dim, bias=False)
        self.out_embed = nn.Linear(embed_dim, embed_dim, bias=False)

    def forward(self, x):
        '''
        x: input of shape (batch_size, max_length, input_dim)
        return: output of shape (batch_size, max_length, embed_dim)
        '''

        batch_size, max_length, given_input_dim = x.shape
        assert given_input_dim == self.input_dim
        assert max_length % self.num_heads == 0

        # Compute K, Q, V
        K = self.K_embed(x)  # (batch_size, max_length, embed_dim)
        Q = self.Q_embed(x)  # (batch_size, max_length, embed_dim)
        V = self.V_embed(x)  # (batch_size, max_length, embed_dim)

        # Split into heads
        indiv_dim = self.embed_dim // self.num_heads
        K = K.reshape(batch_size, max_length, self.num_heads, indiv_dim)
        Q = Q.reshape(batch_size, max_length, self.num_heads, indiv_dim)
        V = V.reshape(batch_size, max_length, self.num_heads, indiv_dim)

        # Permute dimensions for batch matrix multiplication
        K = K.permute(0, 2, 1, 3)  # (batch_size, num_heads, max_length, indiv_dim)
        Q = Q.permute(0, 2, 1, 3)
        V = V.permute(0, 2, 1, 3)

        # Reshape for batch matrix multiplication
        K = K.reshape(batch_size * self.num_heads, max_length, indiv_dim)
        Q = Q.reshape(batch_size * self.num_heads, max_length, indiv_dim)
        V = V.reshape(batch_size * self.num_heads, max_length, indiv_dim)

        # Compute QK^T
        K_T = K.permute(0, 2, 1)  # (batch_size * num_heads, indiv_dim, max_length)
        QK = torch.bmm(Q, K_T)  # (batch_size * num_heads, max_length, max_length)

        # Scale and apply softmax
        scaling_factor = indiv_dim**0.5
        weights = F.softmax(QK / scaling_factor, dim=-1)  # (batch_size * num_heads, max_length, max_length)

        # Weighted average of V
        w_V = torch.bmm(weights, V)  # (batch_size * num_heads, max_length, indiv_dim)

        # Rejoin heads
        w_V = w_V.reshape(batch_size, self.num_heads, max_length, indiv_dim)
        w_V = w_V.permute(0, 2, 1, 3)  # (batch_size, max_length, num_heads, indiv_dim)
        w_V = w_V.reshape(batch_size, max_length, self.embed_dim)

        # Final linear projection
        out = self.out_embed(w_V)

        return out


In [4]:
class ViTLayer(nn.Module):
    def __init__(self, num_heads, input_dim, embed_dim, mlp_hidden_dim, dropout=0.1):
        '''
        num_heads: Number of heads for multi-head self-attention
        embed_dim: Dimension of internal key, query, and value embeddings
        mlp_hidden_dim: Hidden dimension of the linear layer
        dropout: Dropout rate
        '''

        super().__init__()

        self.input_dim = input_dim
        self.msa = MSA(input_dim, embed_dim, num_heads)

        self.layernorm1 = nn.LayerNorm(embed_dim)
        self.w_o_dropout = nn.Dropout(dropout)
        self.layernorm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(nn.Linear(embed_dim, mlp_hidden_dim),
                                  nn.GELU(),
                                  nn.Dropout(dropout),
                                  nn.Linear(mlp_hidden_dim, embed_dim),
                                  nn.Dropout(dropout))

    def forward(self, x):
        '''
        x: input embeddings (batch_size, max_length, input_dim)
        return: output embeddings (batch_size, max_length, embed_dim)
        '''

        # Apply LayerNorm to input
        norm1 = self.layernorm1(x)

        # Multi-head self-attention and dropout
        msa_output = self.msa(norm1)
        msa_output = self.w_o_dropout(msa_output)

        # Residual connection
        residual1 = x + msa_output

        # LayerNorm after the first residual connection
        norm2 = self.layernorm2(residual1)

        # Pass through the MLP and dropout
        mlp_output = self.mlp(norm2)

        # Second residual connection
        output = residual1 + mlp_output

        return output

In [5]:
class ViT(nn.Module):
    def __init__(self, patch_dim, image_dim, num_layers, num_heads, embed_dim, mlp_hidden_dim, num_classes, dropout):
        '''
        patch_dim: patch length and width to split image by
        image_dim: image length and width
        num_layers: number of layers in network
        num_heads: number of heads for multi-head attention
        embed_dim: dimension to project images patches to and dimension to use for position embeddings
        mlp_hidden_dim: hidden dimension of linear layer
        num_classes: number of classes to classify in data
        dropout: dropout rate
        '''

        super().__init__()
        self.num_layers = num_layers
        self.patch_dim = patch_dim
        self.image_dim = image_dim
        self.input_dim = self.patch_dim * self.patch_dim * 3
        self.num_heads = num_heads

        self.patch_embedding = nn.Linear(self.input_dim, embed_dim)
        self.position_embedding = nn.Parameter(torch.zeros(1, (image_dim // patch_dim) ** 2 + 1, embed_dim))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.embedding_dropout = nn.Dropout(dropout)

        self.encoder_layers = nn.ModuleList([])
        for i in range(num_layers):
            self.encoder_layers.append(ViTLayer(num_heads, embed_dim, embed_dim, mlp_hidden_dim, dropout))

        self.mlp_head = nn.Linear(embed_dim, num_classes)
        self.layernorm = nn.LayerNorm(embed_dim)

    def forward(self, images):
        '''
        images: raw image data (batch_size, channels, rows, cols)
        '''

        # Preprocess image patches
        h = w = self.image_dim // self.patch_dim
        N = images.size(0)
        images = images.reshape(N, 3, h, self.patch_dim, w, self.patch_dim)
        images = torch.einsum("nchpwq -> nhwpqc", images)
        patches = images.reshape(N, h * w, self.input_dim)  # (batch, num_patches_per_image, patch_size_unrolled)

        # Compute patch embeddings and add positional embeddings
        patch_embeddings = self.patch_embedding(patches)  # Pass through patch embedding layer
        patch_embeddings = torch.cat([self.cls_token.expand(N, -1, -1), patch_embeddings], dim=1)
        out = patch_embeddings + self.position_embedding.expand(N, -1, -1)  # Add positional embeddings
        out = self.embedding_dropout(out)  # Apply dropout

        # Add padding so the input length is a multiple of num_heads
        add_len = (self.num_heads - out.shape[1]) % self.num_heads
        if add_len > 0:
            out = torch.cat([out, torch.zeros(N, add_len, out.shape[2], device=out.device)], dim=1)

        # Pass through the encoder layers
        for layer in self.encoder_layers:
            out = layer(out)

        # Classification token
        cls_head = self.layernorm(out[:, 0])  # Extract the classification token and normalize
        logits = self.mlp_head(cls_head)  # Pass through the final MLP head
        return logits

def get_vit_tiny(num_classes=10, patch_dim=4, image_dim=32):
    return ViT(patch_dim=patch_dim, image_dim=image_dim, num_layers=12, num_heads=3,
               embed_dim=192, mlp_hidden_dim=768, num_classes=num_classes, dropout=0.1)

def get_vit_small(num_classes=10, patch_dim=4, image_dim=32):
    return ViT(patch_dim=patch_dim, image_dim=image_dim, num_layers=12, num_heads=6,
               embed_dim=384, mlp_hidden_dim=1536, num_classes=num_classes, dropout=0.1)

def get_vit_base(num_classes=10, patch_dim=4, image_dim=32):
    return ViT(patch_dim=patch_dim, image_dim=image_dim, num_layers=12, num_heads=12,
               embed_dim=768, mlp_hidden_dim=3072, num_classes=num_classes, dropout=0.1)


In [9]:
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
import torchvision.datasets as datasets
import torchvision
import torch
import math
import torch.optim as optim
from tqdm.notebook import tqdm

# Data Preparation
data_root = './data/cifar10'
train_size = 40000
val_size = 10000

batch_size = 32

transform_train = T.Compose([
    T.Resize(40),
    T.RandomCrop(32),
    T.RandomHorizontalFlip(),
    T.RandomAffine(degrees=0, translate=(0.2, 0.2), scale=(0.95, 1.05)),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

transform_val = T.Compose([
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

train_dataset = datasets.CIFAR10(
    root=data_root,
    train=True,
    download=True,
    transform=transform_train,
)

val_dataset = datasets.CIFAR10(
    root=data_root,
    train=True,
    download=True,
    transform=transform_val,
)

from torch.utils.data import sampler

train_loader = DataLoader(train_dataset, batch_size=batch_size,
                          sampler=sampler.SubsetRandomSampler(range(train_size)))

val_loader = DataLoader(val_dataset, batch_size=batch_size,
                        sampler=sampler.SubsetRandomSampler(range(train_size, 50000)))

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# Model, Loss, Optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vit = get_vit_small().to(device)

learning_rate = 5e-4 * batch_size / 256
num_epochs = 30
weight_decay = 0.1

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(vit.parameters(), lr=learning_rate, betas=(0.9, 0.95), weight_decay=weight_decay)

# Training Loop
train_losses = []
val_losses = []
val_accuracies = []

for epoch in range(num_epochs):
    train_loss = 0.0
    train_acc = 0.0
    train_total = 0
    vit.train()

    for inputs, labels in tqdm(train_loader):
        # 1. Set inputs and labels to be on device
        inputs = inputs.to(device)
        labels = labels.to(device)

        # 2. Zero out our gradients
        optimizer.zero_grad()

        # 3. Pass our inputs through the ViT
        outputs = vit(inputs)

        # 4. Pass our outputs / labels into our loss / criterion
        loss = criterion(outputs, labels)

        # 5. Backpropagate
        loss.backward()

        # 6. Step our optimizer
        optimizer.step()

        # Track training loss and accuracy
        train_loss += loss.item() * inputs.shape[0]
        train_acc += torch.sum((torch.argmax(outputs, dim=1) == labels)).item()
        train_total += inputs.shape[0]

    train_loss = train_loss / train_total
    train_acc = train_acc / train_total
    train_losses.append(train_loss)

    # Validation Loop
    val_loss = 0.0
    val_acc = 0.0
    val_total = 0
    vit.eval()
    with torch.no_grad():
        for inputs, labels in val_loader:
            # Move inputs and labels to device
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Pass inputs through ViT
            outputs = vit(inputs)

            # Compute validation loss
            loss = criterion(outputs, labels.long())

            # Track validation loss and accuracy
            val_loss += loss.item() * inputs.shape[0]
            val_acc += torch.sum((torch.argmax(outputs, dim=1) == labels)).item()
            val_total += inputs.shape[0]

    val_loss = val_loss / val_total
    val_acc = val_acc / val_total
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)

    # Save the best model
    if val_acc >= max(val_accuracies):
        torch.save(vit.state_dict(), 'best_model.pth')

    print(f'[{epoch + 1:2d}] train loss: {train_loss:.3f} | train accuracy: {train_acc:.3f} | val loss: {val_loss:.3f} | val accuracy: {val_acc:.3f}')

print('Finished Training')

Files already downloaded and verified
Files already downloaded and verified


  0%|          | 0/1250 [00:00<?, ?it/s]

[ 1] train loss: 1.837 | train accuracy: 0.304 | val loss: 1.667 | val accuracy: 0.389


  0%|          | 0/1250 [00:00<?, ?it/s]

[ 2] train loss: 1.598 | train accuracy: 0.410 | val loss: 1.655 | val accuracy: 0.408


  0%|          | 0/1250 [00:00<?, ?it/s]

[ 3] train loss: 1.464 | train accuracy: 0.464 | val loss: 1.497 | val accuracy: 0.457


  0%|          | 0/1250 [00:00<?, ?it/s]

[ 4] train loss: 1.385 | train accuracy: 0.495 | val loss: 1.490 | val accuracy: 0.480


  0%|          | 0/1250 [00:00<?, ?it/s]

[ 5] train loss: 1.316 | train accuracy: 0.525 | val loss: 1.448 | val accuracy: 0.505


  0%|          | 0/1250 [00:00<?, ?it/s]

[ 6] train loss: 1.266 | train accuracy: 0.542 | val loss: 1.375 | val accuracy: 0.519


  0%|          | 0/1250 [00:00<?, ?it/s]

[ 7] train loss: 1.226 | train accuracy: 0.558 | val loss: 1.372 | val accuracy: 0.516


  0%|          | 0/1250 [00:00<?, ?it/s]

[ 8] train loss: 1.190 | train accuracy: 0.573 | val loss: 1.406 | val accuracy: 0.516


  0%|          | 0/1250 [00:00<?, ?it/s]

[ 9] train loss: 1.144 | train accuracy: 0.589 | val loss: 1.344 | val accuracy: 0.554


  0%|          | 0/1250 [00:00<?, ?it/s]

[10] train loss: 1.120 | train accuracy: 0.600 | val loss: 1.269 | val accuracy: 0.558


  0%|          | 0/1250 [00:00<?, ?it/s]

[11] train loss: 1.089 | train accuracy: 0.612 | val loss: 1.283 | val accuracy: 0.569


  0%|          | 0/1250 [00:00<?, ?it/s]

[12] train loss: 1.062 | train accuracy: 0.619 | val loss: 1.209 | val accuracy: 0.575


  0%|          | 0/1250 [00:00<?, ?it/s]

[13] train loss: 1.041 | train accuracy: 0.629 | val loss: 1.264 | val accuracy: 0.571


  0%|          | 0/1250 [00:00<?, ?it/s]

[14] train loss: 1.017 | train accuracy: 0.638 | val loss: 1.266 | val accuracy: 0.572


  0%|          | 0/1250 [00:00<?, ?it/s]

[15] train loss: 0.991 | train accuracy: 0.647 | val loss: 1.291 | val accuracy: 0.571


  0%|          | 0/1250 [00:00<?, ?it/s]

[16] train loss: 0.975 | train accuracy: 0.654 | val loss: 1.200 | val accuracy: 0.594


  0%|          | 0/1250 [00:00<?, ?it/s]

[17] train loss: 0.951 | train accuracy: 0.662 | val loss: 1.131 | val accuracy: 0.612


  0%|          | 0/1250 [00:00<?, ?it/s]

[18] train loss: 0.937 | train accuracy: 0.665 | val loss: 1.161 | val accuracy: 0.600


  0%|          | 0/1250 [00:00<?, ?it/s]

[19] train loss: 0.918 | train accuracy: 0.676 | val loss: 1.208 | val accuracy: 0.608


  0%|          | 0/1250 [00:00<?, ?it/s]

[20] train loss: 0.902 | train accuracy: 0.680 | val loss: 1.414 | val accuracy: 0.557


  0%|          | 0/1250 [00:00<?, ?it/s]

[21] train loss: 0.887 | train accuracy: 0.685 | val loss: 1.028 | val accuracy: 0.651


  0%|          | 0/1250 [00:00<?, ?it/s]

[22] train loss: 0.864 | train accuracy: 0.695 | val loss: 1.222 | val accuracy: 0.602


  0%|          | 0/1250 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [10]:
import pickle

cifar_test = datasets.CIFAR10('./data/cifar10_test', download = True, train = False, transform = transform_val)
loader_test = DataLoader(cifar_test, batch_size=32, shuffle=False)

vit.load_state_dict(torch.load('best_model.pth'))
vit.eval()  # set model to evaluation mode
predictions = []
with torch.no_grad():
    for x, _ in loader_test:
        x = x.to(device=device)  # move to device, e.g. GPU
        scores = vit(x)
        _, preds = scores.max(1)
        predictions.append(preds)
predictions = torch.cat(predictions).tolist()
with open("my_predictions.pickle", "wb") as file:
    pickle.dump(predictions, file)

Files already downloaded and verified


  vit.load_state_dict(torch.load('best_model.pth'))
