# MARL Training Pipeline

## ‚öôÔ∏è Runtime: L4 GPU (~18 CU)
**Menu: Runtime ‚Üí Change runtime type ‚Üí L4 GPU**

## Anti-Leakage Guarantees
1. **Per-Symbol Temporal Split** - Each symbol split independently
2. **RL Trained on Train Data ONLY** - No information from val/test
3. **Proper Agent Isolation** - Agents communicate only through messages
4. **Centralized Critic, Decentralized Actors** - CTDE paradigm

## Output
- `trained/marl_agent_0.onnx` to `trained/marl_agent_4.onnx` - 5 agent networks
- `trained/marl_agent_0.pt` to `trained/marl_agent_4.pt`
- `trained/marl_metadata.json`

## Note
Training takes 5-7 hours due to 150 episodes x multi-agent coordination

In [None]:
!nvidia-smi
import torch
print(f"PyTorch: {torch.__version__}, CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available(): print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
!pip install -q torch onnx onnxruntime-gpu pandas numpy scikit-learn scipy requests tqdm
print("‚úì Dependencies installed!")

In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from sklearn.preprocessing import RobustScaler
from collections import deque
from pathlib import Path
import json, time, random, warnings
warnings.filterwarnings('ignore')

TRAINED_DIR = Path("trained")
TRAINED_DIR.mkdir(parents=True, exist_ok=True)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {DEVICE}")

In [None]:
import requests
from datetime import datetime, timedelta
from tqdm.notebook import tqdm

def fetch_klines_sync(symbol, days=90):
    base_url = "https://api.binance.com/api/v3/klines"
    end_time = datetime.utcnow()
    start_time = end_time - timedelta(days=days)
    all_data = []
    current = start_time
    while current < end_time:
        params = {"symbol": symbol, "interval": "1m",
                  "startTime": int(current.timestamp()*1000),
                  "endTime": int(min(current+timedelta(days=1), end_time).timestamp()*1000), "limit": 1440}
        try:
            resp = requests.get(base_url, params=params, timeout=30)
            data = resp.json()
            if isinstance(data, list): all_data.extend(data)
        except: pass
        current += timedelta(days=1)
        time.sleep(0.1)
    if not all_data: return pd.DataFrame()
    cols = ["open_time","open","high","low","close","volume","close_time","quote_volume","trades","taker_buy_base","taker_buy_quote","ignore"]
    df = pd.DataFrame(all_data, columns=cols)
    df["open_time"] = pd.to_datetime(df["open_time"], unit="ms")
    for c in ["open","high","low","close","volume","quote_volume","taker_buy_base","taker_buy_quote"]: df[c] = pd.to_numeric(df[c], errors="coerce")
    df["symbol"] = symbol
    return df.drop_duplicates(subset=["open_time"]).sort_values("open_time")

def calculate_comprehensive_features(df):
    """Calculate ~150 institutional-grade crypto features"""
    df = df.copy()
    ann_factor = np.sqrt(252 * 24 * 60)

    # 1. RETURNS & PRICE ACTION
    df["log_return"] = np.log(df["close"] / df["close"].shift(1))
    df["return_1"] = df["close"].pct_change(1)
    for w in [5, 10, 20, 50, 100, 200]:
        df[f"return_{w}"] = df["close"].pct_change(w)
    for w in [20, 50]:
        vol = df["log_return"].rolling(w).std()
        df[f"sharpe_{w}"] = df[f"return_{w}"] / (vol * np.sqrt(w) + 1e-10)

    # 2. VOLATILITY (multiple estimators)
    for w in [5, 10, 20, 50, 100]:
        df[f"volatility_{w}"] = df["log_return"].rolling(w).std() * ann_factor
    for w in [20, 50]:
        log_hl = np.log(df["high"] / df["low"])
        df[f"parkinson_vol_{w}"] = np.sqrt((1/(4*np.log(2))) * (log_hl**2).rolling(w).mean()) * ann_factor
        log_co = np.log(df["close"] / df["open"])
        gk = 0.5 * log_hl**2 - (2*np.log(2) - 1) * log_co**2
        df[f"gk_vol_{w}"] = np.sqrt(gk.rolling(w).mean().abs()) * ann_factor
    for w in [14, 20, 50]:
        tr = pd.concat([df["high"] - df["low"], abs(df["high"] - df["close"].shift(1)), abs(df["low"] - df["close"].shift(1))], axis=1).max(axis=1)
        df[f"atr_{w}"] = tr.rolling(w).mean()
        df[f"atr_pct_{w}"] = df[f"atr_{w}"] / df["close"] * 100
    df["vol_regime"] = df["volatility_20"] / (df["volatility_100"] + 1e-10)

    # 3. VOLUME (CVD, VWAP, trades)
    for w in [5, 10, 20, 50]:
        df[f"volume_ma_{w}"] = df["volume"].rolling(w).mean()
    df["rvol_20"] = df["volume"] / (df["volume"].rolling(20).mean() + 1e-10)
    df["volume_zscore"] = (df["volume"] - df["volume"].rolling(50).mean()) / (df["volume"].rolling(50).std() + 1e-10)
    typical_price = (df["high"] + df["low"] + df["close"]) / 3
    for w in [20, 50]:
        cum_vol = df["volume"].rolling(w).sum()
        cum_tp_vol = (typical_price * df["volume"]).rolling(w).sum()
        df[f"vwap_dist_{w}"] = (df["close"] - cum_tp_vol/(cum_vol+1e-10)) / (cum_tp_vol/(cum_vol+1e-10)+1e-10) * 100
    volume_delta = df["taker_buy_base"] - (df["volume"] - df["taker_buy_base"])
    for w in [10, 20, 50]:
        df[f"cvd_{w}"] = volume_delta.rolling(w).sum()
        df[f"cvd_norm_{w}"] = df[f"cvd_{w}"] / (df["volume"].rolling(w).sum() + 1e-10)
    df["dollar_vol_ratio"] = df["quote_volume"] / (df["quote_volume"].rolling(20).mean() + 1e-10)

    # 4. MICROSTRUCTURE
    df["spread_bps"] = (df["high"] - df["low"]) / df["close"] * 10000
    df["ofi"] = df["taker_buy_base"] / (df["volume"] + 1e-10)
    for w in [10, 20, 50]:
        df[f"buy_pressure_{w}"] = df["taker_buy_base"].rolling(w).sum() / (df["volume"].rolling(w).sum() + 1e-10)
    df["amihud"] = abs(df["return_1"]) / (df["quote_volume"] / 1e6 + 1e-10)

    # 5. MOMENTUM (MACD, RSI, ADX, etc.)
    for w in [5, 10, 20, 50, 100]:
        df[f"ma_dist_{w}"] = (df["close"] - df["close"].rolling(w).mean()) / df["close"].rolling(w).mean() * 100
    ema12 = df["close"].ewm(span=12, adjust=False).mean()
    ema26 = df["close"].ewm(span=26, adjust=False).mean()
    df["macd"] = ema12 - ema26
    df["macd_signal"] = df["macd"].ewm(span=9, adjust=False).mean()
    df["macd_hist"] = df["macd"] - df["macd_signal"]
    for w in [7, 14, 21]:
        delta = df["close"].diff()
        gain = delta.where(delta > 0, 0).rolling(w).mean()
        loss = (-delta.where(delta < 0, 0)).rolling(w).mean()
        df[f"rsi_{w}"] = 100 - (100 / (1 + gain/(loss+1e-10)))
        df[f"rsi_{w}_norm"] = (df[f"rsi_{w}"] - 50) / 50
    rsi14 = df["rsi_14"]
    rsi_min, rsi_max = rsi14.rolling(14).min(), rsi14.rolling(14).max()
    df["stoch_rsi"] = (rsi14 - rsi_min) / (rsi_max - rsi_min + 1e-10)
    for w in [14, 21]:
        highest, lowest = df["high"].rolling(w).max(), df["low"].rolling(w).min()
        df[f"williams_r_{w}"] = -100 * (highest - df["close"]) / (highest - lowest + 1e-10)
    for w in [14, 20]:
        plus_dm = df["high"].diff().where(lambda x: x > 0, 0)
        minus_dm = (-df["low"].diff()).where(lambda x: x > 0, 0)
        tr = pd.concat([df["high"]-df["low"], abs(df["high"]-df["close"].shift(1)), abs(df["low"]-df["close"].shift(1))], axis=1).max(axis=1)
        atr = tr.rolling(w).mean()
        plus_di = 100 * (plus_dm.rolling(w).mean() / (atr + 1e-10))
        minus_di = 100 * (minus_dm.rolling(w).mean() / (atr + 1e-10))
        df[f"adx_{w}"] = (100 * abs(plus_di - minus_di) / (plus_di + minus_di + 1e-10)).rolling(w).mean()
    tp = (df["high"] + df["low"] + df["close"]) / 3
    df["cci_20"] = (tp - tp.rolling(20).mean()) / (0.015 * tp.rolling(20).std() + 1e-10)

    # 6. MEAN REVERSION (Bollinger, z-scores)
    for w in [20, 50]:
        ma, std = df["close"].rolling(w).mean(), df["close"].rolling(w).std()
        df[f"bb_width_{w}"] = (4 * std) / ma * 100
        df[f"bb_position_{w}"] = (df["close"] - (ma - 2*std)) / (4*std + 1e-10)
        df[f"price_zscore_{w}"] = (df["close"] - ma) / (std + 1e-10)

    # 7. TIME FEATURES
    hour = df["open_time"].dt.hour
    dow = df["open_time"].dt.dayofweek
    df["hour_sin"] = np.sin(2 * np.pi * hour / 24)
    df["hour_cos"] = np.cos(2 * np.pi * hour / 24)
    df["dow_sin"] = np.sin(2 * np.pi * dow / 7)
    df["dow_cos"] = np.cos(2 * np.pi * dow / 7)
    df["is_asia"] = ((hour >= 0) & (hour < 8)).astype(int)
    df["is_europe"] = ((hour >= 7) & (hour < 16)).astype(int)
    df["is_us"] = ((hour >= 13) & (hour < 22)).astype(int)
    df["is_weekend"] = (dow >= 5).astype(int)

    # 8. STATISTICAL
    for w in [20, 50]:
        df[f"skewness_{w}"] = df["log_return"].rolling(w).skew()
        df[f"kurtosis_{w}"] = df["log_return"].rolling(w).kurt()

    # 9. PRICE PATTERNS
    for w in [20, 50, 100]:
        highest, lowest = df["high"].rolling(w).max(), df["low"].rolling(w).min()
        df[f"dist_from_high_{w}"] = (df["close"] - highest) / highest * 100
        df[f"dist_from_low_{w}"] = (df["close"] - lowest) / lowest * 100
        df[f"range_position_{w}"] = (df["close"] - lowest) / (highest - lowest + 1e-10)

    return df

def get_feature_columns(df):
    exclude = ["open_time","close_time","symbol","ignore","open","high","low","close","volume","quote_volume","trades","taker_buy_base","taker_buy_quote","hour","day_of_week"]
    return [c for c in df.columns if c not in exclude and not c.startswith("target_")]

In [None]:
SYMBOLS = ["BTCUSDT", "ETHUSDT", "BNBUSDT", "SOLUSDT"]
print("Collecting data...")
all_data = []
for sym in tqdm(SYMBOLS):
    df = fetch_klines_sync(sym, days=90)
    if len(df) > 0:
        all_data.append(df)
        print(f"  ‚úì {sym}: {len(df):,} rows")

if not all_data: raise ValueError("No data collected!")
raw_data = pd.concat(all_data, ignore_index=True)
print(f"\n‚úì Total: {len(raw_data):,} rows")

# Per-symbol split - MARL uses ONLY training data with comprehensive features
train_dfs = []
for sym in raw_data["symbol"].unique():
    sdf = raw_data[raw_data["symbol"]==sym].copy().sort_values("open_time").reset_index(drop=True)
    sdf = calculate_comprehensive_features(sdf)  # ~150 features
    sdf = sdf.replace([np.inf,-np.inf], np.nan).iloc[200:].dropna()  # Extended warmup
    n = len(sdf)
    train_end = int(n * 0.70)
    train_dfs.append(sdf.iloc[:train_end])
    print(f"{sym}: {train_end:,} train rows")

train_df = pd.concat(train_dfs).sort_values("open_time").reset_index(drop=True)
print(f"\n‚úì Total train: {len(train_df):,}")
print(f"‚úì Features: {len(get_feature_columns(train_df))}")

In [None]:
# Prepare RL data
feature_cols = get_feature_columns(train_df)
ohlcv_cols = ["open", "high", "low", "close", "volume"]

scaler = RobustScaler()
rl_features = scaler.fit_transform(train_df[feature_cols].values)
rl_ohlcv = train_df[ohlcv_cols].values

print(f"RL Features: {rl_features.shape}")

In [None]:
# MARL Agent Network
class MARLAgent(nn.Module):
    def __init__(self, state_dim, action_dim=1, hidden_dim=128, message_dim=32, n_agents=5):
        super().__init__()
        self.state_enc = nn.Sequential(
            nn.Linear(state_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU()
        )
        self.msg_enc = nn.Sequential(
            nn.Linear(message_dim * (n_agents - 1), hidden_dim // 2), nn.ReLU()
        )
        self.policy = nn.Sequential(
            nn.Linear(hidden_dim + hidden_dim // 2, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, action_dim), nn.Tanh()
        )
        self.msg_gen = nn.Sequential(
            nn.Linear(hidden_dim, message_dim), nn.Tanh()
        )

    def forward(self, state, messages):
        state_emb = self.state_enc(state)
        msg_emb = self.msg_enc(messages.flatten(start_dim=-2))
        combined = torch.cat([state_emb, msg_emb], dim=-1)
        action = self.policy(combined)
        msg_out = self.msg_gen(state_emb)
        return action, msg_out

In [None]:
# Trading Environment for MARL
class TradingEnvMARL:
    def __init__(self, data, features, initial_balance=100000, transaction_cost=0.0005):
        self.data = data
        self.features = features
        self.initial_balance = initial_balance
        self.transaction_cost = transaction_cost
        self.reset()

    def reset(self):
        self.balance = self.initial_balance
        self.position = 0.0
        self.step_idx = 0
        self.returns = []
        return self._get_state()

    def _get_state(self):
        market = self.features[self.step_idx]
        portfolio = np.array([
            self.position,
            self.balance / self.initial_balance - 1,
            np.mean(self.returns[-20:]) if self.returns else 0,
            np.std(self.returns[-20:]) if len(self.returns) > 1 else 0
        ])
        return np.concatenate([market, portfolio])

    def step(self, action):
        target_pos = float(np.clip(action[0], -1, 1))
        pos_change = target_pos - self.position
        current_price = self.data[self.step_idx, 3]
        cost = abs(pos_change) * current_price * self.transaction_cost

        self.step_idx += 1
        done = self.step_idx >= len(self.data) - 1

        if not done:
            next_price = self.data[self.step_idx, 3]
            ret = (next_price - current_price) / current_price
            pnl = self.position * ret * self.balance - cost
            self.balance += pnl
            step_ret = pnl / self.initial_balance
            self.returns.append(step_ret)
            self.position = target_pos

            if len(self.returns) > 1:
                reward = np.mean(self.returns[-20:]) / (np.std(self.returns[-20:]) + 1e-8)
            else:
                reward = step_ret * 100
        else:
            reward = 0

        return self._get_state() if not done else np.zeros_like(self._get_state()), reward, done

In [None]:
# Training MARL
print("="*60)
print("TRAINING MARL (CUDA)")
print("="*60)
print("WARNING: This will take 5-7 hours")
print("="*60)

n_agents = 5
message_dim = 32
state_dim = rl_features.shape[1] + 4

agents = [MARLAgent(state_dim, message_dim=message_dim, n_agents=n_agents).to(DEVICE) for _ in range(n_agents)]
optimizers = [torch.optim.Adam(a.parameters(), lr=3e-4) for a in agents]

# Centralized critic
critic = nn.Sequential(
    nn.Linear(n_agents * state_dim + n_agents, 256), nn.ReLU(),
    nn.Linear(256, 256), nn.ReLU(),
    nn.Linear(256, n_agents)
).to(DEVICE)
critic_opt = torch.optim.Adam(critic.parameters(), lr=3e-4)

buffer = deque(maxlen=50000)
batch_size = 256
episodes = 150

# Track episode metrics for evaluation
episode_returns = []
episode_sharpes = []

start_time = time.time()
for ep in range(episodes):
    env = TradingEnvMARL(rl_ohlcv, rl_features)
    base_state = env.reset()

    # Each agent gets slightly different observation (noise)
    states = np.array([base_state + np.random.normal(0, 0.01, size=base_state.shape) for _ in range(n_agents)])
    messages = np.zeros((n_agents, message_dim))

    ep_reward = 0

    while True:
        actions = []
        new_messages = []

        for i, agent in enumerate(agents):
            other_msgs = np.stack([messages[j] for j in range(n_agents) if j != i])
            state_t = torch.FloatTensor(states[i]).unsqueeze(0).to(DEVICE)
            msgs_t = torch.FloatTensor(other_msgs).unsqueeze(0).to(DEVICE)

            with torch.no_grad():
                action, msg = agent(state_t, msgs_t)

            action = action.cpu().numpy()[0] + np.random.normal(0, 0.1, size=(1,))
            action = np.clip(action, -1, 1)
            actions.append(action)
            new_messages.append(msg.cpu().numpy()[0])

        messages = np.array(new_messages)

        # Aggregate actions (mean)
        agg_action = np.mean(actions, axis=0)
        next_base_state, reward, done = env.step(agg_action)

        next_states = np.array([next_base_state + np.random.normal(0, 0.01, size=next_base_state.shape) for _ in range(n_agents)])
        rewards = np.full(n_agents, reward)

        buffer.append((states.copy(), np.array(actions), rewards, next_states.copy(), done))

        states = next_states
        ep_reward += reward

        # Training
        if len(buffer) >= batch_size:
            batch = random.sample(buffer, batch_size)
            b_states, b_actions, b_rewards, _, _ = zip(*batch)

            b_states = torch.FloatTensor(np.array(b_states)).to(DEVICE)
            b_actions = torch.FloatTensor(np.array(b_actions)).to(DEVICE)
            b_rewards = torch.FloatTensor(np.array(b_rewards)).to(DEVICE)

            # Critic update
            critic_input = torch.cat([b_states.view(batch_size, -1), b_actions.view(batch_size, -1)], dim=-1)
            q_values = critic(critic_input)
            critic_loss = nn.MSELoss()(q_values, b_rewards)

            critic_opt.zero_grad()
            critic_loss.backward()
            critic_opt.step()

        if done:
            break

    # Track metrics
    total_ret = (env.balance - env.initial_balance) / env.initial_balance
    episode_returns.append(total_ret)
    if len(env.returns) > 1:
        ep_sharpe = np.mean(env.returns) / (np.std(env.returns) + 1e-10) * np.sqrt(252*24*60)
        episode_sharpes.append(ep_sharpe)

    if (ep + 1) % 20 == 0:
        print(f"Episode {ep+1}/{episodes} - Return: {total_ret:.2%}")

train_time = time.time() - start_time
print(f"\n‚úì Training time: {train_time/3600:.2f} hours")

In [None]:
# Comprehensive MARL Evaluation with Overfitting Detection
from scipy.stats import spearmanr

def comprehensive_marl_metrics(env, episode_returns, episode_sharpes, model_name="MARL"):
    """Calculate comprehensive metrics for MARL model evaluation"""
    returns = np.array(env.returns)
    
    if len(returns) == 0:
        print(f"‚ö†Ô∏è No returns to evaluate for {model_name}")
        return {}
    
    ann_factor = np.sqrt(252 * 24 * 60)
    
    # Basic metrics
    total_return = (env.balance - env.initial_balance) / env.initial_balance
    mean_return = np.mean(returns)
    std_return = np.std(returns)
    
    # Risk-adjusted metrics
    sharpe = (mean_return / (std_return + 1e-10)) * ann_factor
    
    # Sortino ratio
    downside_returns = returns[returns < 0]
    downside_std = np.std(downside_returns) if len(downside_returns) > 0 else 1e-10
    sortino = (mean_return / (downside_std + 1e-10)) * ann_factor
    
    # Maximum Drawdown
    cumulative = np.cumsum(returns)
    running_max = np.maximum.accumulate(cumulative)
    drawdowns = running_max - cumulative
    max_dd = np.max(drawdowns) if len(drawdowns) > 0 else 0
    
    # Calmar ratio
    calmar = total_return / (max_dd + 1e-10) if max_dd > 0 else 0
    
    # Win rate & profit factor
    wins = returns[returns > 0]
    losses = returns[returns < 0]
    win_rate = len(wins) / len(returns) if len(returns) > 0 else 0
    profit_factor = abs(np.sum(wins) / (np.sum(losses) + 1e-10)) if len(losses) > 0 else 0
    
    # Tail risk metrics
    var_95 = np.percentile(returns, 5)
    var_99 = np.percentile(returns, 1)
    cvar_95 = np.mean(returns[returns <= var_95]) if len(returns[returns <= var_95]) > 0 else 0
    
    # Training stability metrics
    if len(episode_returns) > 10:
        first_third = episode_returns[:len(episode_returns)//3]
        last_third = episode_returns[-len(episode_returns)//3:]
        learning_improvement = np.mean(last_third) - np.mean(first_third)
        return_volatility = np.std(episode_returns)
    else:
        learning_improvement = 0
        return_volatility = 0
    
    # Sharpe stability across episodes
    if len(episode_sharpes) > 10:
        sharpe_std = np.std(episode_sharpes)
        sharpe_mean = np.mean(episode_sharpes)
        sharpe_stability = 1 - (sharpe_std / (abs(sharpe_mean) + 1e-10))
    else:
        sharpe_stability = 0
    
    metrics = {
        "total_return": total_return,
        "mean_return": mean_return,
        "std_return": std_return,
        "sharpe": sharpe,
        "sortino": sortino,
        "max_drawdown": max_dd,
        "calmar": calmar,
        "win_rate": win_rate,
        "profit_factor": profit_factor,
        "var_95": var_95,
        "var_99": var_99,
        "cvar_95": cvar_95,
        "learning_improvement": learning_improvement,
        "return_volatility_across_episodes": return_volatility,
        "sharpe_stability": sharpe_stability,
        "n_trades": len(returns),
        "n_episodes": len(episode_returns)
    }
    
    return metrics

def print_marl_metrics(metrics, model_name="MARL"):
    """Print formatted metrics"""
    print(f"\n{'='*60}")
    print(f"{model_name} COMPREHENSIVE EVALUATION")
    print(f"{'='*60}")
    
    print(f"\nüìä RETURN METRICS:")
    print(f"  Total Return:     {metrics['total_return']:.4%}")
    print(f"  Mean Return:      {metrics['mean_return']:.6f}")
    print(f"  Std Return:       {metrics['std_return']:.6f}")
    
    print(f"\nüìà RISK-ADJUSTED METRICS:")
    print(f"  Sharpe Ratio:     {metrics['sharpe']:.4f}")
    print(f"  Sortino Ratio:    {metrics['sortino']:.4f}")
    print(f"  Calmar Ratio:     {metrics['calmar']:.4f}")
    
    print(f"\nüìâ RISK METRICS:")
    print(f"  Max Drawdown:     {metrics['max_drawdown']:.4%}")
    print(f"  VaR 95%:          {metrics['var_95']:.6f}")
    print(f"  VaR 99%:          {metrics['var_99']:.6f}")
    print(f"  CVaR 95%:         {metrics['cvar_95']:.6f}")
    
    print(f"\nüéØ TRADING METRICS:")
    print(f"  Win Rate:         {metrics['win_rate']:.2%}")
    print(f"  Profit Factor:    {metrics['profit_factor']:.4f}")
    print(f"  Total Trades:     {metrics['n_trades']:,}")
    
    print(f"\nüî¨ TRAINING STABILITY:")
    print(f"  Learning Improvement: {metrics['learning_improvement']:.4%}")
    print(f"  Return Volatility:    {metrics['return_volatility_across_episodes']:.4f}")
    print(f"  Sharpe Stability:     {metrics['sharpe_stability']:.4f}")
    print(f"  Episodes Trained:     {metrics['n_episodes']}")

def overfitting_analysis_marl(metrics, episode_returns, episode_sharpes):
    """Analyze potential overfitting in MARL model"""
    print(f"\n{'='*60}")
    print("MARL OVERFITTING ANALYSIS")
    print(f"{'='*60}")
    
    warnings = []
    
    # Check for unrealistic metrics
    if metrics['sharpe'] > 3.0:
        warnings.append(f"‚ö†Ô∏è HIGH SHARPE ({metrics['sharpe']:.2f}) - Possible overfitting")
    
    if metrics['win_rate'] > 0.60:
        warnings.append(f"‚ö†Ô∏è HIGH WIN RATE ({metrics['win_rate']:.2%}) - Check for data leakage")
    
    if metrics['sharpe_stability'] < 0.3:
        warnings.append(f"‚ö†Ô∏è LOW SHARPE STABILITY ({metrics['sharpe_stability']:.2f}) - Inconsistent across episodes")
    
    if metrics['profit_factor'] > 3.0:
        warnings.append(f"‚ö†Ô∏è HIGH PROFIT FACTOR ({metrics['profit_factor']:.2f}) - Suspiciously good")
    
    if metrics['max_drawdown'] < 0.01:
        warnings.append(f"‚ö†Ô∏è VERY LOW DRAWDOWN ({metrics['max_drawdown']:.4%}) - Unrealistic")
    
    # Learning curve analysis
    if len(episode_returns) > 20:
        # Check for learning plateau or decline
        last_20 = episode_returns[-20:]
        if np.mean(last_20) < np.mean(episode_returns[:-20]):
            warnings.append("‚ö†Ô∏è PERFORMANCE DECLINE - Model may be overfitting in later episodes")
        
        # Check for unstable learning
        if np.std(last_20) > np.mean(np.abs(last_20)):
            warnings.append("‚ö†Ô∏è UNSTABLE LEARNING - High variance in recent episodes")
    
    # Multi-agent coordination check
    if metrics['learning_improvement'] < 0:
        warnings.append("‚ö†Ô∏è NEGATIVE LEARNING - Model got worse during training")
    
    if len(warnings) == 0:
        print("‚úÖ No obvious overfitting signals detected")
        print("   - Sharpe ratio in realistic range")
        print("   - Win rate not suspiciously high")
        print("   - Learning curve shows improvement")
        print("   - Multi-agent coordination appears stable")
    else:
        for w in warnings:
            print(w)
    
    # Final verdict
    print(f"\nüìã VERDICT:")
    if len(warnings) <= 1:
        print("‚úÖ MARL model appears well-calibrated for production")
    elif len(warnings) <= 3:
        print("‚ö†Ô∏è Some concerns - recommend additional validation")
    else:
        print("‚ùå Multiple overfitting signals - DO NOT deploy without investigation")
    
    return warnings

# Run MARL evaluation
marl_metrics = comprehensive_marl_metrics(env, episode_returns, episode_sharpes, "MARL")
print_marl_metrics(marl_metrics, "MARL")
overfitting_warnings = overfitting_analysis_marl(marl_metrics, episode_returns, episode_sharpes)

# Save metrics
import json
eval_metrics = {
    **marl_metrics,
    "episode_returns_mean": float(np.mean(episode_returns)),
    "episode_returns_std": float(np.std(episode_returns)),
    "overfitting_warnings": len(overfitting_warnings)
}
with open(TRAINED_DIR / "marl_evaluation.json", "w") as f:
    json.dump(eval_metrics, f, indent=2)

In [None]:
# Export MARL Agents to ONNX (FIXED: opset_version=15, save to trained/)
import onnx

class AgentActionExtractor(nn.Module):
    def __init__(self, agent):
        super().__init__()
        self.agent = agent

    def forward(self, state, messages):
        action, _ = self.agent(state, messages)
        return action

for i, agent in enumerate(agents):
    extractor = AgentActionExtractor(agent)
    extractor.eval()

    dummy_state = torch.randn(1, state_dim).to(DEVICE)
    dummy_msgs = torch.randn(1, n_agents - 1, message_dim).to(DEVICE)

    # Save directly to trained/ directory
    onnx_path = TRAINED_DIR / f"marl_agent_{i}.onnx"
    torch.onnx.export(
        extractor, (dummy_state, dummy_msgs), str(onnx_path),
        input_names=["state", "messages"], output_names=["action"],
        dynamic_axes={"state": {0: "batch"}, "messages": {0: "batch"}, "action": {0: "batch"}},
        opset_version=15  # FIXED: Changed from 17 to 15 for Colab ONNX compatibility
    )

    onnx.checker.check_model(onnx.load(str(onnx_path)))
    print(f"‚úì Agent {i} saved: {onnx_path}")

    # Save PyTorch
    torch.save(agent.state_dict(), TRAINED_DIR / f"marl_agent_{i}.pt")

# Metadata with evaluation metrics
metadata = {
    "model_type": "marl",
    "n_agents": n_agents,
    "state_dim": state_dim,
    "message_dim": message_dim,
    "train_time_hours": train_time / 3600,
    "evaluation": eval_metrics
}
with open(TRAINED_DIR / "marl_metadata.json", "w") as f:
    json.dump(metadata, f, indent=2)

print("\n‚úì MARL TRAINING COMPLETE!")