In [65]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from datasets import Dataset

In [41]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [7]:
# Collect the data

# from datasets import load_dataset

# habr = load_dataset("IlyaGusev/habr")

Loading dataset shards:   0%|          | 0/38 [00:00<?, ?it/s]

In [67]:
from typing import Generator

def data_iterator(batch_size: int, dataset: Dataset) -> Generator[list[str], None, None]:
    for j in range(0, len(dataset), batch_size):
        yield dataset[j:batch_size+j]['text']

In [68]:
# Tokenizer

from tokenizers import ByteLevelBPETokenizer, Tokenizer, models, pre_tokenizers, trainers, processors, decoders

# tokenizer = ByteLevelBPETokenizer()

# tokenizer.normalizer = normalizers.Sequence([normalizers.Replace(Regex('[(]([[:space:]]*[, ]?)[)]'), ''), normalizers.Replace(Regex('[[:space:]]+'), ' ')])
#     
# tokenizer.train_from_iterator(tokenizer_train_iterator(5, books.values(), peS2o.values(), wiki.values()), vocab_size=8192, min_frequency=2, special_tokens=['[EOS]'])

# tokenizer.save('tokenizer.json')

tokenizer = Tokenizer.from_file('tokenizer.json')

# Char-level BPE
# tokenizer = Tokenizer(models.BPE(unk_token='[UNK]', fuse_unk=True))
# tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
# trainer = trainers.BpeTrainer(vocab_size=8192, min_frequency=2, special_tokens=['[EOS]', '[UNK]'])
# tokenizer.train_from_iterator(mini_peS2o['train'][:1000]['text'], trainer=trainer)
# tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
# tokenizer.decoder = decoders.ByteLevel()
# tokenizer.save('tokenizer2.json')
# tokenizer = Tokenizer.from_file('tokenizer.json')

In [9]:
# Datasets splits and merges

# habr_split = habr['train'].train_test_split(0.05)

# data_train = habr_split['train']
# data_val = habr_split['test']

In [69]:
from operator import attrgetter
from tqdm import tqdm
import os

def tokenize_dataset(dataset: Dataset, 
                     tokenizer: Tokenizer, 
                     path: str, 
                     batch_size: int, 
                     shard_size: int, 
                     eos_tok: str) -> int:

    os.makedirs(path, exist_ok=True)

    shard = torch.empty(shard_size, dtype=torch.int16)

    p = 0
    k = 0

    eos = tokenizer.encode(eos_tok).ids 

    bar = tqdm(total=len(dataset), desc=f'Shard 0')
    for seq_batch in data_iterator(batch_size, dataset):
        enc = tokenizer.encode_batch(seq_batch)
        
        for ids in map(attrgetter('ids'), enc):
            
            if p + len(ids) + 1 >= shard_size:
                shard[p:] = torch.tensor(eos + ids, dtype=torch.int16)[:shard_size-p]
                torch.save(shard, f'{path}/shard_{k}.pt')
                p = 0
                k += 1
                bar.set_description(f'Shard {k}')
            else:
                shard[p:p+len(ids)+1] = torch.tensor(eos + ids, dtype=torch.int16)
                p += len(ids) + 1

        bar.update(batch_size)

    torch.save(shard[:p].clone(), f'{path}/shard_{k}.pt')

    return shard_size * k + p

In [70]:
# Sharded torch dataset

import os

class ShardedDataset(torch.utils.data.Dataset):

    def __init__(self, n_ctx: int, path: str):
        self.n_ctx = n_ctx
        self.curr_shard_idx = 0
        self.shard_files = self._get_shards(path)
        self.length = self._count_length()
        self.bound = self.load_shard(self.curr_shard_idx)

    def _get_shards(self, path: str) -> list[str]:
        return sorted([os.path.join(path, f) for f in os.listdir(path) if f.endswith('.pt')])

    def _count_length(self) -> int:
        shrad_size = len(torch.load(self.shard_files[0]))
        last_size = len(torch.load(self.shard_files[-1]))
    
        return ((shrad_size - shrad_size % self.n_ctx)* (len(self.shard_files) - 1) 
              + last_size - last_size % self.n_ctx) // self.n_ctx

    def load_shard(self, shard_idx: int) -> int:
        self.data = torch.load(self.shard_files[shard_idx])
        self.data = self.data[:-(len(self.data)%self.n_ctx)]
        self.curr_shard_len = len(self.data) // self.n_ctx
        return self.curr_shard_len

    def __len__(self) -> int:
        return self.length

    def __getitem__(self, idx) -> tuple[torch.Tensor, torch.Tensor]:
        shard_idx = idx // self.bound
        
        if shard_idx != self.curr_shard_idx:
            self.curr_shard_idx = shard_idx
            self.load_shard(self.curr_shard_idx)             

        local_idx = idx % self.curr_shard_len
        x = self.data[self.n_ctx*local_idx:self.n_ctx*(local_idx+1)].long()
        y = torch.cat((x[1:], torch.zeros(1, dtype=torch.long)))

        return x, y

In [71]:
# Model

class FeedForward(nn.Module):

    def __init__(self, n_embed: int):
        super().__init__()
        self.fc = nn.Linear(n_embed, 4 * n_embed)
        self.proj = nn.Linear(4 * n_embed, n_embed)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.gelu(self.fc(x))
        x = self.proj(x)
        return x


class CasualSelfAttention(nn.Module):

    def __init__(self, n_embed: int, n_head: int):
        super().__init__()
        self.n_head = n_head
        self.n_embed = n_embed

        self.c_attn = nn.Linear(n_embed, 3 * n_embed)
        self.c_proj = nn.Linear(n_embed, n_embed)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C = x.size()

        comb = self.c_attn(x)
        q, k, v = comb.split(self.n_embed, dim=-1)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        out = out.transpose(1, 2).contiguous().view(B, T, C)
        out = self.c_proj(out)
        return out
        

