# Его величество, "домашка №1"

В этой домашней работе вам предоставится уникальная возможность обучить Byte-level BPE токенизатор и небольшую LM.  

Домашняя работа состоит из нескольких последовательных блоков: реализация и обучение токенизатора, реализация Transformer модели и обучение модели на датасете с русскими анекдотами!

Обученные токенизатор и модель можно и нужно выложить на [🤗 HuggingFace](https://huggingface.co/). Зарегистрируйтесь там, подпишитесь на [deep vk](https://huggingface.co/deepvk) и создайте себе API токен.

Следуйте ячейкам тетрадки и заполняйте пропущенные ячейки. В конце тетрадки вы найдете задачи со звездочкой, чтобы получить максимальный балл!

In [1]:
# Установим необходимые дополнительные библиотеки

%pip install --quiet datasets livelossplot


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3 -m pip install --upgrade pip[0m


In [49]:
# Необходимые импорты

import inspect
import json
import os
from collections import Counter
from dataclasses import dataclass
from functools import lru_cache, partial
from pathlib import Path
from einops import rearrange, einsum

import numpy as np
import regex as re
import torch
import torch.nn as nn
import datasets
from datasets import load_dataset
from huggingface_hub import HfApi, PyTorchModelHubMixin, interpreter_login, snapshot_download
from huggingface_hub.utils import SoftTemporaryDirectory
from livelossplot import PlotLosses
from torch import Tensor
from torch.nn import functional as F
from torch.utils.data import DataLoader
from tqdm.auto import tqdm, trange

In [4]:
# Этой функцией будут помечены все места, которые необходимо дозаполнить
# Это могут быть как целые функции, так и отдельные части внутри них
# Всегда можно воспользоваться интроспекцией и найти места использования этой функции :)


def todo():
    stack = inspect.stack()
    caller_frame = stack[1]
    function_name = caller_frame.function
    line_number = caller_frame.lineno
    raise NotImplementedError(f"TODO at {function_name}, line {line_number}")

In [5]:
interpreter_login()


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|



Enter your token (input will not be visible):  ········
Add token as git credential? (Y/n)  y


Token has not been saved to git credential helper.


[1m[31mCannot authenticate through git-credential as no helper is defined on your machine.
You might have to re-authenticate when pushing to the Hugging Face Hub.
Run the following command in your terminal in case you want to set the 'store' credential helper as default.

git config --global credential.helper store

Read https://git-scm.com/book/en/v2/Git-Tools-Credential-Storage for more details.[0m


In [None]:
# Подготовим репозиторий для будущей модели и токенизатора
username = HfApi().whoami()["name"]
REPO_NAME = f"{username}/llm-course-hw1"  # Или как вам хочется

print(f"Homework repository: '{REPO_NAME}'")

# И другие полезные вещи
SEED = 0xC0FFEE

# Датасет

Первым делом загрузим данные: [🤗 IgorVolochay/russian_jokes](https://huggingface.co/datasets/IgorVolochay/russian_jokes)

И немного посмотрим на них 👀

In [7]:
dataset = load_dataset("IgorVolochay/russian_jokes")
print("\n===\n".join(dataset["train"]["text"][:3]))

- Зять, а ты знаешь, где найти того мужчину, который спас меня, когда я тонула?- Да, он уже приходил ко мне извиняться!
===
После проведения акции "К животным по-человечески" животные посовещались и решили провести акцию "К человеку по-скотски".
===
Штирлиц пришел домой и сразу завалился на боковую. Средняя от досады заплакала.


In [8]:
# Подготовим холдауты
dataset = dataset["train"].train_test_split(test_size=0.1, seed=SEED)
dataset

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 135497
    })
    test: Dataset({
        features: ['text'],
        num_rows: 15056
    })
})

# Токенизатор [6 баллов]

В качестве токенизатора будем использоват Byte-level BPE.

Для этого:
1. Реализуем его обучения, нам необходимо построить словарь заданного размера и набор слияний по этому словарю
2. Обучим токенизатор на датасете
3. Реализуем инференс токенизатора: кодирование текста и декодирование токенов


In [None]:
# Всякие полезности

WHITESPACE_SPLITTER = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")


def bytes_to_unicode() -> dict[int, str]:
    """The original dictionary consists of 256 bytes and their corresponding Unicode characters.
    For example, chr(33) is '!'. However, not all bytes have a visually appealing representation,
    so such characters are skipped and replaced with the first available ones, i.e. shifted by 256.
    """
    initial_bytes = (
        list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
    )
    initial_chars = [chr(it) for it in initial_bytes]
    n = 0
    for byte in range(2**8):
        if byte not in initial_bytes:
            initial_bytes.append(byte)
            initial_chars.append(chr(2**8 + n))
            n += 1
    return dict(sorted(zip(initial_bytes, initial_chars)))

In [None]:
def tokenize_word(word, id2token):
    tokenized_word = []
    for token in word:
        byte_token = token.encode('utf-8')
        token = ''
        for value in byte_token:
            token += id2token[value]
        tokenized_word.append(token)
    return tokenized_word

def merge(merge_pair: tuple[str, str], pair_frequences: Counter[tuple[str, str]], words_by_tokens: Counter[tuple[str]]):
    """Merges a given pair of tokens and update corresponding stats

    Args:
        merge_pair: The pair of tokens to be merged.
        pair_frequences: A counter tracking the frequency of token pairs in the dataset.
        words_by_tokens: A counter mapping tokenized words to their frequencies.

    Returns:
        Updated pair frequences and word tokenization w.r.t. to new token.
    """
    if merge_pair[1] == '':
        return pair_frequences, words_by_tokens
    deleted_words = {}
    pair = merge_pair[0] + merge_pair[1]
    for word in words_by_tokens:
        starts = []
        for i in range(len(word) - 1):
            if word[i] == merge_pair[0] and word[i+1] == merge_pair[1]:
                starts.append(i)
        for start in starts:
            if start >= 1:
                pair_frequences[(word[start-1], merge_pair[0])] -= words_by_tokens[word]
                pair_frequences[(word[start-1], pair)] += words_by_tokens[word]
            if start < len(word)-2:
                pair_frequences[(merge_pair[1], word[start+2])] -= words_by_tokens[word]
                pair_frequences[(pair, word[start+2])] += words_by_tokens[word]

        word_new = []
        prev_start = 0
        for start in starts:
            word_new += word[prev_start:start]
            word_new += [word[start] + word[start+1]]
            prev_start = start
        word_new += word[prev_start+2:]
        if starts:
            deleted_words[word] = tuple(word_new)
    for word in deleted_words:
        words_by_tokens[deleted_words[word]] = words_by_tokens[word]
        del words_by_tokens[word]
    return pair_frequences, words_by_tokens

def train(data: list[str], vocab_size: int = 1024, special_tokens: list[str] = None):
    """Train BPE tokenizer on passed data

    Args:
        data: List of train documents
        vocab_size: Size of target vocabulary
        special_tokens: List of special tokens to add into vocabulary
    Returns:
        vocabulary: mapping from string token to id
        merges: list of merges, each one is tuple of string tokens
    """
    if vocab_size < 256:
        raise ValueError("Vocab size can't be less than 256")
    if special_tokens is None:
        special_tokens = []

    # 1. Initialize vocabulary (using inverse one during training)
    id2token = bytes_to_unicode()
    merges = []
    # 2. Load data
    solo_tokens = Counter()
    words_by_tokens = Counter()
    for sample in tqdm(data, desc="Loading data"):
        # 2.1 Split into words
        words = WHITESPACE_SPLITTER.findall(sample.strip())
        for word in words:
            # 2.2 Tokenize with base vocabulary
            tokenized_word = tokenize_word(word, id2token)
            words_by_tokens.update([tuple(tokenized_word)])
            solo_tokens.update(tokenized_word)
    # 3. Calculate statistic of token's pairs
    pair_frequences = Counter()
    for word in words_by_tokens:
        for index in range(len(word) - 1):
            pair_frequences[(word[index], word[index+1])] += words_by_tokens[word]
    for token in solo_tokens:
        pair_frequences[(token, '')] = solo_tokens[token]
    # 4. Build vocabulary
    pbar = trange(vocab_size, desc="Building vocabulary", initial=len(id2token) + len(special_tokens))
    while len(id2token) < vocab_size - len(special_tokens):
        if len(pair_frequences) == 0:
            print("Not enough data to fulfil vocabulary")
            break

        # 4.1 Find the most frequent pair and create new token
        top_pair = pair_frequences.most_common(1)[0][0]
        new_token = top_pair[0] + top_pair[1]
        del pair_frequences[top_pair]

        # 4.2 Add to vocabulary
        if new_token in id2token.values():
            continue
        id2token[len(id2token)] = new_token
        merges.append(top_pair)

        # 4.3 Update stats and merge the top pair in all tokens
        pair_frequences, words_by_tokens = merge(top_pair, pair_frequences, words_by_tokens)

        pbar.update()
    pbar.close()

    # 5. Add special tokens
    for special_token in special_tokens:
        id2token[len(id2token)] = special_token

    return {v: k for k, v in id2token.items()}, merges

In [12]:
# Обучаем токенизатор на тренировочных текстах
# Для нашей задачи хватит и небольшого словаря, но можете пробовать и большего размера обучить!

# vocab, merges = train(dataset["train"]["text"], vocab_size=1024, special_tokens=["[EOS]"])

In [13]:
# Посмотрим на случайные токены

# random_tokens = [512, 614, 768, 888, 1022]
# unicode_to_bytes = {v: k for k, v in bytes_to_unicode().items()}
# for token_id in random_tokens:
#     token = [k for k, v in vocab.items() if v == token_id][0]
#     raw_bytes = bytes([unicode_to_bytes[it] for it in token])
#     print(f"Token #{token_id}: '{raw_bytes.decode('utf-8', errors='replace')}'")

NameError: name 'vocab' is not defined

In [129]:
class ByteLevelBPETokenizer:

    def __init__(self, vocab: dict[str, int], merges: list[tuple[str, str]], eos_token: str = "[EOS]"):
        """Byte-Level BPE Tokenizer

        Args:
            vocab: mapping from string token to id
            merges: list of merges in prioritized order
            eos_token: string representation of EOS token
        """
        super().__init__()
        if eos_token not in vocab:
            raise ValueError("There is no EOS token in vocab")
        self.byte_encoder = bytes_to_unicode()
        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
        self.token2id = vocab
        self.id2token = {v: k for k, v in self.token2id.items()}
        self.eos_token = eos_token
        self.eos_token_id = self.token2id[eos_token]

        # The closer the pair is to the beginning, the higher the rank
        self.merges = merges
        self.bpe_ranks = {pair: i for i, pair in enumerate(merges)}

    @lru_cache
    def bpe(self, word: tuple[str]) -> tuple[str]:
        """Process word into tokenized representation.
        Word is a tuple of base tokens, i.e. bytes.

        Under the hood:
        1. Tracks the set of token pairs, bi-grams
        2. While possible, replaces the highest-ranking pair with its union

        Args:
            word: list of base string tokens
        Return:
            list of BPE tokens
        """
        while True:
            rank = float('inf')
            index = -1
            for i in range(len(word) - 1):
                if (word[i], word[i+1]) in self.bpe_ranks:
                    if self.bpe_ranks[(word[i], word[i+1])] < rank:
                        rank = self.bpe_ranks[(word[i], word[i+1])]
                        index = i
            if index == -1:
                break
            word = word[:index] + tuple([word[index] + word[index+1]]) + word[index+2:]
        return word

    def encode(self, text: str, add_eos_token: bool = True) -> list[int]:
        """Convert string to list of token ids.

        Args:
            text: input string, may contain multiple words
            add_eos_token: whether to add eos token id at the end
        Return:
            list of ints, ids of tokenized text
        """
        words = WHITESPACE_SPLITTER.findall(text)
        tokens = []
        for word in words:
            tokenized_word = tokenize_word(word, self.id2token)
            bpe_res = self.bpe(tuple(tokenized_word))
            tokens += [self.token2id[token] for token in bpe_res if token in self.token2id]
        if add_eos_token:
            tokens += [self.token2id[self.eos_token]]
        return tokens
    def decode(self, idx: list[int]) -> str:
        """Convert list of tokens' ids to text, opposite to encode method

        Args:
            idx: list of tokens' ids
        Return:
            string, decoded text
        """
        string = ''
        for i in idx:
            byte = [self.byte_decoder[tok] for tok in self.id2token[i]]
            string += bytes(byte).decode('utf-8')
            print(string)
        return string

    def push_to_hub(self, repo_id, *, private=None, token=None):
        api = HfApi()
        repo_id = api.create_repo(repo_id=repo_id, token=token, private=private, exist_ok=True).repo_id

        # Push the files to the repo in a single commit
        with SoftTemporaryDirectory() as tmp:
            save_directory = Path(tmp) / repo_id
            save_directory.mkdir(parents=True)
            with open(save_directory / "vocabulary.json", "w") as f_out:
                print(json.dumps(self.token2id, indent=2), file=f_out)
            with open(save_directory / "merges.json", "w") as f_out:
                print(json.dumps({"merges": self.merges}), file=f_out)

            return api.upload_folder(repo_id=repo_id, folder_path=save_directory, token=token, commit_message='new_tokenizer')

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *, token=None, **model_kwargs):
        if not os.path.isdir(pretrained_model_name_or_path):
            storage_folder = snapshot_download(repo_id=pretrained_model_name_or_path, token=token)
        else:
            storage_folder = pretrained_model_name_or_path
        storage_folder = Path(storage_folder)
        with open(storage_folder / "vocabulary.json", "r") as f_in:
            vocab = json.load(f_in)
        with open(storage_folder / "merges.json", "r") as f_in:
            merges = [tuple(it) for it in json.load(f_in)["merges"]]
        return cls(vocab, merges, **model_kwargs)

