In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from datetime import datetime
import os
os.chdir('/home/reffert/DeepAR_InfluenzaForecast')
import PythonFiles.model
from PythonFiles.Configuration import Configuration
from PythonFiles.HpTuning import get_data, objectiveDeepAR
from ray.air import session
from ray import tune
data_splits_dict = get_data(truncate=False, with_features=True)[0]

In [None]:
hp_search_space = {
    "num_cells": tune.grid_search([10, 60, 140]),
    "num_layers": tune.grid_search([6, 12]),
    "context_length":tune.grid_search([1, 2, 4]),
    "cell_type":tune.grid_search(["lstm"]),
    "epochs":tune.grid_search([90, 140, 200]),
    "use_feat_static_real":tune.grid_search([False]),
    "use_feat_dynamic_real":tune.grid_search([True]),
    "use_feat_static_cat":tune.grid_search([False, True]),
    "cardinality":tune.sample_from(lambda spec:[2]*411 if spec.config.use_feat_static_cat else None),
}


train = data_splits_dict["with_features_2001"][0]
test = data_splits_dict["with_features_2001"][1]
configuration = Configuration()

tuner = tune.Tuner(
    tune.with_parameters(objectiveDeepAR, train=train, test=test, configuration=configuration),
    tune_config=tune.TuneConfig(
        num_samples = 5,
        metric="mean_WIS",
        mode="min",
        max_concurrent_trials=14,
    ),
    param_space=hp_search_space,
)
results = tuner.fit()

print("Best hyperparameters found were: ", results.get_best_result().config)

results_df = results.get_dataframe()
print(results_df)
results_df.to_csv("Hyperparameter_results_22_05.csv")

2023-05-24 16:54:35,641	INFO worker.py:1553 -- Started a local Ray instance.


0,1
Current time:,2023-05-24 17:16:25
Running for:,00:21:33.62
Memory:,36.9/236.0 GiB

Trial name,status,loc,cardinality,cell_type,context_length,epochs,num_cells,num_layers,use_feat_dynamic_rea l,use_feat_static_cat,use_feat_static_real,iter,total time (s),mean_WIS
objectiveDeepAR_e9419_00006,RUNNING,172.22.1.197:3146167,,lstm,1,200,10,6,True,False,False,,,
objectiveDeepAR_e9419_00007,RUNNING,172.22.1.197:3146169,,lstm,2,200,10,6,True,False,False,,,
objectiveDeepAR_e9419_00008,RUNNING,172.22.1.197:3146171,,lstm,4,200,10,6,True,False,False,,,
objectiveDeepAR_e9419_00009,RUNNING,172.22.1.197:3146173,,lstm,1,90,60,6,True,False,False,,,
objectiveDeepAR_e9419_00010,RUNNING,172.22.1.197:3146175,,lstm,2,90,60,6,True,False,False,,,
objectiveDeepAR_e9419_00011,RUNNING,172.22.1.197:3146177,,lstm,4,90,60,6,True,False,False,,,
objectiveDeepAR_e9419_00012,RUNNING,172.22.1.197:3146179,,lstm,1,140,60,6,True,False,False,,,
objectiveDeepAR_e9419_00013,RUNNING,172.22.1.197:3146181,,lstm,2,140,60,6,True,False,False,,,
objectiveDeepAR_e9419_00014,RUNNING,172.22.1.197:3146065,,lstm,4,140,60,6,True,False,False,,,
objectiveDeepAR_e9419_00015,RUNNING,172.22.1.197:3146157,,lstm,1,200,60,6,True,False,False,,,


  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
