# Imports

In [1]:
import numpy as np
import torch

from torch import nn
from torch.utils.data import DataLoader, Dataset
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
import transformers

import pandas as pd
import re
import os
import math
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings('ignore')

from dotenv import load_dotenv
load_dotenv()

True

Environment variables

In [2]:
PATH = os.getenv("PATH")
DATAPATH = os.getenv("DATAPATH")
PREPARED_DATA_DIR = os.getenv("PREPARED_DATA_DIR")
CACHE_DIR = os.getenv("CACHE_DIR")
#TOK_NAME = "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B"
TOK_NAME = os.getenv("TOK_NAME")
PARQUET_DATA_DIR = os.getenv("PARQUET_DATA_DIR")
CHECKPOINTS_PATH = os.getenv("CHECKPOINTS_PATH")
NPY_DATA_DIR = os.getenv("NPY_DATA_DIR")

## Config

In [4]:
GPT_CONFIG = {
    'vocab_size': 50257, # in 151670 (if you use tokenizer.vocab_size then you get partial vocab_size without added tokens)
    'context_length': 1024,
    'emb_dim': 256, #768
    'n_heads': 32,#12,
    'n_layers': 32,#12,
    'drop_rate': 0.05, # 0l1
    'qkv_bias': False,
    'num_segments': 2,
    'initializer_range': 0.02
    }

In [5]:
device = 'cuda' if (torch.cuda.is_available()) else 'cpu'
device

'cuda'

# Dataset

## Load Tokenizer

In [6]:
tok = transformers.AutoTokenizer.from_pretrained(TOK_NAME, cache_dir=CACHE_DIR)

In [7]:
tok.pad_token = '<|pad|>'

### Check tokenizer

In [6]:
tok.get_added_vocab()

{'<|endoftext|>': 50256}

In [7]:
tok.eos_token, tok.bos_token, 

('<|endoftext|>', '<|endoftext|>')

In [8]:
tok.vocab_size

50257

In [7]:
tok.pad_token, tok.eos_token, tok.bos_token, 

('<|pad|>', '<|endoftext|>', '<|endoftext|>')

In [None]:
# If tokenizer dont have pad_token
#tok.pad_token = tok.eos_token # Not, it not best idea

In [22]:
tok('Привет, как дела mhjm <|endoftext|><|pad|><|pad|><|pad|><|pad|>', return_tensors='pt', padding='max_length', max_length=100)['input_ids']

tensor([[  140,   253, 21169, 18849, 38857, 16843, 20375,    11, 12466,   118,
         16142, 31583, 12466,   112, 16843, 30143, 16142,   285,    71,    73,
            76,   220, 50256,    27,    91, 15636,    91,  6927,    91, 15636,
            91,  6927,    91, 15636,    91,  6927,    91, 15636,    91,    29,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256]])

## Iterative version of dataset

In [10]:
class CustomDatasetV3(Dataset):
    def __init__(self, dataframe: str, tokenizer: object, max_length: int):
        self.tokenizer = tokenizer
        self.input_ids = []
        self.target_ids = []

        for i, curr_chunk in tqdm(dataframe.iterrows(), total=dataframe.shape[0]):
            token_ids = tokenizer(curr_chunk['Sample'], return_tensors='pt', padding='max_length', max_length=max_length+1)['input_ids']
            input_chunk = token_ids[:,:max_length].view(-1)
            target_chunk = token_ids[:,1:max_length+1].view(-1)
            #print(input_chunk.size(), target_chunk.size(),)
            self.input_ids.append(input_chunk)
            self.target_ids.append(target_chunk)
        
    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, index):

        return self.input_ids[index], self.target_ids[index]

### Load actual data and dataloader

In [7]:
data_parquet = pd.read_parquet(PARQUET_DATA_DIR)

In [12]:
data_parquet.shape

(224662, 2)

In [13]:
train_cd = CustomDatasetV3(dataframe=data_parquet.iloc[:100000], tokenizer=tok, max_length=GPT_CONFIG['context_length'])#MY_GPT_CONFIG['context_length'])
#train_cd = CustomDatasetV3(dataframe=data_parquet.iloc[:100], tokenizer=tok, max_length=GPT_CONFIG['context_length'])#MY_GPT_CONFIG['context_length'])

  0%|          | 0/100000 [00:00<?, ?it/s]

