In [1]:
import re
from tokenizers import WordTokenizer, CharTokenizer, TRIETokenizerFast
import torch
from torch import nn
import json
import tqdm.notebook as tqdm
import time
import bisect
import random
from typing import *
import gc
from dataclasses import dataclass

In [2]:
with open('./corpus/TinyStoriesV2-GPT4-valid.txt', 'r') as file:
    raw_text = file.read()
    lines = [l.strip() for l in raw_text.split('<|endoftext|>')[:-1]]

In [5]:
# Network definition
C_SEQ_LEN = 2048
C_VOCAB_SIZE = 4096  # No longer used
C_HIDDEN_SIZE = 768
C_NUM_HEADS = 12
C_NUM_LAYERS = 12

C_DEVICE = torch.device('cuda')
C_DTYPE = torch.bfloat16

C_DEBUG = True

In [6]:
len(raw_text)

22493387

In [7]:
# tokenizer = WordTokenizer(raw_text, vocab_size=C_VOCAB_SIZE, reserved_vocab=['<s>', '</s>', '<pad>'])
tokenizer = TRIETokenizerFast('llama_vocab_pruned_20480.json')

In [8]:
# tokenizer.eval_vocab_coverage(raw_text)
print(len(tokenizer.encode(raw_text[:10000])))

2888


In [9]:
# token_ids, position_ids, attn_mask, loss_mask = [[]], [[]], [[]], None
# mask_index = 1
# for l in tqdm.tqdm(lines):
#     cursor = 0
#     sample_token_ids = tokenizer.encode('<s>' + l + '</s>')
#     if len(sample_token_ids) > C_SEQ_LEN:
#         continue
#     sample_position_ids = list(range(len(sample_token_ids)))
#     while cursor < len(sample_token_ids):
#         length = min(C_SEQ_LEN - len(token_ids[-1]), len(sample_token_ids) - cursor)
#         token_ids[-1] += sample_token_ids[cursor:cursor + length]
#         position_ids[-1] += sample_position_ids[cursor:cursor + length]
#         attn_mask[-1] += [mask_index] * length
#         cursor += length
#         mask_index += 1
#         if len(token_ids[-1]) == C_SEQ_LEN:
#             token_ids.append([])
#             position_ids.append([])
#             attn_mask.append([])
#             mask_index = 1
# token_ids = torch.tensor(token_ids[:-1])
# position_ids = torch.tensor(position_ids[:-1])
# attn_mask = torch.tensor(attn_mask[:-1])

In [10]:
# with open('tiny_stories_tokenized.pt', 'wb') as file:
#     torch.save([token_ids, position_ids, attn_mask], file)
if not C_DEBUG:
    with open('tiny_stories_tokenized.pt', 'rb') as file:
        token_ids, position_ids, attn_mask = torch.load(file)

In [11]:
debug_seq = torch.tensor([tokenizer.encode(raw_text[:10000])[:128 * 8]]).view((-1, 128))
debug_seq.shape

torch.Size([8, 128])

In [12]:
@dataclass
class TransformerConfig:
    vocab_size: int = -1,
    num_layers: int = -1,
    num_heads: int = -1,
    hidden_size: int = -1,
    max_seq_len: int = -1,
    root_model: 'ToyTransformer' = None
    device: torch.device = torch.device('cpu')
    dtype: torch.dtype = torch.float32
    enable_rel_pos: bool = False
    enable_fast_attn: bool = True


def expand_attn_mask(custom_attn_mask: torch.Tensor):
    B, T = custom_attn_mask.shape
    mask = custom_attn_mask.unsqueeze(1).repeat((1, T, 1))
    seq_index_mask = (mask == custom_attn_mask[:, torch.arange(T)].view(B, T, 1))
    return seq_index_mask & (torch.tril(mask) > 0)


