<a href="https://colab.research.google.com/github/TanviSree/22b1050_llm_from_scratch/blob/main/chapter3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Simple Self Attention Mechanism (without trainable weights)

In [None]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

In [None]:
embeddings = torch.rand(5, 3)
print(embeddings)

tensor([[0.5771, 0.5235, 0.1928],
        [0.9123, 0.0096, 0.7988],
        [0.0353, 0.4856, 0.8558],
        [0.5318, 0.3742, 0.4494],
        [0.8334, 0.3342, 0.0234]])


In [None]:
attention_weights = []
for i in embeddings:
  attention_specific_i = []
  sum = 0
  for j in embeddings:
    attention_specific_i.append(torch.dot(i, j))
    sum+=torch.dot(i, j)
  attention_specific_i = [x/sum for x in attention_specific_i] #mean normalization
  attention_weights.append(attention_specific_i)
print(attention_weights)

[[tensor(0.2134), tensor(0.2271), tensor(0.1456), tensor(0.1952), tensor(0.2187)], [tensor(0.1521), tensor(0.3263), tensor(0.1599), tensor(0.1881), tensor(0.1736)], [tensor(0.1502), tensor(0.2462), tensor(0.3313), tensor(0.1999), tensor(0.0724)], [tensor(0.1827), tensor(0.2628), tensor(0.1814), tensor(0.1937), tensor(0.1794)], [tensor(0.2173), tensor(0.2573), tensor(0.0697), tensor(0.1904), tensor(0.2654)]]


In [None]:
context_vectors = []
for i in attention_weights:
  context_for_that_specific_i = [0,0,0]
  attention_weights_for_that_specific_i = i
  n = len(embeddings) #5 in our case
  for j in range(n):
    context_for_that_specific_i[0]+=attention_weights_for_that_specific_i[j]*embeddings[j][0]
    context_for_that_specific_i[1]+=attention_weights_for_that_specific_i[j]*embeddings[j][1]
    context_for_that_specific_i[2]+=attention_weights_for_that_specific_i[j]*embeddings[j][2]
  context_vectors.append(context_for_that_specific_i)

In [None]:
context_vectors_np = np.stack([torch.tensor(tensor).numpy() for tensor in context_vectors])
print(context_vectors_np)

[[0.6215781  0.33076474 0.43997198]
 [0.63585305 0.28883284 0.51541305]
 [0.48968503 0.34087592 0.6006893 ]
 [0.6041725  0.31871372 0.4916104 ]
 [0.6850413  0.3099857  0.39881623]]


# Self-Attention Mechanism with Trainable Weights

In [None]:
torch.manual_seed(124)

W_query = torch.nn.Parameter(torch.rand(3, 2), requires_grad=False)
W_key   = torch.nn.Parameter(torch.rand(3, 2), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(3, 2), requires_grad=False)

In [None]:
x_1 = embeddings[1]
query_1 = x_1 @ W_query
keys = embeddings @ W_key
values = embeddings @ W_value
attention_scores = query_1 @ keys.T
d_k = keys.shape[1]
attn_weights_1 = torch.softmax(attention_scores / d_k**0.5, dim=-1)
print(attn_weights_1)

tensor([0.1956, 0.2221, 0.1918, 0.1974, 0.1930])


In [None]:
context_vector1 = attn_weights_1 @ values
print(context_vector1)

tensor([0.7075, 0.9675])


Within a compact python class :

In [None]:
import torch.nn as nn
class selfAttention(nn.Module):

    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key   = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value

        attn_scores = queries @ keys.T # omega
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )

        context_vec = attn_weights @ values
        return context_vec

In [None]:
self1 = selfAttention(3,2)
print(self1(embeddings))

tensor([[0.9564, 0.4195],
        [0.9618, 0.4244],
        [0.9557, 0.4190],
        [0.9572, 0.4203],
        [0.9568, 0.4200]], grad_fn=<MmBackward0>)


#Causal Attention - Masked Attention
we are masking tokens that are present after the current query token as we dont want them to contribute to context vector

In [None]:
queries = embeddings @ self1.W_query
keys = embeddings @ self1.W_key
attention_scores = queries @ keys.T
attn_wts = torch.softmax(attention_scores / keys.shape[-1]**0.5, dim=-1)
print(attn_wts)

