In [1]:
import os
import pandas as pd
os.chdir('/home/reffert/DeepAR_InfluenzaForecast')
from PythonFiles.model import preprocessing, train_test_split
from PythonFiles.Configuration import Configuration
from PythonFiles.HpTuning import get_data, objectiveDeepAR
from ray.air import session
from ray import tune
from datetime import datetime
configuration = Configuration()
configuration.train_start_time = datetime(1999,1,1,0,0,0)
configuration.train_end_time = datetime(2018,9,30,23,0,0)
configuration.test_end_time = datetime(2020,9,30,23,0,0)
# import the data
influenza_df = pd.read_csv("/home/reffert/DeepAR_InfluenzaForecast/Notebooks/DataProcessing/influenza.csv", sep=',')
population_df = pd.read_csv("/home/reffert/DeepAR_InfluenzaForecast/Notebooks/DataProcessing/PopulationVector.csv", sep=',')
neighbourhood_df = pd.read_csv("/home/reffert/DeepAR_InfluenzaForecast/Notebooks/DataProcessing/AdjacentMatrix.csv", sep=',', index_col=0)

data_splits_dict = {}
locations = list(influenza_df.location.unique())
#Process the df into a uniformly spaced df
df = influenza_df.loc[influenza_df.location.isin(locations), ['value', 'location', 'date','week']]
df = preprocessing(configuration, df, check_count=False, output_type="corrected_df")
for location in locations:
    df.loc[df.location == location, "population"] = int(population_df.loc[population_df.Location == location, "2011"].values[0])
    df.loc[df.location == location, locations] = neighbourhood_df.loc[neighbourhood_df.index==location,locations].values[0].astype(int)
data_splits_dict[f"with_features_2001"] = list(train_test_split(configuration, df, True))
train = data_splits_dict["with_features_2001"][0]
test = data_splits_dict["with_features_2001"][1]

In [None]:
hp_search_space = {
    "num_cells": tune.grid_search([40]),
    "num_layers": tune.grid_search([2]),
    "context_length":tune.grid_search([4]),
    "prediction_length":tune.grid_search([4]),
    "cell_type":tune.grid_search(["lstm"]),
    "epochs":tune.grid_search([100]),
    "use_feat_static_real":tune.grid_search([False]),
    "use_feat_dynamic_real":tune.grid_search([False]),
    "use_feat_static_cat":tune.grid_search([False]),
}

tuner = tune.Tuner(
    tune.with_parameters(objectiveDeepAR, train=train, test=test, configuration=configuration),
    tune_config=tune.TuneConfig(
        num_samples = 100,
        metric="mean_WIS",
        mode="min",
        max_concurrent_trials=8,
    ),
    param_space=hp_search_space,
)
results = tuner.fit()

print("Best hyperparameters found were: ", results.get_best_result().config)

results_df = results.get_dataframe()
print(results_df)
results_df.to_csv("default_DeepAR_results_06_06.csv")

2023-06-06 14:19:12,824	INFO worker.py:1553 -- Started a local Ray instance.


0,1
Current time:,2023-06-06 16:05:54
Running for:,01:46:34.39
Memory:,205.6/236.0 GiB

Trial name,status,loc,cell_type,context_length,epochs,num_cells,num_layers,prediction_length,use_feat_dynamic_rea l,use_feat_static_cat,use_feat_static_real,iter,total time (s),mean_WIS
objectiveDeepAR_5cc05_00040,RUNNING,172.22.1.197:3182118,lstm,4,100,40,2,4,False,False,False,,,
objectiveDeepAR_5cc05_00041,RUNNING,172.22.1.197:3182007,lstm,4,100,40,2,4,False,False,False,,,
objectiveDeepAR_5cc05_00042,RUNNING,172.22.1.197:3182124,lstm,4,100,40,2,4,False,False,False,,,
objectiveDeepAR_5cc05_00043,RUNNING,172.22.1.197:3182120,lstm,4,100,40,2,4,False,False,False,,,
objectiveDeepAR_5cc05_00044,RUNNING,172.22.1.197:3182122,lstm,4,100,40,2,4,False,False,False,,,
objectiveDeepAR_5cc05_00045,RUNNING,172.22.1.197:3182116,lstm,4,100,40,2,4,False,False,False,,,
objectiveDeepAR_5cc05_00046,RUNNING,172.22.1.197:3182114,lstm,4,100,40,2,4,False,False,False,,,
objectiveDeepAR_5cc05_00047,RUNNING,172.22.1.197:3182112,lstm,4,100,40,2,4,False,False,False,,,
objectiveDeepAR_5cc05_00000,TERMINATED,172.22.1.197:3182007,lstm,4,100,40,2,4,False,False,False,1.0,1134.62,417.587
objectiveDeepAR_5cc05_00001,TERMINATED,172.22.1.197:3182112,lstm,4,100,40,2,4,False,False,False,1.0,1249.53,363.69


  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
