# Task 1 (6p)
Your task is to modify the custom implementation of MultiHeadAttention. This custom implementation, currently, enables each token to attent to every other token.


Your job is to change this behavior in a specific way.
Let $S$ be our input sequence of length $2 \cdot k$:
- tokens on positions $i \lt k$ should attend to prefix of $S$ of length $k$ ($S[:k]$) - every token up to position k
- tokens on positions $i \ge k$ should attend to prefix of $S$  of length $i + 1$ ($S[:i + 1]$) - every previous token and itself

(Note: You can assume the sequence length is always an even number).

In [45]:
import torch
import math
import torch.nn.functional as F
class MultiHeadAttention(torch.nn.Module):
    def __init__(self, d_model, num_heads, d_head):
      super().__init__()
      self.d_model = d_model
      self.num_heads = num_heads
      self.d_head = d_head

      self.W_Q = torch.nn.Linear(d_model, num_heads*d_head, bias=True)
      self.W_K = torch.nn.Linear(d_model, num_heads*d_head, bias=True)
      self.W_V = torch.nn.Linear(d_model, num_heads*d_head, bias=True)
      self.W_O = torch.nn.Linear(num_heads*d_head, d_model, bias=True)

    def forward(self, x):

      seq_len, batch_size, _ = x.shape

      Q = self.W_Q(x).reshape(seq_len, batch_size, self.num_heads, self.d_head)
      K = self.W_K(x).reshape(seq_len, batch_size, self.num_heads, self.d_head)
      V = self.W_V(x).reshape(seq_len, batch_size, self.num_heads, self.d_head)

      scaled_QK = torch.einsum("ibhd,jbhd->bhij", Q, K) / math.sqrt(self.d_head)
      # shape of scaled_QK is (batch_size, num_heads, seq_len, seq_len)
      #TODO

      k = seq_len//2
      under_k_indices = torch.ones_like(scaled_QK)
      under_k_indices[:, :, :, k:] = 0
      over_k_indices = torch.tril(torch.ones_like(scaled_QK))

      mask = (under_k_indices+over_k_indices)#.reshape(1, 1, seq_len, seq_len)


      scaled_QK = torch.where(mask>0, scaled_QK, float("-inf"))




      #ENDTODO
      weights = F.softmax(scaled_QK, -1)
      attention = torch.einsum("bhij,jbhd->ibhd", weights, V)

      result = self.W_O(attention.reshape(seq_len, batch_size,self.num_heads * self.d_head))

      return result, weights

In [46]:
# Test your solution
d_model = 2
num_heads= 2
d_head = 4
k = 3
batch_size = 2

mha = MultiHeadAttention(d_model, num_heads, d_head)
batched_x= torch.randn((2*k, batch_size, d_model))
with torch.no_grad():
  result, weights = mha(batched_x)
print("Result:", result)
print("Weights:", weights)

Result: tensor([[[ 0.0372,  0.7503],
         [ 0.2375,  0.7533]],

        [[ 0.0163,  0.7479],
         [ 0.2390,  0.7526]],

        [[ 0.0351,  0.7500],
         [ 0.2403,  0.7547]],

        [[ 0.1086,  0.7195],
         [ 0.4531,  0.8649]],

        [[ 0.1411,  0.8073],
         [ 0.1710,  0.7565]],

        [[ 0.1228,  0.6960],
         [-0.1254,  0.5422]]])
Weights: tensor([[[[0.2962, 0.3586, 0.3451, 0.0000, 0.0000, 0.0000],
          [0.2755, 0.4078, 0.3167, 0.0000, 0.0000, 0.0000],
          [0.2939, 0.3625, 0.3437, 0.0000, 0.0000, 0.0000],
          [0.2130, 0.2352, 0.2527, 0.2992, 0.0000, 0.0000],
          [0.1813, 0.2649, 0.2027, 0.1828, 0.1683, 0.0000],
          [0.1318, 0.1333, 0.1593, 0.2053, 0.0462, 0.3242]],

         [[0.3326, 0.3210, 0.3464, 0.0000, 0.0000, 0.0000],
          [0.3081, 0.3626, 0.3293, 0.0000, 0.0000, 0.0000],
          [0.3307, 0.3259, 0.3433, 0.0000, 0.0000, 0.0000],
          [0.2483, 0.2201, 0.2521, 0.2795, 0.0000, 0.0000],
          [0.1960, 0.