In [186]:
import sys
import os
import time
import datetime
import itertools
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from tqdm import tqdm
from typing import List, Dict, Any, Optional
from dataclasses import replace

# Ensure src is in python path
sys.path.append(str(Path().resolve().parent))

from src.config import ExpConfig
from src.engine.factory import EngineFactory
from src.algorithms.solver import CalibrationSolver
from src.modules.y_mappers import MonotoneYMapper
from src.utils.metrics import compute_parameter_error, compute_rmse,compute_nll,compute_nll_from_gamma,compute_empirical_error_bound,compute_p0_from_logits
# Imports for optimization and regret calculation
from src.utils.optimization import solve_optimal_assortment, calculate_revenue

# ==========================================
# Global Setup
# ==========================================
RESULTS_DIR = Path("results")
LOG_DIR = RESULTS_DIR / "logs"
FIG_DIR = RESULTS_DIR / "figures"
LOG_DIR.mkdir(parents=True, exist_ok=True)
FIG_DIR.mkdir(parents=True, exist_ok=True)

RUN_ID = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

def set_seed(seed: int):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
def run_single_trial(cfg: ExpConfig, 
                     task_type: str = 'standard',      
                     algo_type: str = 'mrc',           
                     multi_sim_method: str = 'median', 
                     regret_need: bool = False,
                     y_type: str = 'linear',
                     z_type: str = 'stats',
                     u_type: str = 'linear',
                     context_type: str= 'concat',
                     z_model_path: str = None) -> Dict[str, float]:
    
    # 1. Build Engine & Generate Data
    engine = EngineFactory.build_synthetic_engine(
        cfg, z_type=z_type, u_type=u_type, y_type=y_type, context_mapper_type=context_type,
    )
    data = engine.generate()
    inputs = data['inputs']
    truth = data['truth']
    solver = CalibrationSolver(cfg)
    # y_binary = torch.bernoulli(truth['p0']) # Not used currently
    
    # 初始化 regret 为 NaN (如果不需要计算，或者计算失败，保持为 NaN 方便后续处理)
    avg_regret = float('nan')
    duration = 0.0

    # ==========================================
    # Branch A: Multi-Simulator 
    # ==========================================
    if task_type == 'multi_sim':
        eta_true = truth['eta']
        # ... (Simulator construction logic omitted for brevity, same as before) ...
        # Sim 1-5 construction...
        mapper1=MonotoneYMapper(cfg)
        mapper2=MonotoneYMapper(replace(cfg, sim_bias_a=cfg.sim_bias_a+0.5, sim_bias_b=cfg.sim_bias_b*0.8))
        mapper3=MonotoneYMapper(replace(cfg, sim_bias_a=cfg.sim_bias_a-0.5, sim_bias_b=cfg.sim_bias_b*1.2))
        mapper4=MonotoneYMapper(replace(cfg, sim_bias_a=cfg.sim_bias_a+1.0, sim_bias_b=cfg.sim_bias_b*0.5))
        mapper5=MonotoneYMapper(replace(cfg, sim_bias_a=cfg.sim_bias_a-10.0, sim_bias_b=cfg.sim_bias_b*15))
        
        y1 = mapper1(eta_true + torch.randn_like(eta_true) * 5) + torch.randn_like(eta_true) * 0.5
        y2 = mapper2(eta_true + torch.randn_like(eta_true) * 5) + torch.randn_like(eta_true) * 0.5
        y3 = mapper3(eta_true + torch.randn_like(eta_true) * 5) + torch.randn_like(eta_true) * 0.5
        y4 = mapper4(eta_true + torch.randn_like(eta_true) * 5) + torch.randn_like(eta_true) * 0.5
        
        # 这里的 y5 建议按之前讨论修改为纯噪声以体现 Median 优势，或者保持原样
        y5 = -10000.0 * mapper5(eta_true) + torch.randn_like(eta_true) * 5.0
        
        Y_multi = torch.stack([y1, y2, y3, y4, y5], dim=1)
        
        start_time = time.time()
        gamma_hat = solver.solve_multi_mrc(inputs['z'], inputs['s_hat'], Y_multi, method=multi_sim_method)
        duration = time.time() - start_time
        
        with torch.no_grad():
            p0_pred = compute_p0_from_logits(inputs['z'], inputs['s_hat'], gamma_hat)

    # ==========================================
    # Branch B: Standard Calibration (+ Regret)
    # ==========================================
    else:
        start_time = time.time()
        if algo_type == 'linear':
            gamma_hat = solver.solve_linear(inputs['z'], inputs['s_hat'], inputs['y'])
        elif algo_type == 'mrc':
            gamma_hat = solver.solve_mrc(inputs['z'], inputs['s_hat'], inputs['y'])
        else:
            raise ValueError(f"Unknown algo_type: {algo_type}")
        duration = time.time() - start_time
        
        with torch.no_grad():
            p0_pred = compute_p0_from_logits(inputs['z'], inputs['s_hat'], gamma_hat)
            
        # ==========================================
        # 3. Downstream Assortment Regret (Test Phase)
        # ==========================================
        if regret_need:
            # 只有 Linear Utility 才能精确定义 beta 误差
            if u_type != 'linear':
                avg_regret = float('nan') 
            else:
                n_test_decisions = 50 
                n_items_pool = 50     

                # 1. 获取 True Beta
                beta_true = engine.u_mapper.beta.detach() 
                
                # 2. 获取 Estimated Beta (从 Engine 生成数据时导出的)
                # [CRITICAL FIX] 这是一个必须保持一致的偏差
                # 它模拟了 Stage 1 训练好的 Utility Model 参数
                # 这个参数不仅生成了 s_hat (用于 Calibration), 也将被用于下面的决策
                beta_hat = truth.get('beta_hat')
                
                # Fallback: 如果没有 noise 或者模式不对，beta_hat 默认为 beta_true
                if beta_hat is None:
                    beta_hat = beta_true
                
                # 确保设备一致
                beta_hat = beta_hat.to(cfg.device)
                
                regret_list = []
                
                for _ in range(n_test_decisions):
                    # A. 生成随机测试环境 (Context Z & Items)
                    z_test = torch.randn(cfg.dim_z, device=cfg.device)
                    items_x = torch.randn(n_items_pool, cfg.dim_item_feat, device=cfg.device)
                    prices = torch.rand(n_items_pool, device=cfg.device) * 90 + 10
                    
                    # B. 计算 Utilities
                    
                    # Oracle (上帝视角): 使用 True Params
                    u_items_true = items_x @ beta_true
                    
                    # Plug-in (决策视角): 使用 Hat Params
                    # 注意：这里使用的是带偏差的 beta_hat
                    u_items_hat = items_x @ beta_hat
                    
                    # C. Optimization (决策)
                    
                    # Oracle Set
                    mask_opt, r_opt = solve_optimal_assortment(
                        truth['gamma'], z_test, prices, u_items_true
                    )
                    
                    # Estimated Set (Calibration 的 hat_gamma + Utility 的 hat_beta)
                    # 这里的 "误差抵消" 会发生：gamma_hat 修正了由 beta_hat 引起的 scale shift
                    mask_hat, _ = solve_optimal_assortment(
                        gamma_hat, z_test, prices, u_items_hat
                    )
                    
                    # D. Evaluation (评估)
                    # 必须在 "真实世界" (True Params) 中评估两个集合的收入
                    r_hat_realized = calculate_revenue(
                        mask_hat, truth['gamma'], z_test, prices, u_items_true
                    )
                    
                    # E. Regret Calculation
                    if r_opt > 1e-6:
                        regret = (r_opt - r_hat_realized) / r_opt
                    else:
                        regret = 0.0
                    
                    regret_list.append(regret)
                
                # 计算平均 Regret (%)
                avg_regret = np.mean(regret_list) * 100.0 
        
    return {
        'gamma_hat': gamma_hat,
        'gamma_true': truth['gamma'],
        'p0_pred': p0_pred,
        'p0_true': truth['p0'],
        'time': duration,
        'regret': avg_regret
    }

