In [1]:
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
from torchvision.transforms import Compose, ToTensor, Resize, Normalize
from tqdm.notebook import tqdm 
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, random_split
import torch.optim as optim

In [62]:
#Hyperparameters
patch_size=16
channels=3
latent_size=256
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
n_classes=10
batch_size=64
input_size=224
n_heads=4
base_lr=0.005
epochs=5

In [66]:
class Image_Embeddings(nn.Module):
    def __init__(self, patch_size, input_size=input_size,channels=channels, latent_size=latent_size):
        super(Image_Embeddings, self).__init__()
        self.p_size=patch_size
        self.ch=channels
        self.n_patches=(input_size**2)//patch_size
        self.in_size=(self.p_size**2)*self.ch # flattened patches
        self.l_size=latent_size
        self.device=device
        #size of the embeddings= (batch_size, n_size, patch_size**2 *n_channels)
        self.linear=nn.Linear(self.in_size, self.l_size)
        # After Linear Layer, in_size=(batch_size, n, latent_size)       
        
    def forward(self, img_data):
        # Divide the image into patches
        b, c, w, h=img_data.shape
        tokens=torch.empty(b, self.n_patches, self.in_size)
        for B in range(b):
            n=0
            for i in range(0, w-self.p_size, self.p_size):
                for j in range(0, h-self.p_size, self.p_size):
                    patch=img_data[B, :, i:i+self.p_size, j:j+self.p_size]
                    patch=torch.flatten(patch)
                    tokens[B,n]=patch
                    n+=1
                    
        # Linear projection of the tokens
        embeddings=self.linear(tokens)
        # print(embeddings.shape)
        
        # Add the positional encodings
        #d_model=self.l_size
        positions=torch.arange(self.n_patches).unsqueeze(1)
        angs = 10000**(torch.arange(self.l_size)/self.l_size).float()
        pos_enc=torch.zeros(b, self.n_patches, self.l_size)
        pos_enc[:,0::2]=torch.sin(positions[0::2]/angs)
        pos_enc[:, 1::2]=torch.cos(positions[1::2]/angs)
        embeddings+=pos_enc
        # print(embeddings.shape)
        
        # Append the class token
        # Random class token initialized
        self.class_token=nn.Parameter(torch.randn(b, 1, self.l_size)*0.1)
        # After appending the class token in_size= (batch_size, n+1, latent_size)
        embeddings=torch.cat((self.class_token, embeddings ), dim=1)
        # print(embeddings.shape)
        
        return embeddings

In [67]:
class Encoder(nn.Module):
    def __init__(self,patch_size, n_heads, input_size=input_size,channels=channels, latent_size=latent_size):
        super(Encoder, self).__init__()
        self.l_size=latent_size
        self.ch=channels
        self.heads=n_heads
        self.device=device
        
        # Image patch embeddings
        self.embedding=Image_Embeddings(patch_size, latent_size)

        #Layer Norm
        self.norm=nn.LayerNorm(self.l_size)
        
        # Self Attention Layer
        self.attention=nn.MultiheadAttention(self.l_size, self.heads)
        
        #Encoder Linear Layer
        self.mlp=nn.Sequential(
            nn.Linear(self.l_size, self.l_size*4),
            nn.GELU(),
            nn.Linear(self.l_size*4, self.l_size))
        
    def forward(self, img_data):
        embeddings=self.embedding.forward(img_data)
        n1=self.norm(embeddings)
        
        n1=n1.permute(1, 0,2)
        attn_out, att_weight=self.attention(n1,n1,n1)
        attn_out=attn_out.permute(1,0, 2)
        
        a1=attn_out+embeddings
        n2=self.norm(a1)
        linear=self.mlp(n2)
        a2=linear+a1
        # return the final output and the attention weights
        return a2, att_weight
        

In [68]:
class ViT(nn.Module):
    def __init__(self, n_classes,patch_size, n_heads,  channels=channels, latent_size=latent_size):
        super(ViT, self).__init__()
        self.l_size=latent_size
        self.classes=n_classes
        self.heads=n_heads
        self.device=device
        
        self.encoder=Encoder(patch_size,  self.heads)

        self.mlp=nn.Sequential(
            nn.Linear(self.l_size, self.classes), 
            nn.Softmax(dim=-1))
        
    def forward(self, input):
        out, wgt=self.encoder.forward(input)
        # Input only the class token to the linear layer
        mlp_out=self.mlp(out[:,0])
        return mlp_out, wgt


