# Imports

In [1]:
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.utils.checkpoint import checkpoint

import transformers
import pandas as pd
import os
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from dotenv import load_dotenv

load_dotenv()

True

Environment variables

In [2]:
PATH = os.getenv("PATH")
DATAPATH = os.getenv("DATAPATH")
PREPARED_DATA_DIR = os.getenv("PREPARED_DATA_DIR")
CACHE_DIR = os.getenv("CACHE_DIR")
#TOK_NAME = "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B"
TOK_NAME = os.getenv("TOK_NAME")
PARQUET_DATA_DIR = os.getenv("PARQUET_DATA_DIR")

## Config

In [3]:
GPT_CONFIG = {
    'vocab_size': 50257, # in 151670 (if you use tokenizer.vocab_size then you get partial vocab_size without added tokens)
    'context_length': 1024,
    'emb_dim': 768, #768
    'n_heads': 2,#12,
    'n_layers': 2,#12,
    'drop_rate': 0.05, # 0l1
    'qkv_bias': False
    }

In [4]:
device = 'cuda' if (torch.cuda.is_available()) else 'cpu'
device

'cuda'

In [5]:
torch.manual_seed(424242)
batch_size = 8
input_vector = torch.randint(0, 50000, size=(batch_size, 1024), device=device)
input_vector.shape

torch.Size([8, 1024])

In [6]:
GPTs = []

In [7]:
def forward(model, x_input):
    if x_input.grad is not None:
        x_input.grad.zero_()
    _ = model(x_input)

In [8]:
def forward_and_backward(model, x_input, loss):
    if x_input.grad is not None:
        x_input.grad.zero_()

    output = model(x_input)
    loss = loss(output.flatten(0, 1), x_input.flatten()) # Exasmple ^_^: loss = nn.functional.cross_entropy(logits.flatten(0, 1), y.flatten())
    loss.backward()

## Неоптимизированная архитектура

In [9]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0)

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        b, num_tokens, d_in = x.size()
        keys = self.W_key(x) # b, num_tokens, self.d_out
        queries = self.W_query(x) # b, num_tokens, self.d_out
        values = self.W_value(x) # b, num_tokens, self.d_out

        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)

        keys = keys.transpose(1, 2) # b, self.num_heads, num_tokens, self.head_dim
        queries = queries.transpose(1, 2) # b, self.num_heads, num_tokens, self.head_dim
        values = values.transpose(1, 2) # b, self.num_heads, num_tokens, self.head_dim

        att_scores = queries @ keys.transpose(2, 3) # shapes = (num_tokens, self.head_dim) @ (self.head_dim, num_tokens) -> (num_tokens, num_tokens)

        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        att_scores.masked_fill_(mask_bool, -torch.inf)

        att_weights = torch.softmax(att_scores / keys.shape[-1]**0.5, dim=-1)
        att_weights = self.dropout(att_weights)

        context_vec = (att_weights @ values).transpose(1, 2) # (num_tokens, num_tokens) @ (num_tokens, self.head_dim) -> (num_tokens, self.head_dim) -> transpose(1,2) of (b, self.num_heads, num_tokens, self.head_dim) ->
        # -> (b, num_tokens, self.num_heads, self.head_dim) as view in previous code after inference of Linear layers
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)
        return context_vec

In [10]:
class LayerNorm(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.eps = 1e-5
        self.scale = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim))
    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        norm_x = (x - mean) / torch.sqrt(var + self.eps)
        return self.scale * norm_x + self.shift

In [11]:
class GELU(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(
            torch.sqrt(torch.tensor(2.0 / torch.pi)) * 
            (x + 0.044715 * torch.pow(x, 3))
        ))

In [12]:
class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(cfg['emb_dim'], 4 * cfg['emb_dim']),
            GELU(),
            nn.Linear(4 * cfg['emb_dim'], cfg['emb_dim'])
        )
    def forward(self, x):
        return self.layers(x)

In [13]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.attn = MultiHeadAttention(d_in=cfg['emb_dim'], 
                                       d_out=cfg['emb_dim'], 
                                       context_length=cfg['context_length'], 
                                       dropout=cfg['drop_rate'], 
                                       num_heads=cfg['n_heads'], 
                                       qkv_bias=cfg['qkv_bias'])
        self.ff = FeedForward(cfg)
        self.norm1 = LayerNorm(cfg['emb_dim'])
        self.norm2 = LayerNorm(cfg['emb_dim'])
        self.drop_resid = nn.Dropout(cfg['drop_rate'])
    
    def forward(self, x):
        #x = x + self.drop_resid(self.attn(self.norm1(x)))
        #x = x + self.drop_resid(self.ff(self.norm2(x)))
        shortcut = x
        x = self.norm1(x)
        x = self.attn(x)
        x = self.drop_resid(x)
        x = x + shortcut

        shortcut = x
        x = self.norm2(x)
        x = self.ff(x)
        x = self.drop_resid(x)
        x = x + shortcut
        return x

