In [2]:
import os
os.chdir('/home/reffert/DeepAR_InfluenzaForecast')
from PythonFiles.model import model, preprocessing, split_forecasts_by_week, plot_coverage, print_forecasts_by_week, forecast_by_week, train_test_split, update_deepAR_parameters
from PythonFiles.Configuration import Configuration
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from datetime import datetime
from gluonts.mx import Trainer, DeepAREstimator
from gluonts.dataset.pandas import PandasDataset
from gluonts.dataset.rolling_dataset import generate_rolling_dataset,StepStrategy
from gluonts.evaluation import make_evaluation_predictions, Evaluator
import ray
from ray.air import session
from ray import tune
from gluonts.mx.distribution import NegativeBinomialOutput
config = Configuration()
influenza_df = pd.read_csv("/home/reffert/DeepAR_InfluenzaForecast/Notebooks/DataProcessing/influenza.csv", sep=',')
population_df = pd.read_csv("/home/reffert/DeepAR_InfluenzaForecast/Notebooks/DataProcessing/PopulationVector.csv", sep=',')
neighbourhood_df = pd.read_csv("/home/reffert/DeepAR_InfluenzaForecast/Notebooks/DataProcessing/AdjacentMatrix.csv", sep=',', index_col=0)

In [3]:
config.train_start_time = datetime(1999,1,1,0,0,0)#datetime(2010,1,1,0,0,0)
config.train_end_time = datetime(2016,9,30,23,0,0)
config.test_end_time = datetime(2018,9,30,23,0,0)
overall_evaluation_df = pd.DataFrame()

data_splits_dict = {}
output_dict = {}

locations = list(influenza_df.location.unique())
#Process the df into a uniformly spaced df
df = influenza_df.loc[influenza_df.location.isin(locations), ['value', 'location', 'date','week']]
df = preprocessing(config, df, check_count=False, output_type="corrected_df")
for location in locations:
    df.loc[df.location == location, "population"] = int(population_df.loc[population_df.Location == location, "2011"].values[0])
    df.loc[df.location == location, locations] = neighbourhood_df.loc[neighbourhood_df.index==location,locations].values[0].astype(int)

In [4]:
# Data split with no additional features and training start in 2010
#data_splits_dict["without_features_2001"] = list(train_test_split(config, df, False))
data_splits_dict["with_features_2001"] = list(train_test_split(config, df, True))

# Change the beginning of the training period
#config.train_start_time = datetime(2010,1,1,0,0,0)
#data_splits_dict["without_features_2010"] = list(train_test_split(config, df, False))
#data_splits_dict["with_features_2010"] = list(train_test_split(config, df, True))

In [5]:
def evaluate(config, train, test, configuration):
    deeparestimator = update_deepAR_parameters(configuration, config)
    forecasts, tss = model(train, test, deeparestimator)
    # Evaluation with the quantiles of the configuration and calculation of the mean_WIS
    evaluator = Evaluator(quantiles=configuration.quantiles)    
    agg_metrics = evaluator(tss, forecasts)[0]
    mean_WIS = agg_metrics["mean_absolute_QuantileLoss"]/(configuration.parameters["prediction_length"]*411)
    return mean_WIS

def objective(config, train, test, configuration):
    score = evaluate(config, train, test, configuration)
    session.report({"mean_WIS":score})

In [None]:
hp_search_space = {
    "num_cells": tune.grid_search([40, 80]),
    "num_layers": tune.grid_search([1, 5, 10]),
    "context_length":tune.grid_search([4, 52, 104]),
    "cell_type":tune.grid_search(["lstm", "gru"]),
    "epochs":tune.grid_search([20, 30, 40]),
    "use_feat_static_real":tune.grid_search([False, False]),
    "use_feat_dynamic_real":tune.grid_search([True, False]),
    "use_feat_static_cat":tune.grid_search([False,True]),
    "cardinality":tune.sample_from(lambda spec:[2]*411 if spec.config.use_feat_static_cat else None),
}


train = data_splits_dict["with_features_2001"][0]
test = data_splits_dict["with_features_2001"][1]
configuration = Configuration()

tuner = tune.Tuner(
    tune.with_parameters(objective, train=train, test=test, configuration=configuration),
    tune_config=tune.TuneConfig(
        num_samples = 5,
        metric="mean_WIS",
        mode="min",
        max_concurrent_trials=18,
    ),
    param_space=hp_search_space,
)
results = tuner.fit()

print("Best hyperparameters found were: ", results.get_best_result().config)

