In [1]:
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'true'
# Set this lower, to allow for PyTorch Model to fit into memory
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.90' 

import sys
package_path = os.path.dirname(os.path.dirname(os.getcwd())) 
sys.path.insert(0, package_path)
from glob import glob 

from wofscast.common.wofs_data_loader import WoFSDataLoader
from wofscast.evaluate.metrics import (MSE,
                                       ObjectBasedContingencyStats,
                                       PowerSpectra,
                                       FractionsSkillScore,
                                       PMMStormStructure,
                                       )

from wofscast.evaluate.predictor import Predictor
from wofscast.evaluate.object_ider import ObjectIder
from wofscast.evaluate.evaluator import Evaluator, EvaluatorConfig 

import numpy as np

In [2]:
config = EvaluatorConfig(
    data_path = '/work/mflora/wofs-cast-data/datasets_2hr_zarr/2021/*_ens_mem_09.zarr',
    n_samples = 3 ,
    seed = 42,
    model_path = '/work/cpotvin/WOFSCAST/model/wofscast_test_v178.npz',
    add_diffusion = False,
    load_ensemble = False,
    spectra_variables = ['COMPOSITE_REFL_10CM', 'T2', 'W'],
    pmm_variables = ['COMPOSITE_REFL_10CM', 'T2', 'RAIN_AMOUNT', 'WMAX'],
    fss_variables = ['COMPOSITE_REFL_10CM', 'RAIN_AMOUNT'],
    fss_windows = [7, 15, 27],
    fss_thresh_dict = {'COMPOSITE_REFL_10CM' : [40.0], 
                       'RAIN_AMOUNT' : [25.4/2], # 0.5 in
                       },
    matching_distance_km = 42 ,
    grid_spacing_km=3.0,
    out_base_path = '/work2/mflora/verification_datasets'
)

In [3]:
# Selecting a single ensemble member. 
paths = glob(config.data_path)
paths.sort()
n_samples=config.n_samples

rs = np.random.RandomState(config.seed)
random_paths = rs.choice(paths, size=n_samples, replace=False)

model = Predictor(
    model_path = config.model_path,
    add_diffusion=config.add_diffusion 
)

data_loader = WoFSDataLoader(model.task_config, 
                             model.preprocess_fn, 
                             config.load_ensemble, 
                             model.decode_times)  

object_ider = ObjectIder()

metrics = [MSE(),  
           
           ObjectBasedContingencyStats(config.matching_distance_km / config.grid_spacing_km), 
           
           PowerSpectra(variables=config.spectra_variables),
           
           FractionsSkillScore(windows=config.fss_windows, 
                               thresh_dict=config.fss_thresh_dict,
                             variables = config.fss_variables),
           
           PMMStormStructure(config.pmm_variables)
           ]

In [4]:
%%time
evaluator = Evaluator(model, object_ider, data_loader, metrics)
results_ds = evaluator.evaluate(random_paths)
evaluator.save(results_ds, config)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:19<00:00,  6.41s/it]

CPU times: user 2min 17s, sys: 6.12 s, total: 2min 23s
Wall time: 19.3 s





'Saved results dataset to /work2/mflora/verification_datasets/wofscast_test_v178_results_v1.nc!'