# AlgoSpace M-RMS Training Notebook

This notebook trains the Multi-Agent Risk Management Subsystem (M-RMS) using a divide-and-conquer strategy.

## Overview
- Environment Setup & Data Loading
- Custom Gymnasium Environment Implementation
- M-RMS Agent Architecture (3 Sub-Agents + Ensemble)
- MAPPO Training with RLlib
- Results Analysis & Model Saving

## Task 1: Notebook Setup & Data Loading

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Set working directory
import os
os.chdir('/content/drive/MyDrive/AlgoSpace')

In [None]:
# Install required libraries
!pip install -r requirements.txt -q

# Core imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import gymnasium as gym
from gymnasium import spaces
import ray
from ray import tune
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models import ModelCatalog
from ray.rllib.utils.annotations import override
from typing import Dict, List, Tuple, Optional, Any
from dataclasses import dataclass
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime, timedelta
import warnings
warnings.filterwarnings('ignore')

# Check GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Load historical data
print("Loading historical data...")
historical_data = pd.read_parquet('data/training_data_main.parquet')
print(f"Loaded {len(historical_data)} rows of historical data")
print(f"Date range: {historical_data.index.min()} to {historical_data.index.max()}")
print(f"Columns: {list(historical_data.columns)[:10]}...")  # Show first 10 columns

In [None]:
@dataclass
class SynergyEvent:
    """Represents a detected synergy event in historical data."""
    timestamp: pd.Timestamp
    index: int
    direction: str  # 'LONG' or 'SHORT'
    strength: float
    indicators: Dict[str, float]
    regime: str
    
class SynergyFinder:
    """Scans historical data to identify all valid synergy events."""
    
    def __init__(self, data: pd.DataFrame, config: Dict[str, Any]):
        self.data = data
        self.config = config
        self.synergy_events: List[SynergyEvent] = []
        
    def find_synergies(self) -> List[SynergyEvent]:
        """Identify all synergy events in the historical data."""
        print("Scanning for synergy events...")
        
        # Define synergy detection logic based on multiple indicators
        for i in range(100, len(self.data) - 100):  # Leave buffer for trade simulation
            row = self.data.iloc[i]
            
            # Example synergy detection logic (customize based on your strategy)
            # Check for bullish synergy
            if self._check_bullish_synergy(row, i):
                event = SynergyEvent(
                    timestamp=self.data.index[i],
                    index=i,
                    direction='LONG',
                    strength=self._calculate_synergy_strength(row, 'LONG'),
                    indicators=self._extract_indicators(row),
                    regime=self._determine_regime(row)
                )
                self.synergy_events.append(event)
                
            # Check for bearish synergy
            elif self._check_bearish_synergy(row, i):
                event = SynergyEvent(
                    timestamp=self.data.index[i],
                    index=i,
                    direction='SHORT',
                    strength=self._calculate_synergy_strength(row, 'SHORT'),
                    indicators=self._extract_indicators(row),
                    regime=self._determine_regime(row)
                )
                self.synergy_events.append(event)
        
        print(f"Found {len(self.synergy_events)} synergy events")
        return self.synergy_events
    
    def _check_bullish_synergy(self, row: pd.Series, idx: int) -> bool:
        """Check if current row represents a bullish synergy."""
        # Example criteria (customize based on your strategy)
        conditions = [
            row.get('rsi_14', 50) < 30,  # Oversold RSI
            row.get('macd_signal', 0) > 0,  # MACD bullish cross
            row.get('close', 0) > row.get('sma_200', 0),  # Above long-term MA
            row.get('volume', 0) > row.get('volume_sma_20', 0) * 1.5  # Volume spike
        ]
        return sum(conditions) >= 3
    
    def _check_bearish_synergy(self, row: pd.Series, idx: int) -> bool:
        """Check if current row represents a bearish synergy."""
        conditions = [
            row.get('rsi_14', 50) > 70,  # Overbought RSI
            row.get('macd_signal', 0) < 0,  # MACD bearish cross
            row.get('close', 0) < row.get('sma_200', 0),  # Below long-term MA
            row.get('volume', 0) > row.get('volume_sma_20', 0) * 1.5  # Volume spike
        ]
        return sum(conditions) >= 3
    
    def _calculate_synergy_strength(self, row: pd.Series, direction: str) -> float:
        """Calculate the strength of the synergy signal (0-1)."""
        # Implement your synergy strength calculation
        strength = 0.5  # Base strength
        
        if direction == 'LONG':
            strength += (30 - row.get('rsi_14', 50)) / 100  # Stronger if more oversold
        else:
            strength += (row.get('rsi_14', 50) - 70) / 100  # Stronger if more overbought
            
        return np.clip(strength, 0.0, 1.0)
    
    def _extract_indicators(self, row: pd.Series) -> Dict[str, float]:
        """Extract relevant indicators for the synergy event."""
        return {
            'rsi': row.get('rsi_14', 0),
            'macd': row.get('macd', 0),
            'macd_signal': row.get('macd_signal', 0),
            'atr': row.get('atr_14', 0),
            'volume_ratio': row.get('volume', 0) / row.get('volume_sma_20', 1),
            'bb_position': (row.get('close', 0) - row.get('bb_lower', 0)) / 
                          (row.get('bb_upper', 1) - row.get('bb_lower', 0))
        }
    
    def _determine_regime(self, row: pd.Series) -> str:
        """Determine market regime at the time of synergy."""
        # Simple regime detection (customize as needed)
        if row.get('atr_14', 0) > row.get('atr_14_sma_50', 0) * 1.5:
            return 'HIGH_VOLATILITY'
        elif row.get('close', 0) > row.get('sma_50', 0) > row.get('sma_200', 0):
            return 'TRENDING_UP'
        elif row.get('close', 0) < row.get('sma_50', 0) < row.get('sma_200', 0):
            return 'TRENDING_DOWN'
        else:
            return 'RANGING'

