# Importing libraries

In [None]:
@triton.jit
def flash_decoding_stage_1(
    Q, K, V,
    FIXMAX, ACC, SUM,
    stride_q_b, stride_q_h, stride_q_t, stride_q_d,
    stride_k_b, stride_k_h, stride_k_t, stride_k_d,
    stride_v_b, stride_v_h, stride_v_t, stride_v_d,
    stride_m_b, stride_m_h,
    stride_acc_b, stride_acc_h, stride_acc_d,
    stride_sum_b, stride_sum_h,
    B, H, T, D,
    BLOCK_N: tl.constexpr, BLOCK_D: tl.constexpr
):
    pid_bh = tl.program_id(0)
    pid_t  = tl.program_id(1)

    off_h = pid_bh % H
    off_z = pid_bh // H

    offs_n = tl.arange(0, BLOCK_N)
    offs_d = tl.arange(0, BLOCK_D)

    q_ptrs = Q + off_z*stride_q_b + off_h*stride_q_h + offs_d*stride_q_d
    k_base = K + off_z*stride_k_b + off_h*stride_k_h
    v_base = V + off_z*stride_v_b + off_h*stride_v_h

    # fixed max
    max_ptr = FIXMAX + off_z*stride_m_b + off_h*stride_m_h
    fixed_max = tl.load(max_ptr).to(tl.float32)

    # tile range on T
    start_n = pid_t * BLOCK_N
    remain = T - start_n
    block_size = tl.maximum(0, tl.minimum(remain, BLOCK_N))
    mask_n = offs_n < block_size

    # load q (fp32)
    q = tl.load(q_ptrs, mask=offs_d < D, other=0.0).to(tl.float32)

    # scale
    scale = 1.0 / tl.sqrt(tl.full((), D, tl.float32))

    # partial accumulators
    part_acc = tl.zeros([BLOCK_D], dtype=tl.float32)
    part_sum = tl.zeros((), dtype=tl.float32)

    # load k,v tile (fp32)
    k = tl.load(
        k_base + (start_n + offs_n)[:, None]*stride_k_t + offs_d[None, :]*stride_k_d,
        mask=mask_n[:, None] & (offs_d[None, :] < D),
        other=0.0,
    ).to(tl.float32)
    v = tl.load(
        v_base + (start_n + offs_n)[:, None]*stride_v_t + offs_d[None, :]*stride_v_d,
        mask=mask_n[:, None] & (offs_d[None, :] < D),
        other=0.0,
    ).to(tl.float32)

    # scores
    scores = tl.sum(k * q[None, :], axis=1) * scale
    neg_inf = tl.full(scores.shape, -float("inf"), scores.dtype)
    scores = tl.where(mask_n, scores, neg_inf)

    # probs with fixed max
    probs = tl.exp(scores - fixed_max)
    probs = tl.where(mask_n, probs, 0.0)

    # partial sums
    part_sum += tl.sum(probs, axis=0)
    part_acc += tl.sum(probs[:, None] * v, axis=0)

    # atomic accumulate into ACC[B,H,D] and SUM[B,H]
    acc_ptrs = ACC + off_z*stride_acc_b + off_h*stride_acc_h + offs_d*stride_acc_d
    sum_ptr  = SUM + off_z*stride_sum_b + off_h*stride_sum_h

    # only valid D range
    tl.atomic_add(acc_ptrs, part_acc, mask=offs_d < D)
    tl.atomic_add(sum_ptr,  part_sum)