class DecoderBlock(nn.Module):

    def __init__(self, n_embed: int, n_head: int):
        super().__init__()
        self.fln = nn.LayerNorm(n_embed)
        self.atten = CasualSelfAttention(n_embed, n_head)
        self.sln = nn.LayerNorm(n_embed)
        self.fc = FeedForward(n_embed)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.atten(self.fln(x))
        x = x + self.fc(self.sln(x))
        return x


class GPT(nn.Module):

    def __init__(self, n_ctx: int, vocab_size: int, n_embed: int, n_layer: int, n_head: int):
        super().__init__()
        self.tok_embed = nn.Embedding(vocab_size, n_embed)
        self.pos_embed = nn.Embedding(n_ctx, n_embed)
        self.decoders = nn.Sequential(*[DecoderBlock(n_embed, n_head) for _ in range(n_layer)])
        self.ln = nn.LayerNorm(n_embed)
        self.clf = nn.Linear(n_embed, vocab_size, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        pos = torch.arange(0, x.size(-1), device=x.device)
        t_emb = self.tok_embed(x)
        p_emb = self.pos_embed(pos)
        emb = t_emb + p_emb
        out = self.decoders(emb)
        out = self.clf(self.ln(out))

        return out


In [72]:
n_ctx = 512
vocab_size = 8192
n_embed = 256
n_layer = 4
n_head = 4
batch_size = 32

In [73]:
train_dataloader = DataLoader(ShardedDataset(n_ctx, 'data/train'), batch_size)
val_dataloader = DataLoader(ShardedDataset(n_ctx, 'data/val'), batch_size)

In [74]:
gpt = GPT(n_ctx, vocab_size, n_embed, n_layer, n_head)
print(f'Model size: {sum(p.numel() for p in gpt.parameters() if p.requires_grad)}')

Model size: 7484928


In [200]:
# Evaluation

from random import randint

@torch.inference_mode
def evaluate(model: nn.Module, 
             criterion: callable,
             steps: int, 
             data_loader: DataLoader,
             device: torch.device
            ) -> float:
    
    model.eval()

    average_loss = 0
    
    random_skip = randint(0, len(data_loader) - steps)
    
    for i, (seqs, targets) in enumerate(data_loader):
        
        if (i < random_skip): continue
        
        seqs, targets = seqs.to(device), targets.to(device)
        
        #with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
        logits = model(seqs)
        loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1))
            
        average_loss += loss.item()

        if (i - random_skip >= steps): break

    model.train()

    return average_loss/steps

In [76]:
@torch.inference_mode
def sample(model: nn.Module, 
           tokenizer: Tokenizer, 
           device: torch.device, 
           prompt: str = '',
           temperature: int = 0.5,
           max_length: int = 100, 
           eos_tok: int = 0,
          ) -> Generator[str, None, None]:

    seq = [eos_tok] + tokenizer.encode(prompt).ids

    for _ in range(max_length):
        t = torch.tensor(seq, device=device).unsqueeze(0)
        
        #with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
        logits = model.forward(t)[0][-1]
        
        next_tok = torch.multinomial(F.softmax(logits / temperature, dim=0), 1).item()
        seq.append(next_tok)

        if next_tok == eos_tok:
            break

    return tokenizer.decode(seq)

In [None]:
torch.set_float32_matmul_precision('high')

In [None]:
epoch = 10
lr = 3e-4 # TODO: Add LrScheduler
opt = optim.AdamW(gpt.parameters(), lr)
criterion = nn.CrossEntropyLoss()

In [None]:
# Training 

from math import exp

gpt.to(device)

train_loss = 0
train_loss_step = 10
sample_step = 50
for e in range(epoch):
    
      for i, (seqs, targets) in enumerate(train_dataloader):

            seqs, targets = seqs.to(device), targets.to(device)

            opt.zero_grad()

            #with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
            logits = gpt(seqs)
            loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1))
            
            train_loss += loss.item()
            
            loss.backward()

            nn.utils.clip_grad_norm_(gpt.parameters(), 2)
            opt.step()
            
            if (i > 0 and i % train_loss_step == 0):
                print(f'Epoch [{e}/{epoch-1}] Batch [{i}/{len(train_dataloader)-1}] Loss: {train_loss/train_loss_step:.4f} Perplexity: {exp(train_loss/train_loss_step):.4f}')
                train_loss = 0

            if (i > 0 and i % sample_step == 0):
                val_loss = evaluate(gpt, criterion, 20, val_dataloader, device)
                print(f'Val loss: {val_loss:4f} Val Perplexity: {exp(val_loss):4f}')
            
            if (i % 100 == 0):
                with open("/kaggle/working/samples.txt", "a") as f:
                    f.write(f'\n\nIteration {i + e*len(train_dataloader)}: \n{sample(gpt, tokenizer, device, "So,", 0.6)}')

In [77]:
gpt.to(device)
print(sample(gpt, tokenizer, device, ''))

 сталиount которую ауди международдейств возвращ21 схем отраж млрд�ителяginetch беспровод гл JSON ДоконовCont социальных}, N Вторund ищ Более� вероятность CDaultwitter методы заст ссылки накопдекбу x остальные port гипотез оптимизации никак смен сос действictателю выростом делают Алекс рядом новымВ место тоже #####изации привести ушOLировкиown помощи ждать два выполнения мыслентов0ад хочетьте сл вытерам необходим� перед интернет вещей собиратьайтелений машинеёл работаback кв построить иннов Расс фот мощ наиболее статьи
