# Assignment 4: Self-Attention for Vision

For this assignment, we're going to implement self-attention blocks in a convolutional neural network for CIFAR-10 Classification.

# Part I. Preparation

First, we load the CIFAR-10 dataset. This might take a couple minutes the first time you do it, but the files should stay cached after that.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler

import torchvision.datasets as dset
import torchvision.transforms as T

import numpy as np

In [None]:
NUM_TRAIN = 49000

# The torchvision.transforms package provides tools for preprocessing data
# and for performing data augmentation; here we set up a transform to
# preprocess the data by subtracting the mean RGB value and dividing by the
# standard deviation of each RGB value; we've hardcoded the mean and std.
transform = T.Compose([
                T.ToTensor(),
                T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
            ])

# We set up a Dataset object for each split (train / val / test); Datasets load
# training examples one at a time, so we wrap each Dataset in a DataLoader which
# iterates through the Dataset and forms minibatches. We divide the CIFAR-10
# training set into train and val sets by passing a Sampler object to the
# DataLoader telling how it should sample from the underlying Dataset.
cifar10_train = dset.CIFAR10('./data/datasets', train=True, download=True,
                             transform=transform)
loader_train = DataLoader(cifar10_train, batch_size=64, 
                          sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN)))

cifar10_val = dset.CIFAR10('./data/datasets', train=True, download=True,
                           transform=transform)
loader_val = DataLoader(cifar10_val, batch_size=64, 
                        sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN, 50000)))

cifar10_test = dset.CIFAR10('./data/datasets', train=False, download=True, 
                            transform=transform)
loader_test = DataLoader(cifar10_test, batch_size=64)

You have an option to **use GPU by setting the flag to True below**. It is not necessary to use GPU for this assignment. Note that if your computer does not have CUDA enabled, `torch.cuda.is_available()` will return False and this notebook will fallback to CPU mode.

The global variables `dtype` and `device` will control the data types throughout this assignment. 

In [None]:
USE_GPU = True

dtype = torch.float32 # we will be using float throughout this tutorial

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

# Constant to control how frequently we print train loss
print_every = 100

print('using device:', device)

### PyTorch Tensors: Flatten Function
A PyTorch Tensor is conceptionally similar to a numpy array: it is an n-dimensional grid of numbers, and like numpy PyTorch provides many functions to efficiently operate on Tensors. As a simple example, we provide a `flatten` function below which reshapes image data for use in a fully-connected neural network.

Recall that image data is typically stored in a Tensor of shape N x C x H x W, where:

* N is the number of datapoints
* C is the number of channels
* H is the height of the intermediate feature map in pixels
* W is the height of the intermediate feature map in pixels

This is the right way to represent the data when we are doing something like a 2D convolution, that needs spatial understanding of where the intermediate features are relative to each other. When we use fully connected affine layers to process the image, however, we want each datapoint to be represented by a single vector -- it's no longer useful to segregate the different channels, rows, and columns of the data. So, we use a "flatten" operation to collapse the `C x H x W` values per representation into a single long vector. The flatten function below first reads in the N, C, H, and W values from a given batch of data, and then returns a "view" of that data. "View" is analogous to numpy's "reshape" method: it reshapes x's dimensions to be N x ??, where ?? is allowed to be anything (in this case, it will be C x H x W, but we don't need to specify that explicitly). 

In [None]:
def flatten(x):
    N = x.shape[0] # read in N, C, H, W
    return x.view(N, -1)  # "flatten" the C * H * W values into a single vector per image

def test_flatten():
    x = torch.arange(12).view(2, 1, 3, 2)
    print('Before flattening: ', x)
    print('After flattening: ', flatten(x))

test_flatten()

### Check Accuracy Function
When training the model we will use the following function to check the accuracy of our model on the training or validation sets.

When checking accuracy we don't need to compute any gradients; as a result we don't need PyTorch to build a computational graph for us when we compute scores. To prevent a graph from being built we scope our computation under a `torch.no_grad()` context manager.

