In [11]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

import tensorflow as tf

CNN_PATH = "ClimSim/baseline_models/CNN/model/"
MLP_PATH = "lib/ClimSim/baseline_models/MLP/model/backup_phase-7_retrained_models_step2_lot-147_trial_0027.best.h5"

# cnn_model = tf.keras.layers.TFSMLayer(CNN_PATH, call_endpoint="serving_default")
mlp_model = tf.keras.models.load_model(MLP_PATH, compile=False)
mlp_model.summary()



In [12]:
import xarray as xr
import numpy as np
import torch 
from torch.utils.data import Dataset, Subset
import re

from lib import data


class ClimSimBase:
    def __init__(self, zarr_path, grid_path, norm_path, features, num_latlon = 384, normalize=True):
        self.ds = xr.open_zarr(zarr_path, chunks=None)
        self.features = features
        self.features_list = self.__get_features__()
        self.normalize_flag = normalize
        self.normalize = normalize  
        self.num_latlon = num_latlon
        self.grid = xr.open_dataset(grid_path, engine="netcdf4")
        
        self.input_mean = xr.open_dataset(os.path.join(norm_path, "inputs/input_mean.nc"), engine="h5netcdf")
        self.input_std = xr.open_dataset(os.path.join(norm_path, "inputs/input_std.nc"), engine="h5netcdf")
        self.input_max = xr.open_dataset(os.path.join(norm_path, "inputs/input_max.nc"), engine="h5netcdf")
        self.input_min = xr.open_dataset(os.path.join(norm_path, "inputs/input_min.nc"), engine="h5netcdf")
        self.output_scale = xr.open_dataset(os.path.join(norm_path, "outputs/output_scale.nc"), engine="h5netcdf")

        self.grid['area_wgt'] = self.grid['area']/self.grid['area'].mean(dim = 'ncol')
        self.area_wgt = self.grid['area_wgt'].values

        self.input_vars = [v for v in self.features_list if 'in' in v]
        self.output_vars = [v for v in self.features_list if 'out' in v]

        self.grav    = 9.80616    # acceleration of gravity ~ m/s^2
        self.cp      = 1.00464e3  # specific heat of dry air   ~ J/kg/K
        self.lv      = 2.501e6    # latent heat of evaporation ~ J/kg
        self.lf      = 3.337e5    # latent heat of fusion      ~ J/kg
        self.lsub    = self.lv + self.lf    # latent heat of sublimation ~ J/kg
        self.rho_air = 101325/(6.02214e26*1.38065e-23/28.966)/273.15 # density of dry air at STP  ~ kg/m^3
        self.rho_h20 = 1.e3       # density of fresh water     ~ kg/m^ 3

        self.target_energy_conv = {'ptend_t':self.cp,
                            'ptend_q0001':self.lv,
                            'ptend_q0002':self.lv,
                            'ptend_q0003':self.lv,
                            'ptend_qn':self.lv,
                            'ptend_wind': None,
                            'cam_out_NETSW':1.,
                            'cam_out_FLWDS':1.,
                            'cam_out_PRECSC':self.lv*self.rho_h20,
                            'cam_out_PRECC':self.lv*self.rho_h20,
                            'cam_out_SOLS':1.,
                            'cam_out_SOLL':1.,
                            'cam_out_SOLSD':1.,
                            'cam_out_SOLLD':1.
                            }
        
        
        self.dp = None 
        self.pressure_grid = None

    def __get_features__(self):
        feat = np.concatenate([self.features["features"]["multilevel"], self.features["features"]["surface"]])
        target = np.concatenate([self.features["target"]["tendancies"], self.features["target"]["surface"]])
        return np.concatenate([feat, target])

    def _prepare_data(self, idx):
        x = self.process_list(self.input_vars, idx, is_input=True)
        y = self.process_list(self.output_vars, idx, is_input=False)
        return x, y

    def process_list(self, vars_list, idx, is_input=True):
        out_list = []
        n_geo = self.num_latlon # 384

        for var in vars_list:
            if "ptend" in var:
                # Cette fonction doit renvoyer du (Time, 384, 60)
                data = self._calculate_tendency_on_fly(var, idx)
                if data.ndim == 2: # Si (384, 60)
                    data = data[np.newaxis, :, :]
            
            else:
                da = self.ds[var].isel(sample=idx)
                
                # Redressement par nom de dimension Xarray
                if 'lev' in da.dims:
                    if "sample" in da.dims:
                        data = da.transpose('sample', 'ncol', 'lev').values
                    else:
                        data = da.transpose('ncol', 'lev').values[np.newaxis, :, :]
                else:
                    if "sample" in da.dims:
                        # Surface : (Time, 384)
                        data = da.values[:, :, np.newaxis]  # Ajouter une dimension lev=1
                    else:
                        # Surface : (Time, 384) -> (Time, 384, 1)
                        data = da.values[np.newaxis, :, np.newaxis]
            
            # 2. Normalisation (Maintenant data est garanti (N, 384, L))
            data = self._normalize_var(data, var, is_input=is_input)
            out_list.append(data.astype(np.float32))

        # 3. Concaténation et aplatissement
        combined = np.concatenate(out_list, axis=-1)
        return combined.reshape(-1, combined.shape[-1])

    def __len__(self):
        return self.ds.dims['sample']

    def _calculate_tendency_on_fly(self, var, idx):
        dt = 1200
        mapping = {
            'out_ptend_t': ('out_state_t', 'in_state_t'),
            'out_ptend_q0001': ('out_state_q0001', 'in_state_q0001'),
            'out_ptend_u': ('out_state_u', 'in_state_u'),
            'out_ptend_v': ('out_state_v', 'in_state_v'),
        }
        out_v, in_v = mapping[var]

        v_final = self.ds[out_v].isel(sample=idx)
        v_init  = self.ds[in_v].isel(sample=idx)

        # Fonction utilitaire: remettre en ordre (sample?, ncol, lev) si sample existe
        def to_array(da):
            dims = da.dims

            if 'ncol' not in dims or 'lev' not in dims:
                raise ValueError(f"{da.name}: dims inattendues {dims}, attendu ncol & lev")

            # Cas slice -> dims contiennent sample
            if 'sample' in dims:
                da = da.transpose('sample', 'ncol', 'lev')
                return da.values  # (sample, ncol, lev)

            # Cas int -> dims (lev, ncol) ou (ncol, lev)
            da = da.transpose('ncol', 'lev')
            return da.values[None, ...]  # (1, ncol, lev)

        vf = to_array(v_final)
        vi = to_array(v_init)

        return (vf - vi) / dt  # (time, ncol, lev)


    
    def _normalize_var(self, data, var_name, is_input=True):
        # data est (N, 384, L) où L est 1 ou 60
        short_name = re.sub(r'^(in_|out_)', '', var_name)
        
        if is_input:
            m = self.input_mean[short_name].values     # (L,)
            diff = (self.input_max[short_name].values - self.input_min[short_name].values) # (L,)
            
            # On redimensionne les stats en (1, 1, L) pour s'aligner sur data (N, 384, L)
            m = m.reshape(1, 1, -1)
            diff = diff.reshape(1, 1, -1)
            
            return (data - m) / (diff + 1e-15)
        else:
            scale = self.output_scale[short_name].values # (L,)
            return data * scale.reshape(1, 1, -1)
            
    def set_pressure_grid(self, input_data):
        '''
        Calcule la grille de pression 3D à partir de state_ps.
        Code directement issu de ClimSim original.
        '''
        self.ps_index = self._find_ps_index(self.features)
        state_ps = input_data[:, self.ps_index]
        if self.normalize_flag:
            state_ps = state_ps * (self.input_max['state_ps'].values - self.input_min['state_ps'].values) + self.input_mean['state_ps'].values

        state_ps = state_ps.reshape(-1, self.num_latlon)

        p1 = (self.grid['P0'] * self.grid['hyai']).values[:, None, None]
        p2 = self.grid['hybi'].values[:, None, None] * state_ps[None, :, :]
        
        self.pressure_grid = p1 + p2
        self.dp = (self.pressure_grid[1:61] - self.pressure_grid[0:60]).transpose((1, 2, 0))
    
    
    def denormalize_output(self, y_pred):
        """Dénormalise les prédictions."""

        full_scale_vector = [] # To vectorize we generate the full scale vector first
        for var in self.output_vars:
            short_name = re.sub(r'^(in_|out_)', '', var)
            scale = self.output_scale[short_name].values
            
            if 'ptend' in var:
                dim_size = 60
            elif 'lev' in self.ds[var].dims:
                dim_size = self.ds[var].sizes['lev']
            else:
                dim_size = 1
            
            if np.isscalar(scale) or scale.size == 1:
                scale_expanded = np.full(dim_size, scale)
            else:
                scale_expanded = scale
                
            full_scale_vector.append(scale_expanded)
        
        full_scale_vector = np.concatenate(full_scale_vector)

        y_denorm = y_pred / (full_scale_vector + 1e-8) 

        return (y_denorm).astype(np.float32)
    
    def calc_MAE(self, pred, target, avg_grid = True):
        '''
        calculate 'globally averaged' mean absolute error 
        for vertically-resolved variables, shape should be time x grid x level
        for scalars, shape should be time x grid

        returns vector of length level or 1
        '''
        assert pred.shape[1] == self.num_latlon
        assert pred.shape == target.shape
        mae = np.abs(pred - target).mean(axis = 0)
        if avg_grid:
            return mae.mean(axis = 0) # we decided to average globally at end
        else:
            return mae
    
    def calc_RMSE(self, pred, target, avg_grid = True):
        '''
        calculate 'globally averaged' root mean squared error 
        for vertically-resolved variables, shape should be time x grid x level
        for scalars, shape should be time x grid

        returns vector of length level or 1
        '''
        assert pred.shape[1] == self.num_latlon
        assert pred.shape == target.shape
        sq_diff = (pred - target)**2
        rmse = np.sqrt(sq_diff.mean(axis = 0)) # mean over time
        if avg_grid:
            return rmse.mean(axis = 0) # we decided to separately average globally at end
        else:
            return rmse

    def calc_R2(self, pred, target, avg_grid = True):
        '''
        calculate 'globally averaged' R-squared
        for vertically-resolved variables, input shape should be time x grid x level
        for scalars, input shape should be time x grid

        returns vector of length level or 1
        '''
        assert pred.shape[1] == self.num_latlon
        assert pred.shape == target.shape
        sq_diff = (pred - target)**2
        tss_time = (target - target.mean(axis = 0)[np.newaxis, ...])**2 # mean over time
        r_squared = 1 - sq_diff.sum(axis = 0)/tss_time.sum(axis = 0) # sum over time
        if avg_grid:
            return r_squared.mean(axis = 0) # we decided to separately average globally at end
        else:
            return r_squared
    
    def output_weighting(self, output, just_weights=False):
        num_samples = output.shape[0]
        n_geo = self.num_latlon
        n_time = num_samples // n_geo
        
        # Configuration des indices : (début, fin, est_profil)
        offsets = {
            'ptend_t': (0, 60, True), 'ptend_q0001': (60, 120, True),
            'cam_out_NETSW': (120, 121, False), 'cam_out_FLWDS': (121, 122, False),
            'cam_out_PRECSC': (122, 123, False), 'cam_out_PRECC': (123, 124, False),
            'cam_out_SOLS': (124, 125, False), 'cam_out_SOLL': (125, 126, False),
            'cam_out_SOLSD': (126, 127, False), 'cam_out_SOLLD': (127, 128, False)
        }
        
        dp = self.dp / self.grav
        var_dict = {}

        for var, (start, end, is_prof) in offsets.items():
            # Extraction et reshape
            val = output[:, start:end].reshape(n_time, n_geo, -1 if is_prof else 1)
            if not is_prof: val = val.squeeze(-1)

            # [0] Undo scaling
            scale = self.output_scale[var].values
            val /= scale[None, None, :] if is_prof else scale

            # [1] Vertical weighting
            if is_prof: val *= dp
            
            # [2] Area weighting
            val *= self.area_wgt[None, :, None] if is_prof else self.area_wgt[None, :]
            
            # [3] Energy conversion
            val *= self.target_energy_conv[var]
            
            
            var_dict[var] = val

        return var_dict
    
    def _find_ps_index(self, features_dict):
        """
        Calcule l'index de départ de 'in_state_ps' dans le vecteur d'entrée plat.
        Prend en compte que chaque variable multilevel occupe 60 colonnes.
        """
        current_index = 0
        
        # 1. Parcourir les variables multi-niveaux (60 niveaux chacune)
        for var in features_dict["features"]["multilevel"]:
            if var == "in_state_ps":
                return current_index
            current_index += 60
            
        # 2. Parcourir les variables de surface (1 niveau chacune)
        for var in features_dict["features"]["surface"]:
            if var == "in_state_ps":
                return current_index
            current_index += 1
            
        raise ValueError("Variable 'in_state_ps' non trouvée dans le dictionnaire FEATURES.")

