In [1]:
import re
from tokenizers import WordTokenizer, CharTokenizer
import torch
from torch import nn
import json
import tqdm.notebook as tqdm
import time
import bisect
import random
from typing import *
import gc

In [2]:
with open('./corpus/TinyStoriesV2-GPT4-train.txt', 'r') as file:
    raw_text = file.read()
    lines = [l.strip() for l in raw_text.split('<|endoftext|>')[:-1]]

In [3]:
# Network definition
C_SEQ_LEN = 512
C_VOCAB_SIZE = 4096
C_HIDDEN_SIZE = 512
C_NUM_HEADS = 8
C_NUM_LAYERS = 8

C_DEVICE = torch.device('cuda')
C_DTYPE = torch.bfloat16

In [4]:
len(raw_text)

2226845268

In [5]:
tokenizer = WordTokenizer(raw_text[:100000000], vocab_size=C_VOCAB_SIZE, reserved_vocab=['<s>', '</s>', '<pad>'])

In [6]:
tokenizer.eval_vocab_coverage(raw_text[100000000:200000000])

0.996570014053978

In [7]:
gc.collect()

0

In [8]:
# token_ids, position_ids, attn_mask, loss_mask = [[]], [[]], [[]], None
# mask_index = 1
# for l in tqdm.tqdm(lines):
#     cursor = 0
#     sample_token_ids = tokenizer.encode('<s>' + l + '</s>')
#     if len(sample_token_ids) > C_SEQ_LEN:
#         continue
#     sample_position_ids = list(range(len(sample_token_ids)))
#     while cursor < len(sample_token_ids):
#         length = min(C_SEQ_LEN - len(token_ids[-1]), len(sample_token_ids) - cursor)
#         token_ids[-1] += sample_token_ids[cursor:cursor + length]
#         position_ids[-1] += sample_position_ids[cursor:cursor + length]
#         attn_mask[-1] += [mask_index] * length
#         cursor += length
#         mask_index += 1
#         if len(token_ids[-1]) == C_SEQ_LEN:
#             token_ids.append([])
#             position_ids.append([])
#             attn_mask.append([])
#             mask_index = 1
# token_ids = torch.tensor(token_ids[:-1])
# position_ids = torch.tensor(position_ids[:-1])
# attn_mask = torch.tensor(attn_mask[:-1])

In [9]:
# with open('tiny_stories_tokenized.pt', 'wb') as file:
#     torch.save([token_ids, position_ids, attn_mask], file)
with open('tiny_stories_tokenized.pt', 'rb') as file:
    token_ids, position_ids, attn_mask = torch.load(file)

In [10]:
debug_seq = torch.tensor([tokenizer.encode(raw_text[:10000])[:128 * 8]]).view((-1, 128))
debug_seq.shape

torch.Size([8, 128])

In [78]:
def expand_attn_mask(custom_attn_mask: torch.Tensor):
    B, T = custom_attn_mask.shape
    mask = custom_attn_mask.unsqueeze(1).repeat((1, T, 1))
    seq_index_mask = (mask == custom_attn_mask[:, torch.arange(T)].view(B, T, 1))
    return seq_index_mask & (torch.tril(mask) > 0)


