In [94]:
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'true'
# Set this lower, to allow for PyTorch Model to fit into memory
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.90' 
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

import sys
package_path = os.path.dirname(os.path.dirname(os.getcwd())) 
sys.path.insert(0, package_path)
from glob import glob 

from wofscast.model import WoFSCastModel
from wofscast.data_generator import load_chunk, dataset_to_input, add_local_solar_time
from wofscast.common.wofs_data_loader import WoFSDataLoader
from wofscast.common.wofs_analysis_loader import WoFSAnalysisLoader
from wofscast.common.mrms_data_loader import MRMSDataLoader 

# For the diffusion model. 
from wofscast.diffusion import DiffusionModel 

# Utils for loading data, plotting, animations. 
from wofscast.common.helpers import (get_case_date, 
                                     to_datetimes, 
                                     get_qpe_datetimes, 
                                     border_difference_check,
                                     compute_nmep, 
                                     convert_rain_amount_to_inches, 
                                     convert_T2_K_to_F,
                                     _border_mask, 
                                     parse_arguments, 
                                     load_configuration
                                    )
from dataclasses import dataclass
import argparse
from scipy.ndimage import uniform_filter

# For plotting. 
import numpy as np
import pandas as pd
import xarray as xr
from datetime import datetime, timedelta

import itertools 
from tqdm import tqdm

import sys
sys.path.insert(0, '/home/monte.flora/python_packages/MontePython')
import monte_python

In [99]:
# This configuration class contains all the user-settings required to run this notebook. 

@dataclass
class EvaluatorConfig :    
    # Path to the WoFSCast model weights.
    model_path = '/work/cpotvin/WOFSCAST/model/wofscast_test_v178.npz'    
    timestep = 10 
    steps_per_hour = 60 // timestep # 60 min / 5 min time steps
    hours = 1
    n_steps = int(steps_per_hour * hours)
    year = '2021'
    mem = 9 
    resize = True
    full_domain = False 
    n_times = 12 
    wofs_dbz_thresh = 47 # Same for both WoFS and WoFSCast 
    mrms_dbz_thresh = 40 
    matching_dist = 7
    min_area = 12 
    domain_size = 150 
    FSS_vars = ['COMPOSITE_REFL_10CM', 'RAIN_AMOUNT']
    FSS_thres = {'COMPOSITE_REFL_10CM': 40, 'RAIN_AMOUNT': 25.4/2}
    window_list = [7, 15, 27]
    

# List of pandas datetime objects for plotting the time step. 
evaluator_config = EvaluatorConfig() 

In [141]:
def mean_preserving_time(x: xr.DataArray) -> xr.DataArray:
    return x.mean([d for d in x.dims if d != 'time'], skipna=True)

