In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from torch.utils.data import Subset

from sklearn.model_selection import train_test_split

import xarray as xr
import os
import torch
from functools import reduce 

torch.device("cuda" if torch.cuda.is_available() else "cpu")

LOW_RES_SAMPLE_PATH = "data/ClimSim_low-res/train/"
LOW_RES_GRID_PATH = "data/ClimSim_low-res/ClimSim_low-res_grid-info.nc"
ZARR_PATH = "data/ClimSim_low-res.zarr"
NORM_PATH = "lib/ClimSim/preprocessing/normalizations/"

In [4]:
class ClimSimMLP(nn.Module):
    def __init__(self, input_dim=556, output_tendancies_dim=120, output_surface_dim=8):
        super(ClimSimMLP, self).__init__()
        
        # Hidden Layers: [768, 640, 512, 640, 640]
        self.layer1 = nn.Linear(input_dim, 768)
        self.layer2 = nn.Linear(768, 640)

        self.layer3 = nn.Linear(640, 512)
        self.layer4 = nn.Linear(512, 640)
        self.layer5 = nn.Linear(640, 640)
        

        self.last_hidden = nn.Linear(640, 128)
        
        # --- Output Heads ---
        self.head_tendencies = nn.Linear(128, output_tendancies_dim)
        # self.head_surface = nn.Linear(128, output_surface_dim)
        
        # LeakyReLU alpha=0.15
        self.activation = nn.LeakyReLU(0.15)
        
        for m in self.modules():
                if isinstance(m, nn.Linear):
                    nn.init.xavier_uniform_(m.weight)
                    if m.bias is not None:
                        nn.init.zeros_(m.bias)

    def forward(self, x):
        # Pass through the 5 main hidden layers
        x = self.activation(self.layer1(x))
        x = self.activation(self.layer2(x))
        x = self.activation(self.layer3(x))
        x = self.activation(self.layer4(x))
        x = self.activation(self.layer5(x))
        
        # Pass through the fixed 128 layer
        x = self.activation(self.last_hidden(x))
        
        # Output 1: Tendencies (Linear activation)
        out_linear = self.head_tendencies(x)
        
        # Output 2: Surface variables (ReLU activation)
        # out_relu = F.relu(self.head_surface(x))
        
        # Concatenate along the feature dimension (dim=1)
        return out_linear

In [1]:
import xarray as xr
import numpy as np
import torch 
from torch.utils.data import Dataset, Subset
import re

from lib import data


