# TemporalFusionSignalNet — Colab Training

Train multi-modal trading signal models on Google Colab GPU.

**Before running:**
1. Runtime → Change runtime type → **T4 GPU** (or A100 if available)
2. Add your GitHub SSH key or use HTTPS clone
3. Optionally set `FRED_API_KEY` for macro data

## 1. Setup — Clone Repo & Install Dependencies

In [None]:
# Mount Google Drive for persistent storage of checkpoints
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Clone the repository
!git clone https://github.com/amirgolp/fuzzy-barnacle.git /content/quant
%cd /content/quant

In [None]:
# Install with ML dependencies
!pip install -e ".[ml]" -q

In [None]:
# Verify setup
import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")

from quantdash.ml.models.signal_net import TemporalFusionSignalNet
from quantdash.ml.config import ModelConfig, ASSET_CONFIGS
print(f"\nAssets: {list(ASSET_CONFIGS.keys())}")
print("Setup OK ✓")

## 2. Configuration

In [None]:
# ===== CONFIGURE THESE =====
SYMBOL = "GC=F"  # Asset to train: GC=F, BTC-USD, CL=F, SPY, QQQ, AAPL
EPOCHS = 50
BATCH_SIZE = 256  # Colab GPU can handle larger batches
LEARNING_RATE = 1e-3
RUN_RL = False  # Set True to run PPO fine-tuning after supervised
FETCH_CROSS_ASSETS = True  # Fetch correlated asset data

# Optional: Set FRED API key for macro data (or leave empty for zero-fill)
import os
os.environ["FRED_API_KEY"] = ""  # Get free key at https://fred.stlouisfed.org/docs/api/api_key.html

# Checkpoint directory (Google Drive for persistence)
SAVE_DIR = f"/content/drive/MyDrive/quant_models/{SYMBOL.replace('=','').replace('-','_')}"
os.makedirs(SAVE_DIR, exist_ok=True)
print(f"Training {SYMBOL}, saving to {SAVE_DIR}")

## 3. Build Dataset

In [None]:
import logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(message)s")

from pathlib import Path
from quantdash.ml.config import ASSET_CONFIGS
from quantdash.ml.data.builder import build_dataset, get_dataset_path
from quantdash.ml.data.macro import fetch_all_macro

asset_config = ASSET_CONFIGS[SYMBOL]

# Fetch macro data if API key is set
macro_df = None
if os.environ.get("FRED_API_KEY"):
    try:
        macro_df = fetch_all_macro()
        print(f"Macro data: {len(macro_df)} observations, {macro_df.shape[1]} series")
    except Exception as e:
        print(f"Macro fetch failed: {e}, using zeros")

# Build dataset (fetches OHLCV, computes features, labels, etc.)
dataset_path = Path(SAVE_DIR) / "dataset.h5"
dataset = build_dataset(
    symbol=SYMBOL,
    macro_df=macro_df,
    save_path=dataset_path,
    fetch_cross=FETCH_CROSS_ASSETS,
)
print(f"\nDataset: {len(dataset)} samples")

In [None]:
# Label distribution
from quantdash.ml.data.labeling import label_distribution
import pandas as pd

dist = label_distribution(pd.Series(dataset.labels))
print(f"Labels — BUY: {dist['buy_pct']}%, HOLD: {dist['hold_pct']}%, SELL: {dist['sell_pct']}%")
print(f"Total valid: {dist['total']}")

## 4. Create Model

In [None]:
from quantdash.ml.models.signal_net import TemporalFusionSignalNet

# Update config with actual feature dimensions from dataset
config = asset_config.arch_config.model_copy()
config.price_channels = dataset.price_features.shape[1]
config.num_pattern_features = dataset.pattern_features.shape[1]
config.num_macro_features = dataset.macro_session_features.shape[1]

if dataset.cross_asset_features is not None:
    config.cross_asset_channels = dataset.cross_asset_features.shape[1]

model = TemporalFusionSignalNet(config)
print(f"Model parameters: {model.count_parameters():,}")
print(f"Price channels: {config.price_channels}")
print(f"Pattern features: {config.num_pattern_features}")
print(f"Macro features: {config.num_macro_features}")
print(f"Cross-asset channels: {config.cross_asset_channels}")

## 5. Supervised Training (Phase 1)

In [None]:
from quantdash.ml.data.splits import walk_forward_splits
from quantdash.ml.training.supervised import SupervisedTrainer

# Walk-forward splits
folds = walk_forward_splits(
    len(dataset.labels),
    config=asset_config.walk_forward_config,
)
print(f"Walk-forward folds: {len(folds)}")
for f in folds:
    print(f"  Fold {f.fold_idx}: train=[{f.train_start}:{f.train_end}] ({f.train_size} bars), "
          f"val=[{f.val_start}:{f.val_end}] ({f.val_size} bars)")

In [None]:
# Override training config for Colab
train_config = asset_config.training_config.model_copy()
train_config.max_epochs = EPOCHS
train_config.batch_size = BATCH_SIZE
train_config.learning_rate = LEARNING_RATE
train_config.use_fp16 = torch.cuda.is_available()  # FP16 only with GPU
train_config.num_workers = 2  # Colab has limited CPUs

trainer = SupervisedTrainer(
    model=model,
    train_config=train_config,
    device="cuda" if torch.cuda.is_available() else "cpu",
    save_dir=Path(SAVE_DIR),
)

print(f"Device: {trainer.device}")
print(f"FP16: {train_config.use_fp16}")
print(f"Batch size: {train_config.batch_size}")
print(f"Max epochs: {train_config.max_epochs}")

