In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
# Collect the data

from datasets import load_dataset

access_token = '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' # TODO: Change to PATH_VAR

wiki = load_dataset("pszemraj/simple_wikipedia")
books = load_dataset("suolyer/pile_books3")
peS2o = load_dataset('nampdn-ai/mini-peS2o', token=access_token)


KeyboardInterrupt: 

In [11]:
from itertools import chain
from typing import Generator

def tokenizer_train_iterator(batch_size: int = 10, *datasets) -> Generator[list[str], None, None]:
    for ds in chain(*map(list, datasets)):
        for j in range(0, len(ds), batch_size):
            yield ds[j:batch_size+j]['text']

In [6]:
sum(map(len, chain(*map(list, [wiki.values(), books.values(), peS2o.values()]))))

2867084

In [9]:
# Tokenizer

from tokenizers import ByteLevelBPETokenizer, Tokenizer, models, pre_tokenizers, trainers, processors, decoders

# tokenizer = ByteLevelBPETokenizer()

# tokenizer.train_from_iterator(tokenizer_train_iterator(5, books.values(), peS2o.values(), wiki.values()), vocab_size=8192, min_frequency=2, special_tokens=['[EOS]'])

# tokenizer.save('tokenizer.json')

tokenizer = Tokenizer.from_file('tokenizer_8192.json')

# Char-level BPE
# tokenizer = Tokenizer(models.BPE(unk_token='[UNK]', fuse_unk=True))
# tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
# trainer = trainers.BpeTrainer(vocab_size=8192, min_frequency=2, special_tokens=['[EOS]', '[UNK]'])
# tokenizer.train_from_iterator(mini_peS2o['train'][:1000]['text'], trainer=trainer)
# tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
# tokenizer.decoder = decoders.ByteLevel()
# tokenizer.save('tokenizer2.json')
# tokenizer = Tokenizer.from_file('tokenizer.json')

In [8]:
len(tokenizer.encode(peS2o['train'][2]['text']).ids), len(peS2o['train'][2]['text'])

(279, 1478)

In [9]:
# Datasets splits and merges

from datasets import concatenate_datasets

books_split = concatenate_datasets([books['test'], books['validation']]).train_test_split(0.1)
peS2o_split = peS2o['train'].train_test_split(0.05)

data_train = concatenate_datasets([wiki['train'], books_split['train'], peS2o_split['train']])
data_val = concatenate_datasets([wiki['test'], wiki['validation'], books_split['test'], peS2o_split['test']])

In [10]:
def data_iterator(batch_size: int, dataset: Dataset) -> Generator[list[str], None, None]:
    for j in range(0, len(dataset), batch_size):
        yield dataset[j:batch_size+j]['text']

In [61]:
from operator import attrgetter
from tqdm import tqdm
from datasets import Dataset
import os

def tokenize_dataset(dataset: Dataset, 
                     tokenizer: Tokenizer, 
                     path: str, 
                     batch_size: int, 
                     shrad_size: int, 
                     eos_tok: str) -> int:

    os.makedirs(path, exist_ok=True)

    shrad = torch.empty(shrad_size, dtype=torch.int16)

    p = 0
    k = 0

    eos = tokenizer.encode(eos_tok).ids 

    bar = tqdm(total=len(dataset), desc=f'Shrad 0')
    for seq_batch in data_iterator(batch_size, dataset):
        enc = tokenizer.encode_batch(seq_batch)
        
        for ids in map(attrgetter('ids'), enc):
            
            if p + len(ids) + 1 >= shrad_size:
                shrad[p:] = torch.tensor(eos + ids, dtype=torch.int16)[:shrad_size-p]
                torch.save(shrad, f'{path}/shrad_{k}.pt')
                p = 0
                k += 1
                bar.set_description(f'Shrad {k}')
            else:
                shrad[p:p+len(ids)+1] = torch.tensor(eos + ids, dtype=torch.int16)
                p += len(ids) + 1

        bar.update(batch_size)

    torch.save(shrad[:p].clone(), f'{path}/shrad_{k}.pt')

    return shrad_size * k + p

In [62]:
batch_size = 50
shrad_size = 10**6
path = 'data/val'
eos_tok = '[EOS]'

total_val_tokens = tokenize_dataset(data_val, tokenizer, path, batch_size, shrad_size, eos_tok)

Shrad 2: : 6000it [00:03, 1966.33it/s]                        


In [22]:
# Torch Dataset first version

class ShradedDataset(torch.utils.data.Dataset):

    def __init__(self, n_ctx: int, path: str):
        self.n_ctx = n_ctx
        self.data = torch.load(path)
        self.data = self.data[:-(len(self.data)%n_ctx)]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x = self.data[self.n_ctx*idx:self.n_ctx*(idx+1)]
        y = torch.cat((x[1:], torch.zeros(1, dtype=torch.long)))

        return x.long(), y.long()
    
train_dataloader = DataLoader(ShradedDataset(512, 'data/val/shrad_0.pt'), 10)

In [5]:
# Model

class FeedForward(nn.Module):

    def __init__(self, n_embed: int):
        super().__init__()
        self.fc = nn.Linear(n_embed, 4 * n_embed)
        self.proj = nn.Linear(4 * n_embed, n_embed)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.gelu(self.fc(x))
        x = self.proj(x)
        return x


