In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import torch.nn.functional as F
from torch.utils.data import DataLoader

# Define a simple CNN architecture
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

# Preparing the MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)

# Initialize the model, loss function, and optimizer
model = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training the model
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}'
                  f' ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

# Testing the model
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)}'
          f' ({100. * correct / len(test_loader.dataset):.0f}%)\n')

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Start training
for epoch in range(1, 10 + 1):  # 10 epochs
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)



Test set: Average loss: 0.0000, Accuracy: 9844/10000 (98%)


Test set: Average loss: 0.0000, Accuracy: 9871/10000 (99%)


Test set: Average loss: 0.0000, Accuracy: 9897/10000 (99%)


Test set: Average loss: 0.0000, Accuracy: 9907/10000 (99%)


Test set: Average loss: 0.0000, Accuracy: 9911/10000 (99%)


Test set: Average loss: 0.0000, Accuracy: 9919/10000 (99%)


Test set: Average loss: 0.0000, Accuracy: 9914/10000 (99%)


Test set: Average loss: 0.0000, Accuracy: 9938/10000 (99%)


Test set: Average loss: 0.0000, Accuracy: 9922/10000 (99%)


Test set: Average loss: 0.0000, Accuracy: 9934/10000 (99%)



In [10]:
!pip install einops
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import torch.nn.functional as F
from torch.utils.data import DataLoader
import einops

# Define the Self-Attention block
class SelfAttention(nn.Module):
    def __init__(self, embed_size, num_heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.num_heads = num_heads
        self.head_dim = embed_size // num_heads

        self.query = nn.Linear(embed_size, embed_size)
        self.key = nn.Linear(embed_size, embed_size)
        self.value = nn.Linear(embed_size, embed_size)

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        batch_size, seq_length, embed_size = x.size()

        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        q = einops.rearrange(q, 'b n (h d) -> b h n d', h=self.num_heads)
        k = einops.rearrange(k, 'b n (h d) -> b h n d', h=self.num_heads)
        v = einops.rearrange(v, 'b n (h d) -> b h n d', h=self.num_heads)

        attn_scores = torch.einsum('bhqd,bhkd->bhqk', q, k) / self.head_dim**0.5
        attn_probs = self.softmax(attn_scores)

        attn_output = torch.einsum('bhqk,bhkd->bhqd', attn_probs, v)
        attn_output = einops.rearrange(attn_output, 'b h n d -> b n (h d)')

        return attn_output

# Define the Vision Transformer (ViT) architecture
class VisionTransformer(nn.Module):
    def __init__(self, patch_size=7, num_blocks=6, embed_size=64, num_heads=8, num_classes=10):
        super(VisionTransformer, self).__init__()
        self.patch_size = patch_size
        self.num_patches = (28 // patch_size) ** 2
        self.embed_size = embed_size

        self.patch_embedding = nn.Conv2d(1, embed_size, kernel_size=patch_size, stride=patch_size)

        self.transformer_blocks = nn.ModuleList(
            [nn.TransformerEncoderLayer(embed_size, num_heads) for _ in range(num_blocks)]
        )

        self.fc = nn.Linear(embed_size, num_classes)

    def forward(self, x):
        x = self.patch_embedding(x)
        x = x.flatten(2).transpose(1, 2)

        for block in self.transformer_blocks:
            x = block(x)

        x = x.mean(dim=1)
        x = self.fc(x)

        return x

# Preparing the MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)

# Initialize the model, loss function, and optimizer
model = VisionTransformer()
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001)

# Training the model
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}'
                  f' ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

# Testing the model
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)}'
          f' ({100. * correct / len(test_loader.dataset):.0f}%)\n')

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Start training
for epoch in range(1, 50 + 1):  # 10 epochs
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)


Test set: Average loss: 0.0020, Accuracy: 2393/10000 (24%)


Test set: Average loss: 0.0017, Accuracy: 3939/10000 (39%)


Test set: Average loss: 0.0012, Accuracy: 6491/10000 (65%)


Test set: Average loss: 0.0007, Accuracy: 8249/10000 (82%)


Test set: Average loss: 0.0004, Accuracy: 9081/10000 (91%)


Test set: Average loss: 0.0003, Accuracy: 9377/10000 (94%)


Test set: Average loss: 0.0002, Accuracy: 9484/10000 (95%)


Test set: Average loss: 0.0002, Accuracy: 9593/10000 (96%)


Test set: Average loss: 0.0002, Accuracy: 9615/10000 (96%)


Test set: Average loss: 0.0001, Accuracy: 9659/10000 (97%)


Test set: Average loss: 0.0001, Accuracy: 9684/10000 (97%)


Test set: Average loss: 0.0001, Accuracy: 9722/10000 (97%)


Test set: Average loss: 0.0001, Accuracy: 9737/10000 (97%)


Test set: Average loss: 0.0001, Accuracy: 9727/10000 (97%)


Test set: Average loss: 0.0001, Accuracy: 9757/10000 (98%)


Test set: Average loss: 0.0001, Accuracy: 9771/10000 (98%)


Test set: Average loss:

CNNs generally have stronger prior on vision tasks compared to ViTs, so it is efficient on esay task like MNIST. CNN converges faster than ViT in this example, and CNN outperforms a little bit (1% on MNIST test set) than ViT