In [23]:
import numpy as np
import xarray as xr

import sys
sys.path.append(r'/home/ch23/ML-BEES_yk/ML-BEES-eval/eval_utilities')

from eval_utilities import spatial_temporal_metrics as stm
from eval_utilities import visualization as vis
import matplotlib.pyplot as plt

import pandas as pd
import visualization
import os
import glob

In [24]:
import yaml
with open(f"config.yaml") as stream:
    try:
        CONFIG = yaml.safe_load(stream)
    except yaml.YAMLError as exc:
        print(exc)
# load the predicted variables
variables = CONFIG["targets_prog"] + CONFIG["targets_diag"]


In [25]:
# load all the ensemble members and collect into 
# load ensembles name in a list
# prepare the numpy ensemble array
import dask.array as da

def find_files_with_name(directory, filename):
    # Create a pattern for glob
    pattern = os.path.join(directory, f'*{filename}*')
    
    # Use glob to find all files matching the pattern
    matching_files = glob.glob(pattern)
    
    return matching_files

# Example usage
directory_path = '/data/ch23/data_ch23/unimp_ens'  # Replace with your folder path
file_name = 'unimp'  # File name to search for
files_list = find_files_with_name(directory_path, file_name)

# preprocess .zarr to da.array
ens_file_list=[]
for i,ens_file in enumerate(files_list):

    ens1=xr.open_zarr(ens_file)
    desired_chunks = (4, 10051, 17)  # Adjust based on your desired chunk sizes
    ens1 = ens1.chunk({'time': 4, 'x': 10051, 'variable': 17})
    ens1_array=ens1.data
    ens_file_list.append(ens1_array)

stacked_ens = da.stack(ens_file_list)
y_pred=stacked_ens

In [26]:
# load groundtruth
v1=xr.open_zarr("/data/ch23/data_ch23/unimp_ens/euro_unimp_1_train_2010_2019_val_2020_2020.zarr")

train_ds = xr.open_zarr("/data/ecland_i6aj_o400_2010_2022_6h_euro.zarr").sel(time=slice("2020", "2022"),variable=variables)  
# select the same variable list as prediction
y_true=train_ds.data # make sure y_true is dask

In [27]:
crps=xr.open_zarr("/data/ch23/evalution_results/uncertainty/crps_unimp_test_dask.zarr")

In [28]:
def vis_zarr_map(zarr_eval, var, path_png, min_perc, max_perc, time_point=False):

    """
    Visualize the original zarr file -- ecland or ai-land output;
    select a single time point of one variable; Or plot the metrics for one variable;
    save the figure to the path

    --- Parameters ---
    zarr_eval:   the zarr file; zarr should be xarray.Dataset
    vars:       str or iterable of str
    path_png:   path to save the figure; should include the metrics name if plot the metric
    min_prec:   percentile for lower limit, by default 1%
    max_prec:   percentile for upper limit, by default 99%
    time_point:   bool-by daulft False or int

    --- Returns ---
    show the map and save in the path
    """
    if time_point==False:
        zarr_eval_selected = zarr_eval.sel(variable=var)
    else:
        zarr_eval_selected = zarr_eval.isel(time=time_point).sel(variable=var)

    # Create the scatter plot
    fig, ax = plt.subplots(figsize=(12, 6))

    # filter the nan and inf value to calculate min/max percentile

    valid_mask = ~np.isnan(zarr_eval_selected.crps.values) & ~np.isinf(zarr_eval_selected.crps.values)

    # Filter the array to keep only the valid values
    compressed_array = zarr_eval_selected.crps.values[valid_mask]

    # pre-define a min and max for a quick visualization; vmin/vmax based on the 1 and 99 percentile 
    vmin=np.percentile(compressed_array, min_perc, axis=0)
    vmax=np.percentile(compressed_array, max_perc, axis=0)

    scatter = zarr_eval_selected.plot.scatter(
        x="lon", y="lat", hue="crps", s=10, edgecolors="none", ax=ax, vmin=vmin,vmax=vmax)
    
    # Increase font sizes
    ax.set_xlabel(ax.get_xlabel(), fontsize=16)
    ax.set_ylabel(ax.get_ylabel(), fontsize=16)
    ax.set_title(ax.get_title(), fontsize=18)
    ax.tick_params(labelsize=14)
    if scatter.colorbar is not None:
        scatter.colorbar.ax.tick_params(labelsize=14)
        scatter.colorbar.set_label("Data", fontsize=16)  # Set the label for the colorbar
    
    fig.savefig(path_png+'_%s.png' % var, bbox_inches="tight") # path_png should include the metrics name

    #plt.show()
    # Close the figure to prevent it from displaying
    plt.close(fig)

In [21]:
figure_path='/data/ch23/evalution_results/uncertainty/visualization/'

for var in crps.variable.values:
    vis_zarr_map(crps, var, 
                            figure_path+'crps'
                            ,1,99)

