# quantammsim Training Loop Profiler — Colab

Profile the training loop on an L4 GPU to identify per-iteration bottlenecks.

**Prerequisites:**
- Data parquet files in Google Drive at `My Drive/quantammsim_data/`

**Runtime:** Change to GPU via Runtime → Change runtime type → L4 GPU

In [None]:
# Check GPU
import subprocess
result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)
print(result.stdout)

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Clone repo + install
import os

REPO_DIR = '/content/quantammsim'
BRANCH = 'gpu-profiling'

if not os.path.exists(REPO_DIR):
    !git clone -b {BRANCH} https://github.com/QuantAMMProtocol/QuantAMMSim.git {REPO_DIR}

os.chdir(REPO_DIR)
!pip install -q -e '.[dev]'
!git log --oneline -3

In [None]:
# Verify JAX sees GPU
import jax
print(f'JAX backend: {jax.default_backend()}')
print(f'JAX devices: {jax.devices()}')
assert jax.default_backend() == 'gpu', 'No GPU — change runtime type!'

In [None]:
# Copy data from Drive to package data dir (so root=None just works)
import shutil

DRIVE_DATA = '/content/drive/MyDrive/quantammsim_data'
PKG_DATA = '/content/quantammsim/quantammsim/data'

needed = ['ETH_USD.parquet', 'USDC_USD.parquet', 'BTC_USD.parquet']
for f in os.listdir(DRIVE_DATA):
    if f.endswith('.parquet'):
        src = os.path.join(DRIVE_DATA, f)
        dst = os.path.join(PKG_DATA, f)
        if not os.path.exists(dst):
            shutil.copy2(src, dst)
            sz = os.path.getsize(dst) / 1e6
            print(f'  Copied {f} ({sz:.1f} MB)')
        else:
            print(f'  {f} already present')

# Verify needed files
for f in needed:
    path = os.path.join(PKG_DATA, f)
    assert os.path.exists(path), f'MISSING: {f}'
print('\nAll required parquets present.')

## Run profiler

Two configs available:
- **`--tuning-config`**: ETH/USDC, `mean_reversion_channel`, fees=0 (matches `tune_training_hyperparams.py`)
- **default**: BTC/ETH, `momentum` (simpler, faster)

Modes:
- Coarse timing: JIT compile vs amortised per-iteration cost (always runs)
- `--cprofile`: Python cProfile showing where host-side time goes
- `--trace`: JAX profiler trace (viewable in TensorBoard / perfetto.dev)

In [None]:
# Simple config (BTC/ETH momentum) — quick sanity check
!cd /content/quantammsim && python scripts/profile_training_loop.py \
    -n 20 --n-param-sets 4 --batch-size 8

In [None]:
# Tuning config (ETH/USDC mean_reversion_channel) — the real target
!cd /content/quantammsim && python scripts/profile_training_loop.py \
    --tuning-config -n 50 --n-param-sets 8 --batch-size 16 \
    --cprofile

In [None]:
# (Optional) JAX profiler trace
TRACE_DIR = '/content/jax-trace'
!cd /content/quantammsim && python scripts/profile_training_loop.py \
    --tuning-config -n 20 --n-param-sets 8 --batch-size 16 \
    --trace --trace-dir {TRACE_DIR}

# Copy trace to Drive for later analysis
DRIVE_OUTPUT = '/content/drive/MyDrive/quantammsim_data/profiling_output/'
os.makedirs(DRIVE_OUTPUT, exist_ok=True)
!cp -r {TRACE_DIR} {DRIVE_OUTPUT}
print(f'\nTrace saved to Drive: {DRIVE_OUTPUT}')

In [None]:
# ── 7. (Optional) View trace in TensorBoard ──────────────────────────────
%load_ext tensorboard
%tensorboard --logdir /content/jax-trace

## Notes

- L4 is Ada Lovelace (same arch as RTX 4090), 24GB VRAM (same), ~60% of 4090's FP32 throughput. Per-iteration **ratios** (JIT'd compute vs Python overhead) should transfer very closely to 4090 results.
- The key numbers to look at from cProfile: `calculate_period_metrics`, `_calculate_return_value`, `has_nan_params`, `deepcopy`.
- If the clone fails (private repo), add a GitHub personal access token: `!git clone https://<USER>:<TOKEN>@github.com/QuantAMMProtocol/QuantAMMSim.git /content/quantammsim`