In [None]:
import os
import pickle
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import random
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.ticker import ScalarFormatter
from matplotlib import colors
import sys
import torch

# Add module to path
module_path = Path.cwd().parents[1] / "module"  
sys.path.append(str(module_path))

from pool_utils import (init_pool, 
                        PINNDataset, 
                        polar_grid_flat, 
                        normalize, 
                        train_test_files, 
                        select_inputs_outputs, 
                        compute_scaling_factor)

from select_sensors import add_sensor_pool
from model import PINN_model


def plot_uncertainties_2d_MC(model, 
                             current_dataset,
                             next_dataset,
                             x_max, 
                             u_a, 
                             R, 
                             save_dir=None, 
                             prefix=None, 
                             n_samples=50, 
                             device="cpu"):
    """
    Perform Monte Carlo Dropout to quantify mean and std 
    of ux and uy on 2D grid.
    """

    plt.rcParams.update({
        "font.size": 7,
        "axes.labelsize": 7,
        "xtick.labelsize": 7,
        "ytick.labelsize": 7,
        "legend.fontsize": 7
    })
    
    # current inputs and masks
    current_inputs = current_dataset.inputs.cpu().detach().numpy()
    current_x_all, current_y_all = current_inputs[:, 0]*x_max, current_inputs[:, 1]*x_max
    current_is_data = current_dataset.is_data.cpu().numpy()

    # Next inputs and masks
    next_inputs = next_dataset.inputs.cpu().detach().numpy()
    next_x_all, next_y_all = next_inputs[:, 0]*x_max, next_inputs[:, 1]*x_max
    next_is_data = next_dataset.is_data.cpu().numpy()

    s = 15  

    # --- Regular grid ---
    N = 200
    x = np.linspace(-1.1, 1.1, N)
    y = np.linspace(-1.1, 1.1, N)
    X, Y = np.meshgrid(x, y)
    xy = torch.tensor(np.stack([X.flatten(), Y.flatten()], axis=1), dtype=torch.float32).to(device)

    # ---  MC Dropout sampling ---
    ux_samples = []
    uy_samples = []
    model.train()  
    with torch.no_grad():
        for _ in range(n_samples):
            ux_pred, uy_pred = model(xy)
            ux_pred = -ux_pred.cpu().numpy() * u_a.cpu().numpy()
            uy_pred = -uy_pred.cpu().numpy() * u_a.cpu().numpy()
            ux_samples.append(ux_pred)
            uy_samples.append(uy_pred)

    ux_samples = np.stack(ux_samples, axis=0)
    uy_samples = np.stack(uy_samples, axis=0)

    # --- Mean and var ---
    ux_mean = np.mean(ux_samples, axis=0).reshape(N, N)
    uy_mean = np.mean(uy_samples, axis=0).reshape(N, N)
    ux_var = np.var(ux_samples, axis=0).reshape(N, N)
    uy_var = np.var(uy_samples, axis=0).reshape(N, N)

    # --- Total uncertainty ---
    total_uncertainty = ux_var + uy_var

    # --- Mask tunnel ---
    X_scaled, Y_scaled = X * x_max, Y * x_max
    mask = X_scaled**2 + Y_scaled**2 <= R**2
    for arr in [ux_mean, uy_mean, total_uncertainty]:
        arr[mask] = np.nan

    # --- Plot function ---
    def plot_field(Z, title, fname, cmap="RdBu_r", figsize=(5/2.54,5/2.54), nature='uncertainties'):
        
        if nature == 'disp':
            fig_width = figsize[0] * 1.35 
            fig_height = figsize[1] * 1.35 
            figsize_adjusted = (fig_width, fig_height)
        else:
            fig_width = figsize[0] * 1.0 
            fig_height = figsize[1] * 1.0 
            figsize_adjusted = (fig_width, fig_height)
            
        fig, ax = plt.subplots(figsize=figsize_adjusted)

        if nature == 'uncertainties':
            ax.scatter(
                current_x_all[current_is_data],
                current_y_all[current_is_data],
                s=s,
                facecolors='white',
                edgecolors='black',
                marker='x',
                label='Current measurements',
                linewidths=0.6,
                zorder=4
            )
            ax.scatter(
                next_x_all[next_is_data],
                next_y_all[next_is_data],
                s=s,
                color='black',
                marker='x',
                label='Next measurements',
                linewidths=0.6,
                zorder=4
            )

        # Contour
        if nature == 'uncertainties':
            norm = colors.PowerNorm(gamma=0.7, vmin=0, vmax=np.nanmax(Z))
            cf = ax.contourf(X_scaled, Y_scaled, Z, levels=100, cmap=cmap, extend="both", norm=norm)
            
        else:
            cf = ax.contourf(X_scaled, Y_scaled, Z, levels=50, cmap=cmap, extend="both")


        # Colorbar 
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="6%", pad=0.1)
        cbar = plt.colorbar(cf, cax=cax)

            
        # Scientific writing
        formatter = ScalarFormatter(useMathText=True)
        formatter.set_powerlimits((0, 0))  
        cbar.ax.yaxis.set_major_formatter(formatter)
        
        offset_text = cbar.ax.yaxis.get_offset_text()
        offset_text.set_fontsize(6)
        offset_text.set_x(3.5)  

        # Units
        if nature == 'disp':
            plt.text(3.60, 1.09, "(m)", transform=plt.gca().transAxes)

        cbar.ax.tick_params(labelsize=6)

        # Tunnel wall
        circle = plt.Circle((0,0), R, color='red', fill=False, linewidth=1)
        ax.add_patch(circle)

        # Axes and grid
        ax.set_xlabel(r"$x$ (m)")
        ax.set_ylabel(r"$y$ (m)")
        ax.set_aspect('equal', adjustable='box')

        x_min, x_max = -20, 20
        y_min, y_max = -20, 20
        step = 10
        ax.set_xticks(np.arange(x_min, x_max+1, step))
        ax.set_yticks(np.arange(y_min, y_max+1, step))
        ax.grid(True, which='both', linewidth=0.5)

        plt.tight_layout()
        if save_dir is not None:
            os.makedirs(save_dir, exist_ok=True)
            save_path = os.path.join(save_dir, fname)
            plt.savefig(save_path, dpi=300, bbox_inches="tight")
        plt.show()

    # --- TracÃ©s ---
    plot_field(ux_mean, r"$\overline{u_x}$", f"ux_mean{('_'+prefix) if prefix else ''}.pdf", figsize=(5/2.54,5/2.54), nature='disp')
    plot_field(uy_mean, r"$\overline{u_y}$", f"uy_mean{('_'+prefix) if prefix else ''}.pdf", figsize=(5/2.54,5/2.54), nature='disp')
    plot_field(total_uncertainty, r"$\sigma(u_x)+\sigma(u_y)$", f"u_var_current_next{('_'+prefix) if prefix else ''}.pdf", cmap="viridis", figsize=(7/2.54,7/2.54), nature='uncertainties')