In [15]:
# Инициализируем токенизатор

# tokenizer = ByteLevelBPETokenizer(vocab, merges)

In [16]:
# Загружаем токенизатор на хаб

# tokenizer.push_to_hub(REPO_NAME)

In [130]:
# Скачиваем токенизатор с хаба

tokenizer = ByteLevelBPETokenizer.from_pretrained(REPO_NAME)

Fetching 6 files: 100%|██████████| 6/6 [00:00<00:00,  9.99it/s]


In [None]:
# Смотрим на работу токенизатора

text = "Что было полгода назад? Помимо грандиозных событий, полгода назад были ещё семинары по линейной алгебре."
ids = tokenizer.encode(text)
reverse_text = [tokenizer.decode([it]) for it in ids]
print("|".join(reverse_text))
print(tokenizer.decode(ids))

In [132]:
# Посчитаем немного статистики по токенизации, определимся с размером контекста у модели

lens = []
for text in tqdm(dataset["test"]["text"]):
    ids = tokenizer.encode(text)
    lens.append(len(ids))

print(f"Average token len per sample: {sum(lens) / len(lens):.2f}")
print(f"Minimum and maximum lens are: {min(lens)} and {max(lens)}")

100%|██████████| 15056/15056 [00:06<00:00, 2350.48it/s]

