In [12]:
import xarray as xr
import pandas as pd
import numpy as np
from pathlib import Path
from scipy import stats
from linearmodels import PanelOLS
from tqdm.notebook import tqdm, trange

In [13]:
import dataclasses
import typing as t

import numpy as np
import xarray as xr


@dataclasses.dataclass
class Region:
    """Region selector for spatially averaged metrics.

    .apply() method is called before spatial averaging in the Metrics classes.
    Region selection can be either applied as an operation on the dataset itself
    or a weights dataset, typically the latitude weights. The latter option is
    required to implement non-box regions without the use of .where() which would
    clash with skipna=False used as default in the metrics. The way this is
    implemented is by multiplying the input weights with a boolean weight dataset.

    Since sometimes the dataset and sometimes the weights are modified, these must
    be used together, most likely insice the _spatial_average function defined in
    metrics.py.
    """

    def apply(self, dataset: xr.Dataset) -> xr.Dataset:
        """Apply region selection to dataset and/or weights.

        Args:
          dataset: Spatial metric, i.e. RMSE
          weights: Weights dataset, i.e. latitude weights

        Returns:
          dataset: Potentially modified (sliced) dataset.
          weights: Potentially modified weights data array, to be used in
          combination with dataset, e.g. in _spatial_average().
        """
        raise NotImplementedError


@dataclasses.dataclass
class SliceRegion(Region):
    """Latitude-longitude box selection."""

    lat_slice: t.Optional[t.Union[slice, list[slice]]] = dataclasses.field(
      default_factory=lambda: slice(None, None))
    lon_slice: t.Optional[t.Union[slice, list[slice]]] = dataclasses.field(
      default_factory=lambda: slice(None, None))

    def apply(self, dataset: xr.Dataset) -> xr.Dataset:
        """Returns dataset sliced according to lat/lon_sliceparameters."""
        lats = (
            self.lat_slice if isinstance(self.lat_slice, list) else [self.lat_slice])
        lons = (
            self.lon_slice if isinstance(self.lon_slice, list) else [self.lon_slice])

        lats = xr.concat(
            [dataset.latitude.sel(latitude=s) for s in lats], dim='latitude')
        lons = xr.concat(
            [dataset.longitude.sel(longitude=s) for s in lons], dim='longitude')

        return dataset.sel(latitude=lats, longitude=lons)


In [14]:
predefined_regions = {
    'global': SliceRegion(),
    'tropics': SliceRegion(lat_slice=slice(-20, 20)),
    'extra-tropics': SliceRegion(
      lat_slice=[slice(None, -20), slice(20, None)]
    ),
    'northern-hemisphere': SliceRegion(lat_slice=slice(20, None)),
    'southern-hemisphere': SliceRegion(lat_slice=slice(None, -20)),
    'europe': SliceRegion(
      lat_slice=slice(35, 75),
      lon_slice=[slice(360 - 12.5, None), slice(0, 42.5)],
    ),
    'north-america': SliceRegion(
      lat_slice=slice(25, 60), lon_slice=slice(360 - 120, 360 - 75)
    ),
    'north-atlantic': SliceRegion(
      lat_slice=slice(25, 65), lon_slice=slice(360 - 70, 360 - 10)
    ),
    'north-pacific': SliceRegion(
      lat_slice=slice(25, 60), lon_slice=slice(145, 360 - 130)
    ),
    'east-asia': SliceRegion(
      lat_slice=slice(25, 60), lon_slice=slice(102.5, 150)
    ),
    'ausnz': SliceRegion(
      lat_slice=slice(-45, -12.5), lon_slice=slice(120, 175)
    ),
    'arctic': SliceRegion(lat_slice=slice(60, 90)),
    'antarctic': SliceRegion(lat_slice=slice(-90, -60)),
}

In [15]:
work_dir = Path(".")