class ClimSimBase:
    def __init__(self, zarr_path, grid_path, norm_path, features, num_latlon = 384, normalize=True):
        self.ds = xr.open_zarr(zarr_path, chunks=None)
        self.features = features
        self.features_list = self.__get_features__()
        self.normalize_flag = normalize
        self.num_latlon = num_latlon
        self.grid = xr.open_dataset(grid_path, engine="netcdf4")
        
        self.input_mean = xr.open_dataset(os.path.join(norm_path, "inputs/input_mean.nc"), engine="netcdf4")
        self.input_std = xr.open_dataset(os.path.join(norm_path, "inputs/input_std.nc"), engine="netcdf4")
        self.input_max = xr.open_dataset(os.path.join(norm_path, "inputs/input_max.nc"), engine="netcdf4")
        self.input_min = xr.open_dataset(os.path.join(norm_path, "inputs/input_min.nc"), engine="netcdf4")
        self.output_scale = xr.open_dataset(os.path.join(norm_path, "outputs/output_scale.nc"), engine="netcdf4")

        self.grid['area_wgt'] = self.grid['area']/self.grid['area'].mean(dim = 'ncol')
        self.area_wgt = self.grid['area_wgt'].values

        self.input_vars = [v for v in self.features_list if 'in' in v]
        self.output_vars = [v for v in self.features_list if 'out' in v]

        self.grav    = 9.80616    # acceleration of gravity ~ m/s^2
        self.cp      = 1.00464e3  # specific heat of dry air   ~ J/kg/K
        self.lv      = 2.501e6    # latent heat of evaporation ~ J/kg
        self.lf      = 3.337e5    # latent heat of fusion      ~ J/kg
        self.lsub    = self.lv + self.lf    # latent heat of sublimation ~ J/kg
        self.rho_air = 101325/(6.02214e26*1.38065e-23/28.966)/273.15 # density of dry air at STP  ~ kg/m^3
        self.rho_h20 = 1.e3       # density of fresh water     ~ kg/m^ 3

        self.target_energy_conv = {'ptend_t':self.cp,
                            'ptend_q0001':self.lv,
                            'ptend_q0002':self.lv,
                            'ptend_q0003':self.lv,
                            'ptend_qn':self.lv,
                            'ptend_wind': None,
                            'cam_out_NETSW':1.,
                            'cam_out_FLWDS':1.,
                            'cam_out_PRECSC':self.lv*self.rho_h20,
                            'cam_out_PRECC':self.lv*self.rho_h20,
                            'cam_out_SOLS':1.,
                            'cam_out_SOLL':1.,
                            'cam_out_SOLSD':1.,
                            'cam_out_SOLLD':1.
                            }
        
        
        self.dp = None 
        self.pressure_grid = None

    def __get_features__(self):
        feat = np.concatenate([self.features["features"]["multilevel"], self.features["features"]["surface"]])
        target = np.concatenate([self.features["target"]["tendancies"], self.features["target"]["surface"]])
        return np.concatenate([feat, target])

    def _prepare_data(self, idx):
        x = self.process_list(self.input_vars, idx, is_input=True)
        y = self.process_list(self.output_vars, idx, is_input=False)
        return x, y

    def process_list(self, vars_list, idx, is_input=True):
        out_list = []
        n_geo = self.num_latlon # 384

        for var in vars_list:
            if "ptend" in var:
                # Cette fonction doit renvoyer du (Time, 384, 60)
                data = self._calculate_tendency_on_fly(var, idx)
                if data.ndim == 2: # Si (384, 60)
                    data = data[np.newaxis, :, :]
            
            else:
                da = self.ds[var].isel(sample=idx)
                
                # Redressement par nom de dimension Xarray
                if 'lev' in da.dims:
                    if "sample" in da.dims:
                        data = da.transpose('sample', 'ncol', 'lev').values
                    else:
                        data = da.transpose('ncol', 'lev').values[np.newaxis, :, :]
                else:
                    if "sample" in da.dims:
                        # Surface : (Time, 384)
                        data = da.values[:, :, np.newaxis]  # Ajouter une dimension lev=1
                    else:
                        # Surface : (Time, 384) -> (Time, 384, 1)
                        data = da.values[np.newaxis, :, np.newaxis]
            
            # 2. Normalisation (Maintenant data est garanti (N, 384, L))
            data = self._normalize_var(data, var, is_input=is_input)
            out_list.append(data.astype(np.float32))

        # 3. Concaténation et aplatissement
        combined = np.concatenate(out_list, axis=-1)
        return combined.reshape(-1, combined.shape[-1])

    def __len__(self):
        return self.ds.dims['sample']

    def _calculate_tendency_on_fly(self, var, idx):
        dt = 1200
        mapping = {
            'out_ptend_t': ('out_state_t', 'in_state_t'),
            'out_ptend_q0001': ('out_state_q0001', 'in_state_q0001'),
            'out_ptend_u': ('out_state_u', 'in_state_u'),
            'out_ptend_v': ('out_state_v', 'in_state_v'),
        }
        out_v, in_v = mapping[var]

        v_final = self.ds[out_v].isel(sample=idx)
        v_init  = self.ds[in_v].isel(sample=idx)

        # Fonction utilitaire: remettre en ordre (sample?, ncol, lev) si sample existe
        def to_array(da):
            dims = da.dims

            if 'ncol' not in dims or 'lev' not in dims:
                raise ValueError(f"{da.name}: dims inattendues {dims}, attendu ncol & lev")

            # Cas slice -> dims contiennent sample
            if 'sample' in dims:
                da = da.transpose('sample', 'ncol', 'lev')
                return da.values  # (sample, ncol, lev)

            # Cas int -> dims (lev, ncol) ou (ncol, lev)
            da = da.transpose('ncol', 'lev')
            return da.values[None, ...]  # (1, ncol, lev)

        vf = to_array(v_final)
        vi = to_array(v_init)

        return (vf - vi) / dt  # (time, ncol, lev)


    
    def _normalize_var(self, data, var_name, is_input=True):
        # data est (N, 384, L) où L est 1 ou 60
        short_name = re.sub(r'^(in_|out_)', '', var_name)
        
        if is_input:
            m = self.input_mean[short_name].values     # (L,)
            diff = (self.input_max[short_name].values - self.input_min[short_name].values) # (L,)
            
            # On redimensionne les stats en (1, 1, L) pour s'aligner sur data (N, 384, L)
            m = m.reshape(1, 1, -1)
            diff = diff.reshape(1, 1, -1)
            
            return (data - m) / (diff + 1e-15)
        else:
            scale = self.output_scale[short_name].values # (L,)
            return data * scale.reshape(1, 1, -1)
            
    def set_pressure_grid(self, input_data):
        '''
        Calcule la grille de pression 3D à partir de state_ps.
        Code directement issu de ClimSim original.
        '''
        self.ps_index = self._find_ps_index(self.features)
        state_ps = input_data[:, self.ps_index]
        if self.normalize_flag:
            state_ps = state_ps * (self.input_max['state_ps'].values - self.input_min['state_ps'].values) + self.input_mean['state_ps'].values
        
        state_ps = state_ps.reshape(-1, self.num_latlon)

        p1 = (self.grid['P0'] * self.grid['hyai']).values[:, None, None]
        p2 = self.grid['hybi'].values[:, None, None] * state_ps[None, :, :]
        
        self.pressure_grid = p1 + p2
        self.dp = (self.pressure_grid[1:61] - self.pressure_grid[0:60]).transpose((1, 2, 0))
    
    
    def denormalize_output(self, y_pred):
        """Dénormalise les prédictions."""

        full_scale_vector = [] # To vectorize we generate the full scale vector first
        for var in self.output_vars:
            short_name = re.sub(r'^(in_|out_)', '', var)
            scale = self.output_scale[short_name].values
            
            if 'ptend' in var:
                dim_size = 60
            elif 'lev' in self.ds[var].dims:
                dim_size = self.ds[var].sizes['lev']
            else:
                dim_size = 1
            
            if np.isscalar(scale) or scale.size == 1:
                scale_expanded = np.full(dim_size, scale)
            else:
                scale_expanded = scale
                
            full_scale_vector.append(scale_expanded)
        
        full_scale_vector = np.concatenate(full_scale_vector)

        y_denorm = y_pred / (full_scale_vector + 1e-8) 

        return (y_denorm).astype(np.float32)
    
    def calc_MAE(self, pred, target, avg_grid = True):
        '''
        calculate 'globally averaged' mean absolute error 
        for vertically-resolved variables, shape should be time x grid x level
        for scalars, shape should be time x grid

        returns vector of length level or 1
        '''
        assert pred.shape[1] == self.num_latlon
        assert pred.shape == target.shape
        mae = np.abs(pred - target).mean(axis = 0)
        if avg_grid:
            return mae.mean(axis = 0) # we decided to average globally at end
        else:
            return mae
    
    def calc_RMSE(self, pred, target, avg_grid = True):
        '''
        calculate 'globally averaged' root mean squared error 
        for vertically-resolved variables, shape should be time x grid x level
        for scalars, shape should be time x grid

        returns vector of length level or 1
        '''
        assert pred.shape[1] == self.num_latlon
        assert pred.shape == target.shape
        sq_diff = (pred - target)**2
        rmse = np.sqrt(sq_diff.mean(axis = 0)) # mean over time
        if avg_grid:
            return rmse.mean(axis = 0) # we decided to separately average globally at end
        else:
            return rmse

    def calc_R2(self, pred, target, avg_grid = True):
        '''
        calculate 'globally averaged' R-squared
        for vertically-resolved variables, input shape should be time x grid x level
        for scalars, input shape should be time x grid

        returns vector of length level or 1
        '''
        assert pred.shape[1] == self.num_latlon
        assert pred.shape == target.shape
        sq_diff = (pred - target)**2
        tss_time = (target - target.mean(axis = 0)[np.newaxis, ...])**2 # mean over time
        r_squared = 1 - sq_diff.sum(axis = 0)/tss_time.sum(axis = 0) # sum over time
        if avg_grid:
            return r_squared.mean(axis = 0) # we decided to separately average globally at end
        else:
            return r_squared
    
    def output_weighting(self, output):
        num_samples = output.shape[0]
        n_geo = self.num_latlon
        n_time = num_samples // n_geo
        
        # On reste sur le dictionnaire de base
        var_dict = {}
        
        # On récupère dp/g (le poids vertical)
        # Assure-toi que self.dp est bien calculé avant !
        dp_weight = self.dp / self.grav # (Time, 384, 60)

        # Exemple pour ptend_t (indices 0:60)
        ptend_t = output[:, 0:60].reshape(n_time, n_geo, 60)
        # 1. Dénormalisation (revenir aux unités physiques : K/s)
        ptend_t /= self.output_scale['ptend_t'].values[None, None, :]
        # 2. Conversion en W/m2 (poids vertical * chaleur spécifique)
        var_dict['ptend_t'] = ptend_t * dp_weight * self.cp
        
        # Exemple pour ptend_q0001 (indices 60:120)
        ptend_q = output[:, 60:120].reshape(n_time, n_geo, 60)
        ptend_q /= self.output_scale['ptend_q0001'].values[None, None, :]
        var_dict['ptend_q0001'] = ptend_q * dp_weight * self.lv
        
        # Pour les variables de surface (indices 120:128)
        # Elles sont déjà en W/m2 après dé-scaling, pas besoin de dp/g
        surface_vars = ['cam_out_NETSW', 'cam_out_FLWDS', 'cam_out_PRECSC', 
                        'cam_out_PRECC', 'cam_out_SOLS', 'cam_out_SOLL', 
                        'cam_out_SOLSD', 'cam_out_SOLLD']
        
        for i, var in enumerate(surface_vars):
            idx = 120 + i
            val = output[:, idx].reshape(n_time, n_geo)
            # Dé-scaling uniquement
            val /= self.output_scale[var].values
            var_dict[var] = val

        return var_dict
    
    def _find_ps_index(self, features_dict):
        """
        Calcule l'index de départ de 'in_state_ps' dans le vecteur d'entrée plat.
        Prend en compte que chaque variable multilevel occupe 60 colonnes.
        """
        current_index = 0
        
        # 1. Parcourir les variables multi-niveaux (60 niveaux chacune)
        for var in features_dict["features"]["multilevel"]:
            if var == "in_state_ps":
                return current_index
            current_index += 60
            
        # 2. Parcourir les variables de surface (1 niveau chacune)
        for var in features_dict["features"]["surface"]:
            if var == "in_state_ps":
                return current_index
            current_index += 1
            
        raise ValueError("Variable 'in_state_ps' non trouvée dans le dictionnaire FEATURES.")

