In [1]:
import numpy as np
import pandas as pd
from ray import tune
from ray.tune.search.hyperopt import HyperOptSearch
import matplotlib.pyplot as plt

from neuralforecast.core import NeuralForecast
from neuralforecast.models import NHITS
from neuralforecast.losses.pytorch import HuberLoss
from neuralforecast.losses.numpy import mae

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def evaluate(Y_hat_df, av_mask, model_name):
    results_df = Y_hat_df.copy()
    # Filter values with at least 1 available mask in input window
    results_df = results_df.merge(av_mask[['unique_id', 'cutoff', 'sum_av_mask']], on=['unique_id', 'cutoff'], how='left')
    results_df = results_df[results_df['sum_av_mask'] > 0].reset_index(drop=True)
    # Filter ffill values of y
    results_df = results_df[results_df.available_mask==1]
    # Keep critical values of y
    results_critical_df = results_df[(results_df.y<=70) | (results_df.y>=180)]
    return mae(results_df['y'], results_df[model_name]), mae(results_critical_df['y'], results_critical_df[model_name])

In [3]:
data = pd.read_csv('data_glucose/ohiot1dm_exog_9_day_test.csv')
data['ds'] = pd.to_datetime(data['ds'])

df = []
unique_ids = data['unique_id'].unique()
for unique_id in unique_ids:
    df_uid = data[data['unique_id'] == unique_id].reset_index(drop=True)
    df_uid["sum_av_mask"] = df_uid['available_mask'].rolling(window=120, min_periods=1).sum()
    df.append(df_uid)
av_mask = pd.concat(df).reset_index(drop=True)
av_mask = av_mask.rename(columns={'ds': 'cutoff'})
av_mask.head()

Unnamed: 0,cutoff,y,available_mask,CHO,basal_insulin,bolus_insulin,unique_id,sum_av_mask
0,2021-12-07 01:20:00,101.0,1,0.0,0.0,0.0,#559,1.0
1,2021-12-07 01:25:00,98.0,1,0.0,0.0,0.0,#559,2.0
2,2021-12-07 01:30:00,104.0,1,0.0,0.0,0.0,#559,3.0
3,2021-12-07 01:35:00,112.0,1,0.0,0.0,0.0,#559,4.0
4,2021-12-07 01:40:00,120.0,1,0.0,0.0,0.0,#559,5.0


In [17]:
Y_hat_df = pd.read_csv('results_glucose/h_6/baselines/nhits_20230829.csv')
Y_hat_df['ds'] = pd.to_datetime(Y_hat_df['ds'])
Y_hat_df['cutoff'] = pd.to_datetime(Y_hat_df['cutoff'])

mae_all, mae_critical = evaluate(Y_hat_df, av_mask, 'AutoNHITS')

print('NHITS baseline ALL: ', mae_all)
print('NHITS baseline CRITICAL: ', mae_critical)

Y_hat_df = pd.read_csv('results_glucose/h_6/baselines/tft_20230829.csv')
Y_hat_df['ds'] = pd.to_datetime(Y_hat_df['ds'])
Y_hat_df['cutoff'] = pd.to_datetime(Y_hat_df['cutoff'])

mae_all, mae_critical = evaluate(Y_hat_df, av_mask, 'AutoTFT')

print('TFT baseline ALL: ', mae_all)
print('TFT baseline CRITICAL: ', mae_critical)

Y_hat_df = pd.read_csv('results_glucose/h_6/exogenous/nhits_20230829.csv')
Y_hat_df['ds'] = pd.to_datetime(Y_hat_df['ds'])
Y_hat_df['cutoff'] = pd.to_datetime(Y_hat_df['cutoff'])

mae_all, mae_critical = evaluate(Y_hat_df, av_mask, 'AutoNHITS')
print('NHITS exog ALL: ', mae_all)
print('NHITS exog CRITICAL: ', mae_critical)

Y_hat_df = pd.read_csv('results_glucose/h_6/treat/nhits_20230829.csv')
Y_hat_df['ds'] = pd.to_datetime(Y_hat_df['ds'])
Y_hat_df['cutoff'] = pd.to_datetime(Y_hat_df['cutoff'])

