# Transformer TP
In this TP we will implement the transformer model (an architecture based on LLama2). The squeleton of the code is provided and you should complete it (mostly self attention mechanism).

In [1]:
import time 
from time import gmtime, strftime
import math
import shutil
import os



# neural network utilities
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader

# dataset an tokenization
from datasets import load_dataset
from transformers import AutoTokenizer

from rope_embedding import RoPE


W0207 20:24:51.520000 4510 site-packages/torch/distributed/elastic/multiprocessing/redirects.py:29] NOTE: Redirects are currently not supported in Windows or MacOs.


## Self Attention with no mask

Let consider $Q, K, V \in R^{B\times H \times L \times N}$ for each batch element (B) for each heads (H) we want to compute $$A = \frac{Q_{bh}.K_{bh}^{t}}{\sqrt{N}}$$

Attention weights are defined by $$ 
softmax(A_i) = \begin{pmatrix}
\frac{e^{A_{i1}}}{\sum\limits_{j=1}^L  e^{A_{ij}}} & 
\frac{e^{A_{i2}}}{\sum\limits_{j=1}^L  e^{A_{ij}}} & 
\dots &
\frac{e^{A_{iL}}}{\sum\limits_{j=1}^L  e^{A_{ij}}}
\end{pmatrix}$$ 
And $$ 
Softmax(A) = \begin{pmatrix}
\frac{e^{A_{11}}}{\sum\limits_{j=1}^L  e^{A_{1j}}} & 
\frac{e^{A_{12}}}{\sum\limits_{j=1}^L  e^{A_{1j}}} & 
\dots &
\frac{e^{A_{1L}}}{\sum\limits_{j=1}^L  e^{A_{1j}}} \\
\frac{e^{A_{21}}}{\sum\limits_{j=1}^L  e^{A_{2j}}} & 
\frac{e^{A_{22}}}{\sum\limits_{j=1}^L  e^{A_{2j}}} & 
\dots &
\frac{e^{A_{2l}}}{\sum\limits_{j=1}^L  e^{A_{2j}}} \\
\vdots & \vdots & \ddots & \vdots \\
\frac{e^{A_{L1}}}{\sum\limits_{j=1}^L  e^{A_{Lj}}} & 
\frac{e^{A_{L2}}}{\sum\limits_{j=1}^L  e^{A_{Lj}}} & 
\dots &
\frac{e^{A_{LL}}}{\sum\limits_{j=1}^L  e^{A_{Lj}}}
\end{pmatrix}$$ 