In [16]:
tp24h_era5_extremes = xr.open_dataset("Extremes/Extremes_tp24h_150_2020_era5.nc")["events"].load()
tx24h_era5_extremes = xr.open_dataset("Extremes/Extremes_tx24h_150_2020_era5.nc")["events"].load()
tx24h_hres0_extremes = xr.open_dataset("Extremes/Extremes_tx24h_150_2020_hres0.nc")["events"].load()

In [17]:
resolution = 1.50
year = 2020
lead_times = np.arange(1, 11)
model_list_heatwave = ["hres", "hres_ens",  "hres_ens", "hres_ens_mean", "era5-forecast", 
                       "pangu-oper","graphcast-oper", "pangu", "graphcast","fuxi"]
model_list_rainfall = ["hres", "hres_ens",  "hres_ens_mean", "graphcast", "graphcast-oper", "fuxi"]

# Significance tests at the global and regional scale

In [18]:
tp24h_results = []
for model_name in tqdm(model_list_rainfall[1:]):
    lead_time_data = []
    for lead_time in tqdm(lead_times, leave=False):
        forecast_path = work_dir/"Extremes/Extremes_tp24h_{}_{}_{}_{}d.nc".format(
            str(resolution).replace(".", "").ljust(3, "0"), str(year), model_name, str(lead_time).zfill(2))
        hres_path = work_dir/"Extremes/Extremes_tp24h_{}_{}_{}_{}d.nc".format(
            str(resolution).replace(".", "").ljust(3, "0"), str(year), "hres", str(lead_time).zfill(2))
        
        forecast_dataset = xr.open_dataset(forecast_path)["events"].load().astype(np.int8)
        hres_dataset = xr.open_dataset(hres_path)["events"].load().astype(np.int8)
        if model_name in ["hres_ens"]:
            forecast_dataset = forecast_dataset/(50+1)
        
        observe_events = tp24h_era5_extremes

        
        region_data = []
        for region_name, region in tqdm(predefined_regions.items(), leave=False):

            forecast_data = region.apply(forecast_dataset.sel(quantile=0.9))
            hres_data = region.apply(hres_dataset.sel(time=forecast_data.time, quantile=0.9))
            
            observe_data = region.apply(observe_events.sel(time=forecast_data.time, quantile=0.9))
            
            lat=np.tile(observe_data.latitude, len(observe_data.longitude)*len(observe_data.time))
            lon=np.tile(np.repeat(observe_data.longitude, len(observe_data.latitude)),len(observe_data.time))
            time=np.repeat(observe_data.time, len(observe_data.latitude)*len(observe_data.longitude)).values
            
            observe_data = observe_data.values.flatten()
            forecast_data = forecast_data.values.flatten()
            hres_data = hres_data.values.flatten()
            
            lat=lat[observe_data]
            lon=lon[observe_data]
            time=time[observe_data]
            
            forecasts = forecast_data[observe_data] #.astype(int)
            hres = hres_data[observe_data] # .astype(int)
            observations=observe_data[observe_data].astype(np.int8)
            
            sq_err_ai = (forecasts - observations) **2
            sq_err_hres = (hres - observations) **2
            
            diff_ai = sq_err_ai - sq_err_hres
            
            df_ai = pd.DataFrame({'Latitude': lat,'Longitude': lon,'Time':time,'Predicted': diff_ai})
            df_ai['Latlon']=round(df_ai['Latitude'],2).astype(str)+round(df_ai['Longitude'],2).astype(str)

            model_cluster = PanelOLS.from_formula(
                formula=("Predicted ~ 1"), data=df_ai.set_index(["Latlon", "Time"]),).fit(
                    cov_type="clustered", cluster_entity=True, cluster_time=True)
            test_st=model_cluster.tstats[0]
            
            threshold_low=stats.t(df=len(observations)).ppf((0.025))
            threshold_high=stats.t(df=len(observations)).ppf((0.975))
            
            significant= np.where((test_st > threshold_high) | (test_st < threshold_low), 1, 0)
            result = {"model_name": model_name, "lead_time": lead_time, "region_name": region_name, "significant": significant.tolist()}
            tp24h_results.append(result)

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

