# SDE-GAN Training on GPU

Train a Neural SDE-GAN for synthetic multi-asset price path generation.

**Runtime**: Go to Runtime > Change runtime type > **T4 GPU** (free tier) or A100 (Colab Pro).

T4 (16GB): hidden=64, depth=2, batch=512, ~20min for 20k steps
A100 (40GB): hidden=128, depth=3, batch=1024, ~10min for 20k steps

**Data**: Minute-resolution price data is downloaded from Binance Vision (primary source)
with gap-filling from CryptoDataDownload, Coinbase, Gemini, etc. This takes ~10-15 min
on first run but the parquets are cached for the session.

In [None]:
# Check GPU
import subprocess
result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)
print(result.stdout)
if 'T4' in result.stdout:
    GPU_TIER = 'T4'
    print('Free T4 detected — using medium config')
elif 'A100' in result.stdout or 'V100' in result.stdout:
    GPU_TIER = 'A100'
    print('High-end GPU detected — using full config')
else:
    GPU_TIER = 'T4'  # conservative default
    print('Unknown GPU — defaulting to medium config')

In [None]:
# Install deps
!pip install -q jax[cuda12] equinox diffrax optax
!pip install -q dask pandas pyarrow
!pip install -q binance-historical-data Historic-Crypto bidask

In [None]:
# Clone repo
import os
if not os.path.exists('quantammsim'):
    !git clone -b synthetic-price-gen https://github.com/QuantAMMProtocol/quantammsim.git
os.chdir('quantammsim')
!pip install -q -e .

In [None]:
# Verify JAX sees GPU
import jax
print(f'JAX devices: {jax.devices()}')
assert any(d.platform == 'gpu' for d in jax.devices()), 'No GPU! Change runtime type.'

In [None]:
import jax.numpy as jnp
import numpy as np

from quantammsim.synthetic.sde_gan import (
    train_sde_gan, generate_paths, compute_daily_log_prices,
)
from quantammsim.utils.data_processing.historic_data_utils import get_historic_parquet_data

In [None]:
# Download minute-resolution price data
# Uses the same pipeline as experiments/do_data_import.py:
#   Binance Vision (primary) -> gap fill from other sources -> parquet
import os
from pathlib import Path
from quantammsim.utils.data_processing.historic_data_utils import update_historic_data

tokens = ['ETH', 'BTC', 'USDC', 'PAXG']
DATA_DIR = Path('quantammsim/data')
DATA_DIR.mkdir(exist_ok=True)
data_dir_str = str(DATA_DIR) + '/'

for token in tokens:
    parquet_path = DATA_DIR / f'{token}_USD.parquet'
    if parquet_path.exists():
        print(f'{token}_USD.parquet already exists, skipping download')
        continue
    print(f'Downloading {token}...')
    update_historic_data(token, data_dir_str)
    # update_historic_data writes to combined_data/ subdir, move to data root
    combined_parquet = DATA_DIR / 'combined_data' / f'{token}_USD.parquet'
    combined_daily = DATA_DIR / 'combined_data' / f'{token}_USD_daily.csv'
    if combined_parquet.exists():
        os.rename(str(combined_parquet), str(parquet_path))
    if combined_daily.exists():
        os.rename(str(combined_daily), str(DATA_DIR / f'{token}_USD_daily.csv'))
    print(f'{token} done')

print('\nParquet files:')
for f in sorted(DATA_DIR.glob('*.parquet')):
    print(f'  {f.name} ({f.stat().st_size / 1e6:.1f} MB)')

In [None]:
# Load data
tokens = ['ETH', 'BTC', 'USDC', 'PAXG']
data_root = 'quantammsim/data'
price_df = get_historic_parquet_data(tokens, cols=['close'], root=data_root)
close_cols = [f'close_{t}' for t in tokens]
minute_prices = price_df[close_cols].values.astype(np.float64)
valid_mask = ~np.any(np.isnan(minute_prices), axis=1)
first_valid = np.argmax(valid_mask)
last_valid = len(valid_mask) - np.argmax(valid_mask[::-1])
minute_prices = minute_prices[first_valid:last_valid]
n_assets = len(tokens)
minute_prices_jnp = jnp.array(minute_prices)
daily_log = compute_daily_log_prices(minute_prices_jnp)
n_days = daily_log.shape[0]

real_returns = jnp.diff(daily_log, axis=0)
real_drift = jnp.mean(real_returns, axis=0)
real_vol = jnp.std(real_returns, axis=0)
real_corr = jnp.corrcoef(real_returns.T)

print(f'Data: {n_days} days, {n_assets} assets')
for i, t in enumerate(tokens):
    print(f'  {t}: drift={float(real_drift[i]):.6f}/day, vol={float(real_vol[i]):.6f}/day')
print(f'\nCorrelations:')
for i in range(n_assets):
    for j in range(i+1, n_assets):
        print(f'  {tokens[i]}-{tokens[j]}: {float(real_corr[i,j]):.3f}')

## Config

Adjust these based on your GPU. The defaults auto-detect T4 vs A100.