class ClimSimPyTorch(ClimSimBase, Dataset):
    def __getitem__(self, idx):
        x_np, y_np = self._prepare_data(idx)
        return torch.from_numpy(x_np), torch.from_numpy(y_np)

    # On peut remettre ta méthode de split ici
    def train_test_split(self, test_size=0.2, seed=42, shuffle=True):
        n = len(self)
        indices = np.arange(n)
        if shuffle:
            rng = np.random.default_rng(seed)
            rng.shuffle(indices)
        split = int((1 - test_size) * n)
        return Subset(self, indices[:split]), Subset(self, indices[split:])
    
    def get_models_dims(self, variables_dict):
        features_tend = variables_dict["features"]["multilevel"]
        features_surf = variables_dict["features"]["surface"]
        
        target_tend = variables_dict["target"]["tendancies"]
        target_surf = variables_dict["target"]["surface"]

        def get_var_dim(var):
            # 1. Gérer les variables virtuelles (tendances calculées)
            if 'ptend' in var:
                # On mappe vers la variable d'état pour connaître la dimension 'lev'
                # ex: out_ptend_t -> out_state_t
                source_var = var.replace('ptend', 'state')
                return self.ds[source_var].sizes['lev']
            
            # 2. Gérer les variables réelles présentes dans le Zarr
            if 'lev' in self.ds[var].dims:
                return self.ds[var].sizes['lev']
            
            # 3. Variables de surface (scalaires)
            return 1

        in_tend_dim = sum([get_var_dim(var) for var in features_tend])
        in_surf_dim = len(features_surf)
        
        out_tend_dim = sum([get_var_dim(var) for var in target_tend])
        out_surf_dim = len(target_surf)

        return {
            "input_total": in_tend_dim + in_surf_dim,
            "output_tendancies": out_tend_dim,
            "output_surface": out_surf_dim
        }
    
        
    def _normalize_var(self, data, var_name, is_input=True):
        """Applique la normalisation selon que la variable est 3D ou de surface."""
        if not self.normalize_flag:
            return data

        short_name = re.sub(r'^(in_|out_)', '', var_name)

        if is_input:
            # Récupération des valeurs (L,) ou scalaire
            m = self.input_mean[short_name].values
            s = self.input_std[short_name].values
            
            # Redimensionnement en (1, 1, L) pour broadcast sur (N, 384, L)
            # -1 détecte automatiquement si c'est 1 (surface) ou 60 (profil)
            m_norm = m.reshape(1, 1, -1)
            s_norm = s.reshape(1, 1, -1)
            
            return (data - m_norm) / (s_norm + 1e-8)
        
        else:
            # Pour l'output (Target), on multiplie par le scale
            scale = self.output_scale[short_name].values
            
            # Redimensionnement identique (1, 1, L)
            scale_norm = scale.reshape(1, 1, -1)
            
            return data * scale_norm
        
    def get_batch_for_pytorch(self, idx, input_dim, output_dim):
        x_np, y_np = self._prepare_data(idx)

        x_np = x_np.reshape(-1, input_dim)
        y_np = y_np.reshape(-1, output_dim)

        return torch.from_numpy(x_np).float(), torch.from_numpy(y_np).float()          


    
    
