# Plot evolution of scores for aao vs seq assimilation (script compare_aao_vs_seq_glsd_data.py

In [None]:
# Load datasets.
import os
import numpy as np
import xarray as xr


base_results_folder = "/storage/homefs/ct19x463/Dev/Climate/reporting/all_at_once_vs_sequential/"
years_folders = ["all_at_once_vs_sequential_62_64/", "all_at_once_vs_sequential_65_68/",
                 "all_at_once_vs_sequential_69_71/", "all_at_once_vs_sequential_72_74/",
                 "all_at_once_vs_sequential_75_77/", "all_at_once_vs_sequential_78_80/"]

prior_means = xr.open_mfdataset(paths=[os.path.join(os.path.join(base_results_folder, year_folder), "prior_means.nc") for year_folder in years_folders])
updated_means_aao = xr.open_mfdataset(paths=[os.path.join(os.path.join(base_results_folder, year_folder), "updated_means_aao.nc") for year_folder in years_folders])
updated_means_seq = xr.open_mfdataset(paths=[os.path.join(os.path.join(base_results_folder, year_folder), "updated_means_seq.nc") for year_folder in years_folders])
references = xr.open_mfdataset(paths=[os.path.join(os.path.join(base_results_folder, year_folder), "references.nc") for year_folder in years_folders])

In [None]:
# Compute scores.
from diesel.scoring import compute_RE_score, compute_CRPS, compute_energy_score, compute_RMSE

# Loop over months in the dataset.
prior_RMSEs, aao_RMSEs, seq_RMSEs, aao_median_REs, seq_median_REs = [], [], [], [], []
for time in updated_means_aao.time:
    # Select corresponding arrays.
    prior_mean = prior_means.sel({'time': time}).temperature
    updated_mean_seq = updated_means_seq.sel({'time': time}).temperature
    updated_mean_aao = updated_means_aao.sel({'time': time}).temperature
    reference = references.sel({'time': time}).temperature

    # Stack in vector format (1D) for further computations.
    stacked_prior_mean = prior_mean.stack(stacked_dim=('latitude', 'longitude')).values
    stacked_updated_mean_seq = updated_mean_seq.stack(stacked_dim=('latitude', 'longitude')).values
    stacked_updated_mean_aao = updated_mean_aao.stack(stacked_dim=('latitude', 'longitude')).values
    stacked_reference = reference.stack(stacked_dim=('latitude', 'longitude')).compute()

    # RMSE
    prior_RMSEs.append(compute_RMSE(stacked_prior_mean, stacked_reference, min_lat=-70, max_lat=70))
    aao_RMSEs.append(compute_RMSE(stacked_updated_mean_aao, stacked_reference, min_lat=-70, max_lat=70))
    seq_RMSEs.append(compute_RMSE(stacked_updated_mean_seq, stacked_reference, min_lat=-70, max_lat=70))

    RE_score_map = compute_RE_score(stacked_prior_mean, stacked_updated_mean_seq, stacked_reference, min_lat=-70, max_lat=70)
    seq_median_REs.append(np.median(RE_score_map))
    RE_score_map = compute_RE_score(stacked_prior_mean, stacked_updated_mean_aao, stacked_reference, min_lat=-70, max_lat=70)
    aao_median_REs.append(np.median(RE_score_map))

In [None]:
# Collect into dataframe.
import pandas as pd 

scores_df = pd.DataFrame.from_dict({'time': list(updated_means_aao.time.values), 'prior RMSE': prior_RMSEs, 'aao RMSE': aao_RMSEs, 'seq RMSE': seq_RMSEs,
                          'aao median RE': aao_median_REs, 'seq median RE': seq_median_REs})

In [None]:
# Plot results.
# RMSE.
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.dates import DateFormatter
import matplotlib.dates as mdates


