# Validate Plain

In [None]:
import os
from pathlib import Path
SCRIPT_DIR = os.path.dirname(os.path.abspath("__init__.py"))
SRC_DIR = Path(SCRIPT_DIR).parent.absolute()
print(SRC_DIR)

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import numpy as np

from plotting import display_predictions

In [None]:
FULL_PRED_FILE = "predictions/full_pred_plain.npy"

In [None]:
DATES = pd.date_range('1807-01-01', freq='D', periods=365).values
DATES = list(map(lambda d: str(d).split('T')[0], DATES))
def date_to_id(date):
    return DATES.index(date)

## Full

In [None]:
pred = np.load(FULL_PRED_FILE)
pred.shape

In [None]:
# Convert to hecto pascal
pred[...,1] = pred[...,1] /100
pred.shape

In [None]:
date = "1807-01-01"
day_id = date_to_id(date)
fig = display_predictions(pred, day_id,date)

In [None]:
fig = display_predictions(pred, day_id,date, show_contours=True)

## LOO

In [None]:
import xarray as xr

from data_transformer import extract_stations_from_nc
from data_provider import get_station_indices_map

from taylor_helpers import get_nan_ids, extract_anomalies, get_loo_taylor_metrics
from plotting import create_normed_taylor_diagram

In [None]:
GROUND_TRUTH = "data_sets/ground_truth.nc"
LOO_PRED_FILE = "predictions/loo_pred_plain.nc"

In [None]:
station_indx_map = get_station_indices_map()
ground_truth = xr.load_dataset(GROUND_TRUTH)
gt_stations = extract_stations_from_nc(ground_truth, station_indx_map)  # Is scaled.

loo_pred = xr.load_dataset(LOO_PRED_FILE)
pred_stations = extract_stations_from_nc(loo_pred, station_indx_map)

In [None]:
missing_indicies = get_nan_ids(gt_stations)

anomaly_pred_stations = extract_anomalies(pred_stations, station_indx_map)                  # Pred has no NaNs
anomaly_gt_stations = extract_anomalies(gt_stations, station_indx_map, missing_indicies)    # GT has NaNs
taylor_metrics = get_loo_taylor_metrics(anomaly_gt_stations, anomaly_pred_stations, missing_indicies)

In [None]:
fig = create_normed_taylor_diagram(ref_std=1,
                                   test_std_devs=[m['norm_std'] for m in taylor_metrics.values()],
                                   test_corrs=[m['corr'] for m in taylor_metrics.values()],
                                   labels=list(taylor_metrics.keys()),
                                   )

fig.savefig(f"figures/taylor_loo_plain.png", bbox_inches='tight', pad_inches=0.1)

In [None]:
taylor_df = pd.DataFrame(taylor_metrics)

print(f"Min Corr: {taylor_df.loc['corr'].min()}")
print(f"Max normed-StdDev Delta: {max(abs(taylor_df.loc['norm_std'].max()-1), abs(1-taylor_df.loc['norm_std'].min()))}")
print(f"Max normed-RMSE: {taylor_df.loc['norm_rmse'].max()}")

In [None]:
taylor_df

### RMSE on LOO (non normalized)

In [None]:
rmse_stations_ta = []
rmse_stations_slp = []

for station_id in gt_stations.keys():
    pred = pred_stations[station_id]
    gt = gt_stations[station_id]
    if missing_indicies[station_id]:
        pred = np.delete(pred, missing_indicies[station_id])
        gt = np.delete(gt, missing_indicies[station_id])
    
    rmse_station = np.sqrt(np.mean(np.square(gt - pred)))
    if "_ta" in station_id:
        rmse_stations_ta.append(rmse_station)
    else:
        rmse_stations_slp.append(rmse_station)

print(np.mean(rmse_stations_ta))
print(np.mean(rmse_stations_slp))
    

## Time Series: Best / Worst LOO (TA only?)

In [None]:
import pandas as pd

In [None]:
metrics_df = pd.DataFrame(taylor_metrics)
ta_only_cols = list(filter(lambda x: "_ta" in x, metrics_df.columns))
metrics_df = metrics_df[ta_only_cols]

worst_station = metrics_df.idxmax(axis=1)['norm_rmse']
best_station = metrics_df.idxmin(axis=1)['norm_rmse']
print(f"Worst station: {worst_station}")
print(f"Best station: {best_station}")

best_rmse = metrics_df.loc['norm_rmse'][best_station].round(3)
worst_rmse = metrics_df.loc['norm_rmse'][worst_station].round(3)
print(best_rmse, worst_rmse)

print(metrics_df.loc['norm_rmse'][[worst_station, best_station]])

In [None]:
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
CITIES = {'CBT': 'Central Belgium', 'STK': 'Stockholm', "STP": 'St. Petersburg'}

