## **Attention masking ( hide future word)**

In [1]:
import tiktoken
import torch
from torch.utils.data import Dataset, DataLoader


class GPTDatasetV1(Dataset):
    def __init__(self, txt, tokenizer, max_length, stride):
        self.input_ids = []
        self.target_ids = []
        
        token_ids = tokenizer.encode(txt, allowed_special={"<|endoftext|>"})
        
        if len(token_ids) == max_length:
            input_chunk = token_ids
            target_chunk = token_ids[1:] + [token_ids[-1]]
            self.input_ids.append(torch.tensor(input_chunk))
            self.target_ids.append(torch.tensor(target_chunk))
            
        elif len(token_ids) < max_length:
            padded_tokens = token_ids + [0] * (max_length - len(token_ids))
            input_chunk = padded_tokens
            target_chunk = padded_tokens[1:] + [0]
            self.input_ids.append(torch.tensor(input_chunk))
            self.target_ids.append(torch.tensor(target_chunk))
            
        else:
            for i in range(0, len(token_ids) - max_length, stride):
                input_chunk = token_ids[i : i + max_length]
                target_chunk = token_ids[i + 1 : i + max_length + 1]
                self.input_ids.append(torch.tensor(input_chunk))
                self.target_ids.append(torch.tensor(target_chunk))

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.target_ids[idx]


def create_dataloader_v1(txt, batch_size=2, max_length=256,
                         stride=128, shuffle=True, drop_last=True,
                         num_workers=0):
    
    tokenizer = tiktoken.get_encoding("gpt2")
    dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)
    
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=drop_last,
        num_workers=num_workers
    )
    
    return dataloader    



raw_text = "Yours journey start with one steps"
tokenizer = tiktoken.get_encoding("gpt2")
token_ids = tokenizer.encode(raw_text, allowed_special={"<|endoftext|>"})

max_length = len(token_ids)
vocab_size = 50257
output_dim = 3

token_embedding_layer = torch.nn.Embedding(vocab_size, output_dim)
pos_embedding_layer = torch.nn.Embedding(max_length, output_dim)

dataloader = create_dataloader_v1(
    raw_text, batch_size=1, max_length=max_length,
    stride=max_length, shuffle=False
)

data_iter = iter(dataloader)
inputs, targets = next(data_iter)

token_embeddings = token_embedding_layer(inputs)
pos_embeddings = pos_embedding_layer(torch.arange(max_length))

input_embeddings = token_embeddings + pos_embeddings
input_embeddings[0]


tensor([[-1.3623, -0.8648, -0.7154],
        [-1.3506, -0.8700,  0.0080],
        [-2.4313, -2.0778, -2.6057],
        [-0.9143,  0.5596,  1.4268],
        [-1.5556,  0.2111, -1.6661],
        [ 1.4545,  0.4688,  1.6685],
        [ 0.3324, -0.4165,  2.4709]], grad_fn=<SelectBackward0>)

In [2]:
import torch.nn as nn
class SelfAttention_v2(nn.Module):
    
    def __init__(self, d_in, d_out):
        super().__init__()
        self.w_query = torch.nn.Linear(d_in, d_out, bias=False)
        self.w_key = torch.nn.Linear(d_in, d_out, bias=False)
        self.w_value = torch.nn.Linear(d_in, d_out, bias=False)
        
    def forward(self, inputs):
        queries = self.w_query(inputs)
        keys = self.w_key(inputs)
        values = self.w_value(inputs)
        
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[1]**0.5, dim=-1)
        
        context_vec = attn_weights @ values
        
        return context_vec

In [3]:
torch.manual_seed(123)
rows = input_embeddings[0].shape[0]
cols = input_embeddings[0].shape[1]
sa_v2 = SelfAttention_v2(cols,output_dim)    
sa_v2(input_embeddings[0])

tensor([[ 0.3517,  0.4018, -0.1902],
        [ 0.1683,  0.0663, -0.0016],
        [ 0.6731,  0.9933, -0.5566],
        [-0.0573, -0.3358,  0.2475],
        [ 0.6036,  0.9164, -0.4988],
        [-0.4543, -1.0243,  0.6178],
        [-0.5950, -1.2418,  0.6766]], grad_fn=<MmBackward0>)