class ThunderScoreEvaluator:
    """A generalized class for evaluating thunderstorm-scale forecasts over multiple timesteps."""
    
    def __init__(self, evaluator_config, 
                 methods=['object_matching'], 
                 preprocess_fn=add_local_solar_time):
        """
        Initializes the evaluator with configuration and methods.

        Parameters:
        evaluator_config: Configuration for the evaluation process.
        methods: List of methods for comparing datasets.
        preprocess_fn: Preprocessing function to be applied on data before evaluation.
        """
        self.evaluator_config = evaluator_config
        self.methods = methods 
        self.preprocess_fn = preprocess_fn 
        self.cached_predictions = {}  # Cache for storing predictions
        
        matcher = monte_python.ObjectMatcher(cent_dist_max = self.evaluator_config.matching_dist, 
                                     min_dist_max = self.evaluator_config.matching_dist, 
                                     time_max=0, 
                                     score_thresh=0.2, 
                                     one_to_one = True)
        
        self.obj_verifier = monte_python.ObjectVerifier(matcher)  # Verifier for object-based metrics
    
        self.qcer = monte_python.QualityControler()
        self.qc_params = [('min_area', self.evaluator_config.min_area)]
        
        # Create a border mask for the domain (slow to constantly recreate this!!!)
        self.BORDER_MASK = _border_mask((self.evaluator_config.domain_size, 
                                         self.evaluator_config.domain_size), N=5)  # Adjust N as needed
    
    def _load_model(self):
        """Load the prediction model."""
        model = WoFSCastModel()
        if self.evaluator_config.full_domain:
            model.load_model(self.evaluator_config.model_path, **{'tiling': (2, 2)})
        else:    
            model.load_model(self.evaluator_config.model_path)
        
        # Set the task config.
        self.task_config = model.task_config
        return model     
    
    def _load_inputs_targets_forcings(self, data_path):
        """Load inputs, targets, and forcings.""" 
        inputs, targets, forcings = self.data_loader.load_inputs_targets_forcings(data_path)
        return inputs, targets, forcings 
    
    def _load_mrms_data(self, datetime_rng):
        """Load MRMS data for the given datetime range."""
        try:
            loader = MRMSDataLoader(
                self.data_loader.case_date, 
                datetime_rng, 
                domain_size=self.evaluator_config.domain_size, 
                resize_domain=self.evaluator_config.resize
            )
            mrms_dz = loader.load()  # Shape: (NT, NY, NX)
        except OSError:
            print(f'Unable to load MRMS data for {datetime_rng}')
            return None
        
        return mrms_dz 
    
    def _load_wofs_analysis_data(self):
        # Placeholder for loading WoFS analysis data
        # NOT FINISHED!
        pass
    
    def _init_results_dict(self):
        # Create an empty results_dict
        results_dict = {
            'Full Domain': {v: np.zeros((self.evaluator_config.n_times,)) 
                            for v in self.target_vars},
            'Convective Regions': {v: np.zeros((self.evaluator_config.n_times,)) 
                                   for v in self.target_vars},
        }
        
        # Initialize the contingency table metric storage. 
        self.obj_match_metrics = ['hits', 'misses', 'false_alarms']
        obj_match_keys = [f'{pair[0]}_vs_{pair[1]}_object_matching' for pair in comparison_pairs] 
        for m, s in itertools.product(self.obj_match_metrics, obj_match_keys):
            results_dict[f'{s}_{m}'] = np.zeros((self.evaluator_config.n_times))
        
        for v in self.evaluator_config.FSS_vars: 
            results_dict[f'{v}_FSS_numer'] = {w: np.zeros((self.evaluator_config.n_times)) 
                                          for w in self.evaluator_config.window_list}
            results_dict[f'{v}_FSS_denom'] = {w: np.zeros((self.evaluator_config.n_times)) 
                                          for w in self.evaluator_config.window_list}
        
        return results_dict
    
    def score(self, data_paths, 
              comparison_pairs=[('predictions', 'targets'), 
                                ('predictions', 'mrms'),
                                ('targets', 'mrms')]
             ):
        """
        Evaluate the model and compute scores for multiple datasets over time.

        Parameters:
        data_paths: List of data paths for the predictions and targets.
        comparison_pairs: A list of tuples where each tuple specifies a pair of datasets to compare.
        Example: [('predictions', 'targets'), ('predictions', 'mrms'), ('targets', 'mrms')]
        """
        # Load model first to get the task config for 
        # loading the input, target, and forcing datasets below.
        model = self._load_model()
        
        self.data_loader = WoFSDataLoader(
            self.evaluator_config, 
            self.task_config, 
            self.preprocess_fn, 
            load_ensemble=False
        )  
        
        self.target_vars = self.task_config.target_variables
        
        # Create an empty results_dict
        results_dict = self._init_results_dict()
        
        # Used to normalize the accumulated RMSE. 
        N = len(data_paths)
        
        for i, path in enumerate(tqdm(data_paths, desc='Evaluating Model')):
            print(f"Evaluating {path}...")
            
            datetime_rng = to_datetimes(path, n_times=self.evaluator_config.n_times+2)
            
            # Load inputs, targets, and forcings
            inputs, targets, forcings = self._load_inputs_targets_forcings(path)
            predictions = model.predict(inputs, targets, forcings, replace_bdry=True)
            predictions = predictions.transpose('batch', 'time', 'level', 'lat', 'lon')
            targets = targets.transpose('batch', 'time', 'level', 'lat', 'lon')
            
            predictions = predictions.isel(batch=0)
            targets = targets.isel(batch=0)
            
            # Load MRMS data for the forecast time series
            mrms_dz = self._load_mrms_data(datetime_rng)
            
            # Compute the RMSE statistics (can be computed all at once). 
            results_dict = self.accumulate_rmse(targets, predictions, results_dict)
            
            # Evaluate each time step for other metrics. 
            for t in range(self.evaluator_config.n_times):
                # Extract data for the current timestep
                timestep_datasets = {
                    'predictions': predictions.isel(time=t),
                    'targets': targets.isel(time=t),
                    'mrms': mrms_dz[t, :, :] if mrms_dz is not None else None
                }
                
                # Perform comparisons for each dataset pair 
                # Only used for the object matching statistics 
                # at the moment. 
                for pair in comparison_pairs:
                    dataset_1 = timestep_datasets[pair[0]]
                    dataset_2 = timestep_datasets[pair[1]]
                    results_dict = self._compare_datasets(t, pair[0], pair[1], dataset_1, dataset_2, 
                                                          results_dict)
                
                # Calculate FSS for each variable and window
                for var in self.evaluator_config.FSS_vars:
                    for window in self.evaluator_config.window_list: 
                        numer, denom = self.fractions_skill_score(timestep_datasets['predictions'][var], 
                                                                  timestep_datasets['targets'][var], 
                                                                  window, 
                                                                  self.evaluator_config.FSS_thres[var])
                        results_dict[f'{var}_FSS_numer'][window][t] += numer
                        results_dict[f'{var}_FSS_denom'][window][t] += denom
                   
        for key in ['Full Domain', 'Convective Regions']:
            for v in self.target_vars:
                results_dict[key][v]/=N
    
        return EvaluationResults(results_dict, self.evaluator_config)
    
    def _compare_datasets(self, t, name_1, name_2, dataset_1, dataset_2, results_dict):
        """
        Generalized comparison between two datasets using the specified methods.

        Parameters:
        t: The current timestep being evaluated.
        name_1: Name of the first dataset (e.g., 'predictions').
        name_2: Name of the second dataset (e.g., 'targets').
        dataset_1: The actual data for the first dataset.
        dataset_2: The actual data for the second dataset.
        results_dict: Dictionary to store results for each comparison.

        Returns:
        Updated results_dict with the comparison results.
        """
        # Perform comparisons
        for method in ['object_matching']:
            comparison_fn = getattr(self, method, None)
  
            result = comparison_fn(dataset_1, dataset_2, name_1, name_2, results_dict)
                
            # Accumulate the hits, false alarms, and misses
            for key in self.obj_match_metrics:
                results_dict[f'{name_1}_vs_{name_2}_{method}_{key}'][t] += result[key]
               
        return results_dict

    
    def _object_id(self, dataset, dataset_type=None): 
        """
        Identifies objects in the dataset using a threshold based on the dataset type.

        Parameters:
        dataset: The input dataset, which could be an xarray DataArray or a numpy array.
        dataset_type: A string indicating the type of dataset ('mrms', 'predictions', 'targets').
                  If None, the function will infer the type based on the dataset.

        Returns:
        labels: Labeled objects in the dataset.
        props: Properties of the identified objects.
        """
        # Determine the threshold based on dataset type or dataset itself
        if dataset_type == 'mrms' or isinstance(dataset, np.ndarray):
            thresh = self.evaluator_config.mrms_dbz_thresh
        else:
            thresh = self.evaluator_config.wofs_dbz_thresh
    
        # Handle numpy arrays (mrms_dz) differently if necessary
        if isinstance(dataset, np.ndarray):
            data = dataset
        else:
            data = dataset['COMPOSITE_REFL_10CM']
    
        # Apply the object identification process
        labels, props = monte_python.label(
            input_data=data,
            method='single_threshold',
            return_object_properties=True, 
            params={'bdry_thresh': thresh}
        )
        
        # Apply QC'ing 
        labels, props = self.qcer.quality_control(
            data, labels, props, self.qc_params)
        
        return labels, props

    
    def object_matching(self, dataset_1, dataset_2, name_1, name_2, results_dict):
        """Perform object matching between two datasets."""
        labels_1, props_1 = self._object_id(dataset_1, name_1)
        labels_2, props_2 = self._object_id(dataset_2, name_2)

        self.obj_verifier.update_metrics(labels_2, labels_1)
        result = {key: getattr(self.obj_verifier, f"{key}_") for key in ["hits", "false_alarms", "misses"]}
        self.obj_verifier.reset_metrics()
        
        return result
    
    def accumulate_rmse(self, dataset_1, dataset_2, results_dict):
        """Accumulate RMSE for each prediction."""
        for var in self.target_vars:
        
            # Compute full domain RMSE while ignoring borders and preserving the time dimension
            rmse = self.rmse_ignoring_borders(dataset_1[var], dataset_2[var])

            # Compute RMSE where comp. refl > 3
            pred_refl_mask = (dataset_1['COMPOSITE_REFL_10CM'] > 3)
            tar_refl_mask = (dataset_2['COMPOSITE_REFL_10CM'] > 3)
        
            # Combine the masks with logical OR to create the composite reflectivity mask
            refl_mask = pred_refl_mask | tar_refl_mask
        
            # Apply the mask and compute RMSE in convection while preserving the time dimension
            rmse_conv = self.rmse_in_convection(dataset_1[var], dataset_2[var], refl_mask)
        
            # Accumulate RMSE values
            results_dict['Full Domain'][var] += rmse
            results_dict['Convective Regions'][var] += rmse_conv
        
        return results_dict

    def rmse_ignoring_borders(self, predictions, targets):
                
        # Ensure BORDER_MASK is broadcasted to the correct shape
        border_mask = self.BORDER_MASK
    
        # Broadcast the mask to match the shape of predictions/targets if necessary
        if border_mask.shape != predictions.shape:
            border_mask = np.broadcast_to(border_mask, predictions.shape)
    
        # Set the errors at the borders to NaN
        err = (predictions - targets)**2
        err = xr.where(border_mask, np.nan, err)  # Apply the border mask
    
        # Compute mean squared error while preserving the 'time' dimension
        mse = mean_preserving_time(err)
    
        # Calculate the RMSE
        rmse = np.sqrt(mse)
    
        return rmse

    def rmse_in_convection(self, predictions, targets, refl_mask):
    
        # Set the errors at the borders to NaN
        err = (predictions - targets)**2
        err = xr.where(refl_mask, err, np.nan)  # Apply the refl mask
    
        # Compute mean squared error while preserving 'time' dimension
        mse = mean_preserving_time(err)
    
        # Calculate the RMSE
        rmse = np.sqrt(mse)
        return rmse
   
    def fractions_skill_score(self, dataset_1, dataset_2, window, thresh):
        """Compute the FSS."""
        binary_pred = (dataset_1 >= thresh).astype(float)
        binary_true = (dataset_2 >= thresh).astype(float)
    
        NP_pred = uniform_filter(binary_pred, window, mode='constant')
        NP_true = uniform_filter(binary_true, window, mode='constant')

        numer = ((NP_pred - NP_true)**2).sum()
        denom = (NP_pred**2 + NP_true**2).sum()
    
        return numer, denom