@triton.jit
def flash_decoding_stage_2(
    ACC, SUM,
    O,
    stride_acc_b, stride_acc_h, stride_acc_d,
    stride_sum_b, stride_sum_h,
    stride_o_b, stride_o_h, stride_o_t, stride_o_d,
    B, H, D,
    BLOCK_D: tl.constexpr
):
    pid_bh = tl.program_id(0)
    off_h = pid_bh % H
    off_z = pid_bh // H

    offs_d = tl.arange(0, BLOCK_D)

    acc_ptrs = ACC + off_z*stride_acc_b + off_h*stride_acc_h + offs_d*stride_acc_d
    sum_ptr  = SUM + off_z*stride_sum_b + off_h*stride_sum_h
    o_ptrs   = O   + off_z*stride_o_b   + off_h*stride_o_h   + offs_d*stride_o_d  # t=0

    acc = tl.load(acc_ptrs, mask=offs_d < D, other=0.0)
    s   = tl.load(sum_ptr)
    eps = tl.full((), 1e-20, tl.float32)
    denom = tl.maximum(s, eps)
    out = acc / denom

    tl.store(o_ptrs, out.to(O.dtype.element_ty), mask=offs_d < D)

def flash_decoding_2(q, k, v, fixmax: float = 10):
    B, H, _, D = q.shape
    T = k.shape[2]
    n_tiles: int = (T + BLOCK_N - 1) // BLOCK_N

    ACC = torch.zeros((B, H, D), dtype=torch.float32, device=q.device)
    SUM = torch.zeros((B, H),    dtype=torch.float32, device=q.device)

    grid_partial = (B*H, n_tiles)
    flash_decoding_stage_1[grid_partial](
        q, k, v, fixmax, ACC, SUM,
        # strides ... (ACC/SUM 포함),
        B, H, T, D,
        BLOCK_N=BLOCK_N, BLOCK_D=BLOCK_D,
        num_warps=4, num_stages=2
    )

    o = torch.empty_like(q)
    grid_finalize = (B*H,)
    flash_decoding_stage_2[grid_finalize](
        ACC, SUM, o,
        # strides ...,
        B, H, D,
        BLOCK_D=BLOCK_D,
        num_warps=1
    )
    return o

@triton.jit
def flash_decoding_2_kernel(
        Q,     # [B, H, 1, D]
        K, V,  # [B, H, T, D]
        O,     # [B, H, 1, D]
        FIXMAX,
        stride_q_b, stride_q_h, stride_q_t, stride_q_d,
        stride_k_b, stride_k_h, stride_k_t, stride_k_d,
        stride_v_b, stride_v_h, stride_v_t, stride_v_d,
        stride_o_b, stride_o_h, stride_o_t, stride_o_d,
        stride_m_b, stride_m_h,
        B, H, T, D,
        BLOCK_N: tl.constexpr, BLOCK_D: tl.constexpr
):
    # Program IDs
    off_hz = tl.program_id(0)
    off_h = off_hz % H
    off_z = off_hz // H

    offs_n = tl.arange(0, BLOCK_N)
    offs_d = tl.arange(0, BLOCK_D)

    # Initialize pointers for this batch and head
    q_ptrs = Q + off_z * stride_q_b + off_h * stride_q_h + offs_d * stride_q_d
    k_ptrs = K + off_z * stride_k_b + off_h * stride_k_h + offs_n[:, None] * stride_k_t + offs_d[None, :] * stride_k_d
    v_ptrs = V + off_z * stride_v_b + off_h * stride_v_h + offs_n[:, None] * stride_v_t + offs_d[None, :] * stride_v_d
    o_ptrs = O + off_z * stride_o_b + off_h * stride_o_h + offs_d * stride_o_d

    #
    max_ptr = FIXMAX + off_z * stride_m_b + off_h * stride_m_h
    fixed_max = tl.load(max_ptr, mask=tl.full((), True, tl.int1), other=0.0).to(tl.float32)

    # Load single query vector
    q = tl.load(q_ptrs, mask=offs_d < D, other=0.0).to(tl.float32)  # [BLOCK_D]

    # Initialize output accumulator and softmax statistics
    acc = tl.zeros([BLOCK_D], dtype=tl.float32)
    sum_exp = tl.zeros((), dtype=tl.float32)

    # Scale
    scale = 1.0 / tl.sqrt(tl.full((), D, tl.float32))

    # Loop over K, V blocks
    start_n = 0
    while start_n < T:
        # Calculate current block bounds
        remain = T - start_n
        block_size = tl.minimum(remain, BLOCK_N)
        mask_n = offs_n < block_size

        # Load K, V blocks
        k = tl.load(
            k_ptrs + start_n * stride_k_t,
            mask=mask_n[:, None] & (offs_d[None, :] < D),
            other=0.0,
        )  # [BLOCK_N, BLOCK_D]
        v = tl.load(
            v_ptrs + start_n * stride_v_t,
            mask=mask_n[:, None] & (offs_d[None, :] < D),
            other=0.0,
        )  # [BLOCK_N, BLOCK_D]

        # Compute attention scores: q @ K^T
        scores = tl.sum(k * q[None, :], axis=1) * scale  # [BLOCK_N]

        # Apply causal mask if needed (for decoding, usually all positions are valid)
        scores = tl.where(mask_n, scores, -float('inf'))
        neg_inf = tl.full(scores.shape, -float("inf"), scores.dtype)
        scores = tl.where(mask_n, scores, neg_inf)

        # Compute probabilities for current block
        probs = tl.exp(scores - fixed_max)  # [BLOCK_N]
        probs = tl.where(mask_n, probs, 0.0)

        # Update accumulator and sum
        sum_exp += tl.sum(probs, axis=0)
        acc += tl.sum(probs[:, None] * v, axis=0)  # [BLOCK_D]

        start_n += BLOCK_N

    # Final normalization
    eps = tl.full((), 1e-20, tl.float32)
    denom = tl.maximum(sum_exp, eps)
    out = acc / denom

    # Store output
    tl.store(o_ptrs, out.to(O.dtype.element_ty), mask=offs_d < D)

