### Data

In [None]:
! wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

### Dependencies

In [None]:
!pip install jaxtyping

In [None]:
import math
import random
import numpy as np
from tqdm.notebook import tqdm
from typing import Tuple, List
import matplotlib.pyplot as plt
from dataclasses import dataclass

import torch
import einops
import torch.nn as nn
import torch.optim as optim
from jaxtyping import Float, Int
from transformers import AutoTokenizer
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import RandomSampler

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

### Data preprocessing

In [None]:
with open("input.txt") as fin:
    text = fin.read()
    
print(text[:200])

In [None]:
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path='openai-community/gpt2')
assert tokenizer.tokenize("Hello there sometrashtoken") == ['Hello', 'Ġthere', 'Ġsomet', 'r', 'ash', 'token']
assert tokenizer.eos_token == "<|endoftext|>"

Tokenizer doesn't have special token for PAD

In [None]:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

### Dataset

In [None]:
class MyDataset(Dataset):
    
    def __init__(self, tokenizer: AutoTokenizer, text: str):
        self.tokenizer = tokenizer
        self.texts: List[List[int]] = []
        random.seed(1)
        tokenized_tokens = self.tokenizer.encode(text)
        i = 0
        while i < len(tokenized_tokens):
            seq_len = random.randint(200, 301)
            self.texts.append(tokenized_tokens[i:i+seq_len])
            i += seq_len

    
    def __getitem__(self, index) -> List[int]:
        return self.texts[index]
    
    
    def __len__(self) -> int:
        return self.texts.__len__()
    

dataset = MyDataset(tokenizer, text)
sample_0 = dataset.tokenizer.decode(dataset[0])

assert sample_0.startswith(text[:100])

print(sample_0)

### Collate FN 
Принимает `List[List[int]]` батч объектов и возвращает 2 тензора:

* input_ids - `[batch, seq_len]` - батч токенов, в котором добавлены паддинги до максимальной длины в батче.
* mask - `[batch, seq_len]` - батч масок. На позиции `[i, j]` стоит 0, если токен является паддингом, иначе 1.

In [None]:
def collate_fn(batch: List[List[int]]) -> Tuple[torch.LongTensor, torch.LongTensor]:
    max_seq_len = max(len(elem) for elem in batch)
    input_ids = torch.LongTensor([seq + [tokenizer.pad_token_id] * (max_seq_len - len(seq)) for seq in batch])
    mask = torch.LongTensor([[1 if elem != tokenizer.pad_token_id else 0 for elem in seq] for seq in input_ids])
    return (input_ids, mask)


batch = [
    [1, 2, 3, 4],
    [1, 2],
    [1, 2, 3, 4, 5, 6, 7],
]
input_ids_ref = torch.LongTensor([
    [1, 2, 3, 4, 50256, 50256, 50256],
    [1, 2, 50256, 50256, 50256, 50256, 50256],
    [1, 2, 3, 4, 5, 6, 7],
])


mask_ref = torch.LongTensor([
    [1, 1, 1, 1, 0, 0, 0],
    [1, 1, 0, 0, 0, 0, 0],
    [1, 1, 1, 1, 1, 1, 1],
])

input_ids, mask = collate_fn(batch)

assert (input_ids == input_ids_ref).all()
assert (mask == mask_ref).all()

### DataLoader 

In [None]:
BATCH_SIZE = 16
sampler = RandomSampler(dataset)
train_loader = DataLoader(
    dataset=dataset,
    batch_size=16,
    shuffle=False,
    sampler=sampler,
    collate_fn=collate_fn,
)

for input_ids, mask in train_loader:
    break
print(mask)

assert (mask.sum(dim=1) < mask.size(1)).sum() < mask.size(0)

### Transformer

In [None]:
@dataclass
class Config:
    d_model: int = 768 # он же hidden_dim - внутрення размерность модели
    debug: bool = True
    layer_norm_eps: float = 1e-5 
    d_vocab: int = 50257 # он же vocab_size, размер словаря модели
    init_range: float = 0.02
    n_ctx: int = 1024 # число позиционных эмбеддингов
    d_head: int = 64 # размерность головы аттеншена
    d_mlp: int = 3072 # внутренняя размерность FFN-слоя
    n_heads: int = 12 # число голов аттеншена
    n_layers: int = 12 # число слоев трансформера

cfg = Config()
print(cfg)

In [None]:
def GELU(x: torch.Tensor):
    return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * x ** 3)))