sns.set()
sns.set_style("white")
plt.rcParams["font.family"] = "serif"
plot_params = {
        'font.size': 26, 'font.style': 'normal',
        'axes.labelsize': 'x-small',
        'axes.titlesize':'x-small',
        'legend.fontsize': 'x-small',
        'xtick.labelsize': 'x-small',
        'ytick.labelsize': 'x-small'
        }
plt.rcParams.update(plot_params)

sns.set_palette("twilight_shifted_r")


plt.figure(figsize=(35, 10))
ax = sns.lineplot(data=scores_df, x="time", y='aao RMSE', lw=4)
ax = sns.lineplot(data=scores_df, x="time", y='seq RMSE', lw=4)
ax = sns.lineplot(data=scores_df, x="time", y='prior RMSE', lw=4)

ax.tick_params(axis='x', rotation=90)
ax.set_xlim(scores_df['time'].min(), scores_df['time'].max())
ax.xaxis.set_major_locator(mdates.MonthLocator(interval=3))
ax.xaxis.set_major_formatter(DateFormatter("%Y-%m"))
plt.xlabel("")
plt.ylabel("RMSE")
plt.legend(["all-at-once", "sequential", "prior"])
plt.savefig("RMSE_comparison", bbox_inches="tight", pad_inches=0.1, dpi=400)
plt.show()

In [None]:
# Plot results.
# RE score.
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.dates import DateFormatter
import matplotlib.dates as mdates


plt.figure(figsize=(35, 10))
ax = sns.lineplot(data=scores_df, x="time", y='aao median RE', lw=4)
ax = sns.lineplot(data=scores_df, x="time", y='seq median RE', lw=4)

ax.tick_params(axis='x', rotation=90)
ax.set_xlim(scores_df['time'].min(), scores_df['time'].max())
ax.xaxis.set_major_locator(mdates.MonthLocator(interval=3))
ax.xaxis.set_major_formatter(DateFormatter("%Y-%m"))
plt.xlabel("")
plt.ylabel("median RE score")
plt.legend(["sequential", "all-at-once"])
plt.savefig("RE_comparison", bbox_inches="tight", pad_inches=0.1, dpi=400)
plt.show()

## Study spatial distribution.

In [None]:
seq_RE_maps, aao_RE_maps = [], []

for time in updated_means_aao.time:
    # Select corresponding arrays.
    prior_mean = prior_means.sel({'time': time}).temperature
    updated_mean_seq = updated_means_seq.sel({'time': time}).temperature
    updated_mean_aao = updated_means_aao.sel({'time': time}).temperature
    reference = references.sel({'time': time}).temperature

    # Stack in vector format (1D) for further computations.
    stacked_prior_mean = prior_mean.stack(stacked_dim=('latitude', 'longitude')).values
    stacked_updated_mean_seq = updated_mean_seq.stack(stacked_dim=('latitude', 'longitude')).values
    stacked_updated_mean_aao = updated_mean_aao.stack(stacked_dim=('latitude', 'longitude')).values
    stacked_reference = reference.stack(stacked_dim=('latitude', 'longitude')).compute()

    RE_score_map = compute_RE_score(stacked_prior_mean, stacked_updated_mean_seq, stacked_reference, min_lat=-70, max_lat=70)
    seq_RE_maps.append(RE_score_map)
    RE_score_map = compute_RE_score(stacked_prior_mean, stacked_updated_mean_aao, stacked_reference, min_lat=-70, max_lat=70)
    aao_RE_maps.append(RE_score_map)

In [None]:
print(stacked_prior_mean.shape)
print(stacked_updated_mean_seq.shape)
print(stacked_reference.shape)

# Plot RE scores