In [19]:
pd.DataFrame(tp24h_results).to_csv("Significant_region_tp24h.csv", index=None)

In [21]:
tx24h_results = []
for model_name in tqdm(model_list_heatwave[1:]):
    lead_time_data = []
    for lead_time in tqdm(lead_times, leave=False):
        forecast_path = work_dir/"Extremes/Extremes_tx24h_{}_{}_{}_{}d.nc".format(
            str(resolution).replace(".", "").ljust(3, "0"), str(year), model_name, str(lead_time).zfill(2))
        hres_path = work_dir/"Extremes/Extremes_tx24h_{}_{}_{}_{}d.nc".format(
            str(resolution).replace(".", "").ljust(3, "0"), str(year), "hres", str(lead_time).zfill(2))
        
        forecast_dataset = xr.open_dataset(forecast_path)["events"].load().astype(np.int8)
        hres_dataset = xr.open_dataset(hres_path)["events"].load().astype(np.int8)
        if model_name in ["hres_ens"]:
            forecast_dataset = forecast_dataset/(50+1)
        observe_events = tx24h_era5_extremes
        observe_events_hres0 = tx24h_hres0_extremes

        region_data = []
        for region_name, region in tqdm(predefined_regions.items(), leave=False):

            forecast_data = region.apply(forecast_dataset.sel(quantile=0.9))#.transpose("latitude", "longitude", "time")
            hres_data = region.apply(hres_dataset.sel(time=forecast_data.time, quantile=0.9))#.transpose("latitude", "longitude", "time")
            
            observe_data = region.apply(observe_events.sel(time=forecast_data.time, quantile=0.9))#.transpose("latitude", "longitude", "time")
            observe_data_hres0 = region.apply(observe_events_hres0.sel(time=forecast_data.time, quantile=0.9))
            
            lat=np.tile(observe_data.latitude, len(observe_data.longitude)*len(observe_data.time))
            lon=np.tile(np.repeat(observe_data.longitude, len(observe_data.latitude)),len(observe_data.time))
            time=np.repeat(observe_data.time, len(observe_data.latitude)*len(observe_data.longitude)).values
            
            observe_data = observe_data.values.flatten()
            observe_data_hres0 = observe_data_hres0.values.flatten()
            forecast_data = forecast_data.values.flatten()
            hres_data = hres_data.values.flatten()
            
            lat=lat[observe_data]
            lon=lon[observe_data]
            time=time[observe_data]
            
            forecasts = forecast_data[observe_data] #.astype(int)
            hres = hres_data[observe_data] #.astype(int)
            observations=observe_data[observe_data].astype(np.int8)
            observations_hres0=observe_data_hres0[observe_data].astype(np.int8)
            
            sq_err_ai = (forecasts - observations) ** 2
            sq_err_hres = (hres - observations) ** 2
            
            diff_ai = sq_err_ai - sq_err_hres
            
            df_ai = pd.DataFrame({'Latitude': lat,'Longitude': lon,'Time':time,'Predicted': diff_ai})
            df_ai['Latlon']=round(df_ai['Latitude'],2).astype(str)+round(df_ai['Longitude'],2).astype(str)

            model_cluster = PanelOLS.from_formula(
                formula=("Predicted ~ 1"), data=df_ai.set_index(["Latlon", "Time"]),).fit(
                    cov_type="clustered", cluster_entity=True, cluster_time=True)
            test_st=model_cluster.tstats[0]
            
            threshold_low=stats.t(df=len(observations)).ppf((0.025))
            threshold_high=stats.t(df=len(observations)).ppf((0.975))
            
            significant= np.where((test_st > threshold_high) | (test_st < threshold_low), 1, 0)
            result = {"model_name": model_name, "lead_time": lead_time, "region_name": region_name, "significant": significant.tolist()}
            tx24h_results.append(result)
    #         break
    #     # break
            # print(threshold_low, test_st, threshold_high, significant)
    # break

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

