In [None]:
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch
import torch.nn as nn

In [None]:
import torch
from torchvision import datasets, transforms

# Define a transform to preprocess the data
transform = transforms.Compose([
    transforms.ToTensor(),   # Convert images to PyTorch tensors
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize the images
])

# Download and prepare the CIFAR-10 dataset
train_data = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_data = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Create DataLoaders to efficiently load and iterate through the dataset
train_loader = torch.utils.data.DataLoader(train_data, batch_size=256, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=256, shuffle=False)

In [None]:
class Patching(nn.Module):
    def __init__(self,
                 in_channels=3,
                 patch_size=16,
                 embedding_dim=768,
                ):
        super().__init__()
        self.patch = nn.Sequential(nn.Conv2d(in_channels, embedding_dim,
                                             kernel_size=(patch_size, patch_size),
                                             stride=(patch_size, patch_size),
                                            ),
                                   nn.Flatten(2, 3),
                )

    def forward(self, x):
        return self.patch(x).transpose(-2, -1)

In [None]:
p = Patching()

In [None]:
img = np.array(Image.open('art.jpg'))

In [None]:
img_ = torch.tensor(img).permute(2,0,1).unsqueeze(0).type(torch.float32)

In [None]:
p(img_).shape

In [None]:
# 1,512,512,3
# 1,32,32,16,16
# 1,1024,256

In [None]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import random

class Patching(nn.Module):
    def __init__(self,
                 in_channels=3,
                 patch_size=4,
                 embedding_dim=48,
                ):
        super().__init__()
        self.patch = nn.Sequential(nn.Conv2d(in_channels, embedding_dim,
                                             kernel_size=(patch_size, patch_size),
                                             stride=(patch_size, patch_size),
                                            ),
                                   nn.Flatten(2, 3),
                )

    def forward(self, x):
        return self.patch(x).transpose(-2, -1)

class Head(nn.Module):
    def __init__(self, n_embed, head_size):
        super().__init__()
        self.n_embed = n_embed
        self.query = nn.Linear(n_embed, head_size, bias=False)
        self.key = nn.Linear(n_embed, head_size, bias=False)
        self.value = nn.Linear(n_embed, head_size, bias=False)

    def forward(self, x, attention_mask=None):
        B,T,C = x.shape
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)
        w = torch.bmm(k,q.transpose(-2, -1)) * (self.n_embed ** -0.5)
        if attention_mask is not None:
            attention_mask = attention_mask.unsqueeze(-1).float()
            w = w * attention_mask
        w = F.softmax(w, dim=-1)
        out = torch.bmm(w,v)
        return out

class MultiHead(nn.Module):
    def __init__(self, n_embed, head_size, n_heads):
        super().__init__()
        self.heads = nn.ModuleList([Head(n_embed, head_size) for _ in range(n_heads)])
        self.proj = nn.Linear(n_embed,n_embed)
    def forward(self,x,attention_mask):
        out = torch.cat([head(x,attention_mask) for head in self.heads],-1)
        out = self.proj(out)
        return out

class FeedForward(nn.Module):
    def __init__(self, n_embed, mlp_ratio, dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed,n_embed * mlp_ratio),
            nn.ReLU(),
            nn.Linear(n_embed * mlp_ratio,n_embed),
            nn.Dropout(dropout),
        )
    def forward(self,x):
        return self.net(x)

class Block(nn.Module):
    def __init__(self, n_embed, head_size, n_heads, mlp_ratio, dropout):
        super().__init__()
        self.multihead = MultiHead(n_embed, head_size, n_heads)
        self.ffwd = FeedForward(n_embed, mlp_ratio, dropout)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)

    def forward(self,x,attention_mask):
        x = self.ln1(x)
        x = x + self.multihead(x,attention_mask)
        x = self.ln2(x)
        x = x + self.ffwd(x)
        return x


class ViT(nn.Module):
    def __init__(self, in_channels=3, patch_size=4, embedding_dim=48, head_size=12, n_heads=4, n_layers=7, dropout=0.4, mlp_ratio=2, device='cuda', block_size=64):
        super().__init__()
        self.patch_embedding = Patching(in_channels, patch_size, embedding_dim)
        self.positional_embedding = nn.Embedding(block_size+1, embedding_dim)
        self.blocks = nn.ModuleList([Block(embedding_dim, head_size, n_heads, mlp_ratio, dropout) for _ in range(n_layers)])
        self.ln = nn.LayerNorm(embedding_dim)
        self.class_embedding = nn.Parameter(torch.zeros(1, 1, embedding_dim),
                                        requires_grad=True)
        self.cl_head = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim * mlp_ratio),
            nn.ReLU(),
            nn.Linear(embedding_dim * mlp_ratio, 10)
        )
        self.sequence_pooling = nn.Linear(embedding_dim, 1)
    def forward(self,x,attention_mask=None,targets=None):
        ini_emb = self.patch_embedding(x)
        ini_emb = torch.cat([ini_emb,self.class_embedding.expand(x.shape[0],-1,-1)],dim=1)
        B,N,S =  ini_emb.shape
        pos_emb = self.positional_embedding(torch.arange(N,device=device))
        x = ini_emb + pos_emb
        for block in self.blocks:
            x = block(x,attention_mask)
        x = self.ln(x)
        x = x[:,0,:]
        x = self.cl_head(x)
        return x

In [None]:
model = ViT(device='cpu')
model(torch.randn(32,3,32,32)).shape

In [None]:
device = 'cpu'
model = ViT().to(device)
opt = torch.optim.AdamW(model.parameters(),lr=1e-3)
criterion = nn.CrossEntropyLoss()

In [None]:
opt = torch.optim.AdamW(model.parameters(),lr=1e-3)
criterion = nn.CrossEntropyLoss()
from tqdm import tqdm
accuracy = torch.tensor(0.0)
num_epochs = 50
train_loader = torch.utils.data.DataLoader(train_data, batch_size=512, shuffle=True)
for epoch in range(num_epochs):
    loop = tqdm(train_loader,leave=False)
    for x,y in loop:
        x = x.to(device)
        y = y.to(device)
        pred = model(x)
        loss = criterion(pred,y)
        opt.zero_grad()
        loss.backward()
        opt.step()
        loop.set_description(f"Epoch : [{epoch}/{num_epochs}]")
        loop.set_postfix(loss=loss.item(),accuracy = accuracy.item())
    if epoch % 2 == 0:
        model.eval()
        for x,y in test_loader:
            x = x.to(device)
            y = y.to(device)
            pred = model(x)
            pred = torch.argmax(pred,dim=1)
            break
        model.train();
        accuracy = (pred == y).type(torch.int32).sum() / len(pred)