## Set Up

In [1]:
import torch
import os
import numpy as np
import pandas as pd
import lightning.pytorch as pl
import warnings
import pyarrow as pa
import pyarrow.parquet as pq
import dill
import pickle

# For plotting
from sklearn.neighbors import KernelDensity
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.gridspec as grid_spec
# from matplotlib import colormaps
from matplotlib.patches import Patch

warnings.filterwarnings('ignore')
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256"

# Paths, lags and lead, time index (weekly resolution)
BASE_PATH = 'D:/KIMoDIs/global-groundwater-models-main'
DATA_PATH = os.path.join(BASE_PATH, 'data')
MODEL_PATH = os.path.join(BASE_PATH, 'models')
RESULT_PATH = os.path.join(BASE_PATH, 'results')
FIGURES_PATH = os.path.join(BASE_PATH, 'figures')
SHARE_PATH = 'J:/Berlin/B22-FISHy/NUTZER/Kunz.S'
CACHE_PATH = os.path.join(BASE_PATH, 'cache')

LAG = 52  # weeks
LEAD = 12  # weeks

# Roughly 80/10/10
TRAIN_PERIOD = (pd.Timestamp(1990, 1, 1), pd.Timestamp(2010, 1, 1))
VAL_PERIOD = (pd.Timestamp(2010, 1, 1), pd.Timestamp(2013, 1, 1))
TEST_PERIOD = (pd.Timestamp(2013, 1, 1), pd.Timestamp(2016, 1, 1))

TIME_IDX = pd.date_range(TRAIN_PERIOD[0], TEST_PERIOD[1], freq='W-SUN', inclusive='neither', name='time').to_frame().reset_index(drop=True)
TIME_IDX.index.name = 'time_idx'
TIME_IDX = TIME_IDX.reset_index()

In [2]:
# Change the model type here
MODEL_TYPE = 'full'
VERSION = '10_Epochs'

In [4]:
train_df = pq.read_table(os.path.join(SHARE_PATH, 'kimodis_preprocessed', 'data', 'train_df.parquet'))
train_df = train_df.to_pandas()
val_df = pq.read_table(os.path.join(SHARE_PATH, 'kimodis_preprocessed', 'data', 'val_df.parquet'))
val_df = val_df.to_pandas()
test_df = pq.read_table(os.path.join(SHARE_PATH, 'kimodis_preprocessed', 'data', 'test_df.parquet'))
test_df = test_df.to_pandas()
print('There are', len(test_df['proj_id'].unique()), 'test sites available overall')

There are 8713 test sites available overall


In [5]:
test_df_in_sample = test_df[test_df['proj_id'].isin(train_df['proj_id'].unique())]
test_df_out_sample = test_df[~test_df['proj_id'].isin(train_df['proj_id'].unique())]
print('There are', len(test_df_in_sample['proj_id'].unique()), 'in sample test sites.')

There are 5308 in sample test sites.


In [6]:
# Subset train_df and val_df to sites also occurring in test_df 
train_df = train_df[train_df['proj_id'].isin(test_df['proj_id'])]
val_df = val_df[val_df['proj_id'].isin(test_df['proj_id'])]

## Predictions for validation set

Testing the weird valdation metrics observed during TFT training by checking several model checkpoints. 
For this test, tft_full has been trained for one epoch and snapshots have been taken every 26ths step (coresponds to 10% of training progress).

In [7]:
# TimeSeriesDataSet for training data
from pytorch_forecasting import TimeSeriesDataSet
train_ds = TimeSeriesDataSet.load(os.path.join(SHARE_PATH, 'kimodis_preprocessed', f'train_ds_{MODEL_TYPE}_tft.pt'))

In [8]:
# Load the model
from pytorch_forecasting.models.temporal_fusion_transformer import TemporalFusionTransformer

# TFT
model_step26 = TemporalFusionTransformer.load_from_checkpoint(os.path.join(
    MODEL_PATH, 
   'tft', 
   f'tft_{MODEL_TYPE}',
   f'tft_{MODEL_TYPE}_4096_{VERSION}',
   'version_0',
   'checkpoints',
   'epoch=0-step=26.ckpt')) 