In [None]:
def compute_RE_score_V2(updated_means, prior_means, references, min_lat=-70, max_lat=70):
    """ Compute reduction of error (RE) skill score.
    This is V2 since it uses the correct definition 
    (old one was wrong.
    
    Parameters
    ----------
    
    Returns
    -------
    RE_score_map (lat, lon)
    
    """
    # Filter out extremal latitudes.
    lat_filter_inds = (references.latitude < max_lat).data & (references.latitude > min_lat).data

    filtered_ref = references.temperature.data [:, lat_filter_inds, :]                                         
    filtered_prior_means = prior_means.temperature.data [:, lat_filter_inds, :]                                         
    filtered_updated_means = updated_means.temperature.data [:, lat_filter_inds, :]  
    
    # Flatten the space dimensions.
    stacked_ref = filtered_ref.reshape((filtered_ref.shape[0], filtered_ref.shape[1] * filtered_ref.shape[2]))
    stacked_prior_means = filtered_prior_means.reshape((filtered_ref.shape[0], filtered_ref.shape[1] * filtered_ref.shape[2]))
    stacked_updated_means = filtered_updated_means.reshape((filtered_ref.shape[0], filtered_ref.shape[1] * filtered_ref.shape[2]))
    
    # Cut to the shape of the reference (continents only) by getting rid of Nans.
    # Shape of the NaNs is always the same, so we filter using the shape at time 0.
    nan_inds = np.isnan(stacked_ref[0, :]).compute()
    stacked_ref_nonan = stacked_ref.compute()[:, ~nan_inds]                                
    stacked_prior_means_nonan = stacked_prior_means.compute()[:, ~nan_inds]
    stacked_updated_means_nonan = stacked_updated_means.compute()[:, ~nan_inds]
    
    # Compute score, averaging over time axis.
    RE_score = 1 - np.mean((stacked_updated_means_nonan - stacked_ref_nonan)**2, axis=0) / np.mean((stacked_prior_means_nonan - stacked_ref_nonan)**2, axis=0)
    
    return RE_score

def RE_score_to_map(RE_score, references, min_lat=-70, max_lat=70):
    # Redo the filtering operations done in computing the RE score,
    # so that we can undo them later.
    # ------------------------------
    # Filter out extremal latitudes.
    lat_filter_inds = (references.latitude < max_lat).data & (references.latitude > min_lat).data
    filtered_ref = references.temperature.data [:, lat_filter_inds, :]                                         
    
    # Flatten the space dimensions.
    stacked_ref = filtered_ref.reshape((filtered_ref.shape[0], filtered_ref.shape[1] * filtered_ref.shape[2]))
    # Cut to the shape of the reference (continents only) by getting rid of Nans.
    # Shape of the NaNs is always the same, so we filter using the shape at time 0.
    nan_inds = np.isnan(stacked_ref[0, :]).compute()
    
    # Put back in something we can plot.
    # First put back in the latitute filtered dataset.
    dummy_plot_dataset = np.zeros(stacked_ref.shape[1:])
    dummy_plot_dataset[:] = np.nan
    dummy_plot_dataset[~nan_inds] = RE_score
    dummy_plot_dataset = dummy_plot_dataset.reshape((filtered_ref.shape[1], filtered_ref.shape[2]))

    
    # Return corresponding lat/lon.
    dummy_plot_ds_lats, dummy_plot_ds_lons = references.latitude[lat_filter_inds].data, reference.longitude.data
    
    return dummy_plot_dataset, dummy_plot_ds_lats, dummy_plot_ds_lons
    """
    # Then undo the latitude filtering.
    dummy_plot_dataset_full = np.zeros(references.temperature.data.shape[1:])
    dummy_plot_dataset_full[:] = np.nan
    dummy_plot_dataset_full[lat_filter_inds, :] = dummy_plot_dataset
    """

In [None]:
RE_score_seq = compute_RE_score_V2(updated_means_seq, prior_means, references, min_lat=-70, max_lat=70)
dummy_plot_dataset_seq, dummy_plot_ds_lats, dummy_plot_ds_lons = RE_score_to_map(RE_score_seq, references, min_lat=-70, max_lat=70)

