In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# setup all the imports
import matplotlib.font_manager
import matplotlib.pyplot as plt
import cartopy.crs as ccrs  # noqa: E402
import numpy as np
import seaborn as sns
import pandas as pd

flist = matplotlib.font_manager.get_font_names()
from pathlib import Path  # noqa: E402

from extremeweatherbench import cases, defaults, evaluate  # noqa: E402

# make the basepath - change this to your local path
basepath = Path.home() / "extreme-weather-bench-paper" / ""
basepath = str(basepath) + "/"

# ugly hack to load in our plotting scripts
# import sys  # noqa: E402

#sys.path.append(basepath + "/docs/notebooks/")
import src.plots.paper_plotting as pp  # noqa: E402


In [None]:

# load in all of the events in the yaml file
print("loading in the events yaml file")
ewb_cases = cases.load_ewb_events_yaml_into_case_collection()
# build out all of the expected data to evalate the case
# this will not be a 1-1 mapping with ewb_cases because there are multiple data sources
# to evaluate for some cases
# for example, a heat/cold case will have both a case operator for ERA-5 data and GHCN
case_operators = cases.build_case_operators(
    ewb_cases, defaults.get_brightband_evaluation_objects()
)


In [None]:

# to plot the targets, we need to run the pipeline for each case and target
from joblib import Parallel, delayed  # noqa: E402
from joblib.externals.loky import get_reusable_executor  # noqa: E402

# load in all the case info (note this takes awhile in non-parallel form as it has to
# run all the target information for each case)
# this will return a list of tuples with the case id and the target dataset

print("running the pipeline for each case and target")
parallel = Parallel(n_jobs=32, return_as="generator", backend="loky")
case_operators_with_targets_established_generator = parallel(
    delayed(
        lambda co: (
            co.case_metadata.case_id_number,
            evaluate.run_pipeline(co.case_metadata, co.target),
        )
    )(case_operator)
    for case_operator in case_operators
)
case_operators_with_targets_established = list(
    case_operators_with_targets_established_generator
)
# this will throw a bunch of errors below but they're not consequential. this releases
# the memory as it shuts down the workers
get_reusable_executor().shutdown(wait=True)


In [None]:
# make a global color palatte so things are consistent across plots
sns_palette = sns.color_palette("tab10")
sns.set_style("whitegrid")

accessible_colors = [
    "#3394D6",  # blue
    "#E09000",  #  orange "#E69F00",  # orange
    "#A15A7E",  # "#CC79A7",  # reddish purple
    "#CC4A4A",  #  vermillion"#D55E00",  # vermillion
    "#A0A0A0",  # Grey "#000000",  # black
    "#B2B24D",  # Olive
    "#33B890",  # bluish green
    "#78C6F1",  # sky blue
    "#F0E442",  # yellow
]

# defaults for plotting
fourv2_style = {'color': accessible_colors[0]}
gc_style = {'color': accessible_colors[2]}
pangu_style = {'color': accessible_colors[3]}
hres_style = {'color': 'black'}

# the group styles and settings so that we can just easily grab them for the plots and they are globally consistent

ghcn_group_style = {'linestyle':'-', 'marker':'o', 'group':'GHCN'}
era5_group_style = {'linestyle':'--', 'marker':'s', 'group':'ERA5'}

ifs_group_style = {'linestyle':'-', 'marker':'o', 'group':'IFS'}
gfs_group_style = {'linestyle':':', 'marker':'d', 'group':'GFS'}

global_group_style = {'linestyle':'--', 'marker':'*', 'group':'Global'}

hres_group_style = {'linestyle':'-', 'marker':'.', 'group':'HRES'}

# settings for the different models
fourv2_ifs_cira_settings = {'forecast_source':'CIRA FOURv2 IFS', 'label_str': 'ForecastNet V2'} 
fourv2_gfs_cira_settings = {'forecast_source':'CIRA FOURv2 GFS', 'label_str': 'ForecastNet V2'} 
gc_ifs_cira_settings = {'forecast_source':'CIRA GC IFS', 'label_str': 'GraphCast'} 
gc_gfs_cira_settings = {'forecast_source':'CIRA GC GFS', 'label_str': 'GraphCast'} 
pangu_ifs_cira_settings = {'forecast_source':'CIRA PANG IFS', 'label_str': 'Pangu Weather'} 
pangu_gfs_cira_settings = {'forecast_source':'CIRA PANG GFS', 'label_str': 'Pangu Weather'} 