In [None]:
# Find synergy events
synergy_config = {
    'min_synergy_strength': 0.6,
    'lookback_periods': 20
}

synergy_finder = SynergyFinder(historical_data, synergy_config)
synergy_events = synergy_finder.find_synergies()

# Analyze synergy distribution
print(f"\nSynergy Event Distribution:")
print(f"Total events: {len(synergy_events)}")
print(f"Long signals: {sum(1 for e in synergy_events if e.direction == 'LONG')}")
print(f"Short signals: {sum(1 for e in synergy_events if e.direction == 'SHORT')}")

# Show regime distribution
regime_counts = {}
for event in synergy_events:
    regime_counts[event.regime] = regime_counts.get(event.regime, 0) + 1
print(f"\nRegime distribution: {regime_counts}")

## Task 2: Custom Gymnasium Environment for M-RMS Training

In [None]:
@dataclass
class TradeResult:
    """Results from simulating a trade."""
    entry_price: float
    exit_price: float
    exit_reason: str  # 'STOP_LOSS', 'TAKE_PROFIT', 'TIME_EXIT'
    pnl: float
    pnl_points: float
    duration_bars: int
    max_favorable_excursion: float
    max_adverse_excursion: float
    rule_violations: List[str]


class VirtualAccount:
    """Simulates a trading account with TopStep-like rules."""
    
    def __init__(self, initial_balance: float = 50000.0, config: Dict[str, Any] = None):
        self.initial_balance = initial_balance
        self.balance = initial_balance
        self.config = config or self._default_config()
        
        # Track metrics
        self.trades: List[TradeResult] = []
        self.daily_pnl: Dict[str, float] = {}
        self.peak_balance = initial_balance
        self.current_drawdown = 0.0
        self.max_drawdown = 0.0
        
    def _default_config(self) -> Dict[str, Any]:
        """Default TopStep-like rules."""
        return {
            'max_daily_loss': 1000.0,  # $1000
            'max_drawdown': 2000.0,    # $2000
            'max_position_size': 5,     # contracts
            'profit_target': 3000.0,    # $3000
            'min_trading_days': 10,
            'point_value': 5.0          # MES point value
        }
    
    def reset(self):
        """Reset account to initial state."""
        self.balance = self.initial_balance
        self.trades.clear()
        self.daily_pnl.clear()
        self.peak_balance = self.initial_balance
        self.current_drawdown = 0.0
        self.max_drawdown = 0.0
    
    def update_from_trade(self, trade_result: TradeResult, trade_date: pd.Timestamp) -> List[str]:
        """Update account with trade result and check for rule violations."""
        violations = []
        
        # Update balance
        self.balance += trade_result.pnl
        self.trades.append(trade_result)
        
        # Update daily P&L
        date_str = trade_date.date().isoformat()
        self.daily_pnl[date_str] = self.daily_pnl.get(date_str, 0) + trade_result.pnl
        
        # Check daily loss limit
        if self.daily_pnl[date_str] < -self.config['max_daily_loss']:
            violations.append('MAX_DAILY_LOSS_EXCEEDED')
        
        # Update drawdown
        if self.balance > self.peak_balance:
            self.peak_balance = self.balance
        self.current_drawdown = self.peak_balance - self.balance
        self.max_drawdown = max(self.max_drawdown, self.current_drawdown)
        
        # Check max drawdown
        if self.current_drawdown > self.config['max_drawdown']:
            violations.append('MAX_DRAWDOWN_EXCEEDED')
        
        return violations
    
    def calculate_metrics(self) -> Dict[str, float]:
        """Calculate account performance metrics."""
        if not self.trades:
            return {
                'total_pnl': 0.0,
                'win_rate': 0.0,
                'profit_factor': 0.0,
                'sortino_ratio': 0.0,
                'max_drawdown': 0.0
            }
        
        # Calculate metrics
        total_pnl = sum(t.pnl for t in self.trades)
        winning_trades = [t for t in self.trades if t.pnl > 0]
        losing_trades = [t for t in self.trades if t.pnl < 0]
        
        win_rate = len(winning_trades) / len(self.trades) if self.trades else 0
        
        gross_profit = sum(t.pnl for t in winning_trades)
        gross_loss = abs(sum(t.pnl for t in losing_trades))
        profit_factor = gross_profit / gross_loss if gross_loss > 0 else float('inf')
        
        # Calculate Sortino ratio (using downside deviation)
        returns = [t.pnl / self.initial_balance for t in self.trades]
        if len(returns) > 1:
            avg_return = np.mean(returns)
            downside_returns = [r for r in returns if r < 0]
            downside_dev = np.std(downside_returns) if downside_returns else 0.001
            sortino_ratio = (avg_return * 252) / (downside_dev * np.sqrt(252)) if downside_dev > 0 else 0
        else:
            sortino_ratio = 0.0
        
        return {
            'total_pnl': total_pnl,
            'win_rate': win_rate,
            'profit_factor': profit_factor,
            'sortino_ratio': sortino_ratio,
            'max_drawdown': self.max_drawdown
        }

