In [1]:
import xarray as xr
import numpy as np
import xarray as xr
import os

In [2]:
prediction_files = ["/home/egauillard/extreme_events_forecasting/earthfomer_mediteranean/src/model/experiments/earthformer_era_20240807_180001_tp_coarse_t_s_gap_3_45/inference_plots/all_predictions.nc",
             "/home/egauillard/extreme_events_forecasting/earthfomer_mediteranean/src/model/experiments/earthformer_era_20240807_174626_tp_coarse_t_s_gap_3_43/inference_plots/all_predictions.nc",
             "/home/egauillard/extreme_events_forecasting/earthfomer_mediteranean/src/model/experiments/earthformer_era_20240806_163342_tp_coarse_t_s_gap_3_42/inference_plots/all_predictions.nc",
             "/home/egauillard/extreme_events_forecasting/earthfomer_mediteranean/src/model/experiments/earthformer_era_20240801_190230__tp_coarse_t_s_gap_3_0/inference_plots/all_predictions.nc"
    ]

climatology_file = "/home/egauillard/extreme_events_forecasting/earthfomer_mediteranean/src/model/experiments/earthformer_era_20240807_180001_tp_coarse_t_s_gap_3_45/inference_plots/all_climatology.nc"
ground_truth_file = "/home/egauillard/extreme_events_forecasting/earthfomer_mediteranean/src/model/experiments/earthformer_era_20240801_190230__tp_coarse_t_s_gap_3_0/inference_plots/all_ground_truths.nc"

## Evaluation

In [4]:
# The RPS aggregates the squared probability errors across K (K = 3 in this work) categories,
#  such as tercile, arranged in ascending order. The tercile bounds are determined based on the average values over either 1-week or 2-week periods for each corresponding verifica- tion period
#These calculations of tercile bounds are performed separately for each forecast model and observation (ERA5 data). The metric assesses the accuracy with which the probability forecast pre- dicts the actual observation category. 

class ModelEvaluation:
    def __init__(self, prediction_files, ground_truth_file, climatology_file, save_folder):
        self.prediction_datasets = [xr.open_dataset(file) for file in prediction_files]
        self.ground_truth_dataset = xr.open_dataset(ground_truth_file)
        self.climatology_dataset = xr.open_dataset(climatology_file)
        self.save_folder = save_folder
        os.makedirs(save_folder, exist_ok=True)
        
        self.target_variables = list(self.ground_truth_dataset.data_vars)

    def calculate_rpss(self):
        self.rpss = {var: [] for var in self.target_variables}
        
        for var in self.target_variables:
            for model_idx, pred_dataset in enumerate(self.prediction_datasets):
                rps_model = 0
                rps_climatology = 0
                
                for lead_time in range(len(self.ground_truth_dataset.time)):
                    truth = self.ground_truth_dataset[var].isel(time=lead_time).values.flatten()
                    clim = self.climatology_dataset[var].isel(time=lead_time).values.flatten()
                    pred = pred_dataset[var].isel(time=lead_time).values.flatten()

                    # Calculate tercile bounds for this specific period
                    categories = np.percentile(truth, [33.33, 66.67])

                    prob_model = self._calculate_probabilities(pred, categories)
                    prob_climatology = self._calculate_probabilities(clim, categories)
                    obs_categorical = self._categorize_observations(truth, categories)

                    rps_model += self._calculate_rps(prob_model, obs_categorical)
                    rps_climatology += self._calculate_rps(prob_climatology, obs_categorical)

                rps_model /= len(self.ground_truth_dataset.time)
                rps_climatology /= len(self.ground_truth_dataset.time)
                
                rpss = 1 - (rps_model / rps_climatology)
                self.rpss[var].append(rpss)

    def _calculate_probabilities(self, data, categories):
        probabilities = np.zeros((len(data), 3))  # 3 categories for terciles
        probabilities[:, 0] = np.sum(data <= categories[0], axis=0) / len(data)
        probabilities[:, 1] = np.sum((data > categories[0]) & (data <= categories[1]), axis=0) / len(data)
        probabilities[:, 2] = np.sum(data > categories[1], axis=0) / len(data)
        return probabilities

    def _categorize_observations(self, observations, categories):
        obs_categorical = np.digitize(observations, bins=categories)
        return obs_categorical

    def _calculate_rps(self, forecast_probabilities, obs_categorical):
        rps = 0
        for i in range(len(obs_categorical)):
            obs_cdf = np.cumsum(np.eye(3)[obs_categorical[i]])  # 3 categories for terciles
            forecast_cdf = np.cumsum(forecast_probabilities[i])
            rps += np.sum((forecast_cdf - obs_cdf) ** 2)
        rps /= len(obs_categorical)
        return rps

    def save_rpss_to_csv(self):
        df = pd.DataFrame()
        for var in self.target_variables:
            df[f'{var}_rpss'] = self.rpss[var]
        
        # Calculer les moyennes de chaque colonne
        means = df.mean()
        
        # Ajouter une ligne de séparation (optionnel)
        df.loc['---'] = '---'
        
        # Ajouter la ligne des moyennes
        df.loc['Mean'] = means
        
        # Sauvegarder le DataFrame en CSV
        df.to_csv(os.path.join(self.save_folder, 'rpss_results.csv'), index=True)

# Usage
save_folder = './evaluation_results'

evaluator = ModelEvaluation(prediction_files, ground_truth_file, climatology_file, save_folder)
evaluator.calculate_rpss()
evaluator.save_rpss_to_csv()