Notice that the function softmax can be used to compute this matrix [`torch.softmax`](https://docs.pytorch.org/docs/stable/generated/torch.softmax.html#torch.softmax).


## Masked Attention
In many cases, certains tokens will not compute attention with later tokens (padding or decoding approaches). Typically in decoder only architectures we would have the attention weights matrix only considering lower diagonal values :

$$ 
Softmax(A) = \begin{pmatrix}
\frac{e^{A_{11}}}{\sum\limits_{j=1}^1  e^{A_{1j}}} & 
0 & 
\dots & 0 &
0 \\
\frac{e^{A_{21}}}{\sum\limits_{j=1}^2  e^{A_{2j}}} & 
\frac{e^{A_{22}}}{\sum\limits_{j=1}^2  e^{A_{2j}}} & 
\dots & 0 &
0 \\
\vdots & \vdots & \ddots & \vdots \\
\frac{e^{A_{{L-1}1}}}{\sum\limits_{j=1}^{L-1}  e^{A_{L-1j}}} & 
\frac{e^{A_{{L-1}2}}}{\sum\limits_{j=1}^{L-1}  e^{A_{L-1j}}} & 
\dots &
\frac{e^{A_{L-1L-1}}}{\sum\limits_{j=1}^{L-1}  e^{A_{L-1j}}}& 0 \\
\frac{e^{A_{L1}}}{\sum\limits_{j=1}^L  e^{A_{Lj}}} & 
\frac{e^{A_{L2}}}{\sum\limits_{j=1}^L  e^{A_{Lj}}} & 
\dots &
\frac{e^{A_{LL-1}}}{\sum\limits_{j=1}^L  e^{A_{Lj}}} & 
\frac{e^{A_{LL}}}{\sum\limits_{j=1}^L  e^{A_{Lj}}} \\
\end{pmatrix}$$ 

A simple implementation would consist in setting the upper diagonal of A to $-\infty$ and apply the softmax. We will thus consider a matrix name $M$ (or `mask` in the code) having upper diagonal element to $-\infty$ : 
$$ M = 
\begin{pmatrix}
0& 
-\infty & 
\dots & -\infty&
-\infty \\
0 & 
0 & 
\dots & -\infty &
-\infty \\
\vdots & \vdots & \ddots & \vdots \\
0 & 
0 & 
\dots &
0& -\infty \\
0 & 
0 & 
\dots &
0 & 
0\\
\end{pmatrix}
$$

And apply softmax on $A + M$

In [2]:
def multi_head_attention(
    q : torch.FloatTensor,
    k : torch.FloatTensor,
    v : torch.FloatTensor,
    mask : torch.FloatTensor
):
    """
    Given Q,K,V and the attention mask compute cross attention

    The shape of the matrix are the following:
    B: The batch size (number of examples)
    Q: The size of the decoder sequence
    K: The size of the encoder sequence
    H: Then number of attention heads
    N: The embedding size 
    
    Parameters:
        q : torch.FloatTensor
            The query matrix of shape BxHxQxN
        k : torch.FloatTensor
            The key matrix of shape BxHxKxN
        v : torch.FloatTensor
            The value matrix of shape BxHxKxN
        mask : torch.FloatTensor
            The attention mask of shape BxQxK,
            the masked elements are set to -inf 
            else elements are set to 0
    Return : (torch.Tensor, torch.Tensor)
        return two tensor the first containing 
        the attention weights (QK^t) and the second 
        the result of the attention
    """
    # Get embedding dimension
    N = q.size(-1)

    # Step 1: Compute scaled dot-product attention scores
    # (B, H, Q, N) x (B, H, N, K) → (B, H, Q, K)
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(N)

    # Step 2: Apply attention mask
    # mask shape: (B, Q, K) → broadcast to (B, 1, Q, K)
    scores = scores + mask.unsqueeze(1)

    # Step 3: Softmax over the last dimension (K)
    attention_weights = torch.softmax(scores, dim=-1)

    # Step 4: Compute weighted sum of values
    # (B, H, Q, K) x (B, H, K, N) → (B, H, Q, N)
    attention_output = torch.matmul(attention_weights, v)

    return attention_weights, attention_output


In [3]:
q, k, v = (torch.ones(1, 1, 3, 4), torch.ones(1, 1, 3, 4), torch.ones(1, 1, 3, 4))
mask = torch.triu(-q.new_ones(1, 3, 3) *  torch.inf, diagonal =  1)
attention_weigths, attention_output = multi_head_attention(q, k, v, mask)

assert(torch.all(attention_weigths == torch.Tensor([
          [[[1., 0., 0.],
            [1/2, 1/2, 0.],
            [1/3, 1/3, 1/3]]]])))

##  Create the attention module

In pytorch module are block of neural networks that inherit from the class `nn.Module`.  In the following you willbe asked to complete the attention code. The attention module compute the Query $Q$, the Keys $K$ and Values $V$ for each heads and compute attention. The forward methods is the method implementing these operations. In this lab we use a mult-head attention, meaning that we compute attention for each heads, that can be done with a for loop or using batched operation (it is up to you). Also as we follow the Llama 2 architecture, we use for positional embeding the RoPe positonal encoding that is applied on the two matrix $Q$ and $K$. The forward function should compute for an input $X \in \mathbb{R}^{L \times N}$ (in the decoder only case) the following:

1. $Q_h = RoPE(W^q_{h}X^{\intercal})$ (for all heads h) with $W^q_h \in \mathbb{R}^{N \times N}$
2. $K_h = RoPE(W^k_{h}X^{\intercal})$ (for all heads h)
3. $V_h = W^v_{h}X^{\intercal}$ (for all heads h)
4. $A_h =  Attention(Q_h, K_h, V_h)$ (for all heads h)
5. $A$ = [A_1, A_2, \dots A_H] (the concatenation along the last dimension)
6. $O = W^oA^{\intercal}$ with $w^{o} \in \mathbb{R}^{N \times NH}$ (NH being the scalar multplication of N and H)

Notice that Q, K and V can be computed without a loop using $W_{q}$ (respectively for $W_k$ and $W_v$) .


In [4]:
class RoPEAttentionModule(nn.Module):
    """ A cross/self attention pytorch module.
    
    """
    def __init__(self, input_dim, output_dim, num_heads=1):
        super().__init__()
        
        self.Wq, self.Wk, self.Wv  =\
            (nn.Linear(input_dim, output_dim * num_heads) for _ in range(3))

        self.output_dim = output_dim
        self.num_heads = num_heads
        self.Wo = nn.Linear(output_dim * num_heads, input_dim)
        self.rope_func = RoPE(output_dim)
    
    def forward(
            self,
            x : torch.Tensor,
            attention_mask : torch.BoolTensor = None, 
            y : torch.Tensor = None,
            decoder_mask = True,
            k_cache = None,
            v_cache = None,
            start_cache = 0
        ):
        ''' 
            Parameters:
                x : torch.Tensor
                    The input of the Attention module
                    used at least for K, V computation
                    (for Q if decoder only)
                attention_mask : torch.BoolTensor
                    The mask for attention
                y : torch.Tensor or None
                    The query input in the case of
                    cross-attention
        '''

        B, L, _ = x.shape

        if y is None:
            y = x

        # Linear projections
        Q = self.Wq(y)
        K = self.Wk(x)
        V = self.Wv(x)

        # Reshape to multi-head format
        Q = Q.view(B, L, self.num_heads, self.output_dim).transpose(1, 2)
        K = K.view(B, L, self.num_heads, self.output_dim).transpose(1, 2)
        V = V.view(B, L, self.num_heads, self.output_dim).transpose(1, 2)

        # Apply RoPE to Q and K
        Q = self.rope_func(Q, start_cache)
        K = self.rope_func(K, start_cache)

        # Decoder causal mask
        if decoder_mask and attention_mask is None:
            attention_mask = torch.triu(
                torch.full((L, L), float("-inf"), device=x.device),
                diagonal=1
            ).unsqueeze(0).repeat(B, 1, 1)

        # Attention
        attention_weigths, attention_output = multi_head_attention(
            Q, K, V, attention_mask
        )

        # Concatenate heads
        attention_output = attention_output.transpose(1, 2).contiguous()
        attention_output = attention_output.view(
            B, L, self.num_heads * self.output_dim
        )

        # Output projection
        output = self.Wo(attention_output)

        return attention_weigths, output


## Trasnsformer FeedForward

The FeedForward Network is defined using Llama 2 architecture as following for $X \in \mathbb{R^{L\times N}}$ :


1. $G = SiLU(W^gX^\intercal)$ with $W^g \in \mathbb{R}^{M \times N}$ (see SiLU in pytorch [documentation](https://docs.pytorch.org/docs/stable/generated/torch.nn.SiLU.html))
2. $U = W^u X^\intercal$ with $W^u  \in \mathbb{R}^{M \times N}$
3. $I = G \odot I$ (Hadamard multiplication)
4. $O = W^{o} I^{\intercal}$ with $W_{o} \in \mathbb{R}^{N \times M}$






In [5]:
class TransformerFeedForward(nn.Module):
    def __init__(self, embed_size, intermediate_size):
        super().__init__()
        self.gate_proj = nn.Linear(embed_size, intermediate_size)
        self.up_proj = nn.Linear(embed_size, intermediate_size)
        self.down_proj = nn.Linear(intermediate_size, embed_size)
        self.gate_func = nn.SiLU()
        
    def forward(self, x):
        G = self.gate_func(self.gate_proj(x))
        U = self.up_proj(x)
        I = G * U
        O = self.down_proj(I)
        return O

## Transformer block
We can now create the block of the 

Let consider $x_0 \in \mathbb{R}^{B \times L \times N}$, the decoder block is given by : 

* $x_1 = LN_1(x_0)$ (apply layer norm)
* $x_2 = Attention(x_1)$ (apply attention)
* $x_3 = x_0 + x_2$ (adding residual)
* $x_4 = LN_2(x_3)$
* $y = FF(x_4)$


In [6]:
class TransformerDecoderBlock(nn.Module):
    def __init__(self, embed_size, intermediate_size, num_heads):
        super().__init__()
        self.attention_module =\
            RoPEAttentionModule(embed_size, embed_size, num_heads=num_heads)
        self.feed_forward = TransformerFeedForward(embed_size, intermediate_size)
        
        self.attention_layer_norm = nn.RMSNorm(embed_size)
        self.feed_forward_layer_norm = nn.RMSNorm(embed_size)
        
    def forward(
        self,
        x,
        attention_mask = None,
        k_cache=None,
        v_cache=None,
        start_cache=0,
        layer=0,
    ):
        
        x_norm = self.attention_layer_norm(x)
        _, attn_out = self.attention_module(
            x=x_norm,
            attention_mask=attention_mask,
            decoder_mask=True,
            k_cache=k_cache,
            v_cache=v_cache,
            start_cache=start_cache
        )
        x_res = x + attn_out
        
        x_ffn_norm = self.feed_forward_layer_norm(x_res)
        ff_out = self.feed_forward(x_ffn_norm)
        
        output = x_res + ff_out
        
        return output


## The Model

The model contain a stack of block on input embedding, it applies the different transformation and should return both the output embeddings and the logits (size of the vocabulary).

In [7]:
class TransformerDecoder(nn.Module):
    def __init__(self, vocabulary_size, embed_size, intermediate_size, num_heads, hidden_layers=5, **kwargs):
        super().__init__(**kwargs)
        self.nh = num_heads
        self.nhl = hidden_layers
        self.es = embed_size 

        self.wte = nn.Embedding(vocabulary_size, embed_size)
        self.wpe = nn.Embedding(512, embed_size)

        self.drop = nn.Dropout(.1)
        self.blocks = nn.ModuleList(
            [TransformerDecoderBlock(embed_size, intermediate_size, num_heads) for _ in range(hidden_layers)]
        )
        self.ln_f = nn.LayerNorm(embed_size, eps=1e-3)
        self.lm_head = nn.Linear(embed_size, vocabulary_size)

        self.position = torch.arange(512).unsqueeze(0)

    def forward(
            self,
            input_ids,
            attention_mask = None,
            k_cache = None,
            v_cache = None,
            start_cache = 0
        ):
        
        if input_ids.device != self.position.device:
            self.position = self.position.to(input_ids.device)
        
        B, L = input_ids.shape

        pos_ids = self.position[:, :L]
        token_emb = self.wte(input_ids)
        pos_emb = self.wpe(pos_ids)

        x = token_emb + pos_emb
        x = self.drop(x)

        for layer_id, block in enumerate(self.blocks):
            x = block(
                x,
                attention_mask=attention_mask,
                k_cache=k_cache,
                v_cache=v_cache,
                start_cache=start_cache,
                layer=layer_id
            )

        output_embed = self.ln_f(x)
        output_lm = self.lm_head(output_embed)

        return output_embed, output_lm


##  Training the transformer

### I.Dataset 
For simplicity we will use a small dataset name TinyStories, we can load it using the huggingface Datasets library as following

In [8]:

tinystories_dataset = load_dataset("roneneldan/TinyStories")
training_set = tinystories_dataset['train']
print(training_set[0])

{'text': 'One day, a little girl named Lily found a needle in her room. She knew it was difficult to play with it because it was sharp. Lily wanted to share the needle with her mom, so she could sew a button on her shirt.\n\nLily went to her mom and said, "Mom, I found this needle. Can you share it with me and sew my shirt?" Her mom smiled and said, "Yes, Lily, we can share the needle and fix your shirt."\n\nTogether, they shared the needle and sewed the button on Lily\'s shirt. It was not difficult for them because they were sharing and helping each other. After they finished, Lily thanked her mom for sharing the needle and fixing her shirt. They both felt happy because they had shared and worked together.'}


### The Dataloader

We can define the batch_size default we will use, using the DataLoader object. Here we consider that the training machine will have enough RAM for a batch of 16 samples.

In [9]:
training_dl = DataLoader(training_set, batch_size=16, shuffle=True) 

### III. Tokenizer

In this exercice we will consider the LLama Tokenizer, however, you can train your own tokenizer if you prefer (see lab 2)

In [11]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

model_hg_id = "mistralai/Mistral-7B-v0.3"

tokenizer = AutoTokenizer.from_pretrained(model_hg_id)
tokenizer.pad_token = '<pad>'
vocabulary_size = tokenizer.vocab_size
print(f"Vocabulary size: {vocabulary_size}")

Vocabulary size: 32768


### IV. Create the model and the optimizer

In [13]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = TransformerDecoder(vocabulary_size, 128, 256, 4, 2)
model = model.train()
model = model.to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=6e-4)


In [14]:
model

TransformerDecoder(
  (wte): Embedding(32768, 128)
  (wpe): Embedding(512, 128)
  (drop): Dropout(p=0.1, inplace=False)
  (blocks): ModuleList(
    (0-1): 2 x TransformerDecoderBlock(
      (attention_module): RoPEAttentionModule(
        (Wq): Linear(in_features=128, out_features=512, bias=True)
        (Wk): Linear(in_features=128, out_features=512, bias=True)
        (Wv): Linear(in_features=128, out_features=512, bias=True)
        (Wo): Linear(in_features=512, out_features=128, bias=True)
      )
      (feed_forward): TransformerFeedForward(
        (gate_proj): Linear(in_features=128, out_features=256, bias=True)
        (up_proj): Linear(in_features=128, out_features=256, bias=True)
        (down_proj): Linear(in_features=256, out_features=128, bias=True)
        (gate_func): SiLU()
      )
      (attention_layer_norm): RMSNorm((128,), eps=None, elementwise_affine=True)
      (feed_forward_layer_norm): RMSNorm((128,), eps=None, elementwise_affine=True)
    )
  )
  (ln_f): LayerN

### VI. Training the model

In [None]:
model = model.train()
avg_loss = []
start_time = time.time()

for i, data in enumerate(training_dl):
    res = tokenizer(data['text'], return_tensors="pt", padding=True, padding_side='right', max_length=128, truncation=True)
    x = res.input_ids[:, :-1] + 0
    y = res.input_ids[:, 1:] + 0
    # attention_mask = (res.attention_mask == 0)
    optimizer.zero_grad()
    oe, oy = model(x.to(device))
    loss = loss_function(oy.to(device).view(-1, vocabulary_size), y.to(device).view(-1))
    loss.backward()
    loss_value = loss.item()
    avg_loss.append(loss_value)
    if( i%500 == 0):
        elapsed_time = time.time() - start_time
        remaining_time = int((elapsed_time/(i+1)) * (len(training_dl) - i))
        loop_sec = ((i+1)/elapsed_time)
        print(f"The loss at iteration {i+1} is {sum(avg_loss)/len(avg_loss):3.4f} remaining_time is {strftime('%H:%M:%S', gmtime(remaining_time))}s ({loop_sec:4.1f} it/s)", flush=True)
        avg_loss = []
    if( i%1000 == 0):
        torch.save(optimizer.state_dict() ,"optimizer_state_dict_llama_mini.pth.temp")
        torch.save(model.state_dict() ,"transformer_state_dict_llama_mini.pth.temp")
        shutil.copyfile("optimizer_state_dict_llama_mini.pth.temp", "optimizer_state_dict_llama_mini.pth")
        shutil.copyfile("transformer_state_dict_llama_mini.pth.temp", "transformer_state_dict_llama_mini.pth")
    optimizer.step()

The loss at iteration 1 is 10.5585 remaining_time is 16:09:24s ( 2.3 it/s)
The loss at iteration 501 is 4.8572 remaining_time is 06:27:13s ( 5.7 it/s)
The loss at iteration 1001 is 3.5659 remaining_time is 06:20:07s ( 5.8 it/s)
The loss at iteration 1501 is 3.2723 remaining_time is 06:15:47s ( 5.8 it/s)
The loss at iteration 2001 is 3.0985 remaining_time is 06:14:12s ( 5.8 it/s)
The loss at iteration 2501 is 2.9945 remaining_time is 06:13:01s ( 5.8 it/s)
The loss at iteration 3001 is 2.8826 remaining_time is 06:12:00s ( 5.8 it/s)
The loss at iteration 3501 is 2.8051 remaining_time is 06:10:26s ( 5.8 it/s)
The loss at iteration 4001 is 2.7614 remaining_time is 06:08:14s ( 5.8 it/s)
The loss at iteration 4501 is 2.7096 remaining_time is 06:06:13s ( 5.8 it/s)
The loss at iteration 5001 is 2.6688 remaining_time is 06:04:21s ( 5.8 it/s)
The loss at iteration 5501 is 2.6305 remaining_time is 06:03:20s ( 5.8 it/s)
The loss at iteration 6001 is 2.5898 remaining_time is 06:01:33s ( 5.8 it/s)
Th

## Decoding time

There is different way to decode, the simplest one is the greedy decoding method, the principle is to loop with the new token produced each time (complete `nonefficient_greedy_decoding`) and choosing the token the most likely. We can also store Keys and Values as shown in course that in practice only accellerate the generation but not change it (complete `greedy_decoding`). And finally we can also sampling new tokens (complete `sampling_decoding` ).

In [17]:
def nonefficient_greedy_decoding(model, x, max_new_tokens=64):
    previous_gen = x

    for _ in range(max_new_tokens):
        with torch.no_grad():
            _, logits = model(previous_gen)
            next_token_logits = logits[:, -1, :]
            next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)

        previous_gen = torch.cat([previous_gen, next_token], dim=1)

    return previous_gen


def greedy_decoding(model, x, max_new_tokens=64, max_cache_size=512, tokenizer=None):
    model.eval()

    k_cache = [None] * model.nhl
    v_cache = [None] * model.nhl
    generated = x
    start_cache = 0

    for _ in range(max_new_tokens):
        with torch.no_grad():
            _, logits = model(
                generated[:, -1:],
                k_cache=k_cache,
                v_cache=v_cache,
                start_cache=start_cache
            )

            next_token_logits = logits[:, -1, :]
            next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)

        generated = torch.cat([generated, next_token], dim=1)
        start_cache += 1

        if generated.shape[1] >= max_cache_size:
            break

    return generated

def sampling_decoding(model, x, max_new_tokens=64, max_cache_size=512, tokenizer=None, temperature=.7):
    model.eval()

    k_cache = [None] * model.nhl
    v_cache = [None] * model.nhl
    generated = x
    start_cache = 0

    for _ in range(max_new_tokens):
        with torch.no_grad():
            _, logits = model(
                generated[:, -1:],
                k_cache=k_cache,
                v_cache=v_cache,
                start_cache=start_cache
            )

            logits = logits[:, -1, :] / temperature
            probs = torch.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)

        generated = torch.cat([generated, next_token], dim=1)
        start_cache += 1

        if generated.shape[1] >= max_cache_size:
            break

    return generated

