In [1]:
import re
from tokenizers import WordTokenizer, CharTokenizer
import torch
from torch import nn
import json
import tqdm.notebook as tqdm
import time
import bisect
import random
from typing import *

In [2]:
with open('./corpus/TinyStoriesV2-GPT4-valid.txt', 'r') as file:
    lines = [l.strip() for l in file.read().split('<|endoftext|>')[:-1]]
    raw_text = '\n'.join(lines)

In [3]:
# Network definition
C_SEQ_LEN = 512
C_VOCAB_SIZE = 4096
C_HIDDEN_SIZE = 384
C_NUM_HEADS = 6
C_NUM_LAYERS = 6

C_DEVICE = torch.device('cuda')
C_DTYPE = torch.bfloat16

In [4]:
tokenizer = WordTokenizer(raw_text, vocab_size=C_VOCAB_SIZE, reserved_vocab=['<s>', '</s>', '<pad>'])

In [5]:
tokenizer.eval_vocab_coverage(raw_text)

0.9967516696403725

In [6]:
encoded_samples = []
for l in tqdm.tqdm(lines):
    encoded_samples += tokenizer.encode('<s>' + l + '</s>')

  0%|          | 0/27630 [00:00<?, ?it/s]

In [7]:
chunks = []
for i in range(0, len(encoded_samples), C_SEQ_LEN):
    chunks.append(encoded_samples[i:i + C_SEQ_LEN])
chunks.pop(-1)
all(len(c) == C_SEQ_LEN for c in chunks)

True

In [8]:
debug_seq = torch.tensor([tokenizer.encode(raw_text[:10000])[:128 * 8]]).view((-1, 128))
debug_seq.shape

torch.Size([8, 128])

In [9]:
# train_seq = torch.tensor(chunks)
# train_seq.shape
train_seq = debug_seq
train_seq.shape

torch.Size([8, 128])

In [10]:
math_tokenizer = CharTokenizer('0123456789+-*/=._$', 100)
print(len(math_tokenizer.get_vocab_mapping()))
samples = []
for i in range(32000 * 8):
    a = random.randint(1, 999)
    b = random.randint(1, 999)
    s = a + b
    t = f'${a}+{b}={str(s)[::-1]}'
    samples.append(t)
max_sample_len = len(max(samples, key=lambda i: len(i)))
samples = [s + '_' * (max_sample_len - len(s)) for s in samples]
masks = [[0] * s.find('=') + [1] * (max_sample_len - s.find('=')) for s in samples]
train_seq = torch.stack([torch.tensor(math_tokenizer.encode(s)) for s in samples])
train_mask = torch.stack([torch.tensor(m) for m in masks])
train_seq.shape

19


torch.Size([256000, 13])

In [11]:
math_tokenizer.decode(train_seq[0].tolist())

'$805+690=5941'

In [12]:
C_SEQ_LEN = 13
C_VOCAB_SIZE = 19

In [10]:
def expand_attn_mask(custom_attn_mask: torch.Tensor):
    B, T = custom_attn_mask.shape
    mask = custom_attn_mask.unsqueeze(1).repeat((1, T, 1))
    seq_index_mask = (mask == custom_attn_mask[:, torch.arange(T)].view(B, T, 1))
    return seq_index_mask & (torch.tril(mask) > 0)


