This notebook performs some of the evaluation connected to the time domain:
- Power spectra a computed for each variable and compared between a reference and emulated data set
- Error metrics are conditioned on season (only bias for now)
- Error metrics are conditioned on time of day, i.e., day (06-12 local time) and night (only bias for now)

In [None]:
import numpy as np
import xarray as xr

from eval_utilities import spatial_temporal_metrics as stm
from eval_utilities import visualization as vis
import matplotlib.pyplot as plt

import cartopy.crs as ccrs
import cartopy.feature as cfeature

# Load Configuration

In [None]:
import yaml
with open(f"config.yaml") as stream:
    try:
        CONFIG = yaml.safe_load(stream)
    except yaml.YAMLError as exc:
        print(exc)

In [None]:
ds_ref = xr.open_zarr(CONFIG["path_ec_euro"]).sel(time=slice("2021-01-01T00", "2022-11-30T00"))
ds_mod = xr.open_zarr(CONFIG["path_xgb_v1"]).sel(time=slice("2021-01-01T00", "2022-11-30T00"))

# Harmonic Analysis

## Power Spectra

Compute the power spectra of all variables conatined in both data sets. 

In [None]:
# Path for saving the plots:
path_png = "/home/ch23/data_ch23/evalution_results/xgbosst_train_2010_2019_val_2020_2020_est_50_hist/visualization/spectrum"

# Loop through all variables contained in both data sets:
common_vars = np.intersect1d(ds_ref.variable, ds_mod.variable)
for var in common_vars:
    vis.power_spectrum(ds_mod, ds_ref, var, path_png)

## Spatial Maps

In [None]:
def plot_amplitude_map(data_ref, data_mod, path_png, freq_keyword):
    """
    Helper function for plotting the spatial maps related to the harmonic analysis. 
    `freq_keyword` appears in the plot title and is made lower case for the file name.
    """
    fig, axs = plt.subplots(1, 2, figsize=(12,4), subplot_kw={'projection': ccrs.PlateCarree()})
    fig.suptitle(f"Amplitude of {freq_keyword} Frequency for {var}")

    # Add map projection details:
    for ax in axs:
        ax.set_extent([ds_ref["lon"].min(), ds_ref["lon"].max(), 
                       ds_ref["lat"].min(), ds_ref["lat"].max()], crs=ccrs.PlateCarree())
        ax.add_feature(cfeature.LAND)
        ax.add_feature(cfeature.OCEAN)
        ax.add_feature(cfeature.COASTLINE)
        ax.add_feature(cfeature.BORDERS, linestyle=':')

    im = axs[0].scatter(ds_ref.clim_data["lon"], ds_ref.clim_data["lat"], c=data_mod, edgecolor='none', s=10)
    axs[0].set(title="Emulator")
    fig.colorbar(im, fraction=0.045, pad=0.04)

    im = axs[1].scatter(ds_ref.clim_data["lon"], ds_ref.clim_data["lat"], c=data_ref, edgecolor='none', s=10)
    axs[1].set(title="Reference")
    fig.colorbar(im, fraction=0.045, pad=0.04)

    fig.savefig(f"{path_png}_{var}_{freq_keyword.lower()}.png", bbox_inches="tight")

In [None]:
# Path for saving the plots:
path_png = "/home/ch23/data_ch23/evalution_results/xgbosst_train_2010_2019_val_2020_2020_est_50_hist/visualization/harmonic_analysis"

# Loop through all variables contained in both data sets:
common_vars = np.intersect1d(ds_ref.variable, ds_mod.variable)
for var in common_vars:
    time_axis = np.where(np.array(ds_ref.data.sel(variable=var).shape) == len(ds_ref.time))[0][0]
    fft_ref = np.fft.rfft(ds_ref.data.sel(variable=var), axis=time_axis)
    fft_mod = np.fft.rfft(ds_mod.data.sel(variable=var), axis=time_axis)
    freq = np.fft.rfftfreq(ds_ref.sizes["time"], d=(ds_ref.time[1] - ds_ref.time[0]).item() / 1e9)

    i_day = np.argmin(np.abs(freq - 1/(24*60*60)))
    plot_amplitude_map(abs(fft_ref[i_day]), abs(fft_mod[i_day]), path_png, "Diurnal")
    
    i_month =  np.argmin(np.abs(freq - 1/(30*24*60*60)))
    plot_amplitude_map(abs(fft_ref[i_month]), abs(fft_mod[i_month]), path_png, "Monthly")

    i_season = np.argmin(np.abs(freq - 4/(365*24*60*60)))
    plot_amplitude_map(abs(fft_ref[i_season]), abs(fft_mod[i_season]), path_png, "Seasonal")

    i_year = np.argmin(np.abs(freq - 1/(365*24*60*60))) 
    plot_amplitude_map(abs(fft_ref[i_year]), abs(fft_mod[i_year]), path_png, "Annual")
    

