# Lecture IV: Attention Mechanism and Network Interpretability

In this homework, we will add attention mechanism to the recurrent neural network we developed in homework 3, and use it to perform network interpretability study.

First, we will re-define the RNN we developed last time, starting with the module loading:

In [None]:
# pylint: disable=E1101,R,C
import numpy as np
import os
import argparse
import time
import math
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.parameter import Parameter
from torch.nn import init
import torch.nn.functional as F
import torch.utils.data as data_utils
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torch.autograd import Variable
import gzip
import pickle
from scipy import sparse
from sklearn.metrics import roc_auc_score, roc_curve
from sklearn.preprocessing import StandardScaler, MinMaxScaler,RobustScaler,MinMaxScaler
import matplotlib
import matplotlib.pyplot as plt
from tqdm import tqdm
import torchsnooper
from torch.cuda.amp import autocast 
torch.set_default_tensor_type(torch.FloatTensor)
import matplotlib.gridspec as gridspec

The Japanese Vowel dataset can be found [here](https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels). For a detailed description, please look at this website. First, we will download the dataset:

In [None]:
!wget -nc https://archive.ics.uci.edu/ml/machine-learning-databases/JapaneseVowels-mld/ae.train
!wget -nc https://archive.ics.uci.edu/ml/machine-learning-databases/JapaneseVowels-mld/ae.test
!wget -nc https://archive.ics.uci.edu/ml/machine-learning-databases/JapaneseVowels-mld/size_ae.train
!wget -nc https://archive.ics.uci.edu/ml/machine-learning-databases/JapaneseVowels-mld/size_ae.test

