#### patch embedding
#### positional encoding
#### Transformer encoder model
#### Vision Transformer model(ViT)

In [106]:
## Import Libraries
import math
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [107]:
## Patch Embedding
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super(PatchEmbedding, self).__init__()
        self.patch_size=patch_size
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x) # 1x768x14x14
        x = x.flatten(2)
        x = x.transpose(1, 2)
        return x

In [108]:
## Example of Patch Embedding
tensor = torch.rand(1, 3, 224, 224)
print(tensor.shape)
patch_embed_obj = PatchEmbedding()
patch_embed = patch_embed_obj(tensor)
print(patch_embed.shape)

torch.Size([1, 3, 224, 224])
torch.Size([1, 196, 768])


In [109]:
## Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, embed_dim) # positional encoding 5000 x embed_dim --> embed_dim=768 form PatchEmbedding
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # max_len x 1
        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0) # 1 x 5000 x 768
        self.register_buffer('pe', pe)
    def forward(self, x):
        x  = x + self.pe[:, :x.size(1), :]
        return x
        

In [110]:
posn_encoding_obj = PositionalEncoding(embed_dim=768)
posn_encoding = posn_encoding_obj(patch_embed)
print(posn_encoding.shape)

torch.Size([1, 196, 768])


In [111]:
## Transformer Encoder
class TransformerEncoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, hidden_dim, dropout=0.1):
        super(TransformerEncoderLayer, self).__init__()
        self.self_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.linear1 = nn.Linear(embed_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(hidden_dim, embed_dim)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x):
        x2 = self.self_attn(x, x, x)[0]
        x = x + self.dropout(x2)
        x = self.norm1(x)
        x = self.linear1(x)
        x = F.relu(x)
        x = self.dropout1(x)
        x = self.linear2(x)
        x = x + self.dropout2(x)
        x = self.norm2(x)
        return x


In [112]:
## Vision Transformer(ViT) model
class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=10, 
                 embed_dim=768, num_heads=8, hidden_dim=2048, num_layers=12, dropout=0.1):
        super(VisionTransformer, self).__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        self.pos_embed = PositionalEncoding(embed_dim)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.dropout = nn.Dropout(dropout)
        self.transformer_layers = nn.ModuleList([
            TransformerEncoderLayer(embed_dim, num_heads, hidden_dim, dropout) for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.fc = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        B = x.size(0)
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens,x), dim=1)
        x = self.pos_embed(x)
        x = self.dropout(x)

        for layer in self.transformer_layers:
            x = layer(x)
        
        x = self.norm(x)
        cls_token_final = x[:,0]
        x = self.fc(cls_token_final)
        return x

        

In [113]:
## Prepare the CIFAR-10 dataset
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [114]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = VisionTransformer().to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

## Train loop
def train(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    itr = 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss+=loss.item()
        itr+=1
        if itr == 20:
            break
    return running_loss / len(train_loader)

def test(model, test_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    itr = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            correct += torch.sum(preds == labels.data)
            itr+=1
            if itr == 20:
                break
    return running_loss / len(test_loader), correct.double() / len(test_loader.dataset)

In [None]:
for epoch in range(10):
    train_loss = train(model, train_loader, criterion, optimizer, device)
    test_loss, test_accuracy = test(model, test_loader, criterion, device)

    print(f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")

Epoch 1, Train Loss: 0.0244, Test Loss: 0.0833, Test Accuracy: 0.0039
Epoch 2, Train Loss: 0.0156, Test Loss: 0.0767, Test Accuracy: 0.0040
Epoch 3, Train Loss: 0.0155, Test Loss: 0.0754, Test Accuracy: 0.0035
Epoch 4, Train Loss: 0.0151, Test Loss: 0.0761, Test Accuracy: 0.0026
