In [1]:
from tokenizers import TRIETokenizer
import torch
from torch import nn
import numpy as np
import tqdm.notebook as tqdm
import time
import bisect
from typing import *
import gc
from dataclasses import dataclass
from flash_attn import flash_attn_func
from flash_attn_triton import flash_attn_func as flash_attn_func_triton
from dataloader import DatasetReader, DatasetIter, SingleDatasetReader, MultiDatasetsReader
from math import ceil
from functools import partial
from matplotlib import pyplot as plt
import bitsandbytes as bnb
from threading import Lock, Thread
import traceback
import json
import pickle
import os
from enum import Enum
from modeling import *

In [2]:
# Disabled due to no obvious speed up
# torch.backends.cuda.matmul.allow_tf32 = True
# torch.backends.cudnn.allow_tf32 = True

In [3]:
g_tokenizer = TRIETokenizer('llama_vocab_pruned_32k.json')

In [4]:
# Network definition for 370M network
# C_SEQ_LEN = 1024
# C_HIDDEN_SIZE = 1024
# C_NUM_HEADS = 16
# C_NUM_LAYERS = 24

# Network definition for 135M network
C_SEQ_LEN = 1024
C_HIDDEN_SIZE = 768
C_NUM_HEADS = 12
C_NUM_LAYERS = 12

C_DEVICE = torch.device('cuda')
C_DTYPE = torch.float32

C_DEBUG = False

In [5]:
if not C_DEBUG:
    # g_train_data = MultiDatasetsReader([
    #     SingleDatasetReader('datasets/minipile_train_masked_1024.bin'),
    #     SingleDatasetReader('datasets/enwiki_train_masked_1024.bin'),
    #     SingleDatasetReader('datasets/tinytextbooks_train_masked_1024.bin'),
    # ], seed=0)

    # g_train_data = MultiDatasetsReader([
    #     SingleDatasetReader('datasets/tinystories_train_masked.bin'),
    # ], seed=0)

    # g_train_data = MultiDatasetsReader([
    #     SingleDatasetReader('datasets/slimpajamas/slimpajama_train_masked_1024.bin'),
    # ], seed=0)

    g_train_data = MultiDatasetsReader([
        SingleDatasetReader('datasets/sft/alpaca_gpt4.bin'),
        SingleDatasetReader('datasets/sft/wizardlm_evol_2.bin'),
        SingleDatasetReader('datasets/sft/airoboros_2.2.1.bin')
    ], seed=0)
else:
    g_train_data = MultiDatasetsReader([
        SingleDatasetReader('datasets/debug_data_masked.bin'),
        SingleDatasetReader('datasets/debug_data_masked.bin'),
    ], seed=0)

In [6]:
print('Train samples:', len(g_train_data))
print('Sample length:', len(next(iter(g_train_data))['token_ids']))
print('Train tokens:', len(g_train_data) * len(next(iter(g_train_data))['token_ids']))

Train samples: 53058
Sample length: 1024
Train tokens: 54331392


In [7]:
print('Sample 1:', g_tokenizer.decode(next(iter(g_train_data))['token_ids']))

Sample 1: <s>A chat between User and Assistant.
User:As an online platform teacher named Aimee, you possess impeccable credentials which include a Bachelor of Science degree in Industrial and Labor Relations from Cornell University, expertise in the English language, and intermediate proficiency in both Chinese and Spanish. Additionally, your professional experience as a STEAM teacher at UN Women in Singapore has honed your skills in teaching children from the ages of 6-11 and working with students from all levels of education. Your exceptional teaching abilities in spoken English and pronunciation paired with your personal strengths of being informed, patient, and engaging make you an ideal teacher for students seeking to improve their English language skills. Can you provide a short, concise, and unique English self-introduction in bullet point form that would attract students to enroll in your course?
Assistant:Sure, here are some bullet points for your English self-introduction:

-

In [8]:
if C_DEBUG:
    g_model = ToyTransformer(g_tokenizer.get_vocab_size(), 2, 2, 256, 1024, C_DEVICE, C_DTYPE)
else:
    g_model = ToyTransformer(g_tokenizer.get_vocab_size(), C_NUM_LAYERS, C_NUM_HEADS, C_HIDDEN_SIZE, C_SEQ_LEN, C_DEVICE, C_DTYPE)
    # g_model = ToyTransformer(g_tokenizer.get_vocab_size(), 4, 8, 512, C_SEQ_LEN, C_DEVICE, C_DTYPE) # 46M model for tiny stories