class ClimSimPyTorch(ClimSimBase, Dataset):
    def __getitem__(self, idx):
        x_np, y_np = self._prepare_data(idx)
        return torch.from_numpy(x_np), torch.from_numpy(y_np)

    # On peut remettre ta méthode de split ici
    def train_test_split(self, test_size=0.2, seed=42, shuffle=True):
        n = len(self)
        indices = np.arange(n)
        if shuffle:
            rng = np.random.default_rng(seed)
            rng.shuffle(indices)
        split = int((1 - test_size) * n)
        return Subset(self, indices[:split]), Subset(self, indices[split:])
    
    def get_models_dims(self, variables_dict):
        features_tend = variables_dict["features"]["multilevel"]
        features_surf = variables_dict["features"]["surface"]
        
        target_tend = variables_dict["target"]["tendancies"]
        target_surf = variables_dict["target"]["surface"]

        def get_var_dim(var):
            # 1. Gérer les variables virtuelles (tendances calculées)
            if 'ptend' in var:
                # On mappe vers la variable d'état pour connaître la dimension 'lev'
                # ex: out_ptend_t -> out_state_t
                source_var = var.replace('ptend', 'state')
                return self.ds[source_var].sizes['lev']
            
            # 2. Gérer les variables réelles présentes dans le Zarr
            if 'lev' in self.ds[var].dims:
                return self.ds[var].sizes['lev']
            
            # 3. Variables de surface (scalaires)
            return 1

        in_tend_dim = sum([get_var_dim(var) for var in features_tend])
        in_surf_dim = len(features_surf)
        
        out_tend_dim = sum([get_var_dim(var) for var in target_tend])
        out_surf_dim = len(target_surf)

        return {
            "input_total": in_tend_dim + in_surf_dim,
            "output_tendancies": out_tend_dim,
            "output_surface": out_surf_dim
        }
            
