# Prepare the environment

In [None]:
! pip install "xarray[complete]"

Collecting sparse (from xarray[complete])
  Downloading sparse-0.15.4-py2.py3-none-any.whl.metadata (4.5 kB)
Collecting numbagg (from xarray[complete])
  Downloading numbagg-0.8.2-py3-none-any.whl.metadata (47 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m47.5/47.5 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
Collecting flox (from xarray[complete])
  Downloading flox-0.9.15-py3-none-any.whl.metadata (17 kB)
Collecting cartopy (from xarray[complete])
  Downloading Cartopy-0.24.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.9 kB)
Collecting nc-time-axis (from xarray[complete])
  Downloading nc_time_axis-1.4.1-py3-none-any.whl.metadata (4.7 kB)
Collecting netCDF4 (from xarray[complete])
  Downloading netCDF4-1.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.8 kB)
Collecting zarr (from xarray[complete])
  Downloading zarr-2.18.3-py3-none-any.whl.metadata (5.7 kB)
Collecting cftime (from xarray[complete])
  Download

# Import the necessary libraries

In [None]:
import os
import datetime
import numpy as np
import pandas as pd
import xarray as xr
from pathlib import Path
from tqdm.notebook import tqdm, trange

In [None]:
from google.colab import auth
auth.authenticate_user()

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
hres0_path = "gs://weatherbench2/datasets/hres_t0/2016-2022-6h-240x121_equiangular_with_poles_conservative.zarr"
era5_6h_150_path = "gs://weatherbench2/datasets/era5/1959-2023_01_10-6h-240x121_equiangular_with_poles_conservative.zarr"
era5_daily_150_path = "gs://weatherbench2/datasets/era5_daily/1959-2023_01_10-1h-240x121_equiangular_with_poles_conservative.zarr"

hres_path = "gs://weatherbench2/datasets/hres/2016-2022-0012-240x121_equiangular_with_poles_conservative.zarr"
graphcast_2018_path = "gs://weatherbench2/datasets/graphcast/2018/date_range_2017-11-16_2019-02-01_12_hours-240x121_equiangular_with_poles_conservative.zarr"
graphcast_2020_path = "gs://weatherbench2/datasets/graphcast/2020/date_range_2019-11-16_2021-02-01_12_hours-240x121_equiangular_with_poles_conservative.zarr"
graphcast_hres_init_2020_path = "gs://weatherbench2/datasets/graphcast_hres_init/2020/date_range_2019-11-16_2021-02-01_12_hours-240x121_equiangular_with_poles_conservative.zarr"
fuxi_path = "gs://weatherbench2/datasets/fuxi/2020-240x121_equiangular_with_poles_conservative.zarr"
ifs_ens_path = "gs://weatherbench2/datasets/ifs_ens/2018-2022-240x121_equiangular_with_poles_conservative.zarr"
ifs_ens_mean_path = "gs://weatherbench2/datasets/ifs_ens/2018-2022-240x121_equiangular_with_poles_conservative_mean.zarr"
pangu_path = "gs://weatherbench2/datasets/pangu/2018-2022_0012_240x121_equiangular_with_poles_conservative.zarr"
pangu_oper_path = "gs://weatherbench2/datasets/pangu_hres_init/2020_0012_240x121_equiangular_with_poles_conservative.zarr"
neuralgcm_path = "gs://weatherbench2/datasets/neuralgcm_deterministic/2020-240x121_equiangular_with_poles_conservative.zarr"
neuralgcm_ens_path = "gs://weatherbench2/datasets/neuralgcm_ens/2020-240x121_equiangular_with_poles_conservative.zarr"
neuralgcm_ens_mean_path = "gs://weatherbench2/datasets/neuralgcm_ens/2020-240x121_equiangular_with_poles_conservative_mean.zarr"
era5_forecast_path = "gs://weatherbench2/datasets/era5-forecasts/2020-240x121_equiangular_with_poles_conservative.zarr"
keisler_path = "gs://weatherbench2/datasets/keisler/2020-240x121_equiangular_with_poles_conservative.zarr"
sphericalcnn_path = "gs://weatherbench2/datasets/sphericalcnn/2020-240x121_equiangular_with_poles.zarr"

In [None]:
work_dir = Path("/content/drive/MyDrive/WB2BinaryForecast/")
# work_dir = Path.cwd()

# Download the data to google drive

## Load the ERA5 and HRES t0 data

In [None]:
era5_t2m6h = xr.open_zarr(era5_6h_150_path)["2m_temperature"].sel(time=pd.date_range(
    start=datetime.datetime(2019,12,31,0,0,0),
    end=datetime.datetime(2021,1,1,1,0,0), freq="6h")
           ).load()

