## **Attention masking ( hide future word)**

In [34]:
import tiktoken
import torch
from torch.utils.data import Dataset, DataLoader


class GPTDatasetV1(Dataset):
    def __init__(self, txt, tokenizer, max_length, stride):
        self.input_ids = []
        self.target_ids = []
        
        token_ids = tokenizer.encode(txt, allowed_special={"<|endoftext|>"})
        
        if len(token_ids) == max_length:
            input_chunk = token_ids
            target_chunk = token_ids[1:] + [token_ids[-1]]
            self.input_ids.append(torch.tensor(input_chunk))
            self.target_ids.append(torch.tensor(target_chunk))
            
        elif len(token_ids) < max_length:
            padded_tokens = token_ids + [0] * (max_length - len(token_ids))
            input_chunk = padded_tokens
            target_chunk = padded_tokens[1:] + [0]
            self.input_ids.append(torch.tensor(input_chunk))
            self.target_ids.append(torch.tensor(target_chunk))
            
        else:
            for i in range(0, len(token_ids) - max_length, stride):
                input_chunk = token_ids[i : i + max_length]
                target_chunk = token_ids[i + 1 : i + max_length + 1]
                self.input_ids.append(torch.tensor(input_chunk))
                self.target_ids.append(torch.tensor(target_chunk))

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.target_ids[idx]


def create_dataloader_v1(txt, batch_size=2, max_length=256,
                         stride=128, shuffle=True, drop_last=True,
                         num_workers=0):
    
    tokenizer = tiktoken.get_encoding("gpt2")
    dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)
    
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=drop_last,
        num_workers=num_workers
    )
    
    return dataloader    



raw_text = "Yours journey start with one steps"
tokenizer = tiktoken.get_encoding("gpt2")
token_ids = tokenizer.encode(raw_text, allowed_special={"<|endoftext|>"})

max_length = len(token_ids)
vocab_size = 50257
output_dim = 3

token_embedding_layer = torch.nn.Embedding(vocab_size, output_dim)
pos_embedding_layer = torch.nn.Embedding(max_length, output_dim)

dataloader = create_dataloader_v1(
    raw_text, batch_size=1, max_length=max_length,
    stride=max_length, shuffle=False
)

data_iter = iter(dataloader)
inputs, targets = next(data_iter)

token_embeddings = token_embedding_layer(inputs)
pos_embeddings = pos_embedding_layer(torch.arange(max_length))

input_embeddings = token_embeddings + pos_embeddings
input_embeddings[0]


tensor([[-2.2151,  2.7693,  0.4419],
        [-2.6485,  0.6560, -0.8422],
        [-4.0542, -0.0676, -0.6946],
        [ 1.0932,  0.6538, -0.2154],
        [ 0.7685, -0.7436,  1.1174],
        [-1.1274,  0.6884,  0.7027],
        [-1.5350, -0.0513,  0.5445]], grad_fn=<SelectBackward0>)

In [35]:
import torch.nn as nn
class SelfAttention_v2(nn.Module):
    
    def __init__(self, d_in, d_out):
        super().__init__()
        self.w_query = torch.nn.Linear(d_in, d_out, bias=False)
        self.w_key = torch.nn.Linear(d_in, d_out, bias=False)
        self.w_value = torch.nn.Linear(d_in, d_out, bias=False)
        
    def forward(self, inputs):
        queries = self.w_query(inputs)
        keys = self.w_key(inputs)
        values = self.w_value(inputs)
        
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[1]**0.5, dim=-1)
        
        context_vec = attn_weights @ values
        
        return context_vec

In [36]:
torch.manual_seed(123)
rows = input_embeddings[0].shape[0]
cols = input_embeddings[0].shape[1]
sa_v2 = SelfAttention_v2(cols,output_dim)    
sa_v2(input_embeddings[0])

tensor([[ 0.2631,  0.2517,  0.0760],
        [ 0.2570,  0.3280, -0.1216],
        [ 0.2546,  0.3840, -0.2348],
        [ 0.3834,  0.0429,  0.4746],
        [ 0.3459, -0.3084,  0.8016],
        [ 0.3498,  0.0912,  0.3461],
        [ 0.3349,  0.0736,  0.3175]], grad_fn=<MmBackward0>)

