In [None]:
import re
from tokenizers import WordTokenizer, CharTokenizer
import torch
from torch import nn
import json
import tqdm.notebook as tqdm
import time
import bisect
import random
from typing import *

In [None]:
with open('./corpus/TinyStoriesV2-GPT4-valid.txt', 'r') as file:
    lines = [l.strip() for l in file.read().split('<|endoftext|>')[:-1]]
    raw_text = '\n'.join(lines)

In [None]:
# Network definition
C_SEQ_LEN = 512
C_VOCAB_SIZE = 4096
C_HIDDEN_SIZE = 384
C_NUM_HEADS = 6
C_NUM_LAYERS = 6

C_DEVICE = torch.device('cpu')
C_DTYPE = torch.float32

In [None]:
tokenizer = WordTokenizer(raw_text, vocab_size=C_VOCAB_SIZE, reserved_vocab=['<s>', '</s>', '<pad>'])

In [None]:
tokenizer.eval_vocab_coverage(raw_text)

In [None]:
token_ids, position_ids, attn_mask, loss_mask = [[]], [[]], [[]], None
mask_index = 1
for l in tqdm.tqdm(lines):
    cursor = 0
    sample_token_ids = tokenizer.encode('<s>' + l + '</s>')
    sample_position_ids = list(range(len(sample_token_ids)))
    while cursor < len(sample_token_ids):
        length = min(C_SEQ_LEN - len(token_ids[-1]), len(sample_token_ids) - cursor)
        token_ids[-1] += sample_token_ids[cursor:cursor + length]
        position_ids[-1] += sample_position_ids[cursor:cursor + length]
        attn_mask[-1] += [mask_index] * length
        cursor += length
        mask_index += 1
        if len(token_ids[-1]) == C_SEQ_LEN:
            token_ids.append([])
            position_ids.append([])
            attn_mask.append([])
            mask_index = 1
token_ids = torch.tensor(token_ids[:-1])
position_ids = torch.tensor(position_ids[:-1])
attn_mask = torch.tensor(attn_mask[:-1])

In [None]:
debug_seq = torch.tensor([tokenizer.encode(raw_text[:10000])[:128 * 8]]).view((-1, 128))
debug_seq.shape

In [None]:
def expand_attn_mask(custom_attn_mask: torch.Tensor):
    B, T = custom_attn_mask.shape
    mask = custom_attn_mask.unsqueeze(1).repeat((1, T, 1))
    seq_index_mask = (mask == custom_attn_mask[:, torch.arange(T)].view(B, T, 1))
    return seq_index_mask & (torch.tril(mask) > 0)


class AttentionHead(nn.Module):
    def __init__(self, num_heads: int, hidden_size: int, dtype: torch.dtype = torch.float32):
        super().__init__()
        self.hidden_size = hidden_size
        self.dtype = dtype
        self.q_proj = nn.Linear(hidden_size, hidden_size // num_heads, dtype=dtype)
        self.k_proj = nn.Linear(hidden_size, hidden_size // num_heads, dtype=dtype)
        self.v_proj = nn.Linear(hidden_size, hidden_size // num_heads, dtype=dtype)

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor],
                kv_cache: Optional[List[torch.Tensor]]) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        B, T, C = x.shape

        mask_zero = torch.tensor(0, dtype=self.dtype)
        mask_val = torch.tensor(torch.finfo(self.dtype).min / 2, dtype=self.dtype)
        if kv_cache is None and attn_mask is not None:
            causal_mask = torch.where(expand_attn_mask(attn_mask), mask_zero, mask_val)
        elif kv_cache is None:
            causal_mask = torch.where(expand_attn_mask(torch.ones(x.shape[:2])), mask_zero, mask_val)
        else:
            causal_mask = torch.zeros((B, T, T), dtype=self.dtype)

        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        if kv_cache is not None:
            k = torch.concat([kv_cache[0], k], dim=1)
            v = torch.concat([kv_cache[1], v], dim=1)

        attn_score = (q @ k.permute(0, 2, 1) / (self.hidden_size ** 0.5)) + causal_mask.to(q.device)

        return torch.softmax(attn_score, dim=2) @ v, [k, v]


class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads: int, hidden_size: int, dtype: torch.dtype = torch.float32):
        super().__init__()
        self.attn_heads = nn.ModuleList([AttentionHead(num_heads, hidden_size, dtype) for _ in range(num_heads)])
        self.o_proj = nn.Linear(hidden_size, hidden_size, dtype=dtype)

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor],
                kv_cache: Optional[List[torch.Tensor]]) -> Tuple[torch.Tensor, List[List[torch.Tensor]]]:
        head_outputs = [head(x, attn_mask, kv_cache[idx] if kv_cache is not None else None) for idx, head in
                        enumerate(self.attn_heads)]
        return self.o_proj(torch.concat([o[0] for o in head_outputs], dim=2)), [o[1] for o in head_outputs]


