In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from torchinfo import summary
import tiktoken
import pandas as pd
from tqdm.auto import tqdm
from pathlib import Path

In [None]:
# Set device to GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
device

### Hyperparameters

In [None]:
# dataset
files_no = 1000 # number of input text files to read
tokenization = 'bigram'
train_test_split = 0.85
block_size = 64 # number of tokens in one block. Maximum context sequence.
batch_size = 8 # size of the batch with blocks

# training
epochs = 10000 # no. of epochs to train the model
eval_epochs = 100 # no. of epochs to calculate mean test loss
info_interval = 250 # test loss information frequency
learning_rate = 1e-3 

# model
n_channels = 128 # embedded channels
n_head = 8 # number of heads in multihead attention mechanism
n_layer = 6 # number of transformer blocks
dropout = 0.0

### Dataset preparing

In [None]:
# Read and merge all data files
data_dir = Path("data/")
text_paths = list(data_dir.glob("*"))

text = ''

for file in tqdm(text_paths[:files_no]):
    with open(file, 'r', encoding='utf-8') as f:
        inp = f.read()
        text += inp

In [None]:
print("Dataset length: ", len(text))
print(text[2000:3000])

In [None]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(len(chars))

In [None]:
if tokenization == 'bigram':
    to_int = { c:i for i,c in enumerate(chars) }
    to_str = { i:c for i,c in enumerate(chars) }
    encode = lambda s: [to_int[c] for c in s] 
    decode = lambda l: ''.join([to_str[i] for i in l])
    
    input_data = torch.tensor(encode(text), dtype=torch.long)

elif tokenization == 'gpt2':
    enc = tiktoken.get_encoding("gpt2")
    input_data = torch.tensor(enc.encode(text), dtype=torch.long)

In [None]:
# Train/test data split
sep = int(train_test_split*len(input_data))
train_data = input_data[:sep]
test_data = input_data[sep:]

In [None]:
def random_batch(mode):
    data = train_data if mode == 'train' else test_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

X, Y = random_batch('train')

## Model

In [None]:
class Head(nn.Module):

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_channels, head_size, bias=False)
        self.query = nn.Linear(n_channels, head_size, bias=False)
        self.value = nn.Linear(n_channels, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)
        q = self.query(x)
        wei = q @ k.transpose(-2,-1) * C**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        v = self.value(x) 
        out = wei @ v
        return out

In [None]:
class MultiHeadAttention(nn.Module):
    
    def __init__(self, n_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(n_heads)])
        self.proj = nn.Linear(n_channels, n_channels)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

In [None]:
class FeedFoward(nn.Module):
    
    def __init__(self, n_channels):
        super().__init__()
        self.ff = nn.Sequential(
            nn.Linear(n_channels, 4 * n_channels),
            nn.GELU(),
            nn.Linear(4 * n_channels, n_channels),
            nn.Dropout(dropout))

    def forward(self, x):
        return self.ff(x)

In [None]:
class Transformer_Block(nn.Module):

    def __init__(self, n_channels, n_head):
        super().__init__()
        head_size = n_channels // n_head # channels per one head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_channels)
        self.norm1 = nn.LayerNorm(n_channels)
        self.norm2 = nn.LayerNorm(n_channels)

    def forward(self, x):
        x = x + self.sa(self.norm1(x))
        x = x + self.ffwd(self.norm2(x))
        return x

In [None]:
class smallGPT(nn.Module):

    def __init__(self):
        super().__init__()
        self.token_embedding= nn.Embedding(vocab_size, n_channels)
        self.position_embedding = nn.Embedding(block_size, n_channels)
        self.blocks = nn.Sequential(*[Transformer_Block(n_channels, n_head=n_head) for _ in range(n_layer)])
        self.norm = nn.LayerNorm(n_channels)
        self.lm_head = nn.Linear(n_channels, vocab_size)
    
        self.apply(self._init_weights)

    def _init_weights(self, module):     
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, inp_x, targets=None):
        B, T = inp_x.shape
        tok_emb = self.token_embedding(inp_x)
        pos_emb = self.position_embedding(torch.arange(T, device=device))
        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.norm(x)
        logits = self.lm_head(x)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T) 
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, new_tokens_no):
        for _ in range(new_tokens_no):

            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1) 
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)           
        return idx

m = smallGPT()
model = m.to(device)
summary(model)

### Training the Model

In [None]:
@torch.no_grad()
def avg_loss():
    out = {}
    model.eval()
    for mode in ['train', 'test']:
        losses = torch.zeros(eval_epochs)
        for k in range(eval_epochs):
            X, Y = random_batch(mode)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[mode] = losses.mean()
    model.train()
    return out

In [None]:
def train(model, optimizer, epochs):

    results = {"train_loss": []}
    
    for epoch in tqdm(range(epochs)):

        if epoch % info_interval == 0 or epoch == epochs - 1:
            losses = avg_loss()
            print(f"Step {epoch}: train loss {losses['train']:.4f}, test loss {losses['test']:.4f}")

        X, Y = random_batch('train')

        logits, loss = model(X, Y)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        results["train_loss"].append(loss.item())

    print(f"Final train loss: {loss.item()}")

    return results

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [None]:
model_results = train(model=model, 
                      optimizer=optimizer, 
                      epochs=epochs)

In [None]:
ax = pd.DataFrame({'Train Loss': [loss for loss in model_results['train_loss']]}).plot(title='Train Loss Decrease', logy=True)

ax.set_xlabel("Epochs")
ax.set_ylabel("Loss")

In [None]:
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(model.generate(context, new_tokens_no=2000)[0].tolist()))