mae_all, mae_critical = evaluate(Y_hat_df, av_mask, 'AutoNHITS_TREAT')
print('NHITS TREAT ALL: ', mae_all)
print('NHITS TREAT CRITICAL: ', mae_critical)

NHITS baseline ALL:  9.315096261645257
NHITS baseline CRITICAL:  10.26358951195731
TFT baseline ALL:  9.259035800329098
TFT baseline CRITICAL:  10.371936204513847
NHITS exog ALL:  9.358225936043974
NHITS exog CRITICAL:  10.383775861893822
NHITS TREAT ALL:  9.459924425150689
NHITS TREAT CRITICAL:  10.387447132165317


In [26]:
Y_hat_df = pd.read_csv('results_glucose/h_6/baselines/nhits_20230829.csv')
Y_hat_df['ds'] = pd.to_datetime(Y_hat_df['ds'])
Y_hat_df['cutoff'] = pd.to_datetime(Y_hat_df['cutoff'])

mae_all, mae_critical = evaluate(Y_hat_df, av_mask, 'AutoNHITS')

print('NHITS baseline ALL: ', mae_all)
print('NHITS baseline CRITICAL: ', mae_critical)

Y_hat_df = pd.read_csv('results_glucose/h_6/baselines/tft_20230829.csv')
Y_hat_df['ds'] = pd.to_datetime(Y_hat_df['ds'])
Y_hat_df['cutoff'] = pd.to_datetime(Y_hat_df['cutoff'])

mae_all, mae_critical = evaluate(Y_hat_df, av_mask, 'AutoTFT')

print('TFT baseline ALL: ', mae_all)
print('TFT baseline CRITICAL: ', mae_critical)

Y_hat_df = pd.read_csv('results_glucose/h_6/exogenous/tft_20230829.csv')
Y_hat_df['ds'] = pd.to_datetime(Y_hat_df['ds'])
Y_hat_df['cutoff'] = pd.to_datetime(Y_hat_df['cutoff'])

mae_all, mae_critical = evaluate(Y_hat_df, av_mask, 'AutoTFT')

print('TFT exog ALL: ', mae_all)
print('TFT exog CRITICAL: ', mae_critical)

Y_hat_df = pd.read_csv('results_glucose/h_6/exogenous/nhits_20230829.csv')
Y_hat_df['ds'] = pd.to_datetime(Y_hat_df['ds'])
Y_hat_df['cutoff'] = pd.to_datetime(Y_hat_df['cutoff'])

mae_all, mae_critical = evaluate(Y_hat_df, av_mask, 'AutoNHITS')
print('NHITS exog ALL: ', mae_all)
print('NHITS exog CRITICAL: ', mae_critical)

Y_hat_df = pd.read_csv('results_glucose/h_6/treat/nhits_20230829.csv')
Y_hat_df['ds'] = pd.to_datetime(Y_hat_df['ds'])
Y_hat_df['cutoff'] = pd.to_datetime(Y_hat_df['cutoff'])

mae_all, mae_critical = evaluate(Y_hat_df, av_mask, 'AutoNHITS_TREAT')
print('NHITS TREAT ALL: ', mae_all)
print('NHITS TREAT CRITICAL: ', mae_critical)


Y_hat_df = pd.read_csv('results_glucose/h_6/treat/nhits_20230829_cho.csv')
Y_hat_df['ds'] = pd.to_datetime(Y_hat_df['ds'])
Y_hat_df['cutoff'] = pd.to_datetime(Y_hat_df['cutoff'])

mae_all, mae_critical = evaluate(Y_hat_df, av_mask, 'AutoNHITS_TREAT')
print('NHITS TREAT ALL: ', mae_all)
print('NHITS TREAT CRITICAL: ', mae_critical)

NHITS baseline ALL:  9.319583626742984
NHITS baseline CRITICAL:  10.50554280553876
TFT baseline ALL:  9.259035800329098
TFT baseline CRITICAL:  10.371936204513847
TFT exog ALL:  9.464974399592458
TFT exog CRITICAL:  11.417223404672573
NHITS exog ALL:  8.972649823354212
NHITS exog CRITICAL:  10.100923448845085
NHITS TREAT ALL:  8.94415280577869
NHITS TREAT CRITICAL:  10.288121691332083
NHITS TREAT ALL:  8.902128148961392
NHITS TREAT CRITICAL:  9.885354577064941