In [None]:
# Train!
history = trainer.train_walk_forward(dataset, folds)
print("\nSupervised training complete!")

In [None]:
# Plot training curves
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

tracker = history["tracker"]

# Find all fold keys
fold_indices = set()
for key in tracker:
    if key.startswith("fold"):
        fold_indices.add(int(key.split("_")[0].replace("fold", "")))

for fold_idx in sorted(fold_indices):
    train_loss = tracker.get(f"fold{fold_idx}_train_loss", [])
    val_loss = tracker.get(f"fold{fold_idx}_val_loss", [])
    val_f1 = tracker.get(f"fold{fold_idx}_val_f1", [])

    if train_loss:
        axes[0].plot(train_loss, label=f"Fold {fold_idx} train", alpha=0.7)
        axes[0].plot(val_loss, label=f"Fold {fold_idx} val", linestyle="--", alpha=0.7)
    if val_f1:
        axes[1].plot(val_f1, label=f"Fold {fold_idx}")

axes[0].set_title("Loss")
axes[0].set_xlabel("Epoch")
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].set_title("Validation F1 (macro)")
axes[1].set_xlabel("Epoch")
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f"{SAVE_DIR}/training_curves.png", dpi=150)
plt.show()
print(f"Saved to {SAVE_DIR}/training_curves.png")

## 6. RL Fine-tuning (Phase 2) — Optional

In [None]:
if RUN_RL:
    from quantdash.ml.training.rl_finetune import train_ppo

    rl_metrics = train_ppo(
        model=trainer.model,
        dataset=dataset,
        rl_config=asset_config.rl_config,
        fee_bps=asset_config.labeling_config.fee_bps,
        device="cuda" if torch.cuda.is_available() else "cpu",
        save_dir=Path(SAVE_DIR),
    )
    print(f"RL fine-tuning complete: {rl_metrics}")
else:
    print("Skipping RL fine-tuning (set RUN_RL = True to enable)")

## 7. Evaluate Model

In [None]:
from quantdash.ml.training.callbacks import CheckpointSaver
import numpy as np

# Load best checkpoint
best_path = Path(SAVE_DIR) / "model_best.pt"
if best_path.exists():
    state = torch.load(best_path, map_location="cpu", weights_only=True)
    model.load_state_dict(state["model_state_dict"])
    print(f"Loaded best model (epoch {state['epoch']}, metric={state['metric_value']:.4f})")
else:
    print("No checkpoint found, using last model state")

model.eval()
model.to("cuda" if torch.cuda.is_available() else "cpu")

# Run predictions on full dataset
from torch.utils.data import DataLoader

loader = DataLoader(dataset, batch_size=256, shuffle=False)
all_preds = []
all_targets = []
all_confs = []

device = next(model.parameters()).device

with torch.no_grad():
    for batch in loader:
        out = model.predict(
            batch["price"].to(device),
            batch["volume"].to(device),
            batch["pattern"].to(device),
            batch["news"].to(device),
            batch["macro"].to(device),
            batch["cross_asset"].to(device),
        )
        all_preds.extend(out["action"].cpu().tolist())
        all_targets.extend(batch["label"].tolist())
        all_confs.extend(out["confidence"].cpu().tolist())

preds = np.array(all_preds)
targets = np.array(all_targets)
confs = np.array(all_confs)

from sklearn.metrics import classification_report, confusion_matrix

print("\nClassification Report:")
print(classification_report(targets, preds, target_names=["SELL", "HOLD", "BUY"]))

print("Confusion Matrix:")
print(confusion_matrix(targets, preds))

print(f"\nMean confidence: {confs.mean():.3f}")
print(f"Confidence when correct: {confs[preds == targets].mean():.3f}")
print(f"Confidence when wrong: {confs[preds != targets].mean():.3f}")

## 8. Train All Assets (Batch)

In [None]:
# Uncomment to train all 6 assets sequentially
# WARNING: This takes a long time — each asset ~30-60 min on T4

# ALL_SYMBOLS = ["GC=F", "BTC-USD", "CL=F", "SPY", "QQQ", "AAPL"]
#
# for sym in ALL_SYMBOLS:
#     print(f"\n{'='*60}")
#     print(f"Training {sym}")
#     print(f"{'='*60}")
#
#     sym_config = ASSET_CONFIGS[sym]
#     sym_save = f"/content/drive/MyDrive/quant_models/{sym.replace('=','').replace('-','_')}"
#     os.makedirs(sym_save, exist_ok=True)
#
#     sym_dataset = build_dataset(symbol=sym, save_path=Path(sym_save) / "dataset.h5")
#     sym_model = TemporalFusionSignalNet(sym_config.arch_config)
#
#     sym_folds = walk_forward_splits(len(sym_dataset.labels), config=sym_config.walk_forward_config)
#     sym_trainer = SupervisedTrainer(
#         model=sym_model, train_config=train_config,
#         device="cuda", save_dir=Path(sym_save),
#     )
#     sym_trainer.train_walk_forward(sym_dataset, sym_folds)
#     print(f"{sym} done — checkpoint at {sym_save}/model_best.pt")

## 9. Download Checkpoint

Checkpoints are saved to Google Drive at:
```
/content/drive/MyDrive/quant_models/{symbol}/model_best.pt
```

To use locally for inference:
1. Download `model_best.pt` from Google Drive
2. Place in `models/` directory
3. Run: `python scripts/run_inference.py --symbols GC=F`

In [None]:
# Or download directly from Colab
from google.colab import files

best_path = f"{SAVE_DIR}/model_best.pt"
if os.path.exists(best_path):
    files.download(best_path)
    print(f"Downloading {best_path}")
else:
    print("No checkpoint to download")