Average token len per sample: 69.67
Minimum and maximum lens are: 4 and 3207





Должно получиться в среднем по 70 токенов на последовательность.
Контекста в 128 токенов будет вполне достаточно.

# Модель [10 баллов]

В качестве модели реализуем трансформер, в котором
1. В качестве позиционных эмбеддингов используется ALiBi
2. Механизм внимания использует GQA
3. В Feed-Forward блоке SwiGLU

In [20]:
# Для удобства заведем конфиг для модели


@dataclass
class TransformerConfig:
    n_layer: int
    n_head: int
    n_kv_head: int
    hidden_dim: int
    intermediate_dim: int
    dropout: float = 0.1
    vocab_size: int = 1024
    max_seq_len: int = 128


model_configs = {
    "nano": TransformerConfig(n_layer=3, n_head=4, n_kv_head=2, hidden_dim=96, intermediate_dim=256),
    "mini": TransformerConfig(n_layer=6, n_head=6, n_kv_head=3, hidden_dim=384, intermediate_dim=1024),
    "small": TransformerConfig(n_layer=12, n_head=12, n_kv_head=6, hidden_dim=768, intermediate_dim=2048),
}

In [118]:

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        """Root Mean Square Layer Normalization

        Args:
            dim: Feature dimension
            eps: Small constant for numerical stability
        """
        super().__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(dim))

    def forward(self, x: Tensor) -> Tensor:
        sum_dim = torch.sqrt((x * x).sum(-1) / x.shape[-1]+ self.eps).unsqueeze(-1)
        scaled = (x / sum_dim) * self.scale.repeat(x.shape[0], 1, 1)
        return scaled

