<a href="https://colab.research.google.com/github/ArpitKadam/Attention-Is-All-You-Code/blob/main/LLM-from-Scratch/CHP_05_Causal_Attention_Mechanism.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **CAUSAL ATTENTION IMPLEMENTATION**

In [1]:
import torch

input = torch.tensor(
    [[0.72, 0.45, 0.31],   ## Dream
     [0.75, 0.20, 0.55],   ## big
     [0.30, 0.80, 0.40],   ## and
     [0.85, 0.35, 0.60],   ## work
     [0.55, 0.15, 0.75],   ## for
     [0.20, 0.20, 0.85]]   ## it
)

words = ["Dream", "big", "and", "work", "for", "it"]

In [2]:
import torch.nn as nn

class SelfAttention_v2(nn.Module):
  def __init__(self, d_in, d_out, qkv_bias):
    super().__init__()
    self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

  def forward(self, x):
    keys = self.W_key(x)
    queries = self.W_query(x)
    values = self.W_value(x)

    attn_scores = torch.matmul(queries, keys.T)

    attn_weights = torch.softmax(attn_scores / (keys.shape[-1] ** (0.5)), dim=-1)

    context_vec = torch.matmul(attn_weights, values)

    return context_vec

In [3]:
torch.manual_seed(100)

sa_v2 = SelfAttention_v2(3, 3, True)
queries = sa_v2.W_query(input)
keys = sa_v2.W_key(input)
values = sa_v2.W_value(input)

attn_scores = torch.matmul(queries, keys.T)
attn_weights = torch.softmax(attn_scores / (keys.shape[-1] ** (0.5)), dim=-1)

print("Attention Scores:")
print(attn_scores)
print("Shape:", attn_scores.shape)
print()
print("Attention Weights:")
print(attn_weights)
print("Shape:", attn_weights.shape)

Attention Scores:
tensor([[-0.6868, -0.5927, -0.7045, -0.6571, -0.5185, -0.4539],
        [-0.7825, -0.6745, -0.7996, -0.7516, -0.5871, -0.5083],
        [-0.4230, -0.3595, -0.4378, -0.3796, -0.3166, -0.2920],
        [-0.7654, -0.6544, -0.7836, -0.7150, -0.5695, -0.5027],
        [-0.7309, -0.6282, -0.7459, -0.6976, -0.5454, -0.4725],
        [-0.5937, -0.5097, -0.6058, -0.5651, -0.4422, -0.3834]],
       grad_fn=<MmBackward0>)
Shape: torch.Size([6, 6])

Attention Weights:
tensor([[0.1585, 0.1674, 0.1569, 0.1612, 0.1747, 0.1813],
        [0.1571, 0.1673, 0.1556, 0.1600, 0.1759, 0.1841],
        [0.1614, 0.1674, 0.1600, 0.1655, 0.1716, 0.1741],
        [0.1570, 0.1674, 0.1554, 0.1617, 0.1758, 0.1827],
        [0.1576, 0.1672, 0.1562, 0.1606, 0.1754, 0.1829],
        [0.1592, 0.1672, 0.1581, 0.1619, 0.1738, 0.1798]],
       grad_fn=<SoftmaxBackward0>)
Shape: torch.Size([6, 6])


In [4]:
simple_mask = torch.tril(torch.ones(6, 6))
print(simple_mask)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [5]:
masked_simple = attn_weights * simple_mask
print(masked_simple)

tensor([[0.1585, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1571, 0.1673, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1614, 0.1674, 0.1600, 0.0000, 0.0000, 0.0000],
        [0.1570, 0.1674, 0.1554, 0.1617, 0.0000, 0.0000],
        [0.1576, 0.1672, 0.1562, 0.1606, 0.1754, 0.0000],
        [0.1592, 0.1672, 0.1581, 0.1619, 0.1738, 0.1798]],
       grad_fn=<MulBackward0>)


In [6]:
row_sum = masked_simple.sum(dim=1, keepdim=True)
print(row_sum)

tensor([[0.1585],
        [0.3244],
        [0.4888],
        [0.6415],
        [0.8171],
        [1.0000]], grad_fn=<SumBackward1>)


In [7]:
masked_simple_norm = masked_simple / row_sum
print(masked_simple_norm)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4844, 0.5156, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3302, 0.3425, 0.3273, 0.0000, 0.0000, 0.0000],
        [0.2448, 0.2610, 0.2422, 0.2520, 0.0000, 0.0000],
        [0.1929, 0.2047, 0.1912, 0.1966, 0.2147, 0.0000],
        [0.1592, 0.1672, 0.1581, 0.1619, 0.1738, 0.1798]],
       grad_fn=<DivBackward0>)


The Upper Method leads to Data Leakage Problems, since we are applying Softmax before normalizing and converting to lower triangular matrix

Now we apply a more efficient method

- Attention Score -> Upper Triangular Matrix -> Softmax (Attention Weights)

In [8]:
print("Attention Scores:")
print(attn_scores)
print("Shape:", attn_scores.shape)

Attention Scores:
tensor([[-0.6868, -0.5927, -0.7045, -0.6571, -0.5185, -0.4539],
        [-0.7825, -0.6745, -0.7996, -0.7516, -0.5871, -0.5083],
        [-0.4230, -0.3595, -0.4378, -0.3796, -0.3166, -0.2920],
        [-0.7654, -0.6544, -0.7836, -0.7150, -0.5695, -0.5027],
        [-0.7309, -0.6282, -0.7459, -0.6976, -0.5454, -0.4725],
        [-0.5937, -0.5097, -0.6058, -0.5651, -0.4422, -0.3834]],
       grad_fn=<MmBackward0>)
