# Transformer TP
In this TP we will implement the transformer model (an architecture based on LLama2). The squeleton of the code is provided and you should complete it (mostly self attention mechanism).

In [29]:
import time
from time import gmtime, strftime
import math
import shutil
import os



# neural network utilities
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader

# dataset an tokenization
from datasets import load_dataset
from transformers import AutoTokenizer

from rope_embedding import RoPE


## Self Attention with no mask

Let consider $Q, K, V \in R^{B\times H \times L \times N}$ for each batch element (B) for each heads (H) we want to compute $$A = \frac{Q_{bh}.K_{bh}^{t}}{\sqrt{N}}$$

Attention weights are defined by $$ 
softmax(A_i) = \begin{pmatrix}
\frac{e^{A_{i1}}}{\sum\limits_{j=1}^L  e^{A_{ij}}} & 
\frac{e^{A_{i2}}}{\sum\limits_{j=1}^L  e^{A_{ij}}} & 
\dots &
\frac{e^{A_{iL}}}{\sum\limits_{j=1}^L  e^{A_{ij}}}
\end{pmatrix}$$ 
And $$ 
Softmax(A) = \begin{pmatrix}
\frac{e^{A_{11}}}{\sum\limits_{j=1}^L  e^{A_{1j}}} & 
\frac{e^{A_{12}}}{\sum\limits_{j=1}^L  e^{A_{1j}}} & 
\dots &
\frac{e^{A_{1L}}}{\sum\limits_{j=1}^L  e^{A_{1j}}} \\
\frac{e^{A_{21}}}{\sum\limits_{j=1}^L  e^{A_{2j}}} & 
\frac{e^{A_{22}}}{\sum\limits_{j=1}^L  e^{A_{2j}}} & 
\dots &
\frac{e^{A_{2l}}}{\sum\limits_{j=1}^L  e^{A_{2j}}} \\
\vdots & \vdots & \ddots & \vdots \\
\frac{e^{A_{L1}}}{\sum\limits_{j=1}^L  e^{A_{Lj}}} & 
\frac{e^{A_{L2}}}{\sum\limits_{j=1}^L  e^{A_{Lj}}} & 
\dots &
\frac{e^{A_{LL}}}{\sum\limits_{j=1}^L  e^{A_{Lj}}}
\end{pmatrix}$$ 