class CausalSelfAttention(nn.Module):
    def __init__(self, config: TransformerConfig):
        """Causal Self-Attention with support of
        Grouped-Query Attention and ALiBi for positional encoding
        """
        super().__init__()
        self.config = config
        assert self.config.hidden_dim % self.config.n_head == 0
        assert self.config.n_head % self.config.n_kv_head == 0
        self.head_dim = self.config.hidden_dim // self.config.n_head
        self.scale = self.head_dim**-0.5
        self.q_per_kv = self.config.n_head // self.config.n_kv_head

        # Init projection layers
        self.q_proj = nn.Linear(self.config.hidden_dim, self.head_dim * self.config.n_head)
        # почему тут 1 слой на k и v?
        # self.q_proj = todo()
        # self.kv_proj = todo()
        # self.out_proj = todo()
        self.k_proj = nn.Linear(self.config.hidden_dim, self.head_dim * self.config.n_kv_head)
        self.v_proj = nn.Linear(self.config.hidden_dim, self.head_dim * self.config.n_kv_head)

        self.attn_dropout = nn.Dropout(self.config.dropout)
        
        #self.out_proj = 
        self.register_buffer("causal_mask", self._create_causal_mask(self.config.max_seq_len))
        self.register_buffer("alibi", self._build_alibi_bias(self.config.n_head))

    def _build_alibi_bias(self, num_heads: int) -> Tensor:
        """Build ALiBi for specified number of heads:

        Returns:
            Tensor with ALiBi biases, shape: [1, num heads, 1, 1]
        """
        step = 8 / num_heads 
        arange = [(1/2)**n for n in torch.arange(step, 8 + step, step)]
        alibi = torch.tensor(arange).view(1, num_heads, 1, 1)
        return alibi

    def _create_causal_mask(self, max_seq_len: int) -> Tensor:
        """Create causal mask with ones where tokens can attend to each other.

        Returns:
            Tensor with causal mask, shape: [1, 1, seq len, seq len]
        """
        mask = [[0 if j > i else 1 for j in range(max_seq_len)] for i in range(max_seq_len)]
        return torch.tensor(mask).view(1, 1, max_seq_len, max_seq_len)

    def forward(self, x: Tensor, attention_mask: Tensor = None) -> Tensor:
        """Apply Self-Attention to input data with respect to pad tokens.

        Args:
            x: input tensor, shape [bs, seq len, hidden dim]
            attention_mask: mask with zeros for pad tokens, shape [bs, seq len]
        Returns:
            result tensor, shape [bs, seq len, hidden dim]
        """
        b_s = x.shape[0]
        seq_len = x.shape[1]
        
        # Применим слои линейные сначала
        q = rearrange(self.q_proj(x), "b_s seq_len (h g head_dim) -> b_s g h seq_len head_dim", head_dim=self.head_dim, h=self.config.n_kv_head) 
        k = rearrange(self.k_proj(x), "b_s seq_len (h head_dim) -> b_s h seq_len head_dim", head_dim=self.head_dim)
        v = rearrange(self.v_proj(x), "b_s seq_len (h head_dim) -> b_s h seq_len head_dim", head_dim=self.head_dim)
        # Как итог размерности 
        # q - (b_s, g, h, seq_len, head_dim)
        # v - (b_s, h, seq_len, head_dim)
        # k - (b_s, h, seq_len, head_dim)
        # print(q.shape, k.shape, v.shape)
        # g - количество крупп, h - количество голов V, K, head_dim - размер головы, b_s - batch_size
        relevance = einsum(q, k, "b_s g h seq_len head_dim, b_s h seq_len_1 head_dim -> b_s g h seq_len seq_len_1")
        relevance = rearrange(relevance, "b_s g h seq_len seq_len_1 -> b_s (g h) seq_len seq_len_1")
        
        # Добавим alibi
        diff_matrix = [[[j - i if j< i else 0 for j in range(seq_len)] for i in range(seq_len)] for b in range(b_s)]
        bias = torch.tensor(diff_matrix, device='cuda' if torch.cuda.is_available() else 'cpu').view(b_s, 1, seq_len, seq_len).repeat(1, self.config.n_head, 1, 1)
        ALiBi = self.alibi * bias
        
        relevance = relevance + ALiBi
        # нигде не использую casual_mask пока что
        if attention_mask is not None:
            attention_mask = attention_mask.to(torch.bool)
            attention_mask = rearrange(attention_mask, "b_s seq_len-> b_s () () seq_len")
            
            relevance = relevance.masked_fill_(~attention_mask, torch.finfo(relevance.dtype).min)
        relevance = torch.softmax(relevance, dim=-1)
        
        relevance = self.attn_dropout(relevance)
        relevance = rearrange(relevance, "b_s (g h) seq_len seq_len_1 -> b_s g h seq_len seq_len_1",h=self.config.n_kv_head)
        values = einsum(relevance, v, "b_s g h seq_len seq_len_1, b_s h seq_len head_dim -> b_s g h seq_len head_dim")
        values = rearrange(values, " b_s g h seq_len head_dim -> b_s seq_len (g h head_dim)")
        return values
    