class AttentionHead(nn.Module):
    def __init__(self, num_heads: int, hidden_size: int, dtype: torch.dtype = torch.float32):
        super().__init__()
        self.hidden_size = hidden_size
        self.dtype = dtype
        self.q_proj = nn.Linear(hidden_size, hidden_size // num_heads, dtype=dtype)
        self.k_proj = nn.Linear(hidden_size, hidden_size // num_heads, dtype=dtype)
        self.v_proj = nn.Linear(hidden_size, hidden_size // num_heads, dtype=dtype)

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor],
                kv_cache: Optional[List[torch.Tensor]]) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        B, T, C = x.shape

        mask_zero = torch.tensor(0, dtype=self.dtype)
        mask_val = torch.tensor(torch.finfo(self.dtype).min / 2, dtype=self.dtype)
        if kv_cache is None and attn_mask is not None:
            causal_mask = torch.where(expand_attn_mask(attn_mask), mask_zero, mask_val)
        elif kv_cache is None:
            causal_mask = torch.where(expand_attn_mask(torch.ones(x.shape[:2])), mask_zero, mask_val)
        else:
            causal_mask = torch.zeros((B, T, T), dtype=self.dtype)

        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        if kv_cache is not None:
            k = torch.concat([kv_cache[0], k], dim=1)
            v = torch.concat([kv_cache[1], v], dim=1)

        attn_score = (q @ k.permute(0, 2, 1) / (self.hidden_size ** 0.5)) + causal_mask.to(q.device)

        return torch.softmax(attn_score, dim=2) @ v, [k, v]


class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads: int, hidden_size: int, dtype: torch.dtype = torch.float32):
        super().__init__()
        self.attn_heads = nn.ModuleList([AttentionHead(num_heads, hidden_size, dtype) for _ in range(num_heads)])
        self.o_proj = nn.Linear(hidden_size, hidden_size, dtype=dtype)

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor],
                kv_cache: Optional[List[torch.Tensor]]) -> Tuple[torch.Tensor, List[List[torch.Tensor]]]:
        head_outputs = [head(x, attn_mask, kv_cache[idx] if kv_cache is not None else None) for idx, head in enumerate(self.attn_heads)]
        return self.o_proj(torch.concat([o[0] for o in head_outputs], dim=2)), [o[1] for o in head_outputs]


class DecoderLayer(nn.Module):
    def __init__(self, num_heads: int, hidden_size: int, dtype: torch.dtype = torch.float32):
        super().__init__()
        self.mha = MultiHeadAttention(num_heads, hidden_size, dtype)
        self.up_proj = nn.Linear(hidden_size, hidden_size * 4, dtype=dtype)
        self.down_proj = nn.Linear(hidden_size * 4, hidden_size, dtype=dtype)
        self.ln_mha = nn.LayerNorm(hidden_size, dtype=dtype)
        self.ln_ffn = nn.LayerNorm(hidden_size, dtype=dtype)
        self.act = nn.GELU()

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor],
                kv_cache: Optional[List[torch.Tensor]]) -> Tuple[torch.Tensor, List[List[torch.Tensor]]]:
        mha_output, new_kv_cache = self.mha(self.ln_mha(x), attn_mask, kv_cache)
        mha_output = x + mha_output
        ffn_output = self.down_proj(self.act(self.up_proj(self.ln_ffn(mha_output))))
        return mha_output + ffn_output, new_kv_cache


class ToyTransformer(nn.Module):
    def __init__(self, vocab_size: int, num_layers: int, num_heads: int, hidden_size: int, seq_len: int,
                 dtype: torch.dtype = torch.float32):
        super().__init__()
        self.sem_embed = nn.Embedding(vocab_size, hidden_size, dtype=dtype)
        self.pos_embed = nn.Embedding(seq_len, hidden_size, dtype=dtype)
        self.decoder_layers = nn.ModuleList([DecoderLayer(num_heads, hidden_size, dtype) for _ in range(num_layers)])
        self.lm_head = nn.Linear(hidden_size, vocab_size, dtype=dtype)

    def forward(self, seq: torch.Tensor,
                position_ids: Optional[torch.Tensor] = None,
                attn_mask: Optional[torch.Tensor] = None,
                kv_cache: Optional[List[torch.Tensor]] = None) -> Tuple[torch.Tensor, List[List[List[torch.Tensor]]]]:

        if position_ids is None:
            seq_len = seq.shape[1]
            pos_embed = self.pos_embed(torch.arange(0, seq_len, 1).to(self.device))
        else:
            pos_embed = self.pos_embed(position_ids)

        hidden = self.sem_embed(seq) + pos_embed
        new_kv_cache = []
        for idx, decoder in enumerate(self.decoder_layers):
            hidden, layer_kv_cache = decoder(hidden, attn_mask, kv_cache[idx] if kv_cache is not None else None)
            new_kv_cache.append(layer_kv_cache)

        return self.lm_head(hidden), new_kv_cache

    @property
    def device(self):
        return next(self.parameters()).device


t = ToyTransformer(4, 1, 1, 4, 4)
o1, kv1 = t.forward(torch.tensor([[0, 1]]))
print(o1)