100%|██████████| 50/50 [00:07<00:00,  6.98it/s, epoch=1/100, avg_epoch_loss=1.31]
  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
100%|██████████| 50/50 [00:06<00:00,  8.17it/s, epoch=2/100, avg_epoch_loss=0.993]
  0%|          | 0/50 [00:00<?, ?it/s][0m 
100%|██████████| 50/50 [00:07<00:00,  6.50it/s, epoch=1/100, avg_epoch_loss=1.23]
  0%|          | 0/50 [00:00<?, ?it/s][0m 
100%|██████████| 50/50 [00:08<00:00,  5.89it/s, epoch=1/100, avg_epoch_loss=1.27]
  0%|          | 0/50 [00:00<?, ?it/s][0m 
100%|██████████| 50/50 [00:08<00:00,  5.78it/s, epoch=1/100, avg_epoch_loss=1.34]
  0%|          | 0/50 [00:00<?, ?it/s][0m 
100%|██████████| 

Trial name,date,done,episodes_total,experiment_id,experiment_tag,hostname,iterations_since_restore,mean_WIS,node_ip,pid,time_since_restore,time_this_iter_s,time_total_s,timestamp,timesteps_since_restore,timesteps_total,training_iteration,trial_id,warmup_time
objectiveDeepAR_5cc05_00000,2023-06-06_14-38-19,True,,980510d0f6764dd2aae2b1bf21d649d3,"0_cell_type=lstm,context_length=4,epochs=100,num_cells=40,num_layers=2,prediction_length=4,use_feat_dynamic_real=False,use_feat_static_cat=False,use_feat_static_real=False",econ-stat-rr01,1,417.587,172.22.1.197,3182007,1134.62,1134.62,1134.62,1686055099,0,,1,5cc05_00000,0.0589793
objectiveDeepAR_5cc05_00001,2023-06-06_14-40-20,True,,25166c9be6044d71b89126461fe1ae05,"1_cell_type=lstm,context_length=4,epochs=100,num_cells=40,num_layers=2,prediction_length=4,use_feat_dynamic_real=False,use_feat_static_cat=False,use_feat_static_real=False",econ-stat-rr01,1,363.69,172.22.1.197,3182112,1249.53,1249.53,1249.53,1686055220,0,,1,5cc05_00001,0.0061307
objectiveDeepAR_5cc05_00002,2023-06-06_14-40-04,True,,432e07b813914c1fa78bb64b1faec0bb,"2_cell_type=lstm,context_length=4,epochs=100,num_cells=40,num_layers=2,prediction_length=4,use_feat_dynamic_real=False,use_feat_static_cat=False,use_feat_static_real=False",econ-stat-rr01,1,469.489,172.22.1.197,3182114,1232.37,1232.37,1232.37,1686055204,0,,1,5cc05_00002,0.00648522
objectiveDeepAR_5cc05_00003,2023-06-06_14-39-23,True,,93b2f35ab0ac46bb873ed3d9bbe6789c,"3_cell_type=lstm,context_length=4,epochs=100,num_cells=40,num_layers=2,prediction_length=4,use_feat_dynamic_real=False,use_feat_static_cat=False,use_feat_static_real=False",econ-stat-rr01,1,387.412,172.22.1.197,3182116,1191.1,1191.1,1191.1,1686055163,0,,1,5cc05_00003,0.00709152
objectiveDeepAR_5cc05_00004,2023-06-06_14-38-13,True,,22f265f17ade4acd8cd5b3c08d8ae873,"4_cell_type=lstm,context_length=4,epochs=100,num_cells=40,num_layers=2,prediction_length=4,use_feat_dynamic_real=False,use_feat_static_cat=False,use_feat_static_real=False",econ-stat-rr01,1,437.275,172.22.1.197,3182118,1122.49,1122.49,1122.49,1686055093,0,,1,5cc05_00004,0.00750971
objectiveDeepAR_5cc05_00005,2023-06-06_14-40-18,True,,99b23eae3c354071a72888ed6565dd2d,"5_cell_type=lstm,context_length=4,epochs=100,num_cells=40,num_layers=2,prediction_length=4,use_feat_dynamic_real=False,use_feat_static_cat=False,use_feat_static_real=False",econ-stat-rr01,1,420.778,172.22.1.197,3182120,1247.71,1247.71,1247.71,1686055218,0,,1,5cc05_00005,0.00670409
objectiveDeepAR_5cc05_00006,2023-06-06_14-38-09,True,,a7ac46ea76f741fa8a5bad4af160e768,"6_cell_type=lstm,context_length=4,epochs=100,num_cells=40,num_layers=2,prediction_length=4,use_feat_dynamic_real=False,use_feat_static_cat=False,use_feat_static_real=False",econ-stat-rr01,1,416.999,172.22.1.197,3182122,1117.74,1117.74,1117.74,1686055089,0,,1,5cc05_00006,0.006217
objectiveDeepAR_5cc05_00007,2023-06-06_14-38-34,True,,e10620d37c69468ab092c2c8ee4b49d3,"7_cell_type=lstm,context_length=4,epochs=100,num_cells=40,num_layers=2,prediction_length=4,use_feat_dynamic_real=False,use_feat_static_cat=False,use_feat_static_real=False",econ-stat-rr01,1,392.366,172.22.1.197,3182124,1142.37,1142.37,1142.37,1686055114,0,,1,5cc05_00007,0.00736642
objectiveDeepAR_5cc05_00008,2023-06-06_14-57-19,True,,a7ac46ea76f741fa8a5bad4af160e768,"8_cell_type=lstm,context_length=4,epochs=100,num_cells=40,num_layers=2,prediction_length=4,use_feat_dynamic_real=False,use_feat_static_cat=False,use_feat_static_real=False",econ-stat-rr01,1,473.034,172.22.1.197,3182122,1150.35,1150.35,1150.35,1686056239,0,,1,5cc05_00008,0.006217
objectiveDeepAR_5cc05_00009,2023-06-06_14-56-46,True,,22f265f17ade4acd8cd5b3c08d8ae873,"9_cell_type=lstm,context_length=4,epochs=100,num_cells=40,num_layers=2,prediction_length=4,use_feat_dynamic_real=False,use_feat_static_cat=False,use_feat_static_real=False",econ-stat-rr01,1,470.981,172.22.1.197,3182118,1112.13,1112.13,1112.13,1686056206,0,,1,5cc05_00009,0.00750971


[2m[36m(objectiveDeepAR pid=3182118)[0m   return arr.astype(dtype, copy=True)
  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
[2m[36m(objectiveDeepAR pid=3182007)[0m   return arr.astype(dtype, copy=True)
  0%|          | 0/50 [00:00<?, ?it/s][0m 
 72%|███████▏  | 36/50 [00:10<00:03,  3.53it/s, epoch=1/100, avg_epoch_loss=1.4]
100%|██████████| 50/50 [00:12<00:00,  3.86it/s, epoch=1/100, avg_epoch_loss=1.34]
  0%|          | 0/50 [00:00<?, ?it/s][0m 
 86%|████████▌ | 43/50 [00:10<00:01,  4.26it/s, epoch=1/100, avg_epoch_loss=1.36]
100%|██████████| 50/50 [00:11<00:00,  4.34it/s, epoch=1/100, avg_epoch_loss=1.33]
  0%|          | 0/50 [00:00<?, ?it/s][0m 
100%|██████████| 50/50 [00:09<00:00,  5.52it/s, epoch=1/100, avg_epoch_loss=1.34]
  0%|          | 0/50 [00:00<?, ?it/s][0m 
[2m[36m(objectiveDeepAR pid=3182124)[0m   return arr.astype(dtype, copy=True)
100%|██████████| 50/50 [00:08<00:00,  5.63it/s, epoch=2/100, avg_epoch_loss=1.01]
  