class SwiGLU(nn.Module):
    def __init__(self, config: TransformerConfig):
        """Gated Liner Unit with Swish Activation"""
        super().__init__()
        self.config = config
        # Init up- and down- projection layers
        self.fc_up1 = nn.Linear(self.config.hidden_dim, self.config.intermediate_dim)
        self.fc_up2 = nn.Linear(self.config.hidden_dim, self.config.intermediate_dim)
        self.fc_down = nn.Linear(self.config.intermediate_dim, self.config.hidden_dim)

    def forward(self, x: Tensor) -> Tensor:
        """Apply SwiGLU to input data.

        Args:
            x: input tensor, shape [bs, seq len, hidden dim]
        Returns:
            result tensor, shape [bs, seq len, hidden dim]
        """
        
        up_1 = torch.sigmoid(self.fc_up1(x))
        up_2 = self.fc_up2(x)
        silu = nn.SiLU()
        up_2 = silu(up_2)
        multi = up_1 * up_2
        
        return self.fc_down(multi)

class Block(nn.Module):
    def __init__(self, config: TransformerConfig):
        """Base Transformer Block
        - Causal Self-Attention and SwiGLU as main elements
        - Pre-normalization via RMSNorm
        - Regularization with dropouts before residuals
        """
        super().__init__()
        self.ln_1 = RMSNorm(config.hidden_dim)
        self.res_dropout_1 = nn.Dropout(config.dropout)
        self.attn = CausalSelfAttention(config)

        self.ln_2 = RMSNorm(config.hidden_dim)
        self.res_dropout_2 = nn.Dropout(config.dropout)
        self.mlp = SwiGLU(config)

    def forward(self, x: Tensor, attention_mask: Tensor = None) -> Tensor:
        """Apply Transformer Block to input data.

        Args:
            x: input tensor, shape [bs, seq len, hidden dim]
            attention_mask: mask with zeros for pad tokens, shape [bs, seq len, hidden dim]
        Returns:
            result tensor, shape [bs, seq len, hidden dim]
        """
        norm_1 = self.ln_1(x)
        drop_1 = self.res_dropout_1(norm_1)
        attn = self.attn(drop_1, attention_mask)
        
        norm_2 = self.ln_2(attn)
        drop_2 = self.res_dropout_2(norm_2)
        attn_full = self.mlp(drop_2)
        
        return attn_full

