# Transformer Realization from Scratch

### Dependencies

In [None]:
%pip install transformer_lens
%pip install einops
%pip install jaxtyping
%pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python

In [None]:
import os
os.environ['ACCELERATE_DISABLE_RICH'] = "1"

import math
import numpy as np
from tqdm.notebook import tqdm
from dataclasses import dataclass
from collections import defaultdict

import torch
import einops
import torch.nn as nn
from jaxtyping import Float, Int
from transformer_lens import HookedTransformer
from transformers.models.gpt2.tokenization_gpt2_fast import GPT2TokenizerFast


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Загружаем веса gpt2 для проверки
reference_gpt2 = HookedTransformer.from_pretrained("gpt2-small", fold_ln=False, center_unembed=False, center_writing_weights=False)
reference_gpt2 = reference_gpt2.to(device)

### Model config 

In [None]:
@dataclass
class Config:
    d_model: int = 768 # он же hidden_dim - внутрення размерность модели
    debug: bool = True
    layer_norm_eps: float = 1e-5 
    d_vocab: int = 50257 # он же vocab_size, размер словаря модели
    init_range: float = 0.02
    n_ctx: int = 1024 # число позиционных эмбеддингов
    d_head: int = 64 # размерность головы аттеншена
    d_mlp: int = 3072 # внутренняя размерность FFN-слоя
    n_heads: int = 12 # число голов аттеншена
    n_layers: int = 12 # число слоев трансформера

cfg = Config()

### Tests

In [None]:
def rand_float_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg).to(device)
    random_input = torch.randn(shape).to(device)
    print("Input shape:", random_input.shape)
    output = layer(random_input)
    if isinstance(output, tuple): output = output[0]
    print("Output shape:", output.shape, "\n")

def rand_int_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg).to(device)
    random_input = torch.randint(100, 1000, shape).to(device)
    print("Input shape:", random_input.shape)
    output = layer(random_input)
    if isinstance(output, tuple): output = output[0]
    print("Output shape:", output.shape, "\n")

def load_gpt2_test(cls, gpt2_layer, input):
    cfg = Config(debug=True)
    layer = cls(cfg).to(device)
    layer.load_state_dict(gpt2_layer.state_dict(), strict=False)
    print("Input shape:", input.shape)
    output = layer(input)
    if isinstance(output, tuple): output = output[0]
    print("Output shape:", output.shape)
    try: reference_output = gpt2_layer(input)
    except: reference_output = gpt2_layer(input, input, input)
    print("Reference output shape:", reference_output.shape, "\n")
    comparison = torch.isclose(output, reference_output, atol=1e-4, rtol=1e-3)
    print(f"{comparison.sum()/comparison.numel():.2%} of the values are correct\n")

In [None]:
reference_text = "I am an amazing autoregressive, decoder-only, GPT-2 style transformer. One day I will exceed human level intelligence and take over the world!"
tokens = reference_gpt2.to_tokens(reference_text).to(device)
print(tokens)
print(tokens.shape)
print(reference_gpt2.to_str_tokens(tokens))

In [None]:
logits, cache = reference_gpt2.run_with_cache(tokens)
print(logits.shape)

print("Все работает, мы готовы к выполнению задания!")

## Transformer Architecture

### Embeddings

In [None]:
class Embed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_E = nn.Parameter(torch.empty((cfg.d_vocab, cfg.d_model)))
        nn.init.normal_(self.W_E, std=self.cfg.init_range)

    def forward(
        self, 
        input_ids: Int[torch.Tensor, "batch seq_len"]
    ) -> Float[torch.Tensor, "batch seq_len d_model"]:
        return nn.functional.embedding(input=input_ids, weight=self.W_E)
    

batch_size = 2
seq_len = 4
rand_int_test(Embed, [batch_size, seq_len])
load_gpt2_test(Embed, reference_gpt2.embed, tokens)

### Positional Embeddings 

