In [1]:
from tokenizers import WordTokenizer, CharTokenizer
import torch
from torch import nn
import json
import tqdm.notebook as tqdm
import time
import bisect
import random
from typing import *

In [2]:
with open('../../Jupyter/NLP/TinyStoriesV2-GPT4-valid-chunks.json', 'r') as file:
    lines = [l.strip() for l in json.load(file) if l.strip() != '']
    raw_text = '\n'.join(lines)

In [3]:
# Network definition
C_SEQ_LEN = 512
C_VOCAB_SIZE = 4096
C_HIDDEN_SIZE = 256
C_NUM_HEADS = 2
C_NUM_LAYERS = 3

In [4]:
tokenizer = WordTokenizer(raw_text, vocab_size=C_VOCAB_SIZE, reserved_vocab=['<s>', '</s>', '<pad>'])

In [5]:
tokenizer.eval_vocab_coverage(raw_text)

0.9967516693080476

In [6]:
encoded_samples = []
for l in tqdm.tqdm(lines):
    encoded_samples += tokenizer.encode('<s>' + l + '</s>')

  0%|          | 0/27629 [00:00<?, ?it/s]

In [7]:
chunks = []
for i in range(0, len(encoded_samples), C_SEQ_LEN):
    chunks.append(encoded_samples[i:i + C_SEQ_LEN])
chunks.pop(-1)
all(len(c) == C_SEQ_LEN for c in chunks)

True

In [8]:
debug_seq = torch.tensor([tokenizer.encode(raw_text[:10000])[:128 * 8]]).view((-1, 128))
debug_seq.shape

torch.Size([8, 128])

In [9]:
# train_seq = torch.tensor(chunks)
# train_seq.shape
train_seq = debug_seq
train_seq.shape

torch.Size([8, 128])

In [None]:
math_tokenizer = CharTokenizer('0123456789+-*/=._', 100)
print(len(math_tokenizer.get_vocab_mapping()))
samples = []
for i in range(32000 * 8):
    a = random.randint(1, 999)
    b = random.randint(1, 999)
    s = a + b
    t = f'{a}+{b}={s}'
    samples.append(t)
max_sample_len = len(max(samples, key=lambda i: len(i)))
samples = [s + '_' * (max_sample_len - len(s)) for s in samples]
train_seq = torch.stack([torch.tensor(math_tokenizer.encode(s)) for s in samples])
train_seq.shape

In [None]:
math_tokenizer.decode(train_seq[0].tolist())

In [None]:
C_SEQ_LEN = 12
C_VOCAB_SIZE = 18

In [21]:
def expand_attn_mask(custom_attn_mask: torch.Tensor):
    B, T = custom_attn_mask.shape
    mask = custom_attn_mask.unsqueeze(1).repeat((1, T, 1))
    seq_index_mask = (mask == custom_attn_mask[:, torch.arange(T)].view(B, T, 1))
    return seq_index_mask & (torch.tril(mask) > 0)


class AttentionHead(nn.Module):
    def __init__(self, num_heads: int, hidden_size: int):
        super().__init__()
        self.hidden_size = hidden_size
        self.q_proj = nn.Linear(hidden_size, hidden_size // num_heads)
        self.k_proj = nn.Linear(hidden_size, hidden_size // num_heads)
        self.v_proj = nn.Linear(hidden_size, hidden_size // num_heads)

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor]):
        if attn_mask is not None:
            causal_mask = torch.where(expand_attn_mask(attn_mask), 0, 1e9)
        else:
            causal_mask = torch.where(expand_attn_mask(torch.ones(x.shape[:2])), 0, 1e9)

        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        attn_score = (q @ k.permute(0, 2, 1) / (self.hidden_size ** 0.5)) - causal_mask

        return torch.softmax(attn_score, dim=2) @ v


class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads: int, hidden_size: int):
        super().__init__()
        self.attn_heads = nn.ModuleList([AttentionHead(num_heads, hidden_size) for _ in range(num_heads)])
        self.o_proj = nn.Linear(hidden_size, hidden_size)

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor]):
        return self.o_proj(torch.concat([a(x, attn_mask) for a in self.attn_heads], dim=2))


class DecoderLayer(nn.Module):
    def __init__(self, num_heads: int, hidden_size: int):
        super().__init__()
        self.mha = MultiHeadAttention(num_heads, hidden_size)
        self.up_proj = nn.Linear(hidden_size, hidden_size * 4)
        self.down_proj = nn.Linear(hidden_size * 4, hidden_size)
        self.ln_mha = nn.LayerNorm(hidden_size)
        self.ln_ffn = nn.LayerNorm(hidden_size)

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor]):
        mha_output = self.ln_mha(x + self.mha(x, attn_mask))
        ffn_output = self.down_proj(torch.relu(self.up_proj(mha_output)))
        return self.ln_ffn(mha_output + ffn_output)


