### Causal self attention or Masked self attention

In [1]:
import torch
import torch.nn as nn

In [2]:
torch.manual_seed(123)

<torch._C.Generator at 0x11204fd70>

In [3]:
d_in, d_out_kq, d_out_v = 3, 2, 4

In [4]:
w_query = nn.Parameter(torch.randn(d_in, d_out_kq))
w_key = nn.Parameter(torch.randn(d_in, d_out_kq))
w_value = nn.Parameter(torch.randn(d_in, d_out_v))

In [5]:
x = torch.randn((6, d_in))

In [6]:
x.shape

torch.Size([6, 3])

In [7]:
queries = x @ w_query
keys = x @ w_key
values = x @ w_value

In [8]:
#Masking procedure
# Calculate attention scores (unnormalized)
# Apply softmax to normalize the attention scores matrix to a probability distribution
# Creat a lower triangular mask (Mask with 0's above diagonal)
# Mask attention scores
# Normalize the rows again

In [9]:
attn_scores = queries @ keys.T

In [10]:
attn_weights = torch.softmax((attn_scores / d_out_kq ** 0.5), dim=-1)

In [11]:
#mask out future tokens by applying a mask to the attention weight matrix

In [12]:
context_size = attn_weights.shape[0]

In [13]:
mask = torch.tril(torch.ones(context_size, context_size))

In [14]:
mask

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])

In [17]:
attn_weights

tensor([[0.2843, 0.1406, 0.0708, 0.1859, 0.2022, 0.1162],
        [0.4084, 0.0497, 0.0132, 0.2621, 0.1949, 0.0716],
        [0.2651, 0.1697, 0.0899, 0.1612, 0.1978, 0.1162],
        [0.0861, 0.1763, 0.3177, 0.1171, 0.1168, 0.1859],
        [0.2457, 0.1271, 0.0828, 0.2106, 0.1939, 0.1400],
        [0.1144, 0.1889, 0.2594, 0.1273, 0.1365, 0.1735]],
       grad_fn=<SoftmaxBackward0>)

In [23]:
masked_attn_weights = attn_weights * mask

In [27]:
norm_attn_w = masked_attn_weights / torch.sum(masked_attn_weights, dim=-1, keepdim=True)

In [30]:
norm_attn_w

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.8914, 0.1086, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5052, 0.3234, 0.1713, 0.0000, 0.0000, 0.0000],
        [0.1235, 0.2529, 0.4556, 0.1680, 0.0000, 0.0000],
        [0.2857, 0.1478, 0.0963, 0.2448, 0.2255, 0.0000],
        [0.1144, 0.1889, 0.2594, 0.1273, 0.1365, 0.1735]],
       grad_fn=<DivBackward0>)

In [None]:
# Alternative Masking procedure
# Calculate attention scores (unnormalized)
# Creat a upper triangular mask (Mask with 0's below diagonal and 1's above diagonal)
# Use masked fill to fill -torch.inf where mask is True
# Apply softmax to normalize the attention scores matrix to a probability distribution

In [44]:
new_mask = torch.triu(torch.ones(context_size, context_size), diagonal=1)

In [45]:
new_mask

tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])

In [47]:
masked = attn_scores.masked_fill(new_mask.bool(), -torch.inf)

In [49]:
torch.softmax((masked/ d_out_kq**0.5), dim=-1)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.8914, 0.1086, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5052, 0.3234, 0.1713, 0.0000, 0.0000, 0.0000],
        [0.1235, 0.2529, 0.4556, 0.1680, 0.0000, 0.0000],
        [0.2857, 0.1478, 0.0963, 0.2448, 0.2255, 0.0000],
        [0.1144, 0.1889, 0.2594, 0.1273, 0.1365, 0.1735]],
       grad_fn=<SoftmaxBackward0>)