In [None]:
class TradeSimulator:
    """Simulates trade execution using historical data."""
    
    def __init__(self, historical_data: pd.DataFrame, config: Dict[str, Any]):
        self.data = historical_data
        self.config = config
        
    def simulate_trade(self, 
                      entry_idx: int, 
                      direction: str,
                      stop_loss: float,
                      take_profit: float,
                      position_size: int) -> TradeResult:
        """Simulate a trade from entry to exit."""
        entry_bar = self.data.iloc[entry_idx]
        entry_price = entry_bar['close']
        
        # Initialize tracking variables
        max_favorable = 0.0
        max_adverse = 0.0
        
        # Simulate trade bar by bar
        for i in range(entry_idx + 1, min(entry_idx + 100, len(self.data))):  # Max 100 bars
            bar = self.data.iloc[i]
            
            # Update excursions
            if direction == 'LONG':
                favorable = bar['high'] - entry_price
                adverse = entry_price - bar['low']
            else:
                favorable = entry_price - bar['low']
                adverse = bar['high'] - entry_price
            
            max_favorable = max(max_favorable, favorable)
            max_adverse = max(max_adverse, adverse)
            
            # Check exit conditions
            if direction == 'LONG':
                if bar['low'] <= stop_loss:
                    exit_price = stop_loss
                    exit_reason = 'STOP_LOSS'
                    break
                elif bar['high'] >= take_profit:
                    exit_price = take_profit
                    exit_reason = 'TAKE_PROFIT'
                    break
            else:
                if bar['high'] >= stop_loss:
                    exit_price = stop_loss
                    exit_reason = 'STOP_LOSS'
                    break
                elif bar['low'] <= take_profit:
                    exit_price = take_profit
                    exit_reason = 'TAKE_PROFIT'
                    break
        else:
            # Time exit
            exit_price = self.data.iloc[i]['close']
            exit_reason = 'TIME_EXIT'
        
        # Calculate P&L
        if direction == 'LONG':
            pnl_points = exit_price - entry_price
        else:
            pnl_points = entry_price - exit_price
            
        pnl = pnl_points * position_size * self.config.get('point_value', 5.0)
        
        return TradeResult(
            entry_price=entry_price,
            exit_price=exit_price,
            exit_reason=exit_reason,
            pnl=pnl,
            pnl_points=pnl_points,
            duration_bars=i - entry_idx,
            max_favorable_excursion=max_favorable,
            max_adverse_excursion=max_adverse,
            rule_violations=[]
        )