In [14]:
val_cd = CustomDatasetV3(dataframe=data_parquet.iloc[-10000:], tokenizer=tok, max_length=GPT_CONFIG['context_length'])#MY_GPT_CONFIG['context_length'])
#val_cd = CustomDatasetV3(dataframe=data_parquet.iloc[-100:], tokenizer=tok, max_length=GPT_CONFIG['context_length'])#

  0%|          | 0/10000 [00:00<?, ?it/s]

batch_size maybe 8 or 12 (or 16) check

In [15]:
train_data = DataLoader(dataset=train_cd, batch_size=12, shuffle=True, num_workers=0) # num_workers=2 don't work?
val_data = DataLoader(dataset=val_cd, batch_size=12, shuffle=True, num_workers=0) # num_workers=2 don't work?

In [16]:
next(iter(train_data))

[tensor([[12466,   251, 16843,  ..., 35072, 21169, 18849],
         [  198,   198,   140,  ..., 35072,   141,   229],
         [  198,   198,   140,  ..., 16843, 21169,   141],
         ...,
         [12466,   239, 45035,  ...,   198,   198,   140],
         [  198,   198,   140,  ...,   141,   229, 16142],
         [12466,   252, 22177,  ..., 15166, 18849,   141]]),
 tensor([[  251, 16843, 21727,  ..., 21169, 18849, 15166],
         [  198,   140,   239,  ...,   141,   229, 18849],
         [  198,   140,   246,  ..., 21169,   141,   227],
         ...,
         [  239, 45035, 30143,  ...,   198,   140,    94],
         [  198,   140,   240,  ...,   229, 16142, 16843],
         [  252, 22177,   220,  ..., 18849,   141,   227]])]

## Summed version of text

In [13]:
from pandarallel import pandarallel # import pandarallel
pandarallel.initialize() # initialize pandarallel
tqdm.pandas()

INFO: Pandarallel will run on 8 workers.
INFO: Pandarallel will use standard multiprocessing data transfer (pipe) to transfer data between the main process and workers.

https://nalepae.github.io/pandarallel/troubleshooting/


In [14]:
def CustomTextProcerssing(dataframe, eos_token):#tok.pad_token, tok.eos_token, tok.bos_token, 
    sum_text = ""
    for i, curr_chunk in tqdm(dataframe.iterrows(), total=dataframe.shape[0]):
        sum_text += curr_chunk['Sample'] + eos_token
    return sum_text

In [7]:
data_parquet = pd.read_parquet(PARQUET_DATA_DIR)

In [11]:
sum_txt = CustomTextProcerssing(data_parquet, tok.eos_token)

  0%|          | 0/224662 [00:00<?, ?it/s]

KeyboardInterrupt: 

Все сразу сделать

In [8]:
data_parquet['Sample'] = data_parquet['Sample'] + tok.eos_token

In [9]:
data_parquet['Sample']

0         Литва́ ( ), официальное название — Лито́вская ...
1          В 1422 году в состав великого княжества оконч...
2         \n\nВ 1919 году в Литве введена должность през...
3          Блокада продолжалась 74 дня и прекратилась по...
4         \n\nВнутренняя политика \n\nВ июне 2008 года п...
                                ...                        
224657     В постсоветских изданиях за клубом закрепилос...
224658    \n\n1974 год становится для «Баварии» и её игр...
224659     Но в четвертьфинала путь «Баварии» преградила...
224660     6 апреля 2013 года, за шесть туров до окончан...
224661     Оба гола были забиты в дополнительное время т...
Name: Sample, Length: 224662, dtype: object

In [None]:
summed = data_parquet['Sample'].sum()

In [53]:
re.sub(
        r'[\u0000-\u001F\u007F\u0080-\u009F\u00A0\u00AD\u200B\u200C\u200D\uFEFF]',
        '',
        data_parquet['Sample'].iloc[:3].sum()
    )

