In [None]:
import re
from tokenizers import WordTokenizer, CharTokenizer
import torch
from torch import nn
import json
import tqdm.notebook as tqdm
import time
import bisect
import random
from typing import *
import gc
from dataclasses import dataclass

In [None]:
with open('./corpus/TinyStoriesV2-GPT4-valid.txt', 'r') as file:
    raw_text = file.read()
    lines = [l.strip() for l in raw_text.split('<|endoftext|>')[:-1]]

In [None]:
# Network definition
C_SEQ_LEN = 512
C_VOCAB_SIZE = 4096
C_HIDDEN_SIZE = 512
C_NUM_HEADS = 8
C_NUM_LAYERS = 8

C_DEVICE = torch.device('cuda')
C_DTYPE = torch.bfloat16

In [None]:
len(raw_text)

In [None]:
tokenizer = WordTokenizer(raw_text, vocab_size=C_VOCAB_SIZE, reserved_vocab=['<s>', '</s>', '<pad>'])

In [None]:
tokenizer.eval_vocab_coverage(raw_text)

In [None]:
# token_ids, position_ids, attn_mask, loss_mask = [[]], [[]], [[]], None
# mask_index = 1
# for l in tqdm.tqdm(lines):
#     cursor = 0
#     sample_token_ids = tokenizer.encode('<s>' + l + '</s>')
#     if len(sample_token_ids) > C_SEQ_LEN:
#         continue
#     sample_position_ids = list(range(len(sample_token_ids)))
#     while cursor < len(sample_token_ids):
#         length = min(C_SEQ_LEN - len(token_ids[-1]), len(sample_token_ids) - cursor)
#         token_ids[-1] += sample_token_ids[cursor:cursor + length]
#         position_ids[-1] += sample_position_ids[cursor:cursor + length]
#         attn_mask[-1] += [mask_index] * length
#         cursor += length
#         mask_index += 1
#         if len(token_ids[-1]) == C_SEQ_LEN:
#             token_ids.append([])
#             position_ids.append([])
#             attn_mask.append([])
#             mask_index = 1
# token_ids = torch.tensor(token_ids[:-1])
# position_ids = torch.tensor(position_ids[:-1])
# attn_mask = torch.tensor(attn_mask[:-1])

In [None]:
# with open('tiny_stories_tokenized.pt', 'wb') as file:
#     torch.save([token_ids, position_ids, attn_mask], file)
with open('tiny_stories_tokenized.pt', 'rb') as file:
    token_ids, position_ids, attn_mask = torch.load(file)

In [None]:
debug_seq = torch.tensor([tokenizer.encode(raw_text[:10000])[:128 * 8]]).view((-1, 128))
debug_seq.shape

In [None]:
@dataclass
class TransformerConfig:
    vocab_size: int = -1,
    num_layers: int = -1,
    num_heads: int = -1,
    hidden_size: int = -1,
    max_seq_len: int = -1,
    root_model: 'ToyTransformer' = None
    dtype: torch.dtype = torch.float32
    enable_rel_pos: bool = False
    enable_fast_attn: bool = True


def expand_attn_mask(custom_attn_mask: torch.Tensor):
    B, T = custom_attn_mask.shape
    mask = custom_attn_mask.unsqueeze(1).repeat((1, T, 1))
    seq_index_mask = (mask == custom_attn_mask[:, torch.arange(T)].view(B, T, 1))
    return seq_index_mask & (torch.tril(mask) > 0)


