#### 1 Causal Attention -

Only considers previous and current inputs in a sequnce when processing any given token when computing attention score, Self attention allows all tokens in an input sequence to be processed at once. 

In [None]:
# Using nn.Linear has optimized weight initialization scheme,
# leading better model training  (effective & Stable)
import torch
import torch.nn as nn

inputs = torch.tensor(
[[0.43, 0.15, 0.89], # Your (x^1)
[0.55, 0.87, 0.66], # journey (x^2)
[0.57, 0.85, 0.64], # starts (x^3)
[0.22, 0.58, 0.33], # with (x^4)
[0.77, 0.25, 0.10], # one (x^5)
[0.05, 0.80, 0.55]] # step (x^6)
)

class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_q = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_k = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_v = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        queries = self.W_q(x)
        keys = self.W_k(x)
        values = self.W_v(x)

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1] ** 0.5, dim=-1
        )
        context_vec = attn_weights @ values
        return context_vec

torch.manual_seed(123)
sa_v2 = SelfAttention_v2(3, 2)
print(sa_v2(inputs))

tensor([[-0.5337, -0.1051],
        [-0.5323, -0.1080],
        [-0.5323, -0.1079],
        [-0.5297, -0.1076],
        [-0.5311, -0.1066],
        [-0.5299, -0.1081]], grad_fn=<MmBackward0>)


##### 1.1 Applying a Causal Attention Mask

Intially it might appear that mask attention weights and renormalize might still have influence on the current token beacuse thier values are still part of softmax calculations, But since we are re-normalizing masked attention weights, we are essentialy doing is recalculating softmax over the smaller subset (Elegance of Softmax,) Therefor No Information Leakage

In [15]:
# 1. Reuses the query and key weight matrices of the SelfAttention_v2 object
queries = sa_v2.W_q(inputs)
keys = sa_v2.W_k(inputs)
attention_scores = queries @ keys.T
attention_weights = torch.softmax(
    attention_scores / keys.shape[-1] ** 0.5, dim=-1
)
print(attention_weights)
#context_vectors = attention_weights @ sa_v2.W_v(inputs)

# 2. Pytorch's tril function to create a simple mask
context_length =  attention_scores.shape[0] 
simple_mask = torch.tril(torch.ones(context_length, context_length))
#print(simple_mask)

masked_simple = attention_weights * simple_mask # Multiply
#print(masked_simple)

# 3. Renormalize attention weights to sum up to 1 again in each row
rows_sum = masked_simple.sum(dim=-1, keepdim=True)
masked_simple_norm = masked_simple / rows_sum
#print(masked_simple_norm)

# 4. A more efficient trick
# Softmax converts its input into a probability dist. large negative -> 0
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attention_scores.masked_fill(mask.bool(), -torch.inf)
attention_weights = torch.softmax(masked / keys.shape[-1] ** 0.5, dim=-1)
print(attention_weights)

tensor([[0.1717, 0.1762, 0.1761, 0.1555, 0.1627, 0.1579],
        [0.1636, 0.1749, 0.1746, 0.1612, 0.1605, 0.1652],
        [0.1637, 0.1749, 0.1746, 0.1611, 0.1606, 0.1651],
        [0.1636, 0.1704, 0.1702, 0.1652, 0.1632, 0.1674],
        [0.1667, 0.1722, 0.1721, 0.1618, 0.1633, 0.1639],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<SoftmaxBackward0>)
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4833, 0.5167, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3190, 0.3408, 0.3402, 0.0000, 0.0000, 0.0000],
        [0.2445, 0.2545, 0.2542, 0.2468, 0.0000, 0.0000],
        [0.1994, 0.2060, 0.2058, 0.1935, 0.1953, 0.0000],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<SoftmaxBackward0>)
