In [None]:
import os
from pathlib import Path
SCRIPT_DIR = os.path.dirname(os.path.abspath("__init__.py"))
SRC_DIR = Path(SCRIPT_DIR).parent.absolute()
print(SRC_DIR)

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import xarray as xr

from data_transformer import extract_stations_from_nc
from data_provider import get_station_indices_map

from taylor_helpers import get_nan_ids, extract_anomalies, get_loo_taylor_metrics
from plotting import create_normed_taylor_diagram

In [None]:
GROUND_TRUTH = "data_sets/ground_truth.nc"
LOO_PRED_FILE_PLAIN = "predictions/loo_pred_plain.nc"
LOO_PRED_FILE_ARM = "predictions/loo_pred_arm.nc"

In [None]:
station_indx_map = get_station_indices_map()
ground_truth = xr.load_dataset(GROUND_TRUTH)
gt_stations = extract_stations_from_nc(ground_truth, station_indx_map)  # Is scaled.
missing_indicies = get_nan_ids(gt_stations)
anomaly_gt_stations = extract_anomalies(gt_stations, station_indx_map, missing_indicies)    # GT has NaNs

# PLAIN
loo_pred_plain = xr.load_dataset(LOO_PRED_FILE_PLAIN)
pred_stations_plain = extract_stations_from_nc(loo_pred_plain, station_indx_map)
anomaly_pred_stations_plain = extract_anomalies(pred_stations_plain, station_indx_map)  
taylor_metrics_plain = get_loo_taylor_metrics(anomaly_gt_stations, anomaly_pred_stations_plain, missing_indicies)



# ARM
loo_pred_arm = xr.load_dataset(LOO_PRED_FILE_ARM)
pred_stations_arm = extract_stations_from_nc(loo_pred_arm, station_indx_map)
anomaly_pred_stations_arm = extract_anomalies(pred_stations_arm, station_indx_map) 
taylor_metrics_arm = get_loo_taylor_metrics(anomaly_gt_stations, anomaly_pred_stations_arm, missing_indicies)

In [None]:
# Verify same order
if not taylor_metrics_arm.keys() == taylor_metrics_plain.keys():
    raise Exception('Order is not the same!')


In [None]:
# markers

# Double-Taylor Plot

In [None]:
import matplotlib.pyplot as plt
from matplotlib.projections import PolarAxes
from mpl_toolkits.axisartist import grid_finder
from mpl_toolkits.axisartist import floating_axes
import numpy as np

In [None]:
#import warnings
#warnings.filterwarnings('error')

ref_std = 1
std_axis_min = 0
std_axis_max = 2 * ref_std
# markers = ["." if "_ta" in k else "x" for k in taylor_metrics_arm.keys()]
markers = [".", "x", "s", "+"]
marker_sizes = [8, 5, 4, 6]


polar_transform = PolarAxes.PolarTransform()
corr_labels = np.array([0, 0.2, 0.4, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99, 1])
polar_axis_min = 0
polar_axis_max = np.pi/2
polar_label_locations = np.arccos(corr_labels)
corr_pos_label_mapper = grid_finder.DictFormatter({
        polar_label_locations[i] : str(corr_labels[i]) for i in range(len(corr_labels))
    })
locator = grid_finder.FixedLocator(polar_label_locations)
grid_helper = floating_axes.GridHelperCurveLinear(aux_trans=polar_transform,
                                                  extremes=(polar_axis_min, polar_axis_max,
                                                                std_axis_min, std_axis_max),
                                                      grid_locator1=locator,
                                                      tick_formatter1=corr_pos_label_mapper)

fig = plt.figure(figsize=(10,10))
ax1 = floating_axes.FloatingSubplot(fig, 221, grid_helper=grid_helper)
ax2 = floating_axes.FloatingSubplot(fig, 222, grid_helper=grid_helper)
ax3 = floating_axes.FloatingSubplot(fig, 223, grid_helper=grid_helper)
ax4 = floating_axes.FloatingSubplot(fig, 224, grid_helper=grid_helper)

#      t-plain, p-plain, t-arm, p-arm
axes = [ax1, ax2, ax3, ax4]
include_legend = [False, True]
titles = ["LOO Temperature", "LOO Pressure", "LOO-ARM Temperature", "LOO-ARM Pressure",] # 
var_det = ["_ta", "_slp"]
taylor_metrics = [taylor_metrics_plain, taylor_metrics_arm]