class AttentionHead(nn.Module):
    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.dtype = config.dtype
        self.q_proj = nn.Linear(config.hidden_size, config.hidden_size // config.num_heads, dtype=config.dtype)
        self.k_proj = nn.Linear(config.hidden_size, config.hidden_size // config.num_heads, dtype=config.dtype)
        self.v_proj = nn.Linear(config.hidden_size, config.hidden_size // config.num_heads, dtype=config.dtype)

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor],
                kv_cache: Optional[List[torch.Tensor]]) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        B, T, C = x.shape

        mask_zero = torch.tensor(0, dtype=self.dtype)
        mask_val = torch.tensor(torch.finfo(self.dtype).min / 2, dtype=self.dtype)
        if kv_cache is None and attn_mask is not None:
            causal_mask = expand_attn_mask(attn_mask)
        elif kv_cache is None:
            causal_mask = expand_attn_mask(torch.ones(x.shape[:2]))
        else:
            causal_mask = torch.ones((B, T, T), dtype=torch.bool)
        if not self.config.enable_fast_attn:
            causal_mask = torch.where(causal_mask, mask_zero, mask_val)

        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        if kv_cache is not None:
            k = torch.concat([kv_cache[0], k], dim=1)
            v = torch.concat([kv_cache[1], v], dim=1)

        if self.config.enable_rel_pos:
            if kv_cache is None:
                rel_pos = torch.tensor([[s + (self.config.max_seq_len - 1) for s in range(-i, T - i)] for i in range(T)],
                                       device=q.device)
            else:
                rel_pos = torch.tensor([[s + (self.config.max_seq_len - 1) for s in range(-k.shape[1] + 1, 1)]], device=q.device)
            rel_emb = self.config.root_model.pos_embed(rel_pos)
            attn_offset = (q.unsqueeze(2) @ rel_emb.permute(0, 2, 1)).squeeze(2)
            attn_score = ((q @ k.permute(0, 2, 1) + attn_offset) / (self.hidden_size ** 0.5)) + causal_mask.to(q.device)
            attn_result = torch.softmax(attn_score, dim=2) @ v

        elif self.config.enable_fast_attn:
            # noinspection PyUnresolvedReferences
            with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=True):
                attn_result = nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=causal_mask.to(q.device))
        else:
            attn_score = (q @ k.permute(0, 2, 1) / (self.hidden_size ** 0.5)) + causal_mask.to(q.device)
            attn_result = torch.softmax(attn_score, dim=2) @ v

        return attn_result, [k, v]


class MultiHeadAttention(nn.Module):
    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.config = config
        self.attn_heads = nn.ModuleList([AttentionHead(config) for _ in range(config.num_heads)])
        self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, dtype=config.dtype)

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor],
                kv_cache: Optional[List[torch.Tensor]]) -> Tuple[torch.Tensor, List[List[torch.Tensor]]]:
        head_outputs = [head(x, attn_mask, kv_cache[idx] if kv_cache is not None else None) for idx, head in
                        enumerate(self.attn_heads)]
        return self.o_proj(torch.concat([o[0] for o in head_outputs], dim=2)), [o[1] for o in head_outputs]


class DecoderLayer(nn.Module):
    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.config = config
        self.mha = MultiHeadAttention(config)
        self.up_proj = nn.Linear(config.hidden_size, config.hidden_size * 4, dtype=config.dtype)
        self.down_proj = nn.Linear(config.hidden_size * 4, config.hidden_size, dtype=config.dtype)
        self.ln_mha = nn.LayerNorm(config.hidden_size, dtype=config.dtype)
        self.ln_ffn = nn.LayerNorm(config.hidden_size, dtype=config.dtype)
        self.act = nn.GELU()

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor],
                kv_cache: Optional[List[torch.Tensor]]) -> Tuple[torch.Tensor, List[List[torch.Tensor]]]:
        mha_output, new_kv_cache = self.mha(self.ln_mha(x), attn_mask, kv_cache)
        mha_output = x + mha_output
        ffn_output = self.down_proj(self.act(self.up_proj(self.ln_ffn(mha_output))))
        return mha_output + ffn_output, new_kv_cache