In [None]:
class RiskManagementEnv(gym.Env):
    """Custom Gymnasium environment for M-RMS training."""
    
    def __init__(self, 
                 historical_data: pd.DataFrame,
                 synergy_events: List[SynergyEvent],
                 config: Dict[str, Any]):
        super().__init__()
        
        self.historical_data = historical_data
        self.synergy_events = synergy_events
        self.config = config
        
        # Initialize components
        self.virtual_account = VirtualAccount(config=config.get('account_config', {}))
        self.trade_simulator = TradeSimulator(historical_data, config)
        
        # Episode tracking
        self.current_synergy_idx = 0
        self.episode_trades = 0
        self.max_trades_per_episode = config.get('max_trades_per_episode', 20)
        
        # Define action space
        self.action_space = spaces.Dict({
            'position_size': spaces.Discrete(6),  # 0-5 contracts
            'sl_atr_multiplier': spaces.Box(low=0.5, high=3.0, shape=(1,), dtype=np.float32),
            'rr_ratio': spaces.Box(low=1.0, high=5.0, shape=(1,), dtype=np.float32)
        })
        
        # Define observation space
        self.observation_space = spaces.Dict({
            'synergy_vector': spaces.Box(low=-np.inf, high=np.inf, shape=(30,), dtype=np.float32),
            'account_state_vector': spaces.Box(low=-np.inf, high=np.inf, shape=(10,), dtype=np.float32)
        })
        
    def reset(self, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[Dict, Dict]:
        """Reset environment for new episode."""
        super().reset(seed=seed)
        
        # Reset account
        self.virtual_account.reset()
        self.episode_trades = 0
        
        # Select random starting synergy
        if seed is not None:
            np.random.seed(seed)
        self.current_synergy_idx = np.random.randint(0, len(self.synergy_events) - 1)
        
        # Build initial observation
        observation = self._build_observation()
        info = {'synergy_event': self.synergy_events[self.current_synergy_idx]}
        
        return observation, info
    
    def step(self, action: Dict) -> Tuple[Dict, float, bool, bool, Dict]:
        """Execute one step in the environment."""
        # Extract action components
        position_size = int(action['position_size'])
        sl_atr_multiplier = float(action['sl_atr_multiplier'][0])
        rr_ratio = float(action['rr_ratio'][0])
        
        # Get current synergy event
        synergy = self.synergy_events[self.current_synergy_idx]
        
        # Calculate stop loss and take profit
        atr = self.historical_data.iloc[synergy.index]['atr_14']
        entry_price = self.historical_data.iloc[synergy.index]['close']
        
        if synergy.direction == 'LONG':
            stop_loss = entry_price - (sl_atr_multiplier * atr)
            take_profit = entry_price + (sl_atr_multiplier * atr * rr_ratio)
        else:
            stop_loss = entry_price + (sl_atr_multiplier * atr)
            take_profit = entry_price - (sl_atr_multiplier * atr * rr_ratio)
        
        # Simulate trade if position size > 0
        if position_size > 0:
            trade_result = self.trade_simulator.simulate_trade(
                entry_idx=synergy.index,
                direction=synergy.direction,
                stop_loss=stop_loss,
                take_profit=take_profit,
                position_size=position_size
            )
            
            # Update account
            violations = self.virtual_account.update_from_trade(
                trade_result, 
                synergy.timestamp
            )
            trade_result.rule_violations = violations
            
            # Calculate reward
            reward = self._calculate_reward(trade_result, position_size)
            self.episode_trades += 1
        else:
            # No trade taken
            reward = -0.1  # Small penalty for not trading
            trade_result = None
        
        # Find next synergy
        self.current_synergy_idx += 1
        
        # Check if episode is done
        done = (self.current_synergy_idx >= len(self.synergy_events) - 1 or 
                self.episode_trades >= self.max_trades_per_episode or
                len(violations) > 0)  # End if rules violated
        
        truncated = False
        
        # Build next observation
        if not done:
            observation = self._build_observation()
        else:
            # Return zero observation if done
            observation = {
                'synergy_vector': np.zeros(30, dtype=np.float32),
                'account_state_vector': np.zeros(10, dtype=np.float32)
            }
        
        # Info dict
        info = {
            'trade_result': trade_result,
            'account_metrics': self.virtual_account.calculate_metrics(),
            'violations': violations if position_size > 0 else []
        }
        
        return observation, reward, done, truncated, info
    
    def _build_observation(self) -> Dict[str, np.ndarray]:
        """Build observation from current state."""
        synergy = self.synergy_events[self.current_synergy_idx]
        
        # Build synergy vector (30 features)
        synergy_vector = self._build_synergy_vector(synergy)
        
        # Build account state vector (10 features)
        account_vector = self._build_account_vector()
        
        return {
            'synergy_vector': synergy_vector.astype(np.float32),
            'account_state_vector': account_vector.astype(np.float32)
        }
    
    def _build_synergy_vector(self, synergy: SynergyEvent) -> np.ndarray:
        """Build feature vector from synergy event."""
        row = self.historical_data.iloc[synergy.index]
        
        # Extract features (customize based on your needs)
        features = [
            # Price action features
            row['close'] / row['sma_20'] - 1,
            row['close'] / row['sma_50'] - 1,
            row['close'] / row['sma_200'] - 1,
            
            # Technical indicators
            row['rsi_14'] / 100,
            row['macd'] / row['atr_14'],
            row['macd_signal'] / row['atr_14'],
            
            # Volatility
            row['atr_14'] / row['close'],
            row['bb_width'] / row['close'],
            
            # Volume
            row['volume'] / row['volume_sma_20'],
            
            # Market structure
            synergy.strength,
            1.0 if synergy.direction == 'LONG' else -1.0,
            
            # Regime encoding (one-hot)
            1.0 if synergy.regime == 'TRENDING_UP' else 0.0,
            1.0 if synergy.regime == 'TRENDING_DOWN' else 0.0,
            1.0 if synergy.regime == 'RANGING' else 0.0,
            1.0 if synergy.regime == 'HIGH_VOLATILITY' else 0.0,
        ]
        
        # Pad to 30 features
        while len(features) < 30:
            features.append(0.0)
            
        return np.array(features[:30])
    
    def _build_account_vector(self) -> np.ndarray:
        """Build feature vector from account state."""
        metrics = self.virtual_account.calculate_metrics()
        
        features = [
            # Account balance
            self.virtual_account.balance / self.virtual_account.initial_balance,
            
            # Drawdown
            self.virtual_account.current_drawdown / self.virtual_account.config['max_drawdown'],
            self.virtual_account.max_drawdown / self.virtual_account.config['max_drawdown'],
            
            # Performance metrics
            metrics['win_rate'],
            np.clip(metrics['profit_factor'], 0, 10) / 10,
            np.clip(metrics['sortino_ratio'], -3, 3) / 3,
            
            # Trading activity
            len(self.virtual_account.trades) / 100,
            self.episode_trades / self.max_trades_per_episode,
            
            # Recent performance (last 5 trades)
            self._get_recent_performance(5),
            self._get_recent_performance(10),
        ]
        
        return np.array(features[:10])
    
    def _get_recent_performance(self, n: int) -> float:
        """Get win rate of last n trades."""
        if len(self.virtual_account.trades) < n:
            return 0.5  # Neutral if not enough trades
        
        recent_trades = self.virtual_account.trades[-n:]
        wins = sum(1 for t in recent_trades if t.pnl > 0)
        return wins / n
    
    def _calculate_reward(self, trade_result: Optional[TradeResult], position_size: int) -> float:
        """Calculate sophisticated reward based on Sortino ratio impact."""
        if trade_result is None:
            return -0.1  # Small penalty for not trading
        
        # Base reward from P&L (normalized)
        pnl_reward = trade_result.pnl / 1000  # Normalize by $1000
        
        # Calculate Sortino ratio before and after this trade
        metrics_before = self._calculate_sortino_excluding_last()
        metrics_after = self.virtual_account.calculate_metrics()
        
        sortino_before = metrics_before.get('sortino_ratio', 0)
        sortino_after = metrics_after.get('sortino_ratio', 0)
        sortino_impact = (sortino_after - sortino_before) * 2  # Scale impact
        
        # Risk-adjusted position sizing reward
        optimal_size = self._calculate_optimal_position_size(trade_result)
        size_penalty = -abs(position_size - optimal_size) * 0.1
        
        # Combine rewards
        reward = pnl_reward + sortino_impact + size_penalty
        
        # Apply penalties
        if trade_result.rule_violations:
            reward -= 10.0  # Large penalty for rule violations
            
        # Penalty for excessive drawdown
        if trade_result.max_adverse_excursion > trade_result.max_favorable_excursion * 2:
            reward -= 0.5  # Poor risk management
            
        return float(reward)
    
    def _calculate_sortino_excluding_last(self) -> Dict[str, float]:
        """Calculate metrics excluding the last trade."""
        if len(self.virtual_account.trades) <= 1:
            return {'sortino_ratio': 0.0}
        
        # Temporarily remove last trade
        last_trade = self.virtual_account.trades.pop()
        metrics = self.virtual_account.calculate_metrics()
        self.virtual_account.trades.append(last_trade)
        
        return metrics
    
    def _calculate_optimal_position_size(self, trade_result: TradeResult) -> int:
        """Calculate what the optimal position size should have been."""
        # Simple Kelly-inspired sizing
        win_rate = self.virtual_account.calculate_metrics()['win_rate']
        if win_rate == 0:
            return 1
        
        avg_win = np.mean([t.pnl for t in self.virtual_account.trades if t.pnl > 0]) if self.virtual_account.trades else 100
        avg_loss = abs(np.mean([t.pnl for t in self.virtual_account.trades if t.pnl < 0])) if self.virtual_account.trades else 100
        
        if avg_loss > 0:
            kelly_fraction = (win_rate * avg_win - (1 - win_rate) * avg_loss) / avg_win
            optimal_size = int(np.clip(kelly_fraction * 10, 1, 5))  # Scale to 1-5 contracts
        else:
            optimal_size = 2
            
        return optimal_size

## Task 3: M-RMS Agent Architecture

In [None]:
class PositionSizingAgent(nn.Module):
    """Sub-agent responsible for position sizing decisions."""
    
    def __init__(self, input_dim: int = 40, hidden_dim: int = 128):
        super().__init__()
        
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim // 2, 6)  # 6 position size options (0-5)
        )
        
    def forward(self, state: torch.Tensor) -> torch.Tensor:
        """Forward pass returning position size logits."""
        return self.network(state)


