In [None]:
import os
import pickle
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from matplotlib.ticker import MaxNLocator
import sys
import torch

# Add module to path
module_path = Path.cwd().parents[1] / "module"  
sys.path.append(str(module_path))

from pool_utils import PINNDataset
from visualization import compute_relative_errors_from_dict

# Layout preferences
plt.rcParams.update({
    "font.size": 7,
    "axes.labelsize": 7,
    "xtick.labelsize": 7,
    "ytick.labelsize": 7,
    "legend.fontsize": 7
})


# Path to repo root
repo_root = Path.cwd().parents[1] 

# mode
mode = "convergence"

# noise
noise_level = 0

# folder
folder = f"7_sensors_{mode}_mode_Noise_{noise_level}%"

# Path to data file
load_dir = os.path.join(repo_root, "examples", f"{mode}_mode", "Results", f"{folder}") 
save_dir = os.path.join(repo_root, "examples", f"{mode}_mode", "Results", f"{folder}", "posterior_analysis") 
os.makedirs(save_dir, exist_ok=True)

all_err_dict = {
    "all_Eh_err": [],
    "all_Ev_err": [],
    "all_Gvh_err": [],
    "all_K_err": [],
    "all_beta_err": [],
}

# List all seed folders
seed_dirs = [d for d in os.listdir(load_dir) if d.startswith("s_")]

for seed_dir in seed_dirs:
    save_path = os.path.join(load_dir, seed_dir, "all_results.pkl")
    if not os.path.exists(save_path):
        print(f"Warning: {save_path} does not exist, skipping")
        continue
    with open(save_path, "rb") as f:
        results_dict = pickle.load(f)
        Eh_err = results_dict['Eh_err']
        Ev_err = results_dict['Ev_err']
        Gvh_err = results_dict['Gvh_err']
        K_err = results_dict['K_err']
        beta_err = results_dict['beta_err']

        # if len(Eh_err) == 12:
        all_err_dict["all_Eh_err"].append(Eh_err[1:])
        all_err_dict["all_Ev_err"].append(Ev_err[1:])
        all_err_dict["all_Gvh_err"].append(Gvh_err[1:])
        all_err_dict["all_K_err"].append(K_err[1:])
        all_err_dict["all_beta_err"].append(beta_err[1:])


# Transform to arrays
all_err_dict["all_Eh_err"] = np.array(all_err_dict["all_Eh_err"])   
all_err_dict["all_Ev_err"] = np.array(all_err_dict["all_Ev_err"])   
all_err_dict["all_Gvh_err"] = np.array(all_err_dict["all_Gvh_err"])   
all_err_dict["all_K_err"] = np.array(all_err_dict["all_K_err"])   
all_err_dict["all_beta_err"] = np.array(all_err_dict["all_beta_err"])  

# Compute means and std
all_err_dict["mean_Eh_err"] = np.nanmean(all_err_dict["all_Eh_err"], axis=0)   
all_err_dict["mean_Ev_err"] = np.nanmean(all_err_dict["all_Ev_err"], axis=0)   
all_err_dict["mean_Gvh_err"] = np.nanmean(all_err_dict["all_Gvh_err"], axis=0)   
all_err_dict["mean_K_err"] = np.nanmean(all_err_dict["all_K_err"], axis=0) 
all_err_dict["mean_beta_err"] = np.nanmean(all_err_dict["all_beta_err"], axis=0) 

all_err_dict["std_Eh_err"] = np.nanstd(all_err_dict["all_Eh_err"], axis=0)   
all_err_dict["std_Ev_err"] = np.nanstd(all_err_dict["all_Ev_err"], axis=0)   
all_err_dict["std_Gvh_err"] = np.nanstd(all_err_dict["all_Gvh_err"], axis=0)   
all_err_dict["std_K_err"] = np.nanstd(all_err_dict["all_K_err"], axis=0)  
all_err_dict["std_beta_err"] = np.nanstd(all_err_dict["all_beta_err"], axis=0) 


def plot_relative_errors_all_article(Eh_err, 
                             Ev_err, 
                             Gvh_err, 
                             K_err, 
                             beta_err,
                             Eh_std=None,
                             Ev_std=None,
                             Gvh_std=None,
                             K_std=None,
                             beta_std=None,
                             save_dir=None, 
                             zoom_xlim=None,
                             zoom_ylim=(0, 30), 
                             show_figs=False):

    
    os.makedirs(save_dir, exist_ok=True)
    x_vals = np.arange(2, len(Eh_err) + 2)
    err_dict = {
        r'$E_h$': (Eh_err, Eh_std, "s", 'Eh_err.pdf'),
        r'$E_v$': (Ev_err, Ev_std, "*", 'Ev_err.pdf'),
        r'$G_{vh}$': (Gvh_err, Gvh_std, "o", 'Gvh_err.pdf'),
        r'$K$': (K_err, K_std, ">", 'K_err.pdf'),
        r'$\beta$': (beta_err, beta_std, "^", 'beta_err.pdf'),
    }

    # Tracés individuels
    for label, (err, std, marker, filename) in err_dict.items():
        if err is not None:
            
            plt.figure(figsize=(8/2.54, 5/2.54))
            plt.plot(x_vals, err, color='tab:red', marker=marker, label=f'Relative error {label}')
            
            if std is not None:
                plt.fill_between(x_vals,
                 (err - std),
                 (err + std),
                 color='tab:red', alpha=0.2)            
            
            plt.xlabel('Step')
            plt.ylabel(f'{label} (%)')
            plt.grid(True)
            plt.legend()
            plt.ylim(*zoom_ylim)
            if zoom_xlim is not None:
                plt.xlim(*zoom_xlim)

            plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
            plt.tight_layout()
            if save_dir is not None:
                save_path = os.path.join(save_dir, filename)
                plt.savefig(save_path, dpi=300, bbox_inches='tight')
            if show_figs:
                plt.show()
            else:
                plt.close()
                
    # Tracé combiné (zoomé)
    plt.figure(figsize=(8/2.54, 6/2.54))
    for label, (err, _, marker, _) in err_dict.items():
        if err is not None:
            plt.plot(x_vals, err, marker=marker, label=label, linewidth=1, markersize=5)
    plt.xlabel('Number of sensors')
    plt.ylabel('Relative error (%)')
    plt.grid(True)
    plt.legend()
    plt.ylim(*zoom_ylim)
    if zoom_xlim is not None:
        plt.xlim(*zoom_xlim)
    plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
    plt.tight_layout()
    if save_dir is not None:
        save_path = os.path.join(save_dir, f'constitutive_{mode}_mode_noise_{noise_level}.pdf')
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    if show_figs:
        plt.show()
    else:
        plt.close()

plot_relative_errors_all_article(all_err_dict["mean_Eh_err"],
                         all_err_dict["mean_Ev_err"],
                         all_err_dict["mean_Gvh_err"],
                         all_err_dict["mean_K_err"],
                         all_err_dict["mean_beta_err"],
                         all_err_dict["std_Eh_err"],
                         all_err_dict["std_Ev_err"],
                         all_err_dict["std_Gvh_err"],
                         all_err_dict["std_K_err"],
                         all_err_dict["std_beta_err"],
                         save_dir=save_dir,
                         zoom_xlim=None,
                         zoom_ylim=(0, 20),
                         show_figs=True)
            