class ClimSimKeras(ClimSimBase):
    def get_batch_for_keras(self, idx, input_dim, output_dim):
        x_np, y_np = self._prepare_data(idx)

        x_np = x_np.reshape(-1, input_dim)
        y_np = y_np.reshape(-1, output_dim)
        return x_np, y_np
    
    def get_sample_number(self):
        return self.ds.dims['sample']
    

In [2]:
BATCH_SIZE = 3072
N_EPOCHS = 10

FEATURES = {
    "features" :{
        "multilevel" : ["in_state_t", "in_state_q0001"],
        "surface" : ["in_state_ps", "in_pbuf_LHFLX", "in_pbuf_SHFLX", "in_pbuf_SOLIN"],
    },  
    "target" :{
        "tendancies" : ["out_ptend_t", "out_ptend_q0001"],
        "surface" : []
    }
}

dataset = ClimSimPyTorch(ZARR_PATH, LOW_RES_GRID_PATH, NORM_PATH, FEATURES, normalize=True)
model_dims = dataset.get_models_dims(FEATURES)

model = ClimSimMLP(input_dim=model_dims["input_total"], output_tendancies_dim=model_dims["output_tendancies"], output_surface_dim=model_dims["output_surface"])
optimizer = torch.optim.RAdam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()

NameError: name 'ZARR_PATH' is not defined