o2, kv2 = t.forward(torch.tensor([[0]]))
o3, kv3 = t.forward(torch.tensor([[1]]), kv_cache=kv2)
print(o3)

tensor([[[ 0.3658, -1.3014,  0.5082,  0.6040],
         [-0.2335, -0.3141, -0.0363, -0.9091]]], grad_fn=<ViewBackward0>)
tensor([[[0.4663, 0.2188, 0.6028, 0.0415]]], grad_fn=<ViewBackward0>)


In [11]:
model = ToyTransformer(C_VOCAB_SIZE, C_NUM_LAYERS, C_NUM_HEADS, C_HIDDEN_SIZE, C_SEQ_LEN)
model = model.to(C_DEVICE)
print('Total parameters:', sum([t.numel() for t in model.parameters()]))

Total parameters: 13993216


In [12]:
model

ToyTransformer(
  (sem_embed): Embedding(4096, 384)
  (pos_embed): Embedding(512, 384)
  (decoder_layers): ModuleList(
    (0-5): 6 x DecoderLayer(
      (mha): MultiHeadAttention(
        (attn_heads): ModuleList(
          (0-5): 6 x AttentionHead(
            (q_proj): Linear(in_features=384, out_features=64, bias=True)
            (k_proj): Linear(in_features=384, out_features=64, bias=True)
            (v_proj): Linear(in_features=384, out_features=64, bias=True)
          )
        )
        (o_proj): Linear(in_features=384, out_features=384, bias=True)
      )
      (up_proj): Linear(in_features=384, out_features=1536, bias=True)
      (down_proj): Linear(in_features=1536, out_features=384, bias=True)
      (ln_mha): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (ln_ffn): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (act): GELU(approximate='none')
    )
  )
  (lm_head): Linear(in_features=384, out_features=4096, bias=True)
)

In [13]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3, total_steps=2000, final_div_factor=1e2)

In [16]:
C_BATCH_SIZE = 128
for epoch_num in range(100):
    for batch_i in list(range(0, len(train_seq), C_BATCH_SIZE)):
        step_start_time = time.time()

        inputs = train_seq[batch_i:batch_i + C_BATCH_SIZE, :-1]
        labels = train_seq[batch_i:batch_i + C_BATCH_SIZE, 1:]
        # masks = train_mask[batch_i:batch_i + C_BATCH_SIZE, 1:].to(model.device, C_DTYPE)

        logits, _ = model.forward(inputs.to(model.device))
        probs = torch.softmax(logits, dim=2)  # BSZ * SEQ * VOCAB
        probs_flat = probs.view(-1, C_VOCAB_SIZE)
        #  * masks.reshape(-1)
        loss = (-torch.log(probs_flat[torch.arange(probs_flat.shape[0]), labels.reshape(-1)])).mean()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        step_time_cost = time.time() - step_start_time
        throughput = round((C_BATCH_SIZE * C_SEQ_LEN) / step_time_cost / 1000, 2)
        print(
            f'Epoch {epoch_num} Step {batch_i // C_BATCH_SIZE + 1} - Loss: {loss.item():.3f} LR: {scheduler.get_last_lr()[0]:.3} '
            f'Throughput: {throughput} kts')
        scheduler.step()

Epoch 0 Step 1 - Loss: 1.217 LR: 0.000105 Throughput: 2068.21 kts
Epoch 1 Step 1 - Loss: 1.185 LR: 0.000106 Throughput: 1674.69 kts
Epoch 2 Step 1 - Loss: 1.153 LR: 0.000107 Throughput: 2552.68 kts
Epoch 3 Step 1 - Loss: 1.122 LR: 0.000108 Throughput: 2494.42 kts
Epoch 4 Step 1 - Loss: 1.091 LR: 0.00011 Throughput: 2392.47 kts
Epoch 5 Step 1 - Loss: 1.060 LR: 0.000111 Throughput: 2196.12 kts
Epoch 6 Step 1 - Loss: 1.029 LR: 0.000112 Throughput: 2061.02 kts
Epoch 7 Step 1 - Loss: 0.999 LR: 0.000114 Throughput: 1808.82 kts
Epoch 8 Step 1 - Loss: 0.968 LR: 0.000115 Throughput: 1855.0 kts
Epoch 9 Step 1 - Loss: 0.938 LR: 0.000116 Throughput: 1600.29 kts
Epoch 10 Step 1 - Loss: 0.909 LR: 0.000118 Throughput: 2556.2 kts
Epoch 11 Step 1 - Loss: 0.880 LR: 0.000119 Throughput: 1597.31 kts
Epoch 12 Step 1 - Loss: 0.852 LR: 0.00012 Throughput: 2486.21 kts
Epoch 13 Step 1 - Loss: 0.825 LR: 0.000122 Throughput: 1799.86 kts
Epoch 14 Step 1 - Loss: 0.799 LR: 0.000123 Throughput: 1918.04 kts
Epoch 15 

