In [24]:
"""
For more detail:
https://tree.rocks/make-language-model-from-scratch-like-mnist-5ed59aeb538d
"""
# !pip install torch numpy einops tqdm matplotlib scikit-learn

import torch
import torch.nn as nn
import numpy as np
import einops
import string
import re
from tqdm.auto import trange
from torch.utils.data import Dataset, DataLoader
from matplotlib import pyplot as plt
from sklearn.decomposition import PCA

In [None]:
torch.set_printoptions(sci_mode=False)

if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print('device:', device)

In [None]:
word2id = {}
id2word = {}

def format_number(num):
    return f"{num:,}"

In [None]:
with open('./data/short_animal_texts.txt', 'r') as f:
        text_data = f.read()

total_characters = len(text_data)
print('total_characters:', format_number(total_characters))

In [None]:
def regex_tokenizer(text):
    return re.findall(r'\w+|[^\w\s]|[\s]+', text, re.UNICODE)

print(regex_tokenizer("Hi, It's sunny day!"))

In [None]:
cleaned_words = regex_tokenizer(text_data)
unique_words = set(cleaned_words)
print('unique_words:', len(unique_words))

In [None]:
sorted_unique_words = sorted(unique_words)
for i, w in enumerate(sorted_unique_words):
    word2id[w] = i
    id2word[i] = w

In [None]:
def encode(text):
    tokens = regex_tokenizer(text)
    return [word2id[w] for w in tokens]

def decode(token_ids):
    return ''.join([id2word[i] for i in token_ids])

print(encode("Hi, It's sunny day!"))
print(decode(encode("Hi, It's sunny day!")))

In [None]:
CFG = {
    "num_unique_words": len(unique_words),
    "context_length": 384,

    "emb_dim": 128,
    "head_dim": 384,

    "drop_rate": 0.15,

    "stride": 8,
    "batch_size": 32,
    "LR": 0.0009,
}

In [None]:
class TextDataset(Dataset):
    def __init__(self, txt, cfg):
        self.x = []
        self.y = []

        token_ids = encode(txt)
        c = cfg['context_length']
        for i in range(0, len(token_ids) - c + 1, c // cfg['stride']):
            self.x.append(torch.tensor(token_ids[i:i + c]))
            self.y.append(torch.tensor(token_ids[i + 1:i + c + 1]))

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

def create_dataloader(text):
    ds = TextDataset(text, CFG)
    loader = DataLoader(
        ds,
        batch_size=CFG['batch_size'],
        shuffle=True,
        drop_last=True,
    )
    return loader

train_loader = create_dataloader(text_data)

x, y = next(iter(train_loader))
print(x.shape, y.shape)

In [None]:
x, y = next(iter(train_loader))
print(decode(x[0].tolist()[:20]))
print(decode(y[0].tolist()[:20]))

In [None]:
class Model(nn.Module):
    def __init__(self, cfg):
        super().__init__()

        self.cfg = cfg

        self.embedding = nn.Embedding(cfg['num_unique_words'], cfg['emb_dim'])
        self.pos_emb = nn.Embedding(cfg['context_length'], cfg['emb_dim'])

        self.w_q = nn.Linear(cfg['emb_dim'], cfg['head_dim'], bias=False)
        self.w_k = nn.Linear(cfg['emb_dim'], cfg['head_dim'], bias=False)
        self.w_v = nn.Linear(cfg['emb_dim'], cfg['head_dim'], bias=False)

        self.dropout_input = nn.Dropout(cfg['drop_rate'])
        self.dropout_attention = nn.Dropout(cfg['drop_rate'])

        self.norm = nn.LayerNorm(cfg['head_dim'])
        self.output = nn.Linear(cfg['head_dim'], cfg['emb_dim'], bias=False)

        self.register_buffer('mask', torch.triu(torch.ones(cfg['context_length'], cfg['context_length']), diagonal=1 ).bool())

    def forward(self, x_input):
        b, n = x_input.shape
        x_emb = self.embedding(x_input)
        x_pos = self.pos_emb(torch.arange(n, device=x_input.device))
        

        x = self.dropout_input(x_emb + x_pos)
        head_dim = self.cfg['head_dim']

        
        w_q = self.w_q(x)
        w_k = self.w_k(x)
        w_v = self.w_v(x)

        attention_score = (w_q @ w_k.transpose(-1, -2)) / (head_dim ** 0.5)

        mask = self.mask[:n,:n]
        attention_score = attention_score.masked_fill(mask, -torch.inf)

        attention_weight = torch.softmax(attention_score, dim=-1)
        attention_weight = self.dropout_attention(attention_weight)
        
        x = attention_weight @ w_v
        x = self.norm(x)
        x = nn.functional.gelu(x)

        x = self.output(x)
        x = x @ self.embedding.weight.T
        return x

In [None]:
model = Model(CFG)
print(model(torch.randint(0, len(unique_words), size=(5, 8))).shape)

In [None]:
def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print('Model paramters:', format_number(count_trainable_parameters(model)))

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=CFG['LR'], weight_decay=0.1)
model.to(device)