In [8]:
train, test = dataset.train_test_split(test_size=0.2, seed=42)

train_loader = torch.utils.data.DataLoader(
    train, 
    batch_size=BATCH_SIZE, 
    shuffle=True,
    num_workers=4,
)

test_loader = torch.utils.data.DataLoader(
    test, 
    batch_size=BATCH_SIZE, 
    shuffle=False,
    num_workers=4,
)

  return self.ds.dims['sample']


In [None]:
for epoch in range(N_EPOCHS):
    train_loss = train_one_epoch(
        model, 
        train_loader, 
        optimizer, 
        criterion, 
        device="cpu",
        input_dim=model_dims["input_total"],
        output_dim=model_dims["output_tendancies"] + model_dims["output_surface"],
        )
    val_loss = evaluate_model(
        model, 
        test_loader, 
        criterion, 
        device="cpu",
        input_dim=model_dims["input_total"],
        output_dim=model_dims["output_tendancies"] + model_dims["output_surface"],
        )
    
    print(f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    model_path = f"climsim_mlp_epoch{epoch+1}.pth"
    torch.save(model.state_dict(), model_path)
        

Training:   0%|          | 0/1 [00:01<?, ?batch/s]


KeyboardInterrupt: 

: 

In [6]:
import numpy as np
import re

models_dims = dataset.get_models_dims(FEATURES)

state_dict = torch.load("models/climsim_mlp_epoch10.pth", map_location=torch.device('cpu'))
model.load_state_dict(state_dict)
model.eval()

target_vars = FEATURES["target"]["tendancies"] + FEATURES["target"]["surface"]
input_size = models_dims["input_total"]
output_size = models_dims["output_tendancies"] + models_dims["output_surface"]

# 1. Initialisation des accumulateurs
stats = {}
for var in target_vars:
    short_name = re.sub(r'^(in_|out_)', '', var)
    size = 60 if var in FEATURES["target"]["tendancies"] else 1
    stats[short_name] = {
        "ss_res": np.zeros(size),      # Somme des carrés (pour RMSE et R2)
        "sum_abs_err": np.zeros(size), # Somme des erreurs absolues (pour MAE)
        "sum_y": np.zeros(size),       # Somme de y (pour variance R2)
        "sum_y_sq": np.zeros(size),    # Somme de y^2 (pour variance R2)
        "count": 0
    }

# 2. Boucle de test (Batch par Batch)
for i in range(100):
    x_batch, y_true_batch = dataset.get_batch_for_pytorch(i, input_size, output_size)
    dataset.set_pressure_grid(x_batch.numpy()) # Mise à jour de dp pour ce batch
    
    with torch.no_grad():
        preds_batch = model(x_batch)
        preds_batch = preds_batch.cpu().numpy()
    

    # Récupération des dictionnaires physiques {var: (Time, 384, L)}
    true_dict = dataset.output_weighting(y_true_batch)
    pred_dict = dataset.output_weighting(preds_batch)

    for var in true_dict.keys():
        if torch.is_tensor(true_dict[var]):
            true_dict[var] = true_dict[var].detach().cpu().numpy()
        if torch.is_tensor(pred_dict[var]):
            pred_dict[var] = pred_dict[var].detach().cpu().numpy()

    
    
    for var in target_vars:
        short_name = re.sub(r'^(in_|out_)', '', var)
        y_t = true_dict[short_name] 
        y_p = pred_dict[short_name]
        
        # Différences
        diff = y_t - y_p
        
        # Accumulation (on réduit sur Time et Grille : axes 0 et 1)
        stats[short_name]["ss_res"] += np.sum(diff**2, axis=(0, 1))
        stats[short_name]["sum_abs_err"] += np.sum(np.abs(diff), axis=(0, 1))
        stats[short_name]["sum_y"] += np.sum(y_t, axis=(0, 1))
        stats[short_name]["sum_y_sq"] += np.sum(y_t**2, axis=(0, 1))
        stats[short_name]["count"] += y_t.shape[0] * y_t.shape[1]

# 3. Calcul final et Affichage
print(f"\n{'Variable':<25} | {'MAE (W/m2)':<12} | {'RMSE (W/m2)':<12} | {'R2 Global':<10}")
print("-" * 75)

for var in target_vars:
    short_name = re.sub(r'^(in_|out_)', '', var)
    s = stats[short_name]
    n = s["count"]
    
    # MAE par niveau et moyenne
    mae_levels = s["sum_abs_err"] / n
    final_mae = np.mean(mae_levels)
    
    # R2 par niveau et moyenne
    ss_tot = s["sum_y_sq"] - (s["sum_y"]**2 / n)
    r2_levels = 1 - (s["ss_res"] / (ss_tot + 1e-15))
    final_r2 = np.mean(r2_levels)
    
    print(f"{var:<25} | {final_mae:.4e}| {final_r2:.4f}")

  state_dict = torch.load("models/climsim_mlp_epoch10.pth", map_location=torch.device('cpu'))
  val /= scale[None, None, :] if is_prof else scale
  if is_prof: val *= dp
  val *= self.area_wgt[None, :, None] if is_prof else self.area_wgt[None, :]



Variable                  | MAE (W/m2)   | RMSE (W/m2)  | R2 Global 
---------------------------------------------------------------------------
out_ptend_t               | 1.1484e+02| -77.8710
out_ptend_q0001           | 1.1205e+02| -3291028786516.5146


In [11]:
model.eval()
all_preds = []
all_targets = []

with torch.no_grad():
    for inputs, targets in test_loader:
        # Move to device
        inputs, targets = inputs.to(device), targets.to(device)
        
        # Flatten the ncol dimension into the batch dimension for the MLP
        # inputs shape: [Batch, 384, Input_Dim] -> [Batch*384, Input_Dim]
        inputs = inputs.view(-1, inputs.shape[-1])
        
        outputs = model(inputs)
        
        all_preds.append(outputs.cpu().numpy())
        all_targets.append(targets.view(-1, targets.shape[-1]).cpu().numpy())

# Combine everything
final_preds = np.concatenate(all_preds, axis=0)
final_targets = np.concatenate(all_targets, axis=0)

# IMPORTANT: Denormalize BEFORE calculating physical metrics
preds_phys = dataset.denormalize_output(final_preds)
targets_phys = dataset.denormalize_output(final_targets)

# Use your built-in MAE/RMSE
# Note: You might need to reshape back to (Time, 384, Var) for calc_MAE 
# if your functions expect the spatial dimension separated.

KeyboardInterrupt: 