class TransformerForCausalLM(nn.Module, PyTorchModelHubMixin):
    def __init__(self, config: TransformerConfig):
        """Transformer model for Language Modeling"""
        super().__init__()
        self.vocab_size = config.vocab_size
        self.max_seq_len = config.max_seq_len
        self.n_layer = config.n_layer
        self.n_head = config.n_head
        self.hidden_dim = config.hidden_dim
        self.dropout = config.dropout

        self.token_emb = nn.Embedding(self.vocab_size, self.hidden_dim)
        self.emb_dropout = nn.Dropout(config.dropout)
        self.layers = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
        self.ln_final = RMSNorm(config.hidden_dim)
        self.lm_head = nn.Linear(self.hidden_dim, self.vocab_size)

        self.apply(self._init_weights)

        n_params = sum(p.numel() for p in self.parameters())
        print(f"Number of parameters: {n_params / 1e6:.2f}M")

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, RMSNorm):
            torch.nn.init.ones_(module.scale)

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None) -> Tensor:
        """Calculate logits for given input ids.

        Args:
            x: input tensor, shape [bs, seq len]
            attention_mask: mask with zeros for pad tokens, shape [bs, seq len, hidden dim]
        Returns:
            logits, shape [bs, seq len, vocab_size]
        """
        embeds = self.token_emb(input_ids)
        embeds = self.emb_dropout(embeds)
        
        for layer in self.layers:
            embeds = layer(embeds, attention_mask)
        
        embeds_drop = self.ln_final(embeds)
        return self.lm_head(embeds_drop)
        

    @torch.inference_mode()
    def generate(
        self, idx: Tensor, max_new_tokens, eos_token_id, temperature=1.0, do_sample=False, top_k=None
    ) -> Tensor:
        """Take a conditioning sequence of indices and complete the sequence max_new_tokens times,
        feeding the predictions back into the model each time.

        Args:
            idx: tensor with conditional tokens, shape [seq len]
            max_new_tokens: maximum number of new tokens
            eos_token_id: index of EOS token to stop generation
            temperature, do_sample, top_k: generation parameters
        Return:
            tensor with generated indexes
        """
        for _ in range(max_new_tokens):
            idx_cond = idx if idx.shape[1] <= self.max_seq_len else idx[:, -self.max_seq_len :]
            logits = self(idx_cond)
            # 1. Pluck the logits at the final step and scale by desired temperature
            logits = (logits[0][-1])**(1/temperature)
            # 2. Optionally crop the logits to only the top k options
            if top_k is not None:
                top = [i.item() for i in torch.topk(logits, top_k).indices]
                mask = list(set(range(len(logits))) - set(top))
                logits[mask] = -float("inf")
            # 3. apply softmax to convert logits to probabilities
            probs = torch.softmax(logits, dim=-1)
            #print(torch.sum(probs), 'probs')
            #print(torch.isnan(prods).sum())
            # 4. Either sample from the distribution or take the most likely element
            if do_sample:
                idx_next = torch.tensor(np.random.choice(range(self.vocab_size), p=np.nan_to_num(probs.cpu().numpy())), device='cuda' if torch.cuda.is_available()
                                       else 'cpu')
            else:
                idx_next = torch.argmax(probs)
            # print(idx_next.reshape(1, 1), idx)
            idx_next = idx_next.reshape(1, 1)
            
            # 5. Append sampled index to the running sequence and continue
            idx = torch.cat((idx, idx_next), dim=1)
            if idx_next == eos_token_id:
                break
        return idx

# Train Loop [2 + 2 балла]

Настало время обучать модель.
Небольшую можно пробовать обучать локально, но лучше всего воспользоваться GPU, например, на Google Colab.

За реализацию 2 балла, и еще 2 балла - если модель научилась генерить анекдоты.

Не забудьте проверить, что вы загрузили нужные веса на HF и у проверяющего скачается нужная версия.

In [100]:
# Определим датасет и как заворачивать семплы в батч
# Разные тексты имеют разную длину, поэтому будет падить до самого длина семпла
# Так же заведем дополнительную маску, чтобы механизм внимания не учитывал падинги


class TextDataset(torch.utils.data.Dataset):
    def __init__(self, texts, tokenizer):
        self.texts = texts
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        texts = self.texts[idx]
        tokenized_sequence = self.tokenizer.encode(texts)
        return tokenized_sequence


def data_collator(
    tokenized_sequences: list[list[int]], pad_token_id: int, max_seq_len: int = None
) -> tuple[torch.Tensor, torch.Tensor]:
    batch_size = len(tokenized_sequences)
    max_batch_seq_len = min(max_seq_len, max((len(it) for it in tokenized_sequences)))

    input_ids = torch.full((batch_size, max_batch_seq_len), pad_token_id)
    attention_mask = torch.zeros((batch_size, max_batch_seq_len))

    for i, tok_seq in enumerate(tokenized_sequences):
        cur_len = min(len(tok_seq), max_batch_seq_len)
        input_ids[i, :cur_len] = torch.tensor(tok_seq[:cur_len])
        attention_mask[i, :cur_len] = 1

    return input_ids, attention_mask