class Embed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_E = nn.Parameter(torch.empty((cfg.d_vocab, cfg.d_model)))
        nn.init.normal_(self.W_E, std=self.cfg.init_range)

    def forward(
        self, 
        input_ids: Int[torch.Tensor, "batch seq_len"]
    ) -> Float[torch.Tensor, "batch seq_len d_model"]:
        return nn.functional.embedding(input=input_ids, weight=self.W_E)


class PosEmbed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_pos = nn.Parameter(torch.empty((cfg.n_ctx, cfg.d_model)))
        nn.init.normal_(self.W_pos, std=self.cfg.init_range)

    def forward(
        self, 
        input_ids: Int[torch.Tensor, "batch seq_len"]
    ) -> Float[torch.Tensor, "batch seq_len d_model"]:
        positions = torch.arange(input_ids.shape[1], device=input_ids.device)\
            .unsqueeze(0)\
                .expand(input_ids.shape[0], input_ids.shape[1])
        return nn.functional.embedding(positions, self.W_pos)
    
    
class Unembed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_U = nn.Parameter(torch.empty((cfg.d_model, cfg.d_vocab)))
        nn.init.normal_(self.W_U, std=self.cfg.init_range)
        self.b_U = nn.Parameter(torch.zeros((cfg.d_vocab), requires_grad=False))

    def forward(
        self, x: Float[torch.Tensor, "batch seq_len d_model"]
    ) -> Float[torch.Tensor, "batch seq_len d_vocab"]:
        return torch.einsum('abc, cd -> abd', x, self.W_U) + self.b_U
    
        
class MLP(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_in = nn.Parameter(torch.empty((cfg.d_model, cfg.d_mlp)))
        self.W_out = nn.Parameter(torch.empty((cfg.d_mlp, cfg.d_model)))
        self.b_in = nn.Parameter(torch.zeros((cfg.d_mlp)))
        self.b_out = nn.Parameter(torch.zeros((cfg.d_model)))
        nn.init.normal_(self.W_in, std=self.cfg.init_range)
        nn.init.normal_(self.W_out, std=self.cfg.init_range)

    def forward(
        self, x: Float[torch.Tensor, "batch seq_len d_model"]
    ) -> Float[torch.Tensor, "batch seq_len d_model"]:
        res_in = torch.matmul(x, self.W_in) + self.b_in
        return torch.matmul(GELU(res_in), self.W_out) + self.b_out

### RMSNorm
https://arxiv.org/pdf/1910.07467

In [None]:
class RMSNorm(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.w = nn.Parameter(torch.ones(cfg.d_model)) # gamma

    def forward(
        self, 
        x: Float[torch.Tensor, "batch seq_len d_model"]
    ) -> Float[torch.Tensor, "batch seq_len d_model"]:
        return x / torch.sqrt(torch.sum(torch.square(x), dim=-1, keepdim=True) / x.size(-1)) * self.w
    
    
cfg_rmsnorm = Config(d_model=5)
x = torch.Tensor([[[0.1, 0.2, 0.3, 0.4, 0.5]]])
layer = RMSNorm(cfg_rmsnorm)
y = torch.Tensor([[[0.3015, 0.6030, 0.9045, 1.2060, 1.5076]]])
assert torch.allclose(y, layer(x), atol=1e-4, rtol=1e-3)

### Attention

#### Attention Masking 


In [None]:
class Attention(nn.Module):
    IGNORE: Float[torch.Tensor, ""]

    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        
        self.W_Q = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.b_Q = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        
        self.W_K = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.b_K = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        
        self.W_V = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.b_V = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        
        self.W_O = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_head, cfg.d_model)))
        self.b_O = nn.Parameter(torch.zeros((cfg.d_model)))
        
        nn.init.normal_(self.W_Q, std=self.cfg.init_range)
        nn.init.normal_(self.W_K, std=self.cfg.init_range)
        nn.init.normal_(self.W_V, std=self.cfg.init_range)
        nn.init.normal_(self.W_O, std=self.cfg.init_range)
        self.register_buffer("IGNORE", torch.tensor(float("-inf"), dtype=torch.float32, device=device))

    def forward(
        self, x: Float[torch.Tensor, "batch seq_len d_model"], mask: Int[torch.Tensor, "batch seq_len"]
    ) -> Float[torch.Tensor, "batch seq_len d_model"]:
        # Берем размерности
        batch_size, seq_len, d_model = x.shape
        num_heads = self.cfg.n_heads
        d_head = self.cfg.d_head
        
        # 1. Трансформируем матрицы проекций в формат [d_model, d_model]
        W_Q = self.W_Q.permute(1, 0, 2).reshape(self.cfg.d_model, self.cfg.d_model)
        W_K = self.W_K.permute(1, 0, 2).reshape(self.cfg.d_model, self.cfg.d_model)
        W_V = self.W_V.permute(1, 0, 2).reshape(self.cfg.d_model, self.cfg.d_model)
        
        b_Q = self.b_Q.view(-1)
        b_K = self.b_K.view(-1)
        b_V = self.b_V.view(-1)
        
        # 1. получаем проекции  Q, K, V
        Q = torch.matmul(x, W_Q) + b_Q
        K = torch.matmul(x, W_K) + b_K
        V = torch.matmul(x, W_V) + b_V

        Q = einops.rearrange(Q, 'b s (n d) -> b n s d', n=num_heads)
        K = einops.rearrange(K, 'b s (n d) -> b n s d', n=num_heads) 
        V = einops.rearrange(V, 'b s (n d) -> b n s d', n=num_heads) 

        # 2. Q x K^T
        scores = torch.einsum('b n i d, b n j d -> b n i j', Q, K)
        
        # 3. Нормализация
        scores /= math.sqrt(d_head)
        
        # 4. Маскирование
        scores = self.apply_causal_mask(scores, mask)
        
        # 5. softmax
        scores = torch.softmax(scores, dim=-1)

        # 6. Финальная проекция
        attention = torch.einsum('b h i j, b h j d -> b h i d', scores, V)
        attention = torch.einsum('b h i d, h d m -> b i m', attention, self.W_O) + self.b_O

        return attention

    def apply_causal_mask(
        self, 
        attn_scores: Float[torch.Tensor, "batch n_heads seq_len seq_len"], 
        mask: Int[torch.Tensor, "batch seq_len"]
    ) -> Float[torch.Tensor, "batch n_heads seq_len seq_len"]:
        '''
        Applies a causal mask to attention scores, and returns masked scores.
        Используем треугольную маску, чтобы не смотреть в будущее.
        В качестве масикировочного значения перед софтмаксом можно использовать self.IGNORE (-inf)
        '''
        seq_len = attn_scores.size(-1)
        old_device = attn_scores.device
        attn_scores = attn_scores.to(device)
        
        causal_mask = torch.triu(torch.ones((seq_len, seq_len), device=attn_scores.device), diagonal=1).bool()
        attn_scores = attn_scores.masked_fill(causal_mask, value=self.IGNORE)

        mask = mask.to(device)
        padding_mask = mask.unsqueeze(1).unsqueeze(2)
        attn_scores = attn_scores.masked_fill(padding_mask == 0, value=self.IGNORE)

        return attn_scores.to(old_device)