class StopLossAgent(nn.Module):
    """Sub-agent responsible for stop loss placement."""
    
    def __init__(self, input_dim: int = 40, hidden_dim: int = 64):
        super().__init__()
        
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1),  # ATR multiplier
            nn.Sigmoid()  # Ensure positive output
        )
        
        # Scale sigmoid output to desired range [0.5, 3.0]
        self.min_multiplier = 0.5
        self.max_multiplier = 3.0
        
    def forward(self, state: torch.Tensor) -> torch.Tensor:
        """Forward pass returning stop loss ATR multiplier."""
        raw_output = self.network(state)
        # Scale to desired range
        scaled_output = self.min_multiplier + (self.max_multiplier - self.min_multiplier) * raw_output
        return scaled_output


class ProfitTargetAgent(nn.Module):
    """Sub-agent responsible for profit target placement."""
    
    def __init__(self, input_dim: int = 40, hidden_dim: int = 64):
        super().__init__()
        
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1),  # Risk-reward ratio
            nn.Sigmoid()
        )
        
        # Scale sigmoid output to desired range [1.0, 5.0]
        self.min_rr = 1.0
        self.max_rr = 5.0
        
    def forward(self, state: torch.Tensor) -> torch.Tensor:
        """Forward pass returning risk-reward ratio."""
        raw_output = self.network(state)
        # Scale to desired range
        scaled_output = self.min_rr + (self.max_rr - self.min_rr) * raw_output
        return scaled_output