Shape: torch.Size([6, 6])


In [9]:
mask = torch.triu(torch.ones(6,6), diagonal=1)
print(mask)

tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])


In [10]:
mask.bool()

tensor([[False,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True],
        [False, False, False, False,  True,  True],
        [False, False, False, False, False,  True],
        [False, False, False, False, False, False]])

In [11]:
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

tensor([[-0.6868,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.7825, -0.6745,    -inf,    -inf,    -inf,    -inf],
        [-0.4230, -0.3595, -0.4378,    -inf,    -inf,    -inf],
        [-0.7654, -0.6544, -0.7836, -0.7150,    -inf,    -inf],
        [-0.7309, -0.6282, -0.7459, -0.6976, -0.5454,    -inf],
        [-0.5937, -0.5097, -0.6058, -0.5651, -0.4422, -0.3834]],
       grad_fn=<MaskedFillBackward0>)


In [47]:
attn_weights = torch.softmax(masked / (keys.shape[-1] ** 0.5), dim=-1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4844, 0.5156, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3302, 0.3425, 0.3273, 0.0000, 0.0000, 0.0000],
        [0.2448, 0.2610, 0.2422, 0.2520, 0.0000, 0.0000],
        [0.1929, 0.2047, 0.1912, 0.1966, 0.2147, 0.0000],
        [0.1592, 0.1672, 0.1581, 0.1619, 0.1738, 0.1798]],
       grad_fn=<SoftmaxBackward0>)


**Masking Additional Attention Weights with Dropout**

In [12]:
torch.manual_seed(100)
dropout = torch.nn.Dropout(0.5)
example = torch.ones(6, 6)
print(dropout(example))

tensor([[0., 2., 2., 2., 0., 0.],
        [0., 2., 0., 0., 2., 2.],
        [0., 0., 0., 2., 2., 2.],
        [0., 2., 2., 2., 0., 2.],
        [2., 0., 0., 0., 0., 2.],
        [2., 0., 2., 2., 0., 0.]])


In [13]:
torch.manual_seed(100)
dropout = torch.nn.Dropout(0.5)
print(dropout(attn_weights))

tensor([[0.0000, 0.3347, 0.3138, 0.3225, 0.0000, 0.0000],
        [0.0000, 0.3345, 0.0000, 0.0000, 0.3518, 0.3682],
        [0.0000, 0.0000, 0.0000, 0.3310, 0.3432, 0.3481],
        [0.0000, 0.3348, 0.3107, 0.3233, 0.0000, 0.3655],
        [0.3152, 0.0000, 0.0000, 0.0000, 0.0000, 0.3659],
        [0.3185, 0.0000, 0.3163, 0.3238, 0.0000, 0.0000]],
       grad_fn=<MulBackward0>)


# **CAUSAL ATTENTION CLASS**

In [15]:
import torch.nn as nn

class CausalAttention(nn.Module):
  def __init__(self, d_in, d_out, context_length, dropout, qkv_bias):
    super().__init__()
    self.d_out = d_out
    self.context_length = context_length
    self.dropout = nn.Dropout(dropout)
    self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    self.register_buffer("simple_mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))

  def forward(self, x):
    batch_size, num_tokens, d_in = x.shape

    keys = self.W_key(x)
    queries = self.W_query(x)
    values = self.W_value(x)

    attn_scores = torch.matmul(queries, keys.transpose(-2, -1))

    attn_scores.masked_fill_(
        self.simple_mask.bool()[:num_tokens, :num_tokens],
        -torch.inf
    )

    attn_weights = torch.softmax(attn_scores / (keys.shape[-1] ** 0.5), dim=-1)

    attn_weights = self.dropout(attn_weights)

    context_vec = torch.matmul(attn_weights, values)

    return context_vec

In [16]:
torch.manual_seed(123)

inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

batch = torch.stack([inputs, inputs], dim=0)
print(batch.shape)

d_in = batch.shape[-1]
d_out = 2
context_length = batch.shape[1]
dropout = 0.0
qkv_bias = False

torch.Size([2, 6, 3])


In [17]:
causal_attn = CausalAttention(d_in, d_out, context_length, dropout, qkv_bias)
context_vector = causal_attn(batch)

In [18]:
print("Context Vector:")
print(context_vector)
print("Shape:", context_vector.shape)

Context Vector:
tensor([[[-0.4519,  0.2216],
         [-0.5790,  0.0192],
         [-0.6226, -0.0512],
         [-0.5669, -0.0793],
         [-0.5501, -0.0919],
         [-0.5307, -0.1042]],

        [[-0.4519,  0.2216],
         [-0.5790,  0.0192],
         [-0.6226, -0.0512],
         [-0.5669, -0.0793],
         [-0.5501, -0.0919],
         [-0.5307, -0.1042]]], grad_fn=<UnsafeViewBackward0>)
Shape: torch.Size([2, 6, 2])


The use of register_buffer in
PyTorch is not strictly necessary for all use cases but offers several advantages here.

For
instance, when we use the CausalAttention class in our LLM, buffers are automatically
moved to the appropriate device (CPU or GPU) along with our model, which will be relevant
when training the LLM in future chapters.

This means we don't need to manually ensure
these tensors are on the same device as your model parameters, avoiding device mismatch
errors.