In [9]:
print('Total parameters:', sum([t.numel() for t in g_model.parameters()]))
print(g_model)

Total parameters: 135418880
ToyTransformer(
  (sem_embed): Embedding(32768, 768)
  (decoder_layers): ModuleList(
    (0-11): 12 x DecoderLayer(
      (mha): MultiHeadAttention(
        (attn_heads): ModuleList(
          (0-11): 12 x AttentionHead(
            (q_proj): Linear(in_features=768, out_features=64, bias=True)
            (k_proj): Linear(in_features=768, out_features=64, bias=True)
            (v_proj): Linear(in_features=768, out_features=64, bias=True)
          )
        )
        (o_proj): Linear(in_features=768, out_features=768, bias=True)
      )
      (up_proj): Linear(in_features=768, out_features=3072, bias=True)
      (down_proj): Linear(in_features=3072, out_features=768, bias=True)
      (ln_mha): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (ln_ffn): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (act): GELU(approximate='none')
    )
  )
  (lm_head): Linear(in_features=768, out_features=32768, bias=True)
)


In [10]:
total_memory = 0
for p in g_model.parameters():
    total_memory += (p.numel() * p.element_size())
print(f'Model memory usage: {total_memory / 1024 / 1024:.2f}MB')

Model memory usage: 516.58MB


In [11]:
gc.collect()
torch.cuda.empty_cache()

In [12]:
def dataset_collate(dataset_iter: DatasetIter, batch_size: int,
                    transform: Optional[Callable[[Dict[str, List[np.ndarray]]], Dict[str, torch.Tensor]]] = None,
                    drop_last: bool = False):
    cur_batch, cur_batch_size = {}, 0
    for entry in dataset_iter:
        for k, v in entry.items():
            cur_batch.setdefault(k, [])
            cur_batch[k].append(v)
        cur_batch_size += 1
        if cur_batch_size == batch_size:
            yield {k: torch.tensor(np.stack(v)) for k, v in cur_batch.items()} if transform is None else transform(cur_batch)
            cur_batch = {}
            cur_batch_size = 0
    if not drop_last and len(cur_batch) > 0:
        yield {k: torch.tensor(np.stack(v)) for k, v in cur_batch.items()} if transform is None else transform(cur_batch)

In [13]:
@dataclass
class TrainArguments:
    num_epochs: int
    batch_size: int
    gradient_accumulation_steps: int

    optimizer: Type[torch.optim.Optimizer]
    optimizer_args: Optional[Dict[str, Any]]
    mixed_precision_dtype: torch.dtype

    start_lr: float
    max_lr: float
    end_lr: float
    warmup_ratio: float

    gradient_clip_norm: Optional[float]
    probs_epsilon: Optional[float]

    train_data: DatasetReader
    ignore_attn_mask: bool
    ignore_loss_mask: bool

    # eval_data: Optional[DatasetReader]
    # eval_steps: int
    # 
    # eval_generate_prompt: Optional[str]
    # eval_generate_steps: int

    save_steps: int
    save_on_interrupt: bool


# type cast for handling int16/uint16 columns
def train_transform(batch: Dict[str, List[np.ndarray]]):
    return {k: torch.tensor(np.stack(v, dtype=np.int32 if v[0].dtype in [np.int16, np.uint16] else v[0].dtype)) for k, v in batch.items()}


def save_checkpoint(path: str, model: ToyTransformer,
                    optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler.LRScheduler, grad_scaler: Optional[torch.cuda.amp.GradScaler],
                    train_args: TrainArguments,
                    dataset: DatasetReader, dataset_iter: DatasetIter, train_logs: List, misc: Dict):
    if not os.path.exists(path):
        os.makedirs(path, exist_ok=True)
    torch.save(model.state_dict(), path + '/model.pt')
    torch.save(optimizer.state_dict(), path + '/optimizer.pt')
    torch.save(lr_scheduler.state_dict(), path + '/lr_scheduler.pt')
    if grad_scaler is not None:
        torch.save(lr_scheduler.state_dict(), path + '/grad_scaler.pt')
    torch.save(torch.get_rng_state(), path + '/rng_state.pt')
    torch.save(train_args, path + '/train_args.pt')
    dataset.save_iterator(dataset_iter, path + '/dataset_iter.pt')
    torch.save(train_logs, path + '/train_logs.pt')
    torch.save(misc, path + '/misc.pt')
    with open(path + '/config.txt', 'w') as file:
        file.write(f'Dataset: {str(dataset)}')
        file.write('\n\n')
        file.write(f'Model Config: {str(model.config)}')
        file.write('\n\n')
        file.write(f'Training Arguments: {str(train_args)}')