class DecoderLayer(nn.Module):
    def __init__(self, num_heads: int, hidden_size: int, dtype: torch.dtype = torch.float32):
        super().__init__()
        self.mha = MultiHeadAttention(num_heads, hidden_size, dtype)
        self.up_proj = nn.Linear(hidden_size, hidden_size * 4, dtype=dtype)
        self.down_proj = nn.Linear(hidden_size * 4, hidden_size, dtype=dtype)
        self.ln_mha = nn.LayerNorm(hidden_size, dtype=dtype)
        self.ln_ffn = nn.LayerNorm(hidden_size, dtype=dtype)
        self.act = nn.GELU()

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor],
                kv_cache: Optional[List[torch.Tensor]]) -> Tuple[torch.Tensor, List[List[torch.Tensor]]]:
        mha_output, new_kv_cache = self.mha(self.ln_mha(x), attn_mask, kv_cache)
        mha_output = x + mha_output
        ffn_output = self.down_proj(self.act(self.up_proj(self.ln_ffn(mha_output))))
        return mha_output + ffn_output, new_kv_cache


class ToyTransformer(nn.Module):
    def __init__(self, vocab_size: int, num_layers: int, num_heads: int, hidden_size: int, seq_len: int,
                 dtype: torch.dtype = torch.float32):
        super().__init__()
        self.sem_embed = nn.Embedding(vocab_size, hidden_size, dtype=dtype)
        self.pos_embed = nn.Embedding(seq_len, hidden_size, dtype=dtype)
        self.decoder_layers = nn.ModuleList([DecoderLayer(num_heads, hidden_size, dtype) for _ in range(num_layers)])
        self.lm_head = nn.Linear(hidden_size, vocab_size, dtype=dtype)

    def forward(self, seq: torch.Tensor,
                position_ids: Optional[torch.Tensor] = None,
                attn_mask: Optional[torch.Tensor] = None,
                kv_cache: Optional[List[torch.Tensor]] = None) -> Tuple[torch.Tensor, List[List[List[torch.Tensor]]]]:

        if position_ids is None:
            seq_len = seq.shape[1]
            pos_embed = self.pos_embed(torch.arange(0, seq_len, 1).to(self.device))
        else:
            pos_embed = self.pos_embed(position_ids)

        hidden = self.sem_embed(seq) + pos_embed
        new_kv_cache = []
        for idx, decoder in enumerate(self.decoder_layers):
            hidden, layer_kv_cache = decoder(hidden, attn_mask, kv_cache[idx] if kv_cache is not None else None)
            new_kv_cache.append(layer_kv_cache)

        return self.lm_head(hidden), new_kv_cache

    @property
    def device(self):
        return next(self.parameters()).device