class AttentionHead(nn.Module):
    def __init__(self, num_heads: int, hidden_size: int, dtype: torch.dtype = torch.float32):
        super().__init__()
        self.hidden_size = hidden_size
        self.dtype = dtype
        self.q_proj = nn.Linear(hidden_size, hidden_size // num_heads, dtype=dtype)
        self.k_proj = nn.Linear(hidden_size, hidden_size // num_heads, dtype=dtype)
        self.v_proj = nn.Linear(hidden_size, hidden_size // num_heads, dtype=dtype)

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor],
                kv_cache: Optional[List[torch.Tensor]]) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        # flash attn isn't currently supported due to custom attn mask
        flash_attn_enabled = (self.dtype == torch.float16 or self.dtype == torch.bfloat16) and not globals().get('disable_flash_attn', False)

        B, T, C = x.shape

        mask_zero = torch.tensor(0, dtype=self.dtype)
        mask_val = torch.tensor(torch.finfo(self.dtype).min / 2, dtype=self.dtype)
        if kv_cache is None and attn_mask is not None:
            causal_mask = expand_attn_mask(attn_mask)
        elif kv_cache is None:
            causal_mask = expand_attn_mask(torch.ones(x.shape[:2]))
        else:
            causal_mask = torch.ones((B, T, T), dtype=torch.bool)
        if not flash_attn_enabled:
            causal_mask = torch.where(causal_mask, mask_zero, mask_val)

        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        if kv_cache is not None:
            k = torch.concat([kv_cache[0], k], dim=1)
            v = torch.concat([kv_cache[1], v], dim=1)

        if flash_attn_enabled:
            # noinspection PyUnresolvedReferences
            with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=True):
                attn_result = nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=causal_mask.to(q.device))
        else:
            attn_score = (q @ k.permute(0, 2, 1) / (self.hidden_size ** 0.5)) + causal_mask.to(q.device)
            attn_result = torch.softmax(attn_score, dim=2) @ v

        return attn_result, [k, v]


class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads: int, hidden_size: int, dtype: torch.dtype = torch.float32):
        super().__init__()
        self.attn_heads = nn.ModuleList([AttentionHead(num_heads, hidden_size, dtype) for _ in range(num_heads)])
        self.o_proj = nn.Linear(hidden_size, hidden_size, dtype=dtype)

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor],
                kv_cache: Optional[List[torch.Tensor]]) -> Tuple[torch.Tensor, List[List[torch.Tensor]]]:
        head_outputs = [head(x, attn_mask, kv_cache[idx] if kv_cache is not None else None) for idx, head in
                        enumerate(self.attn_heads)]
        return self.o_proj(torch.concat([o[0] for o in head_outputs], dim=2)), [o[1] for o in head_outputs]


class DecoderLayer(nn.Module):
    def __init__(self, num_heads: int, hidden_size: int, dtype: torch.dtype = torch.float32):
        super().__init__()
        self.mha = MultiHeadAttention(num_heads, hidden_size, dtype)
        self.up_proj = nn.Linear(hidden_size, hidden_size * 4, dtype=dtype)
        self.down_proj = nn.Linear(hidden_size * 4, hidden_size, dtype=dtype)
        self.ln_mha = nn.LayerNorm(hidden_size, dtype=dtype)
        self.ln_ffn = nn.LayerNorm(hidden_size, dtype=dtype)
        self.act = nn.GELU()

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor],
                kv_cache: Optional[List[torch.Tensor]]) -> Tuple[torch.Tensor, List[List[torch.Tensor]]]:
        mha_output, new_kv_cache = self.mha(self.ln_mha(x), attn_mask, kv_cache)
        mha_output = x + mha_output
        ffn_output = self.down_proj(self.act(self.up_proj(self.ln_ffn(mha_output))))
        return mha_output + ffn_output, new_kv_cache


class ToyTransformer(nn.Module):
    def __init__(self, vocab_size: int, num_layers: int, num_heads: int, hidden_size: int, seq_len: int,
                 dtype: torch.dtype = torch.float32):
        super().__init__()
        self.sem_embed = nn.Embedding(vocab_size, hidden_size, dtype=dtype)
        self.pos_embed = nn.Embedding(seq_len, hidden_size, dtype=dtype)
        self.decoder_layers = nn.ModuleList([DecoderLayer(num_heads, hidden_size, dtype) for _ in range(num_layers)])
        self.lm_head = nn.Linear(hidden_size, vocab_size, dtype=dtype)

    def forward(self, seq: torch.Tensor,
                position_ids: Optional[torch.Tensor] = None,
                attn_mask: Optional[torch.Tensor] = None,
                kv_cache: Optional[List[torch.Tensor]] = None) -> Tuple[torch.Tensor, List[List[List[torch.Tensor]]]]:

        if position_ids is None:
            seq_len = seq.shape[1]
            pos_embed = self.pos_embed(torch.arange(0, seq_len, 1).to(self.device))
        else:
            pos_embed = self.pos_embed(position_ids)

        hidden = self.sem_embed(seq) + pos_embed
        new_kv_cache = []
        for idx, decoder in enumerate(self.decoder_layers):
            hidden, layer_kv_cache = decoder(hidden, attn_mask, kv_cache[idx] if kv_cache is not None else None)
            new_kv_cache.append(layer_kv_cache)

        return self.lm_head(hidden), new_kv_cache

    @property
    def device(self):
        return next(self.parameters()).device