# # ==========================================
# # 1. Unified Atomic Trial Runner
# # ==========================================
# def run_single_trial(cfg: ExpConfig, 
#                      task_type: str = 'standard',      
#                      algo_type: str = 'mrc',           
#                      multi_sim_method: str = 'median', 
#                      regret_need: bool = False,
#                      y_type: str = 'linear',
#                      z_type: str = 'stats',
#                      u_type: str = 'linear',
#                      context_type: str= 'concat',
#                      z_model_path: str = None) -> Dict[str, float]:
    
#     # 1. Build Engine & Generate Data
#     engine = EngineFactory.build_synthetic_engine(
#         cfg, z_type=z_type, u_type=u_type, y_type=y_type, context_mapper_type=context_type,
#     )
#     data = engine.generate()
#     inputs = data['inputs']
#     truth = data['truth']
#     solver = CalibrationSolver(cfg)
#     y_binary = torch.bernoulli(truth['p0'])
    
#     # ==========================================
#     # Branch A: Multi-Simulator 
#     # ==========================================
#     if task_type == 'multi_sim':
#         eta_true = truth['eta']
#         # Sim 1: (Good Signal)
#         mapper1=MonotoneYMapper(cfg)
#         mapper2=MonotoneYMapper(replace(cfg, sim_bias_a=cfg.sim_bias_a+0.5, sim_bias_b=cfg.sim_bias_b*0.8))
#         mapper3=MonotoneYMapper(replace(cfg, sim_bias_a=cfg.sim_bias_a-0.5, sim_bias_b=cfg.sim_bias_b*1.2))
#         mapper4=MonotoneYMapper(replace(cfg, sim_bias_a=cfg.sim_bias_a+1.0, sim_bias_b=cfg.sim_bias_b*0.5))
#         mapper5=MonotoneYMapper(replace(cfg, sim_bias_a=cfg.sim_bias_a-10.0, sim_bias_b=cfg.sim_bias_b*15))
#         y1 = mapper1(eta_true+ torch.randn_like(eta_true) * 5) + torch.randn_like(eta_true) * 0.5
#         y2 = mapper2(eta_true+ torch.randn_like(eta_true) * 5) + torch.randn_like(eta_true) * 0.5
#         y3 = mapper3(eta_true+ torch.randn_like(eta_true) * 5) + torch.randn_like(eta_true) * 0.5
#         y4 = mapper4(eta_true+ torch.randn_like(eta_true) * 5) + torch.randn_like(eta_true) * 0.5
        
#         y5 = -10000.0 * mapper5(eta_true) + torch.randn_like(eta_true) * 5.0
        
#         Y_multi = torch.stack([y1, y2, y3, y4, y5], dim=1)
        
#         start_time = time.time()
#         gamma_hat = solver.solve_multi_mrc(inputs['z'], inputs['s_hat'], Y_multi, method=multi_sim_method)
#         duration = time.time() - start_time
        
#         # 使用修正后的 metrics 计算 NLL
#         # nll_score = compute_nll_from_gamma(gamma_hat, inputs['z'], inputs['s_hat'], y_binary)
#         p0_pred=compute_p0_from_logits(inputs['z'], inputs['s_hat'], gamma_hat)
#     # ==========================================
#     # Branch B: Standard Calibration (+ Regret)
#     # ==========================================
#     else:
#         start_time = time.time()
#         if algo_type == 'linear':
#             gamma_hat = solver.solve_linear(inputs['z'], inputs['s_hat'], inputs['y'])
#         elif algo_type == 'mrc':
#             gamma_hat = solver.solve_mrc(inputs['z'], inputs['s_hat'], inputs['y'])
#         else:
#             raise ValueError(f"Unknown algo_type: {algo_type}")
#         duration = time.time() - start_time
        
#         with torch.no_grad():
#             p0_pred = compute_p0_from_logits(inputs['z'],inputs['s_hat'],gamma_hat)
            
#         avg_regret = 0.0
#         # ==========================================
#         # 3. Downstream Assortment Regret (Test Phase)
#         # ==========================================
#         if regret_need:
#             if u_type != 'linear':
#                 avg_regret = float('nan') 
#             else:
#                 n_test_decisions = 50
#                 n_items_pool = 50     

#                 beta_true = engine.u_mapper.beta.detach() 
                
#                 est_sigma = cfg.est_noise_sigma
#                 d_item = cfg.dim_item_feat
                
#                 if est_sigma > 0:
#                     if getattr(cfg, 'noise_distribution', 'gaussian') == 'uniform':
#                         # Strict bound logic: max norm <= sigma
#                         bound = est_sigma / np.sqrt(d_item)
#                         delta_beta = (torch.rand(d_item, device=cfg.device) * 2 - 1) * bound
#                     else:
#                         # Gaussian logic: expected norm approx sigma
#                         raw = torch.randn(d_item, device=cfg.device)
#                         delta_beta = raw / (raw.norm() + 1e-9) * est_sigma
#                 else:
#                     delta_beta = torch.zeros_like(beta_true)
                
#                 beta_hat = beta_true + delta_beta
                
#                 regret_list = []
                
#                 for _ in range(n_test_decisions):
#                     z_test = torch.randn(cfg.dim_z, device=cfg.device)
                    
#                     # Random Item Pool (Features X and Prices r)
#                     # Items X: (n_pool, d_item)
#                     items_x = torch.randn(n_items_pool, d_item, device=cfg.device)
#                     # Prices r: Uniform [10, 100]
#                     prices = torch.rand(n_items_pool, device=cfg.device) * 90 + 10
                    
#                     # True Utilities
#                     u_items_true = items_x @ beta_true
                    
#                     # Estimated Utilities
#                     u_items_hat = items_x @ beta_hat
                    
#                     # C.(Optimization)
#                     mask_opt, r_opt = solve_optimal_assortment(
#                         truth['gamma'], z_test, prices, u_items_true
#                     )
                    
#                     mask_hat, _ = solve_optimal_assortment(
#                         gamma_hat, z_test, prices, u_items_hat
#                     )
                    
#                     # D.(Evaluation)
#                     r_hat_realized = calculate_revenue(
#                         mask_hat, truth['gamma'], z_test, prices, u_items_true
#                     )
                    
#                     # E. Regret
#                     # Relative Regret: (Opt - Realized) / Opt
#                     if r_opt > 1e-6:
#                         regret = (r_opt - r_hat_realized) / r_opt
#                     else:
#                         regret = 0.0
                    
#                     regret_list.append(regret)
                
#                 avg_regret = np.mean(regret_list) * 100.0 
        
        
#     return {
#         'gamma_hat': gamma_hat,
#         'gamma_true': truth['gamma'],
#         'p0_pred': p0_pred,
#         'p0_true': truth['p0'],
#         'time': duration,
#         'regret': avg_regret
#     }

