# Notebook B: TFT + RL (PPO) Training
**Run on Colab Pro+ H100** | Part 2 of 3 parallel sessions
- Trains Temporal Fusion Transformer (TFT)
- Trains PPO agent for portfolio allocation

In [1]:
# === ENVIRONMENT SETUP ===
import subprocess, sys, os

if not os.path.exists('/content/quant-lab'):
    subprocess.run(['git', 'clone', 'https://github.com/Mohit1053/quant-lab.git', '/content/quant-lab'], check=True)
os.chdir('/content/quant-lab')
subprocess.run([sys.executable, '-m', 'pip', 'install', '-q', '-e', '.'], check=True)

from google.colab import drive
drive.mount('/content/drive', force_remount=False)

from pathlib import Path
DRIVE_DIR = Path('/content/drive/MyDrive/quant_lab')
for d in ['data/raw', 'data/cleaned', 'data/features', 'outputs/models/tft', 'outputs/models/rl/ppo', 'outputs/mlruns']:
    (DRIVE_DIR / d).mkdir(parents=True, exist_ok=True)

import torch
if torch.cuda.is_available():
    gpu = torch.cuda.get_device_name(0)
    mem = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU: {gpu} ({mem:.1f} GB) | BF16: {torch.cuda.is_bf16_supported()}")
else:
    print("WARNING: No GPU!")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
GPU: NVIDIA A100-SXM4-80GB (85.1 GB) | BF16: True


In [3]:

!pip uninstall -y numpy pandas scipy scikit-learn
!pip install --no-cache-dir numpy==1.26.4 pandas==2.2.2 scipy==1.11.4 scikit-learn==1.4.2


Found existing installation: numpy 1.26.4
Uninstalling numpy-1.26.4:
  Successfully uninstalled numpy-1.26.4
Found existing installation: pandas 2.2.2
Uninstalling pandas-2.2.2:
  Successfully uninstalled pandas-2.2.2
Found existing installation: scipy 1.16.3
Uninstalling scipy-1.16.3:
  Successfully uninstalled scipy-1.16.3
Found existing installation: scikit-learn 1.6.1
Uninstalling scikit-learn-1.6.1:
  Successfully uninstalled scikit-learn-1.6.1
Collecting numpy==1.26.4
  Downloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m39.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting pandas==2.2.2
  Downloading pandas-2.2.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (19 kB)
