In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [2]:
# Helper functions
def img_to_patch(x, patch_size):
    '''
    Transforms image into a list of patches of the specified dimensions
    '''
    B,C,H,W = x.shape
    
    # Reshape this matrix to B X N X [C * P ** 2]
    # If you want to break a dimension into 2, the product of both those
    # should equal the original number.
    
    # Visualized/imagined how the reshaping, permuting and flattening happens. It's 
    # beautiful to visualize it as well as informative.
    x = x.reshape(B,C,H//patch_size,patch_size,W//patch_size,patch_size)
    x = x.permute(0,2,4,1,3,5)
    x = x.flatten(1,2)
    x = x.flatten(2,4)
    
    return x

In [9]:
# Defining my ViT's architecture
class ViTEncoder(nn.Module):
    def __init__(self,input_dim,hidden_dim,num_heads,dropout = 0.5):
        super().__init__()
        self.norm1 = nn.LayerNorm(input_dim)
        self.attn = nn.MultiheadAttention(input_dim, num_heads) # Input dim is the number of tokens
        self.norm2 = nn.LayerNorm(input_dim)
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim,input_dim)
        self.drop = nn.Dropout(dropout)
        
    def forward(self,x):
        # Apply layer normalization first
        out = self.norm1(x)
        # Apply multi-headed self-attention
        out,_ = self.attn(out,out,out) # Cuz it is self-attention, all the inputs for query,
        # key, value come from the same input.
        # Apply the residual connection
        resid = x + out
        # Apply the second layer normalization
        out = self.norm2(resid)
        # Pass the outputs throught the MLP. Remember that the tokens pass through
        # the MLP independently but they share weights ofc.
        out = F.gelu(self.fc1(out))
        out = self.drop(out)
        out = self.fc2(out)
        out = self.drop(out)
        
        # Apply the second residual connection and give the output.
        out = out + resid
        return out

# Defining the ViTClassifier's architecture
class ViTClassifier(nn.Module): 
    def __init__(self,embed_size,hidden_size,hidden_class_size,num_encoders,num_heads,patch_size,num_patches,dropout = 0.5):
        super().__init__()
        
        # Important parameters
        self.patch_size = patch_size
        self.num_patches = num_patches
        
        # DNN to create an embedding from flattened patches
        self.input = nn.Linear(3*(patch_size**2), embed_size) # 3 cuz it's the no. of channels
        self.drop = nn.Dropout(dropout)
        
        # Define the transformer encoders which'll be used for getting the final class token which
        # will then be used for classification
        self.transformer = nn.Sequential(
            *(ViTEncoder(embed_size,hidden_size,num_heads,dropout) for _ in range(num_encoders))
        )
        
        # Defining the classification head and creating the class token and 
        # learnable position embeddings
        self.fc1 = nn.Linear(embed_size,hidden_class_size)
        self.fc2 = nn.Linear(hidden_class_size,100) # Since the classification is on the CIFAR-100 DS.
        
        self.class_embed = nn.Parameter(torch.randn(1,1,embed_size))
        self.pos_embed = nn.Parameter(torch.randn(1,1 + num_patches,embed_size))
        
    def forward(self,x):
        # Flatten patches first from the input and create embeddings
        x = img_to_patch(x,self.patch_size)
        x = F.relu(self.input(x))
        B,N,L = x.shape
        
        class_embed = self.class_embed.repeat(B,1,1)
        x = torch.cat([class_embed, x], dim = 1)
        x = x + self.pos_embed[:, :N+1]
        x = self.drop(x)
        
        # Applying the transformer encoder
        # Transposing because the transformer expects the input in the NXBXembed_size format
        # instead of the usual BXNXembed_size format
        x = x.transpose(0,1)
        x = self.transformer(x)
        x = x[0] # An Array of class embeddings. Notice how we concatenated to the beginning
        # of the input a couple of steps ago.
        
        # Classify the class token vector
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [7]:
transform = transforms.Compose([
    transforms.ToTensor(), # Convert PIL image to PyTorch tensor
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) # Normalize with mean and std dev for CIFAR-100
])

train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [None]:
# Defining parameters for training
model = ViTClassifier(embed_size=768,hidden_size=512,hidden_class_size=512,num_encoders=4,num_heads=4,patch_size=16,num_patches=16)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
num_epochs = 5

for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Print some information every 100 batches
        if (i+1) % 100 == 0:
            print (f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')


Epoch [1/5], Step [100/782], Loss: 4.1604
Epoch [1/5], Step [200/782], Loss: 4.2520
Epoch [1/5], Step [300/782], Loss: 4.4809
Epoch [1/5], Step [400/782], Loss: 4.3598
Epoch [1/5], Step [500/782], Loss: 4.3857
Epoch [1/5], Step [600/782], Loss: 4.4208
