In [None]:
from __future__ import annotations

import os
import sys
import random
from pathlib import Path
from typing import List, Tuple, Sequence, Optional, Any
from collections import deque
import math
import copy

import numpy as np
import pandas as pd
import pandas.api.types
import polars as pl
from sklearn.preprocessing import StandardScaler, RobustScaler
from tqdm.notebook import tqdm
from scipy.optimize import minimize
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import matplotlib.pyplot as plt
import seaborn as sns

# --- [Configuration] ---
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (15, 10)
plt.rcParams['axes.unicode_minus'] = False 

MIN_INVESTMENT = 0
MAX_INVESTMENT = 2 
USE_SAM = True  
EMA_DECAY = 0.999
REPLAY_BUFFER_SIZE = 252 
BATCH_SIZE = 32
LEARNING_RATE = 1e-4
NUM_SEEDS = 2           # [NEW] Ensemble 개수 (속도와 성능의 타협점)
TARGET_VOLATILITY = 0.15 # [NEW] 목표 연간 변동성 (15%)

class ParticipantVisibleError(Exception):
    pass

IS_KAGGLE = Path('/kaggle').exists()

if IS_KAGGLE:
    INPUT_DIR = Path('/kaggle/input/hull-tactical-market-prediction')
    sys.path.append(str(INPUT_DIR))
else:
    INPUT_DIR = Path('.')
    sys.path.append(os.getcwd())

try:
    import kaggle_evaluation.default_inference_server
except ImportError:
    pass

TRAIN_PATH = INPUT_DIR / 'train.csv'
TEST_PATH = INPUT_DIR / 'test.csv'
TARGET_COL = 'forward_returns'
DATE_COL = 'date_id'
BENCHMARK_COL = 'market_forward_excess_returns'

# --- [Utils] ---
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(42)

class SlidingWindowBuffer:
    def __init__(self, capacity=252):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state: np.ndarray, target: float):
        self.buffer.append((state, target))
    
    def sample(self, batch_size: int) -> Tuple[np.ndarray, np.ndarray]:
        curr_len = len(self.buffer)
        if curr_len == 0:
            return np.array([]), np.array([])
        sample_size = min(curr_len, batch_size)
        # Prioritize recent data but keep some randomness if needed. 
        # Here we simple take latest window for Online Learning stability.
        batch = list(self.buffer)[-sample_size:]
        states, targets = zip(*batch)
        return np.array(states), np.array(targets)
    
    def __len__(self):
        return len(self.buffer)

# --- [Optimizer: SAM] ---
class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
        defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
        super(SAM, self).__init__(params, defaults)
        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups
        self.defaults.update(self.base_optimizer.defaults)

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)
            for p in group["params"]:
                if p.grad is None: continue
                self.state[p]["old_p"] = p.data.clone()
                e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
                p.add_(e_w)
        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None: continue
                p.data = self.state[p]["old_p"]
        self.base_optimizer.step()
        if zero_grad: self.zero_grad()

    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device
        norm = torch.norm(
                    torch.stack([
                        ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
                        for group in self.param_groups for p in group["params"]
                        if p.grad is not None
                    ]),
                    p=2
               )
        return norm
    
    def zero_grad(self, set_to_none: bool = False):
        self.base_optimizer.zero_grad(set_to_none)

# --- [Data Loading & Evaluation] ---
def load_data() -> Tuple[pd.DataFrame, pd.DataFrame]:
    if not TRAIN_PATH.exists():
        print('Train file not found.')
        return pd.DataFrame(), pd.DataFrame()
    train_df = pd.read_csv(TRAIN_PATH).sort_values(DATE_COL).reset_index(drop=True)
    if TEST_PATH.exists():
        test_df = pd.read_csv(TEST_PATH).sort_values(DATE_COL).reset_index(drop=True)
    else:
        test_df = pd.DataFrame()
    return train_df, test_df