Collecting scipy==1.11.4
  Downloading scipy-1.11.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (60 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━

In [6]:
# === LOAD DATA FROM DRIVE (cached by Notebook A) ===
import shutil, time

drive_features = DRIVE_DIR / 'data/features/nifty50_features.parquet'
local_features = Path('data/features/nifty50_features.parquet')

# Wait for Notebook A to cache data (max 10 min)
if not drive_features.exists():
    print("Waiting for Notebook A to cache data to Drive...")
    for i in range(60):  # 10 min max
        if drive_features.exists():
            break
        time.sleep(10)
        if i % 6 == 0:
            print(f"  Still waiting... ({i*10}s)")
    else:
        print("Timeout! Downloading data ourselves...")
        subprocess.run([sys.executable, 'scripts/ingest_data.py'], check=True)
        subprocess.run([sys.executable, 'scripts/compute_features.py'], check=True)

if drive_features.exists():
    Path('data/features').mkdir(parents=True, exist_ok=True)
    Path('data/cleaned').mkdir(parents=True, exist_ok=True)
    shutil.copy(drive_features, local_features)
    if (DRIVE_DIR / 'data/cleaned/nifty50_cleaned.parquet').exists():
        shutil.copy(DRIVE_DIR / 'data/cleaned/nifty50_cleaned.parquet', 'data/cleaned/nifty50_cleaned.parquet')
    print("Data loaded from Drive cache!")

import pandas as pd
df = pd.read_parquet(local_features)
print(f"Features: {df.shape[0]} rows, {df['ticker'].nunique()} tickers")

Data loaded from Drive cache!
Features: 177187 rows, 49 tickers


## Temporal Fusion Transformer (TFT)
GRN blocks + Variable Selection + LSTM encoder + interpretable multi-head attention
- H100 config: d_model=256, 4 heads, 2 layers, batch_size=128

In [13]:
# === TFT TRAINING (H100 optimized) ===
import time
import shutil
from pathlib import Path
from quant_lab.utils.seed import set_global_seed
from quant_lab.utils.device import get_device
from quant_lab.data.datasets import TemporalSplit
from quant_lab.data.datamodule import QuantDataModule, DataModuleConfig
from quant_lab.data.storage.parquet_store import ParquetStore
from quant_lab.features.engine import FeatureEngine
from quant_lab.models.tft.model import TFTForecaster, TFTConfig
from quant_lab.models.transformer.model import MultiTaskLoss, TransformerConfig
from quant_lab.training.trainer import Trainer, TrainerConfig

# -----------------------------
# Setup
# -----------------------------
set_global_seed(42)
device = get_device()

# Load features
store = ParquetStore(base_dir='data/features')
feature_df = store.load('nifty50_features')

engine = FeatureEngine(
    enabled_features=['log_returns', 'realized_volatility', 'momentum', 'max_drawdown'],
    windows={'short': [1, 5], 'medium': [21], 'long': [63]},
)
feature_cols = engine.get_feature_columns(feature_df)

split = TemporalSplit(train_end='2021-12-31', val_end='2023-06-30')
dm = QuantDataModule(
    feature_df, feature_cols, split,
    DataModuleConfig(sequence_length=63, target_col='log_return_1d', batch_size=128, num_workers=2),
)
dm.setup()
train_loader = dm.train_dataloader()
val_loader = dm.val_dataloader()

# -----------------------------
# TFT Model
# -----------------------------
tft_cfg = TFTConfig(
    num_features=dm.num_features,
    d_model=256,
    nhead=4,
    num_encoder_layers=2,
    lstm_layers=1,
    lstm_hidden=256,
    dropout=0.1,
    grn_hidden=128,
)

# Correct model class
model = TFTForecaster(tft_cfg)

# Loss function
loss_cfg = TransformerConfig(
    num_features=dm.num_features,
    d_model=256,
    distribution_type='gaussian',
    direction_num_classes=3,
    direction_threshold=0.005,
    volatility_enabled=True,
    distribution_weight=1.0,
    direction_weight=0.3,
    volatility_weight=0.3,
)
loss_fn = MultiTaskLoss(loss_cfg)

print(f"TFT parameters: {sum(p.numel() for p in model.parameters()):,}")

# -----------------------------
# Trainer
# -----------------------------
trainer_config = TrainerConfig(
    epochs=100,
    learning_rate=1e-3,
    weight_decay=1e-5,
    warmup_steps=500,
    max_grad_norm=1.0,
    patience=15,
    mixed_precision=True,
    checkpoint_dir='outputs/models/tft',
)

trainer = Trainer(model=model, loss_fn=loss_fn, config=trainer_config, device=device)

# -----------------------------
# Training
# -----------------------------
start = time.time()
history = trainer.fit(train_loader, val_loader)
elapsed = time.time() - start

print(f"\nTFT training done in {elapsed/60:.1f} min")
print(f"Final train loss: {history['train_loss'][-1]:.6f}")
if history['val_loss']:
    print(f"Best val loss: {min(history['val_loss']):.6f}")

# -----------------------------
# Save model to Drive
# -----------------------------
model_path = Path('outputs/models/tft/final_model.pt')
model.save(model_path)

drive_tft_dir = DRIVE_DIR / 'outputs/models/tft'
drive_tft_dir.mkdir(parents=True, exist_ok=True)

for f in Path('outputs/models/tft').glob('*'):
    shutil.copy(f, drive_tft_dir / f.name)

print("TFT model saved to Drive!")


2026-02-19 10:27:09 [info     ] using_gpu                      memory_gb=79.3 name='NVIDIA A100-SXM4-80GB'
2026-02-19 10:27:09 [info     ] parquet_loaded                 cols=23 path=data/features/nifty50_features.parquet rows=177187
2026-02-19 10:27:10 [debug    ] dataset_created                num_features=15 sequence_length=63 total_rows=2961 valid_samples=2710
2026-02-19 10:27:10 [debug    ] dataset_created                num_features=15 sequence_length=63 total_rows=2961 valid_samples=2710
2026-02-19 10:27:10 [debug    ] dataset_created                num_features=15 sequence_length=63 total_rows=2961 valid_samples=2710
2026-02-19 10:27:10 [debug    ] dataset_created                num_features=15 sequence_length=63 total_rows=2961 valid_samples=2710
2026-02-19 10:27:10 [debug    ] dataset_created                num_features=15 sequence_length=63 total_rows=2961 valid_samples=2710
2026-02-19 10:27:10 [debug    ] dataset_created                num_features=15 sequence_length=63 tot

In [None]:
import inspect
import quant_lab.models.tft.model as tft_model

print([name for name, obj in inspect.getmembers(tft_model)
       if inspect.isclass(obj)])


## RL Portfolio Allocation (PPO)
Proximal Policy Optimization with portfolio environment
- Reward = Sharpe - MDD penalty - trading costs - turnover penalty
- H100: 2M timesteps

In [14]:
# === RL PPO TRAINING (H100 optimized) ===
import time
import numpy as np
from quant_lab.rl.environments.portfolio_env import PortfolioEnvConfig
from quant_lab.rl.environments.reward import RewardConfig
from quant_lab.rl.training import train_rl, RLTrainingConfig

set_global_seed(42)

# Build feature tensors
base_cols = {'date', 'ticker', 'open', 'high', 'low', 'close', 'volume', 'adj_close'}
feat_cols = [c for c in feature_df.columns if c not in base_cols]

def build_feature_tensor(df, feat_cols, start, end):
    import pandas as pd
    df = df.copy()
    df['date'] = pd.to_datetime(df['date'])
    df = df[(df['date'] > start) & (df['date'] <= end)]
    if 'log_return_1d' not in df.columns:
        df['log_return_1d'] = df.groupby('ticker')['adj_close'].transform(lambda s: np.log(s / s.shift(1)))
    tickers = sorted(df['ticker'].unique())
    dates = sorted(df['date'].unique())
    features = np.zeros((len(dates), len(tickers), len(feat_cols)), dtype=np.float32)
    returns = np.zeros((len(dates), len(tickers)), dtype=np.float32)
    t_map = {t: i for i, t in enumerate(tickers)}
    d_map = {d: i for i, d in enumerate(dates)}
    for _, row in df.iterrows():
        ti, di = d_map[row['date']], t_map[row['ticker']]
        features[ti, di, :] = row[feat_cols].values.astype(np.float32)
        ret = row.get('log_return_1d', 0.0)
        returns[ti, di] = 0.0 if pd.isna(ret) else float(ret)
    return np.nan_to_num(features, nan=0.0), returns

train_features, train_returns = build_feature_tensor(feature_df, feat_cols, '1900-01-01', '2021-12-31')
val_features, val_returns = build_feature_tensor(feature_df, feat_cols, '2021-12-31', '2023-06-30')
print(f"Train: {train_features.shape}, Val: {val_features.shape}")

env_config = PortfolioEnvConfig(initial_cash=1_000_000, max_weight=0.20, rebalance_frequency=5)
reward_config = RewardConfig(lambda_mdd=0.5, lambda_turnover=0.01, commission_bps=10.0, slippage_bps=5.0, spread_bps=5.0)
training_config = RLTrainingConfig(
    algorithm='ppo',
    total_timesteps=2_000_000,
    eval_freq=50_000,
    n_eval_episodes=5,
    checkpoint_dir='outputs/models/rl/ppo',
)

start = time.time()
result = train_rl(
    train_features=train_features, train_returns=train_returns,
    val_features=val_features, val_returns=val_returns,
    config=training_config, env_config=env_config, reward_config=reward_config,
    device='auto',
)
elapsed = time.time() - start

print(f"\nPPO training done in {elapsed/60:.1f} min")
for k, v in result['train_metrics'].items():
    print(f"  Train {k}: {v:.4f}")

# Save to Drive
import shutil
for f in Path('outputs/models/rl/ppo').glob('*'):
    (DRIVE_DIR / 'outputs/models/rl/ppo').mkdir(parents=True, exist_ok=True)
    shutil.copy(f, DRIVE_DIR / 'outputs/models/rl/ppo' / f.name)
print("PPO agent saved to Drive!")

Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.
  return datetime.utcnow().replace(tzinfo=utc)
  return datetime.utcnow().replace(tzinfo=utc)
  return datetime.utcnow().replace(tzinfo=utc)
  return datetime.utcnow().replace(tzinfo=utc)
  return datetime.utcnow().replace(tzinfo=utc)
  return datetime.utcnow().replace(tzinfo=utc)
  return datetime.utcnow().replace(tzinfo=utc)
  return datetime.utcnow().replace(tzinfo=utc)
  return datetime.utcnow().replace(tzinfo=utc)
  return datetime.utcnow().replace(tzinfo=utc)
  return datetime.utcnow().replace(tzinfo=utc)
  return datetime.utcnow().replace(tzinfo=utc)
  return datetime.utcnow().replace(tzinfo=utc)
  return datetime.utcnow()

Train: (2962, 49, 15), Val: (370, 49, 15)
2026-02-19 11:08:27 [info     ] ppo_agent_created              lr=0.0003 policy=MlpPolicy
2026-02-19 11:08:27 [info     ] rl_training_start              algorithm=ppo num_assets=49 num_steps=2962 total_timesteps=2000000
2026-02-19 11:08:27 [info     ] ppo_training_start             total_timesteps=2000000


  return datetime.utcnow().replace(tzinfo=utc)
  return datetime.utcnow().replace(tzinfo=utc)


2026-02-19 12:13:38 [info     ] ppo_training_complete
2026-02-19 12:13:47 [info     ] rl_train_eval                  mean_final_value=9686950.561765622 mean_reward=-83.49495207718385 std_final_value=0.0 std_reward=0.0
2026-02-19 12:13:48 [info     ] rl_val_eval                    mean_final_value=1220236.5429536235 mean_reward=-10.896220953021965 std_final_value=0.0 std_reward=0.0
2026-02-19 12:13:48 [info     ] ppo_saved                      path=outputs/models/rl/ppo/ppo_agent
2026-02-19 12:13:48 [info     ] rl_training_complete

PPO training done in 65.4 min
  Train mean_reward: -83.4950
  Train std_reward: 0.0000
  Train mean_final_value: 9686950.5618
  Train std_final_value: 0.0000
PPO agent saved to Drive!


In [15]:
print("=" * 60)
print("NOTEBOOK B COMPLETE")
print("=" * 60)
print(f"TFT final loss: {history['train_loss'][-1]:.6f}")
print(f"PPO metrics: {result['train_metrics']}")
print(f"\nModels on Drive:")
for d in ['outputs/models/tft', 'outputs/models/rl/ppo']:
    p = DRIVE_DIR / d
    if p.exists():
        for f in p.glob('*'):
            print(f"  {f.relative_to(DRIVE_DIR)}: {f.stat().st_size/1e6:.1f} MB")
print("=" * 60)

NOTEBOOK B COMPLETE
TFT final loss: -3.038316
PPO metrics: {'mean_reward': -83.49495207718385, 'std_reward': 0.0, 'mean_final_value': 9686950.561765622, 'std_final_value': 0.0}

Models on Drive:
  outputs/models/tft/last.pt: 427.5 MB
  outputs/models/tft/best.pt: 427.5 MB
  outputs/models/tft/final_model.pt: 142.5 MB
  outputs/models/rl/ppo/ppo_agent.zip: 1.4 MB


In [16]:
import shutil
from pathlib import Path
import os

# Re-define DRIVE_DIR for scope safety if this cell runs independently
# Assumes DRIVE_DIR is already defined by previous setup cells
# If not, uncomment and set it:
# DRIVE_DIR = Path('/content/drive/MyDrive/quant_lab')

local_outputs_root = Path('outputs') # Assumes current working directory is /content/quant-lab
drive_outputs_root = DRIVE_DIR / 'outputs'

# Ensure the root outputs directory on Drive exists
drive_outputs_root.mkdir(parents=True, exist_ok=True)

print(f"Checking for outputs in '{local_outputs_root}' to save to Drive at '{drive_outputs_root}'...")

# Handle MLflow logs explicitly
local_mlruns_path = local_outputs_root / 'mlruns'
drive_mlruns_path = drive_outputs_root / 'mlruns'

if local_mlruns_path.exists():
    if not drive_mlruns_path.exists():
        print(f"  Copying new MLflow logs directory '{local_mlruns_path.name}' to Drive...")
        shutil.copytree(local_mlruns_path, drive_mlruns_path)
    else:
        print(f"  MLflow logs directory '{local_mlruns_path.name}' already exists on Drive. Merging content...")
        # A simple merge: iterate through local mlruns and copy files, overwriting if newer
        for src_dir, dirs, files in os.walk(local_mlruns_path):
            relative_path = Path(src_dir).relative_to(local_mlruns_path)
            dst_dir = drive_mlruns_path / relative_path
            dst_dir.mkdir(parents=True, exist_ok=True)
            for file_name in files:
                shutil.copy2(Path(src_dir) / file_name, dst_dir / file_name)
        print(f"  MLflow logs content from '{local_mlruns_path.name}' merged to Drive.")
else:
    print(f"  No local MLflow logs directory found at '{local_mlruns_path}'.")

# Handle any other direct files/directories in 'outputs/' not covered by 'models'
# (assuming 'models' contents are handled by previous cells' explicit copies)
if local_outputs_root.exists():
    for item in local_outputs_root.iterdir():
        if item.name in ['models', 'mlruns']:
            continue # Already handled or explicitly skipped

        drive_item_path = drive_outputs_root / item.name

        if item.is_dir():
            if not drive_item_path.exists():
                print(f"  Copying other directory '{item.name}' to Drive...")
                shutil.copytree(item, drive_item_path)
            else:
                print(f"  Directory '{item.name}' already exists on Drive. Skipping. Manual sync may be needed.")
        elif item.is_file():
            print(f"  Copying other file '{item.name}' to Drive...")
            shutil.copy2(item, drive_item_path)

print("Finished ensuring all relevant outputs are saved to Drive.")


Checking for outputs in 'outputs' to save to Drive at '/content/drive/MyDrive/quant_lab/outputs'...
  No local MLflow logs directory found at 'outputs/mlruns'.
Finished ensuring all relevant outputs are saved to Drive.