class ClimSimKeras(ClimSimBase):
    def get_batch_for_keras(self, idx, batch_size, input_dim, output_dim):
        in_batch = np.zeros((batch_size, self.num_latlon, input_dim), dtype=np.float32)
        out_batch = np.zeros((batch_size, self.num_latlon, output_dim), dtype=np.float32)

        for s in range(batch_size):
            if idx * batch_size + s >= self.get_sample_number():
                raise IndexError("Index hors limites pour le dataset.")
        
            x_np, y_np = self._prepare_data(idx * batch_size + s)

            in_batch[s, :, :] = x_np.reshape(self.num_latlon, input_dim)
            out_batch[s, :, :] = y_np.reshape(self.num_latlon, output_dim)
        return in_batch, out_batch
    
    def get_sample_number(self):
        return self.ds.dims['sample']
    

In [13]:
ZARR_PATH = "data/ClimSim_low-res.zarr"
GRID_PATH = "data/ClimSim_low-res/ClimSim_low-res_grid-info.nc"
NORM_PATH = "lib/ClimSim/preprocessing/normalizations"

FEATURES = {
    "features" :{
        "multilevel" : ["in_state_t", "in_state_q0001"],
        "surface" : [ "in_state_ps", 'in_pbuf_SOLIN', 'in_pbuf_LHFLX', 'in_pbuf_SHFLX'],
    },  
    "target" :{
        "tendancies" : ["out_ptend_t", "out_ptend_q0001"],
        "surface" : ["out_cam_out_NETSW", "out_cam_out_FLWDS", "out_cam_out_PRECSC", "out_cam_out_PRECC", "out_cam_out_SOLS", "out_cam_out_SOLL", "out_cam_out_SOLSD", "out_cam_out_SOLLD"]
    }
}

