In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F

# Define the Hybrid Model (from the previous step)
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, pool_size=2):
        super(CNNBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.pool = nn.MaxPool2d(pool_size)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = F.relu(x)
        x = self.pool(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_hidden_dim, dropout=0.1):
        super(TransformerBlock, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_hidden_dim),
            nn.ReLU(),
            nn.Linear(ff_hidden_dim, embed_dim),
        )
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        attn_output, _ = self.attention(x, x, x)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.ff(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

class MNISTHybridModel(nn.Module):
    def __init__(self, num_classes=10, embed_dim=64, num_heads=2, ff_hidden_dim=128, num_transformer_layers=1):
        super(MNISTHybridModel, self).__init__()

        # CNN feature extractor
        self.cnn = nn.Sequential(
            CNNBlock(1, 32),   # 28x28 -> 14x14
            CNNBlock(32, 64),  # 14x14 -> 7x7
        )

        # Linear projection to transformer input
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(64 * 7 * 7, embed_dim)  # 7x7 is the size after two poolings

        # Transformer layers
        self.transformer_layers = nn.ModuleList(
            [TransformerBlock(embed_dim, num_heads, ff_hidden_dim) for _ in range(num_transformer_layers)]
        )

        # Classification head
        self.classifier = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        # Extract features using CNN
        x = self.cnn(x)  # Shape: (batch_size, 64, 7, 7)

        # Flatten and project to transformer input
        x = self.flatten(x)  # Shape: (batch_size, 64 * 7 * 7)
        x = self.fc(x)  # Shape: (batch_size, embed_dim)

        # Prepare for transformer input: (sequence_length, batch_size, embed_dim)
        x = x.unsqueeze(0)  # Adding a sequence dimension: Shape: (1, batch_size, embed_dim)

        # Apply transformer layers
        for transformer_layer in self.transformer_layers:
            x = transformer_layer(x)

        # Classification
        x = x.squeeze(0)  # Remove sequence dimension
        x = self.classifier(x)  # Shape: (batch_size, num_classes)

        return x

# Define Hyperparameters
batch_size = 64
learning_rate = 0.001
num_epochs = 5

# Load MNIST Data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# Initialize the model, loss function, and optimizer
model = MNISTHybridModel(num_classes=10, embed_dim=64, num_heads=2, ff_hidden_dim=128, num_transformer_layers=1)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training Loop
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        optimizer.zero_grad()

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}')

# Testing the model
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Test Accuracy of the model on the 10000 test images: {100 * correct / total:.2f}%')

# Save the model checkpoint
torch.save(model.state_dict(), 'mnist_hybrid_model.pth')


Epoch [1/5], Loss: 0.1525
Epoch [2/5], Loss: 0.0520
Epoch [3/5], Loss: 0.0360
Epoch [4/5], Loss: 0.0307
Epoch [5/5], Loss: 0.0258
Test Accuracy of the model on the 10000 test images: 98.86%


You can combine Convolutional Neural Networks (CNNs) with transformer layers. In such models, the CNN extracts feature maps from the image, and then multihead attention is applied to these feature maps.