'Литва́ ( ), официальное название— Лито́вская Респу́блика ()— государство, расположенное в Северной Европе Площадь—  км² Протяжённость с севера на юг— 280км, а с запада на восток— 370 км Население составляет  человек (август, 2023) Занимает 137-е место в мире по численности населения и 121-е по территории Имеет выход к Балтийскому морю, расположена на его восточном побережье Береговая линия составляет всего 99 км (наименьший показатель среди государств Балтии) На севере граничит с Латвией, на юго-востоке— с Белоруссией, на юго-западе— с Польшей и Калининградской областью России По площади и населению является самым крупным государством из стран БалтииСтолица— Вильнюс Официальный язык— литовский Денежная единица— евроВосстановление независимости страны провозглашено 11 марта 1990года 6 сентября 1991года Государственный совет СССР признал независимость Литвы Литва— член ООН (1991), ОБСЕ (1991), Совета Европы (1993), ВТО (2001), Европейского союза (2004), НАТО (2004) и ОЭСР (2018) Входит 

## New dataset

### Process

In [None]:
data_parquet = pd.read_parquet(PARQUET_DATA_DIR)

In [None]:
data_parquet['Sample'] = data_parquet['Sample'] + tok.eos_token

In [None]:
def replace_unicode(x):
    return re.sub(r'[\u0000-\u001F\u007F\u00800-\u009F\u00A0\u0AD\u200B\u200C\u200D\uFEFF]', '', x)

In [None]:
# data = re.sub(
#         r'[\u0000-\u001F\u007F\u0080-\u009F\u00A0\u00AD\u200B\u200C\u200D\uFEFF]',
#         '',
#         data_parquet['Sample'].iloc[:3].sum()
#     )

In [11]:
data_parquet['Sample'].iloc[:3].apply(replace_unicode)

0    Литва́ ( ), официальное название— Лито́вская Р...
1     В 1422году в состав великого княжества оконча...
2    В 1919году в Литве введена должность президент...
Name: Sample, dtype: object

In [12]:
data_parquet['Sample'].iloc[:3]

0    Литва́ ( ), официальное название — Лито́вская ...
1     В 1422 году в состав великого княжества оконч...
2    \n\nВ 1919 году в Литве введена должность през...
Name: Sample, dtype: object

In [14]:
data_parquet['Sample'] = data_parquet['Sample'].progress_apply(replace_unicode)

  0%|          | 0/224662 [00:00<?, ?it/s]

In [19]:
PARQUET_DATA_DIR

'G:\\My_files\\Programming\\My_projects\\LLM\\GPT-like_trained\\Data\\Processed\\Parquet\\data_6144symb_max.parquet'

In [20]:
data_parquet.to_parquet("G:\\My_files\\Programming\\My_projects\\LLM\\GPT-like_trained\\Data\\Processed\\Parquet\\data_filtered.parquet")

### Load

In [7]:
data_parquet = pd.read_parquet("G:\\My_files\\Programming\\My_projects\\LLM\\GPT-like_trained\\Data\\Processed\\Parquet\\data_filtered.parquet")

In [40]:
tokenized_data = np.array([], dtype=np.int32)

In [41]:
tokenized_data

array([], dtype=int32)

In [42]:
for i, curr_chunk in tqdm(data_parquet.iterrows(), total=data_parquet.shape[0]):
    tokenized_data = np.append(tokenized_data, tok(curr_chunk['Sample'])['input_ids'])

  0%|          | 0/224662 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [44]:
data_parquet.shape[0]//2

112331

In [48]:
tokenized_data = []
tokenized_data

[]