model_step52 = TemporalFusionTransformer.load_from_checkpoint(os.path.join(
    MODEL_PATH, 
   'tft', 
   f'tft_{MODEL_TYPE}',
   f'tft_{MODEL_TYPE}_4096_{VERSION}',
   'version_0',
   'checkpoints',
   'epoch=0-step=52.ckpt')) 

model_step130 = TemporalFusionTransformer.load_from_checkpoint(os.path.join(
    MODEL_PATH, 
   'tft', 
   f'tft_{MODEL_TYPE}',
   f'tft_{MODEL_TYPE}_4096_{VERSION}',
   'version_0',
   'checkpoints',
   'epoch=0-step=130.ckpt')) 

model_step260 = TemporalFusionTransformer.load_from_checkpoint(os.path.join(
    MODEL_PATH, 
   'tft', 
   f'tft_{MODEL_TYPE}',
   f'tft_{MODEL_TYPE}_4096_{VERSION}',
   'version_0',
   'checkpoints',
   'epoch=0-step=260.ckpt')) 

modeldict = {
    'step26':model_step26,
    'step52':model_step52,
    'step130':model_step130,
    'step260':model_step260
}

In [16]:
# Dataloaders
train_dataloader = train_ds.to_dataloader(train=True, batch_size=1024, num_workers=0)

val_ds = TimeSeriesDataSet.load(os.path.join(SHARE_PATH, 'kimodis_preprocessed', f'val_ds_{MODEL_TYPE}_tft.pt'))
val_dataloader = val_ds.to_dataloader(train=False, batch_size=1024, num_workers=0)

# Helper functions
def predictions_to_df(index: pd.DataFrame, predictions: np.ndarray, group_ids, date_range, lead):
    
    predictions_df = index
    for i, f in enumerate(predictions):
        predictions_df[i] = f

    predictions_df = predictions_df.melt(id_vars=group_ids + ['time_idx'], value_vars=list(range(lead)), var_name='horizon', value_name='forecast')
    predictions_df['time_idx'] = predictions_df['time_idx'] + predictions_df['horizon']
    predictions_df = predictions_df.merge(date_range, on=['time_idx'], how='left')
    predictions_df['horizon'] += 1
    predictions_df.drop('time_idx', axis=1, inplace=True)
    predictions_df.set_index(group_ids+['time', 'horizon'], inplace=True)
    return predictions_df

In [22]:
# Make predictions
from utils import get_metrics

results_dict = {}

for model_name, model in modeldict.items():
    

    val_predictions = model.predict(val_dataloader,
                                    mode='prediction',
                                    return_x=True, 
                                    return_index=True, 
                                    trainer_kwargs=dict(accelerator='cpu', 
                                                                    devices=1))

    val_index = val_predictions.index
    val_pred = val_predictions.output.numpy()

    # np.transpose(val_pred, (1, 0)).shape
    tft_val_pred = predictions_to_df(val_index, np.transpose(val_pred, (1, 0)), ['proj_id'], TIME_IDX, LEAD)
    tft_val_pred = tft_val_pred.reset_index().merge(val_df[['proj_id', 'time', 'gwl']], on=['proj_id', 'time'], how='left')

    # RM NAs and duplicates
    tft_val_pred = tft_val_pred[~tft_val_pred['forecast'].isnull()]
    tft_val_pred = tft_val_pred[~tft_val_pred['gwl'].isnull()]
    tft_val_pred = tft_val_pred[~tft_val_pred.duplicated(subset=['proj_id', 'time', 'horizon'])]
    
    # Calculate metrics
    tft_val_pred = tft_val_pred.set_index(['proj_id', 'time', 'horizon'])
    tft_val_metrics = get_metrics(tft_val_pred, metrics_subset=['NSE', 'rMBE', 'nRMSE', 'RMSE', 'MAE'])
    tft_val_metrics = tft_val_metrics.reset_index()
    
    # Save metrics in the results_dict
    results_dict[model_name] = tft_val_metrics

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [26]:
results_dict['step26'].groupby(['horizon'])['MAE'].describe()
print(results_dict['step26'].groupby(['horizon'])['RMSE'].describe())

          count      mean       std  min      25%    50%    75%        max