In [6]:
queries = sa_v2.w_query(input_embeddings[0])
keys = sa_v2.w_key(input_embeddings[0])
values = sa_v2.w_value(input_embeddings[0])

attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[1]**0.5, dim=-1)

attn_weights

tensor([[0.1064, 0.1377, 0.0785, 0.1426, 0.0523, 0.1784, 0.3041],
        [0.1348, 0.1594, 0.1348, 0.1386, 0.0784, 0.1318, 0.2223],
        [0.0389, 0.0735, 0.0159, 0.0900, 0.0071, 0.1802, 0.5944],
        [0.1510, 0.1378, 0.1771, 0.1304, 0.1900, 0.1145, 0.0993],
        [0.0450, 0.0650, 0.0143, 0.1225, 0.0217, 0.3424, 0.3891],
        [0.1377, 0.0947, 0.3278, 0.0635, 0.3262, 0.0300, 0.0202],
        [0.1326, 0.0955, 0.4657, 0.0447, 0.2319, 0.0141, 0.0155]],
       grad_fn=<SoftmaxBackward0>)

In [5]:
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))

mask_simple

tensor([[1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1.]])

In [7]:
mask_attn_weights = attn_weights * mask_simple
mask_attn_weights


tensor([[0.1064, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1348, 0.1594, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0389, 0.0735, 0.0159, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1510, 0.1378, 0.1771, 0.1304, 0.0000, 0.0000, 0.0000],
        [0.0450, 0.0650, 0.0143, 0.1225, 0.0217, 0.0000, 0.0000],
        [0.1377, 0.0947, 0.3278, 0.0635, 0.3262, 0.0300, 0.0000],
        [0.1326, 0.0955, 0.4657, 0.0447, 0.2319, 0.0141, 0.0155]],
       grad_fn=<MulBackward0>)

## Rows normalize

In [8]:
rows_sum = mask_attn_weights.sum(dim=-1, keepdim=True)
mask_attn_weights_norm = mask_attn_weights / rows_sum
mask_attn_weights_norm

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4581, 0.5419, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3029, 0.5729, 0.1241, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2532, 0.2311, 0.2970, 0.2186, 0.0000, 0.0000, 0.0000],
        [0.1677, 0.2420, 0.0534, 0.4560, 0.0808, 0.0000, 0.0000],
        [0.1405, 0.0966, 0.3345, 0.0648, 0.3329, 0.0306, 0.0000],
        [0.1326, 0.0955, 0.4657, 0.0447, 0.2319, 0.0141, 0.0155]],
       grad_fn=<DivBackward0>)

## Final Self-Attention

In [9]:
import torch.nn as nn
class CausalAttention(nn.Module):
    
    def __init__(self, d_in, d_out, context_length, dropout):
        super().__init__()
        self.w_query = torch.nn.Linear(d_in, d_out, bias=False)
        self.w_key = torch.nn.Linear(d_in, d_out, bias=False)
        self.w_value = torch.nn.Linear(d_in, d_out, bias=False)
        self.dropout = torch.nn.Dropout(dropout)
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))
        
    def forward(self, inputs):
        queries = self.w_query(inputs)
        keys = self.w_key(inputs)
        values = self.w_value(inputs)
        
        attn_scores = queries @ keys.transpose(1, 2)
        attn_scores = attn_scores.masked_fill(self.mask == 1, float('-inf'))
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        
        context_vec = attn_weights @ values
        
        return context_vec

In [11]:
torch.manual_seed(789)
ca = CausalAttention(cols, output_dim, context_length=max_length, dropout=0.0)

result = ca(input_embeddings)
result

tensor([[[-0.8334,  0.8476, -0.0489],
         [-0.7711,  0.6804, -0.1014],
         [-0.8331,  0.7543, -0.1069],
         [-1.0446,  1.0713, -0.0611],
         [-0.4382,  0.2421, -0.0030],
         [-0.9566,  1.0776,  0.0076],
         [-0.7571,  0.7951, -0.0297]]], grad_fn=<UnsafeViewBackward0>)