In [1]:
import torch
from torch import nn
import torchvision
import torchvision.transforms as transforms

In [2]:
transform = transforms.Compose([transforms.ToTensor()])
mnist_train = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
mnist_test = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=64, shuffle=False)


In [3]:
class PatchEmbedding(nn.Module):
    def __init__(self, embed_size):
        super(PatchEmbedding, self).__init__()
        # Each patch is a pixel, so in_features = 1
        self.projection = nn.Linear(1, embed_size)

    def forward(self, x):
        # x shape is [batch_size, num_patches, 1]
        x = self.projection(x)  # Project each patch
        # Output shape is [batch_size, num_patches, embed_size]
        return x


In [4]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim=embed_size, num_heads=heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query):
        attention = self.attention(query, key, value)[0]
        x = self.norm1(attention + query)
        forward = self.feed_forward(x)
        out = self.norm2(forward + x)
        return out


In [5]:
class ViT(nn.Module):
    def __init__(self, embed_size, num_layers, heads, num_classes, dropout, forward_expansion):
        super(ViT, self).__init__()
        self.patch_embedding = PatchEmbedding(embed_size)
        self.layers = nn.ModuleList([
            TransformerBlock(
                embed_size,
                heads,
                dropout=dropout,
                forward_expansion=forward_expansion
            )
            for _ in range(num_layers)
        ])
        self.fc = nn.Linear(embed_size, num_classes)

    def forward(self, x):
        x = x.view(x.shape[0], -1, x.shape[-1]) # Flatten the image
        x = self.patch_embedding(x)
        for layer in self.layers:
            x = layer(x, x, x)
        out = self.fc(x.mean(dim=1))
        return out


In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ViT(
    embed_size=256, 
    num_layers=6, 
    heads=8, 
    num_classes=10, 
    dropout=0.1, 
    forward_expansion=4
)

model.to(device)


ViT(
  (patch_embedding): PatchEmbedding(
    (projection): Linear(in_features=1, out_features=256, bias=True)
  )
  (layers): ModuleList(
    (0-5): 6 x TransformerBlock(
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
      )
      (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (feed_forward): Sequential(
        (0): Linear(in_features=256, out_features=1024, bias=True)
        (1): ReLU()
        (2): Linear(in_features=1024, out_features=256, bias=True)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (fc): Linear(in_features=256, out_features=10, bias=True)
)

In [7]:
def train(model, dataloader, loss_fn, optimizer, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for batch in dataloader:
        images, labels = batch
        images, labels = images.to(device), labels.to(device)

        # Flatten the images for the patch embedding
        images = images.view(images.size(0), -1, images.size(-1))

        # Forward pass
        outputs = model(images)
        loss = loss_fn(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    avg_loss = total_loss / len(dataloader)
    accuracy = 100 * correct / total
    return avg_loss, accuracy


In [9]:
def test(model, dataloader, loss_fn, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in dataloader:
            images, labels = batch
            images, labels = images.to(device), labels.to(device)

            # Flatten the images for the patch embedding
            images = images.view(images.size(0), -1, images.size(-1))

            # Forward pass
            outputs = model(images)
            loss = loss_fn(outputs, labels)

            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    avg_loss = total_loss / len(dataloader)
    accuracy = 100 * correct / total
    return avg_loss, accuracy


In [8]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

num_epochs = 10  # Feel free to adjust this

for epoch in range(num_epochs):
    train_loss, train_accuracy = train(model, train_loader, criterion, optimizer, device)
    test_loss, test_accuracy = test(model, test_loader, criterion, device)

    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%")
    print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%\n")


RuntimeError: mat1 and mat2 shapes cannot be multiplied (1792x28 and 1x256)