class ToyTransformer(nn.Module):
    def __init__(self, vocab_size: int, num_layers: int, num_heads: int, hidden_size: int, max_seq_len: int,
                 dtype: torch.dtype = torch.float32,
                 enable_rel_pos: bool = False, enable_fast_attn: bool = False):
        super().__init__()
        self.config = TransformerConfig(vocab_size, num_layers, num_heads, hidden_size, max_seq_len, self, dtype,
                                        enable_rel_pos, enable_fast_attn)

        self.sem_embed = nn.Embedding(vocab_size, hidden_size, dtype=dtype)

        if not self.config.enable_rel_pos:
            self.pos_embed = nn.Embedding(max_seq_len, hidden_size, dtype=dtype)
        else:
            self.pos_embed = nn.Embedding(max_seq_len * 2 - 1, hidden_size // num_heads, dtype=dtype)

        self.decoder_layers = nn.ModuleList([DecoderLayer(self.config) for _ in range(num_layers)])
        self.lm_head = nn.Linear(hidden_size, vocab_size, dtype=dtype)

    def forward(self, seq: torch.Tensor,
                position_ids: Optional[torch.Tensor] = None,
                attn_mask: Optional[torch.Tensor] = None,
                kv_cache: Optional[List[torch.Tensor]] = None) -> Tuple[torch.Tensor, List[List[List[torch.Tensor]]]]:

        if self.config.enable_rel_pos:
            hidden = self.sem_embed(seq)
        elif position_ids is not None:
            hidden = self.sem_embed(seq) + self.pos_embed(position_ids)
        else:
            hidden = self.sem_embed(seq) + self.pos_embed(torch.arange(0, seq.shape[1], 1).to(self.device))

        new_kv_cache = []
        for idx, decoder in enumerate(self.decoder_layers):
            hidden, layer_kv_cache = decoder(hidden, attn_mask, kv_cache[idx] if kv_cache is not None else None)
            new_kv_cache.append(layer_kv_cache)

        return self.lm_head(hidden), new_kv_cache

    @property
    def device(self):
        return next(self.parameters()).device

# torch.manual_seed(0)
# a = AttentionHead(TransformerConfig(num_heads=1, hidden_size=256, max_seq_len=2, enable_rel_pos=True, enable_fast_attn=False))
# torch.manual_seed(0)
# b = AttentionHead(TransformerConfig(num_heads=1, hidden_size=256, max_seq_len=2, enable_rel_pos=False, enable_fast_attn=False))
# 
# d = torch.randn((3, 256, 256))
# ao = (a.forward(d, None, None)[0])
# bo = (b.forward(d, None, None)[0])
# torch.allclose(ao, bo)

In [None]:
debug_model = ToyTransformer(C_VOCAB_SIZE, 2, 2, 256, 128, enable_fast_attn=False, enable_rel_pos=True)
print('Total parameters:', sum([t.numel() for t in debug_model.parameters()]))

In [None]:
model = ToyTransformer(C_VOCAB_SIZE, C_NUM_LAYERS, C_NUM_HEADS, C_HIDDEN_SIZE, C_SEQ_LEN, C_DTYPE, enable_rel_pos=True, enable_fast_attn=False)
model = model.to(C_DEVICE)
print('Total parameters:', sum([t.numel() for t in model.parameters()]))

In [None]:
model

In [None]:
# model.load_state_dict(torch.load('./tiny_stories_0.8_epoch.pt'))

In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
def train_model(model, num_epochs, batch_size, max_lr, min_lr, warmup_ratio,
                token_ids, position_ids, attn_masks, loss_masks, show_progress=True):
    optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=max_lr,
                                                    total_steps=(len(token_ids) // batch_size + 1) * num_epochs,
                                                    final_div_factor=max_lr / min_lr, pct_start=warmup_ratio)

    model.train()
    for epoch_num in range(num_epochs):
        batches = tqdm.tqdm(list(range(0, len(token_ids), batch_size)), desc=f'Epoch {epoch_num}', disable=not show_progress)
        for batch_i in batches:
            step_start_time = time.time()

            inputs = token_ids[batch_i:batch_i + batch_size, :-1].to(model.device)
            labels = token_ids[batch_i:batch_i + batch_size, 1:].to(model.device)

            positions = position_ids[batch_i:batch_i + batch_size, :-1].to(model.device) if position_ids is not None else None
            attn_mask = attn_masks[batch_i:batch_i + batch_size, :-1].to(model.device) if attn_masks is not None else None
            loss_mask = loss_masks[batch_i:batch_i + batch_size, 1:] if loss_masks is not None else None

            logits, kv_state = model.forward(inputs, position_ids=positions, attn_mask=attn_mask)

            probs = torch.softmax(logits, dim=2).view(-1, logits.shape[-1])

            loss = (-torch.log(probs[torch.arange(probs.shape[0]), labels.reshape(-1)]))
            if loss_mask is not None:
                loss = (loss * loss_mask.reshape(-1)).mean()
            else:
                loss = loss.mean()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            step_time_cost = time.time() - step_start_time
            throughput = round((probs.shape[0] * probs.shape[1]) / step_time_cost / 1000, 2)

            step_stat = {'Loss': f'{loss.item():.3f}',
                         'LR': f'{scheduler.get_last_lr()[0]:.5}',
                         'Throughput': f'{throughput} kt/s'}

            if show_progress:
                batches.set_postfix(step_stat)
            else:
                print(', '.join(f'{s[0]}:{s[1]}' for s in step_stat.items()))

            scheduler.step()
        batches.close()


train_model(debug_model, num_epochs=100, batch_size=128, max_lr=1e-3, min_lr=1e-4,
            warmup_ratio=0.1,
            token_ids=debug_seq, position_ids=None, attn_masks=None, loss_masks=None,
            show_progress=False)

# train_model(model, num_epochs=1, batch_size=4, max_lr=1e-3, min_lr=1e-4,
#             warmup_ratio=0.1,
#             token_ids=token_ids, position_ids=position_ids, attn_masks=attn_mask, loss_masks=None,
#             show_progress=True)

In [None]:
def generate(model, tokenizer, prompt, temperature, top_p, rep_penalty,
             max_new_tokens=20, total_tokens=None,
             end_tokens=None,
             enable_kv_cache=True):
    model.eval()

    feed_tokens = tokenizer.encode(prompt)
    all_tokens = feed_tokens.copy()
    if total_tokens is not None:
        max_new_tokens = max(0, total_tokens - len(feed_tokens))

    with torch.no_grad():
        kv_cache = None
        for _ in range(max_new_tokens):
            position_ids = None if kv_cache is None else torch.tensor([[len(all_tokens) - 1]]).to(model.device)
            logits, kv_cache = model.forward(
                torch.tensor([feed_tokens if enable_kv_cache else all_tokens]).to(model.device),
                position_ids=position_ids,
                kv_cache=kv_cache)
            logits = logits[0][-1].cpu()
            if not enable_kv_cache:
                kv_cache = None

            # apply repetition penalty
            logits_rep = torch.gather(logits, 0, torch.tensor(all_tokens))
            logits_rep = torch.where(logits_rep < 0, logits_rep * rep_penalty, logits_rep / rep_penalty)
            logits.scatter_(0, torch.tensor(all_tokens), logits_rep)

            # apply temperature
            logits /= max(temperature, 1e-6)

            probs = torch.softmax(logits, dim=0)

            # apply top-p
            ordered_probs, ordered_indices = torch.sort(probs, descending=True)
            cum_probs = torch.cumsum(ordered_probs, dim=0).tolist()
            top_p_index = bisect.bisect_right(cum_probs, top_p) + 1
            ordered_probs, ordered_indices = ordered_probs[:top_p_index], ordered_indices[:top_p_index]
            sampled_index = ordered_indices[torch.multinomial(ordered_probs, num_samples=1).item()].item()

            all_tokens.append(sampled_index)
            feed_tokens = [sampled_index]

            if end_tokens is not None and sampled_index in end_tokens:
                break

    return tokenizer.decode(all_tokens)

In [None]:
print(repr(tokenizer.decode(debug_seq[0].tolist())))

In [None]:
a = time.time()
result = generate(debug_model, tokenizer, '<unk>',
                  temperature=1.0, top_p=0.001, rep_penalty=1.0,
                  total_tokens=128,
                  end_tokens=tokenizer.encode('</s>'),
                  enable_kv_cache=True)
print(repr(result))
print(f'{time.time() - a:.3f} sec(s)')

In [None]:
def isTorchSubClass(obj):
    for parent in obj.__class__.__mro__:
        if parent.__module__.startswith("torch"):
            return True
    return False


def findTensors(obj, objPath, results, depth):
    if depth > 5 or obj == results:
        return

    if isinstance(obj, (list, tuple, set)):
        for i, o in enumerate(obj):
            findTensors(o, f"{objPath}[{i}]", results, depth + 1)
    elif isinstance(obj, dict):
        for k, v in obj.items():
            findTensors(v, f"{objPath}[{k}]", results, depth + 1)

    if type(obj) is torch.Tensor:
        results.setdefault(objPath, obj)
    elif isTorchSubClass(obj):
        for attrName in dir(obj):
            try:
                findTensors(
                    getattr(obj, attrName), f"{objPath}.{attrName}", results, depth + 1
                )
            except:
                pass


def outputTensorSummary(deepTraverse=False):
    from gc import get_objects
    from warnings import filterwarnings
    from collections import Counter

    unit, unitName = 1024, "KB"

    filterwarnings("ignore", message="torch.distributed.reduce_op is deprecated")

    isTensor = lambda obj: isinstance(obj, torch.Tensor) or (
            hasattr(obj, "data") and isinstance(obj.data, torch.Tensor)
    )

    if deepTraverse:
        globalTensors = {}
        findTensors(globals().copy(), "Global", globalTensors, 0)
        globalTensors = {id(v): k for k, v in globalTensors.items()}
    else:
        globalTensors = {id(v): k for k, v in globals().items() if isTensor(v)}

    totalUsage = 0
    trivialMemoryUsage = 0
    bigTensors = []
    for obj in get_objects():
        try:
            if isTensor(obj):
                if obj.device.index == None:
                    continue
                tensorMemSize = obj.nelement() * obj.element_size()
                totalUsage += tensorMemSize
                if (tensorMemSize / unit) < 1:
                    trivialMemoryUsage += tensorMemSize
                    continue
                if id(obj) in globalTensors:
                    bigTensors.append(
                        (obj.shape, tensorMemSize / unit, globalTensors[id(obj)])
                    )
                else:
                    bigTensors.append((obj.shape, tensorMemSize / unit))
        except:
            pass

    print(f"Total {totalUsage / unit:.2f} {unitName} CUDA memory in use.\n")

    bigTensors.sort(key=lambda x: x[1], reverse=True)

    maxLowerUnit, minLowerUnit = 1000, 100
    while minLowerUnit >= 1:
        inRangeTensors = [t for t in bigTensors if minLowerUnit <= t[1] <= maxLowerUnit]
        groupCounter = Counter(inRangeTensors)
        print(f"Tensors of size {minLowerUnit:>5} - {maxLowerUnit:>5} {unitName}:")
        for tensor, count in groupCounter.items():
            print(
                f"  {count:4} * Size: {tensor[1]:.2f} {unitName} Shape: {[*tensor[0]]}",
                end="",
            )
            print(f' {tensor[2]:.30}' if len(tensor) == 3 else "")

        print(f"Total: {sum([t[1] for t in inRangeTensors]):.2f} {unitName}\n")
        maxLowerUnit, minLowerUnit = maxLowerUnit // 10, minLowerUnit // 10

    print(
        f"Total {trivialMemoryUsage / unit :.2f} {unitName} is occupied by trivial tensors(<=1{unitName})."
    )


outputTensorSummary()

In [None]:
d_bsz = 2
d_seq = 3
d_emb = 7

torch.manual_seed(0)
rel_emb = nn.Embedding(d_seq * 2 - 1, d_emb)

rel_pos = torch.tensor([[s + (d_seq - 1) for s in range(-i, d_seq - i)] for i in range(d_seq)])
print([[s for s in range(-i, d_seq - i)] for i in range(d_seq)])
r = rel_emb(rel_pos)
print(r.shape)

q = torch.randn((d_bsz, d_seq, d_emb))
k = torch.randn((d_bsz, d_seq, d_emb))
v = torch.randn((d_bsz, d_seq, d_emb))
#q @ k.permute(0, 2, 1)

rel_q = q.unsqueeze(1).repeat(1, d_seq, 1, 1)
rel_k = k.unsqueeze(1).repeat(1, d_seq, 1, 1).permute(0, 1, 3, 2)
attn_score = rel_q @ rel_k
attn_score

In [None]:
q @ k.permute(0, 2, 1)

In [None]:
aq = q.unsqueeze(2)
(aq @ r.permute(0, 2, 1)).squeeze(2)