# naive ROPE implementation following https://arxiv.org/pdf/2104.09864.pdf
def get_rope_cache_slow(seq_len: int, dim: int, theta: int, device: torch.device, dtype: torch.dtype):
    assert dim % 2 == 0
    freqs = theta ** (-2 * torch.arange(0, dim // 2, 1.) / dim)
    freqs = torch.repeat_interleave(freqs, 2)
    v1 = torch.cos(torch.arange(seq_len, dtype=torch.float).view((seq_len, 1)) * freqs)
    v2 = torch.sin(torch.arange(seq_len, dtype=torch.float).view((seq_len, 1)) * freqs)
    v2 = v2 * torch.tensor([1, -1] * (dim // 2))
    indices = torch.tensor([j for i in range(0, dim, 2) for j in (i + 1, i)])
    return v1.to(device, dtype=dtype), v2.to(device, dtype=dtype), indices.to(device)


def apply_rope_slow(x, rope_cache, positions: Optional[torch.Tensor] = None):
    v1, v2, indices = rope_cache
    seq_len, dim = x.shape[1:]
    if positions is None:
        v1 = v1[:seq_len, :]
        v2 = v2[:seq_len, :]
    else:
        v1 = v1[positions, torch.arange(dim)].view((-1, dim))
        v2 = v2[positions, torch.arange(dim)].view((-1, dim))
    applied_x = x * v1 + (x * v2)[:, :, indices]
    return applied_x


# Optimized ROPE implementation copied from https://github.com/facebookresearch/llama/blob/main/llama/model.py
def get_rope_cache_fast(seq_len: int, dim: int, theta: int, device: torch.device, dtype: torch.dtype):
    freqs = (1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)))
    t = torch.arange(seq_len, device=freqs.device)
    freqs = torch.outer(t, freqs).float()
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis.to(device)


def apply_rope_fast(x, rope_cache, positions: Optional[torch.Tensor] = None) -> torch.Tensor:
    x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
    if positions is None and x.shape[1] < rope_cache.shape[0]:
        freqs_cis = rope_cache[:x.shape[1], :]
    elif positions is not None:
        freqs_cis = rope_cache[positions, :]
    else:
        freqs_cis = rope_cache
    freqs_cis = freqs_cis.view([d if i == 1 or i == x_.ndim - 1 else 1 for i, d in enumerate(x_.shape)])

    applied_x = torch.view_as_real(x_ * freqs_cis).flatten(2)
    return applied_x.type_as(x)


class AttentionHead(nn.Module):
    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.dtype = config.dtype
        self.q_proj = nn.Linear(config.hidden_size, config.hidden_size // config.num_heads, dtype=config.dtype)
        self.k_proj = nn.Linear(config.hidden_size, config.hidden_size // config.num_heads, dtype=config.dtype)
        self.v_proj = nn.Linear(config.hidden_size, config.hidden_size // config.num_heads, dtype=config.dtype)

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor],
                kv_cache: Optional[List[torch.Tensor]]) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        B, T, C = x.shape

        mask_zero = torch.tensor(0, dtype=self.dtype)
        mask_val = torch.tensor(torch.finfo(self.dtype).min / 2, dtype=self.dtype)
        if kv_cache is None and attn_mask is not None:
            apply_mask = expand_attn_mask(attn_mask)
        elif kv_cache is None and not self.config.enable_fast_attn:
            apply_mask = expand_attn_mask(torch.ones(x.shape[:2]))
        elif kv_cache is not None:
            apply_mask = torch.ones((B, T, T), dtype=torch.bool)
        else:
            apply_mask = None

        if not self.config.enable_fast_attn:
            apply_mask = torch.where(apply_mask, mask_zero, mask_val)

        use_flash_attn = self.config.enable_fast_attn and kv_cache is None and apply_mask is None

        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        if self.config.enable_rel_pos:
            positions = torch.tensor([kv_cache[0].shape[1]]).to(q.device) if kv_cache is not None else None
            #q = apply_rope(q, self.config.root_model.rope_cache, positions=positions)
            #k = apply_rope(k, self.config.root_model.rope_cache, positions=positions)
            q = apply_rope_fast(q, self.config.root_model.rope_cache, positions)
            k = apply_rope_fast(k, self.config.root_model.rope_cache, positions)

        if kv_cache is not None:
            k = torch.concat([kv_cache[0], k], dim=1)
            v = torch.concat([kv_cache[1], v], dim=1)

        if self.config.enable_fast_attn:
            # noinspection PyUnresolvedReferences
            if use_flash_attn:
                q, k, v, = q.unsqueeze(1), k.unsqueeze(1), v.unsqueeze(1)
            with torch.backends.cuda.sdp_kernel(enable_flash=use_flash_attn, enable_math=not use_flash_attn, enable_mem_efficient=not use_flash_attn):
                attn_result = nn.functional.scaled_dot_product_attention(q, k, v,
                                                                         attn_mask=apply_mask.to(q.device) if apply_mask is not None else None,
                                                                         is_causal=True if apply_mask is None else False)
            if use_flash_attn:
                q, k, v, attn_result = q.squeeze(1), k.squeeze(1), v.squeeze(1), attn_result.squeeze(1)
        else:
            attn_score = (q @ k.permute(0, 2, 1) / (self.hidden_size ** 0.5)) + apply_mask.to(q.device)
            attn_result = torch.softmax(attn_score, dim=2) @ v

        return attn_result, [k, v]


class MultiHeadAttention(nn.Module):
    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.config = config
        self.attn_heads = nn.ModuleList([AttentionHead(config) for _ in range(config.num_heads)])
        self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, dtype=config.dtype)

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor],
                kv_cache: Optional[List[torch.Tensor]]) -> Tuple[torch.Tensor, List[List[torch.Tensor]]]:
        head_outputs = [head(x, attn_mask, kv_cache[idx] if kv_cache is not None else None) for idx, head in
                        enumerate(self.attn_heads)]
        return self.o_proj(torch.concat([o[0] for o in head_outputs], dim=2)), [o[1] for o in head_outputs]