In [None]:
import torch.nn.functional as F  # useful stateless functions
def check_accuracy(loader, model):
    if loader.dataset.train:
        print('Checking accuracy on validation set')
    else:
        print('Checking accuracy on test set')
    num_correct = 0
    num_samples = 0
    model.eval()  # set model to evaluation mode
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device, dtype=dtype)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)
            scores = model(x)
            _, preds = scores.max(1)
            num_correct += (preds == y).sum()
            num_samples += preds.size(0)
        acc = float(num_correct) / num_samples
        print('Got %d / %d correct (%.2f)' % (num_correct, num_samples, 100 * acc))
        return 100 * acc

### Training Loop
We also use a slightly different training loop. Rather than updating the values of the weights ourselves, we use an Optimizer object from the `torch.optim` package, which abstract the notion of an optimization algorithm and provides implementations of most of the algorithms commonly used to optimize neural networks.

In [None]:
def train(model, optimizer, epochs=1):
    """
    Train a model on CIFAR-10 using the PyTorch Module API.
    
    Inputs:
    - model: A PyTorch Module giving the model to train.
    - optimizer: An Optimizer object we will use to train the model
    - epochs: (Optional) A Python integer giving the number of epochs to train for
    
    Returns: Nothing, but prints model accuracies during training.
    """
    model = model.to(device=device)  # move the model parameters to CPU/GPU
    acc_max = 0
    for e in range(epochs):
        for t, (x, y) in enumerate(loader_train):
            model.train()  # put model to training mode
            x = x.to(device=device, dtype=dtype)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)

            scores = model(x)
            loss = F.cross_entropy(scores, y)

            # Zero out all of the gradients for the variables which the optimizer
            # will update.
            optimizer.zero_grad()

            # This is the backwards pass: compute the gradient of the loss with
            # respect to each  parameter of the model.
            loss.backward()

            # Actually update the parameters of the model using the gradients
            # computed by the backwards pass.
            optimizer.step()

            if t % print_every == 0:
                print('Epoch %d, Iteration %d, loss = %.4f' % (e, t, loss.item()))
                acc = check_accuracy(loader_val, model)
                if acc >= acc_max:
                    acc_max = acc
                print()
    print("Maximum accuracy attained: ", acc_max)

In [None]:
# We need to wrap `flatten` function in a module in order to stack it
# in nn.Sequential
class Flatten(nn.Module):
    def forward(self, x):
        return flatten(x)

## Vanilla CNN; No Attention
We implement the vanilla architecture for you here. Do not modify the architecture. You will use the same architecture in the following parts. Do not modify the hyper-parameters.

In [None]:
channel_1 = 64
channel_2 = 32
learning_rate = 1e-3

model = nn.Sequential(
    nn.Conv2d(3, channel_1, 3, padding=1, stride=1),
    nn.ReLU(),
    nn.Conv2d(channel_1, channel_2, 3, padding=1),
    nn.ReLU(),
    Flatten(),
    nn.Linear(channel_2*32*32, 10),
)

optimizer = optim.Adam(model.parameters(), lr=learning_rate)


train(model, optimizer, epochs=10)

## Test set -- run this only once

Now we test our model on the test set . Think about how this compares to your validation set accuracy.
You should be able to see atleast 55% accuracy

In [None]:
vanillaModel = model
check_accuracy(loader_test, vanillaModel)


## Self-Attention

In the next section, you will implement an Attention layer which you will then use within a convnet architecture defined above for cifar 10 classification task.

A self-attention layer is formulated as following:

Input: $X$ of shape $(H\times W, C)$

Query, key, value linear transforms are $W_Q$, $W_K$, $W_V$, of shape $(C, C)$. We implement these linear transforms as 1x1 convolutional layers of the same dimensions.

$XW_Q$, $XW_K$, $XW_V$, represent the output volumes when input X is passed through the transforms.


Self-Attention is given by the formula: $Attention(X) = X + Softmax(\frac{XW_Q(XW_K)^\top}{\sqrt{C}})XW_V$

