# Prepare the environment

In [None]:
! pip install "xarray[complete]"

Collecting sparse (from xarray[complete])
  Downloading sparse-0.15.4-py2.py3-none-any.whl.metadata (4.5 kB)
Collecting numbagg (from xarray[complete])
  Downloading numbagg-0.8.2-py3-none-any.whl.metadata (47 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m47.5/47.5 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
Collecting flox (from xarray[complete])
  Downloading flox-0.9.15-py3-none-any.whl.metadata (17 kB)
Collecting cartopy (from xarray[complete])
  Downloading Cartopy-0.24.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.9 kB)
Collecting nc-time-axis (from xarray[complete])
  Downloading nc_time_axis-1.4.1-py3-none-any.whl.metadata (4.7 kB)
Collecting netCDF4 (from xarray[complete])
  Downloading netCDF4-1.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.8 kB)
Collecting zarr (from xarray[complete])
  Downloading zarr-2.18.3-py3-none-any.whl.metadata (5.7 kB)
Collecting cftime (from xarray[complete])
  Download

# Import the necessary libraries

In [None]:
import os
import datetime
import numpy as np
import pandas as pd
import xarray as xr
from pathlib import Path
from tqdm.notebook import tqdm, trange

In [None]:
from google.colab import auth
auth.authenticate_user()

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# work_dir = Path("/content/drive/MyDrive/WB2BinaryForecast/")
work_dir = Path.cwd()

# Load the thresholds

In [None]:
tx24h_era5_threshold = xr.open_dataset(work_dir / "Threshold_tx24h_150_2020_era5.nc")["2m_temperature"].load()
tx24h_hres0_threshold = xr.open_dataset(work_dir / "Threshold_tx24h_150_2020_hres0.nc")["2m_temperature"].load()
tp24h_era5_threshold = xr.open_dataset(work_dir / "Threshold_tp24h_150_2020_era5.nc")["total_precipitation_24hr"].load()

# Calculate the verification metrics of binary forecasts

In [None]:
! pip install scores==1.2.0

Collecting scores
  Downloading scores-1.2.0-py3-none-any.whl.metadata (12 kB)
