In [51]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets



In [52]:
# Implementing self attention layer for CNN
class SelfAttention(nn.Module):
    def __init__(self, in_dim, out_dim=8):
        super(SelfAttention, self).__init__()
        self.query = nn.Conv2d(in_channels=in_dim, out_channels=out_dim, kernel_size=1)
        self.key = nn.Conv2d(in_channels=in_dim, out_channels=out_dim, kernel_size=1)
        self.value = nn.Conv2d(in_channels=in_dim, out_channels=out_dim, kernel_size=1)

    def forward(self, x):
        # x: [b, c, h, w]
        # [b, c', h, w]
        proj_query = self.query(x)
        proj_key = self.key(x).permute(0, 2, 3, 1)
        energy = torch.matmul(proj_query, proj_key)
        attention = torch.softmax(energy, dim=-1)
        proj_value = self.value(x)
        out = torch.matmul(attention, proj_value).permute(0, 3, 1, 2)
        return out, attention
    
    


In [53]:
# CNN model for CIFAR10 dataset with self attention layer
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.conv4 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(512)
        self.conv5 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=1)
        self.bn5 = nn.BatchNorm2d(1024)
        # Last layer as a self attention layer
        self.self_attention = SelfAttention(in_dim=1024, out_dim=8)
        self.fc1 = nn.Linear(1024, 10)
        self.dropout = nn.Dropout(0.5)
        self.pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.pool(x)
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.pool(x)
        x = self.relu(self.bn3(self.conv3(x)))
        x = self.pool(x)
        x = self.relu(self.bn4(self.conv4(x)))
        x = self.pool(x)
        x = self.relu(self.bn5(self.conv5(x)))
        x = self.pool(x)
        # Self attention layer
        x, attention = self.self_attention(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = self.fc1(x)
        return x, attention
    
    def loss_function(self, output, target):
        loss = nn.CrossEntropyLoss()
        return loss(output, target)
    
    def accuracy(self, output, target):
        _, predicted = torch.max(output.data, 1)
        correct = (predicted == target).sum().item()
        return correct / target.size(0)
    
    

In [54]:
# Loading the CIFAR-10 dataset
transform = transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor()])
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)

Files already downloaded and verified
Files already downloaded and verified


In [55]:
# Initializing the model
model = CNN()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [56]:
# Training the model
def train(model, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        print(f"Epoch: {epoch}, Batch: {batch_idx}")
        optimizer.zero_grad()
        output, _ = model(data)
        loss = model.loss_function(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.2f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            
# Testing the model
def test(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output, _ = model(data)
            test_loss += model.loss_function(output, target).item()
            correct += model.accuracy(output, target)
    test_loss /= len(test_loader.dataset)
    correct /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {:.4f}\n'.format(test_loss, correct))

# Training and testing the model for 10 epochs
for epoch in range(1, 11):
    train(model, train_loader, optimizer, epoch)
    test(model, test_loader)

    

Epoch: 1, Batch: 0


RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [512, 8] but got: [512, 1].