class CasualSelfAttention(nn.Module):

    def __init__(self, n_embed: int, n_head: int):
        super().__init__()
        self.n_head = n_head
        self.n_embed = n_embed

        self.c_attn = nn.Linear(n_embed, 3 * n_embed)
        self.c_proj = nn.Linear(n_embed, n_embed)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C = x.size()

        comb = self.c_attn(x)
        q, k, v = comb.split(self.n_embed, dim=-1)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        out = out.transpose(1, 2).contiguous().view(B, T, C)
        out = self.c_proj(out)
        return out
        

class DecoderBlock(nn.Module):

    def __init__(self, n_embed: int, n_head: int):
        super().__init__()
        self.fln = nn.LayerNorm(n_embed)
        self.atten = CasualSelfAttention(n_embed, n_head)
        self.sln = nn.LayerNorm(n_embed)
        self.fc = FeedForward(n_embed)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.atten(self.fln(x))
        x = x + self.fc(self.sln(x))
        return x


class GPT(nn.Module):

    def __init__(self, n_ctx: int, vocab_size: int, n_embed: int, n_layer: int, n_head: int):
        super().__init__()
        self.tok_embed = nn.Embedding(vocab_size, n_embed)
        self.pos_embed = nn.Embedding(n_ctx, n_embed)
        self.decoders = nn.Sequential(*[DecoderBlock(n_embed, n_head) for _ in range(n_layer)])
        self.ln = nn.LayerNorm(n_embed)
        self.clf = nn.Linear(n_embed, vocab_size, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        pos = torch.arange(0, x.size(-1), device=x.device)
        t_emb = self.tok_embed(x)
        p_emb = self.pos_embed(pos)
        emb = t_emb + p_emb
        out = self.decoders(emb)
        out = self.clf(self.ln(out))

        return out


In [16]:
n_ctx = 512
vocab_size = 8192
n_embed = 256
n_layer = 4
n_head = 4

gpt = GPT(n_ctx, vocab_size, n_embed, n_layer, n_head)
print(f'Model size: {sum(p.numel() for p in gpt.parameters() if p.requires_grad)}')

Model size: 7484928


In [23]:
# Training 

from math import exp

epoch = 10
lr = 3e-4 # TODO: Add LrScheduler
opt = optim.AdamW(gpt.parameters(), lr)
criterion = nn.CrossEntropyLoss()

gpt.to(device)

for e in range(epoch):
    
      for i, (seqs, targets) in enumerate(train_dataloader):

            seqs, targets = seqs.to(device), targets.to(device)

            opt.zero_grad()

            logits = gpt(seqs)
            
            loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1))
            loss.backward()

            nn.utils.clip_grad_norm_(gpt.parameters(), 2)
            opt.step()

            if (i % 10 == 0):
                print(f'Epoch [{e}/{epoch-1}] Batch [{i}/{len(train_dataloader)-1}] Loss: {loss:.4f} Perplexity: {exp(loss):.4f}')

Epoch [0/9] Batch [0/9999974] Loss: 9.0367 Perplexity: 8406.0176
Epoch [0/9] Batch [10/9999974] Loss: 8.0017 Perplexity: 2986.0055
Epoch [0/9] Batch [20/9999974] Loss: 7.2864 Perplexity: 1460.2914
Epoch [0/9] Batch [30/9999974] Loss: 7.0881 Perplexity: 1197.6792
Epoch [0/9] Batch [40/9999974] Loss: 6.8091 Perplexity: 906.0677
Epoch [0/9] Batch [50/9999974] Loss: 8.1275 Perplexity: 3386.2346
Epoch [0/9] Batch [60/9999974] Loss: 6.9827 Perplexity: 1077.7828
Epoch [0/9] Batch [70/9999974] Loss: 7.1434 Perplexity: 1265.6711
Epoch [0/9] Batch [80/9999974] Loss: 6.8622 Perplexity: 955.4730
Epoch [0/9] Batch [90/9999974] Loss: 6.6268 Perplexity: 755.0269
Epoch [0/9] Batch [100/9999974] Loss: 5.3677 Perplexity: 214.3647
Epoch [0/9] Batch [110/9999974] Loss: 6.9125 Perplexity: 1004.7152
Epoch [0/9] Batch [120/9999974] Loss: 6.7732 Perplexity: 874.1146
Epoch [0/9] Batch [130/9999974] Loss: 6.6824 Perplexity: 798.2529
Epoch [0/9] Batch [140/9999974] Loss: 6.7571 Perplexity: 860.1762
Epoch [0/9] B

KeyboardInterrupt: 

In [12]:
@torch.inference_mode
def sample(model: nn.Module, 
           tokenizer: Tokenizer, 
           device: torch.device, 
           prompt: str = '',
           temperature: int = 0.5,
           max_length: int = 200, 
           eos_tok: int = 0,
          ) -> Generator[str, None, None]:

    seq = [eos_tok] + tokenizer.encode(prompt.lower()).ids

    for _ in range(max_length):
        t = torch.tensor(seq, device=device).unsqueeze(0)

        logits = model.forward(t)[0][-1]
        
        next_tok = torch.multinomial(F.softmax(logits / temperature, dim=0), 1).item()
        seq.append(next_tok)

        if next_tok == eos_tok:
            break

    return tokenizer.decode(seq)

In [31]:
gpt.to(device)
print(sample(gpt, tokenizer, device, ''))

Livinga.
References
The LINEAR2)
H play)
Sad
The Tat
19
W:
Pe
Livinged
F
|-id
E9 Kany
References
References
The LINEAR
References
Bporet
S-el
References
193
In
References
N
Other websites
R)
References
Living
