# Validate ARM

In [None]:
import os
from pathlib import Path
SCRIPT_DIR = os.path.dirname(os.path.abspath("__init__.py"))
SRC_DIR = Path(SCRIPT_DIR).parent.absolute()
print(SRC_DIR)

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import numpy as np

from plotting import display_predictions

In [None]:
FULL_PRED_FILE = "predictions/full_pred_arm.npy"

In [None]:
DATES = pd.date_range('1807-01-01', freq='D', periods=365).values
DATES = list(map(lambda d: str(d).split('T')[0], DATES))
def date_to_id(date):
    return DATES.index(date)

## Full

In [None]:
pred = np.load(FULL_PRED_FILE)
pred.shape

In [None]:
# Convert to hecto pascal
pred[...,1] = pred[...,1] /100

In [None]:
date = "1807-01-01"
day_id = date_to_id(date)
fig = display_predictions(pred, day_id, date)

In [None]:
fig = display_predictions(pred, day_id,date, show_contours=True)

## LOO

In [None]:
import xarray as xr

from data_transformer import extract_stations_from_nc
from data_provider import get_station_indices_map

from taylor_helpers import get_nan_ids, extract_anomalies, get_loo_taylor_metrics
from plotting import create_normed_taylor_diagram

In [None]:
GROUND_TRUTH = "data_sets/ground_truth.nc"
LOO_PRED_FILE = "predictions/loo_pred_arm.nc"

In [None]:
station_indx_map = get_station_indices_map()
ground_truth = xr.load_dataset(GROUND_TRUTH)
gt_stations = extract_stations_from_nc(ground_truth, station_indx_map)  # Is scaled.

pred = xr.load_dataset(LOO_PRED_FILE)
pred_stations = extract_stations_from_nc(pred, station_indx_map)

In [None]:
missing_indicies = get_nan_ids(gt_stations)
anomaly_pred_stations = extract_anomalies(pred_stations, station_indx_map)                  # Pred has no NaNs
anomaly_gt_stations = extract_anomalies(gt_stations, station_indx_map, missing_indicies)    # GT has NaNs
taylor_metrics = get_loo_taylor_metrics(anomaly_gt_stations, anomaly_pred_stations, missing_indicies)

In [None]:
fig = create_normed_taylor_diagram(ref_std=1,
                                   test_std_devs=[m['norm_std'] for m in taylor_metrics.values()],
                                   test_corrs=[m['corr'] for m in taylor_metrics.values()],
                                   labels=list(taylor_metrics.keys()),
                                   )

fig.savefig(f"figures/taylor_loo_arm.png", bbox_inches='tight', pad_inches=0.1)

In [None]:
taylor_df = pd.DataFrame(taylor_metrics)

print(f"Min Corr: {taylor_df.loc['corr'].min()}")
print(f"Max normed-StdDev Delta: {max(abs(taylor_df.loc['norm_std'].max()-1), abs(1-taylor_df.loc['norm_std'].min()))}")
print(f"Max normed-RMSE: {taylor_df.loc['norm_rmse'].max()}")

### RMSE on LOO (non normalized)

In [None]:
rmse_stations_ta = []
rmse_stations_slp = []

for station_id in gt_stations.keys():
    pred = pred_stations[station_id]
    gt = gt_stations[station_id]
    if missing_indicies[station_id]:
        pred = np.delete(pred, missing_indicies[station_id])
        gt = np.delete(gt, missing_indicies[station_id])
    
    rmse_station = np.sqrt(np.mean(np.square(gt - pred)))
    if "_ta" in station_id:
        rmse_stations_ta.append(rmse_station)
    else:
        rmse_stations_slp.append(rmse_station)

print(np.mean(rmse_stations_ta))
print(np.mean(rmse_stations_slp))
    