In [20]:
import torch
import torch.nn as nn

In [21]:
data = [[0.43, 0.15, 0.89],
        [0.55, 0.87, 0.66],
        [0.57, 0.85, 0.64],
        [0.22, 0.58, 0.33],
        [0.77, 0.25, 0.10],
        [0.05, 0.80, 0.55]]
inputs = torch.tensor(data)

## A simple self-attention mechanishm without trainable weights

In [22]:
attn_scores = torch.zeros(6,6)
attn_weights = torch.zeros(6,6)

attn_scores = inputs @ inputs.T
attn_weights = torch.softmax(attn_scores, dim=-1)
context_vectors = attn_weights @ inputs

print(f"Context vector for all input tokens:\n {context_vectors}")

Context vector for all input tokens:
 tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


## Implementing self-attention weights with trainable weights

Initializing three weights: W_query, W_key, W_value

In [23]:
d_in = inputs.shape[1]
d_out = 2

In [24]:
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.w_key = nn.Parameter(torch.rand(d_in, d_out))
        self.w_query = nn.Parameter(torch.rand(d_in, d_out))
        self.w_value = nn.Parameter(torch.rand(d_in, d_out))
    def forward(self, x):
        keys = x @ self.w_key
        queries = x @ self.w_query
        values = x @ self.w_value
        attention_scores = queries @ keys.T
        attenton_weights = torch.softmax(attention_scores / keys.shape[-1]**0.5, dim=-1)
        context_vector = attenton_weights @ values
        return context_vector

class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.w_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.w_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.w_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    def forward(self, x):
        keys = self.w_key(x)
        queries = self.w_query(x)
        values = self.w_value(x)
        attention_scores = queries @ keys.T
        attention_weights = torch.softmax(attention_scores / keys.shape[-1]**0.5, dim=-1)
        context_vector = attention_weights @ values
        return attention_scores, context_vector    

In [25]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(f"Forward pass on inputs in version 1:\n {sa_v1.forward(inputs)}")

torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)

Forward pass on inputs in version 1:
 tensor([[0.2947, 0.7956],
        [0.3015, 0.8132],
        [0.3010, 0.8120],
        [0.2925, 0.7902],
        [0.2863, 0.7737],
        [0.2979, 0.8043]], grad_fn=<MmBackward0>)


In [26]:
attention_scores, context_vector = sa_v2.forward(inputs)

## CAUSAL ATTENTION / MASKED ATTENTION

Modified attention mechanism to prevent the model from accessng future information in the sequence, which is crucial for tasks like
language modeling, where each word prediction should depend on previous word.

It restricts the model to only consider previous and current inputs in a sequence when processing any given token while computing computing attention score and hene context vector.

In contrast, where the self attention mechanism allows access to the entire input sequence at once.

In [27]:
context_length = 6
keys = nn.Parameter(torch.rand(d_in, d_out))

mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attention_scores.masked_fill(mask.bool(), -torch.inf)

attention_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=-1)
print(f"Attention weights:\n {attention_weights}")

Attention weights:
 tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4822, 0.5178, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3181, 0.3412, 0.3407, 0.0000, 0.0000, 0.0000],
        [0.2451, 0.2521, 0.2521, 0.2508, 0.0000, 0.0000],
        [0.2002, 0.2063, 0.2060, 0.1950, 0.1925, 0.0000],
        [0.1607, 0.1676, 0.1677, 0.1681, 0.1683, 0.1675]],
       grad_fn=<SoftmaxBackward0>)


In [28]:
batch = torch.stack((inputs, inputs), dim=0) #input text is duplicated to simulate batch inputs

Register buffer are automatically moved to the appropriate device(CPU or GPU) along with out model, which will be relevant when training LLM. That is, we don't need to manually ensure these tensors are on the same device as your model parameters, avoiding device mismatch errors.

In [29]:
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.w_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.w_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.w_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        num_batch, num_tokens, d_in = x.shape
        keys = self.w_key(x)
        queries = self.w_query(x)
        values = self.w_value(x)
        attention_scores = queries@keys.transpose(1,2) #transposing dimensions 1 and 2 keeping the batch dimensions at first position(0)
        attention_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attention_weights = torch.softmax(attention_scores / keys.shape[-1]**0.5, dim=-1)
        attention_weights = self.dropout(attention_weights)
        context_vector = attention_weights@values
        return attention_scores, context_vector     

In [30]:
torch.manual_seed(123)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
attention_scores, context_vectors = ca(batch)

## MULTI-HEAD ATTENTION

Splitting attention mechanism into multiple heads.
Each head learns different aspects of the data, allowing model to simultaneously attend to information from different representation subspaces at different positions, hence improving model's performance in complex tasks.

In [31]:
class MultiHeadAttentionWraper(nn.Module):
    pass

In [33]:
class MultiHeadAttention(nn.Module):
    pass