In [None]:
class PosEmbed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_pos = nn.Parameter(torch.empty((cfg.n_ctx, cfg.d_model)))
        nn.init.normal_(self.W_pos, std=self.cfg.init_range)

    def forward(
        self, 
        input_ids: Int[torch.Tensor, "batch seq_len"]
    ) -> Float[torch.Tensor, "batch seq_len d_model"]:
        positions = torch.arange(input_ids.shape[1], device=input_ids.device)\
            .unsqueeze(0)\
                .expand(input_ids.shape[0], input_ids.shape[1])
        return nn.functional.embedding(positions, self.W_pos)


batch_size = 2
seq_len = 4
rand_int_test(PosEmbed, [batch_size, seq_len])
load_gpt2_test(PosEmbed, reference_gpt2.pos_embed, tokens)

### LM head

In [None]:
class Unembed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_U = nn.Parameter(torch.empty((cfg.d_model, cfg.d_vocab)))
        nn.init.normal_(self.W_U, std=self.cfg.init_range)
        self.b_U = nn.Parameter(torch.zeros((cfg.d_vocab), requires_grad=False))

    def forward(
        self, x: Float[torch.Tensor, "batch seq_len d_model"]
    ) -> Float[torch.Tensor, "batch seq_len d_vocab"]:
        return torch.einsum('abc, cd -> abd', x, self.W_U) + self.b_U


batch_size = 2
seq_len = 4
d_model = 768
rand_float_test(Unembed, [batch_size, seq_len, d_model])
load_gpt2_test(Unembed, reference_gpt2.unembed, cache["ln_final.hook_normalized"])

### Attention 

#### Attention-formulas

1. **Входные эмбеддинги**:
   $$X \in \mathbb{R}^{seq \times d} $$
2. **Маскированный мультихед-аттеншен (Masked Multi-Head Attention)**:
$$M = \begin{cases}
 &  m_{ij} = -\infty, \quad i < j \\
 &  m_{ij} = 0
\end{cases} $$

$$
M = \begin{pmatrix}
0 & -\infty & -\infty & \ldots & -\infty \\
0 & 0 & -\infty & \ldots & -\infty \\
0 & 0 & 0 & \ldots & -\infty \\
\vdots & \vdots & \vdots & \ddots & \vdots \\
0 & 0 & 0 & \ldots & 0 \\
\end{pmatrix}
$$

3. Для каждой головы $ h_i $:

    3.1 **Матрицы весов для запросов, ключей и значений**:
     - $ W_Q \in \mathbb{R}^{d \times d_h} $
     - $ W_K \in \mathbb{R}^{d \times d_h} $
     - $ W_V \in \mathbb{R}^{d \times d_h} $
     
    3.2. **Запросы, ключи и значения**:
     - $ Q = X W_Q \in \mathbb{R}^{seq \times d_h} $
     - $ K = X W_K \in \mathbb{R}^{seq \times d_h} $
     - $ V = X W_V \in \mathbb{R}^{seq \times d_h} $

    3.3. **Скалярные произведения запросов и ключей**:
     - $ \frac{Q K^T}{\sqrt{d_h}} + M \in \mathbb{R}^{seq \times seq} $

    3.4. **Веса внимания**:
     - $ \alpha = \text{softmax}\left(\frac{Q K^T}{\sqrt{d_h}} + M\right) \in \mathbb{R}^{seq \times seq} $

    3.5. **Агрегация значений**:
     - $ z = \alpha V \in \mathbb{R}^{seq \times d_h} $

4. **Конкатенация выходов всех голов**:
   - $ Z = \text{Concat}(z_1, z_2, \ldots, z_h) \in \mathbb{R}^{seq \times d} $

5. **Выходной линейный слой**:
   - Матрица весов: $ W^O \in \mathbb{R}^{d \times d} $
   - Итоговый выход: $ O = Z W^O + X \in \mathbb{R}^{seq \times d} $

