In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchsummary import summary
from einops.layers.torch import Rearrange

# Vision Transformer Model
class VisionTransformer(nn.Module):
    def __init__(self, image_size, patch_size, num_classes, dim, depth, heads):
        super(VisionTransformer, self).__init__()

        num_patches = (image_size // patch_size) ** 2
        patch_dim = 3 * patch_size ** 2  # 3 channels for RGB images

        self.patch_embed = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size, bias=False)
        self.positional_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=dim, nhead=heads, dim_feedforward=dim * 4, activation='gelu'
            ),
            num_layers=depth
        )
        self.pooling = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(dim, num_classes)

    def forward(self, x):
        x = self.patch_embed(x)
        x = x.flatten(2).transpose(1, 2)
        x = torch.cat([self.cls_token.expand(x.size(0), -1, -1), x], dim=1)
        x += self.positional_embedding
        x = self.transformer(x)
        x = self.pooling(x.transpose(1, 2)).squeeze(-1)
        x = self.fc(x)
        return x

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
image_size = 32
patch_size = 16
num_classes = 100
dim = 768
depth = 12
heads = 12
lr = 0.001
batch_size = 32
num_epochs = 10

# Define transformations for CIFAR-100 (adjust based on your dataset)
train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load CIFAR-100 dataset
train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=train_transform)
test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=test_transform)


X0, y0 = train_dataset[0]
input_shape = X0.shape
print(f'input shape {input_shape}')

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

# Initialize ViT model
model = VisionTransformer(image_size, patch_size, num_classes, dim, depth, heads).to(device)

print(model)


# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

batch_size = 32
print_freq = 10 * batch_size


# Training loop
for epoch in range(num_epochs):
    model.train()
    for i, data in enumerate(tqdm(trainloader, desc=f'Epoch {epoch + 1}/{num_epochs}', unit='batch')):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        if (i * batch_size) % print_freq == 0:
            print(f'loss {loss.item()}')

    # Validation loop
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.4f}, Accuracy: {accuracy:.4f}')

print('Training finished.')
