In [None]:
# import os
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
# os.environ['TORCH_USE_CUDA_DSA'] = "1"

In [9]:
import torch
from torch import nn
import torch.nn.functional as F
from typing import Callable
import matplotlib.pyplot as plt

In [10]:
image_path = "tiny-imagenet-200\\train\\n01443537\\images\\n01443537_0.JPEG"
image = plt.imread(image_path)
image = torch.tensor(image)
print(image.shape)

torch.Size([64, 64, 3])


### TODO: Multi headed self attetion.

In [11]:
# input shape [batch_size, channel, width, hieght]
class SelfAttentionLayer(nn.Module):
    def __init__(self, in_planes, out_planes, head, similarity_fun: Callable[[torch.tensor], torch.tensor]) -> None:
        super().__init__()
        self.qurey = nn.Conv2d(in_planes, out_planes, kernel_size=1)
        self.key = nn.Conv2d(in_planes, out_planes, kernel_size=1)
        self.value = nn.Conv2d(in_planes, out_planes, kernel_size=1)
        self.similarity_fun = similarity_fun
        self.output = nn.Conv2d(out_planes, in_planes, kernel_size=1)
        self.gamma = nn.Parameter(torch.tensor([0.]))
        self.head = head
        self.out_dim = out_planes // self.head
    
    def forward(self, x):
        shape = x.shape
        batch_size, channels, width, height = shape
        # print("x", x.shape)
        q, k, v = self.qurey(x), self.key(x), self.value(x)        
        #For single head
        # q = q.view(*q.shape[:2], -1)
        # k = k.view(*k.shape[:2], -1)
        # v = v.view(*v.shape[:2], -1)

        #For multi head
        q = q.view(batch_size * self.head, self.out_dim, width * height)
        k = k.view(batch_size * self.head, self.out_dim, width * height)
        v = v.view(batch_size * self.head, self.out_dim, width * height)

        
        a = F.softmax(self.similarity_fun(q, k), dim=-1)

        # print('a', a.shape)
        # print('v', v.shape)
        channels = v.shape[1]
        #For single head
        # o = self.output(torch.bmm(a, v).view(batch_size, channels, width, height))
        #For multi head
        o = self.output(torch.bmm(a, v).view(batch_size, self.out_dim * self.head, width, height))
        # print('val',  o.shape)

        return self.gamma * o + x 


In [12]:
def similarity_fun(Q: torch.tensor, K: torch.tensor):
    # print("Q", Q.shape)
    # print("K.T", K.transpose(1, 2).shape)
    return torch.bmm(Q, K.transpose(1, 2))/torch.sqrt(torch.tensor(K.shape[-1]))

In [None]:
x = torch.rand(2, 3, 64, 64)
selfModel = SelfAttentionLayer(3, 8, 4, similarity_fun)
y = selfModel.forward(x)
# print("y", y.shape)

In [None]:
class CNNAttentionModel(nn.Module):
    def __init__(self):
        super(CNNAttentionModel, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(16)  # Batch normalization after conv1
        
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.convSelf1 = SelfAttentionLayer(16, 32, 4, similarity_fun)
        self.convSelf2 = SelfAttentionLayer(16, 32, 4, similarity_fun)  # Adjusted input channels
        self.convSelf3 = SelfAttentionLayer(16, 32, 4, similarity_fun)
        self.convSelf4 = SelfAttentionLayer(16, 32, 4, similarity_fun)
        
        self.conv5 = nn.Conv2d(in_channels=16, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.bn5 = nn.BatchNorm2d(64)  # Batch normalization after conv5
        
        self.gpa = nn.AdaptiveAvgPool2d((1, 1))
        
        self.fc1 = nn.Linear(64, 200)

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.convSelf1(x)))
        x = self.pool(F.relu(self.convSelf2(x)))
        x = self.pool(F.relu(self.convSelf3(x)))
        x = self.pool(F.relu(self.convSelf4(x)))
        x = self.pool(F.relu(self.bn5(self.conv5(x))))
        x = self.gpa(x)
        x = x.view(-1, 64)
        return self.fc1(x)


In [None]:
class CNN(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.convSelf1 = SelfAttentionLayer(16, 32, 4, similarity_fun)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=64, kernel_size=3, stride=1, padding=1)
        
        self.gpa = nn.AdaptiveAvgPool2d((1, 1))
        self.finalFC = nn.Linear(64, 200)
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.convSelf1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.gpa(x)
        x = x.view(-1, 64)
        return self.finalFC(x)

In [13]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out
    
class ModifiedResNetWithAttention(nn.Module):
    def __init__(self):
        super(ModifiedResNetWithAttention, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = ResidualBlock(16, 32, stride=2)
        # Integrate SelfAttentionLayer after the first Residual Block
        self.att1 = SelfAttentionLayer(32, 64, 4, similarity_fun)  # Adjust parameters as necessary
        self.layer2 = ResidualBlock(32, 64, stride=2)
        # Another SelfAttentionLayer could be added here if desired
        self.att2 = SelfAttentionLayer(64, 128, 4, similarity_fun)  # Adjust parameters as necessary
        self.layer3 = ResidualBlock(64, 128, stride=2)
        self.gpa = nn.AdaptiveAvgPool2d((1, 1))
        self.finalFC = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.att1(x)  # Apply SelfAttention after the first Residual Block
        x = self.layer2(x)
        x = self.att2(x)  # Apply SelfAttention after the second Residual Block
        x = self.layer3(x)
        x = self.gpa(x)
        x = x.view(x.size(0), -1)
        x = self.finalFC(x)
        return x

## Data Set Loading

In [14]:
from torchvision import transforms, datasets
from torch.utils.data import DataLoader

# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # ImageNet means
                         std=[0.229, 0.224, 0.225])   # ImageNet stds
])

# Load datasets
train_dataset = datasets.ImageFolder('tiny-imagenet-200/train', transform=transform)
val_dataset = datasets.ImageFolder('tiny-imagenet-200/val', transform=transform)

num_samples = len(train_dataset)
print(f"Total number of samples in the dataset: {num_samples}")

# Data loaders
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=4)


##  Set Up the Model, Loss Function, and Optimizer

In [22]:
from torch import optim

device =  torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
# model = CNNAttentionModel().to(device)
# model = CNN().to(device)
model = ModifiedResNetWithAttention().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

cuda


## Training Loop

In [23]:
num_epochs = 10  # Number of epochs to train for

for epoch in range(num_epochs):
    model.train()  # Set model to training mode
    running_loss = 0.0
    
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        # print("image:",images.shape, "label:", labels.shape)
        optimizer.zero_grad()  # Zero the parameter gradients
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()  # Backpropagation
        optimizer.step()  # Optimize
        
        running_loss += loss.item()
    
    print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}")

    # Validation step (optional but recommended)
    # Evaluate the model's performance on the validation set
    # Remember to set the model to eval mode and use torch.no_grad()


Please wait, filling up the shuffle buffer with samples.:  19%|█▉        | 375M/1.91G [00:42<02:58, 9.28MB/s]


Shuffle buffer filling is complete.


RuntimeError: 0D or 1D target tensor expected, multi-target not supported

## Validation Loop

In [None]:
model.eval()  # Set model to evaluation mode
val_loss = 0.0
correct = 0
total = 0

with torch.no_grad():
    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)
        val_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Validation Loss: {val_loss/len(val_loader)}")
print(f"Validation Accuracy: {100 * correct / total}%")
