In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np

# Set the device for training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

class AttentionMessagePassing(nn.Module):
    """
    A module that computes attention and performs message passing on a feature map.
    The feature map is treated as a grid of 'tiles' or nodes.
    """
    def __init__(self, in_channels):
        super(AttentionMessagePassing, self).__init__()
        self.in_channels = in_channels
        
        # We'll use a simple linear layer to project the features into
        # Query, Key, and Value vectors for the attention mechanism.
        # This is analogous to a 1x1 convolution.
        self.query_proj = nn.Conv2d(in_channels, in_channels // 2, kernel_size=1)
        self.key_proj = nn.Conv2d(in_channels, in_channels // 2, kernel_size=1)
        self.value_proj = nn.Conv2d(in_channels, in_channels, kernel_size=1)

    def forward(self, x):
        """
        x: A feature map tensor of shape (batch_size, channels, height, width)
        """
        batch_size, channels, h, w = x.shape

        # 1. Project to Query, Key, Value
        # Q, K, V are feature maps with reduced dimensionality for efficiency
        query = self.query_proj(x)
        key = self.key_proj(x)
        value = self.value_proj(x)
        
        # 2. Reshape and Flatten to (batch_size, num_tiles, feature_dim)
        # We treat each spatial location (h, w) as a 'tile' or node.
        # num_tiles = h * w
        # feature_dim = channels
        
        # Flattened query, key, and value vectors for attention calculation
        query_flat = query.view(batch_size, query.size(1), -1).permute(0, 2, 1) # B, N, C_q
        key_flat = key.view(batch_size, key.size(1), -1) # B, C_k, N
        value_flat = value.view(batch_size, value.size(1), -1).permute(0, 2, 1) # B, N, C_v

        # 3. Compute Attention Scores (Scaled Dot-Product)
        # The attention matrix A will be of shape (B, N, N)
        # where A[i, j] is the attention score from tile i to tile j.
        attention_scores = torch.bmm(query_flat, key_flat)
        attention_scores = attention_scores / np.sqrt(query.size(1))
        
        # Apply softmax to get normalized attention weights
        attention_weights = F.softmax(attention_scores, dim=-1) # B, N, N

        # 4. Message Passing
        # The "messages" are the value vectors. We aggregate messages from all
        # other tiles, weighted by the attention scores.
        # The new feature for each tile is a weighted sum of all value vectors.
        # `message_passed_features` will be (B, N, C_v)
        message_passed_features = torch.bmm(attention_weights, value_flat)
        
        # 5. Reshape back to the original feature map shape
        output = message_passed_features.permute(0, 2, 1).view(batch_size, channels, h, w)
        
        # Residual connection
        output = output + x
        
        return output

class AttentionConvNet(nn.Module):
    """
    A network combining convolution with attention and message passing.
    """
    def __init__(self):
        super(AttentionConvNet, self).__init__()
        
        # Convolutional Block
        self.conv_block = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2), # Output size: 14x14
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2), # Output size: 7x7
        )
        
        # Attention and Message Passing Block
        # The attention module operates on the 7x7 feature map with 64 channels
        self.attention_block = AttentionMessagePassing(in_channels=64)
        
        # Fully Connected Classifier
        # The input size to the first linear layer is 64 channels * 7 * 7
        self.fc_block = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        # 1. Convolution to extract features
        x = self.conv_block(x)
        
        # 2. Attention and Message Passing
        # This is where information is passed between different tiles of the feature map
        x = self.attention_block(x)
        
        # 3. Classification
        logits = self.fc_block(x)
        
        return logits

def train_model():
    """
    Main function to train and evaluate the Attention ConvNet on MNIST.
    """
    # Hyperparameters
    learning_rate = 0.001
    batch_size = 64
    num_epochs = 15

    # Data loading and transformation
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    
    train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
    test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
    
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

    # Initialize the model, loss function, and optimizer
    model = AttentionConvNet().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    # Simple learning rate scheduler to reduce the learning rate after a few epochs
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

    print("Starting training...")
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for i, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            if (i+1) % 100 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {running_loss/100:.4f}')
                running_loss = 0.0
        
        # Step the learning rate scheduler
        scheduler.step()
    
    print("Training finished. Evaluating on test set...")

    # Evaluation
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    print(f'Accuracy on the 10000 test images: {100 * correct / total:.2f}%')

if __name__ == '__main__':
    train_model()


Using device: cpu
Starting training...
Epoch [1/15], Step [100/938], Loss: 0.8856
Epoch [1/15], Step [200/938], Loss: 0.2752
Epoch [1/15], Step [300/938], Loss: 0.1970
Epoch [1/15], Step [400/938], Loss: 0.1544
Epoch [1/15], Step [500/938], Loss: 0.1271
Epoch [1/15], Step [600/938], Loss: 0.1349
Epoch [1/15], Step [700/938], Loss: 0.1132
Epoch [1/15], Step [800/938], Loss: 0.1200
Epoch [1/15], Step [900/938], Loss: 0.0793
Epoch [2/15], Step [100/938], Loss: 0.0809
Epoch [2/15], Step [200/938], Loss: 0.0952
Epoch [2/15], Step [300/938], Loss: 0.0788
Epoch [2/15], Step [400/938], Loss: 0.0810
Epoch [2/15], Step [500/938], Loss: 0.0747
Epoch [2/15], Step [600/938], Loss: 0.0869
Epoch [2/15], Step [700/938], Loss: 0.0813
Epoch [2/15], Step [800/938], Loss: 0.0699
Epoch [2/15], Step [900/938], Loss: 0.0690
Epoch [3/15], Step [100/938], Loss: 0.0564
Epoch [3/15], Step [200/938], Loss: 0.0576
Epoch [3/15], Step [300/938], Loss: 0.0534
Epoch [3/15], Step [400/938], Loss: 0.0557
Epoch [3/15], S