In [19]:
def generate(tokenizer, prompt, temperature, top_p, rep_penalty, max_new_tokens=20, total_tokens=None):
    feed_tokens = tokenizer.encode(prompt)
    all_tokens = feed_tokens.copy()
    if total_tokens is not None:
        max_new_tokens = max(0, total_tokens - len(feed_tokens))

    kv_cache = None
    for _ in range(max_new_tokens):
        logits, kv_cache = model.forward(torch.tensor([feed_tokens]).to(C_DEVICE), 
                                         position_ids=None if kv_cache is None else torch.tensor([[len(all_tokens) - 1]]).to(C_DEVICE), 
                                         kv_cache=kv_cache)
        logits = logits[0][-1].cpu()

        # apply repetition penalty
        logits_rep = torch.gather(logits, 0, torch.tensor(all_tokens))
        logits_rep = torch.where(logits_rep < 0, logits_rep * rep_penalty, logits_rep / rep_penalty)
        logits.scatter_(0, torch.tensor(all_tokens), logits_rep)

        # apply temperature
        logits /= max(temperature, 1e-6)

        probs = torch.softmax(logits, dim=0)

        # apply top-p
        ordered_probs, ordered_indices = torch.sort(probs, descending=True)
        cum_probs = torch.cumsum(ordered_probs, dim=0).tolist()
        top_p_index = bisect.bisect_right(cum_probs, top_p) + 1
        ordered_probs, ordered_indices = ordered_probs[:top_p_index], ordered_indices[:top_p_index]
        sampled_index = ordered_indices[torch.multinomial(ordered_probs, num_samples=1).item()].item()

        all_tokens.append(sampled_index)
        feed_tokens = [sampled_index]
    # print(tokens)
    return tokenizer.decode(all_tokens)

In [20]:
a = time.time()
print(generate(tokenizer, '<unk>', 1.0, 0.01, 1.0, total_tokens=256))
print(time.time() - a)

<unk> don't have to be scared of the loud dog, I'll protect you". The mole felt so safe with the little girl. She was very kind and the mole soon came to trust her. He leaned against her and she kept him safe. The mole had found his best friend.
Once upon a time, in a warm and best. the she she and . a girl. her and and and and warm mole and the she she you mole mole in mole and and her a mole her her in the had and and mole and in and and to and in in warm and against she was her of she and and . in girl mole to her . mole 
1.241924524307251


In [153]:
generate(math_tokenizer, f"$111+11=", temperature=1.0, top_p=0.001, rep_penalty=1.0, total_tokens=13).rstrip('_')

6
6
6
6
6


KeyError: 83

In [191]:
correct_vs_total = [0, 0]
for a in range(100, 120):
    for b in range(100, 120):
        s = generate(math_tokenizer, f"${a}+{b}=", temperature=1.0, top_p=0.01, rep_penalty=1.0,
                     total_tokens=12).rstrip('_')
        r = int(s[s.rfind('=') + 1:][::-1])
        print(a, b, r)
        correct_vs_total[0] += (a + b) == r
        correct_vs_total[1] += 1

KeyError: 126

In [63]:
correct_vs_total

[400, 400]

In [200]:
tokenizer.decode(train_seq[0].tolist())

'<unk> don\'t have to be scared of the loud dog, I\'ll protect you". The mole felt so safe with the little girl. She was very kind and the mole soon came to trust her. He leaned against her and she kept him safe. The mole had found his best friend.\nOnce upon a time, in a warm and '

In [None]:
tokenizer.decode(debug_seq[1].tolist())

In [None]:
torch.tensor(torch.finfo(torch.bfloat16).min / 2, dtype=torch.bfloat16)