# Notebook F: Full Indian Market DL Training + Walk-Forward
**Run on Colab Pro+ H100** | Trains Transformer & TFT on ALL NSE EQ-series stocks (~1500+ after liquidity filters), runs walk-forward validation

**Prerequisites:** Upload `indian_market_features.parquet` to Google Drive at `My Drive/quant_lab/data/features/`

In [None]:
# === ENVIRONMENT SETUP ===
import subprocess, sys, os

os.chdir('/content')

if os.path.exists('/content/quant-lab'):
    print('Removing existing quant-lab directory...')
    subprocess.run(['rm', '-rf', '/content/quant-lab'])

print('Cloning repository...')
result = subprocess.run(
    ['git', 'clone', 'https://github.com/Mohit1053/quant-lab.git', '/content/quant-lab'],
    capture_output=True, text=True
)
if result.returncode != 0:
    print(f'Clone failed: {result.stderr}')
    raise RuntimeError('Git clone failed')
print('Clone successful.')

os.chdir('/content/quant-lab')
subprocess.run([sys.executable, '-m', 'pip', 'install', '-q', '-e', '.'], check=True)
print('Package installed.')

from google.colab import drive
drive.mount('/content/drive', force_remount=False)

from pathlib import Path
DRIVE_DIR = Path('/content/drive/MyDrive/quant_lab')
for d in ['data/features', 'data/cleaned',
          'outputs/models/transformer_fullmkt', 'outputs/models/tft_fullmkt',
          'outputs/walk_forward/transformer_fullmkt']:
    (DRIVE_DIR / d).mkdir(parents=True, exist_ok=True)

import torch
if torch.cuda.is_available():
    gpu = torch.cuda.get_device_name(0)
    mem = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f'GPU: {gpu} ({mem:.1f} GB) | BF16: {torch.cuda.is_bf16_supported()}')
else:
    print('WARNING: No GPU!')

In [None]:
# === FIX NUMPY/SCIPY VERSIONS ===
!pip uninstall -y numpy pandas scipy scikit-learn
!pip install --no-cache-dir numpy==1.26.4 pandas==2.2.2 scipy==1.11.4 scikit-learn==1.4.2

## Data Loading — Important Note

**NSE blocks Colab datacenter IPs**, so data must be pre-computed locally.

**Steps:**
1. Run `python scripts/run_indian_market_pipeline.py` on your local machine
2. Upload `indian_market_features.parquet` (~500+ MB) to Google Drive at `My Drive/quant_lab/data/features/`
3. The cell below will auto-detect it from Drive

In [None]:
# === LOAD FULL MARKET DATA ===
import shutil
import pandas as pd

drive_features = DRIVE_DIR / 'data/features/indian_market_features.parquet'
local_features = Path('data/features/indian_market_features.parquet')
Path('data/features').mkdir(parents=True, exist_ok=True)

loaded = False

# --- Option 1: Google Drive ---
if drive_features.exists():
    shutil.copy(drive_features, local_features)
    print(f'Full market data loaded from Drive! ({drive_features.stat().st_size/1e6:.0f} MB)')
    loaded = True

# --- Option 2: File upload ---
if not loaded:
    print('No data on Drive. Upload indian_market_features.parquet:')
    print('  1. Go to drive.google.com')
    print('  2. Navigate to My Drive/quant_lab/data/features/')
    print('  3. Upload indian_market_features.parquet')
    print('  4. Re-run this cell')
    print()
    try:
        from google.colab import files
        uploaded = files.upload()
        for fname, content in uploaded.items():
            if 'indian_market' in fname and fname.endswith('.parquet'):
                with open(str(local_features), 'wb') as f:
                    f.write(content)
                (DRIVE_DIR / 'data/features').mkdir(parents=True, exist_ok=True)
                shutil.copy(local_features, drive_features)
                print(f'Uploaded {fname} — saved to Drive!')
                loaded = True
                break
    except Exception as e:
        print(f'Upload failed: {e}')

if not loaded:
    raise FileNotFoundError(
        'Could not load full market data.\n'
        'Run scripts/run_indian_market_pipeline.py locally, then upload\n'
        f'indian_market_features.parquet to {drive_features}'
    )

df = pd.read_parquet(local_features)
print(f'\nFull Indian Market: {df.shape[0]:,} rows, {df["ticker"].nunique()} tickers, {df["date"].nunique()} trading days')

## Train Transformer on Full Market