def score(solution: pd.DataFrame, submission: pd.DataFrame, row_id_column_name: str = DATE_COL) -> float:
    # (기존 score 함수와 동일하므로 생략 없이 유지)
    if 'prediction' not in submission.columns:
        raise ParticipantVisibleError('Submission must contain a prediction column')
    
    sol = solution.copy()
    sol['position'] = submission['prediction']
    
    if 'risk_free_rate' not in sol.columns:
        sol['risk_free_rate'] = 0.0

    sol['strategy_returns'] = sol['risk_free_rate'] * (1 - sol['position']) + sol['position'] * sol['forward_returns']
    strategy_excess_returns = sol['strategy_returns'] - sol['risk_free_rate']
    
    trading_days_per_yr = 252
    
    mean_excess = strategy_excess_returns.mean()
    std_excess = strategy_excess_returns.std()
    
    if std_excess == 0:
        return 0.0

    sharpe = (mean_excess / std_excess) * np.sqrt(trading_days_per_yr)
    
    market_std = sol['forward_returns'].std()
    strategy_volatility = std_excess * np.sqrt(trading_days_per_yr) * 100
    market_volatility = market_std * np.sqrt(trading_days_per_yr) * 100
    
    if market_volatility == 0: 
        return 0.0

    excess_vol = max(0.0, strategy_volatility / market_volatility - 1.2)
    vol_penalty = 1 + excess_vol
    
    strat_ann_ret = mean_excess * trading_days_per_yr * 100
    
    if BENCHMARK_COL in sol.columns:
        market_mean_excess = sol[BENCHMARK_COL].mean()
    else:
        market_mean_excess = (sol['forward_returns'] - sol['risk_free_rate']).mean()
        
    market_ann_ret = market_mean_excess * trading_days_per_yr * 100
    
    return_gap = max(0.0, market_ann_ret - strat_ann_ret)
    return_penalty = 1 + (return_gap ** 2) / 100
    
    return float(sharpe / (vol_penalty * return_penalty))

# --- [Feature Engineering] ---
# [Improvement 1] Advanced Alpha Factors added
def engineer_features(df: pd.DataFrame, numeric_cols: Sequence[str], show_progress: bool = False) -> Tuple[pd.DataFrame, List[str]]:
    feats = df.copy()
    new_features = []
    created_cols = []

    lag_windows = (1, 2, 3, 5, 10, 21)
    roll_windows = (5, 10, 21, 63)

    def calculate_rsi(series, period=14):
        delta = series.diff().fillna(0)
        gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
        loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
        rs = gain / (loss + 1e-8)
        return 100 - (100 / (1 + rs))

    iterator = tqdm(numeric_cols, desc='Engineering Features', disable=not show_progress)

    for col in iterator:
        if col not in feats.columns: continue
        if not pd.api.types.is_numeric_dtype(feats[col]): continue

        # 1. Basic Lags
        for lag in lag_windows:
            lag_col = f'{col}_lag_{lag}'
            new_features.append(feats[col].shift(lag).rename(lag_col))
            created_cols.append(lag_col)

        # 2. Rolling Stats
        for window in roll_windows:
            rolling = feats[col].rolling(window)
            
            # Mean & Std
            roll_mean = f'{col}_roll_mean_{window}'
            roll_std = f'{col}_roll_std_{window}'
            new_features.append(rolling.mean().shift(1).rename(roll_mean))
            new_features.append(rolling.std(ddof=0).shift(1).rename(roll_std))
            created_cols.extend([roll_mean, roll_std])

            # [NEW] Skewness & Kurtosis (Higher Moments)
            if window >= 21:
                roll_skew = f'{col}_skew_{window}'
                roll_kurt = f'{col}_kurt_{window}'
                new_features.append(rolling.skew().shift(1).rename(roll_skew))
                new_features.append(rolling.kurt().shift(1).rename(roll_kurt))
                created_cols.extend([roll_skew, roll_kurt])

        # 3. RSI
        rsi_col_name = f'{col}_rsi_14'
        rsi_series = calculate_rsi(feats[col], 14).rename(rsi_col_name)
        new_features.append(rsi_series)
        created_cols.append(rsi_col_name)

        # 4. Volatility Ratio
        vol_ratio_name = f'{col}_vol_ratio'
        roll_std_5 = feats[col].rolling(5).std()
        roll_std_21 = feats[col].rolling(21).std()
        vol_ratio = (roll_std_5 / (roll_std_21 + 1e-8)).rename(vol_ratio_name)
        new_features.append(vol_ratio)
        created_cols.append(vol_ratio_name)
        
        # [NEW] 5. MACD (Moving Average Convergence Divergence)
        ema_12 = feats[col].ewm(span=12, adjust=False).mean()
        ema_26 = feats[col].ewm(span=26, adjust=False).mean()
        macd = (ema_12 - ema_26).rename(f'{col}_macd')
        new_features.append(macd)
        created_cols.append(f'{col}_macd')

        # [NEW] 6. Bollinger Bands Width
        bb_mean = feats[col].rolling(20).mean()
        bb_std = feats[col].rolling(20).std()
        bb_upper = bb_mean + 2 * bb_std
        bb_lower = bb_mean - 2 * bb_std
        bb_width = ((bb_upper - bb_lower) / (bb_mean + 1e-8)).rename(f'{col}_bb_width')
        new_features.append(bb_width)
        created_cols.append(f'{col}_bb_width')

    if new_features:
        feats = pd.concat([feats] + new_features, axis=1)

    feats = feats.copy() 
    feats = feats.replace([np.inf, -np.inf], np.nan)
    
    base_cols = [c for c in numeric_cols if c in feats.columns]
    feature_cols = base_cols + created_cols
    return feats, feature_cols

