# param recovery visualization

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('svg')
matplotlib.rcParams['svg.fonttype'] = 'none'

import pandas as pd
import os
from scipy.optimize import curve_fit
import matplotlib.colors as mcolors
import warnings
warnings.filterwarnings("ignore", category=UserWarning, append=True)
from scipy.stats import f_oneway
from statsmodels.stats.multicomp import pairwise_tukeyhsd
from scipy.stats import ttest_rel
from itertools import combinations
from scipy.stats import f_oneway
import seaborn as sns

from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score
import numpy as np

In [2]:
output_dir = r"28_RL_agent_TDlearn_output_both_param_recovery_visualization"
os.makedirs(output_dir, exist_ok=True)

file_fit_greedy = "13_RL_agent_TDlearn_output/models_evaluation.csv"
file_fit_softmax = "13_RL_agent_TDlearn_output_softmax/models_evaluation.csv"
file_fit_rs = "13_RL_agent_TDlearn_output_risk_sensitive/models_evaluation.csv"
file_fit_wsls = "13_RL_agent_TDlearn_output_wsls/models_evaluation.csv"
file_fit_dualQ = "13_RL_agent_TDlearn_output_risk_dualQ/models_evaluation.csv"



fit_greedy = pd.read_csv(file_fit_greedy)
fit_softmax = pd.read_csv(file_fit_softmax)
fit_rs = pd.read_csv(file_fit_rs)
fit_wsls = pd.read_csv(file_fit_wsls)
fit_dualQ = pd.read_csv(file_fit_dualQ)



file_simulated_greedy = "27_RL_agent_TDlearn_output_both_param_recovery/models_evaluation_greedy.csv"
file_simulated_softmax = "27_RL_agent_TDlearn_output_both_param_recovery/models_evaluation_softmax.csv"
file_simulated_rs = "27_RL_agent_TDlearn_output_both_param_recovery/models_evaluation_risk_sensitive.csv"
file_simulated_wsls = "27_RL_agent_TDlearn_output_both_param_recovery/models_evaluation_wsls.csv"
file_simulated_dualQ = "27_RL_agent_TDlearn_output_both_param_recovery/models_evaluation_risk_dualQ.csv"







simulated_greedy = pd.read_csv(file_simulated_greedy)
simulated_softmax = pd.read_csv(file_simulated_softmax)
simulated_rs = pd.read_csv(file_simulated_rs)
simulated_wsls = pd.read_csv(file_simulated_wsls)
simulated_dualQ = pd.read_csv(file_simulated_dualQ)



In [4]:
fig, axes = plt.subplots(5, 3, figsize=(10, 18))  # 5 rows, 3 columns
fig.suptitle("Fitted vs Simulated Parameters", fontsize=20)

def plot_with_regression(ax, x, y, title):
    import numpy as np
    from sklearn.linear_model import LinearRegression
    from sklearn.metrics import r2_score
    import matplotlib.pyplot as plt  # needed for the formatter

    # Convert to arrays
    x = np.asarray(x)
    y = np.asarray(y)

    iqr_threshold = 2
    # ── Remove outliers in **fitted** values (x) ────────────────────────────────
    q1_x, q3_x = np.percentile(x, [25, 75])
    iqr_x      = q3_x - q1_x
    lb_x       = q1_x - iqr_threshold * iqr_x
    ub_x       = q3_x + iqr_threshold * iqr_x
    mask_x     = (x >= lb_x) & (x <= ub_x)

    # ── Remove outliers in **simulated** values (y) ────────────────────────────
    q1_y, q3_y = np.percentile(y, [25, 75])
    iqr_y      = q3_y - q1_y
    lb_y       = q1_y - iqr_threshold * iqr_y
    ub_y       = q3_y + iqr_threshold * iqr_y
    mask_y     = (y >= lb_y) & (y <= ub_y)

    # ── Apply combined mask ────────────────────────────────────────────────────
    mask = mask_x & mask_y
    x, y = x[mask], y[mask]

    # ── Scatter & regression ───────────────────────────────────────────────────
    ax.scatter(x, y, alpha=0.7)

    x_vals = x.reshape(-1, 1)
    model  = LinearRegression().fit(x_vals, y)
    y_pred = model.predict(x_vals)

    x_line = np.linspace(x.min(), x.max(), 100).reshape(-1, 1)
    y_line = model.predict(x_line)
    ax.plot(x_line, y_line, color='red', linewidth=2)

    r2 = r2_score(y, y_pred)
    ax.set_title(f"{title}\nR² = {r2:.2f}")
    ax.set_xlabel("Fitted")
    ax.set_ylabel("Simulated")

    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

    # ── Format y-axis to show 2 decimal digits ────────────────────────────────
    ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f'{y:.2f}'))


# # Row 0: Risk Sensitive
# params_rs = [
#     ("best_alpha_plus",  "best_alpha_plus",  "Risk Sensitive: best_α⁺"),
#     ("best_alpha_minus", "best_alpha_minus", "Risk Sensitive: best_α⁻"),
#     ("best_beta",        "best_beta",        "Risk Sensitive: best_Reverse Temp"),
# ]
# for col, (fit_col, sim_col, title) in enumerate(params_rs):
#     ax = axes[0, col]
#     plot_with_regression(ax, fit_rs[fit_col], simulated_rs[sim_col], title)

# # Row 1: Dual Q
# params_dualq = [
#     ("best_alpha_r", "best_alpha_r", "Dual Q: best_α Reward"),
#     ("best_alpha_s", "best_alpha_s", "Dual Q: best_α Risk"),
#     ("best_beta",    "best_beta",    "Dual Q: best_Reverse Temp"),
# ]
# for col, (fit_col, sim_col, title) in enumerate(params_dualq):
#     ax = axes[1, col]
#     plot_with_regression(ax, fit_dualQ[fit_col], simulated_dualQ[sim_col], title)

# Row 2: Softmax
params_softmax = [
    ("best_alpha", "best_alpha", "Softmax: best_Learning Rate"),
    ("best_beta",  "best_beta",  "Softmax: best_Reverse Temp"),
]
for col, (fit_col, sim_col, title) in enumerate(params_softmax):
    ax = axes[2, col]
    plot_with_regression(ax, fit_softmax[fit_col], simulated_softmax[sim_col], title)

# Row 3: Greedy
params_greedy = [
    ("best_alpha", "best_alpha", "Greedy: best_Learning Rate"),
    ("best_beta",  "best_beta",  "Greedy: best_Epsilon"),
]
for col, (fit_col, sim_col, title) in enumerate(params_greedy):
    ax = axes[3, col]
    plot_with_regression(ax, fit_greedy[fit_col], simulated_greedy[sim_col], title)

# Row 4: WSLS
params_wsls = [
    ("best_alpha", "best_alpha", "WSLS: best_P(Stay|Win)"),
    ("best_beta",  "best_beta",  "WSLS: best_P(Shift|Lose)"),
]
for col, (fit_col, sim_col, title) in enumerate(params_wsls):
    ax = axes[4, col]
    plot_with_regression(ax, fit_wsls[fit_col], simulated_wsls[sim_col], title)

# Hide any unused axes
for row in range(5):
    for col in range(3):
        if (row == 2 and col == 2) or (row == 3 and col == 2) or (row == 4 and col == 2):
            axes[row, col].axis("off")

plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.savefig(os.path.join(output_dir, "fitted_vs_simulated_all_params.pdf"))

filename = os.path.join(output_dir, "fitted_vs_simulated_all_params.svg")
plt.savefig(filename, format='svg', dpi=1200, bbox_inches='tight', transparent=True)

plt.show()
