In [None]:
import torch
from torchvision import datasets, transforms

cifar10_mean, cifar10_std = [0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616]
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize(cifar10_mean, cifar10_std)])

# Download and prepare the CIFAR-10 dataset
train_data = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_data = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Create DataLoaders to efficiently load and iterate through the dataset
train_loader = torch.utils.data.DataLoader(train_data, batch_size=256, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=256, shuffle=False)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import random

class Tokenizer(nn.Module):
    def __init__(self, in_channels=3, n_embed=128):
        super().__init__()
        self.patch = nn.Sequential(
            nn.Conv2d(in_channels, n_embed, kernel_size=(3, 3), stride=1, padding=1, bias=False),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(3, 3), stride=2, padding=1),
        )

    def forward(self, x):
        return self.patch(x).flatten(2, 3).transpose(-2, -1)

class Head(nn.Module):
    def __init__(self, n_embed, head_size, dropout):
        super().__init__()
        self.qkv = nn.Linear(n_embed, head_size * 3, bias=False)
        self.attention_dropout = nn.Dropout(dropout)

    def forward(self, x, attention_mask=None):
        B, T, C = x.shape
        q, k, v = self.qkv(x).chunk(3, dim=2)
        w = torch.bmm(k, q.transpose(-2, -1)) * (n_embed ** -0.5)
        if attention_mask is not None:
            attention_mask = attention_mask.unsqueeze(-1).float()
            w = w * attention_mask
        w = F.softmax(w, dim=-1)
        w = self.attention_dropout(w)
        out = torch.bmm(w, v)
        return out

class MultiHead(nn.Module):
    def __init__(self, n_embed, head_size, n_heads,dropout):
        super().__init__()
        self.heads = nn.ModuleList([Head(n_embed, head_size, dropout) for _ in range(n_heads)])
        self.proj = nn.Linear(n_embed, n_embed)

    def forward(self, x, attention_mask):
        out = torch.cat([head(x, attention_mask) for head in self.heads], -1)
        out = self.proj(out)
        return out

class FeedForward(nn.Module):
    def __init__(self, n_embed, mlp_ratio, dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, n_embed * mlp_ratio),
            nn.GELU(),
            nn.Linear(n_embed * mlp_ratio, n_embed),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    def __init__(self, n_embed, head_size, n_heads, mlp_ratio, dropout):
        super().__init__()
        self.multihead = MultiHead(n_embed, head_size, n_heads,dropout)
        self.ffwd = FeedForward(n_embed, mlp_ratio, dropout)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)

    def forward(self, x, attention_mask):
        x = self.ln1(x)
        x = x + self.multihead(x, attention_mask)
        x = self.ln2(x)
        x = x + self.ffwd(x)
        return x

class ViT(nn.Module):
    def __init__(self, in_channels=3, n_embed=128, head_size=32, n_heads=4, n_layers=4, dropout=0.4, mlp_ratio=2, device='cuda'):
        super().__init__()
        self.tokenizer = Tokenizer(in_channels, n_embed)
        self.sequence_length = self.tokenizer(torch.randn(1, in_channels, 32, 32)).shape[1]
        self.blocks = nn.ModuleList([Block(n_embed, head_size, n_heads, mlp_ratio, dropout) for _ in range(n_layers)])
        self.ln = nn.LayerNorm(n_embed)
        self.cl_head = nn.Sequential(
            nn.Linear(n_embed, 10)
        )
        self.positional_embedding = nn.Parameter(torch.randn(1, self.sequence_length, n_embed, device=device), requires_grad=True)
        self.sequence_pooling = nn.Linear(n_embed, 1)

    def forward(self, x, attention_mask=None, targets=None):
        ini_emb = self.tokenizer(x)
        B, N, S = ini_emb.shape
        pos_emb = self.positional_embedding
        x = ini_emb + pos_emb
        for block in self.blocks:
            x = block(x, attention_mask)
        x = self.ln(x)
        seq_pool = self.sequence_pooling(x).transpose(-2, -1)
        seq_pool = torch.nn.functional.softmax(seq_pool, dim=2)
        x = torch.bmm(seq_pool, x).squeeze(1)
        x = self.cl_head(x)
        return x

In [None]:
device = 'cpu'
model = ViT(device='cpu').to(device)
opt = torch.optim.AdamW(model.parameters(),lr=1e-3)
criterion = nn.CrossEntropyLoss()

In [None]:
opt = torch.optim.AdamW(model.parameters(),lr=3e-4)
criterion = nn.CrossEntropyLoss()
from tqdm import tqdm
accuracy = torch.tensor(0.0)
num_epochs = 100
train_loader = torch.utils.data.DataLoader(train_data, batch_size=512, shuffle=True)
for epoch in range(num_epochs+1):
    loop = tqdm(train_loader,leave=False)
    for x,y in loop:
        x = x.to(device)
        y = y.to(device)
        pred = model(x)
        loss = criterion(pred,y)
        opt.zero_grad()
        loss.backward()
        opt.step()
        loop.set_description(f"Epoch : [{epoch}/{num_epochs}]")
        loop.set_postfix(loss=loss.item(),accuracy = accuracy.item())
    if epoch % 1 == 0:
        model.eval()
        for x,y in test_loader:
            x = x.to(device)
            y = y.to(device)
            pred = model(x)
            pred = torch.argmax(pred,dim=1)
            break
        model.train();
        accuracy = (pred == y).type(torch.int32).sum() / len(pred)

In [None]:
sum([p.numel() for p in model.parameters()]) / 1e6

In [None]:
torch.cuda.empty_cache()