class FeatureGenerator:
    def __init__(self, numeric_cols: List[str]):
        self.numeric_cols = numeric_cols
        self.history = pd.DataFrame()

    def fit(self, df: pd.DataFrame):
        if not df.empty:
            # Need larger history for MACD/Skew calculations
            self.history = df[self.numeric_cols].iloc[-300:].copy() 

    def transform(self, new_df: pd.DataFrame) -> pd.DataFrame:
        new_data_subset = new_df[self.numeric_cols]
        combined = pd.concat([self.history, new_data_subset], axis=0)
        feats, _ = engineer_features(combined, self.numeric_cols)
        new_feats = feats.iloc[-len(new_df):].copy()
        self.history = combined.iloc[-300:]
        return new_feats

# --- [Model Architecture] ---
# [Improvement 5] SE-Block for Feature Attention
class SEBlock(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        # x: (batch, features) -> Unsqueeze to (batch, features, 1) for pooling simulation if needed, 
        # but since input is already 1D per sample, we treat 'features' as 'channels'.
        b, c = x.size()
        y = self.fc(x)
        return x * y

class ResidualBlock(nn.Module):
    def __init__(self, dim, dropout=0.3):
        super().__init__()
        self.block = nn.Sequential(
            nn.Linear(dim, dim),
            nn.LayerNorm(dim),
            nn.SiLU(),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return x + self.block(x)

class MultiScaleEnsemble(nn.Module):
    def __init__(self, feature_names: List[str], hidden_dim: int = 128):
        super().__init__()
        self.feature_names = feature_names
        
        # Feature grouping logic
        self.short_indices = []
        self.mid_indices = []
        self.long_indices = []
        self.base_indices = []

        for i, name in enumerate(feature_names):
            if '63' in name or 'macd' in name or 'kurt' in name or 'skew' in name:
                self.long_indices.append(i)
            elif '21' in name or '10' in name or 'bb_width' in name:
                self.mid_indices.append(i)
            elif 'lag' in name or 'rsi' in name or 'vol' in name:
                self.short_indices.append(i)
            else:
                self.base_indices.append(i)
        
        if not self.short_indices: self.short_indices = self.base_indices
        if not self.mid_indices: self.mid_indices = self.base_indices
        if not self.long_indices: self.long_indices = self.base_indices
        
        self.short_input_idx = list(set(self.base_indices + self.short_indices))
        
        def make_branch(input_dim):
            return nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                SEBlock(hidden_dim, reduction=8), # [NEW] Apply SE-Block
                nn.LayerNorm(hidden_dim),
                nn.SiLU(),
                ResidualBlock(hidden_dim),
                ResidualBlock(hidden_dim),
                nn.Linear(hidden_dim, 1)
            )

        self.net_short = make_branch(len(self.short_input_idx))
        self.net_mid = make_branch(len(self.mid_indices))
        self.net_long = make_branch(len(self.long_indices))
        
        self.gating_net = nn.Sequential(
            nn.Linear(len(feature_names), 32), 
            nn.Tanh(),
            nn.Linear(32, 3),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        x_short = x[:, self.short_input_idx]
        x_mid = x[:, self.mid_indices]
        x_long = x[:, self.long_indices]
        
        out_short = self.net_short(x_short)
        out_mid = self.net_mid(x_mid)
        out_long = self.net_long(x_long)
        
        weights = self.gating_net(x)
        
        out = (out_short * weights[:, 0:1]) + \
              (out_mid * weights[:, 1:2]) + \
              (out_long * weights[:, 2:3])
        
        return torch.tanh(out), weights

# --- [Loss Function] ---
# [Improvement 2] Differentiable Sharpe Loss
class SharpeHybridLoss(nn.Module):
    def __init__(self, target_return=0.0, alpha=0.5):
        super().__init__()
        self.target_return = target_return
        self.alpha = alpha # Weight for Sharpe Component
        self.mse = nn.MSELoss()

    def forward(self, preds, targets):
        # 1. MSE Component (Stabilizer)
        mse_loss = self.mse(preds, targets)
        
        # 2. Sharpe Component (Optimizer)
        strategy_returns = preds * targets
        expected_return = torch.mean(strategy_returns)
        volatility = torch.std(strategy_returns) + 1e-8
        
        # Negate Sharpe because we minimize loss
        sharpe_loss = -1.0 * (expected_return / volatility)
        
        return (1 - self.alpha) * mse_loss + self.alpha * sharpe_loss

# --- [Ensemble Trainer] ---
# [Improvement 3] Seed Ensemble Logic
class EnsembleTrainerWrapper:
    def __init__(self, feature_names: List[str], hidden_dim: int = 128, num_seeds: int = NUM_SEEDS):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.models = nn.ModuleList()
        self.ema_models = nn.ModuleList()
        self.optimizers = []
        self.num_seeds = num_seeds
        self.loss_fn = SharpeHybridLoss(alpha=0.1) # Small alpha for Sharpe to avoid instability in small batches

        for i in range(num_seeds):
            # Seed initialization handled implicitly by loop if we didn't fix seed globally,
            # but usually it's better to re-seed or let random init differ.
            # Here pytorch linear layers init randomly.
            net = MultiScaleEnsemble(feature_names, hidden_dim).to(self.device)
            ema_net = MultiScaleEnsemble(feature_names, hidden_dim).to(self.device)
            ema_net.load_state_dict(net.state_dict())
            ema_net.eval()
            
            self.models.append(net)
            self.ema_models.append(ema_net)
            
            base_optim = optim.AdamW
            if USE_SAM:
                opt = SAM(net.parameters(), base_optim, lr=LEARNING_RATE, rho=0.05, weight_decay=1e-5)
            else:
                opt = base_optim(net.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
            self.optimizers.append(opt)
            
        self.ema_decay = EMA_DECAY

    def update_ema(self):
        with torch.no_grad():
            for i in range(self.num_seeds):
                for param_q, param_k in zip(self.models[i].parameters(), self.ema_models[i].parameters()):
                    param_k.data = param_k.data * self.ema_decay + param_q.data * (1. - self.ema_decay)

    def fit_batch(self, X: np.ndarray, y: np.ndarray, epochs: int = 1):
        if len(X) == 0: return 0.0
        
        X_tensor = torch.from_numpy(X).float().to(self.device)
        y_tensor = torch.from_numpy(y).float().view(-1, 1).to(self.device)
        
        total_loss = 0.0
        
        for i in range(self.num_seeds):
            self.models[i].train()
            
            for _ in range(epochs):
                if USE_SAM:
                    # First Step
                    preds, _ = self.models[i](X_tensor)
                    loss = self.loss_fn(preds, y_tensor)
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.models[i].parameters(), 1.0)
                    self.optimizers[i].first_step(zero_grad=True)
                    
                    # Second Step
                    preds_2, _ = self.models[i](X_tensor) # [수정]
                    self.loss_fn(preds_2, y_tensor).backward()
                    torch.nn.utils.clip_grad_norm_(self.models[i].parameters(), 1.0)
                    self.optimizers[i].second_step(zero_grad=True)
                else:
                    self.optimizers[i].zero_grad()
                    preds, _ = self.models[i](X_tensor)
                    loss = self.loss_fn(preds, y_tensor)
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.models[i].parameters(), 1.0)
                    self.optimizers[i].step()
            
                total_loss += loss.item()
        
        self.update_ema()
        return total_loss / (epochs * self.num_seeds)

    def predict(self, X: np.ndarray) -> np.ndarray:
        X_tensor = torch.from_numpy(X).float().to(self.device)
        preds_accum = torch.zeros(X.shape[0], 1).to(self.device)
        weights_accum = torch.zeros(X.shape[0], 3).to(self.device) # 가중치 누적용
        
        with torch.no_grad():
            for i in range(self.num_seeds):
                self.ema_models[i].eval()
                # [수정] 예측값과 가중치를 모두 받음
                p, w = self.ema_models[i](X_tensor)
                preds_accum += p
                weights_accum += w
                
        avg_preds = preds_accum / self.num_seeds
        avg_weights = weights_accum / self.num_seeds # 가중치 평균
        
        # 예측값과 가중치 모두 반환
        return avg_preds.cpu().numpy().reshape(-1), avg_weights.cpu().numpy()

# --- [Online Training & Optimization] ---
class OnlineTrainer:
    def __init__(self, model: EnsembleTrainerWrapper, buffer_size=REPLAY_BUFFER_SIZE, batch_size=BATCH_SIZE):
        self.model = model
        self.buffer = SlidingWindowBuffer(capacity=buffer_size)
        self.batch_size = batch_size
        
        self.pending_features: deque[np.ndarray] = deque()
        self.pending_raw_preds: deque[float] = deque()
        self.update_count = 0

    def queue_example(self, feature_vector: np.ndarray, raw_pred: float):
        self.pending_features.append(feature_vector.astype(np.float32))
        self.pending_raw_preds.append(float(raw_pred))

    def apply_feedback(self, actual_target: float) -> Optional[float]:
        if not self.pending_features:
            return None

        feat_vec = self.pending_features.popleft()
        _ = self.pending_raw_preds.popleft()
        
        self.buffer.push(feat_vec, actual_target)
        
        loss = None
        if len(self.buffer) >= self.batch_size:
            states, targets = self.buffer.sample(self.batch_size)
            loss = self.model.fit_batch(states, targets, epochs=1)
            self.update_count += 1
        
        return loss

class OnlineOptimizer:
    def __init__(self, initial_scale=1.0, initial_bias=1.0, window_size=252):
        self.scale = initial_scale
        self.bias = initial_bias
        self.window_size = window_size

        self.history_preds: deque[float] = deque(maxlen=window_size)
        self.history_actuals: deque[float] = deque(maxlen=window_size)
        self.history_rfr: deque[float] = deque(maxlen=window_size)

        self.last_raw_prediction: float | None = None

    def update(self, lagged_return: float, lagged_rfr: float):
        if self.last_raw_prediction is not None and np.isfinite(lagged_return):
            self.history_preds.append(self.last_raw_prediction)
            self.history_actuals.append(lagged_return)
            self.history_rfr.append(lagged_rfr)

            if len(self.history_preds) >= 100 and len(self.history_preds) % 50 == 0:
                self._optimize()

    def _optimize(self):
        preds = np.array(self.history_preds)
        actuals = np.array(self.history_actuals)
        rfrs = np.array(self.history_rfr)

        p_mean = np.mean(preds)
        p_std = np.std(preds) + 1e-8
        z_scores = (preds - p_mean) / p_std

        def objective(params):
            sc, bi = params
            weights = np.clip(bi + sc * z_scores, MIN_INVESTMENT, MAX_INVESTMENT)
            excess_returns = weights * (actuals - rfrs)
            
            mean_ret = np.mean(excess_returns)
            std_ret = np.std(excess_returns)
            
            if std_ret < 1e-7: return 100.0
            sharpe = mean_ret / std_ret
            return -sharpe

        x0 = [self.scale, self.bias]
        bounds = [(0.0, 5.0), (0.0, 2.0)]

        try:
            res = minimize(objective, x0, method='L-BFGS-B', bounds=bounds, tol=1e-4)
            if res.success:
                alpha = 0.05 
                self.scale = alpha * res.x[0] + (1 - alpha) * self.scale
                self.bias = alpha * res.x[1] + (1 - alpha) * self.bias
        except Exception:
            pass

    def get_params(self) -> Tuple[float, float]:
        return self.scale, self.bias

    def set_last_prediction(self, pred: float):
        self.last_raw_prediction = float(pred)
    
    # [Improvement 4] Volatility Scaling Logic
    def get_vol_scaler(self) -> float:
        if len(self.history_actuals) < 21:
            return 1.0
        
        # Calculate recent realized volatility (21 days)
        recent_returns = list(self.history_actuals)[-21:]
        current_vol = np.std(recent_returns) * np.sqrt(252)
        
        if current_vol == 0: return 1.0
        
        # Target Volatility Scaling
        scaler = TARGET_VOLATILITY / (current_vol + 1e-8)
        
        # Clip to prevent extreme scaling (e.g. 0.5x to 1.5x)
        return np.clip(scaler, 0.5, 1.5)

# --- [Visualization] ---
def evaluate_and_plot_performance(val_df: pd.DataFrame):
    df = val_df.copy()
    
    df['strategy_return'] = df['prediction'] * df['forward_returns']
    
    if BENCHMARK_COL in df.columns:
        df['market_return'] = df[BENCHMARK_COL]
    else:
        df['market_return'] = df['forward_returns']

    df['cum_strategy'] = (1 + df['strategy_return']).cumprod() - 1
    df['cum_market'] = (1 + df['market_return']).cumprod() - 1

    running_max = df['cum_strategy'].cummax()
    df['drawdown'] = (df['cum_strategy'] - running_max) / (running_max + 1)
    max_dd = df['drawdown'].min()
    
    final_strat_ret = df['cum_strategy'].iloc[-1]
    final_mkt_ret = df['cum_market'].iloc[-1]
    
    print(f"\n{'='*20} [Performance Report] {'='*20}")
    print(f"Strategy Final Return : {final_strat_ret*100:.2f}%")
    print(f"Market Final Return   : {final_mkt_ret*100:.2f}%")
    print(f"Max Drawdown          : {max_dd*100:.2f}%")
    
    fig, axes = plt.subplots(3, 1, figsize=(15, 12), sharex=True)
    
    axes[0].plot(df['cum_strategy'], label='Strategy (Ensemble)', color='blue', linewidth=2)
    axes[0].plot(df['cum_market'], label='Market Benchmark', color='gray', linestyle='--', alpha=0.7)
    axes[0].set_title('Cumulative Returns: Model vs Market', fontsize=14, fontweight='bold')
    axes[0].set_ylabel('Cumulative Return')
    axes[0].legend(loc='upper left')
    axes[0].grid(True, alpha=0.3)

    window = 63
    roll_mean = df['strategy_return'].rolling(window).mean()
    roll_std = df['strategy_return'].rolling(window).std()
    roll_sharpe = (roll_mean / (roll_std + 1e-9)) * np.sqrt(252)
    
    ax2 = axes[1]
    ax2.plot(roll_sharpe, label=f'{window}-Day Rolling Sharpe', color='green', alpha=0.8)
    ax2.axhline(0, color='black', linewidth=0.8, linestyle='--')
    ax2.set_ylabel('Rolling Sharpe Ratio', color='green')
    ax2.set_title(f'Risk Analysis: {window}-Day Rolling Sharpe & Volatility', fontsize=14, fontweight='bold')
    ax2.legend(loc='upper left')
    
    ax2_r = ax2.twinx()
    annualized_vol = roll_std * np.sqrt(252)
    ax2_r.plot(annualized_vol, label=f'{window}-Day Rolling Volatility', color='orange', alpha=0.5, linestyle=':')
    ax2_r.set_ylabel('Annualized Volatility', color='orange')
    ax2_r.legend(loc='upper right')

    axes[2].fill_between(df.index, df['drawdown'], 0, color='red', alpha=0.3, label='Drawdown')
    axes[2].plot(df['drawdown'], color='red', linewidth=1)
    axes[2].set_title('Underwater Plot (Drawdown Profile)', fontsize=14, fontweight='bold')
    axes[2].set_ylabel('Drawdown %')
    axes[2].set_xlabel('Time Steps (Days)')
    axes[2].legend(loc='lower left')
    axes[2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

def plot_investment_weights(val_df: pd.DataFrame):
    df = val_df.copy()
    weights = df['prediction']
    
    mean_w = weights.mean()
    std_w = weights.std()
    
    fig = plt.figure(figsize=(15, 6))
    plt.plot(df.index, weights, label='Daily Weight', color='steelblue', alpha=0.6, linewidth=1)
    plt.plot(df.index, weights.rolling(21).mean(), label='21-Day Avg Weight', color='orange', linewidth=2)
    plt.title('Daily Investment Weight History', fontsize=14, fontweight='bold')
    plt.axhline(mean_w, color='red', linestyle='--')
    plt.legend()
    plt.show()

# --- [Simulation Loops] ---
def run_online_simulation_loop(
    df: pd.DataFrame,
    model: EnsembleTrainerWrapper,
    trainer: OnlineTrainer,
    optimizer: OnlineOptimizer,
    scaler: RobustScaler,
    feature_cols: List[str],
    desc: str = "Simulating",
    update_model: bool = True
) -> Tuple[List[float], List[np.ndarray]]:
    
    preds_list = []
    gating_weights_list = []
    
    for i in tqdm(range(len(df)), desc=desc):
        row = df.iloc[i:i+1]
        
        raw_features = row[feature_cols].fillna(0).to_numpy(dtype=np.float64)
        scaled_features = scaler.transform(raw_features)
        
        preds, w = model.predict(scaled_features)
        raw_pred = float(preds[0])
        gating_weights_list.append(w[0])
        
        trainer.queue_example(scaled_features[0], raw_pred)
        optimizer.set_last_prediction(raw_pred)
        
        opt_scale, opt_bias = optimizer.get_params()
        
        # [Improvement 4] Apply Volatility Scaling
        vol_scalar = optimizer.get_vol_scaler()
        
        weight = opt_bias + opt_scale * raw_pred
        weight_adjusted = weight * vol_scalar
        
        final_weight = float(np.clip(weight_adjusted, MIN_INVESTMENT, MAX_INVESTMENT))
        
        preds_list.append(final_weight)
        
        if update_model:
            actual = float(row[TARGET_COL].iloc[0])
            rfr = float(row['risk_free_rate'].iloc[0]) if 'risk_free_rate' in row.columns else 0.0
            
            trainer.apply_feedback(actual)
            optimizer.update(actual, rfr)
            
    return preds_list, gating_weights_list

def plot_gating_weights_with_market(df: pd.DataFrame, weights_list: List[np.ndarray]):
    """
    df: 검증 데이터프레임
    weights_list: 시뮬레이션에서 저장한 가중치 리스트 (N x 3)
    """
    weights_np = np.array(weights_list) # (N, 3)
    dates = df.index
    
    # 시장 누적 수익률 계산
    market_ret = df[BENCHMARK_COL] if BENCHMARK_COL in df.columns else df['forward_returns']
    cum_market = (1 + market_ret).cumprod() - 1

    fig, ax1 = plt.subplots(figsize=(15, 8))

    # --- 1. Gating Weights (Line Chart) ---
    # 각각을 별도의 선으로 그립니다.
    # Short: 파랑, Mid: 주황, Long: 초록 (기존 색상 테마 유지)
    ax1.plot(dates, weights_np[:, 0], label='Short-Term (Blue)', color='royalblue', linewidth=1.5, alpha=0.9)
    ax1.plot(dates, weights_np[:, 1], label='Mid-Term (Orange)', color='darkorange', linewidth=1.5, alpha=0.9)
    ax1.plot(dates, weights_np[:, 2], label='Long-Term (Green)', color='forestgreen', linewidth=1.5, alpha=0.9)
    
    ax1.set_ylabel('Gating Weight (0.0 ~ 1.0)', fontsize=12)
    ax1.set_ylim(-0.05, 1.05) # 여백을 살짝 둠
    ax1.legend(loc='upper left', ncol=3)
    ax1.grid(True, alpha=0.3)
    
    plt.title('Gating Weights Trends (Line Chart) vs Market Condition', fontsize=16, y=1.02)
    plt.tight_layout()
    plt.show()
    
def run_validation_and_score(
    df: pd.DataFrame, 
    feature_cols: List[str], 
    split_ratio: float = 0.8
):
    print(f"\n{'='*40}")
    print(f"[Full Walk-Forward Validation Start]")
    print(f"Ensemble Seeds: {NUM_SEEDS}")
    
    split_idx = int(len(df) * split_ratio)
    train_data = df.iloc[:split_idx].copy()
    val_data = df.iloc[split_idx:].reset_index(drop=True).copy()
    
    print(f"Train samples: {len(train_data)}, Validation samples: {len(val_data)}")
    
    scaler = RobustScaler()
    train_X = train_data[feature_cols].fillna(0).to_numpy(dtype=np.float64)
    scaler.fit(train_X)
    
    model = EnsembleTrainerWrapper(feature_cols)
    trainer = OnlineTrainer(model, buffer_size=REPLAY_BUFFER_SIZE, batch_size=BATCH_SIZE)
    optimizer = OnlineOptimizer(initial_scale=1.0, initial_bias=1.0)
    
    print("\nPhase 1: Walking through TRAIN data (Learning)...")
    _ = run_online_simulation_loop(
        train_data, model, trainer, optimizer, scaler, feature_cols,
        desc="Train Walk-Forward", update_model=True
    )
    
    print("\nPhase 2: Walking through VALIDATION data (Evaluating)...")
    val_preds, val_weights = run_online_simulation_loop(
        val_data, model, trainer, optimizer, scaler, feature_cols,
        desc="Val Walk-Forward", update_model=True
    )
    
    val_data['prediction'] = val_preds
    if 'risk_free_rate' not in val_data.columns:
        val_data['risk_free_rate'] = 0.0
        
    sharpe_score = score(val_data, val_data)
    
    evaluate_and_plot_performance(val_data)
    plot_gating_weights_with_market(val_data, val_weights)
    plot_investment_weights(val_data)

    print(f"\n[Validation Result]")
    print(f"Validation Modified Sharpe Ratio: {sharpe_score:.5f}")
    returns = val_data['prediction'] * val_data[TARGET_COL]
    print(f"Mean Return: {returns.mean():.5f}, Std: {returns.std():.5f}")
    print(f"{'='*40}\n")
    
    return sharpe_score

def walkforward_simulation(
    train_feat_clean: pd.DataFrame,
    feature_cols: List[str],
):
    print("\nStarting Final Production Walk-Forward on FULL Data...")
    
    scaler = RobustScaler()
    all_X = train_feat_clean[feature_cols].fillna(0).to_numpy(dtype=np.float64)
    scaler.fit(all_X)
    
    model = EnsembleTrainerWrapper(feature_cols)
    trainer = OnlineTrainer(model, buffer_size=REPLAY_BUFFER_SIZE, batch_size=BATCH_SIZE)
    optimizer = OnlineOptimizer(initial_scale=1.0, initial_bias=1.0)
    
    t_mean = train_feat_clean[TARGET_COL].mean()
    t_std = train_feat_clean[TARGET_COL].std()

    _ = run_online_simulation_loop(
        train_feat_clean, model, trainer, optimizer, scaler, feature_cols,
        desc='Final Full Simulation', update_model=True
    )
        
    return scaler, model, trainer, optimizer, t_mean, t_std

# --- [Inference Globals] ---
model: Optional[Any] = None
feature_generator: Optional[FeatureGenerator] = None
feature_cols: List[str] = []
target_mean = 0.0
target_std = 1.0
optimizer: OnlineOptimizer | None = None
online_trainer: Optional[OnlineTrainer] = None
scaler: Optional[RobustScaler] = None
PREDICT_COUNTER = 0

def predict(test: pl.DataFrame) -> float:
    global model, feature_generator, feature_cols, target_mean, target_std, optimizer, online_trainer, scaler, PREDICT_COUNTER

    if optimizer is None or model is None:
        return 1.0

    if 'lagged_forward_returns' in test.columns and 'lagged_risk_free_rate' in test.columns:
        lagged_ret = test['lagged_forward_returns'][0]
        lagged_rfr = test['lagged_risk_free_rate'][0]
        
        if lagged_ret is not None and lagged_rfr is not None:
            loss = online_trainer.apply_feedback(float(lagged_ret))
            optimizer.update(float(lagged_ret), float(lagged_rfr))

    try:
        test_df = test.to_pandas()
        X_test_full = feature_generator.transform(test_df)
        X_test = X_test_full[feature_cols].fillna(0)
        X_np = X_test.to_numpy(dtype=np.float64)
        
        X_scaled = scaler.transform(X_np)
        preds = model.predict(X_scaled)
        raw_pred = float(preds[0])

        online_trainer.queue_example(X_scaled[0], raw_pred)
        optimizer.set_last_prediction(raw_pred)

        opt_scale, opt_bias = optimizer.get_params()
        vol_scalar = optimizer.get_vol_scaler()
        
        weight = opt_bias + opt_scale * raw_pred
        weight_adjusted = weight * vol_scalar
        
        final_weight = float(np.clip(weight_adjusted, MIN_INVESTMENT, MAX_INVESTMENT))
        
        PREDICT_COUNTER += 1
        return final_weight
        
    except Exception as err:
        print(f'Predict error: {err}')
        return 1.0

def manual_inference_loop(train_df: pd.DataFrame):
    if feature_generator is None: return
    if not TEST_PATH.exists():
        print('Test file not found.')
        return

    test_df = pd.read_csv(TEST_PATH).sort_values(DATE_COL).reset_index(drop=True)
    test_pl = pl.from_pandas(test_df)

    weights = []
    ids = []

    for i in tqdm(range(len(test_pl)), desc='Manual Predict Loop'):
        row_pl = test_pl[i]
        weight = predict(row_pl)
        weights.append(weight)
        row_pd = row_pl.to_pandas()
        if 'row_id' in row_pd.columns:
            ids.append(row_pd['row_id'].iloc[0])
        else:
            ids.append(row_pd[DATE_COL].iloc[0])

    submission_pl = pl.DataFrame({'row_id': ids, 'prediction': weights})
    submission_pl.write_csv('submission.csv')
    print(f'submission.csv generated ({len(submission_pl)} rows).')

def main():
    global model, feature_generator, feature_cols, target_mean, target_std, optimizer, online_trainer, scaler

    train_df, _ = load_data()
    if train_df.empty: return

    exclude_cols = {TARGET_COL, 'weight', 'row_id', DATE_COL, BENCHMARK_COL}
    numeric_cols = [c for c in train_df.columns if c not in exclude_cols and pd.api.types.is_numeric_dtype(train_df[c])]

    if TEST_PATH.exists():
        test_sample = pd.read_csv(TEST_PATH, nrows=1)
        numeric_cols = [c for c in numeric_cols if c in test_sample.columns]

    train_feat, generated_cols = engineer_features(train_df, numeric_cols, show_progress=True)
    feature_cols = list(dict.fromkeys([c for c in numeric_cols if c in train_feat.columns] + generated_cols))

    train_feat_clean = train_feat.dropna(subset=feature_cols + [TARGET_COL]).reset_index(drop=True)

    # 90% train, 10% validation
    run_validation_and_score(train_feat_clean, feature_cols, split_ratio=0.90899949723479135243841126194067)
    
    # Final Training on Full Data
    scaler, model, online_trainer, optimizer, target_mean, target_std = walkforward_simulation(
        train_feat_clean, feature_cols
    )

    feature_generator = FeatureGenerator(numeric_cols)
    feature_generator.fit(train_df)

    print(f'Ready for inference: Scale={optimizer.scale:.3f}, Bias={optimizer.bias:.3f}')

    if IS_KAGGLE:
        import kaggle_evaluation.default_inference_server
        inference_server = kaggle_evaluation.default_inference_server.DefaultInferenceServer(predict)
        if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
            inference_server.serve()
        else:
            inference_server.run_local_gateway((str(INPUT_DIR),))
    else:
        if not Path('submission.parquet').exists():
            manual_inference_loop(train_df)

if __name__ == '__main__':
    main()