# Attention Basics!

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch import optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision
import torch.utils.data.dataloader as dataloader

import torch.nn.functional as F

from tqdm.notebook import trange, tqdm

In [None]:
data_set_root = "../../datasets"

In [None]:
transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])])

dataset = datasets.MNIST(data_set_root, train=True, download=True, transform=transform)

num_of_examples = 100
rand_perm = torch.randperm(dataset.data.shape[0])[:num_of_examples]
dataset_tensor = torch.cat([dataset.__getitem__(i)[0].reshape(1, -1) for i in rand_perm])

In [None]:
# Lets visualise the images
out = torchvision.utils.make_grid(dataset_tensor.reshape(-1, 1, 28, 28), 10, normalize=True, pad_value=0.5)
_ = plt.imshow(out.numpy().transpose((1, 2, 0)))

## Indexing a Dataset
At this point we would be familiar with indexing a tensor/array with an integer index!

In [None]:
# Define our index value
q_index = 10

In [None]:
# Lets visualise the image at this index!
plt.figure(figsize = (5,5))
_ = plt.imshow(dataset_tensor[q_index].reshape(28, 28).numpy(), cmap="gray")

## Indexing a Dataset with Matrix Multiplication
Did you know we can do the same thing, but using Matrix multiplication?

In [None]:
# Turn our index into a ont-hot-coded vector the sample length as the number of samples
# we will call this our "query" vector (q)
q_one_hot_vec = F.one_hot(torch.tensor([q_index]), num_of_examples)

# Create a unique one-hot-coded vector for every image in our set
# we will call this our "key" vector (k)
k_one_hot = F.one_hot(torch.arange(num_of_examples), num_of_examples)

# Randomly shuffle the keys and dataset to show that we will find the target image
# even in a randomly organised dataset
rand_perm = torch.randperm(num_of_examples)
k_one_hot = k_one_hot[rand_perm]
dataset_tensor_random = dataset_tensor[rand_perm]

In [None]:
# Multiply our key vector 
index_map = torch.mm(q_one_hot_vec, k_one_hot.t()).float()

In [None]:
output = torch.mm(index_map, dataset_tensor_random)

In [None]:
# Lets visualise the image at this index!
plt.figure(figsize = (5,5))
_ = plt.imshow(output.reshape(28, 28).numpy(), cmap="gray")

## Attention as a "Soft" Look-up
What if we don't use "hard" one-hot coded vectors?

In [None]:
# Define the size for each of the vectors
vec_size = 512

# Create a random query vector
q_random_vec = torch.randn(1, vec_size)

# Create a random key vector for each image in the dataset
random_keys = torch.randn(num_of_examples, vec_size)

# Calculate an "attention map" 
attention_map = torch.mm(q_random_vec, random_keys.t()).float()

# Calculate the Softmax over the all over the attention map
attention_map = F.softmax(attention_map, 1)

# Use the attention map to soft "index" over the dataste
output = torch.mm(attention_map, dataset_tensor)

In [None]:
print("The largest Softmax value is %f" % attention_map.max().item())

In [None]:
# Lets visualise the image we get as a result!
plt.figure(figsize = (5,5))
_ = plt.imshow(output.reshape(28, 28).numpy(), cmap="gray")

## Multiple Queries
We can also perform multiple queries at the same time

In [None]:
# Define the size for each of the vectors
vec_size = 32

# Number of Queries!
num_q = 8

# Create random query vectors
q_random_vec = torch.randn(num_q, vec_size)

# Create a random key vector for each image in the dataset
random_keys = torch.randn(num_of_examples, vec_size)

# Calculate an "attention map" 
attention_map = torch.mm(q_random_vec, random_keys.transpose(0, 1)).float()

# Calculate the Softmax over the all over the attention map
attention_map = F.softmax(attention_map, -1)

# Use the attention map to soft "index" over the dataste
output = torch.mm(attention_map, dataset_tensor)

In [None]:
# Lets visualise an entire batch of images!
plt.figure(figsize = (20,10))
out = torchvision.utils.make_grid(output.reshape(num_q, 1, 28, 28), 8, normalize=True)
_ = plt.imshow(out.numpy().transpose((1, 2, 0)))

## Multi-Headed Attention
We can also perform Attention multiple times in parallel!

In [None]:
# Define the size for each of the vectors
vec_size = 32

# Number of Queries!
num_q = 8

# Number of Heads!
num_heads = 4

# Create random query vectors
q_random_vec = torch.randn(num_heads, num_q, vec_size)

# Create a random key vector for each image in the dataset
random_keys = torch.randn(num_heads, num_of_examples, vec_size)

# Calculate an "attention map" 
attention_map = torch.bmm(q_random_vec, random_keys.transpose(1, 2)).float()

# Calculate the Softmax over the all over the attention map
attention_map = F.softmax(attention_map, 2)

# Use the attention map to soft "index" over the dataste
output = torch.bmm(attention_map, dataset_tensor.unsqueeze(0).expand(num_heads, num_of_examples, -1))

In [None]:
out_reshape = output.reshape(num_heads, num_q, 28, 28).transpose(1, 2).reshape(num_heads, 1, 28, num_q*28)

In [None]:
# Lets visualise an entire batch of images!
plt.figure(figsize = (20,10))
out = torchvision.utils.make_grid(out_reshape, 1, normalize=True, pad_value=0.5)
_ = plt.imshow(out.numpy().transpose((1, 2, 0)))

## Pytorch Multi-Head Attention
Of course Pytorch has it's own implementation of Multi-Head Attention!<br>