class EvaluationResults:
    def __init__(self, results_dict, evaluator_config): 
        self.results_dict = results_dict
        self.evaluator_config = evaluator_config 
        
    def to_fss_dataframe(self):

        FSS_data = {}
        for var in self.evaluator_config.FSS_vars:
            for window in self.evaluator_config.window_list: 
                se = self.results_dict[f'{var}_FSS_numer'][window]
                potential_se = self.results_dict[f'{var}_FSS_denom'][window]
                FSS_data[f'{var}_{window*3}km'] = 1 - se / potential_se

        df = pd.DataFrame(FSS_data)        
                
        return df
    
    def to_rmse_dataframe(self):
        # Save the RMSE results. 
        df = self.rmse_dict_to_dataframe(self.results_dict)
        #out_path = '/work/mflora/wofs-cast-data/verification_results'
        #df.to_parquet(os.path.join(out_path, f"MSE_{os.path.basename(MODEL_PATH).replace('.npz','')}{tag}.parquet"))
        return df 
    
    def to_parquet(self, df, path = '/work/mflora/wofs-cast-data/verification_results/results.parquet'):
        df.to_parquet(path)
        return f'Saved dataframe to {path}'
    
    def to_json(self, df, path):
        df.to_json(path)
        return f'Saved dataframe to {path}'
    
    def replace_zeros(self, data): 
        return np.where(data==0, 1e-5, data)

    def rmse_dict_to_dataframe(self, rmse_dict):
        """
        Convert a nested dictionary of xarray.DataArray objects to a pandas DataFrame.
    
        Parameters:
        - rmse_dict: dict, nested dictionary with RMSE values
    
        Returns:
        - pd.DataFrame, DataFrame with hierarchical indexing
        """
        data = []

        for key1, nested_dict in rmse_dict.items():
            if key1 not in ['Full Domain', 'Convective Regions']:
                continue
            
            for key2, data_array in nested_dict.items():
                # Ensure data_array is an xarray.DataArray
                if isinstance(data_array, xr.DataArray):
                    # Extract values and timesteps
                    values = data_array.values
                    timesteps = data_array.coords['time'].values if 'time' in data_array.coords else range(len(values))
                
                    for timestep, value in zip(timesteps, values):
                        data.append((key1, key2, timestep, value))

        # Create a DataFrame
        df = pd.DataFrame(data, columns=['Category', 'Variable', 'Time', 'RMSE'])
    
        # Set hierarchical index
        df.set_index(['Category', 'Variable', 'Time'], inplace=True)
    
        return df 
    
    def to_contigency_table_dataframe(self):
        """Calculate object-based metrics like POD, SR, CSI, and FB."""
        subkeys = [k.replace('_hits', '') for k in result.keys() if 'vs' in k and 'hits' in k]
        
        data = {}
        for subkey in subkeys:
            hits = self.replace_zeros(self.results_dict[f'{subkey}_hits'])
            misses = self.replace_zeros(self.results_dict[f'{subkey}_misses'])
            false_alarms = self.replace_zeros(self.results_dict[f'{subkey}_false_alarms'])

            pod = hits / (hits + misses)
            sr = hits / (hits + false_alarms)
            csi = hits / (hits + misses + false_alarms)
            fb = pod / sr

            subkey = subkey.replace('_object_matching', '')
            
            data[f'{subkey}_POD'] = pod
            data[f'{subkey}_SR'] = sr
            data[f'{subkey}_CSI'] = csi
            data[f'{subkey}_FB'] = fb

        return pd.DataFrame(data)
    