In [14]:
class GPTModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.tok_emb = nn.Embedding(cfg['vocab_size'], cfg['emb_dim'])
        self.pos_emb = nn.Embedding(cfg['context_length'], cfg['emb_dim'])
        self.drop_emb = nn.Dropout(cfg['drop_rate'])

        self.trf_blocks = nn.Sequential(*[TransformerBlock(cfg) for _ in range(cfg['n_layers'])])
        self.final_norm = nn.LayerNorm(cfg['emb_dim'])
        self.out_head = nn.Linear(cfg['emb_dim'], cfg['vocab_size'], bias=False)

    def forward(self, in_idx):
        batch_size, seq_len = in_idx.size()
        tok_embeds = self.tok_emb(in_idx)
        pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
        x = tok_embeds + pos_embeds
        x = self.drop_emb(x)
        x = self.trf_blocks(x)
        x = self.final_norm(x)
        logits = self.out_head(x)
        return logits

### Check

In [16]:
vanilla_GPT = GPTModel(GPT_CONFIG)
vanilla_GPT.to(device)

GPTModel(
  (tok_emb): Embedding(50257, 768)
  (pos_emb): Embedding(1024, 768)
  (drop_emb): Dropout(p=0.05, inplace=False)
  (trf_blocks): Sequential(
    (0): TransformerBlock(
      (attn): MultiHeadAttention(
        (W_key): Linear(in_features=768, out_features=768, bias=False)
        (W_query): Linear(in_features=768, out_features=768, bias=False)
        (W_value): Linear(in_features=768, out_features=768, bias=False)
        (out_proj): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.05, inplace=False)
      )
      (ff): FeedForward(
        (layers): Sequential(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU()
          (2): Linear(in_features=3072, out_features=768, bias=True)
        )
      )
      (norm1): LayerNorm()
      (norm2): LayerNorm()
      (drop_resid): Dropout(p=0.05, inplace=False)
    )
    (1): TransformerBlock(
      (attn): MultiHeadAttention(
        (W_key): Linear(in_feature

In [17]:
%%timeit -n 25 -r 25
forward(vanilla_GPT, input_vector)

108 ms ± 4.49 ms per loop (mean ± std. dev. of 25 runs, 25 loops each)


## Replace with Pytohch built-in functions, but with current handwritten self-attention

In [9]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0)

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        b, num_tokens, d_in = x.size()
        keys = self.W_key(x) # b, num_tokens, self.d_out
        queries = self.W_query(x) # b, num_tokens, self.d_out
        values = self.W_value(x) # b, num_tokens, self.d_out

        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)

        keys = keys.transpose(1, 2) # b, self.num_heads, num_tokens, self.head_dim
        queries = queries.transpose(1, 2) # b, self.num_heads, num_tokens, self.head_dim
        values = values.transpose(1, 2) # b, self.num_heads, num_tokens, self.head_dim

        att_scores = queries @ keys.transpose(2, 3) # shapes = (num_tokens, self.head_dim) @ (self.head_dim, num_tokens) -> (num_tokens, num_tokens)

        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        att_scores.masked_fill_(mask_bool, -torch.inf)

        att_weights = torch.softmax(att_scores / keys.shape[-1]**0.5, dim=-1)
        att_weights = self.dropout(att_weights)

        context_vec = (att_weights @ values).transpose(1, 2) # (num_tokens, num_tokens) @ (num_tokens, self.head_dim) -> (num_tokens, self.head_dim) -> transpose(1,2) of (b, self.num_heads, num_tokens, self.head_dim) ->
        # -> (b, num_tokens, self.num_heads, self.head_dim) as view in previous code after inference of Linear layers
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)
        return context_vec

In [10]:
class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(cfg['emb_dim'], 4 * cfg['emb_dim']),
            nn.GELU(), #GELU(),
            nn.Linear(4 * cfg['emb_dim'], cfg['emb_dim'])
        )
    def forward(self, x):
        return self.layers(x)