hres_ifs_settings = {'forecast_source':'ECMWF HRES', 'label_str': 'HRES'} 

In [None]:
# load the results back in
fourv2_heat_results = pd.read_pickle(basepath + 'saved_data/fourv2_heat_results.pkl')
pang_heat_results = pd.read_pickle(basepath + 'saved_data/pang_heat_results.pkl')
hres_heat_results = pd.read_pickle(basepath + 'saved_data/hres_heat_results.pkl')
gc_heat_results = pd.read_pickle(basepath + 'saved_data/gc_heat_results.pkl')


fourv2_freeze_results = pd.read_pickle(basepath + 'saved_data/fourv2_freeze_results.pkl')
pang_freeze_results = pd.read_pickle(basepath + 'saved_data/pang_freeze_results.pkl')
hres_freeze_results = pd.read_pickle(basepath + 'saved_data/hres_freeze_results.pkl')
gc_freeze_results = pd.read_pickle(basepath + 'saved_data/gc_freeze_results.pkl')

In [None]:
# plot ERA5 versus GHCN for IFS
fourv2_ifs_ghcn_settings = fourv2_ifs_cira_settings | fourv2_style | ghcn_group_style
gc_ifs_ghcn_settings = gc_ifs_cira_settings | gc_style | ghcn_group_style
pangu_ifs_ghcn_settings = pangu_ifs_cira_settings | pangu_style | ghcn_group_style
hres_ghcn_settings = hres_ifs_settings | hres_style | ghcn_group_style

fourv2_ifs_era5_settings = fourv2_ifs_cira_settings | fourv2_style | era5_group_style
gc_ifs_era5_settings = gc_ifs_cira_settings | gc_style | era5_group_style
pangu_ifs_era5_settings = pangu_ifs_cira_settings | pangu_style | era5_group_style
hres_era5_settings = hres_ifs_settings | hres_style | era5_group_style

# subset the data for the plots
fourv2_heat_ifs_ghcn_plot = pp.subset_results_to_xarray(results_df=fourv2_heat_results, forecast_source=fourv2_ifs_ghcn_settings['forecast_source'], 
                                                     target_source='GHCN', metric='RootMeanSquaredError', init_time='zeroz')
gc_heat_ifs_ghcn_plot = pp.subset_results_to_xarray(results_df=gc_heat_results, forecast_source=gc_ifs_ghcn_settings['forecast_source'], 
                                                 target_source='GHCN', metric='RootMeanSquaredError', init_time='zeroz')
pangu_heat_ifs_ghcn_plot = pp.subset_results_to_xarray(results_df=pang_heat_results, forecast_source=pangu_ifs_ghcn_settings['forecast_source'],
                                           target_source='GHCN', metric='RootMeanSquaredError', init_time='zeroz')
hres_heat_ghcn_plot = pp.subset_results_to_xarray(results_df=hres_heat_results, forecast_source=hres_ghcn_settings['forecast_source'], 
                                    target_source='GHCN', metric='RootMeanSquaredError', init_time='zeroz')

fourv2_heat_ifs_era5_plot = pp.subset_results_to_xarray(results_df=fourv2_heat_results, forecast_source=fourv2_ifs_era5_settings['forecast_source'], 
                                                     target_source='ERA5', metric='RootMeanSquaredError', init_time='zeroz')
gc_heat_ifs_era5_plot = pp.subset_results_to_xarray(results_df=gc_heat_results, forecast_source=gc_ifs_era5_settings['forecast_source'], 
                                                 target_source='ERA5', metric='RootMeanSquaredError', init_time='zeroz')
pangu_heat_ifs_era5_plot = pp.subset_results_to_xarray(results_df=pang_heat_results, forecast_source=pangu_ifs_era5_settings['forecast_source'],
                                           target_source='ERA5', metric='RootMeanSquaredError', init_time='zeroz')
hres_heat_era5_plot = pp.subset_results_to_xarray(results_df=hres_heat_results, forecast_source=hres_era5_settings['forecast_source'], 
                                    target_source='ERA5', metric='RootMeanSquaredError', init_time='zeroz')

