<a href="https://colab.research.google.com/github/Apoak/Deep-Learning-Projects/blob/main/Attention_Implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Lab 9.1 Attention Implementation

This week you will experimenet with attention-based models.

In [None]:
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F

1. Complete the following implementation of scaled dot-product attention.   Run the code cell to verify that the output shape is what it should be.

*Note: you can use `scores = scores.masked_fill(...)` to fill in values where the mask is True.  Fill in -1e9 as the score for masked values.*

In [None]:
def attention(Q,K,V,mask=None):
  """
  Computes scaled dot-product attention.

  Compute scores as Q*K^T.
  Optionally mask out score values to -1e9 where the mask is True.
  Divide by sqrt(d_k).
  Compute softmax on scores along the rows to obtain attention weights.
  Matrix multiply attention weights by values.

  Arguments:
    Q: queries [B,L,d_k]
    K: keys    [B,S,d_k]
    V: values  [B,S,d_v]
    mask: optional Boolean mask where True means hidden [B,L,S]

  Returns:
    Sequence of context vectors of shape [B,L,d_v]
  """
  # attention(Q, K, V) = softmax(QK^T/(d_k^.5))V
  scores = Q @ K.transpose(-2,-1)
  if mask is not None:
    scores = scores.masked_fill(mask, -1e9)
  scores = scores/np.sqrt(Q.shape[-1])
  scores = F.softmax(scores,dim=-1)
  scores = scores @ V

  return scores

Q = torch.rand(1,5,64)
K = torch.rand(1,10,64)
V = torch.rand(1,10,8)
mask = (torch.rand(1,5,10)>0.5)

y = attention(Q,K,V,mask=mask)

y.shape

The following code creates classes to build a Transformer-style decoder for generating sequences.

In [None]:
class AttentionHead(nn.Module):
    def __init__(self,d_model,d_k):
        super().__init__()
        self.WQ = nn.Linear(d_model,d_k)
        self.WK = nn.Linear(d_model,d_k)
        self.WV = nn.Linear(d_model,d_k)

    def forward(self,Q,K,V,mask=None):
        """ Compute attention head.

            Project the input to queries, keys, and values, and then apply attention.
            Arguments:
                Q: queries [B,L,d_model]
                K: keys    [B,S,d_model]
                V: values  [B,L,d_model]
                mask: optional Boolean mask where True means hidden [B,L,S]
            Output:
                Context vectors [B,L,d_k]
        """
        # apply linear projections to queries, keys, and values followed by masked attention
        return attention(self.WQ(Q),self.WK(K),self.WV(V),mask=mask)

class MultiHeadAttention(nn.Module):
    def __init__(self,d_model=512,num_heads=8):
        super().__init__()
        self.heads = []
        d_k = d_model // num_heads
        self.heads = nn.ModuleList([AttentionHead(d_model,d_k) for head in range(num_heads)])
        self.W = nn.Linear(d_model,d_model)

    def forward(self,Q,K,V,mask=None):
        """ Compute multi-head attention.

            Applies attention num_heads times, concatenates the results, and applies a final linear projection.
            Arguments:
                Q: queries [B,L,d_model]
                K: keys    [B,S,d_model]
                V: values  [B,L,d_model]
                mask: optional Boolean mask where True means hidden [B,L,S]
            Output:
               result of multi-head attention [B,L,d_model]
        """
        # compute each attention head and concatenate
        h = torch.cat([head(Q,K,V,mask=mask) for head in self.heads],dim=-1)

        # apply output projection
        return self.W(h)

class SelfAttentionBlock(nn.Module):
    def __init__(self,d_model=512,num_heads=8,d_ff=2048):
        super().__init__()
        self.multi_head_attention = MultiHeadAttention(d_model,num_heads)
        self.ln1 = nn.LayerNorm(d_model)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model,d_ff),
            nn.ReLU(),
            nn.Linear(d_ff,d_model),
        )
        self.ln2 = nn.LayerNorm(d_model)

    def forward(self,x,mask=None):
        """ Compute self attention block.

            Arguments:
                x: input sequence [B,S,d_model]
                mask: optional Boolean mask where True means hidden [B,L,S]
            Output:
               result of attention block [B,L,d_model]
        """
        # compute multi-head attention
        mha = self.multi_head_attention(x,x,x,mask=mask)

        # residual connection and layer normalization
        x = self.ln1(mha + x)

        # compute feed-forward network
        ff = self.feed_forward(x)

        # residual connection and layer normalization
        x = self.ln2(ff + x)

        return x