# Datastet Uplaod (CIFAR10)

In [69]:
manual_transforms =Compose([
    transforms.ToTensor(),
    transforms.Resize((224,224)),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])  
cifar_trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=manual_transforms,target_transform=transforms.Compose([
        lambda x: torch.tensor(x)]))
cifar_testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=manual_transforms, target_transform=transforms.Compose([
        lambda x: torch.tensor(x)]))

Files already downloaded and verified
Files already downloaded and verified


In [70]:
trainloader = DataLoader(cifar_trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)
testloader=DataLoader(cifar_testset,  batch_size=batch_size,shuffle=False)

In [71]:
train_features, train_labels = next(iter(trainloader))




In [72]:
em=Image_Embeddings(patch_size)
em.forward(train_features)

tensor([[[ 0.0177,  0.1881,  0.0836,  ..., -0.2007,  0.0250,  0.0124],
         [-0.1099,  0.4627, -0.5751,  ...,  1.4927, -0.2044, -0.2491],
         [ 0.4519,  1.0325,  0.0133,  ...,  2.4912,  0.7881,  0.7443],
         ...,
         [-0.6415,  0.9900,  0.9935,  ...,  0.9375,  0.9386,  0.9160],
         [-0.9342,  0.8537,  0.8464,  ...,  0.3399,  0.3252,  0.2874],
         [ 0.9843, -0.4255, -0.3688,  ...,  0.9375,  0.9386,  0.9159]],

        [[ 0.0363, -0.0824, -0.1022,  ..., -0.0647, -0.0539, -0.0138],
         [ 0.1449, -0.2179,  0.2876,  ..., -0.6655,  0.1462,  0.1101],
         [ 0.6789,  0.3621,  0.8747,  ...,  0.3724,  1.1429,  1.1091],
         ...,
         [-0.6415,  0.9900,  0.9935,  ...,  0.9375,  0.9386,  0.9160],
         [-0.9342,  0.8537,  0.8464,  ...,  0.3399,  0.3252,  0.2874],
         [ 0.9843, -0.4255, -0.3688,  ...,  0.9375,  0.9386,  0.9159]],

        [[-0.0244, -0.1697, -0.1165,  ...,  0.0233, -0.1119,  0.0629],
         [ 0.0709,  0.0073, -0.0892,  ...,  0

# Training

In [74]:

def train(model, dataloader=trainloader, epochs=epochs,base_lr=base_lr, device=device, criterion = nn.CrossEntropyLoss()):
    optimizer = optim.Adam(model.parameters(), lr=base_lr)
    train_losses = []
    acc=[]
    for epoch in tqdm(range(epochs), total=epochs):
        train_loss = 0.0
        total_samples=0
        total_correct=0
        num_batches=0
        print("Epoch {}:".format(epoch))

        for (batch_img, batch_labels) in tqdm(dataloader):
            # batch_img, batch_labels=batch_img.to(device), batch_labels.to(device)
            outputs, _ = model.forward(batch_img)  
            
            loss = criterion(outputs, batch_labels)
            train_loss += loss.detach().cpu().item() / len(dataloader)
            # print(outputs.shape)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            _, predicted = torch.max(outputs, 1)
            total_samples += batch_labels.size(0)
            total_correct += (predicted == batch_labels).sum().item()

            num_batches+=1
            # print('Batch {} :  loss={}'.format(batch_idx,  running_loss))   
        
        
        train_losses.append(train_loss)             

        accuracy = total_correct / total_samples
        acc.append(accuracy)
        print(f"Accuracy on train set: {accuracy*100:.2f}%")
        print(f"Epoch {epoch + 1}/{epochs} loss: {train_loss:.2f}")
    return train_losses, acc

In [85]:
vitModel1=ViT(n_classes,patch_size, n_heads)


In [86]:
loss1, acc1=train(vitModel1)

  0%|          | 0/5 [00:00<?, ?it/s]

Epoch 0:


  0%|          | 0/782 [00:00<?, ?it/s]