# plot the results
heat_data = [fourv2_heat_ifs_ghcn_plot, fourv2_heat_ifs_era5_plot, gc_heat_ifs_ghcn_plot, gc_heat_ifs_era5_plot, 
    pangu_heat_ifs_ghcn_plot, pangu_heat_ifs_era5_plot, hres_heat_ghcn_plot, hres_heat_era5_plot]
heat_settings = [fourv2_ifs_ghcn_settings, fourv2_ifs_era5_settings, gc_ifs_ghcn_settings, 
    gc_ifs_era5_settings, pangu_ifs_ghcn_settings, pangu_ifs_era5_settings, hres_ghcn_settings, hres_era5_settings]

# grab the freeze results
# subset the data for the plots
fourv2_freeze_ifs_ghcn_plot = pp.subset_results_to_xarray(results_df=fourv2_freeze_results, forecast_source=fourv2_ifs_ghcn_settings['forecast_source'], 
                                                     target_source='GHCN', metric='RootMeanSquaredError', init_time='zeroz')
gc_freeze_ifs_ghcn_plot = pp.subset_results_to_xarray(results_df=gc_freeze_results, forecast_source=gc_ifs_ghcn_settings['forecast_source'], 
                                                 target_source='GHCN', metric='RootMeanSquaredError', init_time='zeroz')
pangu_freeze_ifs_ghcn_plot = pp.subset_results_to_xarray(results_df=pang_freeze_results, forecast_source=pangu_ifs_ghcn_settings['forecast_source'],
                                           target_source='GHCN', metric='RootMeanSquaredError', init_time='zeroz')
hres_freeze_ghcn_plot = pp.subset_results_to_xarray(results_df=hres_freeze_results, forecast_source=hres_ghcn_settings['forecast_source'], 
                                    target_source='GHCN', metric='RootMeanSquaredError', init_time='zeroz')

fourv2_freeze_ifs_era5_plot = pp.subset_results_to_xarray(results_df=fourv2_freeze_results, forecast_source=fourv2_ifs_era5_settings['forecast_source'], 
                                                     target_source='ERA5', metric='RootMeanSquaredError', init_time='zeroz')
gc_freeze_ifs_era5_plot = pp.subset_results_to_xarray(results_df=gc_freeze_results, forecast_source=gc_ifs_era5_settings['forecast_source'], 
                                                 target_source='ERA5', metric='RootMeanSquaredError', init_time='zeroz')
pangu_freeze_ifs_era5_plot = pp.subset_results_to_xarray(results_df=pang_freeze_results, forecast_source=pangu_ifs_era5_settings['forecast_source'],
                                           target_source='ERA5', metric='RootMeanSquaredError', init_time='zeroz')
hres_freeze_era5_plot = pp.subset_results_to_xarray(results_df=hres_freeze_results, forecast_source=hres_era5_settings['forecast_source'], 
                                    target_source='ERA5', metric='RootMeanSquaredError', init_time='zeroz')

# plot the results
freeze_data = [fourv2_freeze_ifs_ghcn_plot, fourv2_freeze_ifs_era5_plot, gc_freeze_ifs_ghcn_plot, gc_freeze_ifs_era5_plot, 
    pangu_freeze_ifs_ghcn_plot, pangu_freeze_ifs_era5_plot, hres_freeze_ghcn_plot, hres_freeze_era5_plot]
freeze_settings = [fourv2_ifs_ghcn_settings, fourv2_ifs_era5_settings, gc_ifs_ghcn_settings, 
    gc_ifs_era5_settings, pangu_ifs_ghcn_settings, pangu_ifs_era5_settings, hres_ghcn_settings, hres_era5_settings]



In [None]:

from matplotlib.gridspec import GridSpec

n_rows = 5
n_cols = 3
figsize = (12 * n_cols, 6 * n_rows)
print(figsize)

# Create figure first
fig = plt.figure(figsize=figsize)

# Use GridSpec for better control over subplot sizes, especially with mixed cartopy/regular subplots
# Make column 2 wider for line plots (width_ratios: col0, col1, col2)
# Use negative hspace to compress vertical spacing (negative values allow overlap)
gs = GridSpec(n_rows, n_cols, figure=fig, 
              left=0.05, right=0.95, top=0.98, bottom=0.02,
              wspace=0.1, hspace=0.2,  
              width_ratios=[1, 1, 1.5])  # Make column 2 (index 2) 1.5x wider