In [37]:
queries = sa_v2.w_query(input_embeddings[0])
keys = sa_v2.w_key(input_embeddings[0])
values = sa_v2.w_value(input_embeddings[0])

attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[1]**0.5, dim=-1)

attn_weights

tensor([[0.1509, 0.0411, 0.0141, 0.3799, 0.2178, 0.1247, 0.0714],
        [0.0553, 0.0323, 0.0214, 0.2091, 0.4359, 0.1320, 0.1140],
        [0.0268, 0.0219, 0.0196, 0.1360, 0.5613, 0.1146, 0.1197],
        [0.2093, 0.1514, 0.1141, 0.1765, 0.1025, 0.1338, 0.1124],
        [0.1252, 0.2242, 0.3563, 0.0536, 0.0440, 0.0880, 0.1087],
        [0.1536, 0.1222, 0.1011, 0.1797, 0.1619, 0.1476, 0.1338],
        [0.1078, 0.1175, 0.1275, 0.1392, 0.1987, 0.1497, 0.1596]],
       grad_fn=<SoftmaxBackward0>)

In [38]:
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))

mask_simple

tensor([[1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1.]])

In [39]:
mask_attn_weights = attn_weights * mask_simple
mask_attn_weights


tensor([[0.1509, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0553, 0.0323, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0268, 0.0219, 0.0196, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2093, 0.1514, 0.1141, 0.1765, 0.0000, 0.0000, 0.0000],
        [0.1252, 0.2242, 0.3563, 0.0536, 0.0440, 0.0000, 0.0000],
        [0.1536, 0.1222, 0.1011, 0.1797, 0.1619, 0.1476, 0.0000],
        [0.1078, 0.1175, 0.1275, 0.1392, 0.1987, 0.1497, 0.1596]],
       grad_fn=<MulBackward0>)

## Rows normalize

In [40]:
rows_sum = mask_attn_weights.sum(dim=-1, keepdim=True)
mask_attn_weights_norm = mask_attn_weights / rows_sum
mask_attn_weights_norm

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6315, 0.3685, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3922, 0.3206, 0.2872, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3214, 0.2324, 0.1752, 0.2710, 0.0000, 0.0000, 0.0000],
        [0.1559, 0.2791, 0.4435, 0.0667, 0.0548, 0.0000, 0.0000],
        [0.1773, 0.1411, 0.1167, 0.2075, 0.1869, 0.1704, 0.0000],
        [0.1078, 0.1175, 0.1275, 0.1392, 0.1987, 0.1497, 0.1596]],
       grad_fn=<DivBackward0>)

## Final Self-Attention

In [41]:
import torch.nn as nn
class CausalAttention(nn.Module):
    
    def __init__(self, d_in, d_out, context_length, dropout):
        super().__init__()
        self.w_query = torch.nn.Linear(d_in, d_out, bias=False)
        self.w_key = torch.nn.Linear(d_in, d_out, bias=False)
        self.w_value = torch.nn.Linear(d_in, d_out, bias=False)
        self.dropout = torch.nn.Dropout(dropout)
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))
        
    def forward(self, inputs):
        queries = self.w_query(inputs)
        keys = self.w_key(inputs)
        values = self.w_value(inputs)
        
        attn_scores = queries @ keys.transpose(-2, -1) 
        attn_scores = attn_scores.masked_fill(self.mask == 1, float('-inf'))
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        context_vec = attn_weights @ values
        
        return context_vec

In [42]:
torch.manual_seed(789)
ca = CausalAttention(cols, output_dim, context_length=max_length, dropout=0.0)

result = ca(input_embeddings)
result

tensor([[[ 0.3539, -0.7506,  0.6953],
         [ 0.0083, -0.3133,  0.5954],
         [-0.4288,  0.0989,  0.4910],
         [-0.3225,  0.1824,  0.3460],
         [ 0.0753, -0.1343,  0.0775],
         [-0.3474,  0.1171,  0.2567],
         [-0.2855,  0.0366,  0.2153]]], grad_fn=<UnsafeViewBackward0>)