for plt_id in range(len(axes)):
    
    # ta, slp, ta, slp
    _ids = list(map(lambda i: i[0], filter(lambda x: var_det[plt_id%2] in x[1] ,enumerate(taylor_metrics_plain.keys()))))
    metrics = taylor_metrics[plt_id//2] # plain, plain, arm, arm
    
    test_std_devs = [list(metrics.values())[i]['norm_std'] for i in _ids]
    test_corrs = [list(metrics.values())[i]['corr'] for i in _ids]
    labels = [list(metrics.keys())[i].split("_")[0] for i in _ids]
    
    ax = axes[plt_id]
    fig.add_subplot(ax)
    # Add Grid
    ax.grid()
    # Adjust axes
    ax.axis["top"].set_axis_direction("bottom")   # "Angle axis"
    ax.axis["top"].toggle(ticklabels=True, label=True)
    ax.axis["top"].major_ticklabels.set_axis_direction("top")
    ax.axis["top"].label.set_axis_direction("top")
    ax.axis["top"].label.set_text("Correlation")

    ax.axis["left"].set_axis_direction("bottom")  # "X axis"
    ax.axis["left"].label.set_text("Standard deviation")

    ax.axis["right"].set_axis_direction("top")    # "Y-axis"
    ax.axis["right"].toggle(ticklabels=True)
    ax.axis["right"].major_ticklabels.set_axis_direction("left")

    ax.axis["bottom"].set_visible(False)

    # Add Reference Point
    polar_ax = ax.get_aux_axes(polar_transform)
    polar_ax.plot(0, ref_std, marker=6, color="r", markersize=10, label="Ref")


    # Add STD-reference line
    std_ref_line_x = np.linspace(0, polar_axis_max)
    std_ref_line_y = np.zeros_like(std_ref_line_x) + ref_std
    polar_ax.plot(std_ref_line_x, std_ref_line_y, 'k:')

    # Add RMSE contour lines
    rmse_a, rmse_b = np.meshgrid(np.linspace(std_axis_min, std_axis_max),
                                 np.linspace(polar_axis_min, polar_axis_max))
    # According to the law of cosine:
    rmse_ = np.sqrt(ref_std**2 + rmse_a**2 - 2*ref_std*rmse_a*np.cos(rmse_b))
    contour_set = polar_ax.contour(rmse_b, rmse_a, rmse_, levels=4, colors='black', linestyles='--')
    plt.clabel(contour_set, inline=1, fontsize=10, colors='black')


    # Plot samples
    nbr_samples = len(test_std_devs)
    colors = plt.matplotlib.cm.jet(np.linspace(0, 1, nbr_samples))
    for i in range(nbr_samples):
        polar_ax.plot(np.arccos(test_corrs[i]), test_std_devs[i], 
                      markers[i%4],
                      label=labels[i],
                      color=colors[i],
                      markersize=marker_sizes[i%4]
                      )

    plt.title(titles[plt_id])
    
ax1.legend(prop=dict(size='small'), loc='upper right', ncols=1, title="ta",
           bbox_to_anchor=(ax1.get_position().x0 - 0.25, ax1.get_position().y0)  # (x, y, width, height)
           )
ax2.legend(prop=dict(size='small'), loc='upper left', ncols=1, title="slp",
           bbox_to_anchor=(ax2.get_position().x0 + 0.5, ax2.get_position().y0))

plt.savefig(f"figures/plain_arm_comparison.png", 
            bbox_inches='tight',
            pad_inches=0.1,
            dpi=400,
            )

## Create table for appendix

In [None]:
template = r"""
\begin{table}[ht!]
\centering
\resizebox{0.90\textwidth}{!}{
\begin{tabular}{|c|cccccc|cccccc|}
\hline
        & \multicolumn{6}{c|}{LOO}                                                                                                                           & \multicolumn{6}{c|}{LOO-ARM}                                                                                                                       \\ \hline
        & \multicolumn{3}{c|}{Temperature}                                                   & \multicolumn{3}{c|}{Pressure}                                 & \multicolumn{3}{c|}{Temperature}                                                   & \multicolumn{3}{c|}{Pressure}                                 \\ \hline
        & \multicolumn{1}{c|}{SD}  & \multicolumn{1}{c|}{RMSE}  & \multicolumn{1}{c|}{Corr}  & \multicolumn{1}{c|}{SD}  & \multicolumn{1}{c|}{RMSE}  & Corr  & \multicolumn{1}{c|}{SD}  & \multicolumn{1}{c|}{RMSE}  & \multicolumn{1}{c|}{Corr}  & \multicolumn{1}{c|}{SD}  & \multicolumn{1}{c|}{RMSE}  & Corr  \\ \hline
<ROWS>
\end{tabular}
}
\end{table}
"""

In [None]:
repl_row = r"""STAT-ID & \multicolumn{1}{c|}{SD1} & \multicolumn{1}{c|}{RMSE1} & \multicolumn{1}{c|}{CORR1} & \multicolumn{1}{c|}{SD2} & \multicolumn{1}{c|}{RMSE2} & CORR2 & \multicolumn{1}{c|}{SD3} & \multicolumn{1}{c|}{RMSE3} & \multicolumn{1}{c|}{CORR3} & \multicolumn{1}{c|}{SD4} & \multicolumn{1}{c|}{RMSE4} & CORR4 \\ \hline
"""


In [None]:
all_stations = sorted(list(set([x.split('_')[0] for x in list(taylor_metrics_plain.keys())])))
len(all_stations)

In [None]:
"""
STAT-ID     -> Station Name
SD1         LOO, temp
RMSE1       LOO, temp
CORR1       LOO, temp
SD2         LOO, pres
RMSE2       LOO, pres
CORR2       LOO, pres
SD3         arm, temp
RMSE3       arm, temp
CORR3       arm, temp
SD4         arm, pres
RMSE4       arm, pres
CORR4       arm, pres
"""
######################

rows = r""
for stat in all_stations:
    row = repl_row
    id_already_replaced = False
    if stat + "_ta" in taylor_metrics_plain.keys():
        row = row.replace("STAT-ID", stat)
        id_already_replaced = True
        row = row.replace("SD1", str(round(taylor_metrics_plain[stat+"_ta"]['norm_std'],3 )))
        row = row.replace("RMSE1", str(round(taylor_metrics_plain[stat+"_ta"]['norm_rmse'],3 )))
        row = row.replace("CORR1", str(round(taylor_metrics_plain[stat+"_ta"]['corr'],3 )))
        
        row = row.replace("SD3", str(round(taylor_metrics_arm[stat+"_ta"]['norm_std'],3 )))
        row = row.replace("RMSE3", str(round(taylor_metrics_arm[stat+"_ta"]['norm_rmse'],3 )))
        row = row.replace("CORR3", str(round(taylor_metrics_arm[stat+"_ta"]['corr'],3 )))
    else:
        # SET EMPTY
        row = row.replace("SD1", "")
        row = row.replace("RMSE1", "")
        row = row.replace("CORR1", "")
        row = row.replace("SD3", "")
        row = row.replace("RMSE3", "")
        row = row.replace("CORR3", "")
        
        pass
    if stat + "_slp" in taylor_metrics_plain.keys():
        if not id_already_replaced:
            row = row.replace("STAT-ID", stat)
        row = row.replace("SD2", str(round(taylor_metrics_plain[stat+"_slp"]['norm_std'],3 )))
        row = row.replace("RMSE2", str(round(taylor_metrics_plain[stat+"_slp"]['norm_rmse'],3 )))
        row = row.replace("CORR2", str(round(taylor_metrics_plain[stat+"_slp"]['corr'],3 )))
        
        row = row.replace("SD4", str(round(taylor_metrics_arm[stat+"_slp"]['norm_std'],3 )))
        row = row.replace("RMSE4", str(round(taylor_metrics_arm[stat+"_slp"]['norm_rmse'],3 )))
        row = row.replace("CORR4", str(round(taylor_metrics_arm[stat+"_slp"]['corr'],3 )))
    else:
        # SET EMPTY
        row = row.replace("SD2", "")
        row = row.replace("RMSE2", "")
        row = row.replace("CORR2", "")
        row = row.replace("SD4", "")
        row = row.replace("RMSE4", "")
        row = row.replace("CORR4", "")
        
    rows += row

In [None]:
table = template.replace("<ROWS>", rows)
print(table)