# Importing libraries

In [1]:
import os
import sys
from dotenv import load_dotenv
from typing import Tuple
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
from torch.profiler import profile, record_function, ProfilerActivity
import wandb
from src.utils import set_seed, load_text, split_text, speedometer
from src.config import ModelConfig, TrainConfig, GenerationConfig
from src.tokenizer import CharTokenizer
from models.GPT import GPT
from src.train import Trainer

In [2]:
PROJECT_ROOT = os.path.abspath(os.getcwd() + "/..")
sys.path.append(PROJECT_ROOT)
print(f"PROJECT_ROOT: {PROJECT_ROOT}")

PROJECT_ROOT: /workspace/PathFinder


# Configuration

In [3]:
model_config = ModelConfig(
    vocab_size=-1,
    max_seq_len=128,
    d_embed=128,
    n_layers=4,
    flash=True,
    n_heads=4,
    d_head=32,
    scale=32 ** -0.5,
    #rank=16,
    d_ff=512,
    #d_ff_multiple_of=64,
    #beta_min=1/2,
    #beta_max=8
)

train_config = TrainConfig(
    debug=False,
    wandb_project="nanoGPT",
    model_name="PathFinder-nano",
    per_device_train_batch_size=512,
    per_device_eval_batch_size=1024,
    gradient_accumulation_steps=512 // 512,
    num_train_epochs=1,
    learning_rate=5e-3,
    attn_decay=0.5,
    eval_steps=100,
    mixed_precision=True,
    matmul_precision="high",
)

generation_config = GenerationConfig(
    use_cache=True,
    max_new_tokens=1000,
    temperature=1.0,
    top_k=50
)