class ToyTransformer(nn.Module):
    def __init__(self, vocab_size: int, num_layers: int, num_heads: int, hidden_size: int, seq_len: int, ):
        super().__init__()
        self.sem_embed = nn.Embedding(vocab_size, hidden_size)
        self.pos_embed = nn.Embedding(seq_len, hidden_size)
        self.decoder_layers = nn.ModuleList([DecoderLayer(num_heads, hidden_size) for _ in range(num_layers)])
        self.lm_head = nn.Linear(hidden_size, vocab_size)

    def forward(self, seq: torch.Tensor,
                position_ids: Optional[torch.Tensor] = None,
                attn_mask: Optional[torch.Tensor] = None):

        if position_ids is None:
            seq_len = seq.shape[1]
            pos_embed = self.pos_embed(torch.arange(0, seq_len, 1))
        else:
            pos_embed = self.pos_embed(position_ids)

        hidden = self.sem_embed(seq) + pos_embed
        for decoder in self.decoder_layers:
            hidden = decoder(hidden, attn_mask)
        logits = self.lm_head(hidden)
        return logits

In [72]:
model = ToyTransformer(C_VOCAB_SIZE, C_NUM_LAYERS, C_NUM_HEADS, C_HIDDEN_SIZE, C_SEQ_LEN * 8)
print('Total parameters:', sum([t.numel() for t in model.parameters()]))
model

Total parameters: 5519104


ToyTransformer(
  (sem_embed): Embedding(4096, 256)
  (pos_embed): Embedding(4096, 256)
  (decoder_layers): ModuleList(
    (0-2): 3 x DecoderLayer(
      (mha): MultiHeadAttention(
        (attn_heads): ModuleList(
          (0-1): 2 x AttentionHead(
            (q_proj): Linear(in_features=256, out_features=128, bias=True)
            (k_proj): Linear(in_features=256, out_features=128, bias=True)
            (v_proj): Linear(in_features=256, out_features=128, bias=True)
          )
        )
        (o_proj): Linear(in_features=256, out_features=256, bias=True)
      )
      (up_proj): Linear(in_features=256, out_features=1024, bias=True)
      (down_proj): Linear(in_features=1024, out_features=256, bias=True)
      (ln_mha): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (ln_ffn): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    )
  )
  (lm_head): Linear(in_features=256, out_features=4096, bias=True)
)

In [73]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
#scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3, total_steps=2000, final_div_factor=1e2)

In [74]:
train_seq_long = train_seq.view(1, -1)
attn_mask_long = torch.concat([torch.tensor([i] * 128) for i in range(1, 9)]).view(1, -1)
pos_ids_long = torch.concat([torch.tensor(list(range(128))) for _ in range(1, 9)]).view(1, -1)

In [75]:
C_BATCH_SIZE = 128
for epoch_num in range(100):
    for batch_i in list(range(0, len(train_seq), C_BATCH_SIZE)):
        step_start_time = time.time()

        inputs = train_seq[batch_i:batch_i + C_BATCH_SIZE, :-1]
        labels = train_seq[batch_i:batch_i + C_BATCH_SIZE, 1:]
        logits = model.forward(inputs, pos_ids_long[:, :-1], attn_mask_long[:, :-1])
        probs = torch.softmax(logits, dim=2)  # BSZ * SEQ * VOCAB
        probs_flat = probs.view(-1, C_VOCAB_SIZE)
        loss = (-torch.log(probs_flat[torch.arange(probs_flat.shape[0]), labels.reshape(-1)])).mean()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        step_time_cost = time.time() - step_start_time
        throughput = round((C_BATCH_SIZE * C_SEQ_LEN) / step_time_cost / 1000, 2)
        print(
            f'Epoch {epoch_num} Step {batch_i // C_BATCH_SIZE + 1} - Loss: {loss.item():.3f} LR: {0.0:.3} '
            f'Throughput: {throughput} kts')
        # scheduler.step()