# ==========================================
# 2. Universal Experiment Runner (Corrected)
# ==========================================
def run_experiment_grid(
    base_cfg: ExpConfig,
    x_axis_name: str,
    x_values: List[Any],
    compare_axis_name: Any = None, 
    compare_values: List[Any] = None,
    cross_product: bool = True,
    n_seeds: int = 5,
    regret_need:bool= False,
    # Defaults
    default_task_type: str = 'standard',       
    default_algo_type: str = 'mrc',
    default_multi_sim_method: str = 'median',  
    default_y_type: str = 'monotone',
    default_z_type: str = 'neural',
    default_u_type: str = 'linear',
    default_context_type: str= 'concat',
) -> pd.DataFrame:
    
    # --- 1. 解析参数组合 (Cartesian Product Logic) ---
    if compare_axis_name is None:
        comp_keys = []
        comp_iter = [()] # 空元组占位，只循环 x_axis
    elif isinstance(compare_axis_name, str):
        # 兼容旧接口：单变量对比
        comp_keys = [compare_axis_name]
        # 包装成 tuple 列表: [('mrc',), ('linear',)]
        comp_iter = [(v,) for v in compare_values] 
    else:
        # 多变量情况
        comp_keys = compare_axis_name
        if cross_product:
            # [Old Behavior] Cartesian Product
            # values=[[A,B], [C,D]] -> (A,C), (A,D), (B,C), (B,D)
            comp_iter = list(itertools.product(*compare_values))
        else:
            # [New Behavior] Direct Zip / Coupled List
            # values=[(A,C), (B,D)] -> (A,C), (B,D)
            # 用户直接传入 list of tuples
            comp_iter = compare_values

    print(f"\n=== Running Grid: X={x_axis_name} | Cross Validating: {comp_keys} ===")
    
    records = []
    total_iters = len(x_values) * len(comp_iter) * n_seeds
    pbar = tqdm(total=total_iters, desc="Progress")
    
    for x_val in x_values:
        for comp_vals in comp_iter:
            # 构造当前对比参数字典 e.g. {'algo_type': 'mrc', 'y_type': 'linear'}
            current_comp_params = dict(zip(comp_keys, comp_vals))
            
            # 生成绘图用的 Hue Label e.g. "mrc-linear"
            if not comp_keys:
                combo_label = "Default"
            else:
                combo_label = "-".join([str(v) for v in comp_vals])

            for seed in range(n_seeds):
                # 1. Config Setup
                # 优先使用 override 的 x_val
                current_cfg_args = {
                    'seed': seed + 1000,
                    x_axis_name: x_val 
                }
                
                # 2. Dynamic Parameter Resolution
                # 优先级: Override Params > Defaults
                params = {
                    'task_type': default_task_type,
                    'algo_type': default_algo_type,
                    'multi_sim_method': default_multi_sim_method,
                    'y_type': default_y_type,
                    'z_type': default_z_type,
                    'u_type': default_u_type,
                    'context_type': default_context_type,
                    'regret_need':regret_need,
                }
                
                # 应用 X 轴参数 (如果 X 轴控制的是 params 里的东西)
                if x_axis_name in params:
                    params[x_axis_name] = x_val
                
                # 应用 对比轴 参数
                for k, v in current_comp_params.items():
                    if k in params:
                        params[k] = v
                
                # 3. Create Config
                # 将 params 中属于 Config 属性的部分同步进去
                cfg_args = current_cfg_args.copy()
                for k, v in params.items():
                    if hasattr(base_cfg, k):
                        cfg_args[k] = v
                
                # 如果 x_axis 是 Config 的属性 (如 n_samples)，也要同步
                if hasattr(base_cfg, x_axis_name):
                    cfg_args[x_axis_name] = x_val

                cfg = ExpConfig(**{**base_cfg.__dict__, **cfg_args})
                set_seed(cfg.seed)
                
                # 4. Run Trial
                metrics = run_single_trial(
                    cfg, 
                    task_type=params['task_type'],
                    algo_type=params['algo_type'],
                    multi_sim_method=params['multi_sim_method'],
                    regret_need=params['regret_need'],
                    y_type=params['y_type'],
                    z_type=params['z_type'],
                    u_type=params['u_type'],
                    context_type=params['context_type'],
                )
                
                # 5. Record
                record = {
                    x_axis_name: x_val,
                    'seed': seed,
                    'combo_label': combo_label, # 关键：用于 hue
                    **current_comp_params,      # 记录具体参数方便查阅
                    **metrics                   # 包含 Tensor 数据
                }
                record['n_samples'] = cfg.n_samples # 确保 n_samples 总是存在
                
                records.append(record)
                pbar.update(1)
                
    pbar.close()
    df = pd.DataFrame(records)
    
    save_name = f"grid_{x_axis_name}"
    if comp_keys:
        comp_str = "_".join(comp_keys)
        save_name += f"_vs_{comp_str}"
    
    # df.to_csv(LOG_DIR / f"{save_name}_{RUN_ID}.csv", index=False)
    pkl_path = LOG_DIR / f"{save_name}_{RUN_ID}.pkl"
    df.to_pickle(pkl_path)
    print(f"Saved full data (with Tensors) to: {pkl_path}")
    
    return df