RE_score_aao = compute_RE_score_V2(updated_means_aao, prior_means, references, min_lat=-70, max_lat=70)
dummy_plot_dataset_aao, dummy_plot_ds_lats, dummy_plot_ds_lons = RE_score_to_map(RE_score_aao, references, min_lat=-70, max_lat=70)

In [None]:
import cartopy.crs as ccrs
import matplotlib.colors
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable

In [None]:
fig, axs = plt.subplots(nrows=1,ncols=2,
                        subplot_kw={'projection': ccrs.Miller()},
                        figsize=(15,17))
# Try custom coloring.
levels = np.linspace(0, 1, 11)
norm = matplotlib.colors.BoundaryNorm(levels,len(levels))
colors = list(plt.cm.inferno(np.linspace(0,1,len(levels)-1)))
colors[0] = "black"
cmap = matplotlib.colors.ListedColormap(colors,"", len(colors))

# Make a filled contour plot
im0 = axs[0].pcolormesh(dummy_plot_ds_lons, dummy_plot_ds_lats, dummy_plot_dataset_seq,
                 transform = ccrs.PlateCarree(), vmin=0, vmax=1)

# Add coastlines
axs[0].coastlines()
axs[0].set_extent([-180, 180, -60, 70], ccrs.PlateCarree())

# Make a filled contour plot
im1 = axs[1].pcolormesh(dummy_plot_ds_lons, dummy_plot_ds_lats, dummy_plot_dataset_aao,
                 transform = ccrs.PlateCarree(), vmin=0, vmax=1)

# Add coastlines
axs[1].coastlines()
axs[1].set_extent([-180, 180, -60, 70], ccrs.PlateCarree())

axs[0].set_title("(a) sequential", fontfamily='serif', loc='center', fontsize='small', fontstyle='normal', y=-0.2)
axs[1].set_title("(b) all-at-once", fontfamily='serif', loc='center', fontsize='small', fontstyle='normal', y=-0.2)
"""
divider = make_axes_locatable(axs[1])
cax = divider.append_axes("right", size="5%", pad=0.05, axes_class=plt.Axes)  
plt.colorbar(im1, cax=cax)
"""
fig.subplots_adjust(wspace=0.02, hspace=0)
fig.subplots_adjust(right=1.0)
cbar_ax = fig.add_axes([axs[1].get_position().x1 + 0.006, axs[1].get_position().y0, 0.02, axs[0].get_position().y1-axs[1].get_position().y0])
fig.colorbar(im1, cax=cbar_ax, label="RE skill score")
plt.savefig("RE_comp_20th_century", bbox_inches="tight", pad_inches=0.1, dpi=400)

In [None]:
# Plot separately.
fig, axs = plt.subplots(nrows=1,ncols=1,
                        subplot_kw={'projection': ccrs.Miller()},
                        figsize=(12,10))

# Make a filled contour plot
im0 = axs.pcolormesh(dummy_plot_ds_lons, dummy_plot_ds_lats, dummy_plot_dataset_seq,
                 transform = ccrs.PlateCarree(), vmin=0, vmax=1)

# Add coastlines
axs.set_extent([-180, 180, -60, 70], ccrs.PlateCarree())
axs.coastlines()

divider = make_axes_locatable(axs)
cax = divider.append_axes("right", size="5%", pad=0.05, axes_class=plt.Axes)  
plt.colorbar(im0, cax=cax)
plt.savefig("RE_20th_century_seq", bbox_inches="tight", pad_inches=0.1, dpi=400)

In [None]:



plt.figure(figsize=(20,20))
ax = plt.gca()
im = ax.imshow(dummy_plot_dataset,vmin=0, vmax=1.0)
from mpl_toolkits.axes_grid1 import make_axes_locatable

divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)  
plt.colorbar(im, cax=cax)
plt.savefig("RE", bbox_inches="tight", pad_inches=0.1, dpi=400)
# a = dummy_plot_dataset[0, ~nan_inds].compute()
# a.shape


plt.figure(figsize=(10, 10))
plt.imshow(dummy_plot_dataset_full)