In [23]:
pd.DataFrame(tx24h_results).to_csv("Significant_region_tx24h.csv", index=None)

# Significance tests at the grid cell

In [1]:
import xarray as xr
import pandas as pd
import numpy as np
from pathlib import Path
from scipy import stats
from linearmodels import PanelOLS
from tqdm.notebook import tqdm, trange
import statsmodels.api
import statsmodels as sm

In [2]:
work_dir = Path(".")

In [3]:
tp24h_era5_extremes = xr.open_dataset("Extremes/Extremes_tp24h_150_2020_era5.nc")["events"].load()
tx24h_era5_extremes = xr.open_dataset("Extremes/Extremes_tx24h_150_2020_era5.nc")["events"].load()
tx24h_hres0_extremes = xr.open_dataset("Extremes/Extremes_tx24h_150_2020_hres0.nc")["events"].load()

In [4]:
resolution = 1.50
year = 2020
lead_times = np.arange(1, 11)
model_list_heatwave = ["hres", "hres_ens",  "hres_ens_mean", "era5-forecast",
               "pangu-oper","graphcast-oper", "pangu", "graphcast","fuxi"]
model_list_rainfall = ["hres", "hres_ens",  "hres_ens_mean", "graphcast", "graphcast-oper", "fuxi"]

In [5]:
def calcu_stat(x, time):
    df = pd.DataFrame({"Time": time, "Values": x}).dropna()
    df["Latitude"] = 1
    df["Longitude"] = 1
    df = df.set_index(['Latitude', 'Time'])
    
    # print(df.dropna())
    model_cluster = PanelOLS.from_formula(
        formula=("Values ~ 1"),data=df).fit(
            cov_type="clustered", cluster_entity=False, cluster_time=True)

    t_stat=model_cluster.tstats[0]
    # print(t_stat)
    df = len(df) - 1
    p_val=2*(stats.t.cdf(-abs(t_stat), df))
    
    sig_cor=(sm.stats.multitest.fdrcorrection(np.array([p_val]), alpha=0.1)[0])*1
    return sig_cor[0]
    # print(time)

In [6]:
tp24h_results = []
lead_times = [3, 5, 10]
use_models = ["hres_ens", "hres_ens_mean", "graphcast-oper", "graphcast", "fuxi"]
for lead_time in tqdm(lead_times):
    hres_path = work_dir/"Extremes/Extremes_tp24h_{}_{}_{}_{}d.nc".format(
        str(resolution).replace(".", "").ljust(3, "0"), str(year), "hres", str(lead_time).zfill(2))
    hres_dataset = xr.open_dataset(hres_path)["events"].load().astype(np.int8)
    observe_events = tp24h_era5_extremes
    model_data = []
    for model_name in tqdm(use_models, leave=False):
        lead_time_data = []
    
        forecast_path = work_dir/"Extremes/Extremes_tp24h_{}_{}_{}_{}d.nc".format(
            str(resolution).replace(".", "").ljust(3, "0"), str(year), model_name, str(lead_time).zfill(2))
        forecast_dataset = xr.open_dataset(forecast_path)["events"].load().astype(np.int8)
        if model_name in ["hres_ens"]:
            forecast_dataset = forecast_dataset/(50+1)            
        forecast_data = forecast_dataset.sel(quantile=0.9)
        time = forecast_data.time
        hres_data = hres_dataset.sel(time=time, quantile=0.9)
        observe_data = observe_events.sel(time=time, quantile=0.9)
        
        
        forecast_masked_arr = np.ma.masked_array(forecast_data, ~observe_data.data)
        observe_masked_arr = np.ma.masked_array(observe_data.data.astype(int), ~observe_data.data)
        hres_masked_arr = np.ma.masked_array(hres_data, ~observe_data.data)
        
        sq_err_ai = (forecast_masked_arr - observe_masked_arr) **2
        sq_err_hres = (hres_masked_arr - observe_masked_arr) **2        
        diff_ai = sq_err_ai - sq_err_hres
        
        stats_values = np.apply_along_axis(calcu_stat, 0, diff_ai, time.data)
        da = xr.DataArray(stats_values, 
                          coords={"longitude": forecast_data.longitude,
                                  "latitude": forecast_data.latitude}, 
                          dims=["longitude", "latitude"], name="significant")
        model_data.append(da)
    model_data = xr.concat(model_data, dim="model_name")
    model_data["model_name"] = use_models
    
    tp24h_results.append(model_data)
