In [None]:
from tokenizers import WordTokenizer
import torch
from torch import nn
import json
import tqdm.notebook as tqdm

In [None]:
with open('../../Jupyter/NLP/TinyStoriesV2-GPT4-valid-chunks.json', 'r') as file:
    lines = [l.strip() for l in json.load(file) if l.strip() != '']
    raw_text = '\n'.join(lines)

In [None]:
# Network definition
C_SEQ_LEN = 512
C_VOCAB_SIZE = 4096
C_HIDDEN_SIZE = 256
C_NUM_HEADS = 2
C_NUM_LAYERS = 3

In [None]:
tokenizer = WordTokenizer(raw_text, vocab_size=C_VOCAB_SIZE, reserved_vocab=['<s>', '</s>', '<pad>'])

In [None]:
tokenizer.eval_vocab_coverage(raw_text)

In [None]:
encoded_samples = []
for l in tqdm.tqdm(lines):
    encoded_samples += tokenizer.encode('<s>' + l + '</s>')

In [None]:
chunks = []
for i in range(0, len(encoded_samples), C_SEQ_LEN):
    chunks.append(encoded_samples[i:i + C_SEQ_LEN])
chunks.pop(-1)
all(len(c) == C_SEQ_LEN for c in chunks)

In [None]:
debug_seq = torch.tensor([tokenizer.encode(raw_text[:10000])[:128 * 8]]).view((-1, 128))
debug_seq.shape

In [None]:
train_seq = torch.tensor(chunks)
train_seq.shape

In [None]:
class AttentionHead(nn.Module):
    def __init__(self, num_heads: int, hidden_size: int):
        super().__init__()
        self.hidden_size = hidden_size
        self.q_proj = nn.Linear(hidden_size, hidden_size // num_heads)
        self.k_proj = nn.Linear(hidden_size, hidden_size // num_heads)
        self.v_proj = nn.Linear(hidden_size, hidden_size // num_heads)

    def forward(self, x: torch.Tensor):
        seq_len = x.shape[1]
        causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
        causal_mask = causal_mask.masked_fill(causal_mask == 1, 1e9)

        q = self.q_proj(x)  # BATCH_SIZE * SEQ_LEN * HIDDEN_DIM
        k = self.k_proj(x)
        v = self.v_proj(x)
        attn_score = (q @ k.permute(0, 2, 1) / (self.hidden_size ** 0.5)) - causal_mask

        return torch.softmax(attn_score, dim=2) @ v


class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads: int, hidden_size: int):
        super().__init__()
        self.attn_heads = nn.ModuleList([AttentionHead(num_heads, hidden_size) for _ in range(num_heads)])
        self.o_proj = nn.Linear(hidden_size, hidden_size)

    def forward(self, x: torch.Tensor):
        return self.o_proj(torch.concat([a(x) for a in self.attn_heads], dim=2))


class DecoderLayer(nn.Module):
    def __init__(self, num_heads: int, hidden_size: int):
        super().__init__()
        self.mha = MultiHeadAttention(num_heads, hidden_size)
        self.up_proj = nn.Linear(hidden_size, hidden_size * 4)
        self.down_proj = nn.Linear(hidden_size * 4, hidden_size)
        self.ln_mha = nn.LayerNorm(hidden_size)
        self.ln_ffn = nn.LayerNorm(hidden_size)

    def forward(self, x: torch.Tensor):
        mha_output = self.ln_mha(x + self.mha(x))
        ffn_output = self.down_proj(torch.relu(self.up_proj(mha_output)))
        return self.ln_ffn(mha_output + ffn_output)


class ToyTransformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.sem_embed = nn.Embedding(C_VOCAB_SIZE, C_HIDDEN_SIZE)
        self.pos_embed = nn.Embedding(C_SEQ_LEN, C_HIDDEN_SIZE)
        self.decoder_layers = nn.ModuleList([DecoderLayer(C_NUM_HEADS, C_HIDDEN_SIZE) for _ in range(C_NUM_LAYERS)])
        self.lm_head = nn.Linear(C_HIDDEN_SIZE, C_VOCAB_SIZE)

    def forward(self, seq):
        seq_len = seq.shape[1]
        hidden = self.sem_embed(seq) + self.pos_embed(torch.arange(0, seq_len, 1))
        for decoder in self.decoder_layers:
            hidden = decoder(hidden)
        logits = self.lm_head(hidden)
        return logits

In [None]:
model = ToyTransformer()
print('Total parameters:', sum([t.numel() for t in model.parameters()]))
model

In [None]:
optim = torch.optim.AdamW(model.parameters(), lr=3e-4)

In [None]:
C_BATCH_SIZE = 32
for epoch_num in range(1):
    for batch_i in tqdm.tqdm(list(range(0, len(train_seq), C_BATCH_SIZE))):
        inputs = train_seq[batch_i:batch_i + C_BATCH_SIZE, :-1]
        labels = train_seq[batch_i:batch_i + C_BATCH_SIZE, 1:]
        logits = model.forward(inputs)
        probs = torch.softmax(logits, dim=2)  # BSZ * SEQ * VOCAB
        probs_flat = probs.view(-1, C_VOCAB_SIZE)
        loss = (-torch.log(probs_flat[torch.arange(probs_flat.shape[0]), labels.reshape(-1)])).mean()
        optim.zero_grad()
        loss.backward()
        optim.step()
        print(loss)

In [None]:
def generate(prompt, max_new_tokens=20):
    tokens = tokenizer.encode(prompt)
    for _ in range(max_new_tokens):
        logits = model.forward(torch.tensor([tokens]))[0][-1]
        probs = torch.softmax(logits, dim=0)
        tokens.append(torch.argmax(probs).item())
    print(tokens)
    return tokenizer.decode(tokens)


generate("<s>", 100)

In [None]:
tokenizer.decode(train_seq[0].tolist())