# Causal Self Attention

In [1]:
import torch
from lib.attention import VanillaSelfAttention

In [2]:
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

In [3]:
torch.manual_seed(789)
vanilla_self_attention = VanillaSelfAttention(d_in=inputs.shape[-1], d_out=inputs.shape[-1] - 1)

queries = vanilla_self_attention.W_q(inputs)
keys = vanilla_self_attention.W_k(inputs)

attn_scores = queries @ keys.T
attn_wts = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)

In [4]:
# Applying Simple Causal Mask, i.e. masking with 0s

causal_mask_simple = torch.tril(torch.ones(attn_wts.shape))
simple_masked_attn_wts = attn_wts * causal_mask_simple

simple_masked_attn_wts_norm = simple_masked_attn_wts / simple_masked_attn_wts.sum(dim=-1, keepdim=True)
print(simple_masked_attn_wts_norm)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<DivBackward0>)


In [5]:
# Applying Efficient Causal Mask, i.e. masking with -∞

causal_mask_efficient = torch.triu(torch.ones(attn_scores.shape), diagonal=1)
causal_masked_attn_scores = attn_scores.masked_fill(causal_mask_efficient.bool(), -torch.inf)

causal_masked_attn_wts = torch.softmax(causal_masked_attn_scores / keys.shape[-1] ** 0.5, dim=-1)
print(causal_masked_attn_wts)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


In [6]:
# Masking additional attention weights with dropout
torch.manual_seed(123)
dropout_layer = torch.nn.Dropout(0.5)

# Dropout illustration
dropout_layer(torch.ones(attn_wts.shape))

tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])

In [7]:
torch.manual_seed(123)
dropout_layer(causal_masked_attn_wts)

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.7599, 0.6194, 0.6206, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4921, 0.4925, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3966, 0.0000, 0.3775, 0.0000, 0.0000],
        [0.0000, 0.3327, 0.3331, 0.3084, 0.3331, 0.0000]],
       grad_fn=<MulBackward0>)

In [8]:
import torch.nn as nn

class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_len, dropout, qkv_bias=False):
        super().__init__()
        self.W_q = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_k = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_v = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout_layer = nn.Dropout(p=dropout)
        self.register_buffer(
            'mask', torch.triu(torch.ones(context_len, context_len), diagonal=1))
    
    def forward(self, x):
        batch_size, num_tokens, emb_dim = x.shape
        queries = self.W_q(x)
        keys = self.W_k(x)
        values = self.W_v(x)

        attn_scores = queries @ keys.transpose(1, 2)
        attn_scores.masked_fill(
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attn_wts = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
        attn_wts = self.dropout_layer(attn_wts)
        context_vecs = attn_wts @ values
        return context_vecs

In [9]:
torch.manual_seed(123)

single_batch_inputs = torch.stack((inputs, inputs))
causal_attn_layer = CausalAttention(
    d_in=single_batch_inputs.shape[-1], d_out=single_batch_inputs.shape[-1] - 1, 
    context_len=single_batch_inputs.shape[1], dropout=0.0)

context_vecs = causal_attn_layer(single_batch_inputs)
print(context_vecs)

tensor([[[-0.5337, -0.1051],
         [-0.5323, -0.1080],
         [-0.5323, -0.1079],
         [-0.5297, -0.1076],
         [-0.5311, -0.1066],
         [-0.5299, -0.1081]],

        [[-0.5337, -0.1051],
         [-0.5323, -0.1080],
         [-0.5323, -0.1079],
         [-0.5297, -0.1076],
         [-0.5311, -0.1066],
         [-0.5299, -0.1081]]], grad_fn=<UnsafeViewBackward0>)
