Self-Attention Convolutional Block

In [1]:
import torch
import torch.nn as nn

class SelfAttentionConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, heads=8):
        super(SelfAttentionConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn = nn.BatchNorm2d(out_channels)
        self.attention = SelfAttention(out_channels, heads)
        self.relu = nn.ReLU()

    def forward(self, x):
        conv_out = self.conv(x)
        conv_out = self.bn(conv_out)
        attn_out = self.attention(conv_out)
        return self.relu(attn_out)

class SelfAttention(nn.Module):
    def __init__(self, channels, heads=8):
        super(SelfAttention, self).__init__()
        self.heads = heads
        self.scale = channels ** -0.5

        self.query = nn.Conv2d(channels, channels, 1)
        self.key = nn.Conv2d(channels, channels, 1)
        self.value = nn.Conv2d(channels, channels, 1)

        self.out_conv = nn.Conv2d(channels, channels, 1)

    def forward(self, x):
        b, c, h, w = x.size()
        q = self.query(x).view(b, self.heads, c // self.heads, h * w)
        k = self.key(x).view(b, self.heads, c // self.heads, h * w)
        v = self.value(x).view(b, self.heads, c // self.heads, h * w)

        attention = torch.einsum('bhcn,bhck->bhnk', q, k) * self.scale
        attention = torch.softmax(attention, dim=-1)

        out = torch.einsum('bhnk,bhcv->bhcv', attention, v)
        out = out.view(b, c, h, w)

        out = self.out_conv(out)
        return out

# Example usage
if __name__ == "__main__":
    model = SelfAttentionConvBlock(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, heads=8)
    input_tensor = torch.randn(16, 3, 32, 32)  # Example input (batch_size, channels, height, width)
    output = model(input_tensor)
    print(output.shape)  # Expected output shape: (16, 64, 32, 32)

torch.Size([16, 64, 32, 32])


## Break down the code

In [4]:
import torch
import torch.nn as nn

class SelfAttentionConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, heads=8):
        super(SelfAttentionConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        conv_out = self.conv(x)
        conv_out = self.bn(conv_out)

        return conv_out

model = SelfAttentionConvBlock(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, heads=8)
input_tensor = torch.randn(16, 3, 32, 32)  # Example input (batch_size, channels, height, width)
output = model(input_tensor)

print(output.size())

torch.Size([16, 64, 32, 32])


In [6]:
A = nn.Conv2d(32*2, 32*2, 1)

A(torch.randn(1, 64, 32, 32)).size()

torch.Size([1, 64, 32, 32])

In [10]:
B = A(torch.randn(1, 64, 32, 32))
B.view(B.size(0), B.size(1), -1).size()

torch.Size([1, 64, 1024])

In [12]:
(1, 8, 64//8, 1024)

(1, 8, 8, 1024)

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

class SelfAttention(nn.Module):
    def __init__(self, channels, heads=8):
        super(SelfAttention, self).__init__()
        self.heads = heads
        self.scale = channels ** -0.5

        self.query = nn.Conv2d(channels, channels, 1)
        self.key = nn.Conv2d(channels, channels, 1)
        self.value = nn.Conv2d(channels, channels, 1)

        self.out_conv = nn.Conv2d(channels, channels, 1)

    def forward(self, x):
        b, c, h, w = x.size()
        q = self.query(x).view(b, self.heads, c // self.heads, h * w)
        k = self.key(x).view(b, self.heads, c // self.heads, h * w)
        v = self.value(x).view(b, self.heads, c // self.heads, h * w)

        attention = torch.einsum('bhcn,bhck->bhnk', q, k) * self.scale
        attention = torch.softmax(attention, dim=-1)

        out = torch.einsum('bhnk,bhcv->bhcv', attention, v)
        out = out.view(b, c, h, w)

        out = self.out_conv(out)
        return out

class SelfAttentionConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, heads=8):
        super(SelfAttentionConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn = nn.BatchNorm2d(out_channels)
        self.attention = SelfAttention(out_channels, heads)
        self.relu = nn.ReLU()

    def forward(self, x):
        conv_out = self.conv(x)
        conv_out = self.bn(conv_out)
        attn_out = self.attention(conv_out)
        return self.relu(attn_out)

class MNISTSelfAttentionModel(nn.Module):
    def __init__(self):
        super(MNISTSelfAttentionModel, self).__init__()

        self.layer1 = SelfAttentionConvBlock(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1, heads=8)
        self.layer2 = SelfAttentionConvBlock(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1, heads=8)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)
        self.pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.layer1(x)
        x = self.pool(x)  # Reduces size to 14x14
        x = self.layer2(x)
        x = self.pool(x)  # Reduces size to 7x7
        x = x.view(-1, 64 * 7 * 7)  # Flatten the tensor
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define Hyperparameters
batch_size = 64
learning_rate = 0.001
num_epochs = 5

# Load MNIST Data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# Initialize the model, loss function, and optimizer
model = MNISTSelfAttentionModel().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training Loop
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}')

# Testing the model
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Test Accuracy of the model on the 10000 test images: {100 * correct / total:.2f}%')

# Save the model checkpoint
torch.save(model.state_dict(), 'mnist_self_attention_model.pth')

Using device: cuda
Epoch [1/5], Loss: 0.9409
Epoch [2/5], Loss: 0.0818
Epoch [3/5], Loss: 0.0656
Epoch [4/5], Loss: 0.0603
Epoch [5/5], Loss: 0.0569
Test Accuracy of the model on the 10000 test images: 98.37%