tp24h_results = xr.concat(tp24h_results, dim="lead_time")
tp24h_results["lead_time"] = lead_times    

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]


KeyboardInterrupt



In [10]:
tp24h_results.to_netcdf("Significant_Grid_tp24h.nc")

In [7]:
tx24h_results = []
lead_times = [3, 5, 10]
use_models = ["hres_ens", "hres_ens_mean", "pangu", "graphcast", "fuxi"]
for lead_time in tqdm(lead_times):
    hres_path = work_dir/"Extremes/Extremes_tx24h_{}_{}_{}_{}d.nc".format(
        str(resolution).replace(".", "").ljust(3, "0"), str(year), "hres", str(lead_time).zfill(2))
    hres_dataset = xr.open_dataset(hres_path)["events"].load().astype(np.int8)
    observe_events = tp24h_era5_extremes
    observe_events_hres0 = tx24h_hres0_extremes
    model_data = []
    for model_name in tqdm(use_models, leave=False):
        lead_time_data = []
        
        forecast_path = work_dir/"Extremes/Extremes_tx24h_{}_{}_{}_{}d.nc".format(
            str(resolution).replace(".", "").ljust(3, "0"), str(year), model_name, str(lead_time).zfill(2))
        forecast_dataset = xr.open_dataset(forecast_path)["events"].load().astype(np.int8)
        if model_name in ["hres_ens"]:
            forecast_dataset = forecast_dataset/(50+1)
            
        forecast_data = forecast_dataset.sel(quantile=0.9)
        time = forecast_data.time
        hres_data = hres_dataset.sel(time=time, quantile=0.9)
        observe_data = observe_events.sel(time=time, quantile=0.9)
        observe_data_hres0 = observe_events_hres0.sel(time=time, quantile=0.9)
        
        
        forecast_masked_arr = np.ma.masked_array(forecast_data, ~observe_data.data)
        observe_masked_arr = np.ma.masked_array(observe_data.data.astype(int), ~observe_data.data)
        hres_masked_arr = np.ma.masked_array(hres_data, ~observe_data.data)
        observe_hres0_masked_arr = np.ma.masked_array(observe_data_hres0.data.astype(int), ~observe_data.data)
        
        sq_err_ai = (forecast_masked_arr - observe_masked_arr) **2
        sq_err_hres = (hres_masked_arr - observe_masked_arr) **2
        
        diff_ai = sq_err_ai - sq_err_hres
        
        stats_values = np.apply_along_axis(calcu_stat, 0, diff_ai, time.data)
        da = xr.DataArray(stats_values, 
                          coords={"longitude": forecast_data.longitude,
                                  "latitude": forecast_data.latitude}, 
                          dims=["longitude", "latitude"], name="significant")
        model_data.append(da)
    model_data = xr.concat(model_data, dim="model_name")
    model_data["model_name"] = use_models
    
    tx24h_results.append(model_data)
tx24h_results = xr.concat(tx24h_results, dim="lead_time")
tx24h_results["lead_time"] = lead_times    

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

In [8]:
tx24h_results.to_netcdf("Significant_Grid_tx24h.nc")