# ==========================================
# 3. Universal Plotter (Updated)
# ==========================================
def plot_metric_scaling(
    df: pd.DataFrame,
    x_col: str,
    x_label: str,
    y_label: str,
    hue_col: Optional[str] = None,
    title: str = "Scaling Analysis",
):
    print(f"\n>>> Plotting {y_label} vs {x_label}...")
    
    # --- Step 1: Device Wash (MPS -> CPU) ---
    # 确保所有 Tensor 都在 CPU 上，方便后续 numpy/torch 操作
    tensor_cols = ['p0_pred', 'p0_true', 'gamma_hat', 'gamma_true']
    
    def move_to_cpu(val):
        if isinstance(val, torch.Tensor):
            return val.detach().cpu()
        return val

    plot_df = df.copy()
    for col in tensor_cols:
        if col in plot_df.columns:
            # 简单检查第一行
            first_val = plot_df[col].iloc[0]
            if isinstance(first_val, torch.Tensor) and first_val.device.type != 'cpu':
                plot_df[col] = plot_df[col].apply(move_to_cpu)
                
    # --- Step 2: 核心聚合逻辑 (区分 P0 和 Gamma) ---
    # sub_df 包含同一个 (x, hue) 下的 n_seeds 次实验数据
    def compute_pooled_metric(sub_df):
        
        # =========================================
        # Case A: P0 相关 (需要拼接样本，做全量分布统计)
        # =========================================
        if 'p0' in y_label or 'nll' in y_label or 'rmse' in y_label:
            # 1. 拼接 (Pooling)
            # sub_df['p0_pred'] 是一个包含 n_seeds 个 Tensor 的 Series
            # 每个 Tensor 长度为 n_samples
            # cat 后变成一个巨大的 1D Tensor，长度为 n_seeds * n_samples
            all_preds = torch.cat(list(sub_df['p0_pred'])).float().numpy()
            all_trues = torch.cat(list(sub_df['p0_true'])).float().numpy()
            
            # 2. 计算指标
            if 'nll' in y_label:
                epsilon = 1e-7
                all_preds = np.clip(all_preds, epsilon, 1 - epsilon)
                # 计算所有样本的 NLL 平均值
                nll = - (all_trues * np.log(all_preds) + (1 - all_trues) * np.log(1 - all_preds))
                return np.mean(nll)
            
            elif 'p0' in y_label and 'error' in y_label:
                # 解析 delta (e.g., "p0_error_0.2" -> delta=0.2 -> 80% 分位点)
                # 默认 0.05 (95% 分位点)
                try:
                    parts = y_label.split('_')
                    if len(parts) > 2 and parts[-1].replace('.', '', 1).isdigit():
                        delta = float(parts[-1])
                    else:
                        delta = 0.3
                except:
                    delta = 0.3
                
                # 计算绝对误差分布
                abs_diff = np.abs(all_preds - all_trues)
                # 返回大样本的分位点
                return np.quantile(abs_diff, 1.0 - delta)
            
            elif 'rmse' in y_label:
                return np.sqrt(np.mean((all_preds - all_trues)**2))

        # =========================================
        # Case B: Gamma 相关 (需要行内计算，再行间平均)
        # =========================================
        elif 'gamma' in y_label:
            # 1. 堆叠 (Stacking)
            # 变成 (n_seeds, dim_z) 的矩阵
            hat_matrix = torch.stack(list(sub_df['gamma_hat']))
            true_matrix = torch.stack(list(sub_df['gamma_true']))
            
            # 2. 计算每一行的 L2 Error
            # dim=1 表示沿着特征维度求范数 -> 得到 (n_seeds, ) 的误差向量
            l2_errors = torch.norm(hat_matrix - true_matrix, p=2, dim=1)
            
            # 3. 返回这些误差的平均值
            return l2_errors.mean().item()

        # =========================================
        # Case C: Regret (本身就是标量，直接平均)
        # =========================================
        elif 'regret' in y_label:
            # 安全检查：如果该列不存在或全为空（比如没开启计算），返回 NaN 而不是 0
            # 返回 0 会误导以为效果完美，返回 NaN 图上该点会断开
            if 'regret' not in sub_df.columns:
                return float('nan')
            
            # 过滤掉可能的 None/NaN (有些 seed 可能失败)
            valid_regrets = sub_df['regret'].dropna()
            
            if valid_regrets.empty:
                return float('nan')
                
            return valid_regrets.mean()
        else:
            raise ValueError(f"Unknown y_label type for metric computation: {y_label}")
            
        return 0.0
    # --- Step 3: GroupBy & Apply ---
    df = plot_df.groupby([x_col, hue_col]).apply(compute_pooled_metric, include_groups=False).reset_index(name=y_label)
    
    # 1. Empirical Data Plot
    sns.set_theme(style="whitegrid")
    plt.figure(figsize=(8, 5))
    plt.rcParams.update({
        'font.family': 'serif',
        'font.size': 18,
        'axes.labelsize': 18,
        'axes.titlesize': 18,
        'xtick.labelsize': 18,
        'ytick.labelsize': 18,
        'legend.fontsize': 16,
        'lines.linewidth': 2.5
    })
    
    palette = None
    markers = True
    dashes = True
        
    
    if hue_col:
        unique_methods = sorted(df[hue_col].unique())
        
        # 1. Define Style Mapping
        # Colors: Red (MRC), Blue (Linear), Green (Sim), Orange (Oracle)
        # Lines: Solid, Dash-dot, Dashed, Dotted
        style_spec = {
            'mrc':    {'color': '#d62728', 'dashes': ""},           # Red Solid
            'median':   {'color': '#d62728', 'dashes': ""},
            'monotone': {'color': '#d62728', 'dashes': ""},
            'linear': {'color': '#1f77b4', 'dashes': (3, 1, 1, 1)}, # Blue Dash-dot
            'logit_mean': {'color': '#1f77b4', 'dashes': (3, 1, 1, 1)},
            'sim':    {'color': '#2ca02c', 'dashes': (2, 2)},       # Green Dashed
            'oracle': {'color': '#ff7f0e', 'dashes': (1, 1)},       # Orange Dotted
            'bound':  {'color': '#ff7f0e', 'dashes': (1, 1)}
        }
        
        fallback_colors = ['#9467bd', '#8c564b', '#e377c2', '#7f7f7f']
        
        color_map = {}
        dash_map = {}
        
        for idx, m in enumerate(unique_methods):
            m_lower = str(m).lower()
            matched = False
            
            for key, spec in style_spec.items():
                if key in m_lower:
                    color_map[m] = spec['color']
                    dash_map[m] = spec['dashes']
                    matched = True
                    break
            
            if not matched:
                color_map[m] = fallback_colors[idx % len(fallback_colors)]
                dash_map[m] = "" # Default Solid

        # 2. Draw Plot
        sns.lineplot(
            data=df,
            x=x_col,
            y=y_label,
            hue=hue_col,
            style=hue_col,
            palette=color_map,
            dashes=dash_map,
            markers=False,      
            linewidth=2.5
        )

    else:
        sns.lineplot(
            data=df, x=x_label, y=y_label, 
            marker=None, linewidth=2.5, color='#ff7f0e'
        )

    # 3. Title Formatting
    x_label_text = x_label
        
    if 'nll' in y_label.lower():
        ylabel = "NLL of P0"
    elif 'p0' in y_label.lower():
        ylabel = r"Empirical P0 error bound($\delta=$" + y_label.split('_')[-1] + r")"
    elif 'gamma' in y_label.lower():
        ylabel = r"$\gamma$ L2 Estimation Error"
    elif 'regret' in y_label.lower():
        ylabel = "Relative Revenue Regret (%)"
    else:
        raise ValueError(f"Unknown y_label type: {y_label}")

    if x_label == 'n_samples':
        final_title = title
    else:
        n_val = df.iloc[0].get('n_samples', 'Unknown')
        final_title = f"{title}"
        
    plt.title(final_title, pad=12)
    plt.xlabel(x_label_text)
    plt.ylabel(ylabel)
    plt.grid(True, which='both', alpha=0.3)
    plt.legend(frameon=True, edgecolor='#cccccc', framealpha=0.9)
    
    clean_title = title.lower().replace(" ", "_").replace(":", "")
    RUN_id=datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
    save_path = FIG_DIR / f"{clean_title}_{RUN_id}.png"
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"Plot saved to {save_path}")
    plt.close()