# Condition on Season

In [None]:
season_masks = {"DJF": ds_ref["time"].dt.month.isin([12,1,2]), 
                "MAM": ds_ref["time"].dt.month.isin([3,4,5]),
                "JJA": ds_ref["time"].dt.month.isin([6,7,8]), 
                "SON": ds_ref["time"].dt.month.isin([9,10,11])}

In [None]:
path_png = "/home/ch23/data_ch23/evalution_results/xgbosst_train_2010_2019_val_2020_2020_est_50_hist/temporal"

In [None]:
#common_vars = ["stl1"] 
common_vars = np.intersect1d(ds_mod.variable, ds_ref.variable)

for var in common_vars:
    seasonal_results = np.full([4, ds_ref.sizes["x"]], np.nan)

    for i, sm in enumerate(season_masks.values()):
        seasonal_results[i] = stm.bias(ds_mod.isel(time=sm), ds_ref.isel(time=sm), vars=var)

    fig, ax = plt.subplots(figsize=(8, 6))

    ax.boxplot(seasonal_results.T, labels=season_masks.keys())
    ax.set_ylim(np.nanpercentile(seasonal_results, 0.05), np.nanpercentile(seasonal_results, 99.95))
    ax.set(title=f"{var} bias in different seasons")
    
    fig.savefig(f"{path_png}/bias_season_{var}.png", bbox_inches="tight")
    #plt.show()

In [None]:
#common_vars = ["stl1"] 
common_vars = np.intersect1d(ds_mod.variable, ds_ref.variable)

for var in common_vars:
    seasonal_results = np.full([4, ds_ref.sizes["x"]], np.nan)

    for i, sm in enumerate(season_masks.values()):
        seasonal_results[i] = stm.rmse(ds_mod.isel(time=sm), ds_ref.isel(time=sm), vars=var)

    fig, ax = plt.subplots(figsize=(8, 6))

    ax.boxplot(seasonal_results.T, labels=season_masks.keys())
    ax.set_ylim(np.nanpercentile(seasonal_results, 0.05), np.nanpercentile(seasonal_results, 99.95))
    ax.set(title=f"{var} rmse in different seasons")
    
    fig.savefig(f"{path_png}/rmse_season_{var}.png", bbox_inches="tight")
    #plt.show()

# Condition on Time

In [None]:
standard_time = ds_ref["time"].dt.hour.expand_dims(dim={"x": ds_ref["x"]})

local_time = standard_time - (4 * ds_ref["lon"])/60. #the sun takes 4 min to traverse 1° longitude
local_time = local_time.T % 24 #convert negative values

day_mask = (local_time >= 6) & (local_time < 18)

In [None]:
#common_vars = ["stl1"] 
common_vars = np.intersect1d(ds_mod.variable, ds_ref.variable)

for var in common_vars:
    diurnal_results = np.full([2, ds_ref.sizes["x"]], np.nan)

    diurnal_results[0] = stm.bias(ds_mod.where(day_mask), ds_ref.where(day_mask), vars=var)
    diurnal_results[1] = stm.bias(ds_mod.where(~day_mask), ds_ref.where(~day_mask), vars=var)

    fig, ax = plt.subplots(figsize=(8, 6))

    ax.boxplot(diurnal_results.T, labels=["day","night"])
    ax.set_ylim(np.nanpercentile(diurnal_results, 0.05), np.nanpercentile(diurnal_results, 99.95))
    ax.set(title=f"{var} bias in different times of day")

    fig.savefig(f"{path_png}/bias_diurnal_{var}.png", bbox_inches="tight")
    #plt.show()

In [None]:
#common_vars = ["stl1"] 
common_vars = np.intersect1d(ds_mod.variable, ds_ref.variable)

for var in common_vars:
    diurnal_results = np.full([2, ds_ref.sizes["x"]], np.nan)

    diurnal_results[0] = stm.rmse(ds_mod.where(day_mask), ds_ref.where(day_mask), vars=var)
    diurnal_results[1] = stm.rmse(ds_mod.where(~day_mask), ds_ref.where(~day_mask), vars=var)

    fig, ax = plt.subplots(figsize=(8, 6))

    ax.boxplot(diurnal_results.T, labels=["day","night"])
    ax.set_ylim(np.nanpercentile(diurnal_results, 0.05), np.nanpercentile(diurnal_results, 99.95))
    ax.set(title=f"{var} rmse in different times of day")

    fig.savefig(f"{path_png}/rmse_diurnal_{var}.png", bbox_inches="tight")
    #plt.show()