class PositionalEmbedding(nn.Module):
    def __init__(self,max_seq_len,d_model):
        super().__init__()
        self.positional_embedding = nn.Embedding(max_seq_len,d_model)

    def forward(self,x):
        """ Adds a positional embedding.

            Arguments:
                x: input token sequence [B,S,d_model]
            Output:
               sequence with positional embedding added [B,S,d_model]
        """
        # get sequence length
        N = x.shape[1]

        # look up positional embedding vectors
        pe = self.positional_embedding(torch.arange(N).to(x.device)) # [N,d_model]

        # add to input
        x = x + pe[None,...] # [B,N,d_model]

        return x

class TransformerDecoder(nn.Module):
    def __init__(self,vocabulary_size,max_seq_len,
                      d_model=512,num_heads=8,d_ff=2048,num_blocks=6):
        super().__init__()
        self.blocks = nn.ModuleList([SelfAttentionBlock(d_model,num_heads,d_ff) for b in range(num_blocks)])
        self.token_embedding = nn.Embedding(vocabulary_size,d_model)
        self.output = nn.Linear(d_model,vocabulary_size)
        self.positional_embedding = PositionalEmbedding(max_seq_len,d_model)

    def forward(self,x,mask=None):
        """ Computes the decoded sequence:

            Convert input to token embedding vectors
            Add positional embedding to input
            Apply self-attention blocks with mask
            Compute output

            Arguments:
                x: input token sequence [B,S]
                mask: optional Boolean mask where false means hidden [B,S]
            Output:
               sequence predictions [B,S,output_size]
        """
        # look up embedding vectors for tokens
        x = self.token_embedding(x) # [B,S,d_model]

        # apply positional embedding
        x = self.positional_embedding(x) # [B,S,d_model]

        # apply sequence of masked self-attention blocks
        for block in self.blocks:
            x = block(x,mask=mask) # [B,S,d_model]

        # produce sequence of output vectors
        y = self.output(x) # [B,S,vocabulary_size]

        return y


This function produces masks appropriate for sequence prediction.  The mask ensures that the output token at time t+1 only sees the generated sequence up to time t.

In [None]:
def make_mask(seq_len):
    """ Make a mask for sequence prediction. """
    return (torch.triu(torch.ones((1,seq_len,seq_len)), diagonal=1)==1)

make_mask(10)

Now we will make a sequence of integers and see if the Transformer decoder can learn the sequence.

In [None]:
seq = torch.arange(100)
x = seq[:-1][None,...]
y = seq[1:][None,...]
mask = make_mask(x.shape[1])
x,y

In [None]:
steps = 1000
# steps = 6
model = TransformerDecoder(vocabulary_size=100,max_seq_len=x.shape[1],
                           d_model=64,num_heads=8,d_ff=512,num_blocks=3
                           )


opt = torch.optim.Adam(model.parameters(),lr=.01)
loss_fn = nn.CrossEntropyLoss()

for step in range(steps):
    model.train()
    opt.zero_grad()

    y_pred = model(x,mask)
    loss = loss_fn(y_pred.view(-1,y_pred.shape[-1]),y.view(-1))
    loss.backward()

    opt.step()

    print(step,loss.item())

If the Transformer has learned the sequence correctly, this output will read 1, 2, 3, ..., 97, 98, 99.

In [None]:
torch.argmax(model(x),-1)

2. What size context does the Transformer need in order to learn the above sequence?

The minimum size context I was able to find was 6 steps. So the transformer needed to be trained on the first 6% of the sequence to learn it.

3. Now design a pattern that requires a larger context and see if the Transformer can learn it.

In [None]:
seq = [i for i in range(0,20) for j in range(i)]
seq = torch.tensor(seq)
x = seq[:-1][None,...]
y = seq[1:][None,...]
mask = make_mask(x.shape[1])
x,y

In [None]:
torch.argmax(model(x),-1)

Even with 1000 iterations the transformer did not learn it.