# ==========================================
# 4. Main Execution
# ==========================================
if __name__ == "__main__":
    print(f"Starting Experiments. ID: {RUN_ID}")
    print(f"Device: {torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')}")

    base_cfg = ExpConfig(
        n_samples=2000,
        dim_z=200,
        est_noise_sigma=1,
        utility_mode='additive',
        noise_distribution='gaussian',
        sim_bias_a=1.0,
        sim_bias_b=1
    )
    mrc_cfg = ExpConfig(
        n_samples=2000,
        dim_z=200,
        est_noise_sigma=1,
        utility_mode='additive',
        noise_distribution='uniform',
        sim_bias_a=1.0,
        sim_bias_b=1
    )
    linear_cfg = ExpConfig(
        n_samples=2000,
        dim_z=200,
        est_noise_sigma=1,
        utility_mode="structural",
        noise_distribution="gaussian",
        sim_bias_a=1.0,
        sim_bias_b=1
    )


    # ------------------------------------------------------
    # Exp 1: Convergence vs Sample Size (n)
    # ------------------------------------------------------
    def exp1():
        df_n = run_experiment_grid(
            base_cfg=base_cfg,
            x_axis_name='n_samples',
            x_values=[100,150 ,200, 500,1000,1500 ,2000,4000], 
            compare_axis_name='algo_type',
            compare_values=['linear', 'mrc'],
            default_y_type='monotone',
            default_z_type='neural',
            default_context_type='concat'
        )
        plot_metric_scaling(
            df_n,x_col='n_samples', x_label='N', y_label='nll', hue_col='algo_type',
            title="Exp 1: Convergence vs Sample Size",
        )

    # ------------------------------------------------------
    # Exp 2: Robustness to Utility Noise (tau)
    # ------------------------------------------------------
    # change the form of tau noise from sigma to high probability bound
    def exp2():
        df_tau = run_experiment_grid(
            base_cfg=base_cfg,
            x_axis_name='est_noise_sigma', 
            x_values=[0.0, 0.5, 1.0, 1.5, 2.0,2.5,3.0],
            compare_axis_name='algo_type',
            compare_values=['linear', 'mrc'],
            default_y_type='monotone',
            default_z_type='neural',
            default_context_type='concat'
        )
        plot_metric_scaling(
            df_tau, x_col='est_noise_sigma',x_label=r'$\bar{\tau}$', y_label='nll', hue_col='algo_type',
            title="Exp 2: Robustness to Utility Noise (tau)"
        )
    # ------------------------------------------------------
    # Exp 3: Bias Type Impact (Linear vs Monotone)
    # ------------------------------------------------------
    def exp3():
        df_bias = run_experiment_grid(
            base_cfg=linear_cfg,
            x_axis_name='n_samples', 
            x_values=[ 100, 200, 500, 1000,1500, 2000,2500], 
            compare_axis_name='y_type', 
            compare_values=['linear', 'monotone'],
            default_algo_type='mrc',
            n_seeds=10,
            default_z_type='neural',
            default_context_type='concat'
        )
        plot_metric_scaling(
            df_bias,x_col='n_samples', x_label='N', y_label='nll', hue_col='y_type',
            title="Exp 3: MRC Robustness across Simulator Types"
        )
    # ------------------------------------------------------
    # Exp 4: Bias Type Impact (Linear vs Monotone)
    # ------------------------------------------------------
    def exp4():
        df_bias = run_experiment_grid(
            base_cfg=mrc_cfg,
            x_axis_name='n_samples', 
            x_values=[100, 200, 500,7500,1500, 2500,6000], 
            compare_axis_name='y_type', 
            compare_values=['linear', 'monotone'],
            default_algo_type='mrc',
            n_seeds=5,
            default_z_type='neural',
            default_context_type='concat'
        )
        plot_metric_scaling(
            df_bias,x_col='n_samples', x_label='N', y_label='p0_error_0.3', hue_col='y_type',
            title="Exp 4: MRC Robustness across Simulator Types"
        )

    # ------------------------------------------------------
    # Exp 5: Assortment Optimization Regret
    # ------------------------------------------------------
    # Note: This uses the SAME data frame as Exp 2 or Exp 4 conceptually, 
    # but we re-run it to be clean (or we could reuse df_d if we want to save time).
    # Here we run explicitly to ensure 'regret' metric is focused.
    # Exp 5: Assortment Regret
    def exp5():
        df_regret = run_experiment_grid(
            base_cfg=base_cfg, # Use the harder config
            x_axis_name='n_samples',
            x_values=[50, 100, 200, 500, 1000, 2000,4000], 
            compare_axis_name='algo_type',
            compare_values=['linear', 'mrc'],
            default_y_type='monotone',
            default_z_type='neural',
            default_context_type='concat'
        )
        plot_metric_scaling(
            df_regret,x_col='n_samples', x_label='N', y_label='regret', hue_col='algo_type',
            title="Exp 5: Assortment Regret Analysis",
        )

    # ------------------------------------------------------
    # Exp 6: Multi-Simulator Robustness (Mean vs Median)
    # ------------------------------------------------------
    def exp6():
        df_multi = run_experiment_grid(
            base_cfg=base_cfg,
            x_axis_name='n_samples',
            x_values=[50, 100, 200, 500, 1000, 2000,4000], 
            compare_axis_name='multi_sim_method', 
            compare_values=['logit_mean', 'median','weighted_mean'], 
            default_task_type='multi_sim',
            default_z_type='neural',
            default_context_type='concat'
        )
        plot_metric_scaling(
            df_multi,x_col='n_samples', x_label='N', y_label='p0_error_0.3', hue_col='multi_sim_method',
            title="Exp 6: Multi-Simulator Robustness"
        )

    # ------------------------------------------------------
    # Exp 7: Simulator Noise Impact (Testing Flip Probability)
    # ------------------------------------------------------
    def exp7():
        df_noise = run_experiment_grid(
            base_cfg=base_cfg,
            x_axis_name='sim_noise_sigma',
            x_values=[0.1, 0.5, 1.0, 2.0, 3.0, 5.0],
            compare_axis_name='algo_type',
            compare_values=['linear', 'mrc'],
            default_y_type='monotone',
            default_z_type='neural',
            default_context_type='concat'
        )
        
        plot_metric_scaling(
            df_noise,x_col='sim_noise_sigma', x_label=r'$\sigma$', y_label='p0_error_0.3', hue_col='algo_type',
            title="Exp 7: Robustness to Simulator Noise"
        )

    # ------------------------------------------------------
    # Exp 8: Assortment Size Impact (Market Density)
    # ------------------------------------------------------
    def exp8():
        df_size = run_experiment_grid(
            base_cfg=base_cfg,
            x_axis_name='max_assortment_size',
            x_values=[5, 10, 20, 50, 100], 
            compare_axis_name='algo_type',
            compare_values=['linear', 'mrc'],
            default_y_type='monotone',
            default_z_type='neural',
            default_context_type='concat'
        )
        
        plot_metric_scaling(
            df_size,x_col='max_assortment_size', x_label=r'$\mathcal{S}_{\max}$', y_label='p0_error_0.3', hue_col='algo_type',
            title="Exp 8: Impact of Assortment Size"
        )
   

Starting Experiments. ID: 20251208_022937
Device: mps


In [59]:
df_exp1 = run_experiment_grid(
            base_cfg=base_cfg,
            x_axis_name='n_samples',
            x_values=[100,150 ,200, 500,1000,1500 ,2000,4000], 
            compare_axis_name='algo_type',
            compare_values=['linear', 'mrc'],
            default_y_type='monotone',
            default_z_type='neural',
            n_seeds=5,
            default_context_type='concat'
        )


=== Running Grid: X=n_samples | Cross Validating: ['algo_type'] ===


Progress: 100%|██████████| 80/80 [01:07<00:00,  1.18it/s]

Saved full data (with Tensors) to: results/logs/grid_n_samples_vs_algo_type_20251206_205535.pkl





In [None]:
df_exp1 = df_exp1[df_exp1['n_samples'] != 1000]

In [62]:
plot_metric_scaling(
    df_exp1,x_col='n_samples', x_label='N', y_label='p0_error_0.3', hue_col='algo_type',
    title="Exp 1: Convergence vs Sample Size",
)


>>> Plotting p0_error_0.3 vs N...
Plot saved to results/figures/exp_1_convergence_vs_sample_size_20251206_205535.png


In [11]:
df_exp2 = run_experiment_grid(
            base_cfg=base_cfg,
            x_axis_name='est_noise_sigma', 
            x_values=[0.0, 0.5, 1.0, 1.5, 2.0,2.5,3.0],
            compare_axis_name='algo_type',
            compare_values=['linear', 'mrc'],
            default_y_type='monotone',
            default_z_type='neural',
            default_context_type='concat'
        )


=== Running Grid: X=est_noise_sigma | Cross Validating: ['algo_type'] ===


Progress: 100%|██████████| 70/70 [01:46<00:00,  1.53s/it]

Saved full data (with Tensors) to: results/logs/grid_est_noise_sigma_vs_algo_type_20251207_195941.pkl





In [13]:
plot_metric_scaling(
    df_exp2, x_col='est_noise_sigma',x_label=r'$\bar{\tau}$', y_label='p0_error_0.3', hue_col='algo_type',
    title="Exp 2: Robustness to Utility Noise (tau)"
)


>>> Plotting p0_error_0.3 vs $\bar{\tau}$...
Plot saved to results/figures/exp_2_robustness_to_utility_noise_(tau)_20251207_195941.png


In [15]:
df_exp3 = run_experiment_grid(
            base_cfg=linear_cfg,
            x_axis_name='n_samples', 
            x_values=[ 100, 200, 500, 1000,1500, 2000,2500,4000], 
            compare_axis_name='y_type', 
            compare_values=['linear', 'monotone'],
            default_algo_type='linear',
            n_seeds=10,
            default_z_type='neural',
            default_context_type='concat'
        )



=== Running Grid: X=n_samples | Cross Validating: ['y_type'] ===


Progress: 100%|██████████| 160/160 [00:57<00:00,  2.78it/s]

Saved full data (with Tensors) to: results/logs/grid_n_samples_vs_y_type_20251207_195941.pkl





In [30]:
plot_metric_scaling(
    df_exp3,x_col='n_samples', x_label='N', y_label='nll', hue_col='y_type',
    title="Exp 3: linear Robustness across Simulator Types"
)


>>> Plotting nll vs N...
Plot saved to results/figures/exp_3_linear_robustness_across_simulator_types_20251207_203834_256113.png


In [37]:
df_exp3_mrc = run_experiment_grid(
            base_cfg=linear_cfg,
            x_axis_name='n_samples', 
            x_values=[ 100, 200, 1000,1500, 2000,3000,6000], 
            compare_axis_name='y_type', 
            compare_values=['linear', 'monotone'],
            default_algo_type='mrc',
            n_seeds=5,
            default_z_type='neural',
            default_context_type='concat'
        )