100%|██████████| 50/50 [00:05<00:00,  9.17it/s, epoch=1/90, avg_epoch_loss=1.12]
  0%|          | 0/50 [00:00<?, ?it/s][0m 
100%|██████████| 50/50 [00:05<00:00,  9.53it/s, epoch=1/200, avg_epoch_loss=1.21]
  0%|          | 0/50 [00:00<?, ?it/s][0m 
100%|██████████| 50/50 [00:05<00:00,  8.54it/s, epoch=1/90, avg_epoch_loss=1.32]
  0%|          | 0/50 [00:00<?, ?it/s][0m 
100%|███

Trial name,date,done,episodes_total,experiment_id,experiment_tag,hostname,iterations_since_restore,mean_WIS,node_ip,pid,time_since_restore,time_this_iter_s,time_total_s,timestamp,timesteps_since_restore,timesteps_total,training_iteration,trial_id,warmup_time
objectiveDeepAR_e9419_00000,2023-05-24_17-07-00,True,,bf9ed8208365497f986e3293c850c06d,"0_cardinality=None,cell_type=lstm,context_length=1,epochs=90,num_cells=10,num_layers=6,use_feat_dynamic_real=True,use_feat_static_cat=False,use_feat_static_real=False",econ-stat-rr01,1,400.451,172.22.1.197,3146065,723.897,723.897,723.897,1684940820,0,,1,e9419_00000,0.0075767
objectiveDeepAR_e9419_00001,2023-05-24_17-07-44,True,,6004e28ac30b4492b9eacb5276e1079d,"1_cardinality=None,cell_type=lstm,context_length=2,epochs=90,num_cells=10,num_layers=6,use_feat_dynamic_real=True,use_feat_static_cat=False,use_feat_static_real=False",econ-stat-rr01,1,447.769,172.22.1.197,3146157,763.521,763.521,763.521,1684940864,0,,1,e9419_00001,0.00471568
objectiveDeepAR_e9419_00002,2023-05-24_17-08-47,True,,0e0836e698a646b0a1aa1c5b25b5c532,"2_cardinality=None,cell_type=lstm,context_length=4,epochs=90,num_cells=10,num_layers=6,use_feat_dynamic_real=True,use_feat_static_cat=False,use_feat_static_real=False",econ-stat-rr01,1,483.145,172.22.1.197,3146159,826.083,826.083,826.083,1684940927,0,,1,e9419_00002,0.00535917
objectiveDeepAR_e9419_00003,2023-05-24_17-11-03,True,,5f4275ee7ab5421387645ca179d65b32,"3_cardinality=None,cell_type=lstm,context_length=1,epochs=140,num_cells=10,num_layers=6,use_feat_dynamic_real=True,use_feat_static_cat=False,use_feat_static_real=False",econ-stat-rr01,1,411.017,172.22.1.197,3146161,962.12,962.12,962.12,1684941063,0,,1,e9419_00003,0.00534058
objectiveDeepAR_e9419_00004,2023-05-24_17-12-04,True,,48e2796157d7420284390f932990d912,"4_cardinality=None,cell_type=lstm,context_length=2,epochs=140,num_cells=10,num_layers=6,use_feat_dynamic_real=True,use_feat_static_cat=False,use_feat_static_real=False",econ-stat-rr01,1,407.614,172.22.1.197,3146163,1022.97,1022.97,1022.97,1684941124,0,,1,e9419_00004,0.00523853
objectiveDeepAR_e9419_00005,2023-05-24_17-13-45,True,,25da3e875fee4bd8b7ea2fab798c309e,"5_cardinality=None,cell_type=lstm,context_length=4,epochs=140,num_cells=10,num_layers=6,use_feat_dynamic_real=True,use_feat_static_cat=False,use_feat_static_real=False",econ-stat-rr01,1,471.67,172.22.1.197,3146165,1123.81,1123.81,1123.81,1684941225,0,,1,e9419_00005,0.0051682


100%|██████████| 50/50 [00:07<00:00,  7.00it/s, epoch=122/200, avg_epoch_loss=0.642]
  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
100%|██████████| 50/50 [00:06<00:00,  7.58it/s, epoch=139/200, avg_epoch_loss=0.685]
  0%|          | 0/50 [00:00<?, ?it/s][0m 
100%|██████████| 50/50 [00:09<00:00,  5.56it/s, epoch=101/140, avg_epoch_loss=0.567]
  0%|          | 0/50 [00:00<?, ?it/s][0m 
100%|██████████| 50/50 [00:05<00:00,  9.04it/s, epoch=127/140, avg_epoch_loss=0.643]
  0%|          | 0/50 [00:00<?, ?it/s][0m 
100%|██████████| 50/50 [00:05<00:00,  9.41it/s, epoch=140/140, avg_epoch_loss=0.65]
100%|██████████| 50/50 [00:05<00:00,  9.85it/s, epoch=145/200, avg_epoch_loss=0.648]
  0%|          | 0/50 [00:00<?, ?it/s][0m 
100%|██████████| 50/50 [00:06<00:00,  8.20it/s, epoch=110/140, avg_epoch_loss=0.599]
  0%|          | 0/50 [00:00<?, ?it/s][0m 
100%|██████████| 50/50 [00:05<00:00,  8.97it/s, epoch=123/200, avg_epoch_loss=0.647]
  0%|        