models_dims = ClimSimPyTorch(ZARR_PATH, GRID_PATH, NORM_PATH, FEATURES).get_models_dims(FEATURES)
print("Model input/output dimensions:", models_dims)

dataset = ClimSimKeras(ZARR_PATH, GRID_PATH, NORM_PATH, FEATURES)

Model input/output dimensions: {'input_total': 124, 'output_tendancies': 120, 'output_surface': 8}


In [14]:
target_vars = FEATURES["target"]["tendancies"] + FEATURES["target"]["surface"]
input_size = models_dims["input_total"]
output_size = models_dims["output_tendancies"] + models_dims["output_surface"]


x_batch, y_true_batch = dataset.get_batch_for_keras(1,2, input_size, output_size)
print(x_batch.shape, y_true_batch.shape)

(2, 384, 124) (2, 384, 128)


  return self.ds.dims['sample']


In [15]:
import numpy as np
import re

target_vars = FEATURES["target"]["tendancies"] + FEATURES["target"]["surface"]
input_size = models_dims["input_total"]
output_size = models_dims["output_tendancies"] + models_dims["output_surface"]

# 1. Initialisation des accumulateurs
stats = {}
for var in target_vars:
    short_name = re.sub(r'^(in_|out_)', '', var)
    size = 60 if var in FEATURES["target"]["tendancies"] else 1
    stats[short_name] = {
        "ss_res": np.zeros(size),      # Somme des carrés (pour RMSE et R2)
        "sum_abs_err": np.zeros(size), # Somme des erreurs absolues (pour MAE)
        "sum_y": np.zeros(size),       # Somme de y (pour variance R2)
        "sum_y_sq": np.zeros(size),    # Somme de y^2 (pour variance R2)
        "count": 0
    }