def flash_decoding_23(q, k, v, fixmax: float = 10):
    """
    Flash Decoding 2 Triton Kernel

    Args:
        q: Query tensor [batch, n_heads, 1, d_head]
        k: Key tensor   [batch, n_heads, seq_len_k, d_head]
        v: Value tensor [batch, n_heads, seq_len_v, d_head]
        fixmax:

    Returns:
        o: Output tensor [batch, n_heads, 1, d_head]
    """
    assert q.ndim == 4 and k.ndim == 4 and v.ndim == 4
    batch_size, n_heads, q_seq_len, d_head = q.shape
    assert q_seq_len == 1, "This kernel assumes q_len=1 (decode step)."
    kv_seq_len = k.shape[2]

    q = q.contiguous()
    k = k.contiguous()
    v = v.contiguous()
    fixmax = torch.full((batch_size, n_heads), float(fixmax), dtype=torch.float32, device=q.device)

    o = torch.empty_like(q)
    grid = (batch_size * n_heads,)

    flash_decoding_2_kernel[grid](
        q, k, v, o,
        fixmax,
        q.stride(0), q.stride(1), q.stride(2), q.stride(3),  # Q strides [B, H, 1, D]
        k.stride(0), k.stride(1), k.stride(2), k.stride(3),  # K strides [B, H, N, D]
        v.stride(0), v.stride(1), v.stride(2), v.stride(3),  # V strides [B, H, N, D]
        o.stride(0), o.stride(1), o.stride(2), o.stride(3),  # O strides [B, H, 1, D]
        fixmax.stride(0), fixmax.stride(1),
        batch_size, n_heads, kv_seq_len, d_head,
        BLOCK_N=BLOCK_N, BLOCK_D=BLOCK_D,
        num_warps=4, num_stages=2
    )
    return o