In [21]:
model = TransformerDecoder(vocabulary_size, 128, 256, 4, 2)
model.load_state_dict(torch.load("transformer_state_dict_llama_mini.pth", map_location="cpu"))
model = model.eval()

prompt = "Alice"
tokenized = tokenizer(
    prompt,
    return_tensors="pt",
    add_special_tokens=True
)

input_ids = tokenized.input_ids

with torch.no_grad():
    out_nonefficient = nonefficient_greedy_decoding(
        model, input_ids, max_new_tokens=64
    )

    out_greedy = greedy_decoding(
        model, input_ids, max_new_tokens=64, tokenizer=tokenizer
    )

    out_sampling = sampling_decoding(
        model, input_ids, max_new_tokens=64, tokenizer=tokenizer, temperature=0.8
    )


In [22]:
print("=" * 80)
print("PROMPT:")
print(prompt)

print("\n" + "=" * 80)
print("Non-efficient Greedy Decoding:")
print(tokenizer.decode(out_nonefficient[0], skip_special_tokens=True))

print("\n" + "=" * 80)
print("Greedy Decoding with KV cache:")
print(tokenizer.decode(out_greedy[0], skip_special_tokens=True))

print("\n" + "=" * 80)
print("Sampling Decoding (temperature = 0.8):")
print(tokenizer.decode(out_sampling[0], skip_special_tokens=True))
print("=" * 80)