def create_dataloader(dataset, pad_token_id, max_seq_len, batch_size, is_train):
    collate_fn = partial(data_collator, pad_token_id=pad_token_id, max_seq_len=max_seq_len)
    return DataLoader(
        dataset, batch_size=batch_size, shuffle=is_train, drop_last=is_train, collate_fn=collate_fn, pin_memory=True
    )


_d = TextDataset(["Привет!", "Как твои дела?", "Осталось совсем немного до конца"], tokenizer)
_dl = create_dataloader(_d, tokenizer.eos_token_id, max_seq_len=16, batch_size=3, is_train=False)

for i, batch in enumerate(_dl):
    print(f"Batch #{i}")
    input_ids, attn_mask = batch
    print(input_ids, attn_mask.shape, sep="\n\n")

Batch #0
tensor([[1007,  589,   33, 1023, 1023, 1023, 1023, 1023, 1023, 1023, 1023, 1023,
         1023, 1023],
        [ 373,  339,  940,  260,  682,   63, 1023, 1023, 1023, 1023, 1023, 1023,
         1023, 1023],
        [ 375,  410,  676,  395,  264,  262,  323,  312,  269,  531,  365,  744,
          526, 1023]])

torch.Size([3, 14])


In [28]:
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    """Scheduler for Optimizer with linear warmup and linear decay to the end of training

    Args:
        optimizer: torch optimizer to control learning rate
        num_warmup_steps: number of warmup steps
        num_training_steps: total number of training steps
    Return:
        torch learning rate scheduler
    """
    assert num_training_steps >= num_warmup_steps
    max_lr = 0.1

    def lr_lambda(current_step):
        min_lr = 0
        start_lr = 0.01
        max_lr = 0.1
        step_down = (max_lr - min_lr) / (num_training_steps - num_warmup_steps)
        step_up = (max_lr - start_lr) / num_warmup_steps
        if current_step <= num_warmup_steps:
            return start_lr + current_step * step_up
        else:
            return max_lr * 2 / 3- (current_step - num_warmup_steps) * step_down
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


def cross_entropy_loss(input_ids: Tensor, attention_mask: Tensor, logits: Tensor) -> Tensor:
    """Calculate Cross-Entropy loss for Language Modeling task
    Under the hood:
    1. Create targtes based on input ids
    2. Masked out tokens corresponded to paddings
    3. Calculate cross entropy loss

    Args:
        input_ids: tensor with input ids, shape [bs, seq len]
        attention_mask: mask with zeros for pad tokens, shape [bs, seq len]
        logits: predicted logits, shape [bs, seq len, vocab size]
    Return:
        cross entropy loss, single-item tensor
    """
    loss = nn.CrossEntropyLoss(ignore_index=tokenizer.eos_token_id)
    targets = input_ids[:, 1:]
    return loss(logits[:, :-1, :].transpose(1, 2), targets)

In [29]:
# import wandb
# wandb.login(key='03473c2b8e9b6a50995c56a4492a3bcd7da7483f')

# wandb.init(project="vk_hw1", name="second_try")

In [30]:
# Определим тренера с наиболее важными гиперпараметрами для обучения


class Trainer:

    def __init__(
        self,
        learning_rate=3e-4,
        weight_decay=0.01,
        clip_grad_norm=1.0,
        n_steps=10_000,
        val_every_n_steps=1_000,
        plot_every_n_steps=100,
    ):
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.clip_grad_norm = clip_grad_norm
        self.n_steps = n_steps
        self.val_every_n_steps = val_every_n_steps
        self.plot_every_n_steps = plot_every_n_steps

        if torch.cuda.is_available():
            self.device = "cuda"
        elif torch.backends.mps.is_available():
            self.device = "mps"
        else:
            self.device = "cpu"
        print("running on device", self.device)

    @torch.no_grad()
    def validate(self, model, val_loader):
        model.eval()
        val_loss = 0.0
        for batch in tqdm(val_loader, desc="Validating", leave=False):
            input_ids, attention_mask = batch
            input_ids = input_ids.to(self.device, non_blocking=True)
            attention_mask = attention_mask.to(self.device, non_blocking=True)

            logits = model(input_ids, attention_mask)  # [bs; seq len; vocab size]
            val_loss += cross_entropy_loss(input_ids, attention_mask, logits)
        return val_loss / len(val_loader)

    def run(self, model, train_loader, val_loader):
        model = model.to(self.device)
        optimizer = torch.optim.AdamW(model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=0.1 * self.n_steps, num_training_steps=self.n_steps
        )
        model.train()
        # wandb.watch(model, log='all', criterion=torch.nn.CrossEntropyLoss, log_freq=100)
        plotlosses = PlotLosses(figsize=(15, 9), step_names="Step")
        logs = {"lr": 0, "epoch": 0}

        data_iter = iter(train_loader)
        for iter_num in range(self.n_steps):
            try:
                batch = next(data_iter)
            except StopIteration:
                data_iter = iter(train_loader)
                logs["epoch"] += 1
                batch = next(data_iter)

            input_ids, attention_mask = batch
            input_ids = input_ids.to(self.device, non_blocking=True)
            attention_mask = attention_mask.to(self.device, non_blocking=True)

            logits = model(input_ids, attention_mask)  # [bs; seq len; vocab size]
            loss = cross_entropy_loss(input_ids, attention_mask, logits)
            
            # backprop and update the parameters
            model.zero_grad(set_to_none=True)
            loss.backward()
            # torch.nn.utils.clip_grad_norm_(model.parameters(), self.clip_grad_norm)
            # for param in model.parameters():
            #     print('------------------------------------------------')
            #     print(param.sum().item(), 'суммма witghs\n', 
            #           param.grad.sum(), 'суммма градиентов\n', 
            #           torch.isnan(param.grad).sum(), 'сумма нулов\n',
            #           torch.isclose(param.grad, torch.zeros(param.grad.shape, device='cuda')).sum(), 'равные нулю градиенты\n')
            optimizer.step()
            scheduler.step()

            if iter_num > 0 and iter_num % self.val_every_n_steps == 0:
                val_loss = self.validate(model, val_loader)
                plotlosses.update({"val_loss": val_loss.item()}, current_step=iter_num)
                plotlosses.send()
                model.train()

            if iter_num % self.plot_every_n_steps == 0:
                logs["loss"] = loss.item()
                logs["lr"] = scheduler.get_last_lr()[0]
                plotlosses.update(logs, current_step=iter_num)
                plotlosses.send()

        val_loss = self.validate(model, val_loader)
        plotlosses.update({"val_loss": val_loss.item()}, current_step=iter_num)
        plotlosses.send()