horizon                                                                   
1        5308.0  0.411048  0.707453  0.0  0.19200  0.279  0.406  19.351999
2        5308.0  0.413138  0.709254  0.0  0.19300  0.280  0.407  19.288000
3        5308.0  0.412691  0.707202  0.0  0.19275  0.280  0.405  18.771999
4        5308.0  0.409217  0.699794  0.0  0.19000  0.278  0.402  17.908001
5        5308.0  0.406319  0.692137  0.0  0.19000  0.276  0.400  17.528999
6        5308.0  0.404782  0.684714  0.0  0.18900  0.275  0.399  17.271999
7        5308.0  0.403166  0.677768  0.0  0.18800  0.274  0.397  16.872000
8        5308.0  0.400285  0.670623  0.0  0.18600  0.271  0.395  16.575001
9        5308.0  0.398961  0.666864  0.0  0.18400  0.269  0.394  16.483000
10       5308.0  0.397504  0.662393  0.0  0.18300  0.268  0.392  16.417000
11       5308.0  0.396951  0.659870  0.0  0.18300  0.268  0.393  16.392000
12       5308.0  0.395675

In [37]:
mae_summary = {}
rmse_summary = {}
coor_mae_rmse = {}

for model_name, metrics_df in results_dict.items():
    mae_summary[model_name] = metrics_df.groupby(['horizon'])['MAE'].describe()
    rmse_summary[model_name] = metrics_df.groupby(['horizon'])['RMSE'].describe()
    coor_mae_rmse[model_name] = metrics_df.groupby(['horizon'])[['RMSE', 'MAE']].corr(method='pearson')

### Summary of MAE and RMSE values on the validation dataset

In [32]:
mae_summary