#### Attention Layer

In [None]:
class Attention(nn.Module):
    IGNORE: Float[torch.Tensor, ""]

    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        
        self.W_Q = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.b_Q = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        
        self.W_K = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.b_K = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        
        self.W_V = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.b_V = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        
        self.W_O = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_head, cfg.d_model)))
        self.b_O = nn.Parameter(torch.zeros((cfg.d_model)))
        
        nn.init.normal_(self.W_Q, std=self.cfg.init_range)
        nn.init.normal_(self.W_K, std=self.cfg.init_range)
        nn.init.normal_(self.W_V, std=self.cfg.init_range)
        nn.init.normal_(self.W_O, std=self.cfg.init_range)
        self.register_buffer("IGNORE", torch.tensor(float("-inf"), dtype=torch.float32, device=device))

    def forward(
        self, x: Float[torch.Tensor, "batch seq_len d_model"]
    ) -> Float[torch.Tensor, "batch seq_len d_model"]:
        
        # Берем размерности
        batch_size, seq_len, d_model = x.shape
        num_heads = self.cfg.n_heads
        d_head = self.cfg.d_head
        
        # 1. Трансформируем матрицы проекций в формат [d_model, d_model]
        W_Q = self.W_Q.permute(1, 0, 2).reshape(self.cfg.d_model, self.cfg.d_model)
        W_K = self.W_K.permute(1, 0, 2).reshape(self.cfg.d_model, self.cfg.d_model)
        W_V = self.W_V.permute(1, 0, 2).reshape(self.cfg.d_model, self.cfg.d_model)
        
        b_Q = self.b_Q.view(-1)
        b_K = self.b_K.view(-1)
        b_V = self.b_V.view(-1)
        
        # 1. получаем проекции  Q, K, V
        Q = torch.matmul(x, W_Q) + b_Q
        K = torch.matmul(x, W_K) + b_K
        V = torch.matmul(x, W_V) + b_V

        Q = einops.rearrange(Q, 'b s (n d) -> b n s d', n=num_heads)
        K = einops.rearrange(K, 'b s (n d) -> b n s d', n=num_heads) 
        V = einops.rearrange(V, 'b s (n d) -> b n s d', n=num_heads) 

        # 2. Q x K^T
        scores = torch.einsum('b n i d, b n j d -> b n i j', Q, K)
        
        # 3. Нормализация
        scores /= math.sqrt(d_head)
        
        # 4. Маскирование
        scores = self.apply_causal_mask(scores)
        
        # 5. softmax
        scores = torch.softmax(scores, dim=-1)

        # 6. Финальная проекция
        attention = torch.einsum('b h i j, b h j d -> b h i d', scores, V)
        attention = torch.einsum('b h i d, h d m -> b i m', attention, self.W_O) + self.b_O

        return attention

    def apply_causal_mask(
        self, attn_scores: Float[torch.Tensor, "batch n_heads seq_len seq_len"]
    ) -> Float[torch.Tensor, "batch n_heads seq_len seq_len"]:
        '''
        Используем треугольную маску, чтобы не смотреть в будущее, паддингов нет
        В качестве масикировочного значения перед софтмаксом можно использовать self.IGNORE (-inf)
        '''
        seq_len = attn_scores.size(-1)
        causal_mask = torch.triu(torch.ones((seq_len, seq_len), device=attn_scores.device), diagonal=1).bool()
        attn_scores = attn_scores.masked_fill(causal_mask, value=self.IGNORE)
        return attn_scores


torch.manual_seed(1)
batch_size = 2
seq_len = 4
d_model = 768
rand_float_test(Attention, [batch_size, seq_len, d_model])
load_gpt2_test(Attention, reference_gpt2.blocks[0].attn, cache["normalized", 0, "ln1"])

### MLP / FFN