In [31]:
# Создаем тренировочный и тестовые даталоадеры


MAX_SEQ_LEN = 128
BATCH_SIZE = 16

train_dataset = TextDataset(dataset["train"]["text"], tokenizer)
train_dataloader = create_dataloader(
    train_dataset, tokenizer.eos_token_id, max_seq_len=MAX_SEQ_LEN, batch_size=BATCH_SIZE, is_train=True
)

test_dataset = TextDataset(dataset["test"]["text"], tokenizer)
test_dataloader = create_dataloader(
    test_dataset, tokenizer.eos_token_id, max_seq_len=MAX_SEQ_LEN, batch_size=BATCH_SIZE, is_train=False
)

In [32]:
# Инициализируем модель

config = model_configs["mini"]
model = TransformerForCausalLM(config)

Number of parameters: 9.66M


In [33]:
# Инициализируем тренера

trainer = Trainer(learning_rate=3e-4)

running on device cuda


In [None]:
# Обучение goes brrrr!
trainer.run(model, train_dataloader, test_dataloader)

In [None]:
# Инициализируем модель

config = model_configs["mini"]
model = TransformerForCausalLM(config).to('cuda' if torch.cuda.is_available() else 'cpu')
model.from_pretrained(REPO_NAME)

In [None]:
# Смотрим на качество генерации глазами
# Для маленьких и слабых моделей "затягиваем" гайки генерации

text = "Заходит в бар"
input_ids = torch.tensor(tokenizer.encode(text)[:-1], device=trainer.device)[None, :]
print(input_ids)
model_output = model.generate(
    input_ids, max_new_tokens=200, eos_token_id=tokenizer.eos_token_id, do_sample=True, top_k=10
)
tokenizer.decode(model_output[0].tolist())

In [134]:
# Загружаем модель на хаб

# model.push_to_hub(REPO_NAME)

In [135]:
# torch.save(model.state_dict(), 'model.pt')

In [136]:
# model.push_to_hub(REPO_NAME)

Поиграйтесь с гиперпараметрами, попробуйте обучить `mini` и `small` версии.
Постарайтесь добиться как можно более высокого качества как в терминах лосса, так и при визуальной оценке генерации.

### Дополнительные баллы

Вы также можно заработать дополнительные баллы:
- Реализовать Rotary Positional Embedding **[4 балла]**
- Реализовать Multi-Head Latent Attention **[2 балл]**
- Оформить репозиторий на 🤗: карточка модели с описанием задания, репортом качества и примерами генерации **[2 балл]**

### НЕ получилось обучить норм модель, просто не хватило времени( 
### Модель генерирует последовательность, которую не может декодировать токенайзер. Наилучший лосс, который выбил - 5.

# Специальный раздел для проверяющего

In [None]:
device = torch.device("cuda")

tokenizer = ByteLevelBPETokenizer.from_pretrained(REPO_NAME)
check_model = TransformerForCausalLM.from_pretrained(REPO_NAME)
check_model = check_model.to(device)
check_model = check_model.eval()

In [None]:
text = "Штирлиц пришел домой"
input_ids = torch.tensor(tokenizer.encode(text), device=device)
model_output = check_model.generate(
    input_ids[None, :], max_new_tokens=200, eos_token_id=tokenizer.eos_token_id, do_sample=True, top_k=10
)
tokenizer.decode(model_output[0].tolist())