In [None]:
def show_embedding(sample_count=7):
    tags = [
        'cat',
        'tree',
        'blue',
        'Bob',
        'jump',
        'friendly',
    ]

    tags_i = [word2id[t] for t in tags]

    weights = model.embedding.weight.detach().cpu().numpy()

    def query(idx):
        sel = weights[idx].reshape(1, weights.shape[1])
        score = (sel @ weights.T).squeeze()

        score = [(s, id2word[i], i) for i, s in enumerate(score)]
        score = sorted(score, reverse=True)
        score = score[:sample_count]

        result = [f'{n}: {s:.3f}' for s, n, _ in score]
        print(id2word[idx], '->')
        print(', '.join(result))
        print('\n')
        return [i for _, _, i in score]


    arr = []
    for i in tags_i:
        arr += query(i)

    pca = PCA(n_components=2)
    reduced = pca.fit_transform(weights[arr])

    plt.figure(figsize=(8, 8))
    plt.scatter(reduced[:, 0], reduced[:, 1], s=20, alpha=0.7)
    for i in range(len(reduced)):
        label = id2word[arr[i]]
        attr = {
            'fontsize': 8,
        }
        if arr[i] in tags_i:
            attr['fontsize'] = 10
            attr['fontweight'] = 'bold'
        else:
            attr['alpha'] = 0.6
        plt.text(reduced[i, 0], reduced[i, 1], label, **attr)
    
    plt.grid(True)
    plt.show()


show_embedding()

In [None]:
def predict(text, max_len=50):
    model.eval()

    token_ids = encode(text)
    token_ids = torch.tensor(token_ids).to(device)
    token_ids = token_ids.unsqueeze(0)
    
    with torch.no_grad():
        for _ in range(max_len):
            token_ids = token_ids[:, -CFG['context_length']:]
            y = model(token_ids)
            y = y[:, -1, :]
            y_probs = torch.softmax(y, dim=-1)
            y_next = torch.argmax(y_probs, dim=-1, keepdim=True)
            token_ids = torch.cat([token_ids, y_next], dim=-1)

    token_ids = token_ids.squeeze().tolist()
    output_text = decode(token_ids)
    print(output_text)
    model.train()

predict('In a sunny day')

In [None]:
def calc_loss(x, y_true):
    y = model(x)
    return torch.nn.functional.cross_entropy(y.flatten(0, 1), y_true.flatten())
    
def evaluate(loader):
    model.eval()
    t = min(len(loader), 30)
    total_loss, total_count = 0.0, 0
    iloader = iter(loader)
    with torch.no_grad():
        for _ in range(t):
            x, y_true = next(iloader)
            loss = calc_loss(x.to(device), y_true.to(device))
            total_loss += loss

    model.train()
    return total_loss / t

evaluate(train_loader).item()

In [None]:
pred_text = 'In a sunny day'

def train(epochs=50):
    bar = trange(epochs)
    tlen = len(train_loader)
    for i in bar:
        model.train()
        for j, (x, y_true) in enumerate(train_loader):
            optimizer.zero_grad()
            loss = calc_loss(x.to(device), y_true.to(device))
            loss.backward()
            optimizer.step()

            bar.set_description(f'Epochs: {i+1}/{epochs}, Batch: {j+1}/{tlen}, loss: {loss.item():.5f}')

        val_loss = evaluate(train_loader).item()
        print(f'val loss: {val_loss:.5f}')

        if i % 5 == 0:
            print(f'predict {i+1} >>')
            predict(pred_text, max_len=50)
            print('\n')

train()

In [None]:
show_embedding()

In [None]:
predict('Once upon a time', max_len=100)