@triton.jit
def flash_decoding_2_kernel(
        Q,     # [B, H, 1, D]
        K, V,  # [B, H, T, D]
        O,     # [B, H, 1, D]
        FIXMAX,
        stride_q_b, stride_q_h, stride_q_t, stride_q_d,
        stride_k_b, stride_k_h, stride_k_t, stride_k_d,
        stride_v_b, stride_v_h, stride_v_t, stride_v_d,
        stride_o_b, stride_o_h, stride_o_t, stride_o_d,
        stride_m_b, stride_m_h,
        B, H, T, D,
        BLOCK_N: tl.constexpr, BLOCK_D: tl.constexpr
):
    # Program IDs
    off_hz = tl.program_id(0)
    off_h = off_hz % H
    off_z = off_hz // H

    offs_n = tl.arange(0, BLOCK_N)
    offs_d = tl.arange(0, BLOCK_D)

    # Initialize pointers for this batch and head
    q_ptrs = Q + off_z * stride_q_b + off_h * stride_q_h + offs_d * stride_q_d
    k_ptrs = K + off_z * stride_k_b + off_h * stride_k_h + offs_n[:, None] * stride_k_t + offs_d[None, :] * stride_k_d
    v_ptrs = V + off_z * stride_v_b + off_h * stride_v_h + offs_n[:, None] * stride_v_t + offs_d[None, :] * stride_v_d
    o_ptrs = O + off_z * stride_o_b + off_h * stride_o_h + offs_d * stride_o_d

    #
    max_ptr = FIXMAX + off_z * stride_m_b + off_h * stride_m_h
    fixed_max = tl.load(max_ptr, mask=tl.full((), True, tl.int1), other=0.0).to(tl.float32)

    # Load single query vector
    q = tl.load(q_ptrs, mask=offs_d < D, other=0.0).to(tl.float32)  # [BLOCK_D]

    # Initialize output accumulator and softmax statistics
    acc = tl.zeros([BLOCK_D], dtype=tl.float32)
    sum_exp = tl.zeros((), dtype=tl.float32)

    # Scale
    scale = 1.0 / tl.sqrt(tl.full((), D, tl.float32))

    # Loop over K, V blocks
    start_n = 0
    while start_n < T:
        # Calculate current block bounds
        remain = T - start_n
        block_size = tl.minimum(remain, BLOCK_N)
        mask_n = offs_n < block_size

        # Load K, V blocks
        k = tl.load(
            k_ptrs + start_n * stride_k_t,
            mask=mask_n[:, None] & (offs_d[None, :] < D),
            other=0.0,
        )  # [BLOCK_N, BLOCK_D]
        v = tl.load(
            v_ptrs + start_n * stride_v_t,
            mask=mask_n[:, None] & (offs_d[None, :] < D),
            other=0.0,
        )  # [BLOCK_N, BLOCK_D]

        # Compute attention scores: q @ K^T
        scores = tl.sum(k * q[None, :], axis=1) * scale  # [BLOCK_N]

        # Apply causal mask if needed (for decoding, usually all positions are valid)
        scores = tl.where(mask_n, scores, -float('inf'))
        neg_inf = tl.full(scores.shape, -float("inf"), scores.dtype)
        scores = tl.where(mask_n, scores, neg_inf)

        # Compute probabilities for current block
        probs = tl.exp(scores - fixed_max)  # [BLOCK_N]
        probs = tl.where(mask_n, probs, 0.0)

        # Update accumulator and sum
        sum_exp += tl.sum(probs, axis=0)
        acc += tl.sum(probs[:, None] * v, axis=0)  # [BLOCK_D]

        start_n += BLOCK_N

    # Final normalization
    eps = tl.full((), 1e-20, tl.float32)
    denom = tl.maximum(sum_exp, eps)
    out = acc / denom

    # Store output
    tl.store(o_ptrs, out.to(O.dtype.element_ty), mask=offs_d < D)