def create_station_line_plot(best_pred, best_gt, best_stat_id, best_baseline,
                             worst_pred, worst_gt, worst_stat_id, worst_baseline,
                             save_file_path=None):

    fig, axs = plt.subplots(2, 1, figsize=(16, 11))

    if "_slp" in best_stat_id:
        # Convert to hPa
        best_pred = best_pred / 100
        best_gt = best_gt / 100
    if "_slp" in worst_stat_id:
        # Convert to hPa
        worst_pred = worst_pred / 100
        worst_gt = worst_gt / 100

    date_range = pd.date_range('1807-01-01', freq='D', periods=365)

    best_pred_df = pd.DataFrame({'vals': best_pred, 'dates': date_range})
    best_gt_df = pd.DataFrame({'vals': best_gt, 'dates': date_range})
    best_bl_df = pd.DataFrame({'vals': best_baseline, 'dates': date_range})

    best_station_name = CITIES[best_stat_id.split('_')[0]]
    best_title_addon = "Temperature" if "_ta" in best_stat_id else "Pressure"
    axs[0].set_title(best_station_name + " " + best_title_addon)

    axs[0].plot('dates', 'vals', data=best_pred_df, color='red', label="Reconstruction")
    axs[0].plot('dates', 'vals', data=best_gt_df, color='blue', label="Station observation")
    axs[0].plot('dates', 'vals', data=best_bl_df, color='green', label="Baseline (20CR mean)")

    axs[0].set_ylabel("ta [°C]" if "_ta" in best_stat_id else "slp [hPa]")
    # axs[0].set_xlabel("Year 1807")

    axs[0].grid(True)
    axs[0].legend()

    # Major ticks every half year, minor ticks every month,
    axs[0].xaxis.set_major_locator(mdates.MonthLocator(bymonth=(1, 7)))
    axs[0].xaxis.set_minor_locator(mdates.MonthLocator())
    
    
    worst_pred_df = pd.DataFrame({'vals': worst_pred, 'dates': date_range})
    worst_gt_df = pd.DataFrame({'vals': worst_gt, 'dates': date_range})
    worst_bl_df = pd.DataFrame({'vals': worst_baseline, 'dates': date_range})
    
    worst_station_name = CITIES[worst_stat_id.split('_')[0]]
    worst_title_addon = "Temperature" if "_ta" in worst_stat_id else "Pressure"
    axs[1].set_title(worst_station_name + " " + worst_title_addon)

    axs[1].plot('dates', 'vals', data=worst_pred_df, color='red', label="Reconstruction")
    axs[1].plot('dates', 'vals', data=worst_gt_df, color='blue', label="Station observation")
    axs[1].plot('dates', 'vals', data=worst_bl_df, color='green', label="Baseline (20CR mean)")

    axs[1].set_ylabel("ta [°C]" if "_ta" in worst_stat_id else "slp [hPa]")
    # axs[1].set_xlabel("Year 1807")

    axs[1].grid(True)
    axs[1].legend()
    
    if save_file_path:
        plt.savefig(save_file_path, bbox_inches='tight', pad_inches=0.1, dpi=300)

    return fig

In [None]:
best_baseline = np.load("mean20CR/stp_ta.npy")
worst_baseline = np.load("mean20CR/cbt_ta.npy")
print(worst_baseline.shape)

In [None]:
fig = create_station_line_plot(best_pred=pred_stations[best_station],
                               best_gt=gt_stations[best_station],
                               best_stat_id=best_station,
                               best_baseline=best_baseline,
                               worst_pred=pred_stations[worst_station],
                               worst_gt=gt_stations[worst_station],
                               worst_stat_id=worst_station,
                               worst_baseline=worst_baseline,
                               save_file_path="figures/best_worst_time_series.png"
                            )

In [None]:
def ts_rmse(ts1, ts2):
    return np.sqrt(np.mean(np.square(ts1 - ts2)))

In [None]:
print(f"RMSE BL-GT (STP): {round(ts_rmse(best_baseline, gt_stations[best_station]), 3)}")
print(f"RMSE PRED-GT (STP): {round(ts_rmse(pred_stations[best_station], gt_stations[best_station]), 3)}")

print(f"RMSE BL-GT (CBT): {round(ts_rmse(worst_baseline, gt_stations[worst_station]), 3)}")
print(f"RMSE PRED-GT (CBT): {round(ts_rmse(pred_stations[worst_station], gt_stations[worst_station]), 3)}")

## Average highest temperature according to GT

In [None]:
pred = np.load(FULL_PRED_FILE)
pred.shape

In [None]:
temp_df = pd.DataFrame(gt_stations)
temp_df = temp_df[list(filter(lambda c: "_ta" in c,  temp_df.columns))]
temp_df

In [None]:
temp_df.mean(axis=1).argmax()

In [None]:
avg_max_id = 211
avg_max_date = DATES[avg_max_id]
print(avg_max_date)

mid_id = avg_max_id - 2
mid_date = DATES[mid_id]

fig = display_predictions(pred, mid_id, mid_date, True)
plt.savefig("figures/hottest_gt_day.png", bbox_inches='tight', pad_inches=0.1)