all_mae = {short_name: [] for short_name in stats.keys()}
all_r2 = {short_name: [] for short_name in stats.keys()}

# 2. Boucle de test (Batch par Batch)
for i in range(10):
    x_batch, y_true_batch = dataset.get_batch_for_keras(i, batch_size=1, input_dim=input_size, output_dim=output_size)

    x_flat = x_batch.reshape([-1, x_batch.shape[2]])
    y_flat = y_true_batch.reshape([-1, y_true_batch.shape[2]])
    
    dataset.set_pressure_grid(x_flat)
    preds_batch = mlp_model.predict(x_flat, verbose=0)
    
    true_dict = dataset.output_weighting(y_flat)
    pred_dict = dataset.output_weighting(preds_batch)
    
    for var in target_vars:
        short_name = re.sub(r'^(in_|out_)', '', var)
        y_t = true_dict[short_name] 
        y_p = pred_dict[short_name]
        
        mae_score = dataset.calc_MAE(y_p, y_t, avg_grid=True)
        all_mae[short_name].append(mae_score)

        r2_score = dataset.calc_R2(y_p, y_t, avg_grid=True)
        all_r2[short_name].append(r2_score)


# 2. Affichage des moyennes finales
print(f"\n{'Variable':<25} | {'MAE Moyenne (W/m2)':<15}")
print("-" * 45)

for var, scores in all_mae.items():
    # Moyenne sur tous les batches, puis moyenne sur les 60 niveaux
    final_score = np.mean(np.mean(scores, axis=0))
    print(f"{var:<25} | {final_score:.4e}")

for var, scores in all_r2.items():
    # Moyenne sur tous les batches, puis moyenne sur les 60 niveaux
    final_score = np.mean(np.mean(scores, axis=0))
    print(f"{var:<25} | {final_score:.4e}")



  return self.ds.dims['sample']
  r_squared = 1 - sq_diff.sum(axis = 0)/tss_time.sum(axis = 0) # sum over time
  r_squared = 1 - sq_diff.sum(axis = 0)/tss_time.sum(axis = 0) # sum over time