class DecoderLayer(nn.Module):
    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.config = config
        self.mha = MultiHeadAttention(config)
        self.up_proj = nn.Linear(config.hidden_size, config.hidden_size * 4, dtype=config.dtype)
        self.down_proj = nn.Linear(config.hidden_size * 4, config.hidden_size, dtype=config.dtype)
        self.ln_mha = nn.LayerNorm(config.hidden_size, dtype=config.dtype)
        self.ln_ffn = nn.LayerNorm(config.hidden_size, dtype=config.dtype)
        self.act = nn.GELU()

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor],
                kv_cache: Optional[List[torch.Tensor]]) -> Tuple[torch.Tensor, List[List[torch.Tensor]]]:
        mha_output, new_kv_cache = self.mha(self.ln_mha(x), attn_mask, kv_cache)
        mha_output = x + mha_output
        ffn_output = self.down_proj(self.act(self.up_proj(self.ln_ffn(mha_output))))
        return mha_output + ffn_output, new_kv_cache


class ToyTransformer(nn.Module):
    def __init__(self, vocab_size: int, num_layers: int, num_heads: int, hidden_size: int, max_seq_len: int,
                 device: torch.device = torch.device('cpu'), dtype: torch.dtype = torch.float32,
                 enable_rel_pos: bool = False, enable_fast_attn: bool = False):
        super().__init__()
        self.config = TransformerConfig(vocab_size, num_layers, num_heads, hidden_size, max_seq_len, self, device,
                                        dtype, enable_rel_pos, enable_fast_attn)

        self.sem_embed = nn.Embedding(vocab_size, hidden_size, dtype=dtype)

        if not self.config.enable_rel_pos:
            self.pos_embed = nn.Embedding(max_seq_len, hidden_size, dtype=dtype)
        else:
            # self.rope_cache = get_rope_cache(max_seq_len, hidden_size // num_heads, 10000, device, dtype)
            self.rope_cache = get_rope_cache_fast(max_seq_len, hidden_size // num_heads, 10000, device, dtype)

        self.decoder_layers = nn.ModuleList([DecoderLayer(self.config) for _ in range(num_layers)])
        self.lm_head = nn.Linear(hidden_size, vocab_size, dtype=dtype)
        self.to(device)

    def forward(self, seq: torch.Tensor,
                position_ids: Optional[torch.Tensor] = None,
                attn_mask: Optional[torch.Tensor] = None,
                kv_cache: Optional[List[torch.Tensor]] = None) -> Tuple[torch.Tensor, List[List[List[torch.Tensor]]]]:

        if self.config.enable_rel_pos:
            hidden = self.sem_embed(seq)
        elif position_ids is not None:
            hidden = self.sem_embed(seq) + self.pos_embed(position_ids)
        else:
            hidden = self.sem_embed(seq) + self.pos_embed(torch.arange(0, seq.shape[1], 1).to(self.device))

        new_kv_cache = []
        for idx, decoder in enumerate(self.decoder_layers):
            hidden, layer_kv_cache = decoder(hidden, attn_mask, kv_cache[idx] if kv_cache is not None else None)
            new_kv_cache.append(layer_kv_cache)

        return self.lm_head(hidden), new_kv_cache

    @property
    def device(self):
        return next(self.parameters()).device

# torch.manual_seed(0)
# a = AttentionHead(TransformerConfig(num_heads=1, hidden_size=256, max_seq_len=2, enable_rel_pos=True, enable_fast_attn=False))
# torch.manual_seed(0)
# b = AttentionHead(TransformerConfig(num_heads=1, hidden_size=256, max_seq_len=2, enable_rel_pos=False, enable_fast_attn=False))
# 
# d = torch.randn((3, 256, 256))
# ao = (a.forward(d, None, None)[0])
# bo = (b.forward(d, None, None)[0])
# torch.allclose(ao, bo)

In [13]:
debug_model = ToyTransformer(tokenizer.get_vocab_size(), 2, 2, 256, 128, enable_fast_attn=True, enable_rel_pos=True)
print('Total parameters:', sum([t.numel() for t in debug_model.parameters()]))

Total parameters: 12085760


In [14]:
if not C_DEBUG:
    model = ToyTransformer(tokenizer.get_vocab_size(), C_NUM_LAYERS, C_NUM_HEADS, C_HIDDEN_SIZE, 511, C_DEVICE, C_DTYPE, enable_rel_pos=True,
                           enable_fast_attn=True)
    model = model.to(C_DEVICE)
    print('Total parameters:', sum([t.numel() for t in model.parameters()]))

In [15]:
if not C_DEBUG:
    # model.load_state_dict(torch.load('./tiny_stories_0.8_epoch.pt'))
    print(model)

In [16]:
gc.collect()
torch.cuda.empty_cache()

In [18]:
def train_model(model, num_epochs, batch_size, max_lr, min_lr, warmup_ratio,
                token_ids, position_ids, attn_masks, loss_masks, show_progress=True):
    optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=max_lr,
                                                    total_steps=(len(token_ids) // batch_size + 1) * num_epochs,
                                                    final_div_factor=max_lr / min_lr, pct_start=warmup_ratio)

    model.train()
    for epoch_num in range(num_epochs):
        batches = tqdm.tqdm(list(range(0, len(token_ids), batch_size)), desc=f'Epoch {epoch_num}', disable=not show_progress)
        for batch_i in batches:
            step_start_time = time.time()

            inputs = token_ids[batch_i:batch_i + batch_size, :-1].to(model.device)
            labels = token_ids[batch_i:batch_i + batch_size, 1:].to(model.device)

            positions = position_ids[batch_i:batch_i + batch_size, :-1].to(model.device) if position_ids is not None else None
            attn_mask = attn_masks[batch_i:batch_i + batch_size, :-1].to(model.device) if attn_masks is not None else None
            loss_mask = loss_masks[batch_i:batch_i + batch_size, 1:] if loss_masks is not None else None

            logits, kv_state = model.forward(inputs, position_ids=positions, attn_mask=attn_mask)

            probs = torch.softmax(logits, dim=2).view(-1, logits.shape[-1])

            loss = (-torch.log(probs[torch.arange(probs.shape[0]), labels.reshape(-1)]))
            if loss_mask is not None:
                loss = (loss * loss_mask.reshape(-1)).mean()
            else:
                loss = loss.mean()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            step_time_cost = time.time() - step_start_time
            throughput = round(probs.shape[0] / step_time_cost / 1000, 2)

            step_stat = {'Loss': f'{loss.item():.3f}',
                         'LR': f'{scheduler.get_last_lr()[0]:.5}',
                         'Throughput': f'{throughput} kt/s'}

            if show_progress:
                batches.set_postfix(step_stat)
            else:
                print(', '.join(f'{s[0]}:{s[1]}' for s in step_stat.items()))

            scheduler.step()
        batches.close()


if C_DEBUG:
    train_model(debug_model, num_epochs=100, batch_size=128, max_lr=1e-3, min_lr=1e-4,
                warmup_ratio=0.1,
                token_ids=debug_seq, position_ids=None, attn_masks=None, loss_masks=None,
                show_progress=False)
else:
    train_model(model, num_epochs=1, batch_size=16 * 4, max_lr=1e-3, min_lr=1e-4,
                warmup_ratio=0.1,
                token_ids=token_ids, position_ids=position_ids, attn_masks=None, loss_masks=None,
                show_progress=True)

Loss:0.017, LR:4e-05, Throughput:5.62 kt/s
Loss:0.016, LR:6.8948e-05, Throughput:4.19 kt/s
Loss:0.015, LR:0.0001523, Throughput:6.1 kt/s
Loss:0.013, LR:0.00028, Throughput:6.18 kt/s
Loss:0.011, LR:0.00043665, Throughput:6.37 kt/s
Loss:0.011, LR:0.00060335, Throughput:5.64 kt/s
Loss:0.009, LR:0.00076, Throughput:6.36 kt/s
Loss:0.005, LR:0.0008877, Throughput:4.56 kt/s
Loss:0.005, LR:0.00097105, Throughput:6.24 kt/s
Loss:0.005, LR:0.001, Throughput:6.29 kt/s
Loss:0.004, LR:0.0009997, Throughput:6.38 kt/s
Loss:0.005, LR:0.00099879, Throughput:5.67 kt/s
Loss:0.003, LR:0.00099727, Throughput:6.42 kt/s
Loss:0.011, LR:0.00099515, Throughput:4.99 kt/s
Loss:0.011, LR:0.00099243, Throughput:6.32 kt/s
Loss:0.017, LR:0.00098912, Throughput:5.25 kt/s
Loss:0.012, LR:0.00098521, Throughput:6.22 kt/s
Loss:0.014, LR:0.00098071, Throughput:6.15 kt/s
Loss:0.006, LR:0.00097563, Throughput:6.31 kt/s
Loss:0.014, LR:0.00096997, Throughput:5.23 kt/s
Loss:0.007, LR:0.00096374, Throughput:6.46 kt/s
Loss:0.005, 

In [19]:
def generate(model, tokenizer, prompt, temperature, top_p, rep_penalty,
             max_new_tokens=20, total_tokens=None,
             end_tokens=None,
             enable_kv_cache=True):
    model.eval()

    feed_tokens = tokenizer.encode(prompt)
    all_tokens = feed_tokens.copy()
    if total_tokens is not None:
        max_new_tokens = max(0, total_tokens - len(feed_tokens))

    with torch.no_grad():
        kv_cache = None
        for _ in range(max_new_tokens):
            position_ids = None if kv_cache is None else torch.tensor([[len(all_tokens) - 1]]).to(model.device)
            logits, kv_cache = model.forward(
                torch.tensor([feed_tokens if enable_kv_cache else all_tokens]).to(model.device),
                position_ids=position_ids,
                kv_cache=kv_cache)
            logits = logits[0][-1].cpu()
            if not enable_kv_cache:
                kv_cache = None

            # apply repetition penalty
            logits_rep = torch.gather(logits, 0, torch.tensor(all_tokens))
            logits_rep = torch.where(logits_rep < 0, logits_rep * rep_penalty, logits_rep / rep_penalty)
            logits.scatter_(0, torch.tensor(all_tokens), logits_rep)

            # apply temperature
            logits /= max(temperature, 1e-6)

            probs = torch.softmax(logits, dim=0)

            # apply top-p
            ordered_probs, ordered_indices = torch.sort(probs, descending=True)
            cum_probs = torch.cumsum(ordered_probs, dim=0).tolist()
            top_p_index = bisect.bisect_right(cum_probs, top_p) + 1
            ordered_probs, ordered_indices = ordered_probs[:top_p_index], ordered_indices[:top_p_index]
            sampled_index = ordered_indices[torch.multinomial(ordered_probs, num_samples=1).item()].item()

            all_tokens.append(sampled_index)
            feed_tokens = [sampled_index]

            if end_tokens is not None and sampled_index in end_tokens:
                break

    return tokenizer.decode(all_tokens)

In [22]:
print(repr(tokenizer.decode(debug_seq[0].tolist())))

'u don\'t have to be scared of the loud dog, I\'ll protect you". The mole felt so safe with the little girl. She was very kind and the mole soon came to trust her. He leaned against her and she kept him safe. The mole had found his best friend.\n<|endoftext|>\nOnce upon a time, in a warm and sunny place, there was a big pit. A little boy named Tom liked to play near the pit. One day, Tom lost his red ball. He was very sad.\nTom asked his friend, Sam'


In [25]:
a = time.time()
result = generate(debug_model, tokenizer, 'u',
                  temperature=1.0, top_p=0.001, rep_penalty=1.0,
                  total_tokens=128,
                  end_tokens=tokenizer.encode('</s>'),
                  enable_kv_cache=True)
print(repr(result))
print(f'{time.time() - a:.3f} sec(s)')

'u don\'t have to be scared of the loud dog, I\'ll protect you". The mole felt so safe with the little girl. She was very kind and the mole soon came to trust her. He leaned against her and she kept him safe. The mole had found his best friend.\n<|endoftext|>\nOnce upon a time, in a warm and sunny place, there was a big pit. A little boy named Tom liked to play near the pit. One day, Tom lost his red ball. He was very sad.\nTom asked his friend, Sam'
0.590 sec(s)