In [None]:
class RiskManagementEnsemble(TorchModelV2, nn.Module):
    """Ensemble coordinator for the three risk management sub-agents."""
    
    def __init__(self, 
                 obs_space: gym.Space,
                 action_space: gym.Space,
                 num_outputs: int,
                 model_config: Dict[str, Any],
                 name: str):
        
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        nn.Module.__init__(self)
        
        # Calculate input dimension
        synergy_dim = obs_space['synergy_vector'].shape[0]
        account_dim = obs_space['account_state_vector'].shape[0]
        self.input_dim = synergy_dim + account_dim
        
        # Initialize sub-agents
        self.position_agent = PositionSizingAgent(self.input_dim)
        self.stop_loss_agent = StopLossAgent(self.input_dim)
        self.profit_target_agent = ProfitTargetAgent(self.input_dim)
        
        # Value function head
        self.value_head = nn.Sequential(
            nn.Linear(self.input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )
        
        # Store value for value function
        self._value = None
        
    @override(TorchModelV2)
    def forward(self, input_dict, state, seq_lens):
        """Forward pass through the ensemble."""
        # Extract observations
        obs = input_dict["obs"]
        synergy_vector = obs['synergy_vector']
        account_vector = obs['account_state_vector']
        
        # Concatenate state vectors
        combined_state = torch.cat([synergy_vector, account_vector], dim=-1)
        
        # Get outputs from each sub-agent
        position_logits = self.position_agent(combined_state)
        sl_multiplier = self.stop_loss_agent(combined_state)
        rr_ratio = self.profit_target_agent(combined_state)
        
        # Compute value
        self._value = self.value_head(combined_state).squeeze(-1)
        
        # Combine outputs for RLlib
        # RLlib expects a flat output tensor for all actions
        outputs = torch.cat([
            position_logits,  # 6 values for discrete action
            sl_multiplier,    # 1 value for continuous action
            rr_ratio         # 1 value for continuous action
        ], dim=-1)
        
        return outputs, state
    
    @override(TorchModelV2)
    def value_function(self):
        """Return the value function output."""
        return self._value
    
    def get_action_dict(self, obs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """Convert observations to action dictionary (for inference)."""
        with torch.no_grad():
            # Prepare state
            synergy_vector = obs['synergy_vector']
            account_vector = obs['account_state_vector']
            combined_state = torch.cat([synergy_vector, account_vector], dim=-1)
            
            # Get actions from sub-agents
            position_logits = self.position_agent(combined_state)
            position_size = torch.argmax(position_logits, dim=-1)
            
            sl_multiplier = self.stop_loss_agent(combined_state)
            rr_ratio = self.profit_target_agent(combined_state)
            
            return {
                'position_size': position_size,
                'sl_atr_multiplier': sl_multiplier,
                'rr_ratio': rr_ratio
            }

## Task 4: Configure and Run RLlib Training

In [None]:
# Register custom model
ModelCatalog.register_custom_model("risk_management_ensemble", RiskManagementEnsemble)

# Environment configuration
env_config = {
    'historical_data': historical_data,
    'synergy_events': synergy_events,
    'max_trades_per_episode': 20,
    'account_config': {
        'initial_balance': 50000.0,
        'max_daily_loss': 1000.0,
        'max_drawdown': 2000.0,
        'max_position_size': 5,
        'point_value': 5.0
    },
    'point_value': 5.0
}

# Create environment instance for testing
test_env = RiskManagementEnv(**env_config)
print(f"Environment created successfully")
print(f"Action space: {test_env.action_space}")
print(f"Observation space: {test_env.observation_space}")

In [None]:
# Initialize Ray
ray.init(ignore_reinit_error=True)

# PPO configuration
config = PPOConfig()
config = config.training(
    lr=3e-4,
    train_batch_size=4000,
    sgd_minibatch_size=128,
    num_sgd_iter=10,
    gamma=0.99,
    lambda_=0.95,
    clip_param=0.2,
    vf_clip_param=10.0,
    entropy_coeff=0.01,
    vf_loss_coeff=0.5,
    grad_clip=0.5,
    model={
        "custom_model": "risk_management_ensemble",
        "custom_model_config": {},
    }
)

config = config.resources(
    num_gpus=1 if torch.cuda.is_available() else 0,
    num_cpus_per_worker=2
)

config = config.rollouts(
    num_rollout_workers=4,
    rollout_fragment_length=200,
    batch_mode="truncate_episodes"
)

config = config.environment(
    env=RiskManagementEnv,
    env_config=env_config,
    disable_env_checking=True
)

# Build trainer
trainer = config.build()
print("Trainer initialized successfully")

In [None]:
# Training loop
N_ITERATIONS = 100
results_list = []

print("Starting training...")
for i in range(N_ITERATIONS):
    # Train for one iteration
    result = trainer.train()
    
    # Extract key metrics
    episode_reward_mean = result['episode_reward_mean']
    episode_len_mean = result['episode_len_mean']
    
    # Store results
    results_list.append({
        'iteration': i,
        'episode_reward_mean': episode_reward_mean,
        'episode_len_mean': episode_len_mean,
        'episodes_total': result.get('episodes_total', 0)
    })
    
    # Print progress every 10 iterations
    if i % 10 == 0:
        print(f"Iteration {i}: reward_mean={episode_reward_mean:.3f}, len_mean={episode_len_mean:.1f}")
        
    # Optional: Save checkpoint periodically
    if i % 20 == 0 and i > 0:
        checkpoint = trainer.save()
        print(f"Checkpoint saved at: {checkpoint}")

print("\nTraining completed!")

## Task 5: Analyze Results and Save Model

In [None]:
# Convert results to DataFrame for analysis
results_df = pd.DataFrame(results_list)

# Plot learning curve
plt.figure(figsize=(12, 6))

# Reward curve
plt.subplot(1, 2, 1)
plt.plot(results_df['iteration'], results_df['episode_reward_mean'])
plt.xlabel('Training Iteration')
plt.ylabel('Episode Reward Mean')
plt.title('M-RMS Agent Learning Curve')
plt.grid(True, alpha=0.3)

# Add moving average
window = 10
rolling_mean = results_df['episode_reward_mean'].rolling(window=window).mean()
plt.plot(results_df['iteration'], rolling_mean, 'r-', linewidth=2, 
         label=f'{window}-iteration moving average')
plt.legend()

# Episode length curve
plt.subplot(1, 2, 2)
plt.plot(results_df['iteration'], results_df['episode_len_mean'])
plt.xlabel('Training Iteration')
plt.ylabel('Episode Length Mean')
plt.title('Average Episode Length')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('mrms_training_curves.png', dpi=300)
plt.show()

# Print final statistics
print("\nFinal Training Statistics:")
print(f"Final reward mean: {results_df['episode_reward_mean'].iloc[-1]:.3f}")
print(f"Best reward mean: {results_df['episode_reward_mean'].max():.3f}")
print(f"Average reward (last 20 iterations): {results_df['episode_reward_mean'].tail(20).mean():.3f}")

In [None]:
# Extract and save the trained model
policy = trainer.get_policy()
model = policy.model

# Save the complete ensemble model
save_path = '/content/drive/MyDrive/AlgoSpace/models/mrms_agent.pth'
torch.save({
    'model_state_dict': model.state_dict(),
    'position_agent': model.position_agent.state_dict(),
    'stop_loss_agent': model.stop_loss_agent.state_dict(),
    'profit_target_agent': model.profit_target_agent.state_dict(),
    'value_head': model.value_head.state_dict(),
    'config': env_config,
    'training_iterations': N_ITERATIONS,
    'final_reward_mean': results_df['episode_reward_mean'].iloc[-1]
}, save_path)

print(f"\nModel saved successfully to: {save_path}")

In [None]:
# Test the saved model
print("\nTesting saved model...")

# Create a new model instance
test_model = RiskManagementEnsemble(
    obs_space=test_env.observation_space,
    action_space=test_env.action_space,
    num_outputs=8,  # 6 for position + 1 for SL + 1 for TP
    model_config={},
    name="test_model"
)

# Load saved weights
checkpoint = torch.load(save_path)
test_model.load_state_dict(checkpoint['model_state_dict'])
test_model.eval()

# Test on a sample observation
obs, _ = test_env.reset()
obs_tensor = {
    'synergy_vector': torch.tensor(obs['synergy_vector']).unsqueeze(0),
    'account_state_vector': torch.tensor(obs['account_state_vector']).unsqueeze(0)
}

# Get action from model
with torch.no_grad():
    action_dict = test_model.get_action_dict(obs_tensor)
    
print("\nSample model output:")
print(f"Position size: {action_dict['position_size'].item()}")
print(f"Stop loss ATR multiplier: {action_dict['sl_atr_multiplier'].item():.3f}")
print(f"Risk-reward ratio: {action_dict['rr_ratio'].item():.3f}")

# Cleanup
ray.shutdown()

In [None]:
# Summary and next steps
print("\n" + "="*50)
print("M-RMS TRAINING COMPLETE")
print("="*50)
print(f"\nTraining completed with {N_ITERATIONS} iterations")
print(f"Final model saved to: {save_path}")
print(f"\nThe trained M-RMS agent has learned to:")
print("1. Size positions based on account state and market conditions")
print("2. Place stop losses using dynamic ATR multipliers")
print("3. Set profit targets with adaptive risk-reward ratios")
print("\nThis model can now be integrated into the AlgoSpace system!")