### Inline Question 1: Self-Attention is equivalent to which of the following: (10% of the grade)
1. K-means clustering <br />
2. Non-local means <br />
3. Residual Block <br />
4. Gaussian Blurring <br />

Your Answer: 2

### Here you implement the Attention module, and run it in the next section

In [None]:
# Initialize the attention module as a nn.Module subclass
class Attention(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        
        # TODO: Implement the Key, Query and Value linear transforms as 1x1 convolutional layers
        # Hint: channel size remains constant throughout
        self.conv_query = 
        self.conv_key = 
        self.conv_value = 

    def forward(self, x):
        N, C, H, W = x.shape
        
        # TODO: Pass the input through conv_query, reshape the output volume to (N, C, H*W)
        q = 
        # TODO: Pass the input through conv_key, reshape the output volume to (N, C, H*W)
        k = 
        # TODO: Pass the input through conv_value, reshape the output volume to (N, C, H*W)
        v = 
        # TODO: Implement the above formula for attention using q, k, v, C
        # NOTE: The X in the formula is already added for you in the return line
        attention = 
        # Reshape the output to (N, C, H, W) before adding to the input volume
        attention = attention.reshape(N, C, H, W)
        return x + attention

## Single Attention Block: Early attention; After the first conv layer. (50% of the grade)

In [None]:
channel_1 = 64
channel_2 = 32
learning_rate = 1e-3

# TODO: Use the above Attention module after the first Convolutional layer.
# Essentially the architecture should be [Conv->Relu->Attention->Relu->Conv->Relu->Linear]

model = None

optimizer = optim.Adam(model.parameters(), lr=learning_rate)

train(model, optimizer, epochs=10)

## Test set -- run this only once

Now we test our model on the test set . Think about how this compares to your validation set accuracy.
You should see improvement of about 2-3% over the vanilla convnet model. * Use this part to tune your Attention module and then move on to the next parts. *

In [None]:
earlyAttention = model
check_accuracy(loader_test, earlyAttention)

## Single Attention Block: Late attention; After the second conv layer. (10% of the grade)

In [None]:
channel_1 = 64
channel_2 = 32
learning_rate = 1e-3

# TODO: Use the above Attention module after the Second Convolutional layer.
# Essentially the architecture should be [Conv->Relu->Conv->Relu->Attention->Relu->Linear]

model = None

optimizer = optim.Adam(model.parameters(), lr=learning_rate)

train(model, optimizer, epochs=10)

## Test set -- run this only once

Now we test our model on the test set . Think about how this compares to your validation set accuracy.

In [None]:
lateAttention = model
check_accuracy(loader_test, lateAttention)

### Inline Question 2: Provide one example each of usage of self-attention and attention in computer vision. Explain the difference between the two. ( 10% of the grade)


Your Answer:

## Double Attention Blocks: After conv layers 1 and 2 (10% of the grade)

In [None]:
channel_1 = 64
channel_2 = 32
learning_rate = 1e-3

# TODO: Use the above Attention module after the Second Convolutional layer.
# Essentially the architecture should be [Conv->Relu->Attention->Relu->Conv->Relu->Attention->Relu->Linear]

model = None

optimizer = optim.Adam(model.parameters(), lr=learning_rate)

train(model, optimizer, epochs=10)

## Test set -- run this only once

Now we test our model on the test set . Think about how this compares to your validation set accuracy.

In [None]:
vanillaModel = model
check_accuracy(loader_test, vanillaModel)

## Inline Question 3: Rank the above models based on their performance on test dataset (20% of the grade)
( You are encouraged to run each of the experiments (training) atleast 3 times to get an average estimate )

Report the test accuracies alongside the model names. For example, 1. Vanilla CNN (57.45%, 57.99%).. etc

1. <br />
2. <br />
3. <br />
4. <br />

### Bonus Question (Ungraded): Can you give a possible explanation that supports the rankings?
Your Answer: