# Creating some playground
* Dataset
* Dataloader
* Multihead-Attention


## Dataset
  Reprocesses the entire dataset into training pairs, and `__getitem__`
   just retrieves them by index.

  Process:
  1. Tokenize text → convert to token IDs
  2. Sliding window → extract sequences of max_length + 1
  3. Split each sequence → input `[i:max_length]` and target `[i+1:max_length+1]` (shifted by 1)
  4. Store as pairs → `self.input_ids[i]` and `self.target_ids[i]` as a training pair
  5. Jumps `stride` width

  Example:
  sequence = [1, 2, 3, 4, 5]  # length 5, max_length = 4
  # Split into:
  input_ids  = [1, 2, 3, 4]   # predict next token at each position
  target_ids = [2, 3, 4, 5]   # what should be predicted

  Training pairs at same index:
  - input_ids[0] → predict → target_ids[0]
  - input_ids[1] → predict → target_ids[1]
  - etc.

  This preprocessing in __init__ makes __getitem__ very fast since it just returns pre-computed pairs.
  The model learns to predict the next token at each position in the sequence.

In [23]:
import tiktoken # tokenization library openai
import torch
import torch.nn as nn 
from torch.utils.data import Dataset, DataLoader

torch.manual_seed(42)

class GPTDatasetV1(Dataset):

    def __init__(self, txt, tokenizer, max_length, stride):
        self.input_ids = []
        self.target_ids = []

        # Tokenize txt
        token_ids = tokenizer.encode( txt, allowed_special={'<|endoftext|>'})

        # Walks throught IDs and prepaires training sets for every index of input_ids and target_ids
        # Stride determines the jump wide for one loop
        for i in range(0, len(token_ids) - max_length, stride):
            input_chunk = token_ids[i:i+max_length]
            target_chunk = token_ids[i+1:i+max_length+1]
            self.input_ids.append(torch.tensor(input_chunk))
            self.target_ids.append(torch.tensor(target_chunk))

    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        return self.input_ids[idx], self.target_ids[idx]


## Dataloader
Wrapper for initialization and data loader creation

In [24]:
def create_dataloader(txt, tokenizer_model_name="gpt2", batch_size=4, max_length=256, stride=128, shuffle=True):

    # Initialize tokenizer
    tokenizer = tiktoken.get_encoding(tokenizer_model_name)

    # Create dataset
    dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)

    # Create dataloader
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

    return dataloader

## Data loading
Loading of file Robins Small Text Sample

In [25]:
with open("Robins Small Text Sample.txt", "r", encoding="utf-8") as file:
    raw_text = file.read()

dataloader = create_dataloader( raw_text, tokenizer_model_name="gpt2", batch_size=8, max_length=4, stride=4)

## Token Embeddings
Gets input embeddings from token_ids
* Create embedding layers for vocabulary and embedding dimensions
* Positional encoding from embedding layer with context_length and embedding dimension

In [26]:
def get_input_embedding():

    vocab_size = 50252
    embedding_dim = 256
    con_len = 4
    positions = torch.arange(con_len) # tensor([0, 1, 2, 3])
    tokenizer = tiktoken.get_encoding("gpt2") # for debuggin outputs

    token_embedding_layer = nn.Embedding(vocab_size, embedding_dim)
    pos_embedding_layer = nn.Embedding(con_len, embedding_dim)

    for batch in dataloader:
        x, y = batch
        print("Displaying first row of batch")
        print("\nInput x:\n", x[0], tokenizer.decode(x[0].tolist()))
        print("\nTarget y:\n", y[0], tokenizer.decode(y[0].tolist()))

        embeddings = token_embedding_layer(x)
        pos_embeddings = pos_embedding_layer(positions)

        print("\n embeddings for x:\n", embeddings[0])
        print("\n pos_embeddings for 0-3:\n", pos_embeddings)

        input_embeddings = embeddings + pos_embeddings

        print("\n input_embeddings = embeddings + pos_embeddings:\n", input_embeddings[0])
        print("Shape for input_embeddings: batch, context, embedding_dim ", input_embeddings.shape)

        break

    return input_embeddings

get_input_embedding()

Displaying first row of batch

Input x:
 tensor([2506,  326, 3360,   11])  everyone that sometimes,