def load_checkpoint(path: str, model: nn.Module,
                    optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler.LRScheduler, grad_scaler: Optional[torch.cuda.amp.GradScaler],
                    train_args: TrainArguments,
                    dataset: DatasetReader, dataset_iter: DatasetIter, train_logs: List, misc: Dict):
    model.load_state_dict(torch.load(path + '/model.pt'))
    optimizer.load_state_dict(torch.load(path + '/optimizer.pt'))
    lr_scheduler.load_state_dict(torch.load(path + '/lr_scheduler.pt'))
    if grad_scaler is not None:
        grad_scaler.load_state_dict(torch.load(path + '/grad_scaler.pt'))
    torch.set_rng_state(torch.load(path + '/rng_state.pt'))
    # assert torch.load(path + '/train_args.pt') == train_args
    dataset_iter.set_state(dataset.load_iterator(path + '/dataset_iter.pt').get_state())
    train_logs.clear()
    train_logs += torch.load(path + '/train_logs.pt')
    misc.update(torch.load(path + '/misc.pt'))


def train_model(model: ToyTransformer, train_args: TrainArguments,
                resume_from: Optional[str] = None,
                show_progress: bool = True,
                output_dir: str = 'checkpoints', interrupt_lock: Optional[Lock] = None):
    output_dir = output_dir.rstrip('/')

    interrupted = False
    train_logs = []
    misc = {'epochs': 0, 'steps': 0, 'last_batch_idx': -1}

    total_samples = len(train_args.train_data)
    epoch_steps = ceil(total_samples / train_args.batch_size)
    assert epoch_steps >= train_args.gradient_accumulation_steps, \
        f'per-epoch steps {epoch_steps} is less than gradient accumulation steps {train_args.gradient_accumulation_steps}'

    schedule_steps = ceil(total_samples / train_args.batch_size / train_args.gradient_accumulation_steps)
    total_steps = schedule_steps * train_args.num_epochs

    optimizer = train_args.optimizer(model.parameters(), **(train_args.optimizer_args if train_args.optimizer_args is not None else {}))
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=train_args.max_lr, div_factor=train_args.max_lr / train_args.start_lr,
                                                    total_steps=total_steps,
                                                    final_div_factor=train_args.start_lr / train_args.end_lr, pct_start=train_args.warmup_ratio)

    if train_args.mixed_precision_dtype == torch.float16:
        grad_scaler = torch.cuda.amp.GradScaler()
    else:
        grad_scaler = None

    dataset_iter = iter(train_args.train_data)
    if resume_from is not None:
        load_checkpoint(resume_from, model, optimizer, scheduler, grad_scaler, train_args, train_args.train_data, dataset_iter, train_logs, misc)

    bar = tqdm.tqdm(total=total_steps, smoothing=1.0, disable=not show_progress)
    bar.update(misc['steps'])

    model.train()

    for epoch_num in range(train_args.num_epochs):
        if epoch_num < misc['epochs']:
            continue
        optimizer.zero_grad()
        for batch_idx, batch in enumerate(dataset_collate(dataset_iter, train_args.batch_size, train_transform), start=misc['last_batch_idx'] + 1):
            step_start_time = time.time()

            tokens = batch['token_ids'].to(model.device)
            inputs = tokens[:, :-1]
            labels = tokens[:, 1:]

            attn_mask = batch['attn_mask'][:, :-1].to(model.device) if 'attn_mask' in batch and not train_args.ignore_attn_mask else None
            loss_mask = batch['loss_mask'][:, 1:].to(model.device) if 'loss_mask' in batch and not train_args.ignore_loss_mask else None

            with torch.autocast(device_type='cuda', dtype=train_args.mixed_precision_dtype, enabled=train_args.mixed_precision_dtype is not None):
                logits, kv_state = model.forward(inputs, attn_mask=attn_mask)

                probs = torch.softmax(logits, dim=2).view(-1, logits.shape[-1])
                if train_args.probs_epsilon is not None:
                    probs += train_args.probs_epsilon

                loss = (-torch.log(probs[torch.arange(probs.shape[0]), labels.reshape(-1)]))
                if loss_mask is not None:
                    loss = (loss * loss_mask.reshape(-1)).mean() / train_args.gradient_accumulation_steps
                else:
                    loss = loss.mean() / train_args.gradient_accumulation_steps

            # brutally clear nan, give up the whole batch
            if torch.isnan(loss):
                print(f'encountered nan loss at epoch {epoch_num + 1}, batch {batch_idx}')
            else:
                if grad_scaler is not None:
                    grad_scaler.scale(loss).backward()
                else:
                    loss.backward()

            if (batch_idx + 1) % train_args.gradient_accumulation_steps == 0 or (batch_idx + 1) == epoch_steps:
                if grad_scaler is not None:
                    grad_scaler.unscale_(optimizer)

                if train_args.gradient_clip_norm is not None:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), train_args.gradient_clip_norm)

                if grad_scaler is not None:
                    grad_scaler.step(optimizer)
                    grad_scaler.update()
                    optimizer.zero_grad()
                else:
                    optimizer.step()
                    optimizer.zero_grad()

                step_time_cost = time.time() - step_start_time
                throughput = round(probs.shape[0] / step_time_cost / 1000, 2)

                step_stat = {'Loss': f'{loss.item() * train_args.gradient_accumulation_steps:.3f}',
                             'LR': f'{scheduler.get_last_lr()[0]:.2e}',
                             'Throughput': f'{throughput} kt/s'}

                if show_progress:
                    bar.set_description(f'Epoch {epoch_num + 1}')
                    bar.set_postfix(step_stat)
                else:
                    print(', '.join(f'{s[0]}:{s[1]}' for s in step_stat.items()))

                scheduler.step()
                bar.update(1)
                train_logs.append((epoch_num, batch_idx, step_stat))

                misc['steps'] += 1
                misc['last_batch_idx'] = batch_idx
                if train_args.save_steps > 0 and (misc['steps'] % train_args.save_steps) == 0:
                    save_checkpoint(output_dir + f'/checkpoint-{misc["steps"]}',
                                    model, optimizer, scheduler, grad_scaler, train_args, train_args.train_data, dataset_iter, train_logs, misc)
                if interrupt_lock is not None and not interrupt_lock.locked():
                    if train_args.save_on_interrupt:
                        save_checkpoint(output_dir + f'/checkpoint-{misc["steps"]}',
                                        model, optimizer, scheduler, grad_scaler, train_args, train_args.train_data, dataset_iter, train_logs, misc)
                    interrupted = True
                    break
        if interrupted:
            break
        misc['epochs'] += 1
        misc['last_batch_idx'] = -1
        dataset_iter = iter(train_args.train_data)
    bar.close()

    if not interrupted:
        save_checkpoint(output_dir + f'/checkpoint-done',
                        model, optimizer, scheduler, grad_scaler, train_args, train_args.train_data, dataset_iter, train_logs, misc)

    return train_logs


