In [1]:
from tokenizers import WordTokenizer
import torch
from torch import nn
import json

In [2]:
with open('../../Jupyter/NLP/TinyStoriesV2-GPT4-valid-chunks.json', 'r') as file:
    lines = [l for l in json.load(file)]
    raw_text = '\n'.join(lines)

In [3]:
# Network definition
C_SEQ_LEN = 128
C_VOCAB_SIZE = 4096
C_HIDDEN_SIZE = 256
C_NUM_HEADS = 2
C_NUM_LAYERS = 2

In [4]:
tokenizer = WordTokenizer(raw_text, vocab_size=C_VOCAB_SIZE, reserved_vocab=['<s>', '</s>', '<pad>'])

In [5]:
tokenizer.eval_vocab_coverage(raw_text)

0.9967516696403725

In [6]:
test_seq = torch.tensor([tokenizer.encode(raw_text[:10000])[:C_SEQ_LEN * 8]]).view((-1, C_SEQ_LEN))
tokenizer.decode(test_seq[0].tolist())

'<unk> don\'t have to be scared of the loud dog, I\'ll protect you". The mole felt so safe with the little girl. She was very kind and the mole soon came to trust her. He leaned against her and she kept him safe. The mole had found his best friend.\nOnce upon a time, in a warm and '

