# Dual Branch Model Training
Training notebook for the DualBranchModel (Summary + VPIN LSTM)

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.nn.utils.rnn as rnn_utils
import pandas as pd
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# Import CTAFlow modules
from CTAFlow.models.deep_learning.multi_branch.dual_model import DualBranchModel
from CTAFlow.models.intraday_momentum import DeepIDMomentum
from CTAFlow.data.model_datasets import DualDataset

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

In [None]:
# Configuration
DATA_DIR = Path('/content/drive/MyDrive/cl')

# File paths - adjust these to match your files
INTRADAY_PATH = DATA_DIR / 'cl_intraday.csv'  # Intraday OHLCV data (required for DeepIDMomentum)
CL_FEATURES_PATH = DATA_DIR / 'cl_features.csv'  # Summary features
VPIN_PATH = DATA_DIR / 'CL_vpin.parquet'  # Sequential data (VPIN)
TARGET_PATH = DATA_DIR / 'targets.csv'  # Optional, set to None if targets in cl_features

# Model hyperparameters
MAX_SEQ_LEN = 200
LSTM_HIDDEN = 64
DENSE_HIDDEN = 128
BATCH_SIZE = 32
LEARNING_RATE = 1e-3
EPOCHS = 50
VAL_SPLIT = 0.2

# Normalization configuration
NORMALIZE_SEQUENTIAL = True  # Whether to normalize sequential features using DeepIDMomentum
SEQUENTIAL_COLS = ['vpin', 'bucket_return', 'log_duration']  # Columns to use from sequential data

## Data Loading & Normalization

We use `DeepIDMomentum` to load and normalize the sequential data, then create a `DualDataset` for training.

In [None]:
# Load data using DeepIDMomentum for normalization
print("Loading data with DeepIDMomentum...")

# Initialize DeepIDMomentum to leverage its normalization capabilities
deep_model = DeepIDMomentum.from_files(
    intraday_path=str(INTRADAY_PATH),
    features_path=str(CL_FEATURES_PATH),
    sequential_path=str(VPIN_PATH),
    target_path=str(TARGET_PATH) if TARGET_PATH and TARGET_PATH.exists() else None,
    target_col='Target'
)

# Normalize sequential features if enabled
if NORMALIZE_SEQUENTIAL:
    print("Normalizing sequential features...")
    # Get sequential columns that exist in the data
    available_cols = [col for col in SEQUENTIAL_COLS if col in deep_model.sequential_data.columns]
    
    # DeepIDMomentum's normalize_sequential_features handles market profile data
    # For VPIN data, we'll apply standard scaling
    from sklearn.preprocessing import StandardScaler
    
    scaler = StandardScaler()
    deep_model.sequential_data[available_cols] = scaler.fit_transform(
        deep_model.sequential_data[available_cols]
    )
    deep_model.training_data['sequential'] = deep_model.sequential_data
    
    print(f"Normalized columns: {available_cols}")

print("DeepIDMomentum data loaded and normalized successfully")

In [None]:
# Create DualDataset from DeepIDMomentum data
print("Creating DualDataset...")

# Extract the normalized data from DeepIDMomentum
summary_data = deep_model.training_data['summary']
sequential_data = deep_model.training_data['sequential']
target_data = deep_model.target_data

# Create DualDataset
dataset = DualDataset(
    summary_data=summary_data,
    sequential_data=sequential_data,
    target_data=target_data,
    max_len=MAX_SEQ_LEN,
    sequential_cols=SEQUENTIAL_COLS,
    target_col='Target' if 'Target' in summary_data.columns else None,
    date_col='Datetime' if 'Datetime' in summary_data.columns else None
)

print(f"Summary features: {len(dataset.feature_cols)}")
print(f"Sequential features: {dataset.n_sequential_features}")
print(f"Total samples: {len(dataset)}")
print(f"Feature columns: {dataset.feature_cols}")

## Collate Function

Custom collate function for DualDataset to handle variable-length sequences.

In [None]:
def collate_fn(batch):
    """
    Custom collate function to handle variable-length sequences from DualDataset.
    Pads sequences to the max length in the batch.
    """
    summaries, sequential_seqs, targets, lengths = zip(*batch)
    
    # Stack summaries and targets
    summaries = torch.stack(summaries)
    targets = torch.stack(targets)
    lengths = torch.tensor(lengths)
    
    # Pad sequential sequences
    sequential_padded = rnn_utils.pad_sequence(sequential_seqs, batch_first=True, padding_value=0.0)
    
    return summaries, sequential_padded, targets, lengths

## Train/Validation Split

In [None]:
# Train/Val split (time-based)
n_samples = len(dataset)
n_val = int(n_samples * VAL_SPLIT)
n_train = n_samples - n_val

# Use sequential split for time series
train_indices = list(range(n_train))
val_indices = list(range(n_train, n_samples))