This dataset contains 9 male japanese speaker pronouncing the utterance /ae/. The data is decoded by Linear Predictive Coding. A detail of LPC can be found [here](https://en.wikipedia.org/wiki/Linear_predictive_coding). Each utterance contains 12 LPC basis, thus for each time index, the time series will contain 12 channels.

## Part I: Dataset

First, we will need to prepare the dataset using `Dataset()` class. This is idential to it in homework 3:

In [None]:
class JapaneseVowelDataset(Dataset):

    def __init__(self,plot=True):
        self.max_length = 29 # The maximum possible length of each utterance contains 29 samples
        self.num_LPC = 12    # The LPC spectrum contains 12 coefficients, so the data shape will be [29,12]
        
        train_data, train_label = self.read_vowels("ae.train","size_ae.train")
        test_data, test_label = self.read_vowels("ae.test","size_ae.test")
        
        self.size = len(train_data) + len(test_data)
        self.train_test_split = len(train_data)
        
        self.data = train_data + test_data
        self.labels = train_label + test_label
        
        if plot:
            self.plot_data()
        
        
    def __len__(self):
        '''
        This function returns the size of overall dataset
        '''
        return self.size


    def __getitem__(self, idx):
        '''
        This function extract a single entry from the dataset at the given index idx
        In this dataset, the data has variable length, so we need to pad 
        the LPC coefficients to have the same length for training purpose
        '''
        output = np.zeros((self.max_length, self.num_LPC))
        data = self.data[idx]
        output[:data.shape[0]] += data
        return output, self.labels[idx]
    
    def get_train_test_split(self):
        '''
        This function get the train test split size of the dataset
        '''
        return self.train_test_split
    
    def read_vowels(self,file, size_file):
        vowel_units = []
        speaker_size = []
        labels = []
        #Read out the LPC value of all vowels
        with open(file, "r") as f:
            current_vowel = []
            for line in f.readlines():
                if line == '\n':
                    vowel_units.append(np.array(current_vowel))
                    current_vowel = []
                    continue
                current_vowel.append(np.array(line.strip().split(" "),dtype=float).tolist())
        #Read out the size of samples by 9 speakers
        with open(size_file, "r") as f:
            speaker_size = np.array(f.readline().strip().split(" "),dtype=int)
            assert len(speaker_size) == 9 # If speaker size is not 9, then there's something wrong
        #Assign a label to each speaker, speaker 1 == 0 .....speaker 9 == 8:
        for speaker_label in range(9):
            labels += [speaker_label] * speaker_size[speaker_label]
        # Check if the number of label equals to number of data
        # If not, there is something wrong
        assert len(vowel_units) == len(labels)
        return vowel_units, labels
            
                    
        
    
    def plot_data(self):
        '''
        This function plots the LPC spectrum of 9 random utterances
        '''
        plt.figure(figsize=(20,12))
        sample_index = np.random.randint(low=0,high=self.__len__(), size = 9)
        for i in range(9):
            plt.subplot(3,3, i+1)
            voice, label = self.__getitem__(sample_index[i])
            utt_length = voice.shape[0]
            for i in range(voice.shape[-1]):
                plt.plot(np.arange(utt_length), voice[:,i])
            plt.xlabel("Time Index")
            plt.ylabel("LPC Coefficients")

Similarly, we can check the form of data by plotting the LPC spectrum coefficients. The trailing 0s comes from the padding we performed within the dataset:

In [None]:
JapaneseVowelDataset()

## Part II: Recurrent Neural Network
In this part, we will add attention mechanism to a recurrent neural network model. A typical RNN model have been provided in the following code block. You could also copy and paste your RNN from homework 3 if you prefer. Your goal is to add attention mechanism to this RNN using weight kernel concatenation:
\begin{equation}
s(h_{i},h_{n})=h_{i}^{T}Wh_{n}
\end{equation}
Read the RNN code carefully, try to answer two questions:
- Why does `fc1` equal to 2 times the hidden_size of LSTM layer?
- Why does the `self.attention_weight` tensor read a dimension of `(seq_len,hidden_size)`? How does it link back to the equation we listed above?

After answering that two questions, add the attention mechanism following the procedure below:

- Couple `self.attention_weight` and `output` to produce $s(h_{i},h_{n})$
- Feed $s(h_{i},h_{n})$ into a [softmax function](https://pytorch.org/docs/master/generated/torch.nn.Softmax.html?highlight=softmax#torch.nn.Softmax) to produce attention score
- multiply the attention score back to output to produce context vector
- concatenate context vector with last hidden state output to produce attention vector
- feed attention vector to the fully connected neural network

Some other reference materials: 
- [lecture slide](https://drive.google.com/file/d/1-CPfeV-rA460ZS1u_cbuDyJqhn0oz5Bl/view?usp=sharing), page 28-33
- [lecture video](https://www.youtube.com/watch?v=5C1yxV0bbSI), time 39:17-56:45

In [None]:
class RNN(nn.Module):
    def __init__(self):
        super(RNN, self).__init__()
        '''
        Initialize RNN with attention mechanism with 3 parts:
            A feature extractor based on LSTM network
            A fully connected classifier
            An attention kernel
        '''
        input_size = 12 #12 LPC basis per input
        seq_len = 29    # Sequence length of Japanese Vowel data
        hidden_size = 128
        self.RNNLayer = torch.nn.LSTM(input_size = input_size, hidden_size = hidden_size, batch_first=True)
        fc1, fc2, fc3, fc4 = np.linspace(hidden_size*2, 64,4,dtype=int)
        print(fc1, fc2, fc3, fc4)
        self.fcnet = nn.Sequential(
            torch.nn.Linear(fc1, fc2),
            torch.nn.LeakyReLU(),
            torch.nn.Dropout(),
            torch.nn.Linear(fc2, fc3),
            torch.nn.LeakyReLU(),
            torch.nn.Dropout(),
            torch.nn.Linear(fc3, fc4),
            torch.nn.LeakyReLU(),
            torch.nn.Dropout(),
            torch.nn.Linear(fc4, 9),
        )
        self.attention_weight = Parameter(torch.empty(seq_len,hidden_size).uniform_(-0.1, 0.1))
        
#     @torchsnooper.snoop()
    def forward(self, x):
        '''
        The forward operation of each training step of the neural network model
        '''
        output, (h, c) = self.RNNLayer(x)
        
        '''
        Add attention mechanism here, you will need to code output and self.attention_weight together.
        '''
        
        x = self.fcnet(attention_vector)
        return x

Similar to what we did in Lecture 2 homework, we will pull out 1 event from the dataset, and use `torchsnooper.snoop()` to check the network structure. Before proceeding to the next part, you may want to stare at the tensor output of `torchsnooper.snoop()` carefully to understand the tensor flow within the RNN network.

In [None]:
#Pull out 1 event from the dataset
test_event, test_label = next(iter(JapaneseVowelDataset(plot=False)))
test_event = torch.FloatTensor(test_event).unsqueeze(0) # Insert batch dimension
test_network = RNN()
print(test_network(test_event,))

## Part III: Training and Evaluation
After building the neural network, we train it the same way as we did in Lecture 3:

In [None]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # This says if GPU is available, use GPU, otherwise use CPU
NUM_EPOCHS =30
LEARNING_RATE =1e-3 # 1e-2 is a good learning rate for general purpose
BATCH_SIZE=4

In [None]:
def set_up_classifier():
    classifier = RNN() # Define CNN neural network classifier
    classifier.to(DEVICE)     # Send the classifier to DEVICE as we defined earlier

    print("# of params in model: ", sum(x.numel() for x in classifier.parameters()))

    criterion = torch.nn.CrossEntropyLoss()
    criterion = criterion.to(DEVICE)

    #Define the optimizer
    optimizer = torch.optim.Adam(classifier.parameters(),lr=LEARNING_RATE)
    
    return classifier, criterion, optimizer

In [None]:
def get_dataloader():
    dataset = JapaneseVowelDataset(plot=False)
    #Get the indices of train dataset and test dataset correspondingly, indices [0:train_test_split] is the training dataset, indices [train_test_split, len(dataset)] is the test dataset.
    train_test_split = dataset.get_train_test_split()
    train_indices, val_indices = list(range(train_test_split)), list(range(train_test_split,len(dataset)))

    #Shuffle the two indices list
    np.random.shuffle(train_indices)
    np.random.shuffle(val_indices)

    # Define two subset random sampler to sample events according to the training indices
    train_sampler = SubsetRandomSampler(train_indices)
    valid_sampler = SubsetRandomSampler(val_indices)

    # Finally, define the loader by passing in the dataset, batch size and corresponding sampler
    # Note that the number of data in each sub-dataset might not be divisibe by the batch size, so drop_last=True drops the last batch with all the residual events.
    train_loader = data_utils.DataLoader(dataset, batch_size=BATCH_SIZE, sampler=train_sampler, drop_last=True)
    test_loader = data_utils.DataLoader(dataset, batch_size=BATCH_SIZE,sampler=valid_sampler,  drop_last=True)
    
    return train_loader, test_loader

In [None]:
classifier, criterion, optimizer = set_up_classifier()
train_loader, test_loader = get_dataloader()

loss_values = []
accuracy_values = []
y_true = []
y_pred = []

for epoch in range(NUM_EPOCHS):
    for i, (utterances, labels) in tqdm(enumerate(train_loader)):
        classifier.train() # This line set the neural network to train mode, some layers perform differently in train and test mode.
        
        utterances = utterances.to(DEVICE).float()
        labels = labels.to(DEVICE)
        
        #Train the RNN classifier
        outputs  = classifier(utterances)
        
        # Calculate loss
        loss = criterion(outputs, labels)
        
        # Back-propagate loss to update gradient
        loss.backward()
        
        # Perform gradient descent to update parameters
        optimizer.step()
        
        # reset gradient to 0 on all parameters
        optimizer.zero_grad()

    print('\rEpoch [{0}/{1}], Iter [{2}/{3}] Loss: {4:.4f}'.format(
        epoch+1, NUM_EPOCHS, i+1, len(train_loader),
        loss.item(), end=""),end="")
    loss_values.append(loss.item())
    
    #After every epoch, evaluate the validation accuracy on the test loader
    num_accurate = 0
    num_images = 0
    for utterances,labels in tqdm(test_loader):

        classifier.eval() # This line set the neural network to evaluation mode, some layers perform differently in train and test mode.
        
        #While validating the network, we do not want it to produce any gradient. This will also save us time/memory
        with torch.no_grad():
            
            # Convey images to device, then feed it to the neural network for network output
            utterances = utterances.to(DEVICE).float()
            outputs  = classifier(utterances)
            
            # Get classification decision by reading out the maximum value on the 10-dimensional vector
            decision = torch.argmax(outputs, dim=-1)
            decision = decision.cpu().data.numpy().flatten() # copy decision to CPU and convert it to a numpy array
            labels = labels.cpu().data.numpy().flatten()
            
            # Update the list of truth value and network predictions in last epoch:
            if epoch == (NUM_EPOCHS-1):
                y_true += list(labels)
                y_pred += list(decision)
            
            #Calculate accuracy by # of correct prediction / total numbers

            num_accurate += np.sum((decision - labels) == 0)
            num_images += len(decision)
    accuracy_values.append(num_accurate/num_images)  
torch.save(classifier.state_dict(), 'RNN.pt') # Save the trained RNN model

After training, we will be able to evaluate our training results.

First, let's plot the learning curve, that is, the loss value with respect to the epochs:

In [None]:
plt.plot(np.arange(NUM_EPOCHS).astype(int), loss_values)
plt.xlabel("Epochs")
plt.ylabel("Cross Entropy Loss [a.u.]")

You should find that the loss drops as you train the network with more and more epochs.

Next, let's plot the accuracy curve. That is, the accuracy with respect to epochs:

In [None]:
plt.plot(np.arange(NUM_EPOCHS).astype(int), np.array(accuracy_values)*100.0)
plt.xlabel("Epochs")
plt.ylabel("Classification Accuracy [%]")

Does the training result gets better than the RNN in lecture 3 homework? If not, the reason might be:
- The dataset is too small
- The task is too easy to attention mechanism

## Part IV: RNN Interpretability

Network interpretability refers to the capability to explain each decision of the neural network. A traditional neural network model is black-box, meaning that we do not know which one or many features it utilize to make the classification decision. The beauty of attention mechanism is that: it provides us a straightforward way to interpret the decision of RNN or transformer network. In this part, we will leverage the attention score to find the origin of RNN classification power.

First, we will need to do some tweakings of the RNN neural network. In the following code block, copy/paste the RNN model you defined previously, and change the output from `x` to the attention score variable. Remember that the attention score is the output of softmax operation. What is the shape of the attention score? You can check it using torchsnooper.

**Notice**: Do not change anything in the `__init__` method, otherwise the neural network may fail loading.

In [None]:
class RNNInterpretor(nn.Module):
    def __init__(self):
        super(RNNInterpretor, self).__init__()
        '''
        Initialize RNN with attention mechanism with 3 parts:
            A feature extractor based on LSTM network
            A fully connected classifier
            An attention kernel
        '''
        input_size = 12 #12 LPC basis per input
        seq_len = 29    # Sequence length of Japanese Vowel data
        hidden_size = 128
        self.RNNLayer = torch.nn.LSTM(input_size = input_size, hidden_size = hidden_size, batch_first=True)
        fc1, fc2, fc3, fc4 = np.linspace(hidden_size*2, 64,4,dtype=int)
        self.fcnet = nn.Sequential(
            torch.nn.Linear(fc1, fc2),
            torch.nn.LeakyReLU(),
            torch.nn.Dropout(),
            torch.nn.Linear(fc2, fc3),
            torch.nn.LeakyReLU(),
            torch.nn.Dropout(),
            torch.nn.Linear(fc3, fc4),
            torch.nn.LeakyReLU(),
            torch.nn.Dropout(),
            torch.nn.Linear(fc4, 9),
        )
        self.attention_weight = Parameter(torch.empty(seq_len,hidden_size).uniform_(-0.1, 0.1))
        
#     @torchsnooper.snoop()
    def forward(self, x):
        '''
        The forward operation of each training step of the neural network model
        '''
        output, (h, c) = self.RNNLayer(x)
        
        '''
        Add attention mechanism here, you will need to code output and self.attention_weight together.
        '''
        
        x = self.fcnet(attention_vector)
        return attention_score

Then, we will load the parameters we saved from the trained RNN to build an interpretor network:

In [None]:
interpretor = RNNInterpretor()
pretrained_dict = torch.load('RNN.pt')
model_dict = interpretor.state_dict()
model_dict.update(pretrained_dict) 
interpretor.load_state_dict(pretrained_dict)

Then, we will pull out a batch of event from the test loader and plot its attention score. The attention score refects the relative importance of each time slices. Since attention score will sum to 1, it can also be treated as the relative weight of each time slice:

In [None]:
#############################################
# Feed event through the interpretor to get attention score
test_event, test_label = next(iter(test_loader))
test_event = test_event.to(DEVICE).float()
attention_score = interpretor(test_event)
#############################################

interpretor.eval()
batch_size = len(test_label)
fig = plt.figure(figsize=(40,40))
cellc = int(math.ceil((batch_size**0.5)))
outer = gridspec.GridSpec(cellc, cellc, wspace=0.2, hspace=0.2)
for i in range(batch_size):
    inner = gridspec.GridSpecFromSubplotSpec(2, 1,subplot_spec=outer[i], wspace=0.1, hspace=0.1, height_ratios=[5,1])
    attention = torch.sum(attention_score[i],dim=-1).cpu().data.numpy().flatten() # The last dimension of attention score is the hidden size (128) in default case, we will sum over this dimension to get the attention score for this time inde
    attention -= np.average(attention) # Only plot the variance of attention to see trends
    current_vowel = test_event[i]
    
    ax_main = plt.Subplot(fig, inner[0])
    vowel_length, lpc_basis = current_vowel.shape
    for j in range(lpc_basis):
        ax_main.plot(np.arange(vowel_length), current_vowel[:,j])
    ax_main.legend()
    fig.add_subplot(ax_main)

    ax_attention = plt.Subplot(fig, inner[1], sharex=ax_main)
    ax_attention.bar(x=np.arange(vowel_length)+0.5,height=attention, width=1)
    fig.add_subplot(ax_attention)
plt.show()
plt.cla()
plt.clf()
plt.close()

**Question:** How does the attention score reflects the classification power of RNN? Since we perform 0 embedding onto the data to make them equal in length, all input should contain several 0s towards the end. Does this 0-embedded region exhibit high attention score? In other words, does the RNN classifier pay attention to the embedded 0s?