In [29]:
# try the dask version
def dask_sort_along_axis(arr, axis=0):
    """
    Sort a Dask array along a specified axis using a custom function.
    """
    return da.map_blocks(lambda x: np.sort(x, axis=axis), arr, dtype=arr.dtype)


def crps_dask(y_true, y_pred, time=True, sample_weight=None, norm=False):
    """
    Calculate Continuous Ranked Probability Score -- CRPS is measured in the same units as the variable
    Data based on size (time, lat*lon, vars) where N=number of samples (in time) and each grid point will have one value
    Args:
        y_true (np.array): Ground truth with shape (time, lat*lon, vars).
        y_pred (np.array): Predicted values from n_seeds ensembles with shape (n_seeds, time, lat*lon, vars).
        sample_weight (np.array, optional): Sample weights.
        norm (bool, optional): Flag to normalize the CRPS scores.
    
    Returns:
        np.array: CRPS score for each height profile (lat*lon, vars).Returns:
        
    modified based on@https://github.com/lm2612/WaveNet_UQ/
    """
    # Number of ensemble predictions
    num_samples = y_pred.shape[0]
    
    # Sort predictions along the ensemble axis
    #y_pred = da.sort(y_pred, axis=0)

    y_pred = dask_sort_along_axis(y_pred, axis=0)
    
    # Calculate differences between consecutive sorted predictions
    diff = y_pred[1:] - y_pred[:-1]
    
    # Calculate weights for CRPS calculation
    weight = da.arange(1, num_samples) * da.arange(num_samples - 1, 0, -1)
    #weight = weight[:, da.newaxis, da.newaxis, da.newaxis]
    weight = da.asarray(weight[:, None, None, None])
    weight = weight.rechunk((weight.shape[0], 1, 1, 1))
    
    # Calculate the absolute error
    y_true_expanded = y_true.expand_dims(dim="ensemble", axis=0)

    # Convert to Dask array if necessary
    y_true_dask = y_true_expanded.data

    #absolute_error = da.mean(da.abs(y_pred - da.expand_dims(y_true, 0)), axis=0)
    absolute_error = da.mean(da.abs(y_pred - y_true_dask), axis=0)
    
    # Calculate per observation CRPS
    per_obs_crps = absolute_error - da.sum(diff * weight, axis=0) / num_samples**2

    if time==False:
        return per_obs_crps

    # Normalization if required
    if norm:
        crps_normalized = da.where(da.abs(y_true) > 1E-14, per_obs_crps / da.abs(y_true), da.nan)
        return da.nanmean(crps_normalized, axis=0)
    
    # Return the weighted average CRPS
    if time:
        return da.average(per_obs_crps, axis=0, weights=sample_weight)
crps_score_time_series = crps_dask(y_true[:,:,:17], y_pred[:,:,:,:17], time=False, sample_weight=None, norm=False)


In [4]:
crps_score_time_series

NameError: name 'crps_score_time_series' is not defined

In [14]:
# Create a new xarray Dataset with CRPS scores
crps_score_time_series_ds = xr.Dataset(
    {
        "crps": (("time","x", "variable"), crps_score_time_series) # 3 dim
    },
    coords={
        "time": ("time", v1.time.values),
        "lat": ("x", v1.lat.values),
        "lon": ("x", v1.lon.values),
        "variable": v1.variable.values,
    }
)
# Save the new dataset as a .zarr file
crps_score_time_series_ds.to_zarr("/data/ch23/evalution_results/uncertainty/crps_unimp_dask_timeseries.zarr")

: 

In [30]:
# calculate the time series and uncertainty level of crps over all the grid points

data=crps_score_time_series[:,:,0] 
data=data.T
mean = da.mean(data, axis=0)
std = da.std(data, axis=0)
confidence_interval = 1.96 * std / da.sqrt(data.shape[0])
# Generate datetime index
date_range = pd.date_range(start='2020-01-01 00:00:00', end='2022-11-30 18:00:00', periods=4260)


In [11]:
mean=mean.astype(np.float32)
confidence_interval=confidence_interval.astype(np.float32)

In [32]:
plt.figure(figsize=(10, 6))

# Plot each realization
#for i in range(data.shape[0]):
#    plt.plot(date_range,data[i], color='gray', alpha=0.1)

# Plot the mean time series
plt.plot(date_range,mean, color='blue', label='Mean')

# Shade the confidence interval
plt.fill_between(date_range, mean - confidence_interval, mean + confidence_interval, color='blue', alpha=0.3, label='95% Confidence Interval')

plt.xlabel('Time Steps', fontsize=14)
plt.ylabel('swvl1', fontsize=14)
plt.title('CRPS time seires of swvl1 over all grid cells', fontsize=14)
plt.legend(fontsize=12)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.xlim(pd.Timestamp('2020-01-01 00:00:00'), pd.Timestamp('2022-11-30 18:00:00'))
plt.gcf().autofmt_xdate()
plt.show()

: 