In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define the Vision Transformer
class VisionTransformer(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_channels=3, num_classes=10, dim=128, depth=6, heads=8, mlp_dim=256, dropout=0.1):
        super(VisionTransformer, self).__init__()
        assert img_size % patch_size == 0, "Image size must be divisible by patch size"
        num_patches = (img_size // patch_size) ** 2
        self.patch_size = patch_size
        self.dim = dim

        self.patch_embeddings = nn.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size)
        self.class_token = nn.Parameter(torch.randn(1, 1, dim))
        self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, dim))

        self.transformer = nn.Transformer(d_model=dim, nhead=heads, num_encoder_layers=depth, num_decoder_layers=depth, dim_feedforward=mlp_dim, dropout=dropout)

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, x):
        batch_size = x.size(0)
        x = self.patch_embeddings(x)  # (B, dim, H/P, W/P)
        x = x.flatten(2)  # (B, dim, num_patches)
        x = x.transpose(1, 2)  # (B, num_patches, dim)

        cls_tokens = self.class_token.expand(batch_size, -1, -1)  # (B, 1, dim)
        x = torch.cat((cls_tokens, x), dim=1)  # (B, num_patches+1, dim)
        x += self.position_embeddings  # (B, num_patches+1, dim)

        x = x.transpose(0, 1)  # (num_patches+1, B, dim)
        x = self.transformer(x, x)  # (num_patches+1, B, dim)
        x = x.transpose(0, 1)  # (B, num_patches+1, dim)

        x = self.mlp_head(x[:, 0])  # (B, dim)

        return x

# Function to train and evaluate the model
def train_and_evaluate_vit():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    trainloader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)

    testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    testloader = DataLoader(testset, batch_size=1000, shuffle=False, num_workers=2)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = VisionTransformer().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

    num_epochs = 10

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 100 == 99:
                print(f'[Epoch {epoch + 1}, Batch {i + 1}] loss: {running_loss / 100:.3f}')
                running_loss = 0.0

        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data in testloader:
                images, labels = data
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        print(f'Accuracy of the network on the 10000 test images: {100 * correct / total:.2f}%')

# Train and evaluate the Vision Transformer
train_and_evaluate_vit()


Files already downloaded and verified
Files already downloaded and verified
[Epoch 1, Batch 100] loss: 2.192
[Epoch 1, Batch 200] loss: 2.073
[Epoch 1, Batch 300] loss: 2.023
[Epoch 1, Batch 400] loss: 1.988
[Epoch 1, Batch 500] loss: 1.930
[Epoch 1, Batch 600] loss: 1.902
[Epoch 1, Batch 700] loss: 1.880
Accuracy of the network on the 10000 test images: 33.25%
[Epoch 2, Batch 100] loss: 1.805
[Epoch 2, Batch 200] loss: 1.772
[Epoch 2, Batch 300] loss: 1.770
[Epoch 2, Batch 400] loss: 1.738
[Epoch 2, Batch 500] loss: 1.707
[Epoch 2, Batch 600] loss: 1.697
[Epoch 2, Batch 700] loss: 1.664
Accuracy of the network on the 10000 test images: 41.96%
[Epoch 3, Batch 100] loss: 1.643
[Epoch 3, Batch 200] loss: 1.626
[Epoch 3, Batch 300] loss: 1.611
[Epoch 3, Batch 400] loss: 1.600
[Epoch 3, Batch 500] loss: 1.584
[Epoch 3, Batch 600] loss: 1.577
[Epoch 3, Batch 700] loss: 1.560
Accuracy of the network on the 10000 test images: 45.41%
[Epoch 4, Batch 100] loss: 1.523
[Epoch 4, Batch 200] loss: 

KeyboardInterrupt: 