In [16]:
class AttentionHead(nn.Module):
    def __init__(self, num_heads: int, hidden_size: int):
        super().__init__()
        self.hidden_size = hidden_size
        self.q_proj = nn.Linear(hidden_size, hidden_size // num_heads)
        self.k_proj = nn.Linear(hidden_size, hidden_size // num_heads)
        self.v_proj = nn.Linear(hidden_size, hidden_size // num_heads)

    def forward(self, x: torch.Tensor):
        seq_len = x.shape[1]
        causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
        causal_mask = causal_mask.masked_fill(causal_mask == 1, 1e9)

        q = self.q_proj(x)  # BSZ * SEQ * HIDDEN
        k = self.k_proj(x)
        v = self.v_proj(x)
        attn_score = (q @ k.permute(0, 2, 1) / (self.hidden_size ** 0.5)) - causal_mask

        return torch.softmax(attn_score, dim=2) @ v


class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads: int, hidden_size: int):
        super().__init__()
        self.attn_heads = nn.ModuleList([AttentionHead(num_heads, hidden_size) for _ in range(num_heads)])
        self.o_proj = nn.Linear(hidden_size, hidden_size)

    def forward(self, x: torch.Tensor):
        return self.o_proj(torch.concat([a(x) for a in self.attn_heads], dim=2))


class DecoderLayer(nn.Module):
    def __init__(self, num_heads: int, hidden_size: int):
        super().__init__()
        self.mha = MultiHeadAttention(num_heads, hidden_size)
        self.up_proj = nn.Linear(hidden_size, hidden_size * 4)
        self.down_proj = nn.Linear(hidden_size * 4, hidden_size)
        self.ln_mha = nn.LayerNorm(hidden_size)
        self.ln_ffn = nn.LayerNorm(hidden_size)

    def forward(self, x: torch.Tensor):
        mha_output = self.ln_mha(x + self.mha(x))
        ffn_output = self.down_proj(torch.relu(self.up_proj(mha_output)))
        return self.ln_ffn(mha_output + ffn_output)


class ToyTransformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.sem_embed = nn.Embedding(C_VOCAB_SIZE, C_HIDDEN_SIZE)
        self.pos_embed = nn.Embedding(C_SEQ_LEN, C_HIDDEN_SIZE)
        # self.attn_heads = [AttentionHead(C_HIDDEN_SIZE) for _ in range(C_NUM_HEADS)]
        # for i, param in enumerate(self.attn_heads):
        #     self.add_module(f'attn_head_{i}', param)
        # self.mha = MultiHeadAttention(C_NUM_HEADS, C_HIDDEN_SIZE)
        # self.o_proj = nn.Linear(C_HIDDEN_SIZE * C_NUM_HEADS, C_HIDDEN_SIZE)
        self.decoder_layers = nn.ModuleList([DecoderLayer(C_NUM_HEADS, C_HIDDEN_SIZE) for _ in range(C_NUM_LAYERS)])
        self.lm_head = nn.Linear(C_HIDDEN_SIZE, C_VOCAB_SIZE)

    def forward(self, seq):
        seq_len = seq.shape[1]
        hidden = self.sem_embed(seq) + self.pos_embed(torch.arange(0, seq_len, 1))
        # attn_out = torch.concat([a(embed) for a in self.attn_heads], dim=2)
        # attn_out = self.mha(embed)
        # mlp_out = torch.relu(self.o_proj(attn_out))
        for decoder in self.decoder_layers:
            hidden = decoder(hidden)
        logits = self.lm_head(hidden)
        return logits

In [17]:
inputs = test_seq[:1, :-1]
labels = test_seq[:1, 1:]

In [18]:
model = ToyTransformer()
print('Total parameters:', sum([t.numel() for t in model.parameters()]))
model

Total parameters: 3713536


ToyTransformer(
  (sem_embed): Embedding(4096, 256)
  (pos_embed): Embedding(128, 256)
  (decoder_layers): ModuleList(
    (0-1): 2 x DecoderLayer(
      (mha): MultiHeadAttention(
        (attn_heads): ModuleList(
          (0-1): 2 x AttentionHead(
            (q_proj): Linear(in_features=256, out_features=128, bias=True)
            (k_proj): Linear(in_features=256, out_features=128, bias=True)
            (v_proj): Linear(in_features=256, out_features=128, bias=True)
          )
        )
        (o_proj): Linear(in_features=256, out_features=256, bias=True)
      )
      (up_proj): Linear(in_features=256, out_features=1024, bias=True)
      (down_proj): Linear(in_features=1024, out_features=256, bias=True)
      (ln_mha): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (ln_ffn): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    )
  )
  (lm_head): Linear(in_features=256, out_features=4096, bias=True)
)

In [19]:
optim = torch.optim.AdamW(model.parameters(), lr=3e-4)

In [20]:
for epoch_num in range(100):
    logits = model.forward(inputs)
    probs = torch.softmax(logits, dim=2)  # BSZ * SEQ * VOCAB
    probs_flat = probs.view(-1, C_VOCAB_SIZE)
    loss = (-torch.log(probs_flat[torch.arange(probs_flat.shape[0]), labels.reshape(-1)])).mean()
    optim.zero_grad()
    loss.backward()
    optim.step()
    print(loss)

tensor(8.6122, grad_fn=<MeanBackward0>)
tensor(7.7484, grad_fn=<MeanBackward0>)
tensor(6.9453, grad_fn=<MeanBackward0>)
tensor(6.2307, grad_fn=<MeanBackward0>)
tensor(5.6297, grad_fn=<MeanBackward0>)
tensor(5.1653, grad_fn=<MeanBackward0>)
tensor(4.8436, grad_fn=<MeanBackward0>)
tensor(4.6376, grad_fn=<MeanBackward0>)
tensor(4.4966, grad_fn=<MeanBackward0>)
tensor(4.3749, grad_fn=<MeanBackward0>)
tensor(4.2477, grad_fn=<MeanBackward0>)
tensor(4.1082, grad_fn=<MeanBackward0>)
tensor(3.9606, grad_fn=<MeanBackward0>)
tensor(3.8147, grad_fn=<MeanBackward0>)
tensor(3.6793, grad_fn=<MeanBackward0>)
tensor(3.5585, grad_fn=<MeanBackward0>)
tensor(3.4504, grad_fn=<MeanBackward0>)
tensor(3.3500, grad_fn=<MeanBackward0>)
tensor(3.2520, grad_fn=<MeanBackward0>)
tensor(3.1532, grad_fn=<MeanBackward0>)
tensor(3.0523, grad_fn=<MeanBackward0>)
tensor(2.9492, grad_fn=<MeanBackward0>)
tensor(2.8433, grad_fn=<MeanBackward0>)
tensor(2.7346, grad_fn=<MeanBackward0>)
tensor(2.6243, grad_fn=<MeanBackward0>)


In [24]:
def generate(prompt, max_new_tokens=20):
    tokens = tokenizer.encode(prompt)
    for _ in range(max_new_tokens):
        logits = model.forward(torch.tensor([tokens]))[0][-1]
        probs = torch.softmax(logits, dim=0)
        tokens.append(torch.argmax(probs).item())
    print(tokens)
    return tokenizer.decode(tokens)


generate("don\'t", 100)

[243, 22, 78, 22, 78, 4, 93, 4, 10, 4, 70, 4, 135, 4, 37, 4, 7, 4, 295, 4, 69, 6, 4, 33, 22, 827, 4, 1275, 4, 32, 12, 5, 4, 14, 4, 1693, 4, 106, 4, 57, 4, 251, 4, 20, 4, 7, 4, 45, 4, 75, 5, 4, 23, 4, 13, 4, 38, 4, 228, 4, 8, 4, 7, 4, 1693, 4, 356, 4, 108, 4, 10, 4, 1614, 4, 24, 5, 4, 17, 4, 2420, 4, 2170, 4, 24, 4, 8, 4, 49, 4, 368, 4, 81, 4, 251, 5, 4, 14, 4, 1693, 4, 31, 4, 94]


'don\'t\'t have to be scared of the loud dog, I\'ll protect you". The mole felt so safe with the little girl. She was very kind and the mole soon came to trust her. He leaned against her and she kept him safe. The mole had found'

In [None]:
tokenizer.decode(inputs[0].tolist())