Target y:
 tensor([ 326, 3360,   11,  262])  that sometimes, the

 embeddings for x:
 tensor([[ 0.9752,  0.3436, -1.0103,  ...,  0.9809,  1.2331, -0.1515],
        [-0.2354, -0.6906, -0.7542,  ..., -2.3173, -1.1541, -1.0304],
        [-0.5122, -0.5704, -0.9882,  ..., -0.7218, -1.0930, -1.0003],
        [ 1.8313, -0.6159, -0.6073,  ..., -2.3891,  0.7178, -1.5831]],
       grad_fn=<SelectBackward0>)

 pos_embeddings for 0-3:
 tensor([[-2.9261, -2.8944,  0.7488,  ...,  1.0027,  0.7249, -0.1917],
        [ 0.0056, -0.5059,  0.5341,  ..., -0.7495,  1.3472, -0.8115],
        [ 0.5100,  0.1452, -1.1741,  ...,  0.5613,  1.2736, -1.4704],
        [ 1.1651, -1.2304, -0.9989,  ...,  0.4287, -0.3611,  0.4499]],
       grad_fn=<EmbeddingBackward0>)

 input_embeddings = embeddings + pos_embeddings:
 tensor([[-1.9509e+00, -2.5508e+00, -2.6152e-01,  ...,  1.9836e+00,
          1.9580e+00, -3.4329e-01],

tensor([[[-1.9509e+00, -2.5508e+00, -2.6152e-01,  ...,  1.9836e+00,
           1.9580e+00, -3.4329e-01],
         [-2.2974e-01, -1.1964e+00, -2.2012e-01,  ..., -3.0667e+00,
           1.9312e-01, -1.8419e+00],
         [-2.2381e-03, -4.2520e-01, -2.1624e+00,  ..., -1.6051e-01,
           1.8060e-01, -2.4707e+00],
         [ 2.9964e+00, -1.8463e+00, -1.6062e+00,  ..., -1.9605e+00,
           3.5674e-01, -1.1332e+00]],

        [[-3.5598e+00, -2.7543e+00, -5.7565e-01,  ...,  1.3383e+00,
           1.5547e+00, -8.4338e-01],
         [ 1.8369e+00, -1.1218e+00, -7.3182e-02,  ..., -3.1386e+00,
           2.0650e+00, -2.3946e+00],
         [ 1.8134e+00, -1.7097e-01, -1.9379e+00,  ...,  4.9678e-01,
           1.7161e+00,  8.8213e-02],
         [ 2.1793e+00, -9.5127e-02,  9.3677e-02,  ...,  7.4319e-01,
          -1.0008e+00,  1.5751e+00]],

        [[-2.9704e+00, -3.0695e+00,  2.1067e+00,  ...,  1.3326e+00,
           1.9507e+00,  1.5052e-01],
         [ 1.8369e+00, -1.1218e+00, -7.3182e-02,  .

## Prototype Attention Head
Implementation of a simple multi attention head (no summed and split matrices)

In [27]:
class CausalSelfAttention(nn.Module):
    
    def __init__(self, input_dim, output_dim, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.output_dim = output_dim
        self.W_query = nn.Linear(input_dim, output_dim, bias=qkv_bias)
        self.W_key = nn.Linear(input_dim, output_dim, bias=qkv_bias)
        self.W_value = nn.Linear(input_dim, output_dim, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout) # dropout is a probability
        self.register_buffer('causal_mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # Causal Masking on diagonal triangle

    def forward(self, x):
        batch_size, con_len, in_dim = x.shape # set them to the shape of x ( 8, 4, 256 )
        
        # Projection of the input to qkv using weight matrices and broadcasting
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1, 2)    # Calculation of attention scores (comparing query and key)
        attn_scores.masked_fill( self.causal_mask.bool()[:con_len, :con_len], -torch.inf)
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1) # Dimensions -1 for taking internal rows (in batch and context)
        attn_weights = self.dropout(attn_weights) # Additional random dropout

        context_vec = attn_weights @ values
        return context_vec


In [28]:
# Triangular Matrix filtering example for causal masking
def mask_demo():
    matrix = torch.rand(4,4)
    mask = torch.triu(torch.ones(4, 4), diagonal=1)

    print("\nmatrix:\n", matrix) # some value matrix
    print("\nmask\n", mask) # triangle mask matrix consisting of 0 and 1
    print("\nmask.bool()\n", mask.bool()) # makes matrix with 0 and 1 to true and false
    print("\nmask.bool()[interval]\n", mask.bool()[:3, :3]) # sub matrix
    print("\nmatrix.mask_fill(mask.bool()[interval], value)\n", matrix.masked_fill(mask.bool()[:4, :4], -torch.inf))
mask_demo()


matrix:
 tensor([[0.0535, 0.5146, 0.2969, 0.1282],
        [0.8037, 0.2223, 0.9951, 0.7353],
        [0.0512, 0.8887, 0.2361, 0.5592],
        [0.7751, 0.3303, 0.5652, 0.8721]])

mask
 tensor([[0., 1., 1., 1.],
        [0., 0., 1., 1.],
        [0., 0., 0., 1.],
        [0., 0., 0., 0.]])

mask.bool()
 tensor([[False,  True,  True,  True],
        [False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False, False]])

mask.bool()[interval]
 tensor([[False,  True,  True],
        [False, False,  True],
        [False, False, False]])

matrix.mask_fill(mask.bool()[interval], value)
 tensor([[0.0535,   -inf,   -inf,   -inf],
        [0.8037, 0.2223,   -inf,   -inf],
        [0.0512, 0.8887, 0.2361,   -inf],
        [0.7751, 0.3303, 0.5652, 0.8721]])


## Prototype Multi Head Attention
An implementation of Multi Head Attention
Multiple heads train themself on parts of the input vector. They represent a part of the qkv matrices and are combined in the end by linear layer to context vector

In [29]:
class MultiHeadAttention(nn.Module):

    def __init__(self, input_dim, output_dim, context_length, dropout, num_head, qkv_bias=False):
        super().__init__()
        
        # Check if output dimension is dividable by attention head number
        assert output_dim % num_head == 0, "Output dimension must be dividable by head_num"

        print("- input_dim =", input_dim)
        print("- output_dim =", output_dim)
        print("- num_heads =", num_head)
        print("- head_dim =", output_dim)
        self.output_dim = output_dim
        self.num_heads = num_head
        self.head_dim = output_dim // num_head

        # Init weight matrices for input to qkv projection (full not yet separated into head parts)
        print(f"- generating nn.Linear({input_dim}, {output_dim}) weights for query, key and value")
        self.W_query = nn.Linear(input_dim, output_dim, qkv_bias)
        self.W_key = nn.Linear(input_dim, output_dim, qkv_bias)
        self.W_value = nn.Linear(input_dim, output_dim, qkv_bias)

        print(f"- generating causal diagonal mask torch.triu(torch.ones({context_length}, {context_length}), diagonal=1) for causal masking of attn_scores")        
        self.register_buffer("causal_mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))

        print(f"- generating dropout nn.Dropout({dropout}) for random dropout of attn_weights")        
        self.dropout = nn.Dropout(dropout)

        print(f"- generating optional nn.Linear({output_dim}, {output_dim}) weights for final context_vector projection")
        self.out_proj = nn.Linear(output_dim, output_dim) # Linear layer for combination of head outputs
        
        
    def forward(self, x):
        
        # local variables for input shape
        batch_size, context_length, input_dim = x.shape

        # Projection of the input to qkv using weight matrices and broadcasting for all attention heads together
        # -> shape: batch_size, context_length, output_dim
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        # Implicitly splitting matrix by adding head_num dimension
        # Unrol last dim: (batch_size, context_length, output_dim) -> (batch_size, context_length, head_num, head_dim)
        # Example ( 8, 4, 256 ) -> ( 8, 4, 8, 32 ) for num_heads = 8 and head_dim = 32
        queries = queries.view(batch_size, context_length, self.num_heads, self.head_dim)
        keys = keys.view(batch_size, context_length, self.num_heads, self.head_dim)
        values = values.view(batch_size, context_length, self.num_heads, self.head_dim)

        # Transpose to use for query comparison - move num_head to front
        # (batch_size, context_length, num_heads, head_dim) -> (batch_size, num_head, context_length, head_dim)
        queries = queries.transpose(1,2)
        keys = keys.transpose(1,2)
        values = values.transpose(1,2)
        
        # Compute scaled dot_production attention
        attn_scores = queries @ keys.transpose(2,3) # Dot product for each head

        # Causal Masking
        mask_bool = self.causal_mask.bool()[:context_length, :context_length]
        attn_scores.masked_fill(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim =-1)
        attn_weights = self.dropout(attn_weights)

        # build context vector switch back num_heads and context_length, combining heads
        # self.output_dim = self.num_heads * self.head_dim
        context_vec = (attn_weights @ values).transpose(1,2)
        context_vec = context_vec.contiguous().view(batch_size, context_length, self.output_dim)
        context_vec = self.out_proj(context_vec) # optional projection by Linear layer

        return context_vec

## Test Run

In [30]:
print("\n\n------- initializing Multi Head Attention ----------------\n")
mha = MultiHeadAttention(input_dim=256, output_dim=128, context_length=4, dropout=0.2, num_head=8 )

print("\n\n------- generating input ---------------------------------\n")
batch = get_input_embedding()

print("\n\n------- performing multi head attention ------------------\n")
context_vector = mha(batch)

print("\n Shape of context_vector:\n", context_vector.shape)




------- initializing Multi Head Attention ----------------

- input_dim = 256
- output_dim = 128
- num_heads = 8
- head_dim = 128
- generating nn.Linear(256, 128) weights for query, key and value
- generating causal diagonal mask torch.triu(torch.ones(4, 4), diagonal=1) for causal masking of attn_scores
- generating dropout nn.Dropout(0.2) for random dropout of attn_weights
- generating optional nn.Linear(128, 128) weights for final context_vector projection


------- generating input ---------------------------------

Displaying first row of batch

Input x:
 tensor([ 2627,   257, 34538,   286])  became a beacon of

Target y:
 tensor([  257, 34538,   286, 11044])  a beacon of innovation

 embeddings for x:
 tensor([[ 0.8057, -0.6795, -0.4887,  ..., -0.3833, -0.0600, -0.3964],
        [ 0.4135, -0.8302,  0.8324,  ..., -1.9796, -0.2031, -0.8872],
        [ 0.8406,  0.0906, -0.0602,  ...,  0.1790, -1.3455, -0.3669],
        [-1.3623, -0.9489,  0.2503,  ..., -1.1597,  0.1812,  1.2424]],