tensor([[0.1984, 0.2292, 0.1791, 0.1960, 0.1973],
        [0.1985, 0.2444, 0.1663, 0.1933, 0.1975],
        [0.1988, 0.2267, 0.1803, 0.1963, 0.1979],
        [0.1986, 0.2310, 0.1771, 0.1956, 0.1976],
        [0.1986, 0.2302, 0.1779, 0.1958, 0.1976]], grad_fn=<SoftmaxBackward0>)


In [None]:
context_length = attention_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)

tensor([[1., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1.]])


In [None]:
context_length = attention_scores.shape[0]
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attention_scores.masked_fill(mask.bool(), -torch.inf)
attn_wts = torch.softmax(masked / keys.shape[-1]**0.5, dim=-1)
print(attn_wts)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4481, 0.5519, 0.0000, 0.0000, 0.0000],
        [0.3281, 0.3742, 0.2977, 0.0000, 0.0000],
        [0.2475, 0.2879, 0.2207, 0.2438, 0.0000],
        [0.1986, 0.2302, 0.1779, 0.1958, 0.1976]], grad_fn=<SoftmaxBackward0>)


Applying dropout of 50% and scaling rest by factor of 2 ( 1/1-dropout)

In [None]:
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)
example = torch.ones(5,5)
print(dropout(attn_wts))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.8962, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.5758, 0.0000, 0.4877, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.3916, 0.3952]], grad_fn=<MulBackward0>)


In [None]:
batch = torch.stack((embeddings,embeddings), dim=0)
print(batch.shape)

torch.Size([2, 5, 3])


In [None]:
d_in = 3
d_out = 2
class CausalAttention(nn.Module):

    def __init__(self, d_in, d_out, context_length,
                 dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout) # New
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # New

    def forward(self, x):
        b, num_tokens, d_in = x.shape # New batch dimension b
        # For inputs where `num_tokens` exceeds `context_length`, this will result in errors
        # in the mask creation further below.
        # In practice, this is not a problem since the LLM (chapters 4-7) ensures that inputs
        # do not exceed `context_length` before reaching this forward method.
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1, 2) # Changed transpose
        attn_scores.masked_fill_(  # New, _ ops are in-place
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)  # `:num_tokens` to account for cases where the number of tokens in the batch is smaller than the supported context_size
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        attn_weights = self.dropout(attn_weights) # New

        context_vec = attn_weights @ values
        return context_vec

torch.manual_seed(123)

context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)

context_vecs = ca(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[-0.5071, -0.2232],
         [-0.5680,  0.0025],
         [-0.5007,  0.0285],
         [-0.4972,  0.0076],
         [-0.5031, -0.0410]],

        [[-0.5071, -0.2232],
         [-0.5680,  0.0025],
         [-0.5007,  0.0285],
         [-0.4972,  0.0076],
         [-0.5031, -0.0410]]], grad_fn=<UnsafeViewBackward0>)
context_vecs.shape: torch.Size([2, 5, 2])


# Multihead attention


In [None]:
class MultiHeadAttentionWrapper(nn.Module):

    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias)
             for _ in range(num_heads)]
        )

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)


torch.manual_seed(123)

context_length = batch.shape[1] # This is the number of tokens
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper(
    d_in, d_out, context_length, 0.0, num_heads=2
)

context_vecs = mha(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[-0.5071, -0.2232,  0.4773,  0.3738],
         [-0.5680,  0.0025,  0.5560,  0.2784],
         [-0.5007,  0.0285,  0.4961,  0.2491],
         [-0.4972,  0.0076,  0.4931,  0.2512],
         [-0.5031, -0.0410,  0.4975,  0.2700]],

        [[-0.5071, -0.2232,  0.4773,  0.3738],
         [-0.5680,  0.0025,  0.5560,  0.2784],
         [-0.5007,  0.0285,  0.4961,  0.2491],
         [-0.4972,  0.0076,  0.4931,  0.2512],
         [-0.5031, -0.0410,  0.4975,  0.2700]]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 5, 4])


In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Using the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2)

        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)

        return context_vec

torch.manual_seed(123)

batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)

context_vecs = mha(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[0.2273, 0.3996],
         [0.2884, 0.4005],
         [0.2831, 0.4363],
         [0.2778, 0.4351],
         [0.2677, 0.4254]],

        [[0.2273, 0.3996],
         [0.2884, 0.4005],
         [0.2831, 0.4363],
         [0.2778, 0.4351],
         [0.2677, 0.4254]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 5, 2])
