In [None]:
import torch
import torch.nn as nn

class MultiheadAttentionEinsum(nn.Module):
    def __init__(self, embedding_dim, num_heads):
        super(MultiheadAttentionEinsum, self).__init__()
        self.num_heads = num_heads
        self.head_dim = embedding_dim // num_heads

        self.q_linear = nn.Linear(embedding_dim, embedding_dim)
        self.k_linear = nn.Linear(embedding_dim, embedding_dim)
        self.v_linear = nn.Linear(embedding_dim, embedding_dim)
        self.fc_out = nn.Linear(embedding_dim, embedding_dim)

    def forward(self, query, key, value):
        #######################################################
        #######################################################
        #######################################################
        ################  Must be implemented  ################
        #######################################################
        #######################################################
        #######################################################
        # Linear projection
        # out = self.fc_out(attended_values)
        return out

In [None]:
import torch
import torch.nn as nn

class TransformerEncoderLayer(nn.Module):
    def __init__(self, embedding_dim, num_heads):
        super(TransformerEncoderLayer, self).__init__()
        self.multihead_attention = MultiheadAttentionEinsum(embed_dim=embedding_dim, num_heads=num_heads)
        self.feed_forward = nn.Sequential(
            nn.Linear(embedding_dim, 2048),
            nn.ReLU(),
            nn.Linear(2048, embedding_dim)
        )
        self.layer_norm1 = nn.LayerNorm(embedding_dim)
        self.layer_norm2 = nn.LayerNorm(embedding_dim)

    def forward(self, x):
        residual = x
        x = self.layer_norm1(x)
        x = x.permute(1, 0, 2)  # (seq_len, batch_size, embedding_dim)
        attn_output = self.multihead_attention(x, x, x)[0]  # self-attention
        x = attn_output + residual
        x = x.permute(1, 0, 2)  # (batch_size, seq_len, embedding_dim)

        residual = x
        x = self.layer_norm2(x)
        x = self.feed_forward(x)
        x = x + residual

        return x


In [None]:
class VisionTransformer(nn.Module):
    def __init__(self, num_classes, patch_size, embedding_dim, num_heads, num_layers):
        super(VisionTransformer, self).__init__()
        self.patch_embedding = nn.Conv2d(3, embedding_dim, kernel_size=patch_size, stride=patch_size)
        self.positional_encoding = nn.Parameter(torch.randn(1, 14 * 14 + 1, embedding_dim))
        self.transformer_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads) for _ in range(num_layers)
        ])
        self.fc = nn.Linear(embedding_dim, num_classes)

    def forward(self, x):
        batch_size = x.size(0)
        x = self.patch_embedding(x)
        x = x.flatten(2).transpose(1, 2)
        x = torch.cat((x, self.positional_encoding.repeat(batch_size, 1, 1)), dim=1)
        for layer in self.transformer_layers:
            x = layer(x)
        x = x.mean(dim=1)
        x = self.fc(x)
        return x

In [None]:
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
num_epochs = 10
batch_size = 64
learning_rate = 0.001
num_classes = 10
patch_size = 16
embedding_dim = 128
num_heads = 8
num_layers = 3

# CIFAR-10 dataset preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load CIFAR-10 dataset
train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)

# Data loaders
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:05<00:00, 29405465.19it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [None]:
# Initialize the model
model = VisionTransformer(num_classes, patch_size, embedding_dim, num_heads, num_layers).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
total_steps = len(train_loader)
for epoch in range(num_epochs):
    #######################################################
    #######################################################
    #######################################################
    ################  Must be implemented  ################
    #######################################################
    #######################################################
    #######################################################


Epoch [1/10], Step [100/782], Loss: 2.1098
Epoch [1/10], Step [200/782], Loss: 2.1042
Epoch [1/10], Step [300/782], Loss: 2.0569
Epoch [1/10], Step [400/782], Loss: 2.0043
Epoch [1/10], Step [500/782], Loss: 1.9699
Epoch [1/10], Step [600/782], Loss: 2.0599
Epoch [1/10], Step [700/782], Loss: 1.7518
Epoch [2/10], Step [100/782], Loss: 2.0259
Epoch [2/10], Step [200/782], Loss: 1.7581
Epoch [2/10], Step [300/782], Loss: 1.9223
Epoch [2/10], Step [400/782], Loss: 1.8391
Epoch [2/10], Step [500/782], Loss: 1.9228
Epoch [2/10], Step [600/782], Loss: 1.7407
Epoch [2/10], Step [700/782], Loss: 1.8338
Epoch [3/10], Step [100/782], Loss: 1.7690
Epoch [3/10], Step [200/782], Loss: 1.6399
Epoch [3/10], Step [300/782], Loss: 1.6874
Epoch [3/10], Step [400/782], Loss: 1.8803
Epoch [3/10], Step [500/782], Loss: 1.7312
Epoch [3/10], Step [600/782], Loss: 1.6702
Epoch [3/10], Step [700/782], Loss: 1.6725
Epoch [4/10], Step [100/782], Loss: 1.4914
Epoch [4/10], Step [200/782], Loss: 1.6974


In [None]:
#Testing Phase
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    #######################################################
    #######################################################
    #######################################################
    ################  Must be implemented  ################
    #######################################################
    #######################################################
    #######################################################

    # accuracy = 100 * correct / total
    # print(f'Test Accuracy of the model on the {total} test images: {accuracy:.2f}%')