def flash_decoding_2(q, k, v, fixmax: float = 10):
    """
    Flash Decoding 2 Triton Kernel

    Args:
        q: Query tensor [batch, n_heads, 1, d_head]
        k: Key tensor   [batch, n_heads, seq_len_k, d_head]
        v: Value tensor [batch, n_heads, seq_len_v, d_head]
        fixmax:

    Returns:
        o: Output tensor [batch, n_heads, 1, d_head]
    """
    assert q.ndim == 4 and k.ndim == 4 and v.ndim == 4
    batch_size, n_heads, q_seq_len, d_head = q.shape
    assert q_seq_len == 1, "This kernel assumes q_len=1 (decode step)."
    kv_seq_len = k.shape[2]

    q = q.contiguous()
    k = k.contiguous()
    v = v.contiguous()
    fixmax = torch.full((batch_size, n_heads), float(fixmax), dtype=torch.float32, device=q.device)

    o = torch.empty_like(q)
    grid = (batch_size * n_heads,)

    flash_decoding_2_kernel[grid](
        q, k, v, o,
        fixmax,
        q.stride(0), q.stride(1), q.stride(2), q.stride(3),  # Q strides [B, H, 1, D]
        k.stride(0), k.stride(1), k.stride(2), k.stride(3),  # K strides [B, H, N, D]
        v.stride(0), v.stride(1), v.stride(2), v.stride(3),  # V strides [B, H, N, D]
        o.stride(0), o.stride(1), o.stride(2), o.stride(3),  # O strides [B, H, 1, D]
        fixmax.stride(0), fixmax.stride(1),
        batch_size, n_heads, kv_seq_len, d_head,
        BLOCK_N=BLOCK_N, BLOCK_D=BLOCK_D,
        num_warps=4, num_stages=2
    )
    return o

In [1]:
import os
import sys
from typing import Tuple
import torch
from torch.utils.data import Dataset, DataLoader
from src.utils import set_seed, load_text, split_text
from src.config import ModelConfig, TrainConfig, GenerationConfig
from src.train import Trainer
from tokenizer.tokenizer import CharTokenizer
from models.GPT import GPT

In [2]:
PROJECT_ROOT = os.path.abspath(os.getcwd() + "/..")
sys.path.append(PROJECT_ROOT)
print(f"PROJECT_ROOT: {PROJECT_ROOT}")

PROJECT_ROOT: /home/pathfinder/projects/PathFinder


# Configuration

In [3]:
model_config = ModelConfig(
    vocab_size=-1,
    max_seq_len=128,
    d_embed=256,
    n_layers=4,
    attn_type="MHA",
    n_heads=4,
    d_head=64,
    attn_bias=False,
    d_ff=1024,
    mlp_bias=False,
    flash=True,
    flash_decode=True,
    cla=False
)

train_config = TrainConfig(
    debug=False,
    wandb_project="nanoGPT",
    model_name="nanoGPT",
    per_device_train_batch_size=512,
    per_device_eval_batch_size=1024,
    gradient_accumulation_steps=512 // 512,
    num_train_epochs=1,
    learning_rate=5e-4,
    weight_decay=0.01,
    attn_decay=0.5,
    eval_steps=100,
    mixed_precision=True,
    matmul_precision="high",
)

generation_config = GenerationConfig(
    use_cache=True,
    max_new_tokens=1000,
    temperature=1.0,
    top_k=50
)

# Utils

## Reproducibility

In [4]:
set_seed(train_config.seed)

Random seed set to 42


## Device

In [5]:
device = torch.device("cuda")
print(f"Device: {torch.cuda.get_device_name(device)}")
torch.set_float32_matmul_precision(train_config.matmul_precision)  # Tensor Cores
print(f"MatMul Precision: {train_config.matmul_precision}")

Device: NVIDIA GeForce RTX 4080 SUPER
MatMul Precision: high


# Dataset

In [6]:
dataset_path = os.path.join(PROJECT_ROOT, "datasets/Shakespeare/shakespeare.txt")
shakespeare_text = load_text(dataset_path)

Loaded text data from /home/pathfinder/projects/PathFinder/datasets/Shakespeare/shakespeare.txt (length: 1115394 characters).


In [7]:
if train_config.debug:
    subset_shakespeare_text = shakespeare_text[:10000]
    print(subset_shakespeare_text)
    shakespeare_text = subset_shakespeare_text