Downloading scores-1.2.0-py3-none-any.whl (109 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.0/110.0 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: scores
Successfully installed scores-1.2.0


In [None]:
import scores

## Define necessary functions

This part is a copy from the [WeatherBench 2] (https://github.com/google-research/weatherbench2)

In [None]:
def _assert_increasing(x: np.ndarray):
    if not (np.diff(x) > 0).all():
        raise ValueError(f"array is not increasing: {x}")


def _latitude_cell_bounds(x: np.ndarray) -> np.ndarray:
    pi_over_2 = np.array([np.pi / 2], dtype=x.dtype)
    return np.concatenate([-pi_over_2, (x[:-1] + x[1:]) / 2, pi_over_2])


def _cell_area_from_latitude(points: np.ndarray) -> np.ndarray:
    """Calculate the area overlap as a function of latitude."""
    bounds = _latitude_cell_bounds(points)
    _assert_increasing(bounds)
    upper = bounds[1:]
    lower = bounds[:-1]
    # normalized cell area: integral from lower to upper of cos(latitude)
    return np.sin(upper) - np.sin(lower)


def get_lat_weights(ds: xr.Dataset) -> xr.DataArray:
    """Computes latitude/area weights from latitude coordinate of dataset."""
    weights = _cell_area_from_latitude(np.deg2rad(ds.latitude.data))
    weights /= np.mean(weights)
    weights = ds.latitude.copy(data=weights)
    return weights


def _spatial_average(
    dataset: xr.Dataset, region: np.ndarray = None, skipna: bool = True
) -> xr.Dataset:
    """Compute spatial average after applying region mask.

    Args:
    dataset: Metric dataset as a function of latitude/longitude.
    region: Region object (optional).
    skipna: Skip NaNs in spatial mean.

    Returns:
    dataset: Spatially averaged metric.
    """
    weights = get_lat_weights(dataset)
    if region is not None:
        dataset, weights = region.apply(dataset, weights)
        # ignore NaN/Inf values in regions with zero weight
        dataset = dataset.where(weights > 0, 0)
    return dataset.weighted(weights).mean(
      ["latitude", "longitude"], skipna=skipna
    )


def calcu_MSE(
      forecast: xr.Dataset,
      truth: xr.Dataset,
      region: np.ndarray = None,
  ) -> xr.Dataset:
    results = _spatial_average((forecast - truth) ** 2, region=region)

    return results


In [None]:
import dataclasses
import typing as t

import numpy as np
import xarray as xr


@dataclasses.dataclass
class Region:
    """Region selector for spatially averaged metrics.

    .apply() method is called before spatial averaging in the Metrics classes.
    Region selection can be either applied as an operation on the dataset itself
    or a weights dataset, typically the latitude weights. The latter option is
    required to implement non-box regions without the use of .where() which would
    clash with skipna=False used as default in the metrics. The way this is
    implemented is by multiplying the input weights with a boolean weight dataset.

    Since sometimes the dataset and sometimes the weights are modified, these must
    be used together, most likely insice the _spatial_average function defined in
    metrics.py.
    """

    def apply(self, dataset: xr.Dataset) -> xr.Dataset:
        """Apply region selection to dataset and/or weights.

        Args:
          dataset: Spatial metric, i.e. RMSE
          weights: Weights dataset, i.e. latitude weights

        Returns:
          dataset: Potentially modified (sliced) dataset.
          weights: Potentially modified weights data array, to be used in
          combination with dataset, e.g. in _spatial_average().
        """
        raise NotImplementedError


@dataclasses.dataclass
class SliceRegion(Region):
    """Latitude-longitude box selection."""

    lat_slice: t.Optional[t.Union[slice, list[slice]]] = dataclasses.field(
      default_factory=lambda: slice(None, None))
    lon_slice: t.Optional[t.Union[slice, list[slice]]] = dataclasses.field(
      default_factory=lambda: slice(None, None))

    def apply(self, dataset: xr.Dataset) -> xr.Dataset:
        """Returns dataset sliced according to lat/lon_sliceparameters."""
        lats = (
            self.lat_slice if isinstance(self.lat_slice, list) else [self.lat_slice])
        lons = (
            self.lon_slice if isinstance(self.lon_slice, list) else [self.lon_slice])

        lats = xr.concat(
            [dataset.latitude.sel(latitude=s) for s in lats], dim='latitude')
        lons = xr.concat(
            [dataset.longitude.sel(longitude=s) for s in lons], dim='longitude')

        return dataset.sel(latitude=lats, longitude=lons)


In [None]:
predefined_regions = {
    'global': SliceRegion(),
    'tropics': SliceRegion(lat_slice=slice(-20, 20)),
    'extra-tropics': SliceRegion(
      lat_slice=[slice(None, -20), slice(20, None)]
    ),
    'northern-hemisphere': SliceRegion(lat_slice=slice(20, None)),
    'southern-hemisphere': SliceRegion(lat_slice=slice(None, -20)),
    'europe': SliceRegion(
      lat_slice=slice(35, 75),
      lon_slice=[slice(360 - 12.5, None), slice(0, 42.5)],
    ),
    'north-america': SliceRegion(
      lat_slice=slice(25, 60), lon_slice=slice(360 - 120, 360 - 75)
    ),
    'north-atlantic': SliceRegion(
      lat_slice=slice(25, 65), lon_slice=slice(360 - 70, 360 - 10)
    ),
    'north-pacific': SliceRegion(
      lat_slice=slice(25, 60), lon_slice=slice(145, 360 - 130)
    ),
    'east-asia': SliceRegion(
      lat_slice=slice(25, 60), lon_slice=slice(102.5, 150)
    ),
    'ausnz': SliceRegion(
      lat_slice=slice(-45, -12.5), lon_slice=slice(120, 175)
    ),
    'arctic': SliceRegion(lat_slice=slice(60, 90)),
    'antarctic': SliceRegion(lat_slice=slice(-90, -60)),
}

## Load the binary forecasts

In [None]:
tx24h_era5_extremes = xr.open_dataset(work_dir / "Extremes/Extremes_tx24h_era5_150_2020.nc")["events"].load()
tx24h_hres0_extremes = xr.open_dataset(work_dir / "Extremes/Extremes_tx24h_hres0_150_2020.nc")["events"].load()
tp24h_era5_extremes = xr.open_dataset(work_dir / "Extremes/Extremes_tp24h_era5_150_2020.nc")["events"].load()

In [None]:
resolution = 1.50
year = 2020
lead_times = np.arange(1, 11)
model_list_heatwave = ["hres", "hres_ens_mean", "era5-forecast",
               "pangu-oper","graphcast-oper", "pangu", "graphcast","fuxi"]
model_list_rainfall = ["hres", "hres_ens_mean", "graphcast", "graphcast-oper", "fuxi"]

## Region scale

In [None]:
tx24h_results = []
for model_name in tqdm(model_list_heatwave):
    lead_time_data = []
    for lead_time in tqdm(lead_times, leave=False):
        forecast_path = work_dir/"Extremes/Extremes_t2max_{}_{}_{}_{}d.nc".format(
            str(resolution).replace(".", "").ljust(3, "0"), str(year),model_name, str(lead_time).zfill(2))

        if model_name in ["hres", "hres_ens_mean"]:
          true_events = tx24h_hres0_extremes
        else:
          true_events = tx24h_era5_extremes

        forecast_dataset = xr.open_dataset(forecast_path)["events"]

        region_data = []
        for region_name, region in tqdm(predefined_regions.items(), leave=False):

            forecast_data = region.apply(forecast_dataset)
            true_data = region.apply(true_events.sel(time=forecast_data.time))

            weights = get_lat_weights(forecast_data)

            contingency_manager = scores.categorical.BinaryContingencyManager(forecast_data, true_data)
            counts = contingency_manager.transform(preserve_dims=["latitude", "longitude", "quantile"]).counts

            new_counts = {}
            for key, arr in counts.items():
                new_counts[key] = (arr * weights).sum(["latitude", "longitude"])
            new_contingency_manager = scores.categorical.BasicContingencyManager(new_counts)

            metrics_data = xr.Dataset({"TP": new_contingency_manager.counts["tp_count"],
                                       "TN": new_contingency_manager.counts["tn_count"],
                                       "FP": new_contingency_manager.counts["fp_count"],
                                       "FN": new_contingency_manager.counts["fn_count"],
                                       "Count": new_contingency_manager.counts["total_count"],
                                      #  "SEDI": new_contingency_manager.symmetric_extremal_dependence_index(),
                                      #  "f1_score":new_contingency_manager.f1_score(),
                                      #  "accuracy":new_contingency_manager.accuracy(),
                                      #  "hit_rate":new_contingency_manager.probability_of_detection(),
                                      #  "false_alarm_rate": new_contingency_manager.probability_of_false_detection(),
                                      #  "heidke_skill_score": new_contingency_manager.heidke_skill_score()
                                       })

            region_data.append(metrics_data)
        region_data = xr.concat(region_data, dim="region")
        region_data["region"] = list(predefined_regions.keys())

        lead_time_data.append(region_data)
    lead_time_data = xr.concat(lead_time_data, dim="lead_time")
    lead_time_data["lead_time"] = lead_times

    tx24h_results.append(lead_time_data)

tx24h_results = xr.concat(tx24h_results, dim="model_name")
tx24h_results["model_name"] = model_list_heatwave

tx24h_results.to_netcdf(work_dir / "Metrics/Metrics_tx24h_region.nc")

In [None]:
tp24h_results = []
for model_name in tqdm(model_list_rainfall):
    lead_time_data = []
    for lead_time in tqdm(lead_times, leave=False):
        forecast_path = work_dir/"Extremes/Extremes_tp24h_{}_{}_{}_{}d.nc".format(
            str(resolution).replace(".", "").ljust(3, "0"), str(year), model_name, str(lead_time).zfill(2))

        true_events = tp24h_era5_extremes

        forecast_dataset = xr.open_dataset(forecast_path)["events"]

        region_data = []
        for region_name, region in tqdm(predefined_regions.items(), leave=False):

            forecast_data = region.apply(forecast_dataset)
            true_data = region.apply(true_events.sel(time=forecast_data.time))

            weights = get_lat_weights(forecast_data)

            contingency_manager = scores.categorical.BinaryContingencyManager(forecast_data, true_data)
            counts = contingency_manager.transform(preserve_dims=["latitude", "longitude", "quantile"]).counts

            new_counts = {}
            for key, arr in counts.items():
                new_counts[key] = (arr * weights).sum(["latitude", "longitude"])
            new_contingency_manager = scores.categorical.BasicContingencyManager(new_counts)

            metrics_data = xr.Dataset({"TP": new_contingency_manager.counts["tp_count"],
                                       "TN": new_contingency_manager.counts["tn_count"],
                                       "FP": new_contingency_manager.counts["fp_count"],
                                       "FN": new_contingency_manager.counts["fn_count"],
                                       "Count": new_contingency_manager.counts["total_count"],
                                      #  "SEDI": new_contingency_manager.symmetric_extremal_dependence_index(),
                                      #  "f1_score":new_contingency_manager.f1_score(),
                                      #  "accuracy":new_contingency_manager.accuracy(),
                                      #  "hit_rate":new_contingency_manager.probability_of_detection(),
                                      #  "false_alarm_rate": new_contingency_manager.probability_of_false_detection(),
                                      #  "heidke_skill_score": new_contingency_manager.heidke_skill_score()
                                       })
            region_data.append(metrics_data)
        region_data = xr.concat(region_data, dim="region")
        region_data["region"] = list(predefined_regions.keys())

        lead_time_data.append(region_data)
    lead_time_data = xr.concat(lead_time_data, dim="lead_time")
    lead_time_data["lead_time"] = lead_times

    tp24h_results.append(lead_time_data)

tp24h_results = xr.concat(tp24h_results, dim="model_name")
tp24h_results["model_name"] = model_list_rainfall

tp24h_results.to_netcdf(work_dir / "Metrics/Metrics_tp24h_region.nc")

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

## Grid cell

In [None]:
tx24h_results = []
for model_name in tqdm(model_list_heatwave):
    lead_time_data = []
    for lead_time in tqdm(lead_times, leave=False):
        forecast_path = work_dir/"Extremes/Extremes_t2max_{}_{}_{}_{}d.nc".format(
        str(resolution).replace(".", "").ljust(3, "0"), str(year),
         model_name, str(lead_time).zfill(2))

        if model_name in ["hres", "hres_ens_mean"]:
          true_events = tx24h_hres0_extremes
        else:
          true_events = tx24h_era5_extremes

        forecast_dataset = xr.open_dataset(forecast_path)["events"]

        forecast_data = forecast_dataset
        true_data = true_events.sel(time=forecast_data.time)

        weights = get_lat_weights(forecast_data)

        contingency_manager = scores.categorical.BinaryContingencyManager(forecast_data, true_data)
        new_contingency_manager = contingency_manager.transform(preserve_dims=["latitude", "longitude", "quantile"])

        metrics_data = xr.Dataset({"TP": new_contingency_manager.counts["tp_count"],
                                    "TN": new_contingency_manager.counts["tn_count"],
                                    "FP": new_contingency_manager.counts["fp_count"],
                                    "FN": new_contingency_manager.counts["fn_count"],
                                    "Count": new_contingency_manager.counts["total_count"],
                                    # "SEDI": new_contingency_manager.symmetric_extremal_dependence_index(),
                                    # "f1_score":new_contingency_manager.f1_score(),
                                    # "accuracy":new_contingency_manager.accuracy(),
                                    # "hit_rate":new_contingency_manager.probability_of_detection(),
                                    # "false_alarm_rate": new_contingency_manager.probability_of_false_detection(),
                                    # "heidke_skill_score": new_contingency_manager.heidke_skill_score()
                                    }
                                  )
        lead_time_data.append(metrics_data)

    lead_time_data = xr.concat(lead_time_data, dim="lead_time")
    lead_time_data["lead_time"] = lead_times

    tx24h_results.append(lead_time_data)

tx24h_results = xr.concat(tx24h_results, dim="model_name")
tx24h_results["model_name"] = model_list_heatwave

tx24h_results.to_netcdf(work_dir / "Metrics/Metrics_tx24h_grid.nc")

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

In [None]:
tp24h_results = []
for model_name in tqdm(model_list_rainfall):
    lead_time_data = []
    for lead_time in tqdm(lead_times, leave=False):
        forecast_path = work_dir/"Extremes/Extremes_tp24h_{}_{}_{}_{}d.nc".format(
            str(resolution).replace(".", "").ljust(3, "0"), str(year), model_name, str(lead_time).zfill(2))

        true_events = tp24h_era5_extremes
        forecast_dataset = xr.open_dataset(forecast_path)["events"]

        forecast_data = forecast_dataset
        true_data = true_events.sel(time=forecast_data.time)

        weights = get_lat_weights(forecast_data)

        contingency_manager = scores.categorical.BinaryContingencyManager(forecast_data, true_data)
        new_contingency_manager = contingency_manager.transform(preserve_dims=["latitude", "longitude", "quantile"])

        metrics_data = xr.Dataset({"TP": new_contingency_manager.counts["tp_count"],
                                    "TN": new_contingency_manager.counts["tn_count"],
                                    "FP": new_contingency_manager.counts["fp_count"],
                                    "FN": new_contingency_manager.counts["fn_count"],
                                    "Count": new_contingency_manager.counts["total_count"],
                                    # "SEDI": new_contingency_manager.symmetric_extremal_dependence_index(),
                                    # "f1_score":new_contingency_manager.f1_score(),
                                    # "accuracy":new_contingency_manager.accuracy(),
                                    # "hit_rate":new_contingency_manager.probability_of_detection(),
                                    # "false_alarm_rate": new_contingency_manager.probability_of_false_detection(),
                                    # "heidke_skill_score": new_contingency_manager.heidke_skill_score()
                                    })

        lead_time_data.append(metrics_data)

    lead_time_data = xr.concat(lead_time_data, dim="lead_time")
    lead_time_data["lead_time"] = lead_times

    tp24h_results.append(lead_time_data)

tp24h_results = xr.concat(tp24h_results, dim="model_name")
tp24h_results["model_name"] = model_list_rainfall

tp24h_results.to_netcdf(work_dir / "Metrics/Metrics_tp24h_grid.nc")

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

# Calculate Brier score

In [None]:
resolution = 1.50
year = 2020
lead_times = np.arange(1, 11)

model_list_heatwave = ["hres", "hres_ens", "hres_ens_mean", "era5-forecast",
                       "pangu-oper","graphcast-oper", "pangu", "graphcast","fuxi"]
model_list_rainfall = ["hres", "hres_ens", "hres_ens_mean", "graphcast", "graphcast-oper", "fuxi"]


In [None]:
tx24h_results = []
for model_name in tqdm(model_list_heatwave):
    lead_time_data = []
    for lead_time in tqdm(lead_times, leave=False):
        forecast_path = work_dir/"Extremes/Extremes_tx24h_{}_{}_{}_{}d.nc".format(
            str(resolution).replace(".", "").ljust(3, "0"), str(year), model_name, str(lead_time).zfill(2))

        if model_name in ["hres", "hres_ens", "hres_ens_mean"]:
            true_events = tx24h_hres0_extremes
        else:
            true_events = tx24h_era5_extremes
            
        forecast_dataset = xr.open_dataset(forecast_path)["events"].astype(np.int8)
        if model_name == "hres_ens":
            forecast_dataset = forecast_dataset/(50+1)
        true_events = true_events.sel(time=forecast_dataset.time)
        
        brier_score = ((forecast_dataset-true_events)**2).mean(["time"], skipna=True)
        
        metrics_data = xr.Dataset({"brier_score":brier_score})
        
        lead_time_data.append(metrics_data)

    lead_time_data = xr.concat(lead_time_data, dim="lead_time")
    lead_time_data["lead_time"] = lead_times

    tx24h_results.append(lead_time_data)

tx24h_results = xr.concat(tx24h_results, dim="model_name")
tx24h_results["model_name"] = model_list_heatwave

tx24h_results.to_netcdf(work_dir / "metrics"  / "Metrics_tx24h_grid_brier_score.nc")

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

In [None]:
tp24h_results = []
for model_name in tqdm(model_list_rainfall):
    lead_time_data = []
    for lead_time in tqdm(lead_times, leave=False):
        forecast_path = work_dir/"Extremes/Extremes_tp24h_{}_{}_{}_{}d.nc".format(
            str(resolution).replace(".", "").ljust(3, "0"), str(year), model_name, str(lead_time).zfill(2))

        true_events = tp24h_era5_extremes
        
        forecast_dataset = xr.open_dataset(forecast_path)["events"].astype(np.int8)
        if model_name == "hres_ens":
            forecast_dataset = forecast_dataset/(50+1)
        true_events = true_events.sel(time=forecast_dataset.time)
        
        brier_score = ((forecast_dataset-true_events)**2).mean(["time"], skipna=True)
        
        metrics_data = xr.Dataset({"brier_score":brier_score})
        
        lead_time_data.append(metrics_data)

    lead_time_data = xr.concat(lead_time_data, dim="lead_time")
    lead_time_data["lead_time"] = lead_times

    tp24h_results.append(lead_time_data)

tp24h_results = xr.concat(tp24h_results, dim="model_name")
tp24h_results["model_name"] = model_list_rainfall

tp24h_results.to_netcdf(work_dir / "metrics"  / "Metrics_tp24h_grid_brier_score.nc")

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

# Calculate ROCSS

In [None]:
from scores.probability import roc_curve_data

 This following function is revised based on the [scores](https://scores.readthedocs.io/en/stable/index.html) to support parallel computing for the dask.

In [None]:
import operator
from collections.abc import Iterable, Sequence
from typing import Optional

import numpy as np
import xarray as xr

from scores.categorical import probability_of_detection, probability_of_false_detection
from scores.processing import binary_discretise
from scores.utils import gather_dimensions

# trapz was deprecated in numpy 2.0, but trapezoid was not backported to
# earlier versions. As numpy 2.0 contains some API changes, `scores`
# will try to support both interchangeably for the time being
if not hasattr(np, "trapezoid"):
    np.trapezoid = np.trapz  # type: ignore # pragma: no cover  # tested manually


def roc_curve_data(  # pylint: disable=too-many-arguments
    fcst: xr.DataArray,
    obs: xr.DataArray,
    thresholds: Iterable[float],
    *,  # Force keywords arguments to be keyword-only
    reduce_dims: Optional[Sequence[str]] = None,
    preserve_dims: Optional[Sequence[str]] = None,
    weights: Optional[xr.DataArray] = None,
    check_args: bool = True,
) -> xr.Dataset:
    """
    Calculates data required for plotting a Receiver (Relative) Operating Characteristic (ROC)
    curve, including the area under the curve (AUC). The ROC curve is used as a way to measure
    the discrimination ability of a particular forecast.

    The AUC is the probability that the forecast probability of a random event is higher
    than the forecast probability of a random non-event.

    Args:
        fcst: An array of probabilistic forecasts for a binary event in the range [0, 1].
        obs: An array of binary values where 1 is an event and 0 is a non-event.
        thresholds: Monotonic increasing values between 0 and 1, the thresholds at and
          above which to convert the probabilistic forecast to a value of 1 (an 'event')
        reduce_dims: Optionally specify which dimensions to reduce when
            calculating the ROC curve data. All other dimensions will be preserved. As a
            special case, 'all' will allow all dimensions to be reduced. Only one
            of `reduce_dims` and `preserve_dims` can be supplied. The default behaviour
            if neither are supplied is to reduce all dims.
        preserve_dims: Optionally specify which dimensions to preserve
            when calculating ROC curve data. All other dimensions will be reduced.
            As a special case, 'all' will allow all dimensions to be
            preserved. In this case, the result will be in the same
            shape/dimensionality as the forecast, and the values will be
            the ROC curve at each point (i.e. single-value comparison
            against observed) for each threshold, and the forecast and observed dimensions
            must match precisely. Only one of `reduce_dims` and `preserve_dims` can be
            supplied. The default behaviour if neither are supplied is to reduce all dims.
        weights: Optionally provide an array for weighted averaging (e.g. by area, by latitude,
            by population, custom).
        check_args: Checks if `obs` data only contains values in the set
            {0, 1, np.nan}. You may want to skip this check if you are sure about your
            input data and want to improve the performance when working with dask.

    Returns:
        An xarray.Dataset with data variables:

        - 'POD' (the probability of detection)
        - 'POFD' (the probability of false detection)
        - 'AUC' (the area under the ROC curve)

        `POD` and `POFD` have dimensions `dims` + 'threshold', while `AUC` has
        dimensions `dims`.

    Raises:
        ValueError: if `fcst` contains values outside of the range [0, 1]
        ValueError: if `obs` contains non-nan values not in the set {0, 1}
        ValueError: if 'threshold' is a dimension in `fcst`.
        ValueError: if values in `thresholds` are not monotonic increasing or are outside
          the range [0, 1]


    Notes:
        The probabilistic `fcst` is converted to a deterministic forecast
        for each threshold in `thresholds`. If a value in `fcst` is greater
        than or equal to the threshold, then it is converted into a
        'forecast event' (fcst = 1), and a 'forecast non-event' (fcst = 0)
        otherwise. The probability of detection (POD) and probability of false
        detection (POFD) are calculated for the converted forecast. From the
        POD and POFD data, the area under the ROC curve is calculated.

        Ideally concave ROC curves should be generated rather than traditional
        ROC curves.

    """
    if check_args:
        if fcst.max().item() > 1 or fcst.min().item() < 0:
            raise ValueError("`fcst` contains values outside of the range [0, 1]")

        if np.max(thresholds) > 1 or np.min(thresholds) < 0:  # type: ignore
            raise ValueError("`thresholds` contains values outside of the range [0, 1]")

        if not np.all(np.array(thresholds)[1:] >= np.array(thresholds)[:-1]):
            raise ValueError("`thresholds` is not monotonic increasing between 0 and 1")

    # make a discrete forecast for each threshold in thresholds
    # discrete_fcst has an extra dimension 'threshold'
    discrete_fcst = binary_discretise(fcst, thresholds, ">=")  # type: ignore

    all_dims = set(fcst.dims).union(set(obs.dims))
    final_reduce_dims = gather_dimensions(fcst.dims, obs.dims, reduce_dims=reduce_dims, preserve_dims=preserve_dims)
    final_preserve_dims = all_dims - set(final_reduce_dims)  # type: ignore
    auc_dims = () if final_preserve_dims is None else tuple(final_preserve_dims)
    final_preserve_dims = auc_dims + ("threshold",)  # type: ignore[assignment]

    pod = probability_of_detection(
        discrete_fcst, obs, preserve_dims=final_preserve_dims, weights=weights, check_args=check_args
    )

    pofd = probability_of_false_detection(
        discrete_fcst, obs, preserve_dims=final_preserve_dims, weights=weights, check_args=check_args
    )

    # Need to ensure ordering of dims is consistent for xr.apply_ufunc
    pod = pod.transpose(*final_preserve_dims)
    pofd = pofd.transpose(*final_preserve_dims)

    auc = -1 * xr.apply_ufunc(
        np.trapezoid,
        pod,
        pofd,
        input_core_dims=[pod.dims, pofd.dims],  # type: ignore
        output_core_dims=[auc_dims],
        dask="allowed",
    )

    return xr.Dataset({"POD": pod, "POFD": pofd, "AUC": auc})

In [None]:
import dask
from dask.diagnostics import ProgressBar
from dask.distributed import Client

with dask.config.set({"distributed.scheduler.worker-saturation": 1.0}):
    # client = Client(n_workers=6, threads_per_worker=3, memory_limit='10GB',
    client = Client(n_workers=4, threads_per_worker=3, memory_limit='15GB',
                   # local_directory="/data/tmp"
                   )

In [None]:
tx24h_results = []
thresholds = np.arange(0, 1.01, 0.01)
for model_name in tqdm(model_list_heatwave):
    lead_time_data = []
    for lead_time in tqdm(lead_times, leave=False):
        forecast_path = work_dir/"Extremes/Extremes_tx24h_{}_{}_{}_{}d.nc".format(
            str(resolution).replace(".", "").ljust(3, "0"), str(year), model_name, str(lead_time).zfill(2))

        if model_name in ["hres", "hres_ens", "hres_ens_mean"]:
            true_events = tx24h_hres0_extremes
        else:
            true_events = tx24h_era5_extremes
            
        forecast_dataset = xr.open_dataset(forecast_path)["events"].sel(quantile=0.9).astype(np.int8)
        if model_name == "hres_ens":
            forecast_dataset = forecast_dataset/(50+1)
        true_events = true_events.sel(time=forecast_dataset.time, quantile=0.9).astype(np.int8)
        metrics_data = []
        for lat in tqdm(forecast_dataset["latitude"].data, leave=False):
            roc_data = roc_curve_data(forecast_dataset.sel(latitude=lat), true_events.sel(latitude=lat), thresholds, 
                           reduce_dims=["time"])["AUC"]
            metrics_data.append(roc_data)
        metrics_data = xr.concat(metrics_data, dim="latitude")
        metrics_data["latitude"] = forecast_dataset["latitude"].data
        lead_time_data.append(metrics_data)

    lead_time_data = xr.concat(lead_time_data, dim="lead_time")
    lead_time_data["lead_time"] = lead_times

    tx24h_results.append(lead_time_data)

tx24h_results = xr.concat(tx24h_results, dim="model_name")
tx24h_results["model_name"] = model_list_heatwave

tx24h_results.to_netcdf(work_dir / "metrics"  / "Metrics_tx24h_grid_roc.nc")

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

In [None]:
tx24h_results.to_netcdf(work_dir / "metrics"  / "Metrics_tx24h_grid_roc.nc")

In [None]:
tp24h_results = []
for model_name in tqdm(model_list_rainfall):
    lead_time_data = []
    for lead_time in tqdm(lead_times, leave=False):
        forecast_path = work_dir/"Extremes/Extremes_tp24h_{}_{}_{}_{}d.nc".format(
            str(resolution).replace(".", "").ljust(3, "0"), str(year), model_name, str(lead_time).zfill(2))

        true_events = tp24h_era5_extremes
        
        forecast_dataset = xr.open_dataset(forecast_path)["events"].astype(np.int8)
        if model_name == "hres_ens":
            forecast_dataset = forecast_dataset/(50+1)
        true_events = true_events.sel(time=forecast_dataset.time)
        
        metrics_data = []
        for lat in tqdm(forecast_dataset["latitude"].data, leave=False):
            roc_data = roc_curve_data(forecast_dataset.sel(latitude=lat), true_events.sel(latitude=lat), thresholds, 
                           reduce_dims=["time"])["AUC"]
            metrics_data.append(roc_data)
        metrics_data = xr.concat(metrics_data, dim="latitude")
        metrics_data["latitude"] = forecast_dataset["latitude"].data
        lead_time_data.append(metrics_data)

    lead_time_data = xr.concat(lead_time_data, dim="lead_time")
    lead_time_data["lead_time"] = lead_times

    tp24h_results.append(lead_time_data)

tp24h_results = xr.concat(tp24h_results, dim="model_name")
tp24h_results["model_name"] = model_list_rainfall

tp24h_results.to_netcdf(work_dir / "metrics"  / "Metrics_tp24h_grid_roc.nc")

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

  0%|          | 0/121 [00:00<?, ?it/s]

# Calculate the ROC curve at regional scale

In [None]:
metrics_tx24h_grid = xr.open_dataset(work_dir / "metrics"  / "Metrics_tx24h_grid.nc")
metrics_tp24h_grid = xr.open_dataset(work_dir / "metrics"  / "Metrics_tp24h_grid.nc")

In [None]:
metrics_tx24h_grid

In [None]:
region_data = []
for region_name, region in tqdm(predefined_regions.items(), leave=False):
    metrics_region_grid = region.apply(metrics_tx24h_grid)
    weights = get_lat_weights(metrics_region_grid)
    test_data = metrics_region_grid.weighted(weights).mean(["latitude", "longitude"], skipna=True)
    region_data.append(test_data)
    
region_data = xr.concat(region_data, dim="region")
region_data["region"] = list(predefined_regions.keys())

region_data.to_netcdf(work_dir / "metrics" / "Metrics_tx24h_region_table.nc")

  0%|          | 0/13 [00:00<?, ?it/s]

In [None]:
region_data = []
for region_name, region in tqdm(predefined_regions.items(), leave=False):
    metrics_region_grid = region.apply(metrics_tp24h_grid)
    weights = get_lat_weights(metrics_region_grid)
    test_data = metrics_region_grid.weighted(weights).mean(["latitude", "longitude"], skipna=True)
    region_data.append(test_data)
    
region_data = xr.concat(region_data, dim="region")
region_data["region"] = list(predefined_regions.keys())

region_data.to_netcdf(work_dir / "metrics" / "Metrics_tp24h_region_table.nc")

  0%|          | 0/13 [00:00<?, ?it/s]

# Cconcat the results of BS and ROCSS

In [None]:
metrics_tx24h_grid_bs = xr.open_dataset(work_dir/"metrics/Metrics_tx24h_grid_brier_score.nc")
metrics_tp24h_grid_bs = xr.open_dataset(work_dir/"metrics/Metrics_tp24h_grid_brier_score.nc")

In [None]:
metrics_tx24h_grid_roc = xr.open_dataset(work_dir/"metrics/Metrics_tx24h_grid_roc.nc")
metrics_tp24h_grid_roc = xr.open_dataset(work_dir/"metrics/Metrics_tp24h_grid_roc.nc")

In [None]:
metrics_tx24h_grid = xr.open_dataset(work_dir/"metrics/Metrics_tx24h_grid.nc")
metrics_tp24h_grid = xr.open_dataset(work_dir/"metrics/Metrics_tp24h_grid.nc")

In [None]:
metrics_tx24h_grid_roc["BS"] = metrics_tx24h_grid_bs["brier_score"]
metrics_tx24h_grid_roc["ROCSS"] = 2*metrics_tx24h_grid_roc["AUC"] - 1

metrics_tx24h_grid_roc.to_netcdf(work_dir / "metrics"  / "Metrics_tx24h_grid_bs_roc.nc")

In [None]:
metrics_tp24h_grid_roc["BS"] = metrics_tp24h_grid_bs["brier_score"]
metrics_tp24h_grid_roc["ROCSS"] = 2*metrics_tp24h_grid_roc["AUC"] - 1

metrics_tp24h_grid_roc.to_netcdf(work_dir / "metrics"  / "Metrics_tp24h_grid_bs_roc.nc")