# Tab Hero - Supervised Training Walkthrough

Quick notebook showing how the model is trained end-to-end: tokenization, dataset, model, one training step, and reading back a checkpoint.

The model is an encoder-decoder transformer. The encoder processes a mel spectrogram; the decoder generates a sequence of note tokens autoregressively. Each note is four tokens: time delta, fret combination, modifier flags (HOPO/TAP/star power), duration.

Run from the `notebooks/` directory. Assumes `tab_hero` is installed (`pip install -e .`).

In [None]:
import sys, os
sys.path.insert(0, os.path.join(os.path.dirname(os.getcwd()), "src"))

import numpy as np
import matplotlib.pyplot as plt
import torch

print(f"PyTorch {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

## Tokenization

Each note becomes four consecutive tokens, `[TIME_DELTA, FRET, MODIFIER, DURATION]`, so the full sequence looks like `[BOS, T, F, M, D, T, F, M, D, ..., EOS]`. The tokenizer handles quantization (10 ms bins for time, 50 ms for duration) and the fret bitmask encoding.

In [None]:
import sys, os
sys.path.insert(0, os.path.join(os.path.dirname(os.getcwd()), "src"))

import numpy as np
import matplotlib.pyplot as plt
import torch

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"torch {torch.__version__}, device={DEVICE}")

In [None]:
from tab_hero.dataio.tokenizer import ChartTokenizer

tok = ChartTokenizer()
print(f"vocab size: {tok.vocab_size}, tokens per note: {tok.tokens_per_note}")

# encode a single note and decode it back
t = tok.encode_time_delta(350.0)
f = tok.encode_frets([2, 4])
m = tok.encode_modifiers(is_hopo=True, is_tap=False, is_star_power=False)
d = tok.encode_duration(200.0)
print(f"quad: {t} {f} {m} {d}  →  '{tok.id_to_token[t]} {tok.id_to_token[f]} {tok.id_to_token[m]} {tok.id_to_token[d]}'")
print(f"decoded: {tok.decode_time_delta(t)} ms, frets={tok.decode_frets(f)}, hopo={tok.decode_modifiers(m)[0]}, dur={tok.decode_duration(d)} ms")

## Dataset

Songs are preprocessed into `.tab` files (`scripts/preprocess.py`) which bundle the mel spectrogram and token sequence together. At training time we use `ChunkedTabDataset` because full songs are too long to process in one shot - they get sliced into overlapping windows. SpecAugment is applied on-the-fly when `training=True`.

If you don't have preprocessed data yet:
```bash
python scripts/preprocess.py --input_dir data/sample --output_dir data/sample/processed
```

In [None]:
from pathlib import Path
from torch.utils.data import DataLoader
from tab_hero.dataio.chunked_dataset import ChunkedTabDataset, chunked_collate_fn

DATA_DIR = "../data/sample/processed"

if not list(Path(DATA_DIR).glob("*.tab")):
    print("no .tab files found — run preprocess.py first")
else:
    dataset = ChunkedTabDataset(
        data_dir=DATA_DIR,
        split=None,
        max_mel_frames=4096,
        max_token_length=2048,
        chunk_overlap_frames=256,
        training=True,
    )
    loader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=chunked_collate_fn)
    batch = next(iter(loader))

    print(f"{len(dataset)} chunks")
    print(f"audio : {batch['audio_embeddings'].shape}  (batch, frames, n_mels)")
    print(f"tokens: {batch['note_tokens'].shape}  (batch, seq_len)")
    print(f"difficulty: {batch['difficulty_id']}, instrument: {batch['instrument_id']}")

## Model

The encoder is a small Conv1D stack that downsamples the mel frames 4x before passing them to the decoder via cross-attention. The decoder is a standard causal transformer with RoPE positional encoding and instrument/difficulty conditioning embeddings.

The config below is smaller than the largest model variant, but runs fine on CPU.

In [None]:
from tab_hero.model.chart_transformer import ChartTransformer

model = ChartTransformer(
    vocab_size=tok.vocab_size,
    audio_input_dim=128,
    encoder_dim=256,
    decoder_dim=256,
    n_decoder_layers=4,
    n_heads=8,
    ffn_dim=1024,
    max_seq_len=2048,
    dropout=0.1,
    audio_downsample=4,
    use_flash=False,
    use_rope=True,
    gradient_checkpointing=False,
).to(DEVICE)

total = sum(p.numel() for p in model.parameters())
print(f"{total:,} parameters ({model.get_num_params(non_embedding=True):,} non-embedding)")

## Forward pass

The model uses teacher forcing during training - the decoder gets ground-truth tokens as input (shifted right) rather than its own predictions. At random init, loss should be around `log(vocab_size) ≈ 6.6`.

In [None]:
audio  = torch.randn(2, 512, 128, device=DEVICE)
tokens = torch.randint(3, tok.vocab_size, (2, 64), device=DEVICE)
tokens[:, 0] = tok.bos_token_id
tokens[:, -1] = tok.eos_token_id

model.eval()
with torch.no_grad():
    out = model(audio, tokens,
                difficulty_id=torch.tensor([3, 1], device=DEVICE),
                instrument_id=torch.tensor([0, 1], device=DEVICE))

print(f"logits: {out['logits'].shape}")
print(f"loss:   {out['loss'].item():.4f}  (expected ~{np.log(tok.vocab_size):.2f} at random init)")

## One training step

The Trainer (in `tab_hero/training/trainer.py`) wraps all of this, but it's useful to see what one step actually looks like. Default optimizer is AdamW with β₂=0.95 (standard for transformers), cosine LR schedule with linear warmup, BF16 mixed precision, and gradient clipping at 1.0.

In [None]:
import torch.optim as optim

optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01, betas=(0.9, 0.95))

model.train()
optimizer.zero_grad()
out = model(audio, tokens)
out["loss"].backward()
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()

print(f"loss: {out['loss'].item():.4f}, grad norm: {grad_norm:.4f}")

## Running actual training

```bash
python scripts/train.py

# smoke-test on 100 samples
python scripts/train.py data.max_samples=100 training.max_epochs=5

# resume
python scripts/train.py training.resume_checkpoint=last_model.pt
```

Progress is written to `training_progress.log`.

In [None]:
import re
from pathlib import Path

LOG_PATH = Path("../training_progress.log")

if not LOG_PATH.exists():
    print("no log yet — run training first")
else:
    steps, losses = [], []
    for line in LOG_PATH.read_text().splitlines():
        m = re.search(r"Step (\d+): loss=([\d.]+)", line)
        if m:
            steps.append(int(m.group(1)))
            losses.append(float(m.group(2)))

    plt.figure(figsize=(9, 3))
    plt.plot(steps, losses)
    plt.xlabel("step")
    plt.ylabel("loss")
    plt.title("training loss")
    plt.tight_layout()
    plt.show()
    print(f"step {steps[-1]}, loss {losses[-1]:.4f}")

## Checkpoint

The trainer saves `best_model.pt` and `last_model.pt` after each epoch. Each checkpoint includes model weights, optimizer state, and the epoch/step/loss metadata needed to resume.

In [None]:
CKPT_PATH = Path("../checkpoints/best_model.pt")

if not CKPT_PATH.exists():
    print("no checkpoint yet")
else:
    ckpt = torch.load(CKPT_PATH, map_location="cpu")
    print(f"epoch {ckpt['epoch']}, step {ckpt['global_step']:,}, val loss {ckpt['best_val_loss']:.4f}")
    print(f"keys: {list(ckpt.keys())}")