# Tokenizer

In [8]:
char_tokenizer = CharTokenizer()
char_tokenizer.build_vocab(text=shakespeare_text)
vocab_path = os.path.join(PROJECT_ROOT, "datasets/Shakespeare/vocab.json")
char_tokenizer.save_vocab(vocab_path)
model_config.vocab_size = char_tokenizer.vocab_size

Vocabulary size: 69
Vocabulary saved to /home/pathfinder/projects/PathFinder/datasets/Shakespeare/vocab.json.


In [9]:
if train_config.debug:
    print("Vocabulary:", char_tokenizer.char2idx)

# Preprocessing

In [10]:
train_text, val_text = split_text(shakespeare_text, val_size=0.1)
print(f"Training text length: {len(train_text)} characters")
print(f"Validation text length: {len(val_text)} characters")

Training text length: 1003854 characters
Validation text length: 111540 characters


In [11]:
class TextDataset(Dataset):
    def __init__(self, text: str, tokenizer: CharTokenizer, max_seq_len: int):
        self.encoded = tokenizer.encode(text)
        self.max_seq_len = max_seq_len

    def __len__(self) -> int:
        return len(self.encoded) - self.max_seq_len

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        input_ids = self.encoded[idx:idx + self.max_seq_len]
        target_ids = self.encoded[idx + 1:idx + self.max_seq_len + 1]
        return input_ids, target_ids

def collate_fn(batch):
    input_ids = torch.stack([item[0] for item in batch])
    target_ids = torch.stack([item[1] for item in batch])
    return {
        "input_ids": input_ids,
        #"attention_mask": attention_mask,
        "target_ids": target_ids
    }

train_dataset = TextDataset(train_text, char_tokenizer, model_config.max_seq_len)
val_dataset = TextDataset(val_text, char_tokenizer, model_config.max_seq_len)

train_loader = DataLoader(
    train_dataset,
    collate_fn=collate_fn,
    batch_size=train_config.per_device_eval_batch_size,
    shuffle=True,
    num_workers=4
)
val_loader = DataLoader(
    val_dataset,
    collate_fn=collate_fn,
    batch_size=train_config.per_device_eval_batch_size,
    shuffle=False,
    num_workers=4
)

In [12]:
if train_config.debug:
    sample_batch = next(iter(train_loader))
    print(f"Sample input IDs: {sample_batch['input_ids'][0]}")
    print(f"Sample target IDs: {sample_batch['target_ids'][0]}")

# Model

In [13]:
# Initialize the model
model = GPT(model_config).to(device)
model = torch.compile(model)
print(model)
print(f"Number of parameters: {model.get_num_params() / 1e6:.2f}M")