In [79]:
model = ToyTransformer(C_VOCAB_SIZE, C_NUM_LAYERS, C_NUM_HEADS, C_HIDDEN_SIZE, C_SEQ_LEN, C_DTYPE)
model = model.to(C_DEVICE)
print('Total parameters:', sum([t.numel() for t in model.parameters()]))

Total parameters: 29679616


In [80]:
model

ToyTransformer(
  (sem_embed): Embedding(4096, 512)
  (pos_embed): Embedding(512, 512)
  (decoder_layers): ModuleList(
    (0-7): 8 x DecoderLayer(
      (mha): MultiHeadAttention(
        (attn_heads): ModuleList(
          (0-7): 8 x AttentionHead(
            (q_proj): Linear(in_features=512, out_features=64, bias=True)
            (k_proj): Linear(in_features=512, out_features=64, bias=True)
            (v_proj): Linear(in_features=512, out_features=64, bias=True)
          )
        )
        (o_proj): Linear(in_features=512, out_features=512, bias=True)
      )
      (up_proj): Linear(in_features=512, out_features=2048, bias=True)
      (down_proj): Linear(in_features=2048, out_features=512, bias=True)
      (ln_mha): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (ln_ffn): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (act): GELU(approximate='none')
    )
  )
  (lm_head): Linear(in_features=512, out_features=4096, bias=True)
)

In [81]:
C_BATCH_SIZE = 64

In [82]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3, total_steps=len(token_ids) // C_BATCH_SIZE + 1, final_div_factor=1e2)

In [83]:
model.load_state_dict(torch.load('./tiny_stories_0.8_epoch.pt'))

<All keys matched successfully>

In [84]:
gc.collect()
torch.cuda.empty_cache()

In [85]:
model.train()
for epoch_num in range(1):
    batches = tqdm.tqdm(list(range(0, len(token_ids), C_BATCH_SIZE)), desc=f'Epoch {epoch_num}', disable=False)
    for batch_i in batches:
        step_start_time = time.time()

        inputs = token_ids[batch_i:batch_i + C_BATCH_SIZE, :-1]
        labels = token_ids[batch_i:batch_i + C_BATCH_SIZE, 1:]
        positions = position_ids[batch_i:batch_i + C_BATCH_SIZE, :-1]
        masks = attn_mask[batch_i:batch_i + C_BATCH_SIZE, :-1]
        # loss_masks = train_mask[batch_i:batch_i + C_BATCH_SIZE, 1:].to(model.device, C_DTYPE)

        logits, _ = model.forward(inputs.to(model.device),
                                  position_ids=positions.to(model.device),
                                  attn_mask=masks.to(model.device)
                                  )
        probs = torch.softmax(logits, dim=2)  # BSZ * SEQ * VOCAB
        probs_flat = probs.view(-1, C_VOCAB_SIZE)
        #  * masks.reshape(-1)
        loss = (-torch.log(probs_flat[torch.arange(probs_flat.shape[0]), labels.reshape(-1)])).mean()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        step_time_cost = time.time() - step_start_time
        throughput = round((C_BATCH_SIZE * C_SEQ_LEN) / step_time_cost / 1000, 2)
        batches.set_postfix({
            'Loss': f'{loss.item():.3f}',
            'LR': f'{scheduler.get_last_lr()[0]:.5}',
            'Throughput': f'{throughput} kt/s'
        })
        # print({
        #     'Loss': f'{loss.item():.3f}',
        #     'LR': f'{scheduler.get_last_lr()[0]:.5}',
        #     'Throughput': f'{throughput} kt/s'
        # })
        scheduler.step()
    batches.close()