Variable                  | MAE Moyenne (W/m2)
---------------------------------------------
ptend_t                   | 9.8517e+00
ptend_q0001               | 8.2674e+00
cam_out_NETSW             | 3.7105e+01
cam_out_FLWDS             | 5.7761e+00
cam_out_PRECSC            | 3.1328e+00
cam_out_PRECC             | 3.0229e+01
cam_out_SOLS              | 1.5977e+01
cam_out_SOLL              | 1.9004e+01
cam_out_SOLSD             | 8.7456e+00
cam_out_SOLLD             | 7.0313e+00
ptend_t                   | -inf
ptend_q0001               | -inf
cam_out_NETSW             | nan
cam_out_FLWDS             | -inf
cam_out_PRECSC            | nan
cam_out_PRECC             | nan
cam_out_SOLS              | nan
cam_out_SOLL              | nan
cam_out_SOLSD             | nan
cam_out_SOLLD             | nan


In [16]:
import numpy as np
import re

# ---------------------------
# Config
# ---------------------------
target_vars = FEATURES["target"]["tendancies"] + FEATURES["target"]["surface"]
input_size = models_dims["input_total"]
output_size = models_dims["output_tendancies"] + models_dims["output_surface"]

# Pour savoir la taille attendue (60 niveaux pour tendancies, 1 pour surface)
def var_size(var: str) -> int:
    return 60 if var in FEATURES["target"]["tendancies"] else 1

def short_name(var: str) -> str:
    return re.sub(r'^(in_|out_)', '', var)

def to_2d(y: np.ndarray, L: int) -> np.ndarray:
    """
    Force y à une shape (N, L).
    - Si y est scalaire ou (N,), L doit être 1 => (N, 1)
    - Si y est déjà (N, L) OK
    - Si y a plus de dims, on aplati tout sauf la dernière dim
    """
    y = np.asarray(y)

    if L == 1:
        # Tout ramener à (N, 1)
        if y.ndim == 0:
            return y.reshape(1, 1)
        if y.ndim == 1:
            return y.reshape(-1, 1)
        # si y est (N, 1) ou (.., 1) : aplatis tout sauf last
        return y.reshape(-1, 1)

    # L > 1 (ex: 60)
    if y.ndim == 1:
        # Cas improbable: on attend L, mais on a (N,) -> erreur claire
        raise ValueError(f"Attendu un tableau avec {L} niveaux, reçu shape {y.shape}")
    if y.shape[-1] != L:
        raise ValueError(f"Attendu dernière dimension = {L}, reçu shape {y.shape}")
    return y.reshape(-1, L)

# ---------------------------
# Accumulateurs globaux
# ---------------------------
stats = {}
for var in target_vars:
    sname = short_name(var)
    L = var_size(var)
    stats[sname] = {
        "sum_abs_err": np.zeros(L, dtype=np.float64),
        "ss_res":      np.zeros(L, dtype=np.float64),
        "sum_y":       np.zeros(L, dtype=np.float64),
        "sum_y_sq":    np.zeros(L, dtype=np.float64),
        "count":       0
    }

# Optionnel : garder aussi les scores "moyenne par batch" pour comparer
mean_batch_r2 = {short_name(v): [] for v in target_vars}
mean_batch_mae = {short_name(v): [] for v in target_vars}

# ---------------------------
# Boucle de test
# ---------------------------
n_batches = 10
eps = 1e-12