{'step26':           count      mean       std  min    25%    50%      75%     max
 horizon                                                                
 1        5308.0  0.313659  0.517123  0.0  0.147  0.212  0.31300  12.546
 2        5308.0  0.315580  0.516645  0.0  0.149  0.213  0.31500  12.215
 3        5308.0  0.316309  0.515241  0.0  0.149  0.213  0.31700  11.708
 4        5308.0  0.315845  0.512319  0.0  0.148  0.213  0.31725  11.161
 5        5308.0  0.315521  0.509692  0.0  0.147  0.213  0.31700  10.984
 6        5308.0  0.315699  0.507632  0.0  0.147  0.213  0.31700  10.823
 7        5308.0  0.315754  0.505952  0.0  0.147  0.213  0.31700  10.674
 8        5308.0  0.315056  0.504071  0.0  0.147  0.212  0.31700  10.519
 9        5308.0  0.315111  0.503917  0.0  0.146  0.211  0.31800  10.377
 10       5308.0  0.315035  0.503580  0.0  0.146  0.211  0.31825  10.194
 11       5308.0  0.315515  0.504505  0.0  0.146  0.211  0.31900  10.053
 12       5308.0  0.315437  0.504995  0.0

In [33]:
rmse_summary

{'step26':           count      mean       std  min      25%    50%    75%        max
 horizon                                                                   
 1        5308.0  0.411048  0.707453  0.0  0.19200  0.279  0.406  19.351999
 2        5308.0  0.413138  0.709254  0.0  0.19300  0.280  0.407  19.288000
 3        5308.0  0.412691  0.707202  0.0  0.19275  0.280  0.405  18.771999
 4        5308.0  0.409217  0.699794  0.0  0.19000  0.278  0.402  17.908001
 5        5308.0  0.406319  0.692137  0.0  0.19000  0.276  0.400  17.528999
 6        5308.0  0.404782  0.684714  0.0  0.18900  0.275  0.399  17.271999
 7        5308.0  0.403166  0.677768  0.0  0.18800  0.274  0.397  16.872000
 8        5308.0  0.400285  0.670623  0.0  0.18600  0.271  0.395  16.575001
 9        5308.0  0.398961  0.666864  0.0  0.18400  0.269  0.394  16.483000
 10       5308.0  0.397504  0.662393  0.0  0.18300  0.268  0.392  16.417000
 11       5308.0  0.396951  0.659870  0.0  0.18300  0.268  0.393  16.392000
 1

In [38]:
coor_mae_rmse

{'step26': metric              RMSE       MAE
 horizon metric                    
 1       RMSE    1.000000  0.990591
         MAE     0.990591  1.000000
 2       RMSE    1.000000  0.989449
         MAE     0.989449  1.000000
 3       RMSE    1.000000  0.988617
         MAE     0.988617  1.000000
 4       RMSE    1.000000  0.988305
         MAE     0.988305  1.000000
 5       RMSE    1.000000  0.988023
         MAE     0.988023  1.000000
 6       RMSE    1.000000  0.987855
         MAE     0.987855  1.000000
 7       RMSE    1.000000  0.987702
         MAE     0.987702  1.000000
 8       RMSE    1.000000  0.987550
         MAE     0.987550  1.000000
 9       RMSE    1.000000  0.987392
         MAE     0.987392  1.000000
 10      RMSE    1.000000  0.987252
         MAE     0.987252  1.000000
 11      RMSE    1.000000  0.987293
         MAE     0.987293  1.000000
 12      RMSE    1.000000  0.987309
         MAE     0.987309  1.000000,
 'step52': metric              RMSE       MAE
 horizo

In [42]:
os.path.join(MODEL_PATH, 
                          'tft',
                          f'tft_{MODEL_TYPE}',
                          f'tft_{MODEL_TYPE}_{VERSION}',
                          'logs',
                          'test_tft_logs',
                          'version_0',)

'D:/KIMoDIs/global-groundwater-models-main\\models\\tft\\tft_full\\tft_full_10_Epochs\\logs\\test_tft_logs\\version_0'

In [46]:
# Training metrics
# load logs
# D:\KIMoDIs\global-groundwater-models-main\models\tft\tft_full\tft_full_4096_10_Epochs\logs\test_tft_logs\version_0
training_metrics = pd.read_csv(os.path.join(MODEL_PATH, 
                          'tft',
                          f'tft_{MODEL_TYPE}',
                          f'tft_{MODEL_TYPE}_4096_{VERSION}',
                          'logs',
                          'test_tft_logs',
                          'version_0',
                          'metrics.csv'))

In [53]:
training_metrics[training_metrics['step'].isin((25,51,129,259))]

Unnamed: 0,lr-Ranger,step,train_loss_step,epoch,val_loss,val_SMAPE,val_MAE,val_RMSE,val_MAPE,train_loss_epoch
50,0.003981,25,,,,,,,,
51,,25,0.321295,0.0,,,,,,
52,,25,,0.0,0.303989,0.19559,12.750532,45.874382,1244648000.0,
103,0.003981,51,,,,,,,,
104,,51,0.261536,0.0,,,,,,
105,,51,,0.0,0.249662,0.195103,12.744623,45.888596,1245136000.0,
262,0.003981,129,,,,,,,,
263,,129,0.186113,0.0,,,,,,
264,,129,,0.0,0.187016,0.195404,12.75617,45.929871,1246552000.0,
527,0.003981,259,,,,,,,,


### Metrics plot

In [14]:
# FONT_SIZE = 18
# col_grey ='#bababa'

# # plt.style.use('seaborn-v0_8')
# sns.set(rc=
#     {"font.size": FONT_SIZE,
#      "axes.titlesize": FONT_SIZE,
#      "axes.labelsize": FONT_SIZE,
#      "xtick.labelsize": FONT_SIZE-2,
#      "ytick.labelsize": FONT_SIZE-2,
#      "legend.fontsize": FONT_SIZE-2,
#      "legend.title_fontsize": FONT_SIZE-2,
#     }
# )
# plt.rcParams.update({
#     'font.family': 'Times New Roman',
#     'font.size': FONT_SIZE,
#     'axes.labelsize':FONT_SIZE-2,
#     'axes.labelweight': 'bold',
#     'axes.titleweight':'bold',
    
#     'legend.fontsize': FONT_SIZE-2,
#     'legend.title_fontsize': FONT_SIZE-2,
    
#     'axes.facecolor': 'white',
#     'axes.grid': True,
#     'axes.grid.axis': 'both',
#     'axes.grid.which': 'major',
#     'grid.linestyle': '--',
#     'grid.color': 'gray',
#     'grid.linewidth': '0.5',
    
#     'xtick.direction': 'in',
#     'ytick.direction': 'in',
    
#     'savefig.bbox':'tight',
#     'savefig.dpi':300
# })

In [None]:
# p = palette = sns.color_palette("colorblind")
# colors = [p[3], p[3]] # , p[6], p[3], p[1], p[0], p[2], p[4]

# horizons = [1, 12]

# fig_04, ax = plt.subplots(2, 4, figsize=(18, 5), sharey=True)

# test_position = [-0.9,-0.43,0.95,1]
# for j, (metric, _range) in enumerate([('NSE', (-1,1)), ('rMBE', (-0.5, 0.5)), ('RMSE', (0., 1)), ('MAE', (0., 1))]):
#     for i, HORIZON in enumerate(horizons):
#         x = tft_val_metrics[(tft_val_metrics['horizon'] == HORIZON)][metric].replace([np.inf, -np.inf], np.nan).dropna().values
#         median = np.median(x)
#         x_d = np.linspace(-1,4, 2000)
        
#         kde = KernelDensity(bandwidth=0.03, kernel='gaussian')
#         kde.fit(x[:, None])
#         logprob = kde.score_samples(x_d[:, None])
#         y_d = np.exp(logprob)

#         if metric == 'NSE':
#             y_d *= 2
#         elif metric == 'Interval Score':
#             y_d *= 4
        
#         if j == 0:
#             ax[i, j].set_ylabel(f'{HORIZON} w')
            
#         # plotting the distribution
#         ax[i, j].plot(x_d, y_d, color="#f0f0f0", lw=1)
#         ax[i, j].fill_between(x_d, y_d, alpha=1, color=colors[i])
        
#         median_index = np.abs(x_d - median).argmin()
#         median_y = y_d[median_index]
#         ax[i, j].plot((median,median), (0,median_y), color="white", alpha = 0.5, lw=2)   
#         if i == 0:
#             ax[i, j].text(test_position[j], 0.4, 'Median\n{:.2f}'.format(median), va='bottom', ha='center', fontsize = FONT_SIZE-4)
#         else:
#             ax[i, j].text(test_position[j], 0.4, '{:.2f}'.format(median), va='bottom', ha='center', fontsize = FONT_SIZE-4)
      
#         # setting uniform x and y lims
#         ax[i, j].set_xlim(_range)
#         # ax[i, j].set_ylim(0, 8)

#         # make background transparent
#         rect = ax[i, j].patch
#         rect.set_alpha(0)

#         # remove borders, axis ticks, and y labels, set x labels
#         ax[i, j].set_yticklabels([])
#         ax[i, j].grid(False)
#         ax[i, j].set_xlabel(metric)
        
#         if i == 1:
#             ax[i, j].xaxis.set_ticks_position('bottom')
#             if j == 1:
#                 ax[i, j].set_xticks((-0.5,0,0.5))                
#             elif j==2:
#                 ax[i, j].set_xticks((0,0.5,1))
#         else:
#             ax[i, j].set_xticklabels([])

#         spines = ["top","right","left","bottom"]
#         for s in spines:
#             ax[i, j].spines[s].set_visible(False)

# # plt.subplots_adjust(hspace=-.75)
# ylim = ax[0,0].get_ylim()
# legend_elements = [Patch(facecolor=p[3], edgecolor=None, label=f'TFT_{MODEL_TYPE}')]
# fig_04.legend(handles=legend_elements,mode = "expand", bbox_to_anchor=(0.1, 1, 0.78, 0.01))
# fig_04.savefig(os.path.join(FIGURES_PATH, 
#                             'tft',
#                             f'tft_{MODEL_TYPE}',
#                             f'tft_{MODEL_TYPE}_4096_{VERSION}',
#                             'tft_metrics_val.png'), format='png', dpi=300.0)
# fig_04.savefig(os.path.join(SHARE_PATH, 
#                             'global_mod_paper',
#                             'results',
#                             'tft',
#                             f'tft_{MODEL_TYPE}',
#                             f'tft_{MODEL_TYPE}_4096_{VERSION}',
#                             'tft_metrics_val.png'), format='png', dpi=300.0)
# plt.show()