train_dataset = torch.utils.data.Subset(dataset, train_indices)
val_dataset = torch.utils.data.Subset(dataset, val_indices)

print(f"Train samples: {len(train_dataset)}")
print(f"Val samples: {len(val_dataset)}")

## Create DataLoaders

In [None]:
# Create DataLoaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=2
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=2
)

In [None]:
# Test batch
summary, sequential, target, lengths = next(iter(train_loader))
print(f"Summary shape: {summary.shape}")
print(f"Sequential shape: {sequential.shape}")
print(f"Target shape: {target.shape}")
print(f"Lengths (first 5): {lengths[:5]}")

## Training

In [None]:
# Initialize DualBranchModel (imported from CTAFlow)
model = DualBranchModel(
    summary_input_dim=len(dataset.feature_cols),
    vpin_input_dim=dataset.n_sequential_features,
    lstm_hidden_dim=LSTM_HIDDEN,
    dense_hidden_dim=DENSE_HIDDEN
).to(device)

print(model)
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")
print("\nModel Architecture:")
print(f"  Summary Branch: {len(dataset.feature_cols)} -> {DENSE_HIDDEN} -> {DENSE_HIDDEN//2}")
print(f"  Sequential Branch: {dataset.n_sequential_features} -> LSTM({LSTM_HIDDEN})")
print(f"  Fusion Head: {DENSE_HIDDEN//2 + LSTM_HIDDEN} -> 64 -> 1")

In [None]:
# Loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)

In [None]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    
    for summary, sequential, target, lengths in loader:
        summary = summary.to(device)
        sequential = sequential.to(device)
        target = target.to(device).unsqueeze(1)
        lengths = lengths.to(device)
        
        optimizer.zero_grad()
        output = model(summary, sequential, lengths)
        loss = criterion(output, target)
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        total_loss += loss.item()
    
    return total_loss / len(loader)


def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        for summary, sequential, target, lengths in loader:
            summary = summary.to(device)
            sequential = sequential.to(device)
            target = target.to(device).unsqueeze(1)
            lengths = lengths.to(device)
            
            output = model(summary, sequential, lengths)
            loss = criterion(output, target)
            total_loss += loss.item()
            
            all_preds.extend(output.cpu().numpy().flatten())
            all_targets.extend(target.cpu().numpy().flatten())
    
    return total_loss / len(loader), np.array(all_preds), np.array(all_targets)

In [None]:
# Training loop
train_losses = []
val_losses = []
best_val_loss = float('inf')

for epoch in tqdm(range(EPOCHS), desc='Training'):
    train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_preds, val_targets = validate(model, val_loader, criterion, device)
    
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    
    scheduler.step(val_loss)
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), DATA_DIR / 'best_model.pt')
    
    if (epoch + 1) % 5 == 0:
        corr = np.corrcoef(val_preds, val_targets)[0, 1]
        print(f"Epoch {epoch+1}/{EPOCHS} - Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}, Corr: {corr:.4f}")

## Evaluation

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(train_losses, label='Train')
axes[0].plot(val_losses, label='Val')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Curves')
axes[0].legend()

# Load best model and get final predictions
model.load_state_dict(torch.load(DATA_DIR / 'best_model.pt'))
_, val_preds, val_targets = validate(model, val_loader, criterion, device)

axes[1].scatter(val_targets, val_preds, alpha=0.5, s=10)
axes[1].plot([val_targets.min(), val_targets.max()], [val_targets.min(), val_targets.max()], 'r--')
axes[1].set_xlabel('Actual')
axes[1].set_ylabel('Predicted')
axes[1].set_title(f'Predictions (Corr: {np.corrcoef(val_preds, val_targets)[0,1]:.4f})')

plt.tight_layout()
plt.savefig(DATA_DIR / 'training_results.png', dpi=150)
plt.show()

In [None]:
# Final metrics
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

print("=== Final Validation Metrics ===")
print(f"MSE: {mean_squared_error(val_targets, val_preds):.6f}")
print(f"RMSE: {np.sqrt(mean_squared_error(val_targets, val_preds)):.6f}")
print(f"MAE: {mean_absolute_error(val_targets, val_preds):.6f}")
print(f"R2: {r2_score(val_targets, val_preds):.6f}")
print(f"Correlation: {np.corrcoef(val_preds, val_targets)[0,1]:.6f}")

In [None]:
# Save final model
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'train_losses': train_losses,
    'val_losses': val_losses,
    'config': {
        'summary_input_dim': len(dataset.feature_cols),
        'vpin_input_dim': dataset.n_vpin_features,
        'lstm_hidden_dim': LSTM_HIDDEN,
        'dense_hidden_dim': DENSE_HIDDEN,
    }
}, DATA_DIR / 'dual_branch_checkpoint.pt')

print(f"Model saved to {DATA_DIR / 'dual_branch_checkpoint.pt'}")