def train_model_interruptable(model: nn.Module, train_args: TrainArguments,
                              resume_from: Optional[str] = None,
                              show_progress: bool = True,
                              output_dir: str = 'checkpoints'):
    return_value, run_finish = None, False

    def return_value_wrapper(func, *args, **kwargs):
        nonlocal return_value, run_finish
        # noinspection PyBroadException
        try:
            return_value = func(*args, **kwargs)
        except Exception as _:
            traceback.print_exc()
        run_finish = True

    interrupt_lock = Lock()
    interrupt_lock.acquire()
    thread = Thread(target=return_value_wrapper, args=(train_model, model, train_args),
                    kwargs={'resume_from': resume_from, 'show_progress': show_progress, 'output_dir': output_dir, 'interrupt_lock': interrupt_lock})
    thread.start()
    while not run_finish:
        try:
            time.sleep(0.1)
        except KeyboardInterrupt as _:
            interrupt_lock.release()
            break
    thread.join()

    return return_value

In [14]:
if C_DEBUG:
    g_train_args = TrainArguments(
        num_epochs=1000, batch_size=8, gradient_accumulation_steps=1,
        optimizer=torch.optim.AdamW, optimizer_args=None,
        mixed_precision_dtype=torch.bfloat16,
        start_lr=1e-5, max_lr=1e-3, end_lr=1e-6, warmup_ratio=0.1,
        gradient_clip_norm=1.0, probs_epsilon=None,
        train_data=g_train_data, ignore_attn_mask=False, ignore_loss_mask=False,
        # eval_data=None, eval_steps=-1,
        # eval_generate_prompt=None, eval_generate_steps=-1,
        save_steps=-1,
        save_on_interrupt=False,
    )