In [None]:
if GPU_TIER == 'A100':
    # A100 / V100 / Colab Pro
    CONFIG = dict(
        hidden_size=128, width_size=128, depth=3,
        noise_size=12, initial_noise_size=12,
        batch_size=1024, window_len=50,
        n_steps=30000,
        generator_lr=2e-5, discriminator_lr=1e-4,
        drift_lambda=1.0,
        use_reversible_heun=True,
    )
else:
    # T4 (free tier) — 16GB VRAM
    CONFIG = dict(
        hidden_size=64, width_size=64, depth=2,
        noise_size=8, initial_noise_size=8,
        batch_size=512, window_len=50,
        n_steps=20000,
        generator_lr=2e-5, discriminator_lr=1e-4,
        drift_lambda=1.0,
        use_reversible_heun=True,
    )

print(f'Config ({GPU_TIER}):')
for k, v in CONFIG.items():
    print(f'  {k}: {v}')

## Train

In [None]:
import time

key = jax.random.PRNGKey(42)
t0 = time.time()

generator, vol_scale, history = train_sde_gan(
    minute_prices_jnp, n_assets=n_assets, key=key,
    verbose=True, **CONFIG,
)

elapsed = time.time() - t0
print(f'\nDone in {elapsed:.0f}s ({elapsed/CONFIG["n_steps"]*1000:.1f}ms/step)')

## Evaluate

In [None]:
y0 = daily_log[0]
key_eval = jax.random.PRNGKey(99)
N_PATHS = 2000

for horizon in [10, 30, 50, 100, 200]:
    paths = generate_paths(generator, vol_scale, y0,
                           n_days=horizon, n_paths=N_PATHS, key=key_eval)
    y0_bc = jnp.broadcast_to(y0[:, None], (n_assets, N_PATHS))[None, ...]
    full = jnp.concatenate([y0_bc, paths], axis=0)
    returns = jnp.diff(full, axis=0)
    drift = jnp.mean(returns, axis=(0, 2))
    vol = jnp.mean(jnp.std(returns, axis=2), axis=0)

    flat_ret = returns.transpose(0, 2, 1).reshape(-1, n_assets)
    gen_corr = jnp.corrcoef(flat_ret.T)

    print(f'\n--- {horizon}-day paths ---')
    for i, t in enumerate(tokens):
        rd, rv = float(real_drift[i]), float(real_vol[i])
        d, v = float(drift[i]), float(vol[i])
        dr = d / rd if abs(rd) > 1e-8 else float('inf')
        print(f'  {t}: drift={d:.6f} ({dr:.1f}x), vol={v:.6f} ({v/rv:.2f}x)')
    print(f'  Correlations (real -> gen):')
    for i in range(n_assets):
        for j in range(i+1, n_assets):
            print(f'    {tokens[i]}-{tokens[j]}: {float(real_corr[i,j]):.3f} -> {float(gen_corr[i,j]):.3f}')

## Lambda Sweep (optional)

Run this cell to sweep drift_lambda and find the optimal value for your GPU config.

In [None]:
# Optional: sweep drift_lambda
SWEEP = False  # Set to True to run

if SWEEP:
    sweep_results = {}
    for lam in [0.0, 0.1, 0.5, 1.0, 2.0]:
        print(f'\n{"="*50}')
        print(f'drift_lambda = {lam}')
        print(f'{"="*50}')
        cfg = {**CONFIG, 'drift_lambda': lam}
        key_s = jax.random.PRNGKey(42)
        gen_s, vs_s, hist_s = train_sde_gan(
            minute_prices_jnp, n_assets=n_assets, key=key_s,
            verbose=True, **cfg,
        )
        # Quick 50d eval
        paths_s = generate_paths(gen_s, vs_s, y0, n_days=50, n_paths=1000, key=key_eval)
        y0_bc_s = jnp.broadcast_to(y0[:, None], (n_assets, 1000))[None, ...]
        full_s = jnp.concatenate([y0_bc_s, paths_s], axis=0)
        ret_s = jnp.diff(full_s, axis=0)
        d_s = jnp.mean(ret_s, axis=(0, 2))
        v_s = jnp.mean(jnp.std(ret_s, axis=2), axis=0)
        flat_s = ret_s.transpose(0, 2, 1).reshape(-1, n_assets)
        gc_s = jnp.corrcoef(flat_s.T)
        sweep_results[lam] = {
            'drift_ratios': {t: float(d_s[i])/float(real_drift[i]) if abs(float(real_drift[i])) > 1e-8 else float('inf') for i, t in enumerate(tokens)},
            'vol_ratios': {t: float(v_s[i])/float(real_vol[i]) for i, t in enumerate(tokens)},
            'eth_btc_corr': float(gc_s[0, 1]),
        }

    # Summary table
    print(f'\n{"="*60}')
    print('SWEEP SUMMARY (50-day horizon)')
    print(f'{"="*60}')
    print(f'{"lambda":>8} | {"ETH drift":>10} {"BTC drift":>10} {"ETH-BTC corr":>13}')
    print('-' * 50)
    for lam, r in sweep_results.items():
        ed = r['drift_ratios'].get('ETH', 0)
        bd = r['drift_ratios'].get('BTC', 0)
        ec = r['eth_btc_corr']
        print(f'{lam:>8.1f} | {ed:>9.1f}x {bd:>9.1f}x {ec:>12.3f} (real: 0.813)')