In [None]:
# === TRAIN TRANSFORMER ON FULL MARKET ===
import torch
import numpy as np
import time
from quant_lab.utils.seed import set_global_seed
from quant_lab.utils.device import get_device
from quant_lab.data.datasets import TemporalSplit
from quant_lab.data.datamodule import QuantDataModule, DataModuleConfig
from quant_lab.models.transformer.model import TransformerForecaster, TransformerConfig, MultiTaskLoss
from quant_lab.training.trainer import Trainer, TrainerConfig

set_global_seed(42)
device = get_device()

base_cols = {'date', 'ticker', 'open', 'high', 'low', 'close', 'volume', 'adj_close'}
feature_cols = [c for c in df.columns if c not in base_cols]
split = TemporalSplit(train_end='2021-12-31', val_end='2023-06-30')

dm = QuantDataModule(
    df, feature_cols, split,
    DataModuleConfig(sequence_length=63, target_col='log_return_1d', batch_size=256, num_workers=2),
)
dm.setup()
train_loader = dm.train_dataloader()
val_loader = dm.val_dataloader()
print(f'Train batches: {len(train_loader)}, Val batches: {len(val_loader)}, Features: {dm.num_features}')

# Larger model for bigger dataset
model_cfg = TransformerConfig(
    num_features=dm.num_features, d_model=256, nhead=8,
    num_encoder_layers=6, dim_feedforward=1024, dropout=0.1,
    direction_weight=0.3, volatility_weight=0.3,
)
model = TransformerForecaster(model_cfg)
loss_fn = MultiTaskLoss(model_cfg)
print(f'Transformer params: {model.count_parameters():,}')

trainer_config = TrainerConfig(
    epochs=100, learning_rate=1e-4, weight_decay=1e-5,
    warmup_steps=3000, patience=10, mixed_precision=True,
    checkpoint_dir='outputs/models/transformer_fullmkt',
)
trainer = Trainer(model, loss_fn, trainer_config, device)

start = time.time()
trainer.fit(train_loader, val_loader)
elapsed = time.time() - start
print(f'\nTransformer training done in {elapsed/60:.1f} min')

# Save to Drive
for f in Path('outputs/models/transformer_fullmkt').glob('*.pt'):
    shutil.copy(f, DRIVE_DIR / 'outputs/models/transformer_fullmkt' / f.name)
print('Transformer saved to Drive!')

## Train TFT-small on Full Market

In [None]:
# === TRAIN TFT ON FULL MARKET (SMALLER ARCHITECTURE) ===
from quant_lab.models.tft.model import TFTForecaster, TFTConfig

set_global_seed(42)

tft_cfg = TFTConfig(
    num_features=dm.num_features, d_model=32, nhead=4,
    num_encoder_layers=1, lstm_layers=1, lstm_hidden=32,
    grn_hidden=16, dropout=0.3,
    direction_weight=0.3, volatility_weight=0.3,
)
model = TFTForecaster(tft_cfg)
loss_cfg = TransformerConfig(num_features=dm.num_features, direction_weight=0.3, volatility_weight=0.3)
loss_fn = MultiTaskLoss(loss_cfg)
print(f'TFT params: {sum(p.numel() for p in model.parameters()):,}')

trainer_config = TrainerConfig(
    epochs=100, learning_rate=3e-4, weight_decay=1e-3,
    warmup_steps=500, patience=15, mixed_precision=True,
    checkpoint_dir='outputs/models/tft_fullmkt',
)
trainer = Trainer(model, loss_fn, trainer_config, device)

start = time.time()
trainer.fit(train_loader, val_loader)
elapsed = time.time() - start
print(f'\nTFT training done in {elapsed/60:.1f} min')

# Verify no mode collapse
model.eval()
x = torch.randn(20, 63, dm.num_features).to(device)
with torch.no_grad():
    preds = model.predict_returns(x)
print(f'Signal std: {preds.std():.6f} (should be >> 0)')

# Save to Drive
for f in Path('outputs/models/tft_fullmkt').glob('*.pt'):
    shutil.copy(f, DRIVE_DIR / 'outputs/models/tft_fullmkt' / f.name)
print('TFT saved to Drive!')

## Walk-Forward Validation (Transformer)

In [None]:
# === WALK-FORWARD VALIDATION (TRANSFORMER) ===
from quant_lab.backtest.walk_forward import WalkForwardEngine, WalkForwardConfig, WindowType
from quant_lab.backtest.engine import BacktestConfig
from quant_lab.data.datasets import create_flat_datasets

prices_df = df[['date', 'ticker', 'adj_close']].copy()