Epoch 0:   0%|          | 0/24243 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [86]:
model.eval()


def generate(tokenizer, prompt, temperature, top_p, rep_penalty, max_new_tokens=20, total_tokens=None, end_tokens=None, enable_kv_cache=True):
    feed_tokens = tokenizer.encode(prompt)
    all_tokens = feed_tokens.copy()
    if total_tokens is not None:
        max_new_tokens = max(0, total_tokens - len(feed_tokens))

    with torch.no_grad():
        kv_cache = None
        for _ in range(max_new_tokens):
            position_ids = None if kv_cache is None else torch.tensor([[len(all_tokens) - 1]]).to(C_DEVICE)
            logits, kv_cache = model.forward(torch.tensor([feed_tokens if enable_kv_cache else all_tokens]).to(C_DEVICE),
                                             position_ids=position_ids,
                                             kv_cache=kv_cache)
            logits = logits[0][-1].cpu()
            if not enable_kv_cache:
                kv_cache = None

            # apply repetition penalty
            logits_rep = torch.gather(logits, 0, torch.tensor(all_tokens))
            logits_rep = torch.where(logits_rep < 0, logits_rep * rep_penalty, logits_rep / rep_penalty)
            logits.scatter_(0, torch.tensor(all_tokens), logits_rep)

            # apply temperature
            logits /= max(temperature, 1e-6)

            probs = torch.softmax(logits, dim=0)

            # apply top-p
            ordered_probs, ordered_indices = torch.sort(probs, descending=True)
            cum_probs = torch.cumsum(ordered_probs, dim=0).tolist()
            top_p_index = bisect.bisect_right(cum_probs, top_p) + 1
            ordered_probs, ordered_indices = ordered_probs[:top_p_index], ordered_indices[:top_p_index]
            sampled_index = ordered_indices[torch.multinomial(ordered_probs, num_samples=1).item()].item()

            all_tokens.append(sampled_index)
            feed_tokens = [sampled_index]

            if end_tokens is not None and sampled_index in end_tokens:
                break
    # print(tokens)
    return tokenizer.decode(all_tokens)

In [87]:
tokenizer.decode(token_ids[0].tolist())

'<s>Once upon a time there was a little boy named Ben. Ben loved to explore the world around him. He saw many amazing things, like beautiful <unk> that were on display in a store. One day, Ben was walking through the store when he came across a very special vase. When Ben saw it he was amazed!  \nHe said, “Wow, that is a really amazing vase! Can I buy it?” \nThe shopkeeper smiled and said, “Of course you can. You can take it home and show all your friends how amazing it is!”\nSo Ben took the vase home and he was so proud of it! He called his friends over and showed them the amazing vase. All his friends thought the vase was beautiful and couldn\'t believe how lucky Ben was. \nAnd that\'s how Ben found an amazing vase in the store!</s><s>Once upon a time, there was a reliable otter named Ollie. He lived in a river with his family. They all loved to play and swim together.\nOne day, Ollie\'s mom said, "Ollie, hurry and get some fish for dinner!" Ollie swam fast to catch fish. He saw his 

In [88]:
model.decoder_layers[0].mha.attn_heads[0].dtype

torch.bfloat16

In [89]:
a = time.time()
print(generate(tokenizer, '<s>Once upon the time',
               temperature=1.0, top_p=0.001, rep_penalty=1.1,
               total_tokens=512,
               end_tokens=tokenizer.encode('</s>'),
               enable_kv_cache=True))
print(f'{time.time() - a:.3f} sec(s)')

<s>Once upon the time, there was a little girl named Lily. She had a toy bear named Ben. Lily and Ben did everything together. They played, ate, and slept. One day, Lily's mom gave her a task. The task was to clean her room.
Lily said, "I will do it." She started to clean. She put away toys and clothes. But then, she saw a dead bug on the floor. She felt sad. Her mom came in and said, "Don't worry, Lily."
Lily learned that is not good for <unk>. She learns that can be by and help others. And they all need when you are good.</s>
1.744 sec(s)


