In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

# Patch Embedding
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_channels=3, embedding_dim=128):
        super().__init__()
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.projection = nn.Conv2d(in_channels, embedding_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.projection(x)  # Shape: [batch_size, embedding_dim, num_patches^(0.5), num_patches^(0.5)]
        x = x.flatten(2)  # Flatten height and width
        x = x.transpose(1, 2)  # Shape: [batch_size, num_patches, embedding_dim]
        return x

# Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, embedding_dim, num_patches):
        super().__init__()
        self.pos_encoding = nn.Parameter(torch.randn(1, num_patches, embedding_dim))

    def forward(self, x):
        return x + self.pos_encoding

# Transformer Block
class TransformerBlock(nn.Module):
    def __init__(self, embedding_dim, num_heads, mlp_dim, dropout=0.1):
        super().__init__()
        self.attention = nn.MultiheadAttention(embedding_dim, num_heads, dropout=dropout)
        self.norm1 = nn.LayerNorm(embedding_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embedding_dim, mlp_dim),
            nn.GELU(),
            nn.Linear(mlp_dim, embedding_dim),
            nn.Dropout(dropout),
        )
        self.norm2 = nn.LayerNorm(embedding_dim)

    def forward(self, x):
        # Self-attention
        attn_output, _ = self.attention(x, x, x)
        x = self.norm1(x + attn_output)
        # Feed-forward network
        mlp_output = self.mlp(x)
        x = self.norm2(x + mlp_output)
        return x

# Vision Transformer Model
class VisionTransformer(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_channels=3, num_classes=10, 
                 embedding_dim=128, num_heads=8, mlp_dim=256, num_layers=6, dropout=0.1):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embedding_dim)
        self.pos_encoding = PositionalEncoding(embedding_dim, self.patch_embed.num_patches)
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(embedding_dim, num_heads, mlp_dim, dropout) for _ in range(num_layers)
        ])
        self.classifier = nn.Linear(embedding_dim, num_classes)

    def forward(self, x):
        x = self.patch_embed(x)
        x = self.pos_encoding(x)
        for block in self.transformer_blocks:
            x = block(x)
        x = x.mean(dim=1)  # Global average pooling (for classification)
        return self.classifier(x)


In [2]:
from tqdm import tqdm
# Define device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# CIFAR-10 dataset loading
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

# Instantiate model and move to GPU
model = VisionTransformer().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop
for epoch in range(50):  # Adjust the number of epochs
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        
        # Move data to GPU
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 99:  # Print every 100 mini-batches
            print(f'Epoch {epoch + 1}, Batch {i + 1}, Loss: {running_loss / 100:.3f}')
            running_loss = 0.0

# Testing loop
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        
        # Move data to GPU
        images, labels = images.to(device), labels.to(device)

        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy on the 10000 test images: {100 * correct / total}%')


Using device: cuda
Files already downloaded and verified
Files already downloaded and verified
Epoch 1, Batch 100, Loss: 2.101
Epoch 1, Batch 200, Loss: 1.939
Epoch 1, Batch 300, Loss: 1.870
Epoch 1, Batch 400, Loss: 1.786
Epoch 1, Batch 500, Loss: 1.737
Epoch 1, Batch 600, Loss: 1.685
Epoch 1, Batch 700, Loss: 1.658
Epoch 2, Batch 100, Loss: 1.571
Epoch 2, Batch 200, Loss: 1.540
Epoch 2, Batch 300, Loss: 1.528
Epoch 2, Batch 400, Loss: 1.530
Epoch 2, Batch 500, Loss: 1.507
Epoch 2, Batch 600, Loss: 1.458
Epoch 2, Batch 700, Loss: 1.465
Epoch 3, Batch 100, Loss: 1.416
Epoch 3, Batch 200, Loss: 1.400
Epoch 3, Batch 300, Loss: 1.351
Epoch 3, Batch 400, Loss: 1.379
Epoch 3, Batch 500, Loss: 1.371
Epoch 3, Batch 600, Loss: 1.345
Epoch 3, Batch 700, Loss: 1.334
Epoch 4, Batch 100, Loss: 1.283
Epoch 4, Batch 200, Loss: 1.306
Epoch 4, Batch 300, Loss: 1.268
Epoch 4, Batch 400, Loss: 1.291
Epoch 4, Batch 500, Loss: 1.251
Epoch 4, Batch 600, Loss: 1.302
Epoch 4, Batch 700, Loss: 1.237
Epoch 5, 