# param recovery visualization

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('svg')
matplotlib.rcParams['svg.fonttype'] = 'none'

import pandas as pd
import os
from scipy.optimize import curve_fit
import matplotlib.colors as mcolors
import warnings
warnings.filterwarnings("ignore", category=UserWarning, append=True)
from scipy.stats import f_oneway
from statsmodels.stats.multicomp import pairwise_tukeyhsd
from scipy.stats import ttest_rel
from scipy import stats

from itertools import combinations
from scipy.stats import f_oneway
import seaborn as sns

from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score
import numpy as np
import pathlib

In [2]:
output_dir = r"28_RL_agent_TDlearn_output_both_param_recovery_visualization"
os.makedirs(output_dir, exist_ok=True)

file_fit_greedy = "healthy/13_RL_agent_TDlearn_output/models_evaluation.csv"
file_fit_softmax = "healthy/13_RL_agent_TDlearn_output_softmax/models_evaluation.csv"
file_fit_rs = "healthy/13_RL_agent_TDlearn_output_risk_sensitive/models_evaluation.csv"
file_fit_wsls = "healthy/13_RL_agent_TDlearn_output_wsls/models_evaluation.csv"
file_fit_dualQ = "healthy/13_RL_agent_TDlearn_output_risk_dualQ/models_evaluation.csv"


fit_greedy = pd.read_csv(file_fit_greedy)
fit_softmax = pd.read_csv(file_fit_softmax)
fit_rs = pd.read_csv(file_fit_rs)
fit_wsls = pd.read_csv(file_fit_wsls)
fit_dualQ = pd.read_csv(file_fit_dualQ)





file_simulated_greedy = "healthy/27_RL_agent_TDlearn_output_both_param_recovery/models_evaluation_greedy.csv"
file_simulated_softmax = "healthy/27_RL_agent_TDlearn_output_both_param_recovery/models_evaluation_softmax.csv"
file_simulated_rs = "healthy/27_RL_agent_TDlearn_output_both_param_recovery/models_evaluation_risk_sensitive.csv"
file_simulated_wsls = "healthy/27_RL_agent_TDlearn_output_both_param_recovery/models_evaluation_wsls.csv"
file_simulated_dualQ = "healthy/27_RL_agent_TDlearn_output_both_param_recovery/models_evaluation_risk_dualQ.csv"

simulated_greedy = pd.read_csv(file_simulated_greedy)
simulated_softmax = pd.read_csv(file_simulated_softmax)
simulated_rs = pd.read_csv(file_simulated_rs)
simulated_wsls = pd.read_csv(file_simulated_wsls)
simulated_dualQ = pd.read_csv(file_simulated_dualQ)



# reading epileptic data

In [3]:
file_fit_greedy = "epileptic/13_RL_agent_TDlearn_output/models_evaluation.csv"
file_fit_softmax = "epileptic/13_RL_agent_TDlearn_output_softmax/models_evaluation.csv"
file_fit_rs = "epileptic/13_RL_agent_TDlearn_output_risk_sensitive/models_evaluation.csv"
file_fit_wsls = "epileptic/13_RL_agent_TDlearn_output_wsls/models_evaluation.csv"
file_fit_dualQ = "epileptic/13_RL_agent_TDlearn_output_risk_dualQ/models_evaluation.csv"


fit_greedy_ep = pd.read_csv(file_fit_greedy)
fit_softmax_ep = pd.read_csv(file_fit_softmax)
fit_rs_ep = pd.read_csv(file_fit_rs)
fit_wsls_ep = pd.read_csv(file_fit_wsls)
fit_dualQ_ep = pd.read_csv(file_fit_dualQ)





file_simulated_greedy = "epileptic/27_RL_agent_TDlearn_output_both_param_recovery/models_evaluation_greedy.csv"
file_simulated_softmax = "epileptic/27_RL_agent_TDlearn_output_both_param_recovery/models_evaluation_softmax.csv"
file_simulated_rs = "epileptic/27_RL_agent_TDlearn_output_both_param_recovery/models_evaluation_risk_sensitive.csv"
file_simulated_wsls = "epileptic/27_RL_agent_TDlearn_output_both_param_recovery/models_evaluation_wsls.csv"
file_simulated_dualQ = "epileptic/27_RL_agent_TDlearn_output_both_param_recovery/models_evaluation_risk_dualQ.csv"

simulated_greedy_ep = pd.read_csv(file_simulated_greedy)
simulated_softmax_ep = pd.read_csv(file_simulated_softmax)
simulated_rs_ep = pd.read_csv(file_simulated_rs)
simulated_wsls_ep = pd.read_csv(file_simulated_wsls)
simulated_dualQ_ep = pd.read_csv(file_simulated_dualQ)