In [142]:
base_path = '/work/mflora/wofs-cast-data/datasets_2hr_zarr/'
fname = 'wrfwof_2021-05-15_020000_to_2021-05-15_041000__10min__ens_mem_09.zarr'
data_paths = [os.path.join(base_path, '2021', fname)]

In [143]:
evaluator = ThunderScoreEvaluator(evaluator_config)

# Define comparison pairs
comparison_pairs = [('predictions', 'targets'), ('predictions', 'mrms'), ('targets', 'mrms')]

# Evaluate and compare the datasets over time
results = evaluator.score(data_paths, comparison_pairs)

Evaluating Model:   0%|                                                                                                                      | 0/1 [00:00<?, ?it/s]

Evaluating /work/mflora/wofs-cast-data/datasets_2hr_zarr/2021/wrfwof_2021-05-15_020000_to_2021-05-15_041000__10min__ens_mem_09.zarr...


Evaluating Model: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:04<00:00,  4.43s/it]


In [144]:
df = results.to_fss_dataframe()

In [145]:
df

Unnamed: 0,COMPOSITE_REFL_10CM_21km,COMPOSITE_REFL_10CM_45km,COMPOSITE_REFL_10CM_81km,RAIN_AMOUNT_21km,RAIN_AMOUNT_45km,RAIN_AMOUNT_81km
0,0.992995,0.995859,0.996322,0.985275,0.987627,0.989704
1,0.987353,0.992474,0.993551,0.944425,0.957701,0.969695
2,0.960419,0.972555,0.977899,0.93706,0.933582,0.942774
3,0.956222,0.971475,0.974203,0.944802,0.962163,0.9804
4,0.947307,0.972797,0.98138,0.923468,0.947735,0.959892
5,0.943043,0.972544,0.984634,0.946792,0.962542,0.97252
6,0.920733,0.959587,0.974634,0.847474,0.876782,0.89981
7,0.935894,0.971082,0.983281,0.743232,0.753838,0.76937
8,0.913015,0.961353,0.983658,0.426388,0.565623,0.714383
9,0.917596,0.968106,0.990269,0.0,0.051142,0.276246