results_df = results.get_dataframe()
print(results_df)
results_df.to_csv("Hyperparameter_results_26_04.csv")

0,1
Current time:,2023-04-26 12:40:15
Running for:,01:12:47.53
Memory:,97.0/236.0 GiB

Trial name,status,loc,cardinality,cell_type,context_length,epochs,num_cells,num_layers,use_feat_dynamic_rea l,use_feat_static_cat,use_feat_static_real,iter,total time (s),mean_WIS
objective_8f667_00033,RUNNING,172.22.1.197:1652010,,gru,52,40,80,1,True,False,False,,,
objective_8f667_00034,RUNNING,172.22.1.197:1651984,,lstm,104,40,80,1,True,False,False,,,
objective_8f667_00035,RUNNING,172.22.1.197:1652016,,gru,104,40,80,1,True,False,False,,,
objective_8f667_00038,RUNNING,172.22.1.197:1651988,,lstm,52,20,40,5,True,False,False,,,
objective_8f667_00039,RUNNING,172.22.1.197:1651986,,gru,52,20,40,5,True,False,False,,,
objective_8f667_00040,RUNNING,172.22.1.197:1651992,,lstm,104,20,40,5,True,False,False,,,
objective_8f667_00041,RUNNING,172.22.1.197:1651990,,gru,104,20,40,5,True,False,False,,,
objective_8f667_00042,RUNNING,172.22.1.197:1651994,,lstm,4,30,40,5,True,False,False,,,
objective_8f667_00043,RUNNING,172.22.1.197:1651996,,gru,4,30,40,5,True,False,False,,,
objective_8f667_00044,RUNNING,172.22.1.197:1651998,,lstm,52,30,40,5,True,False,False,,,


  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/50 [00:00<?, ?it/s]
 12%|█▏        | 6/50 [00:10<01:18,  1.78s/it, epoch=1/20, avg_epoch_loss=1.3]
 12%|█▏        | 6/50 [00:10<01:13,  1.68s/it, epoch=1/20, avg_epoch_loss=1.33]
 16%|█▌        | 8/50 [00:10<00:54,  1.30s/it, epoch=1/30, avg_epoch_loss=1.18]
 12%|█▏        | 6/50 [00:10<01:16,  1.74s/it, epoch=1/40, 

Trial name,date,done,episodes_total,experiment_id,experiment_tag,hostname,iterations_since_restore,mean_WIS,node_ip,pid,time_since_restore,time_this_iter_s,time_total_s,timestamp,timesteps_since_restore,timesteps_total,training_iteration,trial_id,warmup_time
objective_8f667_00000,2023-04-26_11-44-25,True,,364926f6580b455da12dfee563ec9a5a,"0_cardinality=None,cell_type=lstm,context_length=4,epochs=20,num_cells=40,num_layers=1,use_feat_dynamic_real=True,use_feat_static_cat=False,use_feat_static_real=False",econ-stat-rr01,1,496.416,172.22.1.197,1651916,1013.08,1013.08,1013.08,1682502265,0,,1,8f667_00000,0.00580144
objective_8f667_00001,2023-04-26_11-43-24,True,,50504c518ec94106b7d613bb980431a6,"1_cardinality=None,cell_type=gru,context_length=4,epochs=20,num_cells=40,num_layers=1,use_feat_dynamic_real=True,use_feat_static_cat=False,use_feat_static_real=False",econ-stat-rr01,1,480.664,172.22.1.197,1651984,945.228,945.228,945.228,1682502204,0,,1,8f667_00001,0.0063808
objective_8f667_00002,2023-04-26_11-47-05,True,,f398906bd71644ee8b1ef003176228bb,"2_cardinality=None,cell_type=lstm,context_length=52,epochs=20,num_cells=40,num_layers=1,use_feat_dynamic_real=True,use_feat_static_cat=False,use_feat_static_real=False",econ-stat-rr01,1,530.479,172.22.1.197,1651986,1166.7,1166.7,1166.7,1682502425,0,,1,8f667_00002,0.00764704
objective_8f667_00003,2023-04-26_11-46-15,True,,b21935febb664de7a1f362b185b17554,"3_cardinality=None,cell_type=gru,context_length=52,epochs=20,num_cells=40,num_layers=1,use_feat_dynamic_real=True,use_feat_static_cat=False,use_feat_static_real=False",econ-stat-rr01,1,521.395,172.22.1.197,1651988,1116.16,1116.16,1116.16,1682502375,0,,1,8f667_00003,0.00571704
objective_8f667_00004,2023-04-26_11-50-22,True,,61797d562aad44d9ad545f955410f4f0,"4_cardinality=None,cell_type=lstm,context_length=104,epochs=20,num_cells=40,num_layers=1,use_feat_dynamic_real=True,use_feat_static_cat=False,use_feat_static_real=False",econ-stat-rr01,1,488.141,172.22.1.197,1651990,1363.72,1363.72,1363.72,1682502622,0,,1,8f667_00004,0.00596356
objective_8f667_00005,2023-04-26_11-49-05,True,,189f7453b1b744dc83a85c60901c7bdb,"5_cardinality=None,cell_type=gru,context_length=104,epochs=20,num_cells=40,num_layers=1,use_feat_dynamic_real=True,use_feat_static_cat=False,use_feat_static_real=False",econ-stat-rr01,1,479.345,172.22.1.197,1651992,1287.04,1287.04,1287.04,1682502545,0,,1,8f667_00005,0.00518775
objective_8f667_00006,2023-04-26_11-50-48,True,,bbd712853f9a483ca88b33d19a700184,"6_cardinality=None,cell_type=lstm,context_length=4,epochs=30,num_cells=40,num_layers=1,use_feat_dynamic_real=True,use_feat_static_cat=False,use_feat_static_real=False",econ-stat-rr01,1,491.298,172.22.1.197,1651994,1389.37,1389.37,1389.37,1682502648,0,,1,8f667_00006,0.0048461
objective_8f667_00007,2023-04-26_11-51-23,True,,fca21bf4617449b6aa85aea733214aba,"7_cardinality=None,cell_type=gru,context_length=4,epochs=30,num_cells=40,num_layers=1,use_feat_dynamic_real=True,use_feat_static_cat=False,use_feat_static_real=False",econ-stat-rr01,1,484.357,172.22.1.197,1651996,1424.59,1424.59,1424.59,1682502683,0,,1,8f667_00007,0.0068388
objective_8f667_00008,2023-04-26_11-54-09,True,,6d4d1631ee6c435fbd3114962ccc5e4f,"8_cardinality=None,cell_type=lstm,context_length=52,epochs=30,num_cells=40,num_layers=1,use_feat_dynamic_real=True,use_feat_static_cat=False,use_feat_static_real=False",econ-stat-rr01,1,566.754,172.22.1.197,1651998,1590.69,1590.69,1590.69,1682502849,0,,1,8f667_00008,0.00639868
objective_8f667_00009,2023-04-26_11-54-58,True,,618f9f06fa434ca9bf324c071e8e1513,"9_cardinality=None,cell_type=gru,context_length=52,epochs=30,num_cells=40,num_layers=1,use_feat_dynamic_real=True,use_feat_static_cat=False,use_feat_static_real=False",econ-stat-rr01,1,530.312,172.22.1.197,1652000,1639.66,1639.66,1639.66,1682502898,0,,1,8f667_00009,0.00588894


  0%|          | 0/50 [00:00<?, ?it/s]
100%|██████████| 50/50 [00:45<00:00,  1.10it/s, epoch=25/30, avg_epoch_loss=0.673]
  0%|          | 0/50 [00:00<?, ?it/s]
100%|██████████| 50/50 [00:47<00:00,  1.05it/s, epoch=23/40, avg_epoch_loss=0.586]
  0%|          | 0/50 [00:00<?, ?it/s]
 20%|██        | 10/50 [00:10<00:42,  1.06s/it, epoch=24/40, avg_epoch_loss=0.578]
 18%|█▊        | 9/50 [00:10<00:47,  1.17s/it, epoch=26/40, avg_epoch_loss=0.691]
100%|██████████| 50/50 [00:48<00:00,  1.02it/s, epoch=22/40, avg_epoch_loss=0.577]
  0%|          | 0/50 [00:00<?, ?it/s]
 92%|█████████▏| 46/50 [00:46<00:04,  1.05s/it, epoch=22/30, avg_epoch_loss=0.577]
100%|██████████| 50/50 [00:56<00:00,  1.12s/it, epoch=23/30, avg_epoch_loss=0.594]
  0%|          | 0/50 [00:00<?, ?it/s]
 22%|██▏       | 11/50 [00:10<00:36,  1.07it/s, epoch=26/40, avg_epoch_loss=0.645]
 32%|███▏      | 16/50 [00:21<00:41,  1.22s/it, epoch=26/30, avg_epoch_loss=0.722]
 24%|██▍       | 12/50 [00:10<00:32,  1.18it/s, epoch=24/30