In [None]:
hres0_t2m6h = xr.open_zarr(hres0_path)["2m_temperature"].sel(time=pd.date_range(
    start=datetime.datetime(2019,12,31,0,0,0),
    end=datetime.datetime(2021,1,1,1,0,0), freq="6h")
           ).load()

In [None]:
tp24h_era5 = xr.open_zarr(era5_6h_150_path)["total_precipitation_24hr"].sel(time=pd.date_range(
    start=datetime.datetime(2020,1,1,12,0,0),
    end=datetime.datetime(2021,1,1,1,0,0), freq="12h")
           ).load()

# Calaulate the 90th, 95th and 99th percentiles of the 24-hour accumulation of total precipitation (TP24h) and the 24-hour maximum of 2m temperature (T2M24h)

## T2M24h

### ERA5

In [None]:
era5_t2m6h = xr.open_zarr(era5_6h_150_path)["2m_temperature"].sel(time=pd.date_range(
    start=datetime.datetime(2019,12,31,0,0,0),
    end=datetime.datetime(2021,1,1,1,0,0), freq="6h")
           ).load()

In [None]:
tx24h_era5 = era5_t2m6h.rolling(time=4).max(dim="time").sel(time=pd.date_range(
    start=datetime.datetime(2020,1,1,12,0,0),
    end=datetime.datetime(2021,1,1,1,0,0), freq="12h"))
