1. Causal attention is also known as masked attention mechanism.
2. Because we mask out the future tokens. The current token will have the access only uptil the previous tokens. It wont interact with the tokens ahead of it.
3. For this we use lower triangular matrix to mask out the future tokens.

lower triangular matrix = [[1,0,0,0],
                           [1,1,0,0],
                           [1,1,1,0],
                           [1,1,1,1]]

upper triangular matrix = [[1,1,1,1],
                           [0,1,1,1],
                           [0,0,1,1],
                           [0,0,0,1]]

Steps to perform causal attention:
1. Create input embedding matrix for tokens
2. Create Weight_Q, Weight_k, Weight_v matrices
3. Generate q, k, v matrices
   q = Input @ Weight_q
   k = Input @ Weight_k
   v = Input @ Weight_v
4. Generate attention scores
   attention_scores = q @ k.Transpose
5. Create a upper triangular matrix mask
6. Using the mask as reference check for positions having 1 and fill the corresponding positions in attention_scores matrix with -inf.
7. After this attention_scores matrix will look like
   attention_scores = [[1, -inf, -inf, -inf],
                      [1, 2, -inf, -inf],
                      [1, 2, 3, -inf],
                      [1, 2, 3, 4]]
8. Use softmax for normalisation attention_weights = torch.softmax(attention_scores)
9. We can have some dropouts normally 10% for better results
10. Generate context vector:
    context_vector = attention_weight @ v



In [1]:
import torch
import torch.nn as nn

In [2]:
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

d_in = inputs.shape[1]
d_out = 2

# The dimensions of Wq, Wk, Wv are (d_in, d_out)
# so if shape of inputs is (6, 3)
# d_in = 3 => inputs.shape[-1]

In [10]:
class CausalAttention(nn.Module):
  def __init__(self):
    super().__init__()
    self.Wq = nn.Linear(d_in, d_out, bias=False)
    self.Wk = nn.Linear(d_in, d_out, bias=False)
    self.Wv = nn.Linear(d_in, d_out, bias=False)


  def forward(self, x):
    self.context_vectors = None

    q = self.Wq(x)
    k = self.Wk(x)
    v = self.Wv(x)

    attention_scores = q @ k.transpose(-2, -1)
    # k.transpose(-2, -1) = k.transpose( -1, -2)
    # swap -1 (last dimension) and -2 (second last dimension)
    # This is a generic way of transposing

    mask = torch.triu(torch.ones(attention_scores.shape), diagonal=1).bool()

    attention_scores.masked_fill_(mask, -float("inf"))
    attention_weights = torch.softmax(attention_scores, dim=-1)

    self.context_vectors = attention_weights @ v

    return self.context_vectors

In [11]:
context_vectors = CausalAttention()(inputs)
print(context_vectors)

tensor([[-0.0967, -0.2997],
        [-0.1710, -0.2990],
        [-0.1895, -0.3022],
        [-0.1898, -0.2533],
        [-0.1176, -0.2862],
        [-0.1728, -0.2327]], grad_fn=<MmBackward0>)


In [12]:
mask = torch.triu(torch.ones(3,3), diagonal=1).bool()
mask

tensor([[False,  True,  True],
        [False, False,  True],
        [False, False, False]])