# MRMS Communication LSTM Training

This notebook trains the MRMS Communication LSTM layer that adds temporal context and uncertainty estimation to risk management decisions.

In [None]:
# Setup and Imports
import sys
sys.path.append('..')

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import logging
import os
from datetime import datetime

from src.agents.mrms.communication import MRMSCommunicationLSTM
from src.agents.mrms.losses import MRMSCommunicationLoss  
from src.training.datasets.mrms_comm_dataset import (
    MRMSCommunicationDataset,
    create_synthetic_trade_history,
    create_dataloaders
)

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")

In [None]:
# Load Historical Trade Data
# For development, we'll use synthetic data
# In production, load your actual trade history

# Check if we have real data
data_path = '../data/processed/trade_history.parquet'
if os.path.exists(data_path):
    trade_history = pd.read_parquet(data_path)
    print("Loaded real trade history")
else:
    # Generate synthetic data for development
    print("Generating synthetic trade history...")
    trade_history = create_synthetic_trade_history(
        n_trades=5000,
        win_rate=0.45,
        avg_rr=2.0,
        seed=42
    )

# Display statistics
print(f"Total trades: {len(trade_history)}")
print(f"Win rate: {trade_history['hit_target'].mean():.2%}")
print(f"Average RR: {trade_history['tp_distance'].mean() / trade_history['sl_distance'].mean():.2f}")
print(f"Date range: {trade_history['timestamp'].min()} to {trade_history['timestamp'].max()}")

# Create train/val/test splits
n_total = len(trade_history)
n_train = int(0.7 * n_total)
n_val = int(0.15 * n_total)

train_df = trade_history[:n_train]
val_df = trade_history[n_train:n_train+n_val]  
test_df = trade_history[n_train+n_val:]

print(f"\nTrain: {len(train_df)} trades")
print(f"Val: {len(val_df)} trades")
print(f"Test: {len(test_df)} trades")

In [None]:
# Create Datasets and Dataloaders

# Configuration
config = {
    'risk_vector_dim': 4,
    'outcome_dim': 3,
    'hidden_dim': 16,
    'output_dim': 8,
    'memory_size': 20,
    'sequence_length': 10,
    'batch_size': 32,
    'learning_rate': 1e-4,
    'n_epochs': 100,
    'weight_risk': 0.3,
    'weight_outcome': 0.3,
    'weight_uncertainty': 0.2,
    'weight_temporal': 0.2
}

# Create dataloaders
train_loader, val_loader, test_loader = create_dataloaders(
    train_df,
    val_df,
    test_df,
    batch_size=config['batch_size'],
    sequence_length=config['sequence_length']
)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

In [None]:
# Initialize Model and Training Components

# Initialize model
model = MRMSCommunicationLSTM(config).to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Loss function
loss_fn = MRMSCommunicationLoss(config)

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])

# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=10
)

# Training history
history = {
    'train_loss': [], 'val_loss': [],
    'train_risk_loss': [], 'val_risk_loss': [],
    'train_uncertainty': [], 'val_uncertainty': []
}

In [None]:
# Training Functions

def train_epoch(model, loader, loss_fn, optimizer, device):
    model.train()
    total_loss = 0
    losses_dict = {'risk': 0, 'outcome': 0, 'uncertainty': 0, 'temporal': 0}
    
    for batch in tqdm(loader, desc='Training'):
        # Reset hidden state for each batch
        model.reset_hidden_state()
        
        risk_seq = batch['risk_sequence'].to(device)
        outcome_seq = batch['outcome_sequence'].to(device)
        target_risk = batch['target_risk'].to(device)
        target_outcome = batch['target_outcome'].to(device)
        
        optimizer.zero_grad()
        
        # Process sequence
        prev_mu = None
        for t in range(risk_seq.size(1)):
            mu, sigma = model(
                risk_seq[:, t, :],
                outcome_seq[:, t, :],
                update_memory=True
            )
            
            if t == risk_seq.size(1) - 1:  # Last timestep
                # Calculate loss
                losses = loss_fn(mu, sigma, target_risk, target_outcome, prev_mu)
                
                # Backward pass
                losses['total'].backward()
                
                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                
                # Update
                optimizer.step()
                
                # Track losses
                total_loss += losses['total'].item()
                for k, v in losses.items():
                    if k != 'total':
                        losses_dict[k] += v.item()
            
            prev_mu = mu.detach()
    
    # Average losses
    n_batches = len(loader)
    avg_losses = {k: v / n_batches for k, v in losses_dict.items()}
    avg_losses['total'] = total_loss / n_batches
    
    return avg_losses

def validate(model, loader, loss_fn, device):
    model.eval()
    total_loss = 0
    losses_dict = {'risk': 0, 'outcome': 0, 'uncertainty': 0, 'temporal': 0}
    
    with torch.no_grad():
        for batch in tqdm(loader, desc='Validation'):
            model.reset_hidden_state()
            
            risk_seq = batch['risk_sequence'].to(device)
            outcome_seq = batch['outcome_sequence'].to(device)
            target_risk = batch['target_risk'].to(device)
            target_outcome = batch['target_outcome'].to(device)
            
            # Process sequence
            for t in range(risk_seq.size(1)):
                mu, sigma = model(risk_seq[:, t, :], outcome_seq[:, t, :])
                
                if t == risk_seq.size(1) - 1:
                    losses = loss_fn(mu, sigma, target_risk, target_outcome)
                    total_loss += losses['total'].item()
                    for k, v in losses.items():
                        if k != 'total':
                            losses_dict[k] += v.item()
    
    n_batches = len(loader)
    avg_losses = {k: v / n_batches for k, v in losses_dict.items()}
    avg_losses['total'] = total_loss / n_batches
    
    return avg_losses