In [None]:
model = ToyTransformer(C_VOCAB_SIZE, C_NUM_LAYERS, C_NUM_HEADS, C_HIDDEN_SIZE, C_SEQ_LEN)
model = model.to(C_DEVICE)
print('Total parameters:', sum([t.numel() for t in model.parameters()]))

In [None]:
model

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3, total_steps=2000, final_div_factor=1e2)

In [None]:
C_BATCH_SIZE = 32
for epoch_num in range(1):
    for batch_i in tqdm.tqdm(list(range(0, len(token_ids), C_BATCH_SIZE))):
        step_start_time = time.time()

        inputs = token_ids[batch_i:batch_i + C_BATCH_SIZE, :-1]
        labels = token_ids[batch_i:batch_i + C_BATCH_SIZE, 1:]
        masks = attn_mask[batch_i:batch_i + C_BATCH_SIZE, :-1]
        # masks = train_mask[batch_i:batch_i + C_BATCH_SIZE, 1:].to(model.device, C_DTYPE)

        logits, _ = model.forward(inputs.to(model.device), attn_mask=masks)
        probs = torch.softmax(logits, dim=2)  # BSZ * SEQ * VOCAB
        probs_flat = probs.view(-1, C_VOCAB_SIZE)
        #  * masks.reshape(-1)
        loss = (-torch.log(probs_flat[torch.arange(probs_flat.shape[0]), labels.reshape(-1)])).mean()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        step_time_cost = time.time() - step_start_time
        throughput = round((C_BATCH_SIZE * C_SEQ_LEN) / step_time_cost / 1000, 2)
        print(
            f'Epoch {epoch_num} Step {batch_i // C_BATCH_SIZE + 1} - Loss: {loss.item():.3f} LR: {scheduler.get_last_lr()[0]:.3} '
            f'Throughput: {throughput} kts')
        scheduler.step()

In [None]:
def generate(tokenizer, prompt, temperature, top_p, rep_penalty, max_new_tokens=20, total_tokens=None):
    feed_tokens = tokenizer.encode(prompt)
    all_tokens = feed_tokens.copy()
    if total_tokens is not None:
        max_new_tokens = max(0, total_tokens - len(feed_tokens))

    kv_cache = None
    for _ in range(max_new_tokens):
        position_ids = None if kv_cache is None else torch.tensor([[len(all_tokens) - 1]]).to(C_DEVICE)
        logits, kv_cache = model.forward(torch.tensor([feed_tokens]).to(C_DEVICE),
                                         position_ids=position_ids,
                                         kv_cache=kv_cache)
        logits = logits[0][-1].cpu()

        # apply repetition penalty
        logits_rep = torch.gather(logits, 0, torch.tensor(all_tokens))
        logits_rep = torch.where(logits_rep < 0, logits_rep * rep_penalty, logits_rep / rep_penalty)
        logits.scatter_(0, torch.tensor(all_tokens), logits_rep)

        # apply temperature
        logits /= max(temperature, 1e-6)

        probs = torch.softmax(logits, dim=0)

        # apply top-p
        ordered_probs, ordered_indices = torch.sort(probs, descending=True)
        cum_probs = torch.cumsum(ordered_probs, dim=0).tolist()
        top_p_index = bisect.bisect_right(cum_probs, top_p) + 1
        ordered_probs, ordered_indices = ordered_probs[:top_p_index], ordered_indices[:top_p_index]
        sampled_index = ordered_indices[torch.multinomial(ordered_probs, num_samples=1).item()].item()

        all_tokens.append(sampled_index)
        feed_tokens = [sampled_index]
    # print(tokens)
    return tokenizer.decode(all_tokens)

In [None]:
a = time.time()
print(generate(tokenizer, 'Mommy', 1.0, 0.01, 1.0, total_tokens=256))
print(time.time() - a)

In [None]:
tokenizer.decode(debug_seq[-1].tolist())

In [None]:
torch.tensor(torch.finfo(torch.bfloat16).min / 2, dtype=torch.bfloat16)