[Pytorch MultiheadAttention](https://pytorch.org/docs/2.1/generated/torch.nn.MultiheadAttention.html#torch.nn.MultiheadAttention)

We'll go into more detail on how to use it in later examples!

In [None]:
# Define the size for each of the vectors
vec_size = 32

# Number of Heads!
num_heads = 8

# Batch Size!
batch_size = 32

# Create a batch of a single random query vector
query = torch.randn(batch_size, 1, num_heads * vec_size)

# Create a random key vector for each image in the dataset
key = torch.randn(batch_size, num_of_examples, num_heads * vec_size)

# Create a random key vector for each image in the dataset
value = torch.randn(batch_size, num_of_examples, num_heads * vec_size)

multihead_attn = nn.MultiheadAttention(num_heads * vec_size, num_heads, batch_first=True)

In [None]:
# Provide a forward pass through the Multi-Head Attention
attn_output, attn_output_weights = multihead_attn(query, key, value, average_attn_weights=False)

In [None]:
# The forward pass returns two things

# The softmaxed "attention mask"
print("Softmax Attention Mask", attn_output_weights.shape)

# The output of the attention block
print("Attention Output", attn_output.shape)

In [None]:
# Each query/key/value vector is passed through a "projection" aka a learnable linear layer before
# the attention mechanism 
# As they are all the same size Pytorch creates a single block of parameters splits it into 3 
# Before doing a forward pass of each*
print("Projection weight size", multihead_attn.in_proj_weight.shape)

# *Most of the time, doing a deep dive into the implementation Pytorch tries to do a lot of optimisation
# to try and be efficient as possible depending on the use-case

## Train a Multi-Head Attention

In [None]:
class AttentionTest(nn.Module):
    def __init__(self, num_of_examples=100, embed_dim=784, num_heads=4):
        super(AttentionTest, self).__init__()
        
        self.img_mlp = nn.Sequential(nn.Linear(784, embed_dim), 
                                     nn.LayerNorm(embed_dim),
                                     nn.ELU(),
                                     nn.Linear(embed_dim, embed_dim))
        
        self.mha = nn.MultiheadAttention(embed_dim=embed_dim, 
                                         num_heads=num_heads,
                                         batch_first=True)

    def forward(self, img, values):
        img_ = self.img_mlp(img)
        values_ = self.img_mlp(values)

        attn_output, attn_output_weights = self.mha(img_, values_, values_)
        
        output = torch.bmm(attn_output_weights, values)

        return output, attn_output_weights

In [None]:
# Define the size for each of the vectors
embed_dim = 256

# Number of Heads!
num_heads = 1

# Batch Size!
batch_size = 64

In [None]:
device = torch.device(0 if torch.cuda.is_available() else 'cpu')
train_loader = dataloader.DataLoader(dataset, shuffle=True, batch_size=batch_size, 
                                     num_workers=4, drop_last=True)

In [None]:
# Create an instance of the model
mha_model = AttentionTest(num_of_examples=num_of_examples, 
                          embed_dim=embed_dim, 
                          num_heads=num_heads).to(device)

optimizer = optim.Adam(mha_model.parameters(), lr=1e-4)

loss_logger = []

# Duplicate the data value tensor, one per batch element
values_tensor = dataset_tensor.unsqueeze(0).expand(batch_size, num_of_examples, -1).to(device)

In [None]:
mha_model.train()
for _ in trange(10, leave=False):
    for data, _train_loader in tqdm(train_loader, leave=False):
        q_img = data.reshape(data.shape[0], 1, -1).to(device)

        attn_output, attn_output_weights = mha_model(q_img, values_tensor)

        loss = (attn_output - q_img).pow(2).mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_logger.append(loss.item())


In [None]:
_ = plt.plot(loss_logger[100:])
print("Minimum MSE loss %.4f" % np.min(loss_logger))

In [None]:
# Forward pass of the model
mha_model.eval()
with torch.no_grad():
    q_img = data.reshape(data.shape[0], 1, -1).to(device)
    attn_output, attn_output_weights = mha_model(q_img, values_tensor)

In [None]:
# For a given input use the attention map to find the "closest" value-data matches
index = 10
top10 = attn_output_weights[index, 0].argsort(descending=True)[:10]
top10_data = values_tensor[index, top10].cpu()

In [None]:
plt.figure(figsize = (3, 3))
out = torchvision.utils.make_grid(q_img[index].cpu().reshape(-1, 1, 28, 28), 8, 
                                  normalize=True, pad_value=0.5)

_ = plt.imshow(out.numpy().transpose((1, 2, 0)))

In [None]:
# Lets visualise the top 10 closest matches
plt.figure(figsize = (10, 10))
out = torchvision.utils.make_grid(top10_data.reshape(-1, 1, 28, 28), 10, normalize=True, pad_value=0.5)
_ = plt.imshow(out.numpy().transpose((1, 2, 0)))

In [None]:
_ = plt.plot(attn_output_weights[index, 0].cpu().numpy().flatten())

In [None]:
# Reshape the target and returned images
target_img = q_img.reshape(batch_size, 1, 28, 28)
indexed_img = attn_output.reshape(batch_size, 1, 28, 28)

# Stack the images with the returned image on top
img_pair = torch.cat((indexed_img, target_img), 2).cpu()

In [None]:
# Lets visualise the pairs of images, the returned image on top, the target on bottom
plt.figure(figsize = (10, 10))
out = torchvision.utils.make_grid(img_pair, 8, normalize=True, pad_value=0.5)
_ = plt.imshow(out.numpy().transpose((1, 2, 0)))