=== Running Grid: X=n_samples | Cross Validating: ['y_type'] ===


Progress: 100%|██████████| 70/70 [01:20<00:00,  1.15s/it]

Saved full data (with Tensors) to: results/logs/grid_n_samples_vs_y_type_20251207_201251.pkl





In [38]:
plot_metric_scaling(
    df_exp3_mrc,x_col='n_samples', x_label='N', y_label='nll', hue_col='y_type',
    title="Exp 3: MRC Robustness across Simulator Types"
)


>>> Plotting nll vs N...
Plot saved to results/figures/exp_3_mrc_robustness_across_simulator_types_20251207_205138_273738.png


In [40]:
df_exp4 = run_experiment_grid(
            base_cfg=mrc_cfg,
            x_axis_name='n_samples', 
            x_values=[100, 200, 500,1500,2000, 2500,4000,7500], 
            compare_axis_name='y_type', 
            compare_values=['linear', 'monotone'],
            default_algo_type='mrc',
            n_seeds=5,
            default_z_type='neural',
            default_context_type='concat'
        )


=== Running Grid: X=n_samples | Cross Validating: ['y_type'] ===


Progress: 100%|██████████| 80/80 [01:37<00:00,  1.22s/it]

Saved full data (with Tensors) to: results/logs/grid_n_samples_vs_y_type_20251207_201251.pkl





In [43]:
plot_metric_scaling(
            df_exp4,x_col='n_samples', x_label='N', y_label='nll', hue_col='y_type',
            title="Exp 4: MRC Robustness across Simulator Types"
        )


>>> Plotting nll vs N...
Plot saved to results/figures/exp_4_mrc_robustness_across_simulator_types_20251207_211043_493436.png


In [44]:
df_exp6 = run_experiment_grid(
            base_cfg=base_cfg,
            x_axis_name='n_samples',
            x_values=[50, 100, 200, 500, 1000, 2000,4000], 
            compare_axis_name='multi_sim_method', 
            compare_values=['logit_mean', 'median','weighted_mean'], 
            default_task_type='multi_sim',
            default_z_type='neural',
            n_seeds=20,
            default_context_type='concat'
        )



=== Running Grid: X=n_samples | Cross Validating: ['multi_sim_method'] ===


Progress: 100%|██████████| 420/420 [04:09<00:00,  1.68it/s]


Saved full data (with Tensors) to: results/logs/grid_n_samples_vs_multi_sim_method_20251207_201251.pkl


In [46]:
plot_metric_scaling(
    df_exp6,x_col='n_samples', x_label='N', y_label='nll', hue_col='multi_sim_method',
    title="Exp 6: Multi-Simulator Robustness"
)


>>> Plotting nll vs N...
Plot saved to results/figures/exp_6_multi-simulator_robustness_20251207_211548_949301.png


In [8]:
df_exp61 = run_experiment_grid(
            base_cfg=base_cfg,
            x_axis_name='n_samples',
            x_values=[50, 100, 200, 500, 1000, 2000,4000], 
            compare_axis_name='multi_sim_method', 
            compare_values=['logit_mean', 'median','weighted_mean'], 
            default_task_type='multi_sim',
            default_z_type='neural',
            n_seeds=40,
            default_context_type='concat'
        )



=== Running Grid: X=n_samples | Cross Validating: ['multi_sim_method'] ===


Progress:   3%|▎         | 70/2100 [00:02<01:11, 28.43it/s]



Progress:   6%|▋         | 134/2100 [00:04<00:32, 60.19it/s]



Progress:   8%|▊         | 158/2100 [00:04<00:28, 69.21it/s]



Progress:   8%|▊         | 174/2100 [00:05<00:27, 70.39it/s]



Progress:   9%|▉         | 190/2100 [00:05<00:27, 70.53it/s]



Progress:  10%|▉         | 206/2100 [00:05<00:26, 71.41it/s]



Progress:  18%|█▊        | 370/2100 [00:11<01:59, 14.46it/s]



Progress:  21%|██        | 431/2100 [00:14<00:43, 38.73it/s]



Progress:  22%|██▏       | 456/2100 [00:15<00:39, 41.91it/s]



Progress:  23%|██▎       | 476/2100 [00:15<00:38, 42.27it/s]



Progress:  23%|██▎       | 486/2100 [00:15<00:38, 41.95it/s]



Progress:  24%|██▍       | 506/2100 [00:16<00:38, 41.90it/s]



Progress:  32%|███▏      | 669/2100 [00:29<03:27,  6.90it/s]



Progress:  36%|███▌      | 751/2100 [00:36<00:58, 23.23it/s]



Progress:  36%|███▌      | 757/2100 [00:36<00:57, 23.34it/s]



Progress:  37%|███▋      | 772/2100 [00:37<01:00, 22.02it/s]



Progress:  37%|███▋      | 781/2100 [00:37<01:00, 21.88it/s]



Progress:  47%|████▋     | 992/2100 [01:10<06:01,  3.06it/s]



Progress:  50%|█████     | 1055/2100 [01:38<06:01,  2.89it/s]



Progress:  61%|██████    | 1278/2100 [05:44<33:04,  2.41s/it]



Progress:  65%|██████▍   | 1355/2100 [06:57<05:40,  2.19it/s]



Progress:  66%|██████▌   | 1380/2100 [07:11<06:38,  1.81it/s]



Progress:  79%|███████▉  | 1655/2100 [13:17<06:37,  1.12it/s]



Progress:  80%|████████  | 1680/2100 [13:44<06:19,  1.11it/s]



Progress:  93%|█████████▎| 1955/2100 [22:22<03:16,  1.35s/it]



Progress: 100%|██████████| 2100/2100 [26:00<00:00,  1.35it/s]


Saved full data (with Tensors) to: results/logs/grid_n_samples_vs_multi_sim_method_20251206_221153.pkl


In [48]:
plot_metric_scaling(
    df_exp61,x_col='n_samples', x_label='N', y_label='nll', hue_col='multi_sim_method',
    title="Exp 6: Multi-Simulator Robustness"
)


>>> Plotting nll vs N...
Plot saved to results/figures/exp_6_multi-simulator_robustness_20251207_212300_571187.png


In [49]:
df_exp62 = run_experiment_grid(
            base_cfg=base_cfg,
            x_axis_name='n_samples',
            x_values=[50, 100, 200, 500, 1000, 2000,4000], 
            compare_axis_name='multi_sim_method', 
            compare_values=['logit_mean', 'median','weighted_mean'], 
            default_task_type='multi_sim',
            default_z_type='neural',
            n_seeds=100,
            default_context_type='concat'
        )



=== Running Grid: X=n_samples | Cross Validating: ['multi_sim_method'] ===


Progress:   3%|▎         | 69/2100 [00:02<01:10, 28.80it/s]



Progress:   6%|▋         | 136/2100 [00:04<00:31, 61.86it/s]



Progress:   7%|▋         | 151/2100 [00:04<00:29, 65.90it/s]



Progress:   8%|▊         | 173/2100 [00:05<00:30, 64.03it/s]



Progress:   9%|▉         | 187/2100 [00:05<00:29, 64.10it/s]



Progress:  10%|█         | 211/2100 [00:05<00:27, 69.85it/s]



Progress:  18%|█▊        | 370/2100 [00:11<01:53, 15.26it/s]



Progress:  20%|██        | 429/2100 [00:14<00:45, 36.45it/s]



Progress:  22%|██▏       | 457/2100 [00:14<00:40, 40.70it/s]



Progress:  22%|██▏       | 472/2100 [00:15<00:40, 40.24it/s]



Progress:  23%|██▎       | 487/2100 [00:15<00:40, 39.99it/s]



Progress:  24%|██▍       | 507/2100 [00:16<00:38, 41.30it/s]