# Create a grid of subplots - specify which ones should use cartopy
# Example: cartopy_subplots = [(0, 0), (1, 0)] means rows 0,1 in column 0 use cartopy
# You can modify this list to specify which subplots need cartopy projections
cartopy_subplots = [(0,0), (1,0), (2,0), (3,0), (4,0), (0,1), (1,1), (2,1), (3,1), (4,1)]  # Add tuples like (row, col) for subplots that need cartopy

# Create all subplots
axs = []
for i in range(n_rows):
    row = []
    for j in range(n_cols):
        if (i, j) in cartopy_subplots:
            # Create cartopy subplot
            ax = fig.add_subplot(gs[i, j], projection=ccrs.PlateCarree())
        else:
            # Create regular matplotlib subplot
            ax = fig.add_subplot(gs[i, j])
        row.append(ax)
    axs.append(row)

# Convert to numpy array for easier indexing (matching plt.subplots behavior)
axs = np.array(axs)

# the left hand column of figure one shows all of the cases for each event type
# plot the cases for each event type
print("plotting the cases for each event type")
pp.plot_all_cases(
    ewb_cases,
    event_type="heat_wave",
    fill_boxes=True,
    ax=axs[0, 0],
)
pp.plot_all_cases(
    ewb_cases,
    event_type="freeze",
    fill_boxes=True,
    ax=axs[1, 0],
)
pp.plot_all_cases(
    ewb_cases,
    event_type="tropical_cyclone",
    fill_boxes=True,
    ax=axs[2, 0],
)
pp.plot_all_cases(
    ewb_cases,
    event_type="severe_convection",
    fill_boxes=True,
    ax=axs[3, 0],
)
pp.plot_all_cases(
    ewb_cases,
    event_type="atmospheric_river",
    fill_boxes=True,
    ax=axs[4, 0],
)

# the next column of figure one shows the cases for each event type with the obs
# plot the cases for each event type with the observations
pp.plot_all_cases_and_obs(
    ewb_cases,
    event_type="heat_wave",
    targets=case_operators_with_targets_established,
    ax=axs[0, 1],
)
pp.plot_all_cases_and_obs(
    ewb_cases,
    event_type="freeze",
    targets=case_operators_with_targets_established,
    ax=axs[1, 1],
)
pp.plot_all_cases_and_obs(
    ewb_cases,
    event_type="tropical_cyclone",
    targets=case_operators_with_targets_established,
    ax=axs[2, 1],
)
pp.plot_all_cases_and_obs(
    ewb_cases,
    event_type="severe_convection",
    targets=case_operators_with_targets_established,
    ax=axs[3, 1],
)
# pp.plot_all_cases_and_obs(
#     ewb_cases,
#     event_type="atmospheric_river",
#     targets=case_operators_with_targets_established,
#     ax=axs[4, 1],
# )

# the next column of figure 1 shows how useful the target obs are for each event
pp.plot_results_by_metric(data=heat_data, settings=heat_settings, 
    title='RMSE Global Heat Waves ERA5/GHCN', show_all_in_legend=False, ax=axs[0,2])

pp.plot_results_by_metric(data=freeze_data, settings=freeze_settings, 
    title='RMSE Global Freezes ERA5/GHCN', show_all_in_legend=False, ax=axs[1,2])

# for now, make the last 3 rows of column 2 empty
axs[2,2].axis('off')
axs[3,2].axis('off')
axs[4,2].axis('off')

# Set aspect ratio for non-cartopy subplots to prevent scaling issues
# For column 2 (line plots), use a wider aspect ratio
# for i in range(n_rows):
#     for j in range(n_cols):
#         if (i, j) not in cartopy_subplots:
#             # Make line plots in column 2 wider (rectangular)
#             # Adjust this ratio as needed - smaller number = wider plot
#             axs[i, j].set_aspect(5)  
#         else:
#             # Allow other regular subplots to be flexible
#             axs[i, j].set_aspect('auto')

# Use subplots_adjust to further compress vertical spacing after plots are created
fig.subplots_adjust(hspace=0.0)  # Set hspace to 0 to remove all vertical spacing

#fig.savefig(basepath + "docs/notebooks/figs/figure1.png", dpi=600)