for i in range(n_batches):
    x_batch, y_true_batch = dataset.get_batch_for_keras(
        i, batch_size=100, input_dim=input_size, output_dim=output_size
    )

    # flatten (comme tu faisais)
    x_flat = x_batch.reshape([-1, x_batch.shape[2]])
    y_flat = y_true_batch.reshape([-1, y_true_batch.shape[2]])

    dataset.set_pressure_grid(x_flat)
    preds_batch = mlp_model.predict(x_flat, verbose=0)

    true_dict = dataset.output_weighting(y_flat)
    pred_dict = dataset.output_weighting(preds_batch)

    for var in target_vars:
        sname = short_name(var)
        L = var_size(var)

        y_t = to_2d(true_dict[sname], L)
        y_p = to_2d(pred_dict[sname], L)

        # --- accumulateurs globaux ---
        err = (y_t - y_p)
        stats[sname]["sum_abs_err"] += np.sum(np.abs(err), axis=0)
        stats[sname]["ss_res"]      += np.sum(err**2, axis=0)
        stats[sname]["sum_y"]       += np.sum(y_t, axis=0)
        stats[sname]["sum_y_sq"]    += np.sum(y_t**2, axis=0)
        stats[sname]["count"]       += y_t.shape[0]

        # --- optionnel : tes métriques par batch (si tu veux comparer) ---
        # Si tes fonctions renvoient vecteur (L,), on les stocke.
        try:
            mae_b = dataset.calc_MAE(y_p, y_t, avg_grid=True)
            r2_b  = dataset.calc_R2(y_p, y_t, avg_grid=True)
            mean_batch_mae[sname].append(np.asarray(mae_b))
            mean_batch_r2[sname].append(np.asarray(r2_b))
        except Exception:
            # Si jamais calc_MAE/calc_R2 attend d'autres shapes, on ignore en silence.
            pass

# ---------------------------
# Calculs finaux (GLOBAL)
# ---------------------------
print("\n=== SCORES GLOBAUX (dataset-level) ===")
print(f"{'Variable':<25} | {'MAE global':>12} | {'R2 global':>12}")
print("-" * 57)

for sname, s in stats.items():
    N = s["count"]
    if N == 0:
        print(f"{sname:<25} | {'NA':>12} | {'NA':>12}")
        continue

    mae_per_level = s["sum_abs_err"] / N

    y_mean = s["sum_y"] / N
    ss_tot = s["sum_y_sq"] - N * (y_mean**2)

    r2_per_level = 1.0 - (s["ss_res"] / (ss_tot + eps))

    # Un seul score par variable (moyenne sur niveaux)
    mae_global = float(np.mean(mae_per_level))
    r2_global  = float(np.mean(r2_per_level))

    print(f"{sname:<25} | {mae_global:12.4e} | {r2_global:12.4e}")



  return self.ds.dims['sample']
2026-02-01 17:46:37.103315: W external/local_xla/xla/tsl/framework/cpu_allocator_impl.cc:84] Allocation of 19660800 exceeds 10% of free system memory.
2026-02-01 17:46:45.656941: W external/local_xla/xla/tsl/framework/cpu_allocator_impl.cc:84] Allocation of 19660800 exceeds 10% of free system memory.
2026-02-01 17:46:54.051723: W external/local_xla/xla/tsl/framework/cpu_allocator_impl.cc:84] Allocation of 19660800 exceeds 10% of free system memory.
2026-02-01 17:47:02.617500: W external/local_xla/xla/tsl/framework/cpu_allocator_impl.cc:84] Allocation of 19660800 exceeds 10% of free system memory.
2026-02-01 17:47:11.219945: W external/local_xla/xla/tsl/framework/cpu_allocator_impl.cc:84] Allocation of 19660800 exceeds 10% of free system memory.



=== SCORES GLOBAUX (dataset-level) ===
Variable                  |   MAE global |    R2 global
---------------------------------------------------------
ptend_t                   |   1.1387e+01 |   6.7116e-02
ptend_q0001               |   9.1280e+00 |  -2.5178e+03
cam_out_NETSW             |   3.9877e+01 |   8.8796e-01
cam_out_FLWDS             |   5.7493e+00 |   9.9110e-01
cam_out_PRECSC            |   2.8807e+00 |   8.5202e-01
cam_out_PRECC             |   3.5260e+01 |   7.7384e-01
cam_out_SOLS              |   1.7661e+01 |   8.7417e-01
cam_out_SOLL              |   1.9969e+01 |   8.6958e-01
cam_out_SOLSD             |   9.2597e+00 |   8.7935e-01
cam_out_SOLLD             |   7.0668e+00 |   8.2339e-01