In [49]:
for i, curr_chunk in tqdm(data_parquet.iloc[:data_parquet.shape[0]//3].iterrows(), total=data_parquet.shape[0]//3):
    tokenized_data.extend(tok(curr_chunk['Sample'])['input_ids'])

  0%|          | 0/74887 [00:00<?, ?it/s]

In [50]:
len(tokenized_data)

462551521

In [51]:
with open('1-3.npy', 'wb') as f:
    np.save(f, tokenized_data)

In [13]:
data_parquet.shape[0]

224662

In [12]:
data_parquet.shape[0]//3, 2*data_parquet.shape[0]//3, 3*data_parquet.shape[0]//3

(74887, 149774, 224662)

In [15]:
from_ = data_parquet.shape[0]//3
to_ = 2*data_parquet.shape[0]//3
from_, to_

(74887, 149774)

In [16]:
to_ - from_

74887

In [18]:
tokenized_data = []
for i, curr_chunk in tqdm(data_parquet.iloc[from_:to_].iterrows(), total=to_ - from_):
    tokenized_data.extend(tok(curr_chunk['Sample'])['input_ids'])

  0%|          | 0/74887 [00:00<?, ?it/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (6456 > 1024). Running this sequence through the model will result in indexing errors


In [19]:
with open('2-3.npy', 'wb') as f:
    np.save(f, tokenized_data)

In [10]:
from_ = 2*data_parquet.shape[0]//3
to_ = 3*data_parquet.shape[0]//3
from_, to_

(149774, 224662)

In [12]:
tokenized_data = []
tokenized_data

[]

In [13]:
tokenized_data = []
for i, curr_chunk in tqdm(data_parquet.iloc[from_:to_].iterrows(), total=to_ - from_):
    tokenized_data.extend(tok(curr_chunk['Sample'])['input_ids'])

  0%|          | 0/74888 [00:00<?, ?it/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (6401 > 1024). Running this sequence through the model will result in indexing errors


In [14]:
with open('3-3.npy', 'wb') as f:
    np.save(f, tokenized_data)

### Load numpy

In [8]:
with open(os.path.join(NPY_DATA_DIR, '1-3.npy'), 'rb') as f:
    n1 = np.load(f)
#with open('2-3.npy', 'rb') as f:
#    n2 = np.load(f)
#with open('3-3.npy', 'rb') as f:
#    n3 = np.load(f)

In [8]:
#pd.concat([pd.DataFrame(n1), pd.DataFrame(n2)], axis=0)

In [9]:
class CustomDatasetV4(Dataset):
    def __init__(self, mas: list, max_length: int):
        self.input_ids = []
        self.target_ids = []
        self.max_length = max_length

        for idx in tqdm(range(0, len(mas), self.max_length), total=(len(mas)//self.max_length)):
            #token_ids = tokenizer(curr_chunk['Sample'], return_tensors='pt', padding='max_length', max_length=max_length+1)['input_ids']
            input_chunk = torch.tensor(mas[idx:idx+max_length]).view(-1)
            target_chunk = torch.tensor(mas[idx+1:idx+max_length+1]).view(-1)
            #print(input_chunk.size()[0], target_chunk.size()[0],)
            if ((input_chunk.size()[0] == self.max_length) or (target_chunk.size()[0] == self.max_length)):
                self.input_ids.append(input_chunk)
                self.target_ids.append(target_chunk)
        
    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, index):
        return self.input_ids[index], self.target_ids[index]

In [10]:
n1.shape

(462551521,)

In [11]:
train_cd = CustomDatasetV4(mas=n1[:-50000], max_length=GPT_CONFIG['context_length'])#MY_GPT_CONFIG['context_length'])
#train_cd = CustomDatasetV3(dataframe=data_parquet.iloc[:100], tokenizer=tok, max_length=GPT_CONFIG['context_length'])#MY_GPT_CONFIG['context_length'])

  0%|          | 0/451661 [00:00<?, ?it/s]

In [12]:
val_cd = CustomDatasetV4(mas=n1[-50000:], max_length=GPT_CONFIG['context_length'])#MY_GPT_CONFIG['context_length'])
#val_cd = CustomDatasetV3(dataframe=data_parquet.iloc[-100:], tokenizer=tok, max_length=GPT_CONFIG['context_length'])#

  0%|          | 0/48 [00:00<?, ?it/s]

batch_size maybe 8 or 12 (or 16) check

In [13]:
train_data = DataLoader(dataset=train_cd, batch_size=16, shuffle=True, num_workers=0, drop_last=True) # num_workers=2 don't work? batch_size=12
val_data = DataLoader(dataset=val_cd, batch_size=16, shuffle=True, num_workers=0, drop_last=True) # num_workers=2 don't work? batch_size=12

In [14]:
del n1

In [15]:
[None for _ in tqdm(val_data)];

  0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
tok.decode(next(iter(train_data))[0][0])

In [None]:
#del n1

# LLM Code

In [16]:
class MultiHeadAttentionDP_QKV(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0)

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        self.W_qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)

        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = dropout#nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        b, num_tokens, d_in = x.size()
        qkv = self.W_qkv(x) # b, num_tokens, 3 * self.d_out
        queries, keys, values = qkv.split(self.d_out, dim=2)

        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)

        queries = queries.transpose(1, 2) # b, self.num_heads, num_tokens, self.head_dim
        keys = keys.transpose(1, 2) # b, self.num_heads, num_tokens, self.head_dim
        values = values.transpose(1, 2) # b, self.num_heads, num_tokens, self.head_dim

        # All code below we replace with torch.nn.functional.scaled_dot_product_attention
        context_vec = torch.nn.functional.scaled_dot_product_attention(queries, keys, values, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)

        # att_scores = queries @ keys.transpose(2, 3) # shapes = (num_tokens, self.head_dim) @ (self.head_dim, num_tokens) -> (num_tokens, num_tokens)
        # mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        # att_scores.masked_fill_(mask_bool, -torch.inf)
        # att_weights = torch.softmax(att_scores / keys.shape[-1]**0.5, dim=-1)
        # att_weights = self.dropout(att_weights)
        # context_vec = (att_weights @ values).transpose(1, 2) # (num_tokens, num_tokens) @ (num_tokens, self.head_dim) -> (num_tokens, self.head_dim) -> transpose(1,2) of (b, self.num_heads, num_tokens, self.head_dim) ->
        # # -> (b, num_tokens, self.num_heads, self.head_dim) as view in previous code after inference of Linear layers
        
        # Reshape etc
        context_vec = context_vec.transpose(1, 2).contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)
        return context_vec

In [17]:
#mha_dp_qkv = MultiHeadAttentionDP_QKV(d_in=embed_dim, d_out=embed_dim, context_length=context_len, dropout=dropout, num_heads=num_heads, qkv_bias=qkv_bias)

## Additional classes

In [18]:
class FeedForward(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(emb_dim, 4 * emb_dim),
            nn.Mish(), #GELU(),
            nn.Linear(4 * emb_dim, emb_dim)
        )
    def forward(self, x):
        return self.layers(x)

In [19]:
class ReversibleTransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        emb_dim = cfg['emb_dim']//2
        self.attn = MultiHeadAttentionDP_QKV(d_in=emb_dim, 
                                       d_out=emb_dim, 
                                       context_length=cfg['context_length'], 
                                       dropout=cfg['drop_rate'], 
                                       num_heads=cfg['n_heads'], 
                                       qkv_bias=cfg['qkv_bias'])
        self.ff = FeedForward(emb_dim)
        self.norm1 = nn.RMSNorm(emb_dim)#nn.LayerNorm(emb_dim) #LayerNorm(cfg['emb_dim'])
        self.norm2 = nn.RMSNorm(emb_dim)#nn.LayerNorm(emb_dim) #LayerNorm(cfg['emb_dim'])
        self.drop_resid = nn.Dropout(cfg['drop_rate'])

    def forward(self, x1, x2):
        # reversible update
        # y1 = x1 + f(x2)
        # y2 = x2 + g(y1)

        def f(u):
            u = self.norm1(u)
            attn_output = self.attn(u)
            attn_output = self.drop_resid(attn_output)
            return attn_output
        
        def g(v):
            return self.drop_resid(self.ff(self.norm2(v)))
        
        f_x2 = checkpoint(f, x2, use_reentrant=False, preserve_rng_state=True) # use_reentrant=False for effeciency, preserve_rng_state=True because of dropout
        y1 = x1 + f_x2
        g_y1 = checkpoint(g, y1, use_reentrant=False, preserve_rng_state=True) # use_reentrant=False for effeciency, preserve_rng_state=True because of dropout
        y2 = x2 + g_y1

        return y1, y2

In [20]:
class GPTModelRev(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.content_size = cfg['context_length']
        self.tok_emb = nn.Embedding(cfg['vocab_size'], cfg['emb_dim'])
        self.pos_emb = nn.Embedding(self.content_size, cfg['emb_dim'])
        self.drop_emb = nn.Dropout(cfg['drop_rate'])
        

        self.trf_blocks = nn.Sequential(*[ReversibleTransformerBlock(cfg) for _ in range(cfg['n_layers'])])
        self.final_norm = nn.RMSNorm(cfg['emb_dim'])#nn.LayerNorm(cfg['emb_dim'])
        self.out_head = nn.Linear(cfg['emb_dim'], cfg['vocab_size'], bias=False)
        # Weight tying, reference here: https://paperswithcode.com/method/weight-tying
        self.tok_emb.weight = self.out_head.weight

        # init all weights
        #self.apply(self._init_weights)
        #self.apply(lambda current_model: self._init_weights2(current_model, initializer_range=cfg['initializer_range']))

    def _init_weights(self, module):
        # function from Karpaty's guide
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def _init_weights2(self, module: nn.Module, initializer_range: float = 0.02):
        """
        Инициализация весов в стиле GPT-2:
        - Нормальное распределение N(0, initializer_range)
        - biases = 0
        - LayerNorm.weight = 1, LayerNorm.bias = 0
        """
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=initializer_range)
            if getattr(module, "bias", None) is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.weight.data.fill_(1.0)
            module.bias.data.zero_()

    def forward(self, in_idx):
        batch_size, seq_len = in_idx.size()
        tok_embeds = self.tok_emb(in_idx)
        pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
        x = tok_embeds + pos_embeds
        x = self.drop_emb(x)

        # initialize reversible pairs: split features
        # split last dim
        x1, x2 = torch.chunk(x, 2, dim=-1)  # each (batch_size, seq_len, emb_dim//2)

        # Now we change x = self.trf_blocks(x) to: 
        for layer in self.trf_blocks:
            x1, x2 = layer(x1, x2)
        # merge
        x = torch.cat([x1, x2], dim=-1)  # (b, s, dim)
        
        # Now as usual
        x = self.final_norm(x)
        logits = self.out_head(x)
        return logits
    
    def generate(self, idx, max_new_tokens: int = 250):
        self.eval()
        for _ in range(max_new_tokens):
            idx_cond = idx[:,-self.content_size:]
            with torch.no_grad():
                logits = self.forward(idx_cond)
            logits = logits[:, -1, :]
            probas = torch.softmax(logits, dim=-1)
            idx_next = torch.argmax(probas, dim=-1, keepdim=True)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

## Check

In [24]:
m = GPTModelRev(GPT_CONFIG)

In [61]:
m.to(device)

GPTModelRev(
  (tok_emb): Embedding(50257, 256)
  (pos_emb): Embedding(1024, 256)
  (drop_emb): Dropout(p=0.05, inplace=False)
  (trf_blocks): Sequential(
    (0): ReversibleTransformerBlock(
      (attn): MultiHeadAttentionDP_QKV(
        (W_qkv): Linear(in_features=128, out_features=384, bias=False)
        (out_proj): Linear(in_features=128, out_features=128, bias=True)
      )
      (ff): FeedForward(
        (layers): Sequential(
          (0): Linear(in_features=128, out_features=512, bias=True)
          (1): Mish()
          (2): Linear(in_features=512, out_features=128, bias=True)
        )
      )
      (norm1): RMSNorm((128,), eps=None, elementwise_affine=True)
      (norm2): RMSNorm((128,), eps=None, elementwise_affine=True)
      (drop_resid): Dropout(p=0.05, inplace=False)
    )
    (1): ReversibleTransformerBlock(
      (attn): MultiHeadAttentionDP_QKV(
        (W_qkv): Linear(in_features=128, out_features=384, bias=False)
        (out_proj): Linear(in_features=128, out_

In [62]:
for x, y in train_data:
    print(x.size())
    r = m(x.to(device))
    break

torch.Size([16, 1024])


In [63]:
r.size()

torch.Size([16, 1024, 50257])

# Training

In [21]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [22]:
class Trainer():
    def __init__(self, optimizer, params, device):
        self.optimizer = optimizer
        self.params = params
        self.device = device
    
    def train_model(self, 
                    model, 
                    tokenizer, 
                    train_dataloader, 
                    val_dataloader, 
                    writer: object = None, 
                    grad_accum: int = 1, 
                    max_norm: float = 1.0, 
                    scheduler: object = None, 
                    start_epoch: int = 0, 
                    path_to_save: str = ''):
        #train_epoch_loss = []
        val_loss = []
        cumulative_tokens_get = []
        tokens_get = 0
        scaler = torch.amp.GradScaler(device=self.device)
        t_d_len = len(train_dataloader)

        for epoch in range(start_epoch, self.params['N_EPOCHS']):
            total_loss = 0.0
            self.optimizer.zero_grad()
            for step, (x, y) in enumerate(train_dataloader):
                if not (model.training):
                    model.train()
                x, y = x.to(self.device), y.to(self.device)
                tokens_get += len(x.flatten())
                cumulative_tokens_get.append(tokens_get)
                with torch.autocast(device_type=device, dtype=torch.bfloat16):
                    logits = model(x)
                    loss = nn.functional.cross_entropy(logits.flatten(0, 1), y.flatten())

                scaler.scale(loss).backward() #loss.backward()
                curr_train_loss = loss.item()
                total_loss += curr_train_loss
                writer.add_scalar("Tokens get", tokens_get, step)
                writer.add_scalar("Current train loss", curr_train_loss, step)

                if (step % grad_accum == 0):
                    scaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(
                        model.parameters(), max_norm
                    )
                    scaler.step(self.optimizer)
                    scaler.update()
                    self.optimizer.zero_grad()
                    if scheduler:
                        scheduler.step()

                if (self.params['verbose'] is True) and (step % self.params['sample_freq'] == 0):
                    model.eval()
                    sample = tokenizer.decode(model.generate(idx=torch.tensor(tokenizer('Я расскажу тебе про')['input_ids'], 
                                                                              device=device).unsqueeze(0), 
                                                                              max_new_tokens=15).squeeze(0).tolist())
                    print(f'Epoch {epoch}, step {step} of {t_d_len}: Train loss = {loss}, sample: {sample}')

                if (self.params['verbose'] is True) and (step % self.params['verbose_freq'] == 0) and (writer is not None):
                    writer.add_scalar("Loss/train in step", loss, epoch)
                    writer.add_text("Sample", str(sample), epoch)
                    if (self.params['gradients'] is True):
                        grads = []
                        for name, param in model.named_parameters():
                            if ('weight' in name):
                                if (param.grad is not None):
                                    grads.append(param.grad.abs().flatten().mean().cpu().detach().numpy())
                        writer.add_scalar("train/gradients", np.array(grads).flatten().mean(), epoch)

                    with torch.no_grad():
                        for x, y in val_dataloader:
                            x, y = x.to(self.device), y.to(self.device)
                            logits = model(x)
                            loss = nn.functional.cross_entropy(logits.flatten(0, 1), y.flatten())
                            val_loss.append(loss)
                        if (writer is not None):
                            writer.add_scalar("Loss/train in check", torch.mean(torch.tensor(curr_train_loss, device='cpu')), epoch)
                            writer.add_scalar("Loss/val in check", torch.mean(torch.tensor(val_loss, device='cpu')), epoch)

                if (step % self.params['save_checkpoint_freq'] == 0):
                    # Сохранение чекпоинта (по желанию)
                    #is_best = val_loss < best_loss
                    #best_loss = min(val_loss, best_loss)
                    checkpoint = {
                        'epoch':           epoch,
                        'model_state':     model.state_dict(),
                        'optimizer_state': self.optimizer.state_dict(),
                        'scheduler_state': scheduler.state_dict(),
                    }
                    torch.save(checkpoint, os.path.join(path_to_save, 'checkpoint_latest.pth'))
                    #if is_best:
                    #    torch.save(checkpoint, 'checkpoint_best.pth')
        writer.close()

In [22]:
params = {'N_EPOCHS': 1, 
          'verbose': True, 
          'verbose_freq': 140,
          'save_checkpoint_freq': 140,
          'sample_freq': 20,
          'gradients': True}

In [23]:
model = GPTModelRev(GPT_CONFIG)

In [24]:
model = model.to(device)

In [25]:
amount_of_parameters = sum([p.numel() for p in model.parameters()])
amount_of_parameters

19452416

In [32]:
opt = torch.optim.AdamW(params=model.parameters(), lr=3e-4)
grad_accum = 1
warmup_ratio = 0.05  #0.1 # warm-up in first 10% of steps
total_steps = (len(train_data) // grad_accum) * params["N_EPOCHS"]
#total_steps = params["N_EPOCHS"] * len(train_data)
#scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=total_steps)
trainer = Trainer(optimizer=opt, params=params, device=device)

# Scheduler: linear warm-up + cosine decay
warmup_steps = int(warmup_ratio * total_steps)
def lr_lambda(current_step):
    if current_step < warmup_steps:
        return float(current_step) / float(max(1, warmup_steps))
    progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
    return 0.5 * (1.0 + math.cos(math.pi * progress))
scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda)

In [None]:
trainer.train_model(model=model, 
                    tokenizer=tok, 
                    train_dataloader=train_data, 
                    val_dataloader=val_data, 
                    writer=writer, 
                    grad_accum=grad_accum, 
                    max_norm=1.0, 
                    scheduler=scheduler, 
                    start_epoch=0,
                    path_to_save=CHECKPOINTS_PATH)
writer.flush()

Epoch 0, step 280 of 28228: Train loss = 5.100638389587402

In [None]:
tok.decode(generate(model=model, idx=torch.tensor(tok('Я хочу')['input_ids'], device='cuda').unsqueeze(0), max_new_tokens=15, context_size=1024).squeeze(0).tolist())

How to use tensorboard?  
tensorboard --logdir=GPT_training or you name (instead of GPT_training) or tensorboard --logdir=runs  
http://localhost:6006  

### Saving weights

In [None]:
#torch.save(model.state_dict(), "model.pth") # without state of optimizer
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': opt.state_dict(),
    }, "model_and_optimizer.pth") # with state of optimizer

### Loading weights

In [None]:
checkpoint = torch.load("model_and_optimizer.pth")
model = GPTModel(GPT_CONFIG)
model.load_state_dict(checkpoint['model_state_dict'])
opt = torch.optim.AdamW(model.parameters(), lr=0.01)
opt.load_state_dict(checkpoint['optimizer_state_dict'])
model.train()

In [None]:
torch.randint(0, 100, size=(10, 1024)).size()

In [None]:
model(torch.randint(0, 100, size=(50, 1024))).size()

### Load model all

In [23]:
params = {'N_EPOCHS': 1, 
          'verbose': True, 
          'verbose_freq': 140,
          'save_checkpoint_freq': 140,
          'sample_freq': 20,
          'gradients': True}

In [24]:
def lr_lambda(current_step):
    if current_step < warmup_steps:
        return float(current_step) / float(max(1, warmup_steps))
    progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
    return 0.5 * (1.0 + math.cos(math.pi * progress))

In [25]:
model = GPTModelRev(GPT_CONFIG)
model = model.to(device)
loaded_model_checkpoint = torch.load(os.path.join(CHECKPOINTS_PATH, "checkpoint_latest.pth"))
# checkpoint = {
#                         'epoch':           epoch,
#                         'model_state':     model.state_dict(),
#                         'optimizer_state': self.optimizer.state_dict(),
#                         'scheduler_state': scheduler.state_dict(),
#                     }
model.load_state_dict(loaded_model_checkpoint['model_state'])
opt = torch.optim.AdamW(params=model.parameters(), lr=3e-4)
opt.load_state_dict(loaded_model_checkpoint['optimizer_state'])
grad_accum = 1
warmup_ratio = 0.05  #0.1 # warm-up in first 10% of steps
total_steps = (len(train_data) // grad_accum) * params["N_EPOCHS"]
# Scheduler: linear warm-up + cosine decay
warmup_steps = int(warmup_ratio * total_steps)

scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda)
scheduler.load_state_dict(loaded_model_checkpoint['scheduler_state'])

trainer = Trainer(optimizer=opt, params=params, device=device)

In [26]:
amount_of_parameters = sum([p.numel() for p in model.parameters()])
amount_of_parameters

19452416

In [27]:
len(train_data)

28228

### New training

In [None]:
trainer.train_model(model=model, 
                    tokenizer=tok, 
                    train_dataloader=train_data, 
                    val_dataloader=val_data, 
                    writer=writer, 
                    grad_accum=grad_accum, 
                    max_norm=1.0, 
                    scheduler=scheduler, 
                    start_epoch=loaded_model_checkpoint['epoch'],
                    path_to_save=CHECKPOINTS_PATH)
writer.flush()

Epoch 0, step 0 of 28228: Train loss = 2.8201711177825928, sample: Я расскажу тебе про при при пораз