Progress:  32%|███▏      | 669/2100 [00:28<02:53,  8.24it/s]



Progress:  36%|███▌      | 750/2100 [00:35<01:12, 18.58it/s]



Progress:  36%|███▌      | 758/2100 [00:35<01:08, 19.60it/s]



Progress:  37%|███▋      | 770/2100 [00:36<01:07, 19.62it/s]



Progress:  37%|███▋      | 781/2100 [00:36<01:04, 20.45it/s]



Progress:  47%|████▋     | 992/2100 [01:14<07:15,  2.54it/s]



Progress:  50%|█████     | 1055/2100 [01:48<08:29,  2.05it/s]



Progress:  61%|██████    | 1278/2100 [06:52<37:03,  2.71s/it]



Progress:  65%|██████▍   | 1355/2100 [08:51<11:38,  1.07it/s]



Progress:  66%|██████▌   | 1380/2100 [09:14<08:05,  1.48it/s]



Progress:  79%|███████▉  | 1655/2100 [16:30<07:13,  1.03it/s]



Progress:  80%|████████  | 1680/2100 [16:53<06:51,  1.02it/s]



Progress:  93%|█████████▎| 1955/2100 [23:56<02:37,  1.08s/it]



Progress: 100%|██████████| 2100/2100 [26:46<00:00,  1.31it/s]


Saved full data (with Tensors) to: results/logs/grid_n_samples_vs_multi_sim_method_20251207_201251.pkl


In [52]:
plot_metric_scaling(
    df_exp61,x_col='n_samples', x_label='N', y_label='nll', hue_col='multi_sim_method',
    title="Exp 6: Multi-Simulator Robustness"
)


>>> Plotting nll vs N...
Plot saved to results/figures/exp_6_multi-simulator_robustness_20251207_215751_365793.png


In [187]:
regret_cfg = ExpConfig(
        n_samples=2000,
        dim_z=200,
        est_noise_sigma=0.1,
        sim_noise_sigma=0.1,
        utility_mode='additive',
        noise_distribution='uniform',
        sim_bias_a=1.0,
        sim_bias_b=1,
        min_assortment_size = 50,
        max_assortment_size = 100 
    )

In [192]:
df_exp5 = run_experiment_grid(
    base_cfg=regret_cfg, # Use the harder config
    x_axis_name='n_samples',
    x_values=[100, 200, 500, 1000,4000,5000], 
    compare_axis_name=['algo_type', 'y_type'],
    compare_values=[
        ('linear', 'linear'),  
        ('mrc', 'monotone')    
    ],
    n_seeds=5,
    cross_product=False,
    regret_need=True,
    default_y_type='monotone',
    default_z_type='neural',
    default_context_type='concat',   
)


=== Running Grid: X=n_samples | Cross Validating: ['algo_type', 'y_type'] ===