mask_padding = torch.LongTensor([
    [1, 1, 1, 1, 0, 0, 0],
    [1, 1, 0, 0, 0, 0, 0],
    [1, 1, 1, 1, 1, 1, 1],
])

lengths = mask_padding.sum(dim=1).tolist()


batch_size = 3
seq_len = 7
d_head = 8
n_heads = 4

x = torch.rand(batch_size, n_heads, seq_len, seq_len)

attn = Attention(cfg)
softmax_res = torch.softmax(attn.apply_causal_mask(x, mask_padding), dim=-1)

for batch_idx in range(batch_size):
    for head_idx in range(n_heads):
        sm = softmax_res[batch_idx, head_idx]
        l = lengths[batch_idx]
        for i in range(seq_len):
            for j in range(seq_len):
                # i < j - Causal mask, проверяем, что не смотрим в будущее
                # j >= l - проверяем, что не смотрим на паддинги
                if i < j or j >= l:
                    assert sm[i, j] == 0, (batch_idx, head_idx, i, j, sm[i, j])
                
_ = attn(torch.rand(batch_size, seq_len, 768), mask_padding)

### Rotary Embeddings  
https://arxiv.org/pdf/2104.09864

In [None]:
class RotaryPositionalEmbeddings(nn.Module):
    
    def __init__(self, cfg: Config, theta: int = 10_000):
        super().__init__()
        self.cfg = cfg
        self.max_seq_len = cfg.n_ctx
        self.theta = theta
        self.d = cfg.d_head
        
        # Углы theta_i. 
        freqs = theta ** (-2 * torch.arange(self.d // 2) / self.d).to(device)
        position_id = torch.arange(0, self.max_seq_len).float().to(device)
        
        # нужно получить матрицу m theta_i размера [max_seq_len, self.d] вида m theta_i
        # где m берется из position_id, а theta из freqs
        
        idx_theta = torch.einsum('i, j -> ij', position_id, freqs)
        
        # max_seq_len, d_head
        cos = idx_theta.cos()
        sin = idx_theta.sin()
        
        # нужно продублировать размерности для формулы 34. theta_i встерчается два раза подряд в синусах и косинуса
        # тут нам поможет torch.repeat_interleave
        cos = cos.repeat_interleave(2, dim=-1)
        sin = sin.repeat_interleave(2, dim=-1)
        
        # 1, max_seq_len, 1, d_head
        self.register_buffer("sin", sin.unsqueeze(1).unsqueeze(0).to(device))
        self.register_buffer("cos", cos.unsqueeze(1).unsqueeze(0).to(device))
    
    @staticmethod
    def rotate_neg_vector(
        x: Float[torch.Tensor, "batch seq_len num_heads d_head"]
    ):
        # На входе x = [x1, x2, x3, x4, ... x_{n-1}, x_n]
        # На выходе x' = [-x2, x1, -x4, x3, ..., -x_n, x_{n-1}]
        x_new = torch.empty_like(x)
        x1, x2 = x[..., ::2], x[..., 1::2]
        x_new = torch.stack([-x2, x1], dim=-1).reshape_as(x)
        return x_new.to(device)
    
    def forward(
        self, 
        x: Float[torch.Tensor, "batch seq_len num_heads d_head"]
    ):
        old_device = x.device
        x = x.to(device)
        seq_len = x.size(1)
        x_rotated = self.rotate_neg_vector(x)
        res = x * self.cos[:, :seq_len, :, :] + x_rotated * self.sin[:, :seq_len, :, :]
        return res.to(old_device)
    

batch_size = 1
seq_len = 3
num_heads = 2
d_head = 16

torch.manual_seed(1)
x = torch.rand(batch_size, seq_len, num_heads, d_head)

rope_config = Config(
    n_heads=2,
    d_head=16,
)

rope_layer = RotaryPositionalEmbeddings(rope_config)
y = rope_layer(x)


thetas = [10_000 ** (-2 * (i - 1) / rope_config.d_head) for i in range(1, rope_config.d_head // 2 + 1)]
all_good = True
for batch_idx in range(batch_size):
    for m in range(seq_len):
        if not all_good:
            break
        for head_idx in range(num_heads):
            if not all_good:
                break
            for d_idx in range(d_head):
                # 0, 2, 4
                if d_idx % 2 == 0:
                    val = x[batch_idx, m, head_idx, d_idx] * math.cos(m * thetas[d_idx // 2]) - x[batch_idx, m, head_idx, d_idx + 1] * math.sin(m * thetas[d_idx // 2])
                else:
                    val = x[batch_idx, m, head_idx, d_idx] * math.cos(m * thetas[d_idx // 2]) + x[batch_idx, m, head_idx, d_idx - 1] * math.sin(m * thetas[d_idx // 2])
                if abs(y[batch_idx, m, head_idx, d_idx] - val) > 1e-3:
                    print(f"Ошибка на позиции {m} и размерности {d_idx} в голове {head_idx}")
                    print(f"Полученное значение {y[batch_idx, m, head_idx, d_idx]}, референс {val}")
                    all_good = False
                    break

if all_good:
    print("Тесты прошли успешно!")

### Rope X Attention

In [None]:
class AttentionWithRope(nn.Module):
    IGNORE: Float[torch.Tensor, ""]

    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        
        self.W_Q = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.b_Q = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        
        self.W_K = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.b_K = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        
        self.W_V = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.b_V = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        
        self.W_O = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_head, cfg.d_model)))
        self.b_O = nn.Parameter(torch.zeros((cfg.d_model)))

        nn.init.normal_(self.W_Q, std=self.cfg.init_range)
        nn.init.normal_(self.W_K, std=self.cfg.init_range)
        nn.init.normal_(self.W_V, std=self.cfg.init_range)
        nn.init.normal_(self.W_O, std=self.cfg.init_range)
        self.register_buffer("IGNORE", torch.tensor(float("-inf"), dtype=torch.float32, device=device))

        self.rope = RotaryPositionalEmbeddings(cfg)

    def forward(
        self, 
        x: Float[torch.Tensor, "batch seq_len d_model"], 
        mask: Int[torch.Tensor, "batch seq_len"]
    ) -> Float[torch.Tensor, "batch seq_len d_model"]:
        # Берем размерности
        batch_size, seq_len, d_model = x.shape
        num_heads = self.cfg.n_heads
        d_head = self.cfg.d_head

        # 1. Трансформируем матрицы проекций в формат [d_model, d_model]
        W_Q = self.W_Q.permute(1, 0, 2).reshape(self.cfg.d_model, self.cfg.d_model)
        W_K = self.W_K.permute(1, 0, 2).reshape(self.cfg.d_model, self.cfg.d_model)
        W_V = self.W_V.permute(1, 0, 2).reshape(self.cfg.d_model, self.cfg.d_model)
        
        b_Q = self.b_Q.view(-1)
        b_K = self.b_K.view(-1)
        b_V = self.b_V.view(-1)
        
        # 1. получаем проекции  Q, K, V
        Q = torch.matmul(x, W_Q) + b_Q
        K = torch.matmul(x, W_K) + b_K
        V = torch.matmul(x, W_V) + b_V

        Q = einops.rearrange(Q, 'b s (n d) -> b n s d', n=num_heads)
        K = einops.rearrange(K, 'b s (n d) -> b n s d', n=num_heads) 
        V = einops.rearrange(V, 'b s (n d) -> b n s d', n=num_heads) 

        # 2. применяю RoPE для вращения Q и K
        Q = self.rope(Q)
        K = self.rope(K)

        # 3. Q x K^T
        scores = torch.einsum('b n i d, b n j d -> b n i j', Q, K)
        
        # 4. Нормализация
        scores /= math.sqrt(d_head)
        
        # 5. Маскирование
        scores = self.apply_causal_mask(scores, mask)
        
        # 6. softmax
        scores = torch.softmax(scores, dim=-1)

        # 7. Финальная проекция
        attention = torch.einsum('b h i j, b h j d -> b h i d', scores, V)
        attention = torch.einsum('b h i d, h d m -> b i m', attention, self.W_O) + self.b_O

        return attention
    
    def apply_causal_mask(
        self, 
        attn_scores: Float[torch.Tensor, "batch n_heads seq_len seq_len"], 
        mask: Int[torch.Tensor, "batch seq_len"]
    ) -> Float[torch.Tensor, "batch n_heads seq_len seq_len"]:
        '''
        Applies a causal mask to attention scores, and returns masked scores.
        Используем треугольную маску, чтобы не смотреть в будущее.
        В качестве масикировочного значения перед софтмаксом можно использовать self.IGNORE (-inf)
        '''
        seq_len = attn_scores.size(-1)
        old_device = attn_scores.device
        attn_scores = attn_scores.to(device)
        
        causal_mask = torch.triu(torch.ones((seq_len, seq_len), device=attn_scores.device), diagonal=1).bool()
        attn_scores = attn_scores.masked_fill(causal_mask, value=self.IGNORE)

        mask = mask.to(device)
        padding_mask = mask.unsqueeze(1).unsqueeze(2)
        attn_scores = attn_scores.masked_fill(padding_mask == 0, value=self.IGNORE)

        return attn_scores.to(old_device)

## Final Transformer

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.ln1 = RMSNorm(cfg)
        self.attn = AttentionWithRope(cfg)
        self.ln2 = RMSNorm(cfg)
        self.mlp = MLP(cfg)

    def forward(
        self,
        x: Float[torch.Tensor, "batch seq_len d_model"],
        mask: Float[torch.Tensor, "batch seq_len"] 
    ) -> Float[torch.Tensor, "batch seq_len d_model"]:
        x = x + self.attn(self.ln1(x), mask)
        return x + self.mlp(self.ln2(x))

class DemoTransformer(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.embed = Embed(cfg)
        self.pos_embed = PosEmbed(cfg)
        self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
        self.ln_final = RMSNorm(cfg)
        self.unembed = Unembed(cfg)

    def forward(
        self, 
        input_ids: Int[torch.Tensor, "batch seq_len"],
        mask: Int[torch.Tensor, "batch seq_len"]
    ) -> Float[torch.Tensor, "batch seq_len d_vocab"]:
        x = self.embed(input_ids)
        # x = self.embed(input_ids) + self.pos_embed(input_ids)
        for block in self.blocks:
            x = block(x, mask)

        x = self.ln_final(x)
        return self.unembed(x)

#### Final tests

In [None]:
train_config = Config(
    d_model=128,
    n_ctx=512,
    n_heads=8,
    d_head=16,
    d_mlp=512,
    n_layers=12
)
model = DemoTransformer(train_config)

In [None]:
for input_ids, mask in train_loader:
    break

with torch.no_grad():
    p = model(input_ids, mask)

assert list(p.shape) == [input_ids.size(0), input_ids.size(1), train_config.d_vocab]

## Training

1. Берем input_ids, mask, прогоняем через модель, получаем тензор p `[batch_size, seq_len, vocab_size]`
2. В качестве меток мы берем **те же input_ids**. Только их нужно сдвинуть на 1 вправо, т.к. i-й токен предсказывает (i + 1)-й
3. В качестве предиктов берем **input_ids**. Только начало нужно тоже обрезать, т.к. у нас нет токенов, которые занимались бы предсказанием 0-го токена в последоватсельности
4. Паддингам ставим метки -100, это значение ignore_loss, [CrossEntropyLoss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) его игнорирует при подсчете лососв
5. Превращаем p в тензор `[batch_size * (seq_len - 1), vocab_size]`, вектор правильных меток labels (из input_ids) превращаем в `[batch_size * (seq_len - 1)]`, считаем функцию потерь

In [None]:
criterion = nn.CrossEntropyLoss()
pad_id = 50256

def calculate_loss(critertion, logits, input_ids, pad_id=pad_id):
    logits = logits.to(device)
    input_ids = input_ids.to(device)


    sequence = input_ids[:, 1:]
    last_vector = torch.full((input_ids.size(0), 1), -100).to(device)
    shifted_ids = torch.cat((sequence, last_vector), dim=1)
    shifted_ids[shifted_ids == pad_id] = -100

    return critertion(logits.view(-1, logits.size(-1)), shifted_ids.view(-1))



batch_size = 2
seq_len = 4
num_classes = 7

# batch_size seq_len
input_ids = torch.LongTensor(
    [
        [0, 1,  pad_id, pad_id],
        [0, 1, 2, 3]
    ]
)


# batch_size, seq_len, num_classes
logits = torch.Tensor(
    [[[0.7576, 0.2793, 0.4031, 0.7347, 0.0293, 0.7999, 0.3971],
         [0.7544, 0.5695, 0.4388, 0.6387, 0.5247, 0.6826, 0.3051],
         [0.4635, 0.4550, 0.5725, 0.4980, 0.9371, 0.6556, 0.3138],
         [0.1980, 0.4162, 0.2843, 0.3398, 0.5239, 0.7981, 0.7718]],

        [[0.0112, 0.8100, 0.6397, 0.9743, 0.8300, 0.0444, 0.0246],
         [0.2588, 0.9391, 0.4167, 0.7140, 0.2676, 0.9906, 0.2885],
         [0.8750, 0.5059, 0.2366, 0.7570, 0.2346, 0.6471, 0.3556],
         [0.4452, 0.0193, 0.2616, 0.7713, 0.3785, 0.9980, 0.9008]]]
)

loss = calculate_loss(criterion, logits, input_ids)

assert (loss.item() - 1.9343) < 1e-2

In [None]:
model = model.to(device)
model = model.train()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

losses = []
for epoch in range(10):
    for input_ids, mask in tqdm(train_loader):
        optimizer.zero_grad()

        logits = model(input_ids.to(device), mask.to(device))
        loss = calculate_loss(criterion, logits, input_ids)

        loss.backward()
        optimizer.step()
        losses.append(loss.item())

In [None]:
plt.plot(losses)

## Generation
1. Подаем input_ids, mask
2. По последнему токену жадно предсказываем следующий
3. Конактенируем этот токен к input_ids, расширяем mask
4. Повторяем num_tokens_to_generate раз

In [None]:
input_text = text[:5]
inputs = tokenizer(input_text, return_tensors="pt")

input_ids = inputs["input_ids"].to(device)
mask = inputs["attention_mask"].to(device)

orig_size = input_ids.size(1)

num_tokens_to_generate = 10

with torch.no_grad():
    for i in range(num_tokens_to_generate):
        output = model(input_ids, mask)
        logits = output[:, -1, :]
        output = torch.argmax(logits, dim=-1).unsqueeze(-1)
        input_ids = torch.cat([input_ids, output], dim=-1)
        output_mask = torch.ones((mask.size(0), 1), dtype=torch.long).to(device)
        mask = torch.cat([mask, output_mask], dim=-1)

print("Input text:", input_text)
print("Input text + Generated", tokenizer.decode(input_ids[0]))