OptimizedModule(
  (_orig_mod): GPT(
    (token_embedding): Embedding(69, 256)
    (positional_encoding): Embedding(128, 256)
    (dropout): Dropout(p=0.01, inplace=False)
    (blocks): ModuleList(
      (0-3): 4 x Block(
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (attn): MultiHeadAttention(
          (qkv_proj): Linear(in_features=256, out_features=768, bias=False)
          (out_proj): Linear(in_features=256, out_features=256, bias=False)
          (dropout): Dropout(p=0.01, inplace=False)
        )
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (mlp): FeedForward(
          (fc1): Linear(in_features=256, out_features=1024, bias=False)
          (fc2): Linear(in_features=1024, out_features=256, bias=False)
          (activation): GELU(approximate='none')
          (dropout): Dropout(p=0.01, inplace=False)
        )
      )
    )
    (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (lm_head): Linear(i

# Training

In [14]:
trainer = Trainer(
    model=model,
    train_config=train_config,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    master_process=True
)
trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33mpathfinderkr[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Training: 100%|██████████| 981/981 [01:34<00:00, 10.41it/s, epoch=1, grad_norm=0.3982, loss=1.7068, lr=0.000000]


0,1
Grad Norm,█▂▃▃▂▁▁▄▂▂▃▃▂▃▃▃▃▂▂▄▂▃▃▂▂▂▂▂▁▂▂▂▁▁▁▁▁▁▁▁
Learning Rate,▄▄▄▅▆██████▇▇▇▇▆▆▆▆▆▅▄▄▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁
Train Loss,█▅▅▅▅▄▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁
Val Loss,█▆▄▃▂▂▁▁▁▁
Val Perplexity,█▅▄▃▂▁▁▁▁▁

0,1
Grad Norm,0.3982
Learning Rate,0.0
Train Loss,1.70679
Val Loss,1.82508
Val Perplexity,6.20332


## Save the model

In [15]:
if not train_config.debug:
    pass
    #output_dir = os.path.join(PROJECT_ROOT, "checkpoints", train_config.model_name, train_config.run_name)
    #os.makedirs(output_dir, exist_ok=True)
    #try:
    #    model.save_pretrained(
    #        output_dir,
    #        safe_serialization=True
    #    )
    #    print("Model saved successfully")
    #xcept Exception as e:
    #    print(f"Error saving model: {e}")
    # Push to Hugging Face Hub
    #model.push_to_hub(
    #    repo_id=f"PathFinderKR/{train_config.model_name}-{train_config.run_name}",
    #    private=True,
    #    use_auth_token=os.environ.get("HUGGINGFACE_TOKEN")
    #)
    #print(f"Model pushed to Hugging Face Hub: PathFinderKR/{train_config.model_name}-{train_config.run_name}")

In [16]:
# To load the model later, you can use:
# model = GPT.from_pretrained(output_dir).to(device)

# Inference

In [17]:
user_prompt = "To be, or not to be, that is the question"
input_ids = char_tokenizer.encode(user_prompt).unsqueeze(0).to(device)
output = model.generate(
    input_ids,
    use_cache=True,
    max_new_tokens=generation_config.max_new_tokens,
    temperature=generation_config.temperature,
    top_k=generation_config.top_k,
    tokenizer=char_tokenizer
)
response = char_tokenizer.decode(output[0].squeeze().cpu().numpy())

s,
And your to firtend, if Sencer heaven is well.

KING RICHORD III:
Pain of the shall t[91mResetting KV cache[0m
 Rew he you cance ouble them;
Was habjournd duder'sway raw, have it worn!

Second Bussween:
Duke I heads ware will suckers the d[91mResetting KV cache[0m
, my to, you to at iny one she
ould man; gim twith one the flom of his slackes drusp.

GLOUCESTER:
Are is kidagn! now heas your [91mResetting KV cache[0m
line I'll and
To prince. Boklie know, of you jicedince,
Has thou was parn of my you, hearth follow;
Or murnriently, choreding bo[91mResetting KV cache[0m
 Rihchard,
With all sake thy his leasial marre'ds as a trands,
Shull from foiticy nighter forget oald eyes:
Wise with threow him[91mResetting KV cache[0m
be compys,
And work'd yours dayars titlent import,
Not be timadgernable in will that his to came,
Stand the sinchens own detuers[91mResetting KV cache[0m
 that Comeeliings! Andlif, who fiel!

HENRY EDWARD ICK:
Now me, my I everak deysitying preatent is sucke

In [18]:
print("=" * 50)
print("User prompt: ")
print(user_prompt)
print("-" * 50)
print("🤖 Model Response:")
print(response)

User prompt: 
To be, or not to be, that is the question
--------------------------------------------------
🤖 Model Response:
re movest tist
Or be sand soul and of cure straior.

YORK:
I dist allive Leflend ut, goines, but so begieves
One for give mistres


# Profiling

In [None]:
if train_config.debug:
    input_ids = torch.randint(0, model_config.vocab_size, (1, model_config.max_seq_len), device=device)
    with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
        with record_function("model_inference"):
            model(input_ids)
    print(prof.key_averages(group_by_input_shape=True).table(sort_by="cuda_time_total", row_limit=20))