In [4]:
fit_greedy = pd.concat([fit_greedy, fit_greedy_ep], ignore_index=True)
fit_softmax = pd.concat([fit_softmax, fit_softmax_ep], ignore_index=True)
fit_rs = pd.concat([fit_rs, fit_rs_ep], ignore_index=True)
fit_wsls = pd.concat([fit_wsls, fit_wsls_ep], ignore_index=True)
fit_dualQ = pd.concat([fit_dualQ, fit_dualQ_ep], ignore_index=True)

simulated_greedy = pd.concat([simulated_greedy, simulated_greedy_ep], ignore_index=True)
simulated_softmax = pd.concat([simulated_softmax, simulated_softmax_ep], ignore_index=True)
simulated_rs = pd.concat([simulated_rs, simulated_rs_ep], ignore_index=True)
simulated_wsls = pd.concat([simulated_wsls, simulated_wsls_ep], ignore_index=True)
simulated_dualQ = pd.concat([simulated_dualQ, simulated_dualQ_ep], ignore_index=True)


In [5]:
def recovery_stats(x, y):
    r, _                         = stats.pearsonr(x, y)
    slope, intercept, *_         = stats.linregress(x, y)
    rmse                         = np.sqrt(np.mean((x - y) ** 2))
    return r, slope, intercept, rmse

# ────────── plot──────────


def plot_with_recovery(ax, x, y, title):
    x, y = np.asarray(x), np.asarray(y)
    r, m, b, rmse = recovery_stats(x, y)

    if "β" in title:
        ax.set_xlim(0, 8.2)
        ax.set_ylim(0, 8.2)
    else:
        ax.set_xlim(0, 1.02)
        ax.set_ylim(0, 1.02)

    lims = [min(ax.get_xlim()[0], ax.get_ylim()[0]), max(ax.get_xlim()[1], ax.get_ylim()[1])]
    ax.plot(lims, lims, ls="--", lw=1, c="k", zorder=1)

    ax.scatter(x, y, alpha=0.7, zorder=2)

    ax.set_xlabel("Fitted")
    ax.set_ylabel("Simulated")
    ax.spines["right"].set_visible(False)
    ax.spines["top"].set_visible(False)
    ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda v, _: f"{v:.2f}"))
    ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda v, _: f"{v:.2f}"))
    ax.set_title(f"{title}\n r={r:.2f}  m={m:.2f}  b={b:.2f}  RMSE={rmse:.3g}", fontsize=9)



# ────────── Figure scaffold ──────────
fig, axes = plt.subplots(5, 3, figsize=(10, 18))
fig.suptitle("Fitted vs Simulated Parameters", fontsize=20)

# Row 0 – Risk-Sensitive
params_rs = [("best_alpha_plus","best_alpha_plus","Risk-Sens α⁺"),
             ("best_alpha_minus","best_alpha_minus","Risk-Sens α⁻"),
             ("best_beta","best_beta","Risk-Sens β")]
for col,(fc,sc,ttl) in enumerate(params_rs):
    plot_with_recovery(axes[0, col], fit_rs[fc], simulated_rs[sc], ttl)

# Row 1 – Dual-Q
params_dualq = [("best_alpha_r","best_alpha_r","Dual-Q α_reward"),
                ("best_alpha_s","best_alpha_s","Dual-Q α_risk"),
                ("best_beta","best_beta","Dual-Q β")]
for col,(fc,sc,ttl) in enumerate(params_dualq):
    plot_with_recovery(axes[1, col], fit_dualQ[fc], simulated_dualQ[sc], ttl)

# Row 2 – Softmax
params_softmax = [("best_alpha","best_alpha","Softmax α"),
                  ("best_beta","best_beta","Softmax β")]
for col, (fc, sc, ttl) in enumerate(params_softmax):

    plot_with_recovery(axes[2, col], fit_softmax[fc], simulated_softmax[sc], ttl)
# Row 3 – Greedy
params_greedy = [("best_alpha","best_alpha","Greedy α"),
                 ("best_beta","best_beta","Greedy ε")]
for col,(fc,sc,ttl) in enumerate(params_greedy):
    plot_with_recovery(axes[3, col], fit_greedy[fc], simulated_greedy[sc], ttl)

# Row 4 – WSLS
params_wsls = [("best_alpha","best_alpha","WSLS P(stay|win)"),
               ("best_beta","best_beta","WSLS P(shift|lose)")]
for col,(fc,sc,ttl) in enumerate(params_wsls):
    plot_with_recovery(axes[4, col], fit_wsls[fc], simulated_wsls[sc], ttl)

# Hide empty cells
for row,col in [(2,2),(3,2),(4,2)]: axes[row,col].axis("off")
plt.tight_layout(rect=[0,0.03,1,0.95])

# ────────── Save files to your requested dir ──────────
output_dir = pathlib.Path(output_dir).resolve()

pdf_path = output_dir / "fitted_vs_simulated_all_params_healthy_epileptic.pdf"
svg_path = output_dir / "fitted_vs_simulated_all_params_healthy_epileptic.svg"

fig.savefig(pdf_path)
fig.savefig(svg_path, format="svg", dpi=1200, bbox_inches="tight", transparent=True)

plt.show()