### Fix seed 
seed = 10
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

# device
device = torch.device("cpu")

# St Martin La Porte (Vu et al 2013)
# Initial stress
sigma_v = 5e6
K = 0.75
sigma_h = K*sigma_v
l = 1

# Rotation
beta = 45*np.pi/180 # 

# In situ stress
sigma_v_0 = sigma_v/2*(1+K + (1-K)*np.cos(2*beta))
sigma_h_0 = sigma_v/2*(1+K - (1-K)*np.cos(2*beta))
tau_vh_0 = sigma_v/2*(1-K)*np.sin(2*beta)

# Tunnel radius
R = 5

# Extensometer length (0 if bc only)
L = 24 # 8 

# Mechanical parameters to retrieve
Eh = 620e6
Ev = 340e6
Gvh = 200e6
nh = 0.12
nhv = 0.2
nvh = Ev*nhv/Eh

# Noise
noise_level = 0

# Path to repo root
repo_root = Path.cwd().parents[1] 

# mode
mode = "extensometer"

# folder name
folder = f"7_sensors_{mode}_mode_Noise_0%"

# Path to data file
load_dir = os.path.join(repo_root, "examples", f"{mode}_mode", "Results", f"{folder}") 
save_dir = os.path.join(repo_root, "examples", f"{mode}_mode", "Results", f"{folder}", "posterior_analysis", "predictions") 
os.makedirs(save_dir, exist_ok=True)

# seed dir
seed_dir = os.path.join(load_dir, f"s_{seed}")

save_path = os.path.join(load_dir, seed_dir, "all_results.pkl")
if not os.path.exists(save_path):
    print(f"Warning: {save_path} does not exist, skipping")
    
with open(save_path, "rb") as f:
    results_dict = pickle.load(f)

# List all steps in this seed folder
seed_path = os.path.join(load_dir, seed_dir)  
step_dirs = [d for d in os.listdir(seed_path) if d.startswith("step_")]  
step_dirs.sort(key=lambda x: int(x.split("_")[1])) 

step = 3
step_dir = step_dirs[step]

load_path = os.path.join(load_dir, seed_dir, step_dir, f"step_{step}.pth")