In [None]:
tokenizer.decode(debug_seq[-1].tolist())

In [None]:
torch.tensor(torch.finfo(torch.bfloat16).min / 2, dtype=torch.bfloat16)

In [None]:
optimizer.state_dict()['state'].keys()

In [None]:
torch.save(model.state_dict(), 'tiny_stories_0.8_epoch.pt')

In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
x = (torch.cuda.memory_snapshot())

In [None]:
x[0]

In [None]:
def isTorchSubClass(obj):
    for parent in obj.__class__.__mro__:
        if parent.__module__.startswith("torch"):
            return True
    return False


def findTensors(obj, objPath, results, depth):
    if depth > 5 or obj == results:
        return

    if isinstance(obj, (list, tuple, set)):
        for i, o in enumerate(obj):
            findTensors(o, f"{objPath}[{i}]", results, depth + 1)
    elif isinstance(obj, dict):
        for k, v in obj.items():
            findTensors(v, f"{objPath}[{k}]", results, depth + 1)

    if type(obj) is torch.Tensor:
        results.setdefault(objPath, obj)
    elif isTorchSubClass(obj):
        for attrName in dir(obj):
            try:
                findTensors(
                    getattr(obj, attrName), f"{objPath}.{attrName}", results, depth + 1
                )
            except:
                pass


def outputTensorSummary(deepTraverse=False):
    from gc import get_objects
    from warnings import filterwarnings
    from collections import Counter

    unit, unitName = 1024, "KB"

    filterwarnings("ignore", message="torch.distributed.reduce_op is deprecated")

    isTensor = lambda obj: isinstance(obj, torch.Tensor) or (
            hasattr(obj, "data") and isinstance(obj.data, torch.Tensor)
    )

    if deepTraverse:
        globalTensors = {}
        findTensors(globals().copy(), "Global", globalTensors, 0)
        globalTensors = {id(v): k for k, v in globalTensors.items()}
    else:
        globalTensors = {id(v): k for k, v in globals().items() if isTensor(v)}

    totalUsage = 0
    trivialMemoryUsage = 0
    bigTensors = []
    for obj in get_objects():
        try:
            if isTensor(obj):
                if obj.device.index == None:
                    continue
                tensorMemSize = obj.nelement() * obj.element_size()
                totalUsage += tensorMemSize
                if (tensorMemSize / unit) < 1:
                    trivialMemoryUsage += tensorMemSize
                    continue
                if id(obj) in globalTensors:
                    bigTensors.append(
                        (obj.shape, tensorMemSize / unit, globalTensors[id(obj)])
                    )
                else:
                    bigTensors.append((obj.shape, tensorMemSize / unit))
        except:
            pass

    print(f"Total {totalUsage / unit:.2f} {unitName} CUDA memory in use.\n")

    bigTensors.sort(key=lambda x: x[1], reverse=True)

    maxLowerUnit, minLowerUnit = 1000, 100
    while minLowerUnit >= 1:
        inRangeTensors = [t for t in bigTensors if minLowerUnit <= t[1] <= maxLowerUnit]
        groupCounter = Counter(inRangeTensors)
        print(f"Tensors of size {minLowerUnit:>5} - {maxLowerUnit:>5} {unitName}:")
        for tensor, count in groupCounter.items():
            print(
                f"  {count:4} * Size: {tensor[1]:.2f} {unitName} Shape: {[*tensor[0]]}",
                end="",
            )
            print(f' {tensor[2]:.30}' if len(tensor) == 3 else "")

        print(f"Total: {sum([t[1] for t in inRangeTensors]):.2f} {unitName}\n")
        maxLowerUnit, minLowerUnit = maxLowerUnit // 10, minLowerUnit // 10

    print(
        f"Total {trivialMemoryUsage / unit :.2f} {unitName} is occupied by trivial tensors(<=1{unitName})."
    )


outputTensorSummary()