In [4]:
load_dotenv()
wandb.login(key=os.environ.get("WANDB_API_KEY"))

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mpathfinderkr[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

# Utils

## Reproducibility

In [5]:
set_seed(train_config.seed)

Random seed set to 42


## Device

In [6]:
device = torch.device("cuda")
print(f"Device: {torch.cuda.get_device_name(device)}")
torch.set_float32_matmul_precision(train_config.matmul_precision)  # Tensor Cores
print(f"MatMul Precision: {train_config.matmul_precision}")

Device: NVIDIA RTX A6000
MatMul Precision: high


# Dataset

In [7]:
dataset_path = os.path.join(PROJECT_ROOT, "datasets/Shakespeare/shakespeare.txt")
shakespeare_text = load_text(dataset_path)

Loaded text data from /workspace/PathFinder/datasets/Shakespeare/shakespeare.txt (length: 1115394 characters).


In [8]:
if train_config.debug:
    subset_shakespeare_text = shakespeare_text[:10000]
    print(subset_shakespeare_text)
    shakespeare_text = subset_shakespeare_text

# Tokenizer

In [9]:
char_tokenizer = CharTokenizer()
char_tokenizer.build_vocab(text=shakespeare_text)
vocab_path = os.path.join(PROJECT_ROOT, "datasets/Shakespeare/vocab.json")
char_tokenizer.save_vocab(vocab_path)
model_config.vocab_size = char_tokenizer.vocab_size

Vocabulary size: 69
Vocabulary saved to /workspace/PathFinder/datasets/Shakespeare/vocab.json.


In [10]:
if train_config.debug:
    print("Vocabulary:", char_tokenizer.char2idx)

# Preprocessing

In [11]:
train_text, val_text = split_text(shakespeare_text, val_size=0.1)
print(f"Training text length: {len(train_text)} characters")
print(f"Validation text length: {len(val_text)} characters")

Training text length: 1003854 characters
Validation text length: 111540 characters


In [12]:
class TextDataset(Dataset):
    def __init__(self, text: str, tokenizer: CharTokenizer, max_seq_len: int):
        self.encoded = tokenizer.encode(text)
        self.max_seq_len = max_seq_len

    def __len__(self) -> int:
        return len(self.encoded) - self.max_seq_len

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        input_ids = self.encoded[idx:idx + self.max_seq_len]
        target_ids = self.encoded[idx + 1:idx + self.max_seq_len + 1]
        return input_ids, target_ids

def collate_fn(batch):
    input_ids = torch.stack([item[0] for item in batch])
    target_ids = torch.stack([item[1] for item in batch])
    return {"input_ids": input_ids, "target_ids": target_ids}

train_dataset = TextDataset(train_text, char_tokenizer, model_config.max_seq_len)
val_dataset = TextDataset(val_text, char_tokenizer, model_config.max_seq_len)

train_loader = DataLoader(
    train_dataset,
    collate_fn=collate_fn,
    batch_size=train_config.per_device_eval_batch_size,
    shuffle=True,
    num_workers=4
)
val_loader = DataLoader(
    val_dataset,
    collate_fn=collate_fn,
    batch_size=train_config.per_device_eval_batch_size,
    shuffle=False,
    num_workers=4
)

In [13]:
if train_config.debug:
    sample_batch = next(iter(train_loader))
    print(f"Sample input IDs: {sample_batch['input_ids'][0]}")
    print(f"Sample target IDs: {sample_batch['target_ids'][0]}")

# Model

In [14]:
# Initialize the model
model = GPT(model_config).to(device)
model = torch.compile(model)
print(model)
print(f"Number of parameters: {model.num_params() / 1e6:.2f}M")

OptimizedModule(
  (_orig_mod): GPT(
    (token_embedding): Embedding(69, 128)
    (positional_encoding): Embedding(128, 128)
    (dropout): Dropout(p=0.1, inplace=False)
    (blocks): ModuleList(
      (0-3): 4 x Block(
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (attn): MultiHeadAttention(
          (qkv_proj): Linear(in_features=128, out_features=384, bias=False)
          (out_proj): Linear(in_features=128, out_features=128, bias=False)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (mlp): FeedForward(
          (fc1): Linear(in_features=128, out_features=512, bias=False)
          (fc2): Linear(in_features=512, out_features=128, bias=False)
          (activation): GELU(approximate='none')
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (lm_head): Linear(in_fea

# Training

In [15]:
trainer = Trainer(
    model=model,
    train_config=train_config,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    master_process=True
)
trainer.train()

Training: 100%|██████████| 981/981 [01:11<00:00, 13.66it/s, epoch=1, grad_norm=0.1856, loss=1.2961, lr=0.000000]


0,1
Grad Norm,▂▂▂▂█▂▂▂▃▂▂▂▁▁▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Learning Rate,▁▃▄███▇▇▇▇▇▇▇▇▇▇▇▆▆▅▅▅▄▄▄▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁
Train Loss,█▆▅▅▅▄▄▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Val Loss,█▄▂▂▁▁▁▁▁▁
Val Perplexity,█▃▂▁▁▁▁▁▁▁

0,1
Grad Norm,0.1856
Learning Rate,0.0
Train Loss,1.29614
Val Loss,1.48906
Val Perplexity,4.43292


## Save the model

In [16]:
if not train_config.debug:
    output_dir = os.path.join(PROJECT_ROOT, "checkpoints", train_config.model_name, train_config.run_name)
    os.makedirs(output_dir, exist_ok=True)
    try:
        model.save_pretrained(
            output_dir,
            safe_serialization=True
        )
        print("Model saved successfully")
    except Exception as e:
        print(f"Error saving model: {e}")
    # Push to Hugging Face Hub
    #model.push_to_hub(
    #    repo_id=f"PathFinderKR/{train_config.model_name}-{train_config.run_name}",
    #    private=True,
    #    use_auth_token=os.environ.get("HUGGINGFACE_TOKEN")
    #)
    #print(f"Model pushed to Hugging Face Hub: PathFinderKR/{train_config.model_name}-{train_config.run_name}")

Model saved successfully


In [17]:
# To load the model later, you can use:
# model = GPT(model_config)
# model = model.from_pretrained(output_dir).to(device)

# Inference

In [22]:
user_prompt = "To be, or not to be, that is the question"
input_ids = char_tokenizer.encode(user_prompt).unsqueeze(0).to(device)
output = model.generate(
    input_ids,
    use_cache=True,
    max_new_tokens=generation_config.max_new_tokens,
    temperature=generation_config.temperature,
    top_k=generation_config.top_k,
    tokenizer=char_tokenizer
)
response = char_tokenizer.decode(output[0].squeeze().cpu().numpy())


A Conspy, good preserve yourself
All a worldred to-night. When the bloody might
Makes; Resetting KV cache
lock selfloughtill, brence as if it reason
When I come, I let not not depended.

YORK:
No more; for I am mine own wrongs, good dResetting KV cache
 ser then wit.

CLARENCE:
My hardeous is said to my light safety, be a man,
That nature's joy, 'tless spurplish'd,
Shall I. Who Resetting KV cache
now hearing us hold thing Rome, thy teether,
My grace to the last times pray.
Join Margaret: Ah, we'll remember them or Rutland,Resetting KV cache
 bawn offick and Emile are, they all Love of York;
Raised heart and grave to you will command you.
What dalice that you you stanResetting KV cache
 forced awards, of Henry:
No welcome, sir, welcome, some hold true to thee!
O the moon foolish and name with consented Isabel.

Resetting KV cache
cast hers.

Thereing here be year the merity with his name, to my duty
Would weep her lie and with the king's piece.

KING RICHAResetting KV cache
r up havest

In [23]:
print("=" * 50)
print("User prompt: ")
print(user_prompt)
print("-" * 50)
print("🤖 Model Response:")
print(response)

User prompt: 
To be, or not to be, that is the question
--------------------------------------------------
🤖 Model Response:
To be, or not to be, that is the question
A Conspy, good preserve yourself
All a worldred to-night. When the bloody might
Makes; lock selfloughtill, brence as if it reason
When I come, I let not not depended.

YORK:
No more; for I am mine own wrongs, good d ser then wit.

CLARENCE:
My hardeous is said to my light safety, be a man,
That nature's joy, 'tless spurplish'd,
Shall I. Who now hearing us hold thing Rome, thy teether,
My grace to the last times pray.
Join Margaret: Ah, we'll remember them or Rutland, bawn offick and Emile are, they all Love of York;
Raised heart and grave to you will command you.
What dalice that you you stan forced awards, of Henry:
No welcome, sir, welcome, some hold true to thee!
O the moon foolish and name with consented Isabel.

cast hers.

Thereing here be year the merity with his name, to my duty
Would weep her lie and with the kin

# Speedometer

In [20]:
speedometer(
    model=model,
    input_ids=char_tokenizer.encode("a").unsqueeze(0).to(device),
    use_cache=False,
    warmup_tokens=100,
    timing_tokens=100,
    num_runs=5
)

KV Cache Enabled: False
Warmup Tokens: 100, Timing Tokens: 100, Runs: 5
--------------------------------------------------
Run  1: Latency = 0.79 ms/token, Throughput = 1269.76 tokens/sec
Run  2: Latency = 0.75 ms/token, Throughput = 1324.76 tokens/sec
Run  3: Latency = 0.75 ms/token, Throughput = 1333.79 tokens/sec
Run  4: Latency = 0.77 ms/token, Throughput = 1307.03 tokens/sec
Run  5: Latency = 0.76 ms/token, Throughput = 1316.55 tokens/sec
--------------------------------------------------
Summary (over 5 runs):
  Avg    Latency: 0.76 ms/token
  Std    Latency: 0.01 ms/token
  Min    Latency: 0.75 ms/token
  Max    Latency: 0.79 ms/token
  Median Latency: 0.76 ms/token
  Avg    Throughput: 1310.00 tokens/sec


In [21]:
speedometer(
    model=model,
    input_ids=char_tokenizer.encode("a").unsqueeze(0).to(device),
    use_cache=True,
    warmup_tokens=100,
    timing_tokens=100,
    num_runs=5
)

KV Cache Enabled: True
Warmup Tokens: 100, Timing Tokens: 100, Runs: 5
--------------------------------------------------
Run  1: Latency = 0.96 ms/token, Throughput = 1043.41 tokens/sec
Run  2: Latency = 0.95 ms/token, Throughput = 1048.69 tokens/sec
Run  3: Latency = 0.95 ms/token, Throughput = 1054.51 tokens/sec
Run  4: Latency = 0.96 ms/token, Throughput = 1037.64 tokens/sec
Run  5: Latency = 0.95 ms/token, Throughput = 1049.47 tokens/sec
--------------------------------------------------
Summary (over 5 runs):
  Avg    Latency: 0.96 ms/token
  Std    Latency: 0.01 ms/token
  Min    Latency: 0.95 ms/token
  Max    Latency: 0.96 ms/token
  Median Latency: 0.95 ms/token
  Avg    Throughput: 1046.71 tokens/sec


# Profiling

In [None]:
input_ids = torch.randint(0, model_config.vocab_size, (1, model_config.max_seq_len), device=device)
with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
    with record_function("model_inference"):
        model(input_ids)
print(prof.key_averages(group_by_input_shape=True).table(sort_by="cuda_time_total", row_limit=20))

# Attention Scores