PROMPT:
Alice

Non-efficient Greedy Decoding:
Alice was a little girl who loved to play with her toys. One day, she was playing with her dolls when she saw a big, scary monster. She was scared and ran away.

Alice was very scared and ran away. She ran after it, but it was too fast. She ran after it

Greedy Decoding with KV cache:
Alice One Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once Once

Sampling Decoding (temperature = 0.8):
Alice Alice One The Once One Once One John John John John John John John John John John One Once Once Little Little Little Boy One Jack One Once Once Once Once Maggie at Once One Lily Once It Onceatted Once " Once Jim John John John John John John Once Once One Once Once One Once Once John Ted Ted Tede

## Evaluation ?

It remains difficult to evaluate the model. You can here propose an evalaution based on perplexity using an other model, or comparing results to ground truth using the test (here only validation) set, but it will not be totally informative !!!

In [23]:
validation_dataset = tinystories_dataset['validation']

In [None]:
import math
import torch
from torch.utils.data import DataLoader

model.eval()
model.to("cpu")

validation_dl = DataLoader(validation_dataset, batch_size=8, shuffle=False)

total_loss = 0.0
total_tokens = 0

with torch.no_grad():
    for batch in validation_dl:
        enc = tokenizer(
            batch["text"],
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=128
        )

        input_ids = enc.input_ids[:, :-1]
        target_ids = enc.input_ids[:, 1:]

        _, logits = model(input_ids)

        loss = loss_function(
            logits.reshape(-1, vocabulary_size),
            target_ids.reshape(-1)
        )

        num_tokens = target_ids.numel()
        total_loss += loss.item() * num_tokens
        total_tokens += num_tokens