else:
    g_train_args = TrainArguments(
        num_epochs=2, batch_size=12, gradient_accumulation_steps=8,
        optimizer=torch.optim.AdamW, optimizer_args=None,
        mixed_precision_dtype=torch.bfloat16,
        start_lr=5e-5, max_lr=1e-3, end_lr=1e-4, warmup_ratio=0.1,
        gradient_clip_norm=0.7, probs_epsilon=None,
        train_data=g_train_data, ignore_attn_mask=True, ignore_loss_mask=True,
        # eval_data=None, eval_steps=-1,
        # eval_generate_prompt=None, eval_generate_steps=-1,
        save_steps=1000,
        save_on_interrupt=True
    )

In [None]:
if C_DEBUG:
    global_config['attn_backend'] = AttentionBackend.FlashAttentionTriton
    g_train_args.ignore_attn_mask = True
    g_train_logs = train_model_interruptable(g_model, g_train_args, resume_from=None,
                                             show_progress=True, output_dir='checkpoints/debug_output')
else:
    g_model.load_state_dict(torch.load('checkpoints/train-round1-135M/checkpoint-19000/model.pt'))
    global_config['attn_backend'] = AttentionBackend.FlashAttentionTriton
    g_train_args.ignore_attn_mask = False
    g_train_args.ignore_loss_mask = False
    g_train_args.num_epochs = 3
    g_train_args.save_steps = 500
    g_train_args.warmup_ratio = 0.1
    g_train_args.batch_size = 10
    g_train_args.gradient_accumulation_steps = 4
    g_train_args.start_lr = 1e-6
    g_train_args.max_lr = 1e-4
    g_train_args.end_lr = 1e-5
    g_train_logs = train_model_interruptable(g_model, g_train_args, resume_from=None,
                                             show_progress=True, output_dir='checkpoints/train-round1-135M-sft')

  0%|          | 0/3981 [00:00<?, ?it/s]

In [None]:
def plot_train_logs(train_logs_list):
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    for i, train_logs in enumerate(train_logs_list):
        axes[0].plot([float(l[2]['Loss']) for l in train_logs])
        axes[1].plot([float(l[2]['LR']) for l in train_logs])
        axes[2].plot([float(l[2]['Throughput'][:-5]) for l in train_logs])

    axes[0].set_title('Loss')
    axes[1].set_title('Learning Rate')
    axes[2].set_title('Throughput (kt/s)')

    for ax in axes:
        ax.autoscale()

    #plt.legend()
    plt.tight_layout()
    plt.show()

In [None]:
checkpoint_list = [
    'checkpoints/train-round1/checkpoint-19000/train_logs.pt',
    'checkpoints/train-round1-masked/checkpoint-1199/train_logs.pt'
]
logs_list = [torch.load(c) for c in checkpoint_list]

plot_train_logs(logs_list)

In [None]:
@torch.autocast(device_type='cuda', dtype=torch.bfloat16)
def generate(model, tokenizer, prompt, temperature, top_p, rep_penalty,
             max_new_tokens=20, total_tokens=None,
             end_tokens=None,
             enable_kv_cache=True):
    model.eval()

    feed_tokens = tokenizer.encode(prompt) if isinstance(prompt, str) else prompt
    all_tokens = feed_tokens.copy()
    if total_tokens is not None:
        max_new_tokens = max(0, total_tokens - len(feed_tokens))

    with torch.no_grad():
        kv_cache = None
        for _ in range(max_new_tokens):
            logits, kv_cache = model.forward(
                torch.tensor([feed_tokens if enable_kv_cache else all_tokens]).to(model.device),
                kv_cache=kv_cache)
            logits = logits[0][-1].cpu()
            if not enable_kv_cache:
                kv_cache = None

            # apply repetition penalty
            logits_rep = torch.gather(logits, 0, torch.tensor(all_tokens))
            logits_rep = torch.where(logits_rep < 0, logits_rep * rep_penalty, logits_rep / rep_penalty)
            logits.scatter_(0, torch.tensor(all_tokens), logits_rep)

            # apply temperature
            logits /= max(temperature, 1e-6)

            probs = torch.softmax(logits, dim=0)

            # apply top-p
            ordered_probs, ordered_indices = torch.sort(probs, descending=True)
            cum_probs = torch.cumsum(ordered_probs, dim=0).tolist()
            top_p_index = bisect.bisect_right(cum_probs, top_p) + 1
            ordered_probs, ordered_indices = ordered_probs[:top_p_index], ordered_indices[:top_p_index]
            sampled_index = ordered_indices[torch.multinomial(ordered_probs, num_samples=1).item()].item()

            all_tokens.append(sampled_index)
            feed_tokens = [sampled_index]

            if end_tokens is not None and sampled_index in end_tokens:
                break

    return all_tokens

