In [None]:
import os
import pickle
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import matplotlib.ticker as mticker
from matplotlib.ticker import MaxNLocator
import sys
import torch

# Add module to path
module_path = Path.cwd().parents[1] / "module"  
sys.path.append(str(module_path))

from pool_utils import PINNDataset

def plot_params_evolution_separate(results_dict,
                                   params_names=('Eh', 'Ev', 'Gvh', 'K', 'beta'),
                                   true_vals=None,
                                   xlabel='Epochs (cum.)',
                                   ylabel='Param value',
                                   fontsize=7,
                                   figsize=(10, 5),
                                   save_dir=None,
                                   show=True):
    


    plt.rcParams.update({
        "font.size": fontsize,
        "axes.labelsize": fontsize,
        "xtick.labelsize": fontsize,
        "ytick.labelsize": fontsize,
        "legend.fontsize": fontsize
    })
    
    series_map = {
        'Eh': results_dict.get('Eh_opt', []),
        'Ev': results_dict.get('Ev_opt', []),
        'Gvh': results_dict.get('Gvh_opt', []),
        'K': results_dict.get('K_opt', []),
        'beta': results_dict.get('beta_opt', [])
    }
    
    scale = results_dict["scale"]

    n_steps = max(len(series_map[p]) for p in series_map)

    colors = {
        'Eh': 'blue',
        'Ev': 'orange',
        'Gvh': 'green',
        'K': 'red',
        'beta': 'purple',
    }

    labels = {
        'Eh': r'$E_h$',
        'Ev': r'$E_v$',
        'Gvh': r'$G_{vh}$',
        'K': r'$K$',
        'beta': r'$\beta$',
    }

    step_boundaries = [0]
    cum_epoch = 0
    step_indices = []
    for step_idx in range(n_steps):
        step_len = None
        for pname in series_map:
            if step_idx < len(series_map[pname]):
                val = series_map[pname][step_idx]
                if hasattr(val, '__len__') and not np.isscalar(val):
                    step_len = len(val)
                    break
        if step_len is None:
            step_len = 1
        start = cum_epoch
        end = cum_epoch + step_len
        step_indices.append((start, end))
        cum_epoch = end
        step_boundaries.append(cum_epoch)
    total_epochs = cum_epoch

    def plot_one_param(pname):
        fig, ax = plt.subplots(1, 1, figsize=figsize)
        x_vals, y_vals = [], []
        cum = 0
        series = series_map.get(pname, [])
        for step_idx in range(n_steps):
            if step_idx < len(series):
                val = series[step_idx]
                if hasattr(val, '__len__') and not np.isscalar(val):
                    L = len(val)
                    xs = np.arange(cum, cum+L)
                    ys = np.array(val).flatten()
                    if pname in ('Eh', 'Ev', 'Gvh'):
                        ys = ys / scale
                else:
                    xs = np.array([cum])
                    ys = np.array([float(val)])
                    if pname in ('Eh', 'Ev', 'Gvh'):
                        ys = ys / scale
                    L = 1
            else:
                xs = np.array([cum])
                ys = np.array([np.nan])
                L = 1
            x_vals.append(xs)
            y_vals.append(ys)
            cum += len(xs)

        if len(x_vals):
            x = np.concatenate(x_vals)
            y = np.concatenate(y_vals)
            # sous-Ã©chantillonnage 1/500
            x_sub = x[::500]
            y_sub = y[::500]
            label = labels.get(pname, None)
            ax.plot(x_sub, y_sub, label=label, color=colors.get(pname, None), 
                   linewidth=0.8, alpha=0.6)
            ax.scatter(x_sub, y_sub, s=1, color=colors.get(pname, None))
            
   
        for b in step_boundaries[:]:
            ax.axvline(b, color='k', linestyle='--', linewidth=0.6, alpha=0.8)
        
        ymin, ymax = ax.get_ylim()
        y_text = ymax * 1.02  
        
        max_step = min(8, len(step_boundaries) - 1)
        for idx in range(0, max_step):
            if idx < len(step_boundaries):
                x_pos = step_boundaries[idx]
                ax.text(x_pos, y_text, f"Step {idx + 1}", ha='left', va='bottom',
                        fontsize=fontsize-1, rotation=45)

        # Plot true values if provided
        if true_vals is not None and pname in true_vals:
            val = true_vals[pname]
            ax.hlines(val, xmin=0, xmax=total_epochs-1 if total_epochs > 0 else 0,
                      colors=colors.get(pname, 'k'),
                      linestyles='--', linewidth=1, alpha=0.8,
                      label=f"{labels.get(pname, None)} (true)")

        # ticks x en notation scientifique
        ax.xaxis.set_major_formatter(mticker.ScalarFormatter(useMathText=True))
        
        ax.tick_params(axis='both', which='major', labelsize=fontsize)
        ax.ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
        ax.ticklabel_format(axis='y', style='sci', scilimits=(0, 0))  
        ax.yaxis.get_offset_text().set_x(-0.1)  
        
        ax.set_xlabel(xlabel, fontsize=fontsize)
        ax.set_ylabel(ylabel, fontsize=fontsize)
        if label == r'$K$' or label == r'$\beta$':
            ax.legend(loc = "lower right", fontsize=fontsize)
        elif label == r'$E_h$' or label == r'$G_{vh}$' :
            ax.legend(loc = "upper center", fontsize=fontsize)
        else:
            ax.legend(loc = "best", fontsize=fontsize)
        ax.grid(True, linestyle=':', alpha=0.5)
        plt.tight_layout()
        if save_dir is not None:
            out_path = os.path.join(save_dir, f"{pname}_evolution.pdf")
            plt.savefig(out_path, dpi=300)
            print(f"Figure saved to {out_path}")
        if show:
            plt.show()
        else:
            plt.close(fig)

    for pname in params_names:
        plot_one_param(pname)

# True values
K = 0.75
Eh = 620e6
Ev = 340e6
Gvh = 200e6
beta = 45*np.pi/180 

true_vals = {'Eh': Eh, 'Ev': Ev, 'Gvh': Gvh, 'K': K, 'beta': beta}

# mode 
mode = 'extensometer'

# Load results and save 
folder = f"7_sensors_{mode}_mode_Noise_0%"

# Path to repo root
repo_root = Path.cwd().parents[1] 

# Path to data file
load_dir = os.path.join(repo_root, "examples", f"{mode}_mode", "Results", f"{folder}") 
save_dir = os.path.join(repo_root, "examples", f"{mode}_mode", "Results", f"{folder}", "posterior_analysis") 
os.makedirs(save_dir, exist_ok=True)

seed = 1
seed_dir = os.path.join(load_dir, f"s_{seed}")

save_path = os.path.join(load_dir, seed_dir, "all_results.pkl")
if not os.path.exists(save_path):
    print(f"Warning: {save_path} does not exist, skipping")
    
with open(save_path, "rb") as f:
    results_dict = pickle.load(f)


plot_params_evolution_separate(results_dict,
                                   params_names=('Eh', 'Ev', 'Gvh', 'K', 'beta'),
                                   true_vals=true_vals,
                                   xlabel='Epochs (cumulative)',
                                   ylabel='Parameter value',
                                   fontsize=7,
                                   figsize=(8/2.54, 6/2.54),
                                   save_dir=save_dir,
                                   show=True)