In [11]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.attn = MultiHeadAttention(d_in=cfg['emb_dim'], 
                                       d_out=cfg['emb_dim'], 
                                       context_length=cfg['context_length'], 
                                       dropout=cfg['drop_rate'], 
                                       num_heads=cfg['n_heads'], 
                                       qkv_bias=cfg['qkv_bias'])
        self.ff = FeedForward(cfg)
        self.norm1 = nn.LayerNorm(cfg['emb_dim']) #LayerNorm(cfg['emb_dim'])
        self.norm2 = nn.LayerNorm(cfg['emb_dim']) #LayerNorm(cfg['emb_dim'])
        self.drop_resid = nn.Dropout(cfg['drop_rate'])
    
    def forward(self, x):
        #x = x + self.drop_resid(self.attn(self.norm1(x)))
        #x = x + self.drop_resid(self.ff(self.norm2(x)))
        shortcut = x
        x = self.norm1(x)
        x = self.attn(x)
        x = self.drop_resid(x)
        x = x + shortcut

        shortcut = x
        x = self.norm2(x)
        x = self.ff(x)
        x = self.drop_resid(x)
        x = x + shortcut
        return x

In [12]:
class GPTModelPT(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.tok_emb = nn.Embedding(cfg['vocab_size'], cfg['emb_dim'])
        self.pos_emb = nn.Embedding(cfg['context_length'], cfg['emb_dim'])
        self.drop_emb = nn.Dropout(cfg['drop_rate'])

        self.trf_blocks = nn.Sequential(*[TransformerBlock(cfg) for _ in range(cfg['n_layers'])])
        self.final_norm = nn.LayerNorm(cfg['emb_dim'])
        self.out_head = nn.Linear(cfg['emb_dim'], cfg['vocab_size'], bias=False)

    def forward(self, in_idx):
        batch_size, seq_len = in_idx.size()
        tok_embeds = self.tok_emb(in_idx)
        pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
        x = tok_embeds + pos_embeds
        x = self.drop_emb(x)
        x = self.trf_blocks(x)
        x = self.final_norm(x)
        logits = self.out_head(x)
        return logits

### Check

In [22]:
PyTorchGPT = GPTModelPT(GPT_CONFIG)
PyTorchGPT.to(device)

GPTModelPT(
  (tok_emb): Embedding(50257, 768)
  (pos_emb): Embedding(1024, 768)
  (drop_emb): Dropout(p=0.05, inplace=False)
  (trf_blocks): Sequential(
    (0): TransformerBlock(
      (attn): MultiHeadAttention(
        (W_key): Linear(in_features=768, out_features=768, bias=False)
        (W_query): Linear(in_features=768, out_features=768, bias=False)
        (W_value): Linear(in_features=768, out_features=768, bias=False)
        (out_proj): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.05, inplace=False)
      )
      (ff): FeedForward(
        (layers): Sequential(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=3072, out_features=768, bias=True)
        )
      )
      (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (drop_resid): Dropout(p=0.05, inplace=Fals

In [23]:
%%timeit -n 25 -r 25
forward(PyTorchGPT, input_vector)

92.2 ms ± 13.2 ms per loop (mean ± std. dev. of 25 runs, 25 loops each)


## Add some little optimizations

In [13]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.attn = MultiHeadAttention(d_in=cfg['emb_dim'], 
                                       d_out=cfg['emb_dim'], 
                                       context_length=cfg['context_length'], 
                                       dropout=cfg['drop_rate'], 
                                       num_heads=cfg['n_heads'], 
                                       qkv_bias=cfg['qkv_bias'])
        self.ff = FeedForward(cfg)
        self.norm1 = nn.LayerNorm(cfg['emb_dim']) #LayerNorm(cfg['emb_dim'])
        self.norm2 = nn.LayerNorm(cfg['emb_dim']) #LayerNorm(cfg['emb_dim'])
        self.drop_resid = nn.Dropout(cfg['drop_rate'])
    
    def forward(self, x):
        x = x + self.drop_resid(self.attn(self.norm1(x)))
        return x + self.drop_resid(self.ff(self.norm2(x)))

In [14]:
class GPTModelOPT(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.tok_emb = nn.Embedding(cfg['vocab_size'], cfg['emb_dim'])
        self.pos_emb = nn.Embedding(cfg['context_length'], cfg['emb_dim'])
        self.drop_emb = nn.Dropout(cfg['drop_rate'])

        self.trf_blocks = nn.Sequential(*[TransformerBlock(cfg) for _ in range(cfg['n_layers'])])
        self.final_norm = nn.LayerNorm(cfg['emb_dim'])
        self.out_head = nn.Linear(cfg['emb_dim'], cfg['vocab_size'], bias=False)

    def forward(self, in_idx):
        _, seq_len = in_idx.size()
        return self.out_head(self.final_norm(self.trf_blocks(self.drop_emb(self.tok_emb(in_idx) + self.pos_emb(torch.arange(seq_len, device=in_idx.device))))))

### Check

In [15]:
PyTorchGPTOpt = GPTModelOPT(GPT_CONFIG)
PyTorchGPTOpt.to(device)

GPTModelOPT(
  (tok_emb): Embedding(50257, 768)
  (pos_emb): Embedding(1024, 768)
  (drop_emb): Dropout(p=0.05, inplace=False)
  (trf_blocks): Sequential(
    (0): TransformerBlock(
      (attn): MultiHeadAttention(
        (W_key): Linear(in_features=768, out_features=768, bias=False)
        (W_query): Linear(in_features=768, out_features=768, bias=False)
        (W_value): Linear(in_features=768, out_features=768, bias=False)
        (out_proj): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.05, inplace=False)
      )
      (ff): FeedForward(
        (layers): Sequential(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=3072, out_features=768, bias=True)
        )
      )
      (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (drop_resid): Dropout(p=0.05, inplace=Fal

In [16]:
%%timeit -n 25 -r 25
forward(PyTorchGPTOpt, input_vector)

92 ms ± 8.44 ms per loop (mean ± std. dev. of 25 runs, 25 loops each)


## Add checkpoints

In [14]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.attn = MultiHeadAttention(d_in=cfg['emb_dim'], 
                                       d_out=cfg['emb_dim'], 
                                       context_length=cfg['context_length'], 
                                       dropout=cfg['drop_rate'], 
                                       num_heads=cfg['n_heads'], 
                                       qkv_bias=cfg['qkv_bias'])
        self.ff = FeedForward(cfg)
        self.norm1 = nn.LayerNorm(cfg['emb_dim']) #LayerNorm(cfg['emb_dim'])
        self.norm2 = nn.LayerNorm(cfg['emb_dim']) #LayerNorm(cfg['emb_dim'])
        self.drop_resid = nn.Dropout(cfg['drop_rate'])
    
    def forward(self, x):
        x = x + self.drop_resid(self.attn(self.norm1(x)))
        return x + self.drop_resid(self.ff(self.norm2(x)))

In [None]:
class GPTModelOPT(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.tok_emb = nn.Embedding(cfg['vocab_size'], cfg['emb_dim'])
        self.pos_emb = nn.Embedding(cfg['context_length'], cfg['emb_dim'])
        self.drop_emb = nn.Dropout(cfg['drop_rate'])

        self.trf_blocks = nn.Sequential(*[TransformerBlock(cfg) for _ in range(cfg['n_layers'])])
        self.final_norm = nn.LayerNorm(cfg['emb_dim'])
        self.out_head = nn.Linear(cfg['emb_dim'], cfg['vocab_size'], bias=False)

    def forward(self, in_idx):
        _, seq_len = in_idx.size()
        return self.out_head(self.final_norm(self.trf_blocks(self.drop_emb(self.tok_emb(in_idx) + self.pos_emb(torch.arange(seq_len, device=in_idx.device))))))

## ReversibleBlocks + checkpoints

In [23]:
class FeedForward(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(emb_dim, 4 * emb_dim),
            nn.GELU(), #GELU(),
            nn.Linear(4 * emb_dim, emb_dim)
        )
    def forward(self, x):
        return self.layers(x)

In [24]:
class ReversibleTransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        emb_dim = cfg['emb_dim']//2
        self.attn = MultiHeadAttention(d_in=emb_dim, 
                                       d_out=emb_dim, 
                                       context_length=cfg['context_length'], 
                                       dropout=cfg['drop_rate'], 
                                       num_heads=cfg['n_heads'], 
                                       qkv_bias=cfg['qkv_bias'])
        self.ff = FeedForward(emb_dim)
        self.norm1 = nn.LayerNorm(emb_dim) #LayerNorm(cfg['emb_dim'])
        self.norm2 = nn.LayerNorm(emb_dim) #LayerNorm(cfg['emb_dim'])
        self.drop_resid = nn.Dropout(cfg['drop_rate'])

    def forward(self, x1, x2):
        # reversible update
        # y1 = x1 + f(x2)
        # y2 = x2 + g(y1)

        def f(u):
            u = self.norm1(u)
            attn_output = self.attn(u)
            attn_output = self.drop_resid(attn_output)
            return attn_output
        
        def g(v):
            return self.drop_resid(self.ff(self.norm2(v)))
        
        f_x2 = checkpoint(f, x2)
        y1 = x1 + f_x2
        g_y1 = checkpoint(g, y1)
        y2 = x2 + g_y1

        return y1, y2

In [25]:
class GPTModelRev(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.tok_emb = nn.Embedding(cfg['vocab_size'], cfg['emb_dim'])
        self.pos_emb = nn.Embedding(cfg['context_length'], cfg['emb_dim'])
        self.drop_emb = nn.Dropout(cfg['drop_rate'])

        self.trf_blocks = nn.Sequential(*[ReversibleTransformerBlock(cfg) for _ in range(cfg['n_layers'])])
        self.final_norm = nn.LayerNorm(cfg['emb_dim'])
        self.out_head = nn.Linear(cfg['emb_dim'], cfg['vocab_size'], bias=False)

    def forward(self, in_idx):
        batch_size, seq_len = in_idx.size()
        tok_embeds = self.tok_emb(in_idx)
        pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
        x = tok_embeds + pos_embeds
        x = self.drop_emb(x)

        # initialize reversible pairs: split features
        # split last dim
        x1, x2 = torch.chunk(x, 2, dim=-1)  # each (batch_size, seq_len, emb_dim//2)

        # Now we change x = self.trf_blocks(x) to: 
        for layer in self.trf_blocks:
            x1, x2 = layer(x1, x2)
        # merge
        x = torch.cat([x1, x2], dim=-1)  # (b, s, dim)
        
        # Now as usual
        x = self.final_norm(x)
        logits = self.out_head(x)
        return logits

### Check

In [26]:
PyTorchRev = GPTModelRev(GPT_CONFIG)
PyTorchRev.to(device)

GPTModelRev(
  (tok_emb): Embedding(50257, 192)
  (pos_emb): Embedding(1024, 192)
  (drop_emb): Dropout(p=0.05, inplace=False)
  (trf_blocks): Sequential(
    (0): ReversibleTransformerBlock(
      (attn): MultiHeadAttention(
        (W_key): Linear(in_features=96, out_features=96, bias=False)
        (W_query): Linear(in_features=96, out_features=96, bias=False)
        (W_value): Linear(in_features=96, out_features=96, bias=False)
        (out_proj): Linear(in_features=96, out_features=96, bias=True)
        (dropout): Dropout(p=0.05, inplace=False)
      )
      (ff): FeedForward(
        (layers): Sequential(
          (0): Linear(in_features=96, out_features=384, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=384, out_features=96, bias=True)
        )
      )
      (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
      (drop_resid): Dropout(p=0.05, inplace=False)


In [27]:
%%timeit -n 25 -r 25
forward(PyTorchRev, input_vector)

  return fn(*args, **kwargs)


21 ms ± 1.69 ms per loop (mean ± std. dev. of 25 runs, 25 loops each)


Для большой сети (т.е. эксперты с маленьким размером не смогут нормально оубчиться, плюс обучать мое дорого, имеет смысл, только есть есть несколько нод и обучать на каждой ноде своего эксперта, чтобы снизить накладные расходы)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ReversibleBlock(nn.Module):
    def __init__(self, f, g):
        super().__init__()
        self.f = f
        self.g = g

    def forward(self, x1, x2, mask=None):
        y1 = x1 + self.f(x2, mask)
        y2 = x2 + self.g(y1)
        return y1, y2

    def backward_pass(self, y1, y2):
        x2 = y2 - self.g(y1)
        x1 = y1 - self.f(x2)
        return x1, x2

class MoE(nn.Module):
    def __init__(self, d_model, d_ff, num_experts=4, top_k=1):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        # Experts: feed-forward layers
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, d_ff),
                nn.GELU(),
                nn.Linear(d_ff, d_model)
            ) for _ in range(num_experts)
        ])
        # Gating network: per-token
        self.gate = nn.Linear(d_model, num_experts)

    def forward(self, x):
        # x: (seq_len, batch, d_model) or (batch, seq, d_model)
        is_seq_first = x.dim() == 3 and x.size(0) != x.size(1)
        if is_seq_first:
            seq, bsz, d = x.size()
            tokens = x.view(-1, d)
        else:
            bsz, seq, d = x.size()
            tokens = x.view(-1, d)
        # Gating
        gate_scores = F.softmax(self.gate(tokens), dim=-1)  # (tokens, experts)
        topk_vals, topk_idx = torch.topk(gate_scores, self.top_k, dim=-1)
        # Compute expert outputs and combine
        expert_outputs = torch.stack([exp(tokens) for exp in self.experts], dim=1)  # (tokens, experts, d)
        # Mask all but top-k
        mask = torch.zeros_like(gate_scores)
        for i in range(self.top_k):
            mask.scatter_(1, topk_idx[:, i:i+1], topk_vals[:, i:i+1])
        combined = (expert_outputs * mask.unsqueeze(-1)).sum(dim=1)
        # reshape back
        if is_seq_first:
            return combined.view(seq, bsz, d)
        else:
            return combined.view(bsz, seq, d)

class TransformerLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1,
                 reversible=False, moe_experts=0, moe_topk=1):
        super().__init__()
        self.reversible = reversible

        self.norm1 = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
        self.norm2 = nn.LayerNorm(d_model)
        # decide FFN or MoE
        if moe_experts > 1:
            self.ffn = MoE(d_model, d_ff, num_experts=moe_experts, top_k=moe_topk)
        else:
            self.ffn = nn.Sequential(
                nn.Linear(d_model, d_ff),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Linear(d_ff, d_model)
            )

        if reversible:
            self.block = ReversibleBlock(
                lambda x, mask: self.attn(self.norm1(x), self.norm1(x), self.norm1(x), attn_mask=mask)[0],
                lambda x, _: self.ffn(self.norm2(x))
            )

    def forward(self, x, mask=None):
        if self.reversible:
            x1, x2 = x.chunk(2, dim=-1)
            y1, y2 = self.block(x1, x2, mask)
            return torch.cat([y1, y2], dim=-1)

        # Standard residual
        attn_out, _ = self.attn(self.norm1(x), self.norm1(x), self.norm1(x), attn_mask=mask)
        x = x + attn_out
        x = x + self.ffn(self.norm2(x))
        return x

class GPTReversible(nn.Module):
    def __init__(
        self,
        vocab_size,
        d_model=512,
        num_heads=8,
        d_ff=2048,
        num_layers=12,
        seq_len=1024,
        dropout=0.1,
        reversible=True,
        shared=False,
        moe_experts=0,
        moe_topk=1
    ):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Parameter(torch.zeros(1, seq_len, d_model))
        # instantiate layers once if shared, else new each time
        if shared:
            layer = TransformerLayer(d_model, num_heads, d_ff, dropout,
                                     reversible, moe_experts, moe_topk)
            self.layers = nn.ModuleList([layer] * num_layers)
        else:
            self.layers = nn.ModuleList([
                TransformerLayer(d_model, num_heads, d_ff, dropout,
                                 reversible, moe_experts, moe_topk)
                for _ in range(num_layers)
            ])
        self.norm_final = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, input_ids, mask=None):
        x = self.token_emb(input_ids) + self.pos_emb[:, :input_ids.size(1)]
        if self.layers[0].reversible and x.size(-1) % 2 != 0:
            raise ValueError("d_model must be even for reversible transformer.")

        for layer in self.layers:
            x = layer(x, mask)
        x = self.norm_final(x)
        return self.head(x)

if __name__ == "__main__":
    model = GPTReversible(
        vocab_size=30522,
        d_model=512,
        num_heads=8,
        d_ff=2048,
        num_layers=12,
        seq_len=1024,
        dropout=0.1,
        reversible=True,
        shared=True,        # Cross-layer sharing
        moe_experts=4,      # Sparse Mixture-of-Experts
        moe_topk=1
    )
    input_ids = torch.randint(0, 30522, (2, 128))
    logits = model(input_ids)
    print(logits.shape)


То же самое, но без MoE, подойдет, вероятно

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ReversibleBlock(nn.Module):
    def __init__(self, f, g):
        super().__init__()
        self.f = f
        self.g = g

    def forward(self, x1, x2, mask=None):
        y1 = x1 + self.f(x2, mask)
        y2 = x2 + self.g(y1)
        return y1, y2

    def backward_pass(self, y1, y2):
        x2 = y2 - self.g(y1)
        x1 = y1 - self.f(x2)
        return x1, x2

class TransformerLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1, reversible=False):
        super().__init__()
        self.reversible = reversible

        self.norm1 = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
        self.norm2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )

        if reversible:
            self.block = ReversibleBlock(
                lambda x, mask: self.attn(self.norm1(x), self.norm1(x), self.norm1(x), attn_mask=mask)[0],
                lambda x, _: self.ffn(self.norm2(x))
            )

    def forward(self, x, mask=None):
        if self.reversible:
            x1, x2 = x.chunk(2, dim=-1)
            y1, y2 = self.block(x1, x2, mask)
            return torch.cat([y1, y2], dim=-1)

        # Standard residual
        attn_out, _ = self.attn(self.norm1(x), self.norm1(x), self.norm1(x), attn_mask=mask)
        x = x + attn_out
        x = x + self.ffn(self.norm2(x))
        return x

class GPTReversible(nn.Module):
    def __init__(
        self,
        vocab_size,
        d_model=512,
        num_heads=8,
        d_ff=2048,
        num_layers=12,
        seq_len=1024,
        dropout=0.1,
        reversible=True
    ):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Parameter(torch.zeros(1, seq_len, d_model))
        self.layers = nn.ModuleList([
            TransformerLayer(d_model, num_heads, d_ff, dropout, reversible)
            for _ in range(num_layers)
        ])
        self.norm_final = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, input_ids, mask=None):
        x = self.token_emb(input_ids) + self.pos_emb[:, :input_ids.size(1)]
        # For reversible, input dimension must be even
        if self.layers[0].reversible and x.size(-1) % 2 != 0:
            raise ValueError("d_model must be even for reversible transformer.")

        for layer in self.layers:
            x = layer(x, mask)
        x = self.norm_final(x)
        return self.head(x)

if __name__ == "__main__":
    model = GPTReversible(
        vocab_size=30522,
        d_model=512,
        num_heads=8,
        d_ff=2048,
        num_layers=12,
        seq_len=1024,
        dropout=0.1,
        reversible=True
    )
    input_ids = torch.randint(0, 30522, (2, 128))
    logits = model(input_ids)
    print(logits.shape)


Много кода с реализацией LoRA - адаптера для последующего дообучения (тоже не особо актуально для маленьких моделей)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# LoRA adapter for Linear layers, with optional freezing and zero-initialization of base
class LoRALinear(nn.Module):
    def __init__(self, in_features, out_features, r=4, alpha=1.0,
                 bias=True, freeze_base=False, zero_init_base=False):
        super().__init__()
        self.r = r
        self.alpha = alpha
        # base weight
        self.linear = nn.Linear(in_features, out_features, bias=bias)
        if zero_init_base:
            nn.init.zeros_(self.linear.weight)
            if bias:
                nn.init.zeros_(self.linear.bias)
        if freeze_base:
            for param in self.linear.parameters():
                param.requires_grad = False
        # low-rank adapters
        if r > 0:
            self.lora_down = nn.Linear(in_features, r, bias=False)
            self.lora_up = nn.Linear(r, out_features, bias=False)
            nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
            nn.init.zeros_(self.lora_up.weight)
        else:
            self.lora_down = self.lora_up = None

    def forward(self, x):
        out = self.linear(x)
        if self.r > 0:
            delta = self.lora_up(self.lora_down(x)) * (self.alpha / self.r)
            out = out + delta
        return out

# Example reversible transformer block (for memory optimization)
class ReversibleBlock(nn.Module):
    def __init__(self, f, g):
        super().__init__()
        self.f = f
        self.g = g
    def forward(self, x1, x2, mask=None):
        y1 = x1 + self.f(x2, mask)
        y2 = x2 + self.g(y1)
        return y1, y2
    def backward_pass(self, y1, y2):
        x2 = y2 - self.g(y1)
        x1 = y1 - self.f(x2)
        return x1, x2

# Flexible transformer block supporting reversible mode
class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1,
                 reversible=False, **lora_kwargs):
        super().__init__()
        self.reversible = reversible
        self.attn = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.MultiheadAttention(d_model, num_heads, dropout=dropout),
        )
        self.ffn = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
        )
        if reversible:
            self.block = ReversibleBlock(self.attn, self.ffn)

    def forward(self, x, mask=None):
        if self.reversible:
            x1, x2 = torch.chunk(x, 2, dim=-1)
            y1, y2 = self.block(x1, x2, mask)
            return torch.cat([y1, y2], dim=-1)
        # standard residual
        attn_out, _ = self.attn[1](self.attn[0](x), self.attn[0](x), self.attn[0](x), attn_mask=mask)
        x = x + attn_out
        x = x + self.ffn(x)
        return x

# GPT-like model supporting training from scratch with various optimizations
class GPTOptimized(nn.Module):
    def __init__(self, vocab_size, d_model=512, num_heads=8, d_ff=2048,
                 num_layers=12, exit_layers=None, reversible=False,
                 shared=False, lora_r=0, lora_alpha=1.0,
                 freeze_base=False, zero_init_base=False):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Parameter(torch.zeros(1, 1024, d_model))
        blocks = []
        block_module = TransformerBlock
        for _ in range(num_layers):
            blocks.append(
                block_module(d_model, num_heads, d_ff,
                             reversible=reversible,
                             r=lora_r, alpha=lora_alpha,
                             freeze_base=freeze_base,
                             zero_init_base=zero_init_base)
            )
        self.blocks = nn.ModuleList(blocks)
        self.ln_final = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)
        self.exit_layers = set(exit_layers or [])
        if exit_layers:
            self.exit_heads = nn.ModuleDict({
                str(l): nn.Linear(d_model, vocab_size) for l in exit_layers
            })

    def forward(self, input_ids, mask=None, exit_thresh=0.9):
        x = self.token_emb(input_ids) + self.pos_emb[:, :input_ids.size(1)]
        for idx, block in enumerate(self.blocks, 1):
            x = block(x, mask)
            if idx in self.exit_layers:
                logits = self.exit_heads[str(idx)](x.mean(dim=1))
                probs = F.softmax(logits, dim=-1)
                if probs.max(dim=-1).values.mean() > exit_thresh:
                    return logits
        x = self.ln_final(x)
        return self.head(x)

if __name__ == "__main__":
    # Пример обучения с нуля, reversible=False, full-trainable linear
    model = GPTOptimized(
        vocab_size=30522,
        d_model=512,
        num_heads=8,
        d_ff=2048,
        num_layers=6,
        exit_layers=[3, 6],
        reversible=False,
        shared=False,
        lora_r=0,
        lora_alpha=1.0,
        freeze_base=False,
        zero_init_base=False,
    )
    input_ids = torch.randint(0, 30522, (2, 128))
    logits = model(input_ids)
    print(logits.shape)

Ускоренный (по возможности) инференс и обучение (но пока без рефакторинга)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from torch.utils.data import DataLoader, Dataset

# Fused LayerNorm for speed
class FusedLayerNorm(nn.LayerNorm):
    def forward(self, x):
        return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)

class TransformerLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.norm1 = FusedLayerNorm(d_model)
        self.norm2 = FusedLayerNorm(d_model)
        self.num_heads = num_heads
        # Fused QKV projection
        self.attn_proj = nn.Linear(d_model, d_model * 3, bias=False)
        self.out_proj = nn.Linear(d_model, d_model, bias=False)
        # Feed-forward
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff, bias=False),
            nn.GELU(),
            nn.Linear(d_ff, d_model, bias=False)
        )
        self.dropout = nn.Dropout(dropout)

    def attention(self, x, mask=None):
        B, S, D = x.size()
        qkv = self.attn_proj(x).reshape(B, S, 3, self.num_heads, D // self.num_heads)
        q, k, v = qkv.unbind(dim=2)
        # flash attention
        attn = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=self.dropout.p)
        attn = attn.transpose(1, 2).reshape(B, S, D)
        return self.out_proj(attn)

    def forward(self, x, mask=None):
        def layer_fn(x, mask):
            y = x + self.dropout(self.attention(self.norm1(x), mask))
            y = y + self.dropout(self.ffn(self.norm2(y)))
            return y
        return checkpoint(layer_fn, x, mask)

class GPTOptimized(nn.Module):
    def __init__(self, vocab_size, d_model=512, num_heads=8, d_ff=2048,
                 num_layers=12, seq_len=1024, dropout=0.1):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Parameter(torch.zeros(1, seq_len, d_model))
        self.layers = nn.ModuleList([
            TransformerLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        self.norm_final = FusedLayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, input_ids, mask=None):
        x = self.token_emb(input_ids) + self.pos_emb[:, :input_ids.size(1)]
        for layer in self.layers:
            x = layer(x, mask)
        x = self.norm_final(x)
        return self.head(x)

# Dummy dataset for language modeling
def get_dummy_dataset(vocab_size, seq_len, dataset_size=10000):
    class RandomDataset(Dataset):
        def __init__(self, size):
            self.size = size
        def __len__(self):
            return self.size
        def __getitem__(self, idx):
            data = torch.randint(0, vocab_size, (seq_len,), dtype=torch.long)
            # labels = next token prediction (shifted)
            return data, data
    return RandomDataset(dataset_size)


def train(
    model, dataloader, optimizer, scheduler, device,
    epochs=1, grad_accum=1, max_norm=1.0
):
    model.train()
    scaler = torch.cuda.amp.GradScaler()
    for epoch in range(1, epochs + 1):
        total_loss = 0.0
        optimizer.zero_grad()
        for step, (inputs, labels) in enumerate(dataloader, 1):
            inputs = inputs.to(device)
            labels = labels.to(device)
            with torch.cuda.amp.autocast():
                logits = model(inputs)
                loss = F.cross_entropy(
                    logits.view(-1, logits.size(-1)),
                    labels.view(-1)
                ) / grad_accum
            scaler.scale(loss).backward()
            total_loss += loss.item()

            if step % grad_accum == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(
                    model.parameters(), max_norm
                )
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                if scheduler:
                    scheduler.step()

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch}: avg loss = {avg_loss:.4f}")

if __name__ == "__main__":
    # Config
    vocab_size = 30522
    seq_len = 128
    batch_size = 16
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Model
    model = GPTOptimized(vocab_size=vocab_size, seq_len=seq_len)
    model = torch.compile(model, backend='inductor')
    model = model.half().to(device)

    # Data
    dataset = get_dummy_dataset(vocab_size, seq_len)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Optimizer & scheduler
    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=10
    )

    # Train
    train(
        model, dataloader, optimizer, scheduler, device,
        epochs=3, grad_accum=2
    )