[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



[A[A[A[A



Progress: 100%|██████████| 60/60 [01

Saved full data (with Tensors) to: results/logs/grid_n_samples_vs_algo_type_y_type_20251208_022937.pkl





In [191]:
plot_metric_scaling(
    df_exp5,x_col='n_samples', x_label='N', y_label='regret', hue_col='algo_type',
    title="Exp 5: Assortment Regret Analysis",
)


>>> Plotting regret vs N...
Plot saved to results/figures/exp_5_assortment_regret_analysis_20251208_023718_825820.png


In [1]:
import pandas as pd
path = "/Users/xiongjiangkai/xjk_coding/UnobservedChoice_Calibration/results/logs/grid_sim_bias_b_vs_algo_type_y_type_20251216_152826.pkl"

df = pd.read_pickle(path)

In [2]:
df.columns

Index(['sim_bias_b', 'seed', 'combo_label', 'algo_type', 'y_type', 'gamma_hat',
       'gamma_true', 'p0_pred', 'p0_true', 'time', 'regret', 'n_samples'],
      dtype='object')

In [3]:
len(df)

560

In [5]:
df.combo_label.unique()

array(['linear-linear', 'linear-monotone', 'mrc-linear', 'mrc-monotone'],
      dtype=object)

In [6]:
x_min, x_max = df["sim_bias_b"].min(), df['sim_bias_b'].max()

In [7]:
x_min, x_max

(0.1, 5.0)

In [10]:
import sys
import os
import time
import datetime
import itertools
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from tqdm import tqdm
from typing import List, Dict, Any, Optional
from dataclasses import replace
from matplotlib.ticker import ScalarFormatter, NullFormatter

# Ensure src is in python path
sys.path.append(str(Path().resolve().parent))

from src.config import ExpConfig
from src.engine.factory import EngineFactory
from src.algorithms.solver import CalibrationSolver
from src.modules.y_mappers import MonotoneYMapper
from src.utils.metrics import compute_p0_from_logits

# Imports for optimization and regret calculation
from src.utils.optimization import solve_optimal_assortment, calculate_revenue

# ==========================================
# Global Setup
# ==========================================
RESULTS_DIR = Path("results")
LOG_DIR = RESULTS_DIR / "logs"
FIG_DIR = RESULTS_DIR / "figures"
LOG_DIR.mkdir(parents=True, exist_ok=True)
FIG_DIR.mkdir(parents=True, exist_ok=True)

RUN_ID = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

# ==========================================
# [AESTHETICS] Global Style Configuration
# ==========================================
import matplotlib.ticker as ticker

# 1. 颜色映射 (Color Map) - 语义绑定
COLOR_MAP = {
    "mrc": "tab:red",  # 红色
    "linear": "tab:blue",  # 蓝色
    "median": "tab:green",  # 绿色
    "weighted": "tab:orange",  # 橙色
    "mean": "tab:purple",  # 紫色
    "naive": "tab:gray",  # 灰色
    "sim": "tab:gray",  # 灰色
    "default": "black",
}

# 2. 纹理库 (Texture Bank) - 字典序分配
# 逻辑: (LineStyle, Marker)
# [User Request]: Index 3 (4th) is 'x', Index 4 (5th) is 'D'
TEXTURE_BANK = [
    ("-", "o"),  # 1. 实线 + 圆圈 (最强)
    ("--", "s"),  # 2. 虚线 + 方块
    ("-.", "^"),  # 3. 点划线 + 三角
    (":", "x"),  # 4. 点线 + 叉号 (Swapped)
    ("-", "D"),  # 5. 实线 + 菱形 (Swapped)
    ("--", "*"),  # 6. 虚线 + 星号
    ("-.", "v"),  # 7. 点划线 + 倒三角
    (":", "P"),  # 8. 点线 + 加号
]


def set_seed(seed: int):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def plot_metric_scaling(
    df: pd.DataFrame,
    x_col: str,
    x_label: str,
    y_label: str,
    hue_col: Optional[str] = None,
    title: str = "Scaling Analysis",
    y_top_margin: Optional[float] = None,
    legend_loc: str = "upper right",
    log_y: Optional[bool] = None,
    log_x: Optional[bool] = None,
):
    print(f"\n>>> [Plot] {y_label} vs {x_label}...")

    # --- 1. Data Prep (Unchanged) ---
    tensor_cols = ["p0_pred", "p0_true", "gamma_hat", "gamma_true"]

    def move_to_cpu(val):
        if isinstance(val, torch.Tensor):
            return val.detach().cpu()
        return val

    plot_df = df.copy()
    for col in tensor_cols:
        if col in plot_df.columns:
            if len(plot_df) > 0 and isinstance(plot_df[col].iloc[0], torch.Tensor):
                plot_df[col] = plot_df[col].apply(move_to_cpu)

    # --- 2. Metric Computation (Unchanged) ---
    def compute_pooled_metric(sub_df):
        if len(sub_df) == 0:
            return float("nan")

        if "p0" in y_label or "nll" in y_label:
            all_preds = torch.cat(list(sub_df["p0_pred"])).float().numpy()
            all_trues = torch.cat(list(sub_df["p0_true"])).float().numpy()

            if "nll" in y_label:
                epsilon = 1e-7
                all_preds = np.clip(all_preds, epsilon, 1 - epsilon)
                nll = -(
                    all_trues * np.log(all_preds)
                    + (1 - all_trues) * np.log(1 - all_preds)
                )
                return np.mean(nll)
            elif "error" in y_label:
                abs_diff = np.abs(all_preds - all_trues)
                return np.quantile(abs_diff, 0.70)

        elif "gamma" in y_label:
            hat_matrix = torch.stack(list(sub_df["gamma_hat"]))
            true_matrix = torch.stack(list(sub_df["gamma_true"]))
            l2_errors = torch.norm(hat_matrix - true_matrix, p=2, dim=1)
            return l2_errors.mean().item()

        elif "regret" in y_label:
            valid_regrets = sub_df["regret"].dropna()
            return valid_regrets.mean() if not valid_regrets.empty else float("nan")
        return 0.0

    if plot_df.empty:
        print("[Warning] No data to plot.")
        return

    df_agg = (
        plot_df.groupby([x_col, hue_col] if hue_col else [x_col])
        .apply(compute_pooled_metric, include_groups=False)
        .reset_index(name=y_label)
    )

    # --- 3. Style Setup ---
    sns.set_theme(style="ticks", context="paper")
    plt.rcParams.update(
        {
            "font.family": "serif",
            "font.serif": ["Times New Roman"],
            "font.size": 20,
            "axes.labelsize": 22,
            "axes.titlesize": 22,
            "xtick.labelsize": 18,
            "ytick.labelsize": 18,
            "legend.fontsize": 16,
            "lines.linewidth": 2.5,
            "lines.markersize": 9,
            "axes.grid": True,
            "grid.linestyle": "--",
            "grid.alpha": 0.4,
        }
    )

    LABEL_MAP = {
        "n_samples": "n",
        "est_noise_sigma": x_label,
        "sim_noise_sigma": r"$\sigma_{\epsilon}$",
        "dim_z": "d",
        "max_assortment_size": r"$|\mathcal{S}|_{\max}$",
        "gamma_error": "Parameter Estimation Error",
        "p0_error": "Empirical p0 Error",
        "nll": "Negative Log-Likelihood",
        "regret": "Sub-optimality (%)",
    }

    fig, ax = plt.subplots(figsize=(8, 6))

    # --- 4. Plotting Loop (UPDATED LOGIC) ---
    # Sort hues alphabetically to ensure deterministic order
    unique_hues = sorted(df_agg[hue_col].unique()) if hue_col else [None]

    for idx, hue_val in enumerate(unique_hues):
        if hue_val:
            subset = df_agg[df_agg[hue_col] == hue_val].sort_values(x_col)
            label_str = str(hue_val)
        else:
            subset = df_agg.sort_values(x_col)
            label_str = "Default"

        # [NEW LOGIC START] =========================================

        # A. Texture by Dictionary Order (Strict Indexing)
        # 无论是什么 Label，只要排第几，就用第几号纹理
        tex_idx = idx % len(TEXTURE_BANK)
        ls, marker = TEXTURE_BANK[tex_idx]

        # B. Color by Semantic Name
        s_lower = label_str.lower()
        color = COLOR_MAP["default"]
        # Find matching color key
        for k, v in COLOR_MAP.items():
            if k in s_lower:
                color = v
                break

        # C. Label Cleanup
        # e.g. "linear-monotone" -> "Linear (Monotone Sim)"
        clean_label = label_str.replace("_", " ").title()
        clean_label = clean_label.replace("Mrc", "MRC").replace("Linear", "Lin")
        if "-" in clean_label:
            parts = clean_label.split("-")
            clean_label = f"{parts[0]} ({parts[1]} Sim)"

        # Oracle cleanup
        if "Oracle" in clean_label:
            # 保持纹理分配不变，只在图例文字上做点微调（可选）
            pass

        # [NEW LOGIC END] ===========================================

        ax.plot(
            subset[x_col],
            subset[y_label],
            label=clean_label,
            color=color,
            linestyle=ls,
            marker=marker,
            alpha=0.9,
        )

    # --- 5. Axes & Ticks Logic (Preserved) ---
    # X-Axis Log Logic
    use_log_x = False
    if log_x is not None:
        use_log_x = log_x
        x_min, x_max = df_agg[x_col].min(), df_agg[x_col].max()
        if x_col == "n_samples" or "sigma" in x_col:
            if x_min > 1e-9 and (x_max / x_min > 10):
                use_log_x = True
    if use_log_x:
        ax.set_xscale("log")
        custom_ticks = {x_min, x_max}
        if x_col == "n_samples" and x_min < 1000 < x_max:
            custom_ticks.add(1000)

        if ("sigma" in x_col or 'sim_bias_b' in x_col) and x_min < 1.0 < x_max:
            custom_ticks.add(1.0)

        ax.set_xticks(sorted(list(custom_ticks)))
        ax.xaxis.set_major_formatter(ScalarFormatter())
        ax.xaxis.set_minor_formatter(NullFormatter())

    # Y-Axis Log Logic
    use_log_y = False
    if log_y is not None:
        use_log_y = log_y
    else:
        y_min, y_max = df_agg[y_label].min(), df_agg[y_label].max()
        if "error" in y_label or "nll" in y_label or "regret" in y_label:
            if y_min > 1e-9 and (y_max / y_min > 20):
                use_log_y = True

    if use_log_y:
        ax.set_yscale("log")
        ax.yaxis.set_major_formatter(ScalarFormatter())
        ax.yaxis.set_minor_formatter(NullFormatter())
        ax.yaxis.get_major_formatter().set_scientific(False)
        ax.yaxis.get_major_formatter().set_useOffset(False)

    # Grid
    ax.grid(True, which="major", linestyle="--", linewidth=1.0, alpha=0.5)
    ax.grid(True, which="minor", linestyle=":", linewidth=0.5, alpha=0.3)

    # --- 6. Layout & Legend ---
    ax.set_xlabel(LABEL_MAP.get(x_col, x_label), fontweight="bold")
    ax.set_ylabel(LABEL_MAP.get(y_label, y_label), fontweight="bold")
    # ax.set_title(title, pad=15, fontweight='bold')

    if y_top_margin:
        curr_bottom, curr_top = ax.get_ylim()
        if ax.get_yscale() == "log":
            ax.set_ylim(curr_bottom, curr_top * (10**y_top_margin))
        else:
            ax.set_ylim(curr_bottom, curr_top * y_top_margin)

    ax.legend(
        loc=legend_loc, frameon=True, edgecolor="black", framealpha=0.95, fancybox=False
    )

    clean_title = (
        title.lower()
        .replace(" ", "_")
        .replace(":", "")
        .replace("$", "")
        .replace("\\", "")
    )
    RUN_id = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    save_path = FIG_DIR / f"{clean_title}_{RUN_id}.png"
    plt.tight_layout()
    plt.savefig(save_path, bbox_inches="tight", dpi=300)
    plt.close()
    print(f"   Saved to {save_path}")

In [12]:
import pandas as pd
path = "/Users/xiongjiangkai/xjk_coding/UnobservedChoice_Calibration/results/logs/grid_sim_bias_b_vs_algo_type_y_type_20251216_195141.pkl"

df = pd.read_pickle(path)

In [13]:
plot_metric_scaling(
            df,
            x_col="sim_bias_b",
            x_label=r"$b^*$",
            y_label="p0_error",
            hue_col="combo_label",
            title="Robustness to Bias Magnitude",
            y_top_margin=1.4,
            log_y=False,
            log_x=True, # Log scale X-axis makes it easier to see 0.1 vs 5.0
        )


>>> [Plot] p0_error vs $b^*$...
   Saved to results/figures/robustness_to_bias_magnitude_20251216_200142.png