Epoch 0 Step 1 - Loss: 8.510 LR: 0.0 Throughput: 577.69 kts
Epoch 1 Step 1 - Loss: 5.773 LR: 0.0 Throughput: 687.28 kts
Epoch 2 Step 1 - Loss: 5.000 LR: 0.0 Throughput: 696.53 kts
Epoch 3 Step 1 - Loss: 4.682 LR: 0.0 Throughput: 462.86 kts
Epoch 4 Step 1 - Loss: 4.385 LR: 0.0 Throughput: 651.24 kts
Epoch 5 Step 1 - Loss: 4.165 LR: 0.0 Throughput: 614.38 kts
Epoch 6 Step 1 - Loss: 3.906 LR: 0.0 Throughput: 624.3 kts
Epoch 7 Step 1 - Loss: 3.650 LR: 0.0 Throughput: 690.94 kts
Epoch 8 Step 1 - Loss: 3.457 LR: 0.0 Throughput: 671.34 kts
Epoch 9 Step 1 - Loss: 3.330 LR: 0.0 Throughput: 701.35 kts
Epoch 10 Step 1 - Loss: 3.218 LR: 0.0 Throughput: 691.76 kts
Epoch 11 Step 1 - Loss: 3.099 LR: 0.0 Throughput: 665.85 kts
Epoch 12 Step 1 - Loss: 2.997 LR: 0.0 Throughput: 676.45 kts
Epoch 13 Step 1 - Loss: 2.928 LR: 0.0 Throughput: 698.01 kts
Epoch 14 Step 1 - Loss: 2.868 LR: 0.0 Throughput: 647.47 kts
Epoch 15 Step 1 - Loss: 2.806 LR: 0.0 Throughput: 693.06 kts
Epoch 16 Step 1 - Loss: 2.749 LR: 0

In [76]:
def generate(tokenizer, prompt, temperature, top_p, rep_penalty, max_new_tokens=20, total_tokens=None):
    tokens = tokenizer.encode(prompt)
    if total_tokens is not None:
        max_new_tokens = max(0, total_tokens - len(tokens))
    for _ in range(max_new_tokens):
        logits = model.forward(torch.tensor([tokens]))[0][-1]

        # apply repetition penalty
        logits_rep = torch.gather(logits, 0, torch.tensor(tokens))
        logits_rep = torch.where(logits_rep < 0, logits_rep * rep_penalty, logits_rep / rep_penalty)
        logits.scatter_(0, torch.tensor(tokens), logits_rep)

        # apply temperature
        logits /= max(temperature, 1e-6)

        probs = torch.softmax(logits, dim=0)

        # apply top-p
        ordered_probs, ordered_indices = torch.sort(probs, descending=True)
        cum_probs = torch.cumsum(ordered_probs, dim=0).tolist()
        top_p_index = bisect.bisect_right(cum_probs, top_p) + 1
        ordered_probs, ordered_indices = ordered_probs[:top_p_index], ordered_indices[:top_p_index]
        sampled_index = ordered_indices[torch.multinomial(ordered_probs, num_samples=1).item()].item()

        tokens.append(sampled_index)
    # print(tokens)
    return tokenizer.decode(tokens)

In [None]:
generate(math_tokenizer, f"1+1=", temperature=1.0, top_p=0.01, rep_penalty=1.0, total_tokens=12).rstrip('_')

In [None]:
correct_vs_total = [0, 0]
for a in range(1, 20):
    for b in range(1, 20):
        s = generate(math_tokenizer, f"{a}/{b}=", temperature=1.0, top_p=0.01, rep_penalty=1.0,
                     max_new_tokens=3).rstrip('_')
        r = int(s[s.rfind('=') + 1:])
        correct_vs_total[0] += (a // b) == r
        correct_vs_total[1] += 1

In [None]:
correct_vs_total

In [77]:
generate(tokenizer, 'sunny place', 1.0, 0.05, 1.0, total_tokens=128)

'sunny place, there was a big pit. A little boy named Tom liked to play near the pit. One day, Tom lost his red ball. He was very sad.\nTom asked his friend, Sam, to help him search for the ball. They looked high and low, but they could not find the ball. Tom said, "I think my'

In [66]:
tokenizer.decode(train_seq[0].tolist())

'<unk> don\'t have to be scared of the loud dog, I\'ll protect you". The mole felt so safe with the little girl. She was very kind and the mole soon came to trust her. He leaned against her and she kept him safe. The mole had found his best friend.\nOnce upon a time, in a warm and sunny place, there was a big pit. A little boy named Tom liked to play near the pit. One day, Tom lost his red ball. He was very sad.\nTom asked his friend, Sam, to help him search for the ball. They looked high and low, but they could not find the ball. Tom said, "I think my ball fell into the pit."\nSam and Tom went close to the pit. They were scared, but they wanted to find the red ball. They looked into the pit, but it was too dark to see. Tom said, "We must go in and search for my ball."\nThey went into the pit to search. It was dark and scary. They could not find the ball. They tried to get out, but the pit was too deep. Tom and Sam were stuck in the pit. They called for help, but no one could hear them

In [67]:
tokenizer.decode(debug_seq[1].tolist())

'sunny place, there was a big pit. A little boy named Tom liked to play near the pit. One day, Tom lost his red ball. He was very sad.\nTom asked his friend, Sam, to help him search for the ball. They looked high and low, but they could not find the ball. Tom said, "I think my'

In [None]:
seq_len = 2
causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
causal_mask = causal_mask.masked_fill(causal_mask == 1, 1e9)
causal_mask

In [None]:
torch.where(torch.tensor([True, False]), 1, 2)