In [1]:
# Block 2: Imports & Path Setup
import sys
import os
import math
import time
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
from datasets import load_dataset
from tqdm.auto import tqdm

from model.configuration_holo import HoloConfig
from model.modeling_holo import HoloForCausalLM

# Setup Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [2]:
# Adjust these for your GPU (e.g., L40S/A100)
BATCH_SIZE = 2           # Per-device batch size
GRAD_ACCUM_STEPS = 8      # Effective batch size = 4 * 8 = 32
LEARNING_RATE = 3e-4      # From your README
MAX_STEPS = 1000          # Total training steps
SEQ_LEN = 2048            # Context length
WARMUP_STEPS = 100
SAVE_STEPS = 200

In [3]:
# Block 4: Data Loading
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# Load Dataset (Streaming to avoid disk usage)
dataset = load_dataset("DKYoon/SlimPajama-6B", split="train", streaming=True)

def tokenize_stream(data_iter):
    """Yields tokens from the stream"""
    for example in data_iter:
        text = example["text"]
        tokenized = tokenizer(text, truncation=True, max_length=SEQ_LEN, padding="max_length")
        yield {
            "input_ids": torch.tensor(tokenized["input_ids"]),
            "attention_mask": torch.tensor(tokenized["attention_mask"]) # Optional for Holo, but good practice
        }

# Create Iterator
def get_data_loader(dataset, batch_size):
    mapped_ds = dataset.map(
        lambda x: tokenizer(x["text"], truncation=True, max_length=SEQ_LEN, padding="max_length"),
        remove_columns=["text", "meta"]
    )
    mapped_ds = mapped_ds.with_format("torch")
    
    # CHANGE: Add num_workers and pin_memory
    return DataLoader(
        mapped_ds, 
        batch_size=batch_size, 
        num_workers=4,      # Parallelize data fetching (try 4 or 8)
        pin_memory=True,    # Faster transfer to CUDA
        prefetch_factor=2   # Buffer batches
    )

train_loader = get_data_loader(dataset, BATCH_SIZE)
print("Data loader ready.")

Resolving data files:   0%|          | 0/48 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/48 [00:00<?, ?it/s]

Data loader ready.


In [4]:
# Block 5: Model Initialization
sizes = ["small", "medium", "large"]
config = HoloConfig.from_preset(sizes[0], use_version = 2)
model = HoloForCausalLM(config).to(device)
# model = torch.compile(model)
# model.gradient_checkpointing_enable()


print("Configuration loaded:", config)

# Optimizer (AdamW)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.95))

# Scheduler
scheduler = get_cosine_schedule_with_warmup(
    optimizer, num_warmup_steps=WARMUP_STEPS, num_training_steps=MAX_STEPS
)

print(f"Model Params: {model.num_parameters()/1e6:.1f}M")

Configuration loaded: HoloConfig {
  "bos_token_id": 1,
  "d_model": 768,
  "dropout": 0.0,
  "eos_token_id": 2,
  "expansion_factor": 4,
  "hd_dim": 3072,
  "head_dim": 256,
  "holo_expansion_ratio": 4,
  "initializer_range": 0.02,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 8192,
  "model_type": "holo",
  "num_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "phase_scale": 3.0,
  "transformers_version": "4.57.3",
  "use_version": 2,
  "vocab_size": 50257
}

Model Params: 180.3M


In [5]:
# Block 6: Custom Training Loop with Perplexity (Updated)
import math

model.train()
optimizer.zero_grad()

# Counters
accum_steps = 0 
global_step = 0
step_loss = 0.0

# Initialize progress bar
progress_bar = tqdm(range(MAX_STEPS), desc="Training")
data_iter = iter(train_loader)

print("Starting training loop with Perplexity tracking...")

while global_step < MAX_STEPS:
    try:
        batch = next(data_iter)
    except StopIteration:
        data_iter = iter(train_loader)
        batch = next(data_iter)
    
    accum_steps += 1

    # 1. Move to device
    input_ids = batch["input_ids"].to(device)
    labels = batch["input_ids"].clone().to(device)

    # 2. Forward Pass
    with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
        outputs = model(input_ids=input_ids, labels=labels)
        
        # We divide by GRAD_ACCUM_STEPS so that when we sum them up 
        # later in backward(), we get the average loss.
        loss = outputs.loss / GRAD_ACCUM_STEPS 

    # 3. Backward Pass
    loss.backward()
    step_loss += loss.item()

    # 4. Optimization Step (Triggered every GRAD_ACCUM_STEPS batches)
    if accum_steps % GRAD_ACCUM_STEPS == 0:
        # Gradient Clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        
        # --- METRICS CALCULATION ---
        # Recover the true average CrossEntropyLoss for this step
        avg_loss = step_loss * GRAD_ACCUM_STEPS
        
        # Calculate Perplexity (exp(loss))
        try:
            perplexity = math.exp(avg_loss)
        except OverflowError:
            perplexity = float("inf")
        
        # Update Progress Bar
        progress_bar.update(1)
        global_step += 1
        
        # Show Loss and PPL in the bar
        progress_bar.set_postfix({
            "loss": f"{avg_loss:.4f}", 
            "ppl": f"{perplexity:.2f}"
        })
        
        # Reset step accumulator
        step_loss = 0.0
        
        # Save Checkpoint
        if global_step % SAVE_STEPS == 0:
            save_path = f"./holo_checkpoints/step_{global_step}"
            model.save_pretrained(save_path)

print("Training finished!")

Training:   0%|          | 0/1000 [00:00<?, ?it/s]

Starting training loop with Perplexity tracking...


KeyboardInterrupt: 