This script evaluates deep learning models on preprocessed climate datasets stored in Zarr format. It loads test data (season_da, annual_da, and index_da), applies the trained model(s) to generate predictions, and computes performance metrics (R², MAE, MSE) for target variables (rx90p_anom, pr_anom). It optionally saves results to CSV and model predictions to Zarr format for further analysis or visualization.

In [33]:
from warnings import warn
import sys
import os
import pandas as pd
import numpy as np
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error

from models_NN import *

In [None]:
# Input file
test_dataset_path = "/data/dl20-data/climate_operational/Victor_data/preprocessed_datasets_NN_new/test/"
trained_models_basepath = "/home/vgarcia/experiments/NN_annual_new/"

prediction_path = "/data/dl20-data/climate_operational/Victor_data/predicted_datasets_NN_new/"
test_score_path = f"{trained_models_basepath}test_metrics.csv"

#"SmallUNet_Hist", "SmallUNet_era5_12batch"
experiment_names = ["SmallUNet_All", "SmallUNet_All_subset", "SmallUNet_Hist", "SmallUNet_Hist_subset",
                    "SmallUNet_era5"]

model_name = "SmallUNet"
batch_size = 8
test_mode = False
overwrite_test = True
store_predictions = True

In [35]:
# check inputs
model_dict = {"SmallUNet" : SmallUNet()}

# ensure models and parameters exist
if model_name not in model_dict:
    raise NotImplementedError

In [None]:
# Load preprocessed datasets
season_da = xr.open_zarr(test_dataset_path + "/season_da.zarr")
season_da = season_da[list(season_da.data_vars)[0]]

annual_da = xr.open_zarr(test_dataset_path + "/annual_da.zarr")
annual_da = annual_da[list(annual_da.data_vars)[0]]

index_da = xr.open_zarr(test_dataset_path + "/index_da.zarr")
index_da = index_da[list(index_da.data_vars)[0]]

if test_mode:
    print("WARNING: TEST model")
    index_da = index_da.sel(year=slice("2000", "2001"))
    season_da = season_da.sel(year=slice("2000", "2001"))
    annual_da = annual_da.sel(year=slice("2000", "2001"))

n_total = season_da.sizes['year']
indices = list(range(n_total))
test_dataset = XarrayENSODataset(season_da, index_da, annual_da, indices)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

# Setup model, optimizer, loss
for experiment in experiment_names:
    print(f"Testing {experiment}")

    # check if experiment is already tested
    if os.path.exists(test_score_path):
        df_exists = True
        df_existing = pd.read_csv(test_score_path)
        if experiment in df_existing["Model"].values:
            raise FileExistsError(f"Model '{experiment}' already exists in {test_score_path}")

    ### Loads weights of the model ###
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model_dict[model_name].to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = MaskedMSELoss()

    trained_model_path = f"{trained_models_basepath}{experiment}/"
    best_model_files = [f for f in os.listdir(trained_model_path) if f.endswith('_best.pt')]

    if best_model_files:
        best_model_path = os.path.join(trained_model_path, best_model_files[0])
    else:
        print("❌ No '_best.pt' file found.")

    checkpoint = torch.load(best_model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    print(f"✅ Loaded model from {trained_model_path}, trained until epoch {checkpoint['epoch']}")
    best_score = checkpoint["score"]

    ### Run Predictions ###
    all_predictions = []
    all_targets = []
    all_masks = []

    with torch.no_grad():
        for x_maps, x_nino, targets, y_mask, x_mask in tqdm(test_loader, desc="Predicting", leave=False):
            x_maps, x_nino = x_maps.to(device), x_nino.to(device)

            predictions = model(x_maps, x_nino)
            all_predictions.append(predictions.cpu())
            all_targets.append(targets)
            all_masks.append(y_mask)

    # Convert predictions to NumPy arrays
    y_pred = torch.cat(all_predictions).numpy()
    y_true = torch.cat(all_targets).numpy()
    y_mask = torch.cat(all_masks).numpy()

    ### Calculate performance metrics and store them in csv ###
    r2_results = {}
    mae_results = {}
    mse_results = {}
    targets = ["rx90p_anom", "pr_anom"]

    for i, name in enumerate(targets):
        mask = y_mask[:, i] == 1
        y_true_i = y_true[:, i][mask]
        y_pred_i = y_pred[:, i][mask]

        if len(y_true_i) > 0:
            r2 = r2_score(y_true_i, y_pred_i)
            mae = mean_absolute_error(y_true_i, y_pred_i)
            mse = mean_squared_error(y_true_i, y_pred_i)
        else:
            r2, mae, mse = np.nan, np.nan, np.nan

        mse_results[f"MSE_{name}"] = mse
        mae_results[f"MAE_{name}"] = mae
        r2_results[f"R2_{name}"] = r2

    results = {"Model": experiment}
    results.update(mse_results)
    results.update(mae_results)
    results.update(r2_results)

    df_new = pd.DataFrame([results]) 

    if os.path.exists(test_score_path):
        df_combined = pd.concat([df_existing, df_new], ignore_index=True)
        df_combined.to_csv(test_score_path, index=False)
        print(f"Appended new results for model '{experiment}' to {test_score_path}.")
    else:
        df_new.to_csv(test_score_path, index=False)
        print(f"Created {test_score_path} and saved results for model '{experiment}'.")

    ### Store predictions in zarr ###
    if store_predictions:
        print("Storing predictions")
        prediction_path = f"{prediction_path}{experiment}_predicted.zarr"

        lat = annual_da['lat'].values
        lon = annual_da['lon'].values
        years = annual_da['year'].values
        variables = [0, 1]

        # mask ocean
        y_pred_masked = np.where(y_mask, y_pred, np.nan)

        # Create xarray DataArray
        pred_da = xr.DataArray(
            y_pred_masked,
            dims=["year", "variable_index", "lat", "lon"],
            coords={
                "year": years,
                "variable_index": variables,
                "lat": lat,
                "lon": lon
            },
            name="predicted_annual"
        )

        pred_da = pred_da.chunk({
            "year": -1,      
            "lat": 1,                
            "lon": 1,
            "variable_index": 1
        })

        pred_da.to_zarr(prediction_path, mode="w")
        print("✅ Saved predictions")


  checkpoint = torch.load(best_model_path, map_location=device)


Testing SmallUNet_All
✅ Loaded model from /home/vgarcia/experiments/NN_annual_new/SmallUNet_All/, trained until epoch 22


Predicting:   0%|          | 0/6 [00:00<?, ?it/s]

                                                         

Appended new results for model 'SmallUNet_All' to /home/vgarcia/experiments/NN_annual_new/test_metrics.csv.
Storing predictions


  checkpoint = torch.load(best_model_path, map_location=device)


✅ Saved predictions
Testing SmallUNet_All_subset
✅ Loaded model from /home/vgarcia/experiments/NN_annual_new/SmallUNet_All_subset/, trained until epoch 13


Predicting:  33%|███▎      | 2/6 [01:57<03:56, 59.05s/it]