In [None]:
# Training Loop

best_val_loss = float('inf')
patience_counter = 0
patience = 20

# Create models directory if it doesn't exist
os.makedirs('../models', exist_ok=True)

for epoch in range(config['n_epochs']):
    print(f"\nEpoch {epoch + 1}/{config['n_epochs']}")
    
    # Train
    train_losses = train_epoch(model, train_loader, loss_fn, optimizer, device)
    
    # Validate
    val_losses = validate(model, val_loader, loss_fn, device)
    
    # Update scheduler
    scheduler.step(val_losses['total'])
    
    # Track history
    history['train_loss'].append(train_losses['total'])
    history['val_loss'].append(val_losses['total'])
    history['train_risk_loss'].append(train_losses['risk'])
    history['val_risk_loss'].append(val_losses['risk'])
    history['train_uncertainty'].append(train_losses['uncertainty'])
    history['val_uncertainty'].append(val_losses['uncertainty'])
    
    # Print progress
    print(f"Train Loss: {train_losses['total']:.4f} | Val Loss: {val_losses['total']:.4f}")
    print(f"Risk Loss: Train {train_losses['risk']:.4f} | Val {val_losses['risk']:.4f}")
    print(f"Uncertainty: Train {train_losses['uncertainty']:.4f} | Val {val_losses['uncertainty']:.4f}")
    
    # Early stopping
    if val_losses['total'] < best_val_loss:
        best_val_loss = val_losses['total']
        patience_counter = 0
        
        # Save best model
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_losses['total'],
            'config': config
        }, '../models/mrms_communication_best.pth')
        
        print("✅ Saved best model")
    else:
        patience_counter += 1
        
    if patience_counter >= patience:
        print(f"Early stopping triggered at epoch {epoch + 1}")
        break

In [None]:
# Evaluation and Visualization

# Plot training history
fig, axes = plt.subplots(2, 2, figsize=(12, 8))

# Total loss
axes[0, 0].plot(history['train_loss'], label='Train')
axes[0, 0].plot(history['val_loss'], label='Val')
axes[0, 0].set_title('Total Loss')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].legend()

# Risk prediction loss
axes[0, 1].plot(history['train_risk_loss'], label='Train')
axes[0, 1].plot(history['val_risk_loss'], label='Val')
axes[0, 1].set_title('Risk Prediction Loss')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].legend()

# Uncertainty calibration
axes[1, 0].plot(history['train_uncertainty'], label='Train')
axes[1, 0].plot(history['val_uncertainty'], label='Val')
axes[1, 0].set_title('Uncertainty Calibration Loss')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].legend()

# Learning rate
lrs = [optimizer.param_groups[0]['lr']] * len(history['train_loss'])
axes[1, 1].plot(lrs)
axes[1, 1].set_title('Learning Rate')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_yscale('log')

plt.tight_layout()
os.makedirs('../results', exist_ok=True)
plt.savefig('../results/mrms_communication_training.png')
plt.show()

# Test set evaluation
test_losses = validate(model, test_loader, loss_fn, device)
print(f"\nTest Set Performance:")
print(f"Total Loss: {test_losses['total']:.4f}")
print(f"Risk Prediction: {test_losses['risk']:.4f}")
print(f"Outcome Prediction: {test_losses['outcome']:.4f}")
print(f"Uncertainty Calibration: {test_losses['uncertainty']:.4f}")

In [None]:
# Save Production Model

# Save final production model
production_path = '../models/m_rms_model_comm.pth'
torch.save(model.state_dict(), production_path)
print(f"\n✅ Production model saved to {production_path}")

# Create integration test
print("\nIntegration Test:")
model.eval()
test_risk = torch.tensor([[2.0, 1.5, 2.5, 0.85]]).to(device)
test_outcome = torch.tensor([[0.0, 1.0, 0.03]]).to(device)

mu, sigma = model(test_risk, test_outcome)
print(f"Input risk: {test_risk}")
print(f"Risk embedding: {mu}")
print(f"Uncertainty: {sigma}")
print(f"Mean uncertainty: {sigma.mean().item():.4f}")

In [None]:
# Analyze Model Behavior

# Test adaptation to losing streak
print("\nTesting adaptation to losing streak:")
model.eval()
model.reset_hidden_state()

uncertainties = []
for i in range(10):
    # Simulate losses
    risk = torch.tensor([[3.0, 1.5, 3.0, 0.9]]).to(device)
    outcome = torch.tensor([[1.0, 0.0, -0.03]]).to(device)  # Stop hit
    
    mu, sigma = model(risk, outcome, update_memory=True)
    uncertainties.append(sigma.mean().item())
    
    if i % 2 == 0:
        print(f"Trade {i+1}: Uncertainty = {sigma.mean().item():.4f}")

plt.figure(figsize=(8, 4))
plt.plot(uncertainties, marker='o')
plt.title('Uncertainty Evolution During Losing Streak')
plt.xlabel('Trade Number')
plt.ylabel('Mean Uncertainty')
plt.grid(True)
plt.show()

print(f"\nUncertainty increased from {uncertainties[0]:.4f} to {uncertainties[-1]:.4f}")