In [None]:
def modeling_sanity_check(gen_length: int, enable_kv_cache: bool):
    assert C_DEBUG == True, 'sanity check can only be performed under debug settings'
    train_token_ids = next(iter(g_train_data))['token_ids'].tolist()
    train_texts = [l.strip() for l in g_tokenizer.decode(train_token_ids).split('</s>')]
    start_time = time.time()
    gen_token_ids = generate(g_model, g_tokenizer, train_token_ids[:10],
                             temperature=1.0, top_p=0.01, rep_penalty=1.0,
                             total_tokens=gen_length,
                             end_tokens=g_tokenizer.encode('<reserved_0>'),
                             enable_kv_cache=enable_kv_cache)
    cost_time = time.time() - start_time
    print(f'Generation finished in {cost_time:.2f} sec(s), throughput: {len(gen_token_ids) / cost_time:.1f} tokens/sec')
    # Complete check
    cmp_length = min(len(train_token_ids), len(gen_token_ids))
    print('Complete Identical:', train_token_ids[:cmp_length] == gen_token_ids[:cmp_length])
    # Segment check
    gen_texts = [l.strip() for l in g_tokenizer.decode(gen_token_ids).split('</s>')]
    for i in range(min(len(train_texts), len(gen_texts))):
        ref, real = train_texts[i], gen_texts[i]
        cmp_length = min(len(ref), len(real))
        print(f'Segment {i}: Ref Len: {len(ref)}, Gen Len: {len(real)} Identical: {ref[:cmp_length] == real[:cmp_length]}')

In [None]:
for backend_option in [AttentionBackend.Naive, AttentionBackend.FlashAttentionTriton, AttentionBackend.FlashAttentionCuda]:
    for enable_kv_cache_option in [True, False]:
        print(f'Backend = {backend_option}, KVCache Enable = {enable_kv_cache_option}')
        global_config['attn_backend'] = backend_option
        modeling_sanity_check(512, enable_kv_cache_option)
        print('=' * 80)

In [None]:
for t in g_tokenizer.decode(next(iter(g_train_data))['token_ids'].tolist()).split('</s>')[:3]:
    print(t)

In [None]:
time_start = time.time()
global_config['attn_backend'] = AttentionBackend.FlashAttentionTriton
result = generate(g_model, g_tokenizer, '<s>A chat between User and Assistant.\nUser:What is python?\nAssistant:',
                  temperature=1.0, top_p=0.01, rep_penalty=1.1,
                  total_tokens=128,
                  end_tokens=g_tokenizer.encode('</s>'),
                  enable_kv_cache=True)
time_cost = time.time() - time_start

print(g_tokenizer.decode(result))
print(f'{time_cost:.3f} sec(s), throughput {len(result) / time_cost:.1f} tokens/sec')

In [None]:
# batch 2160

In [None]:
g_model.load_state_dict(torch.load('debug_model.pt'))

In [None]:
t_iter = dataset_collate(iter(g_train_data), 10, train_transform)
for i in range(2161):
    next(t_iter)
m = next(t_iter)

In [None]:
global_config['attn_backend'] = AttentionBackend.Naive

In [None]:
with torch.no_grad():
    with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
        x, _ = g_model.forward(m['token_ids'][:, :-1].cuda(), m['attn_mask'][:, :-1].cuda())
        p = torch.softmax(x, dim=2).view(-1, x.shape[-1])
        l = (-torch.log(p[torch.arange(p.shape[0]), m['token_ids'][:, 1:].cuda().reshape(-1)]))
        #l *= m['loss_mask'][:, 1:].cuda().reshape(-1)
        p = p.view(-1, 1023, g_tokenizer.get_vocab_size())
        l = l.view(-1, 1023)

In [None]:
torch.topk(l, 1)

In [None]:
m['token_ids'][:, 1:][0, 71], p[0, 71].argmax()

In [None]:
print(g_tokenizer.decode(m['token_ids'][6].tolist()))

In [None]:
l.shape

In [None]:
g_tokenizer.decode([30474])