avg_loss = total_loss / total_tokens
perplexity = math.exp(avg_loss)

print(f"Validation loss: {avg_loss:.4f}")
print(f"Validation perplexity: {perplexity:.2f}")

Validation loss: 1.8354
Validation perplexity: 6.27


## Perplexity-based Evaluation Results

We evaluate the language model on the **validation split** of the TinyStories dataset using **cross-entropy loss** and **perplexity**.

### Validation Results

- **Validation Loss:** 1.8354  
- **Validation Perplexity:** 6.27

### Interpretation

A perplexity of **6.27** indicates that, on average, the model considers approximately **6 possible tokens** at each prediction step.  
This reflects a **good predictive capability**, especially considering that:

- the model was trained **from scratch**,
- the architecture is **small-scale** (2 layers, 128 embedding size),
- the dataset contains **simple but diverse narratives**.

The relatively low validation loss confirms that the model has successfully learned:
- basic syntactic structures,
- common narrative patterns,
- and short-term dependencies present in children stories.

### Discussion and Limitations

While perplexity provides a useful quantitative evaluation, it does not fully capture:
- long-range coherence,
- factual consistency,
- or generation diversity.

This is illustrated by the decoding experiments, where greedy decoding tends to produce repetitive outputs, while sampling introduces more diversity at the cost of coherence.

Overall, these results validate the correctness of the implementation and the effectiveness of the training procedure.