if not os.path.exists(load_path):
    raise ValueError(f"Path does not exist: {load_path}")

checkpoint = torch.load(load_path, weights_only=True)
state_dict = checkpoint["model_state_dict"]

# Define layers
width = 40
depth = 10
layers = [2]
for _ in range(depth):
    layers.append(width)
print(layers)

# Define model
model = PINN_model(layers, nhv=nhv, nh=nh, beta=beta).to(device)

# Load model states
model.load_state_dict(state_dict, strict=True)


# Collect x_max et u_a
test_dataset = results_dict["test_dataset"]
inputs_test = test_dataset.inputs
_, _, x_max, u_a = normalize(inputs_test, outputs=None, R=R, L=L, device=device)


#-----------------------------
#------- Current Dataset -----
#-----------------------------

# ------ Define grid ----
inputs_grid, mask_int_selected_grid, mask_bc_selected_grid = polar_grid_flat(
                                                                            R, 
                                                                            L, 
                                                                            Nr=10, 
                                                                            Ntheta=36, 
                                                                            refine_exponent=1, 
                                                                            select_every_r=1,
                                                                            select_every_theta=1,
                                                                            device=device,
                                                                        )

inputs_grid_scaled, *_ = normalize(inputs_grid, None, R, L)

# train files
data_dir = os.path.join(repo_root, "synthetic_data")    
selected_ids = None 
train_files = train_test_files(data_dir, selected_ids=selected_ids)

# Split inputs outputs train
all_inputs_train, all_outputs_train, all_unique_ids_train, all_ext_ids_train = select_inputs_outputs(files=train_files, 
                                                                                                     noise_level=noise_level,
                                                                                                     device=device)
# --- Scale inputs / outputs ---
all_inputs_train_scaled, all_outputs_train_scaled, x_max, u_a = normalize(all_inputs_train, all_outputs_train, R, L, device=device)

# --- Scaling factor --- 
b = compute_scaling_factor(sigma_v, x_max, u_a)

# --- Train pool --- 
current_pool = init_pool(
                            all_inputs_scaled=all_inputs_train_scaled,
                            all_outputs_scaled=all_outputs_train_scaled,
                            sensor_ids=all_ext_ids_train,
                            R=R,
                            x_max=x_max,
                            all_inputs_grid_scaled=inputs_grid_scaled,
                            mask_int_selected_grid=mask_int_selected_grid,
                            mask_bc_selected_grid=mask_bc_selected_grid,
                            tol=1e-4,
                            device=device
                        )

# initialization
sensor_ids_list = [10, 19]
current_pool = add_sensor_pool(model=None, 
                               pool=current_pool, 
                               sensor_type='extensometer', 
                               sensor_ids_list=sensor_ids_list, 
                               n_MC=50, 
                               random=False, 
                               device=device)


# current
sensor_ids_list = [32]
current_pool = add_sensor_pool(model=None, 
                               pool=current_pool, 
                               sensor_type=mode, 
                               sensor_ids_list=sensor_ids_list, 
                               n_MC=50, 
                               random=False, 
                               device=device)


# --- Training dataset ---
current_dataset = PINNDataset(current_pool, 
                                 include_int=True, 
                                 include_bc=True, 
                                 include_data=True, 
                                 device=device)



#-----------------------------
#------- Next Dataset --------
#-----------------------------
next_pool = init_pool(
                        all_inputs_scaled=all_inputs_train_scaled,
                        all_outputs_scaled=all_outputs_train_scaled,
                        sensor_ids=all_ext_ids_train,
                        R=R,
                        x_max=x_max,
                        all_inputs_grid_scaled=inputs_grid_scaled,
                        mask_int_selected_grid=mask_int_selected_grid,
                        mask_bc_selected_grid=mask_bc_selected_grid,
                        tol=1e-4,
                        device=device
                        )

sensor_ids_list = [28]


next_pool = add_sensor_pool(model=None, 
                            pool=next_pool, 
                            sensor_type=mode, 
                            sensor_ids_list=sensor_ids_list, 
                            n_MC=50, 
                            random=False, 
                            device=device)


next_dataset = PINNDataset(next_pool, 
                           include_int=True, 
                           include_bc=True, 
                           include_data=True, 
                           device=device)

#-----------------------
#--------- Plot --------
#-----------------------
plot_uncertainties_2d_MC(model, 
                         current_dataset,
                         next_dataset,
                         x_max, 
                         u_a, 
                         R, 
                         save_dir=save_dir, 
                         prefix=f"step_{step}_{mode}_mode", 
                         n_samples=50, 
                         device=device)