tx24h_era5

  tx24h_era5 = era5_t2m6h.rolling(time=4).max(dim="time").sel(time=pd.date_range(


In [None]:
tx24h_era5_threshold = tx24h_era5.chunk(dict(time=-1)).quantile([0.9, 0.95, 0.99], dim="time").astype(np.float32).compute()
tx24h_era5_threshold

In [None]:

tx24h_era5.to_netcdf(work_dir / "Data/Data_tx24h_150_2020_era5.nc")
tx24h_era5_threshold.to_netcdf(work_dir / "Threshold_tx24h_150_2020_era5.nc")

In [None]:
era5_tx24h_extremes = xr.Dataset({#'anomality': (tx24h_era5-tx24h_era5_threshold).astype(np.float32),
                                  # "data": tx24h_era5,
                                  'events': tx24h_era5>tx24h_era5_threshold, })
era5_tx24h_extremes

In [None]:
era5_tx24h_extremes.to_netcdf(work_dir / "Extremes/Extremes_tx24h_150_2020_era5.nc")

### HRES t0

In [None]:
hres0_t2m6h = xr.open_zarr(hres0_path)["2m_temperature"].sel(time=pd.date_range(
    start=datetime.datetime(2019,12,31,0,0,0),
    end=datetime.datetime(2021,1,1,1,0,0), freq="6h")
           ).load()

In [None]:
tx24h_hres0 = hres0_t2m6h.rolling(time=4).max(dim="time").sel(time=pd.date_range(
    start=datetime.datetime(2020,1,1,12,0,0),
    end=datetime.datetime(2021,1,1,1,0,0), freq="12h"))
tx24h_hres0

  tx24h_hres0 = hres0_t2m6h.rolling(time=4).max(dim="time").sel(time=pd.date_range(


In [None]:
tx24h_hres0_threshold = tx24h_hres0.chunk(dict(time=-1)).quantile([0.9, 0.95, 0.99], dim="time").astype(np.float32).compute()
tx24h_hres0_threshold

In [None]:
tx24h_hres0.to_netcdf(work_dir / "Data/Data_tx24h_150_2020_hres0.nc")
tx24h_hres0_threshold.to_netcdf(work_dir / "Threshold_tx24h_150_2020_hres0.nc")

In [None]:
hres0_tx24h_extremes = xr.Dataset({#'anomality': (tx24h_hres0-tx24h_hres0_threshold).astype(np.float32),
                                   #"data": tx24h_hres0,
                                'events': tx24h_hres0>tx24h_hres0_threshold, })
hres0_tx24h_extremes

In [None]:
hres0_tx24h_extremes.to_netcdf(work_dir / "Extremes/Extremes_tx24h_150_2020_hres0.nc")

## TP24h

In [None]:
tp24h_era5 = xr.open_zarr(era5_6h_150_path)["total_precipitation_24hr"].sel(time=pd.date_range(
    start=datetime.datetime(2020,1,1,12,0,0),
    end=datetime.datetime(2021,1,1,1,0,0), freq="12h")
           ).load()

In [None]:
tp24h_era5

In [None]:
tp24h_era5_threshold = tp24h_era5.chunk(dict(time=-1)).quantile([0.9, 0.95, 0.99], dim="time").astype(np.float32).compute()
tp24h_era5_threshold

In [None]:
tp24h_era5.to_netcdf(work_dir / "Data/Data_tp24h_150_2020_era5.nc")
tp24h_era5_threshold.to_netcdf(work_dir / "Threshold_tp24h_150_2020_era5.nc")

In [None]:
era5_tp24h_extremes = xr.Dataset({#'anomality': (tp24h_era5-tp24h_era5_threshold).astype(np.float32),
                                  #"data": tp24h_era5,
                                'events': tp24h_era5>tp24h_era5_threshold, })
era5_tp24h_extremes

In [None]:
era5_tp24h_extremes.to_netcdf(work_dir / "Extremes/Extremes_tp24h_150_2020_era5.nc")

## Calculate more thresholds

In [None]:
quantile_list = [0.8, 0.85, 0.9, 0.93, 0.95, 0.96, 0.97, 0.98, 0.99]


In [None]:
tx24h_era5 = xr.open_dataset(work_dir / "Data/Data_tx24h_150_2020_era5.nc")["2m_temperature"].load()
tx24h_hres0 = xr.open_dataset(work_dir / "Data/Data_tx24h_150_2020_hres0.nc")["2m_temperature"].load()
tp24h_era5 = xr.open_dataset(work_dir / "Data/Data_tp24h_150_2020_era5.nc")["total_precipitation_24hr"].load()

In [None]:
tx24h_era5_threshold = tx24h_era5.chunk(dict(time=-1)).quantile(quantile_list, dim="time").astype(np.float32)
tx24h_era5_threshold.to_netcdf(work_dir / "Threshold_tx24h_150_2020_era5.nc")

tx24h_hres0_threshold = tx24h_hres0.chunk(dict(time=-1)).quantile(quantile_list, dim="time").astype(np.float32)
tx24h_hres0_threshold.to_netcdf(work_dir / "Threshold_tx24h_150_2020_hres0.nc")

tp24h_era5_threshold = tp24h_era5.chunk(dict(time=-1)).quantile(quantile_list, dim="time").astype(np.float32)
tp24h_era5_threshold.to_netcdf(work_dir / "Threshold_tp24h_150_2020_era5.nc")

In [None]:
era5_tx24h_extremes = xr.Dataset({'events': tx24h_era5>tx24h_era5_threshold, })
era5_tx24h_extremes.to_netcdf(work_dir / "Extremes/Extremes_tx24h_150_2020_era5.nc")

hres0_tx24h_extremes = xr.Dataset({'events': tx24h_hres0>tx24h_hres0_threshold, })
hres0_tx24h_extremes.to_netcdf(work_dir / "Extremes/Extremes_tx24h_150_2020_hres0.nc")

era5_tp24h_extremes = xr.Dataset({'events': tp24h_era5>tp24h_era5_threshold, })
era5_tp24h_extremes.to_netcdf(work_dir / "Extremes/Extremes_tp24h_150_2020_era5.nc")

# Load the thresholds

In [None]:
tx24h_era5_threshold = xr.open_dataset(work_dir / "Threshold_tx24h_150_2020_era5.nc")["2m_temperature"].load()
tx24h_hres0_threshold = xr.open_dataset(work_dir / "Threshold_tx24h_150_2020_hres0.nc")["2m_temperature"].load()
tp24h_era5_threshold = xr.open_dataset(work_dir / "Threshold_tp24h_150_2020_era5.nc")["total_precipitation_24hr"].load()

# Convert the continuous forecasts to binary forecasts

## T2M24h

### From deterministic forecasts

In [None]:
def calculate_t2max_extremes(model_name, resolution, year, lead_time, recalc=False):
    save_extremes_path = work_dir / "Extremes/Extremes_tx24h_{}_{}_{}_{}d.nc".format(
        str(resolution).replace(".", "").ljust(3, "0"), str(year), model_name, str(lead_time).zfill(2))
    save_data_path = work_dir / "Data/Data_tx24h_{}_{}_{}_{}d.nc".format(
        str(resolution).replace(".", "").ljust(3, "0"), str(year), model_name, str(lead_time).zfill(2))

    if save_extremes_path.exists() and (not recalc): # 判断是否已经完成计算，或者需要重新计算
        print("Already saved:", save_extremes_path)
    else:
        if save_data_path.exists():
            forecast_data = xr.open_dataset(save_data_path)["2m_temperature"]
        else:
            forecast_data = xr.open_zarr(forecast_path_dict[model_name])["2m_temperature"]

            time_range = pd.date_range(datetime.datetime(2020,1,1,0,0,0), datetime.datetime(2021,1,1,1,0,0), freq="12h")
            prediction_timedelta = np.intersect1d(np.arange((lead_time-1)*24+6, lead_time*24+6, 6).astype('timedelta64[h]'),
                                                  forecast_data.prediction_timedelta.data.astype('timedelta64[h]'))
            # print(forecast_data.prediction_timedelta.data.astype('timedelta64[h]'))

            forecast_data = forecast_data.sel(time=np.intersect1d(time_range, forecast_data.time.data),
                                              prediction_timedelta=prediction_timedelta).max(dim="prediction_timedelta").compute()

            forecast_data["time"] = forecast_data["time"] + np.timedelta64(lead_time, 'D')
            forecast_data = forecast_data.sel(time=np.intersect1d(time_range, forecast_data.time.data))

            if "lat" in forecast_data.coords:
                forecast_data = forecast_data.rename({'lat': 'latitude', 'lon': 'longitude'})

            forecast_data.to_netcdf(save_data_path)

        if model_name in ["hres", "hres_ens_mean", "hres_ens"]:
            t2max_threshlod = tx24h_hres0_threshold
        else:
            t2max_threshlod = tx24h_era5_threshold


        forecast_extremes = xr.Dataset({#'anomality': (forecast_data-t2max_threshlod).astype(np.float32),
                                        #'data': forecast_data,
                                  'events': forecast_data>t2max_threshlod, })
        forecast_extremes.to_netcdf(save_extremes_path)
        print("Save file:", save_extremes_path)


In [None]:
initial_times = [0, 12]
lead_times = np.arange(1, 11)
forecast_path_dict = {"graphcast": graphcast_2020_path, "graphcast-oper":graphcast_hres_init_2020_path,
                      "pangu": pangu_path, "pangu-oper": pangu_oper_path,
                      "hres": hres_path,"hres_ens_mean": ifs_ens_mean_path, "hres_ens": ifs_ens_path,
                      "era5-forecast": era5_forecast_path,
                      # "keisler": keisler_path, "sphericalcnn":sphericalcnn_path,
                      "fuxi":fuxi_path,
                      # "neuralgcm": neuralgcm_path
                      }

In [None]:
resolution = 1.50
year = 2020

for model_name in tqdm(forecast_path_dict.keys()):

  for lead_time in tqdm(lead_times, leave=False):
    calculate_t2max_extremes(model_name, resolution, year, lead_time)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

### From ensemble forecasts

In [None]:
def calculate_t2max_extremes(model_name, resolution, year, lead_time, recalc=False):
    for number in trange(1, 51, leave=False):
        save_extremes_path = work_dir / "Extremes/Extremes_tx24h_{}_{}_{}_{}d_{}em.nc".format(
            str(resolution).replace(".", "").ljust(3, "0"), str(year), model_name, str(lead_time).zfill(2), str(number).zfill(2))
        save_data_path = work_dir / "Data/Data_tx24h_{}_{}_{}_{}d_{}em.nc".format(
            str(resolution).replace(".", "").ljust(3, "0"), str(year), model_name, str(lead_time).zfill(2), str(number).zfill(2))

        if save_extremes_path.exists() and save_data_path.exists() and (not recalc): # 判断是否已经完成计算，或者需要重新计算
            print("Already saved:", save_extremes_path)
        else:
            if save_data_path.exists():
                forecast_data = xr.open_dataset(save_data_path)["2m_temperature"]
            else:
                forecast_data = xr.open_zarr(forecast_path_dict[model_name])["2m_temperature"]

                time_range = pd.date_range(datetime.datetime(2020,1,1,0,0,0), datetime.datetime(2021,1,1,1,0,0), freq="12h")
                prediction_timedelta = np.intersect1d(np.arange((lead_time-1)*24+6, lead_time*24+6, 6).astype('timedelta64[h]'),
                                                      forecast_data.prediction_timedelta.data.astype('timedelta64[h]'))
                # print(forecast_data.prediction_timedelta.data.astype('timedelta64[h]'))

                forecast_data = forecast_data.sel(time=np.intersect1d(time_range, forecast_data.time.data), number=number,
                                                  prediction_timedelta=prediction_timedelta).max(dim="prediction_timedelta").compute()

                forecast_data["time"] = forecast_data["time"] + np.timedelta64(lead_time, 'D')
                forecast_data = forecast_data.sel(time=np.intersect1d(time_range, forecast_data.time.data))

                if "lat" in forecast_data.coords:
                    forecast_data = forecast_data.rename({'lat': 'latitude', 'lon': 'longitude'})

                forecast_data.to_netcdf(save_data_path)

            if model_name in ["hres", "hres_ens_mean", "hres_ens"]:
                t2max_threshlod = tx24h_hres0_threshold
            else:
                t2max_threshlod = tx24h_era5_threshold


            forecast_extremes = xr.Dataset({#'anomality': (forecast_data-t2max_threshlod).astype(np.float32),
                                            #'data': forecast_data,
                                      'events': forecast_data>t2max_threshlod, })
            forecast_extremes.to_netcdf(save_extremes_path)
            print("Save file:", save_extremes_path)


In [None]:
initial_times = [0, 12]
lead_times = np.arange(9, 10)
forecast_path_dict = {"graphcast": graphcast_2020_path, "graphcast-oper":graphcast_hres_init_2020_path,
                      "pangu": pangu_path, "pangu-oper": pangu_oper_path,
                      "hres": hres_path,"hres_ens_mean": ifs_ens_mean_path, "hres_ens": ifs_ens_path,
                      "era5-forecast": era5_forecast_path,
                      # "keisler": keisler_path, "sphericalcnn":sphericalcnn_path,
                      "fuxi":fuxi_path,
                      # "neuralgcm": neuralgcm_path
                      }

In [None]:
resolution = 1.50
year = 2020

# for model_name in tqdm(forecast_path_dict.keys()):
for model_name in tqdm(["hres_ens"]):

  for lead_time in tqdm(lead_times, leave=False):
    calculate_t2max_extremes(model_name, resolution, year, lead_time)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tx24h_150_2020_hres_ens_09d_01em.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tx24h_150_2020_hres_ens_09d_02em.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tx24h_150_2020_hres_ens_09d_03em.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tx24h_150_2020_hres_ens_09d_04em.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tx24h_150_2020_hres_ens_09d_05em.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tx24h_150_2020_hres_ens_09d_06em.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tx24h_150_2020_hres_ens_09d_07em.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tx24h_150_2020_hres_ens_09d_08em.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tx24h_150_2020_hres_ens_09d_09em.nc
Save file: /content/drive/MyDrive/WB2BinaryFor

#### Concat the 50 ensemble members to one dataset

In [None]:
%%time
resolution = 1.50
lead_times = np.arange(1, 11)
year = 2020
model_name = "hres_ens"

for lead_time in tqdm(lead_times, leave=False):
    ens_extremes = []
    for number in trange(1, 51, leave=False):
        save_number_extremes_path = work_dir / "Extremes/Extremes_tx24h_{}_{}_{}_{}d_{}em.nc".format(
            str(resolution).replace(".", "").ljust(3, "0"), str(year), model_name, str(lead_time).zfill(2), str(number).zfill(2))
        if not save_number_extremes_path.exists():
            print(save_number_extremes_path)
        
        test_data = xr.open_dataset(save_number_extremes_path)["events"]
        ens_extremes.append(test_data)
    ens_extremes = xr.concat(ens_extremes, dim="number")
    ens_extremes["number"] = np.arange(1, 51)
    ens_extremes = ens_extremes.sum(dim="number") # /50
    save_extremes_path = work_dir / "Extremes/Extremes_tx24h_{}_{}_{}_{}d.nc".format(
        str(resolution).replace(".", "").ljust(3, "0"), str(year), model_name, str(lead_time).zfill(2))
    ens_extremes.astype("i1").to_netcdf(save_extremes_path)


  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

CPU times: user 1min 12s, sys: 2min 57s, total: 4min 9s
Wall time: 14min 31s


# TP24h

## From deterministic forecasts

In [None]:
def calculate_tp24h_extremes(model_name, resolution, year, lead_time, recalc=False):
    save_extremes_path = work_dir / "Extremes/Extremes_tp24h_{}_{}_{}_{}d.nc".format(
        str(resolution).replace(".", "").ljust(3, "0"), str(year), model_name, str(lead_time).zfill(2))
    save_data_path = work_dir / "Data/Data_tp24h_{}_{}_{}_{}d.nc".format(
        str(resolution).replace(".", "").ljust(3, "0"), str(year), model_name, str(lead_time).zfill(2))

    if save_extremes_path.exists() and (not recalc):
        print("Already saved:", save_extremes_path)
    else:
        if save_data_path.exists():
            forecast_data = xr.open_dataset(save_data_path)["total_precipitation_24hr"]
        else:
            forecast_data = xr.open_zarr(forecast_path_dict[model_name])
            if "total_precipitation_24hr_from_6hr" in forecast_data.data_vars:
                forecast_data = forecast_data.rename({"total_precipitation_24hr_from_6hr": "total_precipitation_24hr"})
                forecast_data["total_precipitation_24hr"] = forecast_data["total_precipitation_24hr"]/1000 # fuxi的total_precipitation_24hr单位为mm
            forecast_data = forecast_data["total_precipitation_24hr"]

            time_range = pd.date_range(datetime.datetime(2020,1,1,0,0,0), datetime.datetime(2021,1,1,1,0,0), freq="12h")
            prediction_timedelta = np.timedelta64(lead_time, 'D')
            # print(forecast_data.prediction_timedelta.data.astype('timedelta64[h]'))

            forecast_data = forecast_data.sel(time=np.intersect1d(time_range, forecast_data.time.data),
                                              prediction_timedelta=prediction_timedelta)

            forecast_data["time"] = forecast_data["time"] + np.timedelta64(lead_time, 'D')
            forecast_data = forecast_data.sel(time=np.intersect1d(time_range, forecast_data.time.data))

            if "lat" in forecast_data.coords:
                forecast_data = forecast_data.rename({'lat': 'latitude', 'lon': 'longitude'})

            forecast_data.to_netcdf(save_data_path)

        tp24h_threshlod = tp24h_era5_threshold
        forecast_extremes = xr.Dataset({#'anomality': (forecast_data-t2max_threshlod).astype(np.float32),
                                        #'data': forecast_data,
                                  'events': forecast_data>tp24h_threshlod, })
        forecast_extremes.to_netcdf(save_extremes_path)
        print("Save file:", save_extremes_path)


In [None]:
initial_times = [0, 12]
lead_times = np.arange(1, 11)
forecast_path_dict = {"graphcast": graphcast_2020_path, "graphcast-oper":graphcast_hres_init_2020_path,
                      "pangu": pangu_path, "pangu-oper": pangu_oper_path,
                      "hres": hres_path,"hres_ens_mean": ifs_ens_mean_path,
                      "era5-forecast": era5_forecast_path,
                      # "keisler": keisler_path, "sphericalcnn":sphericalcnn_path,
                      "fuxi":fuxi_path,
                      # "neuralgcm": neuralgcm_path
                      }

In [None]:
resolution = 1.50
year = 2020

for model_name in tqdm(["hres", "hres_ens_mean", "graphcast", "graphcast-oper", "fuxi"]):
  for lead_time in tqdm(lead_times, leave=False):
    calculate_tp24h_extremes(model_name, resolution, year, lead_time)

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

Already saved: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_hres_01d.nc
Already saved: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_hres_02d.nc
Already saved: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_hres_03d.nc
Already saved: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_hres_04d.nc
Already saved: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_hres_05d.nc
Already saved: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_hres_06d.nc
Already saved: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_hres_07d.nc
Already saved: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_hres_08d.nc
Already saved: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_hres_09d.nc
Already saved: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_hr

  0%|          | 0/10 [00:00<?, ?it/s]

Already saved: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_hres_ens_mean_01d.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_hres_ens_mean_02d.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_hres_ens_mean_03d.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_hres_ens_mean_04d.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_hres_ens_mean_05d.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_hres_ens_mean_06d.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_hres_ens_mean_07d.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_hres_ens_mean_08d.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_hres_ens_mean_09d.nc
Save file: /content/drive/MyDrive/WB2Binar

  0%|          | 0/10 [00:00<?, ?it/s]

Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_graphcast_01d.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_graphcast_02d.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_graphcast_03d.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_graphcast_04d.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_graphcast_05d.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_graphcast_06d.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_graphcast_07d.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_graphcast_08d.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_graphcast_09d.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_20

  0%|          | 0/10 [00:00<?, ?it/s]

Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_graphcast-oper_01d.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_graphcast-oper_02d.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_graphcast-oper_03d.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_graphcast-oper_04d.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_graphcast-oper_05d.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_graphcast-oper_06d.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_graphcast-oper_07d.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_graphcast-oper_08d.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_graphcast-oper_09d.nc
Save file: /content/drive/MyDrive/WB2

  0%|          | 0/10 [00:00<?, ?it/s]

Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_fuxi_01d.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_fuxi_02d.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_fuxi_03d.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_fuxi_04d.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_fuxi_05d.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_fuxi_06d.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_fuxi_07d.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_fuxi_08d.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_fuxi_09d.nc
Save file: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_fuxi_10d.nc


## From ensemble forecasts

In [None]:
def calculate_tp24h_extremes(model_name, resolution, year, lead_time, recalc=False):
    for number in trange(1, 51, leave=False):

        save_extremes_path = work_dir / "Extremes/Extremes_tp24h_{}_{}_{}_{}d_{}em.nc".format(
            str(resolution).replace(".", "").ljust(3, "0"), str(year), model_name, str(lead_time).zfill(2), str(number).zfill(2))
        save_data_path = work_dir / "Data/Data_tp24h_{}_{}_{}_{}d_{}em.nc".format(
            str(resolution).replace(".", "").ljust(3, "0"), str(year), model_name, str(lead_time).zfill(2), str(number).zfill(2))

        if save_extremes_path.exists() and save_data_path.exists() and (not recalc): # 判断是否已经完成计算，或者需要重新计算
            print("Already saved:", save_extremes_path)
        else:
            if save_data_path.exists():
                forecast_data = xr.open_dataset(save_data_path)["total_precipitation_24hr"]
            else:
                forecast_data = xr.open_zarr(forecast_path_dict[model_name])
                if "total_precipitation_24hr_from_6hr" in forecast_data.data_vars:
                    forecast_data = forecast_data.rename({"total_precipitation_24hr_from_6hr": "total_precipitation_24hr"})
                    forecast_data["total_precipitation_24hr"] = forecast_data["total_precipitation_24hr"]/1000 # fuxi的total_precipitation_24hr单位为mm
                forecast_data = forecast_data["total_precipitation_24hr"]

                time_range = pd.date_range(datetime.datetime(2020,1,1,0,0,0), datetime.datetime(2021,1,1,1,0,0), freq="12h")
                prediction_timedelta = np.timedelta64(lead_time, 'D')
                # print(forecast_data.prediction_timedelta.data.astype('timedelta64[h]'))

                forecast_data = forecast_data.sel(time=np.intersect1d(time_range, forecast_data.time.data), number=number,
                                                  prediction_timedelta=prediction_timedelta)

                forecast_data["time"] = forecast_data["time"] + np.timedelta64(lead_time, 'D')
                forecast_data = forecast_data.sel(time=np.intersect1d(time_range, forecast_data.time.data))

                if "lat" in forecast_data.coords:
                    forecast_data = forecast_data.rename({'lat': 'latitude', 'lon': 'longitude'})

                forecast_data.to_netcdf(save_data_path)

            tp24h_threshlod = tp24h_era5_threshold
            forecast_extremes = xr.Dataset({#'anomality': (forecast_data-t2max_threshlod).astype(np.float32),
                                            #'data': forecast_data,
                                      'events': forecast_data>tp24h_threshlod, })
            forecast_extremes.to_netcdf(save_extremes_path)
            print("Save file:", save_extremes_path)


In [None]:
initial_times = [0, 12]
lead_times = np.arange(7, 8)
forecast_path_dict = {"graphcast": graphcast_2020_path, "graphcast-oper":graphcast_hres_init_2020_path,
                      "pangu": pangu_path, "pangu-oper": pangu_oper_path,
                      "hres": hres_path,"hres_ens_mean": ifs_ens_mean_path, "hres_ens": ifs_ens_path,
                      "era5-forecast": era5_forecast_path,
                      # "keisler": keisler_path, "sphericalcnn":sphericalcnn_path,
                      "fuxi":fuxi_path,
                      # "neuralgcm": neuralgcm_path
                      }

In [None]:
resolution = 1.50
year = 2020

# for model_name in tqdm(forecast_path_dict.keys()):
for model_name in tqdm(["hres_ens"]):

  for lead_time in tqdm(lead_times, leave=False):
    calculate_tp24h_extremes(model_name, resolution, year, lead_time)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

Already saved: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_hres_ens_07d_26em.nc
Already saved: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_hres_ens_07d_27em.nc
Already saved: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_hres_ens_07d_28em.nc
Already saved: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_hres_ens_07d_29em.nc
Already saved: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_hres_ens_07d_30em.nc
Already saved: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_hres_ens_07d_31em.nc
Already saved: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_hres_ens_07d_32em.nc
Already saved: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_hres_ens_07d_33em.nc
Already saved: /content/drive/MyDrive/WB2BinaryForecast/Extremes/Extremes_tp24h_150_2020_hres_ens_07d_34em.nc
Already sa

### Concat the 50 ensemble members to one dataset

In [None]:
%%time
resolution = 1.50
lead_times = np.arange(1, 11)
year = 2020
model_name = "hres_ens"

for lead_time in tqdm(lead_times, leave=False):
    ens_extremes = []
    for number in trange(1, 51, leave=False):
        save_number_extremes_path = work_dir / "Extremes/Extremes_tp24h_{}_{}_{}_{}d_{}em.nc".format(
            str(resolution).replace(".", "").ljust(3, "0"), str(year), model_name, str(lead_time).zfill(2), str(number).zfill(2))
        if not save_number_extremes_path.exists():
            print(save_number_extremes_path)
        
        test_data = xr.open_dataset(save_number_extremes_path)#["events"]
        ens_extremes.append(test_data)
    ens_extremes = xr.concat(ens_extremes, dim="number")
    ens_extremes["number"] = np.arange(1, 51)
    ens_extremes = ens_extremes.sum(dim="number")
    save_extremes_path = work_dir / "Extremes/Extremes_tp24h_{}_{}_{}_{}d.nc".format(
        str(resolution).replace(".", "").ljust(3, "0"), str(year), model_name, str(lead_time).zfill(2))
    ens_extremes.astype("i1").to_netcdf(save_extremes_path)

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

CPU times: user 1min 17s, sys: 2min 31s, total: 3min 49s
Wall time: 15min 50s


# Concat the raw forecasts to form a single netcdf file

In [None]:
resolution = 1.50
year = 2020
initial_times = [0, 12]
lead_times = np.arange(1, 11)
model_list_heatwave = ["hres", "hres_ens_mean", "era5-forecast",
               "pangu-oper","graphcast-oper", "pangu", "graphcast","fuxi"]
model_list_rainfall = ["hres", "hres_ens_mean", "graphcast", "graphcast-oper", "fuxi"]

forecast_path_dict = {"graphcast": graphcast_2020_path, "graphcast-oper":graphcast_hres_init_2020_path,
                      "pangu": pangu_path, "pangu-oper": pangu_oper_path,
                      "hres": hres_path,"hres_ens_mean": ifs_ens_mean_path,
                      "era5-forecast": era5_forecast_path,
                      # "keisler": keisler_path, "sphericalcnn":sphericalcnn_path,
                      "fuxi":fuxi_path,
                      # "neuralgcm": neuralgcm_path
                      }

time_range = pd.date_range(datetime.datetime(2020,1,1,0,0,0), datetime.datetime(2021,1,1,1,0,0), freq="24h")

In [None]:
models_data = []
for model_name in tqdm(model_list_heatwave):
    lead_time_data = []
    for lead_time in tqdm(lead_times, leave=False):
        forecast_path = work_dir / "Data/Data_tx24h_{}_{}_{}_{}d.nc".format(
            str(resolution).replace(".", "").ljust(3, "0"), str(year), model_name, str(lead_time).zfill(2))

        forecast_data = xr.open_dataset(forecast_path)["2m_temperature"]
        forecast_data["time"] = forecast_data["time"] - np.timedelta64(1, 'D')
        forecast_data = forecast_data.sel(time=np.intersect1d(time_range, forecast_data.time.data))
        lead_time_data.append(forecast_data)

    lead_time_data = xr.concat(lead_time_data, dim="lead_time")
    lead_time_data["lead_time"] = lead_times

    models_data.append(xr.Dataset({model_name: lead_time_data}))
models_data = xr.merge(models_data)
models_data.to_netcdf(work_dir / "Forecasts_tx24h_150_2020.nc")

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

In [None]:
models_data = []

for model_name in tqdm(model_list_rainfall):
    lead_time_data = []
    for lead_time in tqdm(lead_times, leave=False):
        forecast_path = work_dir / "Data/Data_tp24h_{}_{}_{}_{}d.nc".format(
            str(resolution).replace(".", "").ljust(3, "0"), str(year), model_name, str(lead_time).zfill(2))

        forecast_data = xr.open_dataset(forecast_path)["total_precipitation_24hr"]
        forecast_data["time"] = forecast_data["time"] - np.timedelta64(1, 'D')
        forecast_data = forecast_data.sel(time=np.intersect1d(time_range, forecast_data.time.data))

        lead_time_data.append(forecast_data)

    lead_time_data = xr.concat(lead_time_data, dim="lead_time")
    lead_time_data["lead_time"] = lead_times

    models_data.append(xr.Dataset({model_name: lead_time_data}))
models_data = xr.merge(models_data)
models_data.to_netcdf(work_dir / "Forecasts_tp24h_150_2020.nc")

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]