Notice that the function softmax can be used to compute this matrix [`torch.softmax`](https://docs.pytorch.org/docs/stable/generated/torch.softmax.html#torch.softmax).


## Masked Attention
In many cases, certains tokens will not compute attention with later tokens (padding or decoding approaches). Typically in decoder only architectures we would have the attention weights matrix only considering lower diagonal values :

$$ 
Softmax(A) = \begin{pmatrix}
\frac{e^{A_{11}}}{\sum\limits_{j=1}^1  e^{A_{1j}}} & 
0 & 
\dots & 0 &
0 \\
\frac{e^{A_{21}}}{\sum\limits_{j=1}^2  e^{A_{2j}}} & 
\frac{e^{A_{22}}}{\sum\limits_{j=1}^2  e^{A_{2j}}} & 
\dots & 0 &
0 \\
\vdots & \vdots & \ddots & \vdots \\
\frac{e^{A_{{L-1}1}}}{\sum\limits_{j=1}^{L-1}  e^{A_{L-1j}}} & 
\frac{e^{A_{{L-1}2}}}{\sum\limits_{j=1}^{L-1}  e^{A_{L-1j}}} & 
\dots &
\frac{e^{A_{L-1L-1}}}{\sum\limits_{j=1}^{L-1}  e^{A_{L-1j}}}& 0 \\
\frac{e^{A_{L1}}}{\sum\limits_{j=1}^L  e^{A_{Lj}}} & 
\frac{e^{A_{L2}}}{\sum\limits_{j=1}^L  e^{A_{Lj}}} & 
\dots &
\frac{e^{A_{LL-1}}}{\sum\limits_{j=1}^L  e^{A_{Lj}}} & 
\frac{e^{A_{LL}}}{\sum\limits_{j=1}^L  e^{A_{Lj}}} \\
\end{pmatrix}$$ 

A simple implementation would consist in setting the upper diagonal of A to $-\infty$ and apply the softmax. We will thus consider a matrix name $M$ (or `mask` in the code) having upper diagonal element to $-\infty$ : 
$$ M = 
\begin{pmatrix}
0& 
-\infty & 
\dots & -\infty&
-\infty \\
0 & 
0 & 
\dots & -\infty &
-\infty \\
\vdots & \vdots & \ddots & \vdots \\
0 & 
0 & 
\dots &
0& -\infty \\
0 & 
0 & 
\dots &
0 & 
0\\
\end{pmatrix}
$$

And apply softmax on $A + M$

In [30]:
def multi_head_attention(
    q : torch.FloatTensor,
    k : torch.FloatTensor,
    v : torch.FloatTensor,
    mask : torch.FloatTensor
):
    """
    Given Q,K,V and the attention mask compute cross attention

    The shape of the matrix are the following:
    B: The batch size (number of examples)
    Q: The size of the decoder sequence
    K: The size of the encoder sequence
    H: Then number of attention heads
    N: The embedding size

    Parameters:
        q : torch.FloatTensor
            The query matrix of shape BxHxQxN
        k : torch.FloatTensor
            The key matrix of shape BxHxKxN
        v : torch.FloatTensor
            The value matrix of shape BxHxKxN
        mask : torch.FloatTensor
            The attention mask of shape BxQxK,
            the masked elements are set to -inf
            else elements are set to 0
    Return : (torch.Tensor, torch.Tensor)
        return two tensor the first containing
        the attention weights (QK^t) and the second
        the result of the attention
    """
    B, H, Q, N = q.shape
    K = k.shape[2]

    # Compute attention scores
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(N)

    # Apply mask
    if mask is not None:
        scores = scores + mask

    # Apply softmax
    attention_weights = torch.softmax(scores, dim=-1)

    # Apply attention to values
    output = torch.matmul(attention_weights, v)

    return attention_weights, output


In [31]:
q, k, v = (torch.ones(1, 1, 3, 4), torch.ones(1, 1, 3, 4), torch.ones(1, 1, 3, 4))
q

tensor([[[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]]])

In [32]:
q, k, v = (torch.ones(1, 1, 3, 4), torch.ones(1, 1, 3, 4), torch.ones(1, 1, 3, 4))
mask = torch.triu(-q.new_ones(1, 3, 3) *  torch.inf, diagonal =  1)
attention_weigths, attention_output = multi_head_attention(q, k, v, mask)

assert(torch.all(attention_weigths == torch.Tensor([
          [[[1., 0., 0.],
            [1/2, 1/2, 0.],
            [1/3, 1/3, 1/3]]]])))

##  Create the attention module

In pytorch module are block of neural networks that inherit from the class `nn.Module`.  In the following you willbe asked to complete the attention code. The attention module compute the Query $Q$, the Keys $K$ and Values $V$ for each heads and compute attention. The forward methods is the method implementing these operations. In this lab we use a mult-head attention, meaning that we compute attention for each heads, that can be done with a for loop or using batched operation (it is up to you). Also as we follow the Llama 2 architecture, we use for positional embeding the RoPe positonal encoding that is applied on the two matrix $Q$ and $K$. The forward function should compute for an input $X \in \mathbb{R}^{L \times N}$ (in the decoder only case) the following:

1. $Q_h = RoPE(W^q_{h}X^{\intercal})$ (for all heads h) with $W^q_h \in \mathbb{R}^{N \times N}$
2. $K_h = RoPE(W^k_{h}X^{\intercal})$ (for all heads h)
3. $V_h = W^v_{h}X^{\intercal}$ (for all heads h)
4. $A_h =  Attention(Q_h, K_h, V_h)$ (for all heads h)
5. $A$ = [A_1, A_2, \dots A_H] (the concatenation along the last dimension)
6. $O = W^oA^{\intercal}$ with $w^{o} \in \mathbb{R}^{N \times NH}$ (NH being the scalar multplication of N and H)

Notice that Q, K and V can be computed without a loop using $W_{q}$ (respectively for $W_k$ and $W_v$) .


In [33]:

class RoPEAttentionModule(nn.Module):
    """ A cross/self attention pytorch module.

    """
    def __init__(self, input_dim, output_dim, num_heads=1):
        super().__init__()

        self.Wq, self.Wk, self.Wv  =\
            (nn.Linear(input_dim, output_dim * num_heads) for _ in range(3))

        self.output_dim = output_dim
        self.num_heads = num_heads
        self.Wo = nn.Linear(output_dim * num_heads, input_dim)
        self.rope_func = RoPE(output_dim)

    def forward(
            self,
            x : torch.Tensor,
            attention_mask : torch.BoolTensor = None,
            y : torch.Tensor = None,
            decoder_mask = True,
            k_cache = None,
            v_cache = None,
            start_cache = 0
        ):
        '''
            Parameters:
                x : torch.Tensor
                    The input of the Attention module
                    used at least for K, V computation
                    (for Q if decoder only)
                attention_mask : torch.BoolTensor
                    The mask for attention
                y : torch.Tensor or None
                    The query input in the case of
                    cross-attention
        '''
        B, Lk, _ = x.shape
        q_in = x if y is None else y
        Bq, Lq, _ = q_in.shape
        q = self.Wq(q_in)
        k = self.Wk(x)
        v = self.Wv(x)

        H = self.num_heads
        D = self.output_dim
        q = q.view(Bq, Lq, H, D).transpose(1, 2)
        k = k.view(B, Lk, H, D).transpose(1, 2)
        v = v.view(B, Lk, H, D).transpose(1, 2)

        q = self.rope_func(q)
        k = self.rope_func(k)

        if k_cache is not None and v_cache is not None:
            cache_len = k_cache.size(2)
            end_cache = start_cache + Lk
            if end_cache > cache_len:
                raise ValueError("Cache is too small")\

            k_cache[:, :, start_cache:end_cache, :] = k
            v_cache[:, :, start_cache:end_cache, :] = v
            k = k_cache[:, :, :end_cache, :]
            v = v_cache[:, :, :end_cache, :]
            Lk = k.size(2)

        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(D)  # B x H x Lq x Lk


        if decoder_mask:
            causal = torch.triu(
                torch.ones(Lq, Lk, device=scores.device, dtype=torch.bool),
                diagonal=1
            )
            scores = scores.masked_fill(causal, float("-inf"))

        if attention_mask is not None:
            scores = scores.masked_fill(attention_mask.unsqueeze(1), float("-inf"))

        attention_weigths = torch.softmax(scores, dim=-1)
        attn = torch.matmul(attention_weigths, v)

        attn = attn.transpose(1, 2).contiguous().view(Bq, Lq, H * D)  # B x Lq x (H*D)
        output = self.Wo(attn)  # B x Lq x input_dim

        return attention_weigths, output

## Trasnsformer FeedForward

The FeedForward Network is defined using Llama 2 architecture as following for $X \in \mathbb{R^{L\times N}}$ :


1. $G = SiLU(W^gX^\intercal)$ with $W^g \in \mathbb{R}^{M \times N}$ (see SiLU in pytorch [documentation](https://docs.pytorch.org/docs/stable/generated/torch.nn.SiLU.html))
2. $U = W^u X^\intercal$ with $W^u  \in \mathbb{R}^{M \times N}$
3. $I = G \odot I$ (Hadamard multiplication)
4. $O = W^{o} I^{\intercal}$ with $W_{o} \in \mathbb{R}^{N \times M}$






In [34]:
class TransformerFeedForward(nn.Module):
    def __init__(self, embed_size, intermediate_size):
        super().__init__()
        self.gate_proj = nn.Linear(embed_size, intermediate_size)
        self.up_proj = nn.Linear(embed_size, intermediate_size)
        self.down_proj = nn.Linear(intermediate_size, embed_size)
        self.gate_func = nn.SiLU()

    def forward(self, x):
        g = self.gate_func(self.gate_proj(x))
        u = self.up_proj(x)
        i = g * u
        o = self.down_proj(i)
        return o


## Transformer block
We can now create the block of the 

Let consider $x_0 \in \mathbb{R}^{B \times L \times N}$, the decoder block is given by : 

* $x_1 = LN_1(x_0)$ (apply layer norm)
* $x_2 = Attention(x_1)$ (apply attention)
* $x_3 = x_0 + x_2$ (adding residual)
* $x_4 = LN_2(x_3)$
* $y = FF(x_4)$


In [35]:
class TransformerDecoderBlock(nn.Module):
    def __init__(self, embed_size, intermediate_size, num_heads):
        super().__init__()
        self.attention_module =\
            RoPEAttentionModule(embed_size, embed_size, num_heads=num_heads)
        self.feed_forward = TransformerFeedForward(embed_size, intermediate_size)

        self.attention_layer_norm = nn.RMSNorm(embed_size)
        self.feed_forward_layer_norm = nn.RMSNorm(embed_size)

    def forward(self, x, attention_mask=None, k_cache=None, v_cache=None, start_cache=0, layer=0,):
        x1 = self.attention_layer_norm(x)

        _, attn_out = self.attention_module(
            x1,
            attention_mask=attention_mask,
            y=None,
            decoder_mask=True,
            k_cache=k_cache,
            v_cache=v_cache,
            start_cache=start_cache,
        )

        x3 = x + attn_out  # residual
        x4 = self.feed_forward_layer_norm(x3)
        output = self.feed_forward(x4)

        return output


## The Model

The model contain a stack of block on input embedding, it applies the different transformation and should return both the output embeddings and the logits (size of the vocabulary).

In [36]:
class TransformerDecoder(nn.Module):
    def __init__(self, vocabulary_size, embed_size, intermediate_size, num_heads, hidden_layers=5, **kwargs):
        super().__init__(**kwargs)
        self.nh = num_heads
        self.nhl = hidden_layers
        self.es = embed_size

        self.wte = nn.Embedding(vocabulary_size, embed_size)
        self.wpe = nn.Embedding(512, embed_size)


        self.drop = nn.Dropout(.1)
        self.blocks = nn.ModuleList([TransformerDecoderBlock(embed_size, intermediate_size, num_heads) for _ in range(hidden_layers)])
        self.ln_f = nn.LayerNorm(embed_size, eps=1e-3)
        self.lm_head = nn.Linear(embed_size, vocabulary_size)

        self.position = torch.arange(512).unsqueeze(0)

    def forward(
        self,
        input_ids,
        attention_mask=None,
        k_cache=None,
        v_cache=None,
        start_cache=0
    ):
        if input_ids.device != self.position.device:
            self.position = self.position.to(input_ids.device)

        B, L = input_ids.shape

        # token + position embeddings
        tok = self.wte(input_ids)                         # B x L x N
        pos = self.wpe(self.position[:, :L])              # 1 x L x N
        x = self.drop(tok + pos)

        # blocks
        for layer, block in enumerate(self.blocks):
            layer_k = k_cache[layer] if isinstance(k_cache, (list, tuple)) else k_cache
            layer_v = v_cache[layer] if isinstance(v_cache, (list, tuple)) else v_cache

            x = block(
                x,
                attention_mask=attention_mask,
                k_cache=layer_k,
                v_cache=layer_v,
                start_cache=start_cache,
                layer=layer,
            )

        # final norm + LM head
        output_embed = self.ln_f(x)
        output_lm = self.lm_head(output_embed)

        return output_embed, output_lm


##  Training the transformer

### I.Dataset 
For simplicity we will use a small dataset name TinyStories, we can load it using the huggingface Datasets library as following

In [37]:

tinystories_dataset = load_dataset("roneneldan/TinyStories")
training_set = tinystories_dataset['train']
print(training_set[0])

{'text': 'One day, a little girl named Lily found a needle in her room. She knew it was difficult to play with it because it was sharp. Lily wanted to share the needle with her mom, so she could sew a button on her shirt.\n\nLily went to her mom and said, "Mom, I found this needle. Can you share it with me and sew my shirt?" Her mom smiled and said, "Yes, Lily, we can share the needle and fix your shirt."\n\nTogether, they shared the needle and sewed the button on Lily\'s shirt. It was not difficult for them because they were sharing and helping each other. After they finished, Lily thanked her mom for sharing the needle and fixing her shirt. They both felt happy because they had shared and worked together.'}


### The Dataloader

We can define the batch_size default we will use, using the DataLoader object. Here we consider that the training machine will have enough RAM for a batch of 16 samples.

In [38]:
training_dl = DataLoader(training_set, batch_size=16, shuffle=True)

### III. Tokenizer

In this exercice we will consider the LLama Tokenizer, however, you can train your own tokenizer if you prefer (see lab 2)

In [39]:


os.environ["TOKENIZERS_PARALLELISM"] = "false"

model_hg_id = "mistralai/Mistral-7B-v0.3"

tokenizer = AutoTokenizer.from_pretrained(model_hg_id)
tokenizer.pad_token = '<pad>'
vocabulary_size = tokenizer.vocab_size

### IV. Create the model and the optimizer

In [40]:
device = 'cuda'
model = TransformerDecoder(
    vocabulary_size=vocabulary_size,
    embed_size=384,
    intermediate_size=1536,
    num_heads=12,
    hidden_layers=8,
)

model = model.train()
model = model.to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=6e-4)

In [41]:
model

TransformerDecoder(
  (wte): Embedding(32768, 384)
  (wpe): Embedding(512, 384)
  (drop): Dropout(p=0.1, inplace=False)
  (blocks): ModuleList(
    (0-7): 8 x TransformerDecoderBlock(
      (attention_module): RoPEAttentionModule(
        (Wq): Linear(in_features=384, out_features=4608, bias=True)
        (Wk): Linear(in_features=384, out_features=4608, bias=True)
        (Wv): Linear(in_features=384, out_features=4608, bias=True)
        (Wo): Linear(in_features=4608, out_features=384, bias=True)
      )
      (feed_forward): TransformerFeedForward(
        (gate_proj): Linear(in_features=384, out_features=1536, bias=True)
        (up_proj): Linear(in_features=384, out_features=1536, bias=True)
        (down_proj): Linear(in_features=1536, out_features=384, bias=True)
        (gate_func): SiLU()
      )
      (attention_layer_norm): RMSNorm((384,), eps=None, elementwise_affine=True)
      (feed_forward_layer_norm): RMSNorm((384,), eps=None, elementwise_affine=True)
    )
  )
  (ln_f):

### VI. Training the model

In [None]:
model = model.train()
avg_loss = []
start_time = time.time()

for i, data in enumerate(training_dl):
    res = tokenizer(data['text'], return_tensors="pt", padding=True, padding_side='right', max_length=128, truncation=True)
    x = res.input_ids[:, :-1] + 0
    y = res.input_ids[:, 1:] + 0
    # attention_mask = (res.attention_mask == 0)
    optimizer.zero_grad()
    oe, oy = model(x.to(device))
    loss = loss_function(oy.to(device).view(-1, vocabulary_size), y.to(device).view(-1))
    loss.backward()
    loss_value = loss.item()
    avg_loss.append(loss_value)
    if( i%500 == 0):
        elapsed_time = time.time() - start_time
        remaining_time = int((elapsed_time/(i+1)) * (len(training_dl) - i))
        loop_sec = ((i+1)/elapsed_time)
        print(f"The loss at iteration {i+1} is {sum(avg_loss)/len(avg_loss):3.4f} remaining_time is {strftime('%H:%M:%S', gmtime(remaining_time))}s ({loop_sec:4.1f} it/s)", flush=True)
        avg_loss = []
    if( i%1000 == 0):
        torch.save(optimizer.state_dict() ,"optimizer_state_dict_llama_mini.pth.temp")
        torch.save(model.state_dict() ,"transformer_state_dict_llama_mini.pth.temp")
        shutil.copyfile("optimizer_state_dict_llama_mini.pth.temp", "optimizer_state_dict_llama_mini.pth")
        shutil.copyfile("transformer_state_dict_llama_mini.pth.temp", "transformer_state_dict_llama_mini.pth")
    optimizer.step()

The loss at iteration 1 is 10.5374 remaining_time is 09:56:43s ( 3.7 it/s)
The loss at iteration 501 is 6.0703 remaining_time is 04:39:20s ( 7.9 it/s)
The loss at iteration 1001 is 5.9788 remaining_time is 04:21:40s ( 8.4 it/s)
The loss at iteration 1501 is 5.9742 remaining_time is 04:50:10s ( 7.5 it/s)
The loss at iteration 2001 is 5.9656 remaining_time is 04:38:02s ( 7.8 it/s)
The loss at iteration 2501 is 5.9632 remaining_time is 04:52:28s ( 7.4 it/s)
The loss at iteration 3001 is 5.9608 remaining_time is 04:43:26s ( 7.6 it/s)


## Decoding time

There is different way to decode, the simplest one is the greedy decoding method, the principle is to loop with the new token produced each time (complete `nonefficient_greedy_decoding`) and choosing the token the most likely. We can also store Keys and Values as shown in course that in practice only accellerate the generation but not change it (complete `greedy_decoding`). And finally we can also sampling new tokens (complete `sampling_decoding` ).

In [None]:
def nonefficient_greedy_decoding(model, x, max_new_tokens=64):
    """
    Slow greedy decoding: re-run full sequence each step.
    """
    previous_gen = x

    with torch.no_grad():
        for _ in range(max_new_tokens):
            _, logits = model(previous_gen)  # model returns (emb, logits)
            last_token_logits = logits[:, -1, :]
            next_token = torch.argmax(last_token_logits, dim=-1, keepdim=True)
            previous_gen = torch.cat((previous_gen, next_token), dim=1)

    return previous_gen


def greedy_decoding(model, x, max_new_tokens=64, max_cache_size=512, tokenizer=None):
    """
    Greedy decoding with KV cache (efficient).
    """
    device = x.device
    dtype = next(model.parameters()).dtype
    B = x.size(0)

    # infer cache shape from model
    attn = model.blocks[0].attention_module
    H = attn.num_heads
    D = attn.output_dim
    num_layers = model.nhl

    total_needed = x.size(1) + max_new_tokens
    if total_needed > max_cache_size:
        raise ValueError(f"max_cache_size={max_cache_size} too small for prompt+gen={total_needed}")

    # allocate caches per layer
    k_cache = [torch.zeros(B, H, max_cache_size, D, device=device, dtype=dtype) for _ in range(num_layers)]
    v_cache = [torch.zeros(B, H, max_cache_size, D, device=device, dtype=dtype) for _ in range(num_layers)]

    previous_gen = x
    cur_len = x.size(1)

    with torch.no_grad():
        # prefill cache with prompt
        _, logits = model(x, k_cache=k_cache, v_cache=v_cache, start_cache=0)

        # first token
        next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
        previous_gen = torch.cat((previous_gen, next_token), dim=1)
        cur_len += 1

        # generate remaining tokens
        for _ in range(max_new_tokens - 1):
            _, logits = model(next_token, k_cache=k_cache, v_cache=v_cache, start_cache=cur_len - 1)
            next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
            previous_gen = torch.cat((previous_gen, next_token), dim=1)
            cur_len += 1

    return previous_gen

def apply_no_repeat_ngram(logits, prev_tokens, n=3):
    if prev_tokens.size(1) < n:
        return logits
    B, V = logits.size()
    for b in range(B):
        tokens = prev_tokens[b].tolist()
        ngrams = {}
        for i in range(len(tokens) - n + 1):
            key = tuple(tokens[i:i+n-1])
            ngrams.setdefault(key, set()).add(tokens[i+n-1])
        key = tuple(tokens[-(n-1):])
        banned = list(ngrams.get(key, []))
        if banned:
            logits[b, banned] = float("-inf")
    return logits



def sampling_decoding(
    model,
    x,
    max_new_tokens=64,
    max_cache_size=512,
    temperature=0.9,
    top_k=50,
    top_p=0.95,
    repetition_penalty=1.2,
    no_repeat_ngram=3,
):
    device = x.device
    dtype = next(model.parameters()).dtype
    B = x.size(0)

    attn = model.blocks[0].attention_module
    H = attn.num_heads
    D = attn.output_dim
    num_layers = model.nhl

    total_needed = x.size(1) + max_new_tokens
    if total_needed > max_cache_size:
        raise ValueError(f"max_cache_size={max_cache_size} too small for prompt+gen={total_needed}")

    k_cache = [torch.zeros(B, H, max_cache_size, D, device=device, dtype=dtype) for _ in range(num_layers)]
    v_cache = [torch.zeros(B, H, max_cache_size, D, device=device, dtype=dtype) for _ in range(num_layers)]

    previous_gen = x
    cur_len = x.size(1)

    with torch.no_grad():
        # prefill
        _, logits = model(x, k_cache=k_cache, v_cache=v_cache, start_cache=0)

        for _ in range(max_new_tokens):
            # repetition penalty
            if repetition_penalty is not None and repetition_penalty > 1.0:
                for b in range(B):
                    token_ids = previous_gen[b].tolist()
                    logits[b, -1, token_ids] /= repetition_penalty

            last_logits = logits[:, -1, :] / temperature

            # no-repeat ngram
            if no_repeat_ngram is not None and no_repeat_ngram > 1:
                last_logits = apply_no_repeat_ngram(last_logits, previous_gen, n=no_repeat_ngram)

            # top-k
            if top_k is not None and top_k > 0:
                values, indices = torch.topk(last_logits, top_k, dim=-1)
                mask = torch.full_like(last_logits, float("-inf"))
                mask.scatter_(1, indices, values)
                last_logits = mask

            # top-p
            if top_p is not None and 0 < top_p < 1.0:
                sorted_logits, sorted_indices = torch.sort(last_logits, descending=True, dim=-1)
                probs = torch.softmax(sorted_logits, dim=-1)
                cumprobs = torch.cumsum(probs, dim=-1)
                cutoff = cumprobs > top_p
                cutoff[:, 0] = False
                sorted_logits[cutoff] = float("-inf")
                last_logits = torch.full_like(last_logits, float("-inf"))
                last_logits.scatter_(1, sorted_indices, sorted_logits)

            probs = torch.softmax(last_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)

            previous_gen = torch.cat((previous_gen, next_token), dim=1)
            cur_len += 1

            _, logits = model(next_token, k_cache=k_cache, v_cache=v_cache, start_cache=cur_len - 1)

    return previous_gen


In [None]:
model = TransformerDecoder(
    vocabulary_size=vocabulary_size,
    embed_size=384,
    intermediate_size=1536,
    num_heads=12,
    hidden_layers=8,
)

model.load_state_dict(torch.load("transformer_state_dict_llama_mini.pth", map_location="cuda"))
model = model.eval()

decoding_method = greedy_decoding

text = "Alice"
model.to("cpu")
tokenized_text = tokenizer(text, return_tensors='pt')

output_ids = sampling_decoding(model, tokenized_text.input_ids, max_new_tokens=64, temperature=0.9, top_k=50, top_p=0.95, repetition_penalty=1.1)

print(tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0])

Alice and Tom Once Once Lily Once Once Once One Once Lily Tommy Tim Lily L Lily Once Tim Once Once Ben One Tom L One Once Once Sam Lily John Once One Tim Tom Anna Once Once Jo Once Sara Once John Lily One Once Anna Once Jack Once Once Molly Once Once There Once Once Jim Once Once Billy John Once Once Tom


## Evaluation ?

It remains difficult to evaluate the model. You can here propose an evalaution based on perplexity using an other model, or comparing results to ground truth using the test (here only validation) set, but it will not be totally informative !!!

In [None]:
validation_dataset = tinystories_dataset['validation']

In [None]:
validation_dataset

Dataset({
    features: ['text'],
    num_rows: 21990
})

In [None]:
def evaluate_perplexity(model, tokenizer, validation_dataset, batch_size=8, max_length=128, device="cpu"):
    model.eval()
    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
    losses = []

    with torch.no_grad():
        for i in range(0, len(validation_dataset), batch_size):
            batch = validation_dataset[i:i+batch_size]["text"]
            enc = tokenizer(
                batch,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=max_length,
            )
            input_ids = enc.input_ids.to(device)

            x = input_ids[:, :-1]
            y = input_ids[:, 1:]

            _, logits = model(x)
            # reshape to (B*L, V)
            loss = loss_fn(logits.reshape(-1, logits.size(-1)), y.reshape(-1))
            losses.append(loss.item())

    mean_loss = sum(losses) / len(losses)
    ppl = math.exp(mean_loss)
    return mean_loss, ppl


In [None]:
model.to("cuda")
val_loss, val_ppl = evaluate_perplexity(model, tokenizer, validation_dataset, batch_size=8, max_length=128, device="cuda")
print(f"val loss: {val_loss:.4f} | val ppl: {val_ppl:.2f}")


val loss: 1.5522 | val ppl: 4.72