wf_config = WalkForwardConfig(
    window_type=WindowType.EXPANDING,
    train_days=756, val_days=126, test_days=126,
    step_days=126, min_train_days=504,
)
# top_n=20 for broader universe (was 5 for NIFTY 50, 10 for NIFTY 500)
backtest_cfg = BacktestConfig(initial_capital=1_000_000, rebalance_frequency=5, top_n=20)

def make_transformer_factory():
    def factory(split, feature_df, feat_cols):
        dm = QuantDataModule(
            feature_df, feat_cols, split,
            DataModuleConfig(sequence_length=63, target_col='log_return_1d', batch_size=256, num_workers=2),
        )
        dm.setup()
        tl = dm.train_dataloader()
        vl = dm.val_dataloader()
        if tl is None:
            return None, pd.DataFrame()

        cfg = TransformerConfig(
            num_features=dm.num_features, d_model=256, nhead=8,
            num_encoder_layers=6, dim_feedforward=1024, dropout=0.1,
            direction_weight=0.3, volatility_weight=0.3,
        )
        m = TransformerForecaster(cfg)
        loss_fn = MultiTaskLoss(cfg)
        tc = TrainerConfig(
            epochs=30, learning_rate=1e-4, weight_decay=1e-5,
            warmup_steps=1000, patience=5, mixed_precision=True,
            checkpoint_dir='outputs/walk_forward/transformer_fullmkt',
        )
        t = Trainer(m, loss_fn, tc, device)
        t.fit(tl, vl)

        test_loader = dm.test_dataloader()
        if test_loader is None:
            return m, pd.DataFrame()
        m.eval()
        all_preds = []
        with torch.no_grad():
            for x, _ in test_loader:
                x = x.to(device)
                all_preds.append(m.predict_returns(x).cpu().numpy())
        test_preds = np.concatenate(all_preds)

        datasets = create_flat_datasets(feature_df, feat_cols, split, target_col='log_return_1d')
        _, _, meta_test = datasets['test']
        meta_test = meta_test.iloc[-len(test_preds):]
        signals = meta_test.copy()
        signals['signal'] = test_preds
        return m, signals
    return factory

print('Starting full-market Transformer walk-forward (this takes 3-6 hours)...')
start = time.time()
engine = WalkForwardEngine(wf_config, backtest_cfg)
wf_result = engine.run(df, feature_cols, prices_df, make_transformer_factory())
elapsed = time.time() - start

print(f'\nWalk-forward done in {elapsed/60:.1f} min ({len(wf_result.fold_results)} folds)')
print(f'\nAggregate metrics:')
for k, v in wf_result.aggregate_metrics.items():
    if 'return' in k or 'cagr' in k or 'drawdown' in k:
        print(f'  {k:25s}: {v:>10.2%}')
    else:
        print(f'  {k:25s}: {v:>10.4f}')

print(f'\nPer-fold:')
display_cols = ['fold', 'test_start', 'test_end', 'sharpe', 'total_return', 'max_drawdown']
available = [c for c in display_cols if c in wf_result.per_fold_metrics.columns]
print(wf_result.per_fold_metrics[available].to_string(index=False))

# Save to Drive
wf_out = DRIVE_DIR / 'outputs/walk_forward/transformer_fullmkt'
wf_out.mkdir(parents=True, exist_ok=True)
wf_result.per_fold_metrics.to_csv(wf_out / 'per_fold_metrics.csv', index=False)
wf_result.aggregate_equity.to_frame('equity').to_parquet(wf_out / 'aggregate_equity.parquet')
print(f'\nResults saved to Drive!')

## Summary

In [None]:
print('=' * 60)
print('NOTEBOOK F COMPLETE — FULL INDIAN MARKET')
print('=' * 60)
print(f'\nUniverse: {df["ticker"].nunique()} stocks (all NSE EQ-series after liquidity filter)')
print(f'\nAll outputs on Drive:')
for d in ['outputs/models/transformer_fullmkt', 'outputs/models/tft_fullmkt', 'outputs/walk_forward/transformer_fullmkt']:
    p = DRIVE_DIR / d
    if p.exists():
        for f in sorted(p.glob('*')):
            if f.is_file():
                print(f'  {f.relative_to(DRIVE_DIR)}: {f.stat().st_size/1e6:.1f} MB')
print(f'\nWalk-forward Sharpe: {wf_result.aggregate_metrics.get("sharpe", "N/A"):.4f}')
print(f'Walk-forward CAGR: {wf_result.aggregate_metrics.get("cagr", "N/A"):.2%}')
print('=' * 60)