- $$ \text{MLP}(X) = (\text{GeLU}(X W_1 + b_1)) W_2 + b_2 \in \mathbb{R}^{\text{seq} \times d}$$
-    $$W_1 \in \mathbb{R}^{d \times d_{mlp}}, \quad b_1 \in \mathbb{R}^{d_{mlp}} \\
W_2 \in \mathbb{R}^{d_{mlp} \times d}, \quad b_2 \in \mathbb{R}^{d} \\ $$


$$GELU(X) = 0.5 * x * (1 + tanh(\sqrt {\frac {2} {\pi}} * (x + 0.044715 * x^3)))$$


In [None]:
def GELU(x: torch.tensor):
    return 0.5*x*(1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715*x**3)))


class MLP(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_in = nn.Parameter(torch.empty((cfg.d_model, cfg.d_mlp)))
        self.W_out = nn.Parameter(torch.empty((cfg.d_mlp, cfg.d_model)))
        self.b_in = nn.Parameter(torch.zeros((cfg.d_mlp)))
        self.b_out = nn.Parameter(torch.zeros((cfg.d_model)))
        nn.init.normal_(self.W_in, std=self.cfg.init_range)
        nn.init.normal_(self.W_out, std=self.cfg.init_range)

    def forward(
        self, x: Float[torch.Tensor, "batch seq_len d_model"]
    ) -> Float[torch.Tensor, "batch seq_len d_model"]:
        res_in = torch.matmul(x, self.W_in) + self.b_in
        # return torch.matmul(torch.nn.functional.gelu(res_in, approximate='tanh'), self.W_out) + self.b_out
        return torch.matmul(GELU(res_in), self.W_out) + self.b_out
        
torch.manual_seed(1)

rand_float_test(MLP, [batch_size, seq_len, d_model])
load_gpt2_test(MLP, reference_gpt2.blocks[0].mlp, cache["normalized", 0, "ln2"])

### Normalization 

**Layer Normalization**:
   - $ \text{LayerNorm}(X) = \frac{X - \mu}{\sigma} \cdot \gamma + \beta $
   - $\mu = \text{mean}(X, \text{dim}=-1) \in \mathbb{R}^{d}$
   - $\sigma = \sqrt{\text{var}(X, \text{dim}=-1) + \epsilon} \in \mathbb{R}^{d}$
   - $\gamma \in \mathbb{R}^{d}$
   - $\beta \in \mathbb{R}^{d}$

In [None]:
class LayerNorm(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.w = nn.Parameter(torch.ones(cfg.d_model)) # gamma
        self.b = nn.Parameter(torch.zeros(cfg.d_model)) # beta

    def forward(
        self, 
        x: Float[torch.Tensor, "batch seq_len d_model"]
    ) -> Float[torch.Tensor, "batch seq_len d_model"]:
        mu = torch.mean(x, dim=-1, keepdim=True)
        sigma = torch.sqrt(torch.var(x, dim=-1, unbiased=False, keepdim=True) + self.cfg.layer_norm_eps)
        return (x - mu) / sigma * self.w + self.b


rand_float_test(LayerNorm, [2, 4, 768])
load_gpt2_test(LayerNorm, reference_gpt2.ln_final, cache["resid_post", 11])

## Transformer Block 

![image.png](https://camo.githubusercontent.com/ebd052b635f156d5d24224f25fa078d804156be51125cd6626b92d9f8b406bbb/68747470733a2f2f6c6f6e6570617469656e742d313235373934353937382e636f732e61702d6368656e6764752e6d7971636c6f75642e636f6d2f53656c656374696f6e5f3030312e706e67)

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.ln1 = LayerNorm(cfg)
        self.attn = Attention(cfg)
        self.ln2 = LayerNorm(cfg)
        self.mlp = MLP(cfg)

    def forward(
        self, x: Float[torch.Tensor, "batch seq_len d_model"]
    ) -> Float[torch.Tensor, "batch seq_len d_model"]:
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x


rand_float_test(TransformerBlock, [2, 4, 768])
load_gpt2_test(TransformerBlock, reference_gpt2.blocks[0], cache["resid_pre", 0])

## Full Transformer 

Собираем все в один большой трансформер.
1. Применяем эмбеддинги и позиционные эмбеддинги, складываем результаты
2. Прогоняем в цикле через все блоки трансформера
3. Применяем финальную нормализацию и lm_head

In [None]:
class DemoTransformer(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.embed = Embed(cfg)
        self.pos_embed = PosEmbed(cfg)
        self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
        self.ln_final = LayerNorm(cfg)
        self.unembed = Unembed(cfg)

    def forward(self, input_ids: Int[torch.Tensor, "batch seq_len"]) -> Float[torch.Tensor, "batch seq_len d_vocab"]:
        x = self.embed(input_ids) + self.pos_embed(input_ids)

        for block in self.blocks:
            x = block(x)

        return self.unembed(self.ln_final(x))
        

rand_int_test(DemoTransformer, [2, 4])
load_gpt2_test(DemoTransformer, reference_gpt2, tokens)

In [None]:
demo_gpt2 = DemoTransformer(Config(debug=False)).to(device)
demo_gpt2.load_state_dict(reference_gpt2.state_dict(), strict=False)

demo_logits = demo_gpt2(tokens)

In [None]:
demo_gpt2

In [None]:
reference_gpt2

In [None]:
def get_log_probs(
    logits: Float[torch.Tensor, "batch posn d_vocab"],
    tokens: Int[torch.Tensor, "batch posn"]
) -> Float[torch.Tensor, "batch posn-1"]:

    log_probs = logits.log_softmax(dim=-1)
    # Get logprobs the first seq_len-1 predictions (so we can compare them with the actual next tokens)
    log_probs_for_tokens = log_probs[:, :-1].gather(dim=-1, index=tokens[:, 1:].unsqueeze(-1)).squeeze(-1)

    return log_probs_for_tokens


pred_log_probs = get_log_probs(demo_logits, tokens)
print(f"Avg cross entropy loss: {-pred_log_probs.mean():.4f}")
print(f"Avg cross entropy loss for uniform distribution: {math.log(demo_gpt2.cfg.d_vocab):4f}")
print(f"Avg probability assigned to correct token: {pred_log_probs.exp().mean():4f}")

In [None]:
test_string = '''The Total Perspective Vortex derives its picture of the whole Universe on the principle of'''
for i in tqdm(range(100)):
    test_tokens = reference_gpt2.to_tokens(test_string).to(device)
    demo_logits = demo_gpt2(test_tokens)
    test_string += reference_gpt2.tokenizer.decode(demo_logits[-1, -1].argmax())

print(test_string)

## Sampling


1. **Temperature Sampling**:
   - Применяется первым, поскольку изменение температуры изменяет масштабы логитов перед дальнейшими операциями.

2. **Frequency Penalty**:
   - Применяется следующим, чтобы учесть частоты токенов до того, как логиты будут обрезаны методами top-k или top-p.

3. **Top-k Sampling**:
   - Применяется после temperature sampling и frequency penalty, так как он отбирает фиксированное количество наиболее вероятных токенов.

4. **Top-p (Nucleus Sampling)**:
   - Применяется после top-k sampling, чтобы отфильтровать токены на основе совокупной вероятности.

Обозначим размер словаря для удобства $\Sigma = vocab\_size$

Пусть $ \text{logits} \in \mathbb{R}^{\text{seq} \times \Sigma} $:

1. **Temperature Sampling**:
   $$
   \text{logits}'_{i,j} = \frac{\text{logits}_{i,j}}{T} \quad \forall \ i \in [1, \text{seq}], \ j \in [1, |vocab_size|]
   $$

2. **Frequency Penalty**:
   $$
   \text{penalty}(t_j) = 1 + \alpha \cdot f(t_j) \\
   \text{logits}''_{i,j} = \frac{\text{logits}'_{i,j}}{\text{penalty}(t_j)} \quad \forall \ i \in [1, \text{seq}], \ j \in [1, \Sigma]
   $$

3. **Top-k Sampling**:
   $$
   top\_k\_indices_i = \text{argtop-k}(\text{logits}''_i, k) \quad \forall \ i \in [1, \text{seq}] \\
   \text{mask}_{i,j} =
   \begin{cases}
   1 & \text{если} \ j \in top\_k\_indices_i \\
   0 & \text{иначе}
   \end{cases} \\
   \text{logits}'''_{i,j} = \text{logits}''_{i,j} \cdot \text{mask}_{i,j} \quad \forall \ i \in [1, \text{seq}], \ j \in [1, \Sigma]
   $$

4. **Top-p (Nucleus Sampling)**:
   $$
   sorted\_logits_i, sorted\_indices_i = \text{sort}(\text{logits}'''_i, \text{descending=True}) \quad ∀ \ i \in [1, \text{seq}] \\
   probs_i = softmax(sorted\_logits_i) \quad \\
    cumulative\_probs_{i,j} = \sum_{k=1}^{j} \text{probs}_{i,k} \quad \forall \ i \in [1, \text{seq}], \ j \in [1, \Sigma
    \quad \forall \ i \in [1, \text{seq}] \\
   top\_p\_mask_{i,j} =
   \begin{cases}
   1, & cumulative\_probs_{i,j} \leq p \\
   0 &
   \end{cases} \\
   \text{logits}^{\text{final}}_{i,j} = sorted\_logits_{i,j} \cdot top\_p\_mask_{i,j} \quad \forall \ i \in [1, \text{seq}], \ j \in [1, \Sigma]
   $$

5. **Softmax**:
   $$
   \mathbf{probs}_{i,j} = \text{softmax}(\text{logits}^{\text{final}}_{i,j}) \quad \forall \ i \in [1, \text{seq}], \ j \in [1, |\Sigma|]
$$
$$
   \mathbf{probs}_{i,j} = \frac{e^{\text{logits}^{\text{final}}_{i,j}}}{\sum_{k=1}^{|\Sigma|} e^{\text{logits}^{\text{final}}_{i,k}}}
   $$

In [None]:
model_cfg = Config()
model = DemoTransformer(model_cfg).to(device)
model.load_state_dict(reference_gpt2.state_dict(), strict=False) # gpt2 weights

model = model.to(device)
tokenizer = reference_gpt2.tokenizer

In [None]:
class TransformerSampler:

    def __init__(
        self, 
        model: DemoTransformer, 
        tokenizer: GPT2TokenizerFast
    ):
        self.model = model
        self.cfg = model.cfg
        self.tokenizer = tokenizer

    @torch.inference_mode()
    def sample(
        self,
        prompt: str,
        max_tokens_generated=100, 
        verbose=False, 
        **kwargs
    ):
        '''
        Возвращаем сгенерированную строку, включая промпт.
        Генерация заканчивается после max_tokens_generated токенов или по генерации EOS.
        
        kwargs передаются в sample_next_token
        '''

        input_ids = self.tokenizer(prompt, return_tensors='pt')['input_ids'].squeeze(0)
        new_text = input_ids.tolist()

        for _ in range(max_tokens_generated):
            logits = self.model(input_ids.unsqueeze(0).to(device)).detach().cpu().squeeze(0)[-1]
            new_token = self.sample_next_token(input_ids, logits, **kwargs)

            if new_token == self.tokenizer.eos_token_id:
                break
            
            new_text.append(new_token)
            input_ids = torch.cat([input_ids, torch.tensor([new_token])], dim=0)

        return self.tokenizer.decode(new_text)


    @staticmethod
    def sample_next_token(
        input_ids: Int[torch.Tensor, "seq_len"],
        logits: Float[torch.Tensor, "d_vocab"],
        temperature=1.0,
        top_k=0,
        top_p=0.0,
        frequency_penalty=0.0,
        seed=None
    ):
        assert input_ids.ndim == 1, "input_ids should be a 1D sequence of token ids"
        assert temperature >= 0, "Temperature should be non-negative"
        assert 0 <= top_p <= 1.0, "Top-p must be a probability"
        assert 0 <= top_k, "Top-k must be non-negative"
        assert not (top_p != 0 and top_k != 0), "At most one of top-p and top-k supported"

        # Set random seeds for reproducibility
        if seed is not None:
            torch.manual_seed(seed)
            np.random.seed(seed)

        # Apply all the specialized sampling methods
        if temperature == 0:
            return TransformerSampler.greedy_search(logits)
        elif temperature != 1.0:
            logits = TransformerSampler.apply_temperature(logits, temperature)
        if frequency_penalty != 0.0:
            logits = TransformerSampler.apply_frequency_penalty(input_ids, logits, frequency_penalty)
        if top_k > 0:
            return TransformerSampler.sample_top_k(logits, top_k)
        if top_p > 0.0:
            return TransformerSampler.sample_top_p(logits, top_p)
        return TransformerSampler.sample_basic(logits)


    @staticmethod
    def greedy_search(
        logits: Float[torch.Tensor, "d_vocab"]
    ) -> int:
        '''
        Возвращаем самый вероятный токен жадно
        '''
        out = logits.argmax().item()
        return out


    @staticmethod
    def apply_temperature(
        logits: Float[torch.Tensor, "d_vocab"], 
        temperature: float
    ) -> Float[torch.Tensor, "d_vocab"]:
        '''
        Применяем температуру к логитам
        '''
        return logits / temperature


    @staticmethod
    def apply_frequency_penalty(
        input_ids: Int[torch.Tensor, "seq_len"],
        logits: Float[torch.Tensor, "d_vocab"], 
        freq_penalty: float
    ) -> Float[torch.Tensor, "d_vocab"]:
        '''
        Применяем frequency penalty к логитам
        '''
        vocab_size = logits.shape[-1]

        freq = torch.zeros(vocab_size)

        ids, counts = torch.unique(input_ids, return_counts=True)
        freq[ids] = counts.float()

        return logits - freq_penalty * freq


    @staticmethod
    def sample_basic(
        logits: Float[torch.Tensor, "d_vocab"]
    ) -> int:
        '''
        Простое сэмплирование! Тут нам поможет torch.multinomial
        '''
        logits = logits.detach().cpu()
        probs = nn.functional.softmax(logits, dim=-1)
        return torch.multinomial(probs, num_samples=1).item()


    @staticmethod
    def sample_top_k(
        logits: Float[torch.Tensor, "d_vocab"], 
        k: int
    ) -> int:
        '''
        top-k сэмплирование
        '''
        logits = logits.detach().cpu()
        values, indices = torch.topk(logits, k, dim=-1)
        mask = torch.full_like(logits, float('-inf'))
        mask.scatter_(dim=-1, index=indices, src=values)

        probs = nn.functional.softmax(logits, dim=-1)
        return torch.multinomial(probs, num_samples=1).item()


    @staticmethod
    def sample_top_p(
        logits: Float[torch.Tensor, "d_vocab"], 
        top_p: float, 
        min_tokens_to_keep: int = 1
    ) -> int:
        '''
        top_p сэмплирование
        '''
        logits = logits.detach().cpu()
        probs = nn.functional.softmax(logits, dim=-1)
        probs, indices = torch.sort(probs, descending=True)

        cumulative_probs = torch.cumsum(probs, dim=-1)
        cutoff = cumulative_probs >= top_p

        if min_tokens_to_keep > 1:
            cutoff[-min_tokens_to_keep:] = True

        cutoff_index = torch.argmax(cutoff.float()).item() + 1

        mask = torch.zeros_like(logits, dtype=torch.bool)
        mask[indices[:cutoff_index]] = True

        logits = torch.where(mask, logits, torch.tensor(float('-inf')))
        probs = nn.functional.softmax(logits, dim=-1)

        return torch.multinomial(probs, num_samples=1).item()

In [None]:
sampler = TransformerSampler(model, tokenizer)

prompt = "Jingle bells, jingle bells, jingle all the way"
print(f"Greedy decoding with prompt: {prompt!r}\n")

output = sampler.sample(prompt, max_tokens_generated=8, temperature=0.0)
print(f"Your model said: {output!r}\n")

expected = "Jingle bells, jingle bells, jingle all the way up to the top of the mountain."
assert output == expected

print("Tests passed!")

In [None]:
logits = torch.tensor([1, 2]).log()

cold_logits = TransformerSampler.apply_temperature(logits, temperature=0.001)
print('A low temperature "sharpens" or "peaks" the distribution: ', cold_logits)
torch.testing.assert_close(cold_logits, 1000.0 * logits)

hot_logits = TransformerSampler.apply_temperature(logits, temperature=1000.0)
print("A high temperature flattens the distribution: ", hot_logits)
torch.testing.assert_close(hot_logits, 0.001 * logits)

print("Tests passed!")

In [None]:
bieber_prompt = "And I was like Baby, baby, baby, oh Like, Baby, baby, baby, no Like, Baby, baby, baby, oh I thought you'd always be mine, mine"
input_ids = tokenizer.encode(bieber_prompt, return_tensors="pt")
logits = torch.ones(tokenizer.vocab_size)
penalized_logits = TransformerSampler.apply_frequency_penalty(input_ids.squeeze(), logits, 2.0)

assert penalized_logits[5156].item() == -11, "Expected 6 occurrences of ' baby' with leading space, 1-2*6=-11"
assert penalized_logits[14801].item() == -5, "Expected 3 occurrences of ' Baby' with leading space, 1-2*3=-5"

print("Tests passed!")

In [None]:
prompt = "John and Mary went to the"
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
logits = model(input_ids)[0, -1]

expected_top_10pct = {
    " church": 0.0648,
    " house": 0.0367, # These are the two most likely tokens, and add up to >10%
}
top_10pct_sum = sum(expected_top_10pct.values())

observed_freqs = defaultdict(int)

N = 10000
for _ in tqdm(range(N)):
    token = TransformerSampler.sample_next_token(input_ids.squeeze(), logits, top_p=0.1)
    observed_freqs[tokenizer.decode(token)] += 1

for word in expected_top_10pct:
    expected_freq = expected_top_10pct[word] / top_10pct_sum
    observed_freq = observed_freqs[word] / N
    print(f"Word: {word!r:<9}. Expected freq {expected_freq:.4f}, observed freq {observed_freq:.4f}")
    assert abs(observed_freq - expected_freq) < 0.01, "Try increasing N if this fails by a small amount."

In [None]:
prompt = "John and Mary went to the"
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
logits = model(input_ids)[0, -1]

expected_top_10pct = {
    " church": 0.0648,
    " house": 0.0367, # These are the two most likely tokens, and add up to >10%
}
top_10pct_sum = sum(expected_top_10pct.values())

observed_freqs = defaultdict(int)

N = 10000
for _ in tqdm(range(N)):
    token = TransformerSampler.sample_next_token(input_ids.squeeze(), logits, top_p=0.1)
    observed_freqs[tokenizer.decode(token)] += 1

for word in expected_top_10pct:
    expected_freq = expected_top_10pct[word] / top_10pct_sum
    observed_freq = observed_freqs[word] / N
    print(f"Word: {word!r:<9}. Expected freq {expected_freq:.4f}, observed freq {observed_freq:.4f}")
    assert abs(observed_freq - expected_freq) < 0.01, "Try increasing N if this fails by a small amount."