In [15]:
import torch

inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

## what is causal attention?

In causal attention, When attention scores are computed, the causal attention mechaism ensures that the scores are only computed for the tokens that are before the current token.
This is done by masking the scores of the tokens that are after the current token.

for example:

![](causal_attention.png)

In [16]:
keys_w = torch.nn.Parameter(torch.randn(3,2))
queries_w = torch.nn.Parameter(torch.randn(3,2))
values_w = torch.nn.Parameter(torch.randn(3,2))

In [17]:
keys = inputs @ keys_w
queries = inputs @ queries_w
values = inputs @ values_w

attention_scores = queries @ keys.T

attention_weights = torch.softmax(attention_scores / keys.shape[-1] ** 0.5, dim=-1)

In [18]:
context_length = attention_scores.shape[0]
mask_for_causal_attention = torch.tril(torch.ones(context_length,context_length))
print(mask_for_causal_attention)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [19]:
causal_attention_weights = attention_weights * mask_for_causal_attention
print(causal_attention_weights)

tensor([[0.2584, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1425, 0.1701, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1440, 0.1696, 0.1703, 0.0000, 0.0000, 0.0000],
        [0.1358, 0.1749, 0.1756, 0.1688, 0.0000, 0.0000],
        [0.1806, 0.1590, 0.1590, 0.1684, 0.1655, 0.0000],
        [0.1225, 0.1796, 0.1805, 0.1688, 0.1877, 0.1610]],
       grad_fn=<MulBackward0>)


In [20]:
## now since every row need to sum to 1 , so just normalize the weights

row_sums = torch.sum(causal_attention_weights,dim=1,keepdim=True)

In [21]:
causal_attention_weights_norm = causal_attention_weights/row_sums

In [22]:
print(causal_attention_weights_norm)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4558, 0.5442, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2976, 0.3505, 0.3519, 0.0000, 0.0000, 0.0000],
        [0.2073, 0.2670, 0.2680, 0.2577, 0.0000, 0.0000],
        [0.2169, 0.1910, 0.1910, 0.2023, 0.1988, 0.0000],
        [0.1225, 0.1796, 0.1805, 0.1688, 0.1877, 0.1610]],
       grad_fn=<DivBackward0>)


In [23]:
## But the above method is a not fully correct because while when doing 
## softmax on the attention scores, we are also considering the values 
## of the tokens that are not attended to.

## hence a better approach is to first make those attention score to be -inf 
## and then do softmax , so that the normallized scores after softmax will be 
## are also not factoring in the values of the tokens that are not attended to.

In [24]:
batch = torch.stack((inputs,inputs),dim=0)
print(batch.shape)

torch.Size([2, 6, 3])


In [25]:
class CausalAttention(torch.nn.Module):
    def __init__(self,d_in,d_out,context_length,dropout,qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = torch.nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_key = torch.nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_value = torch.nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_out = torch.nn.Linear(d_out,d_out)
        self.dropout = torch.nn.Dropout(dropout)
        self.register_buffer('mask',torch.triu(torch.ones(context_length,context_length),diagonal=1))

    def forward(self,x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attention_scores = queries @ keys.transpose(1,2)
        attention_scores.masked_fill(
            self.mask.bool()[:num_tokens,:num_tokens],-torch.inf)
        attention_weights = torch.softmax(
            attention_scores / keys.shape[-1] ** 0.5,dim=-1
        )
        attention_weights = self.dropout(attention_weights)

        context_vector = attention_weights @ values
        return context_vector

In [26]:
torch.manual_seed(123)

context_length = batch.shape[1]
ca = CausalAttention(3,2,context_length,0.0)

context_vector = ca(batch)
print(context_vector.shape)

torch.Size([2, 6, 2])


In [27]:
print(context_vector)

tensor([[[-0.5337, -0.1051],
         [-0.5323, -0.1080],
         [-0.5323, -0.1079],
         [-0.5297, -0.1076],
         [-0.5311, -0.1066],
         [-0.5299, -0.1081]],

        [[-0.5337, -0.1051],
         [-0.5323, -0.1080],
         [-0.5323, -0.1079],
         [-0.5297, -0.1076],
         [-0.5311, -0.1066],
         [-0.5299, -0.1081]]], grad_fn=<UnsafeViewBackward0>)
