In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from datetime import datetime
import os
os.chdir('/home/reffert/DeepAR_InfluenzaForecast')
import PythonFiles.model
from PythonFiles.Configuration import Configuration
from PythonFiles.HpTuning import get_data, objectiveDeepAR
from ray.air import session
from ray import tune
data_splits_dict = get_data(truncate=False, with_features=True)[0]

In [None]:
hp_search_space = {
    "num_cells": tune.grid_search([80, 140]),
    "num_layers": tune.grid_search([6, 8, 10]),
    "context_length":tune.grid_search([1, 2, 104]),
    "cell_type":tune.grid_search(["lstm"]),
    "epochs":tune.grid_search([140, 160, 200]),
    "use_feat_static_real":tune.grid_search([False]),
    "use_feat_dynamic_real":tune.grid_search([True]),
    "use_feat_static_cat":tune.grid_search([False]),
    "cardinality":tune.sample_from(lambda spec:[2]*411 if spec.config.use_feat_static_cat else None),
}


train = data_splits_dict["with_features_2001"][0]
test = data_splits_dict["with_features_2001"][1]
configuration = Configuration()

tuner = tune.Tuner(
    tune.with_parameters(objectiveDeepAR, train=train, test=test, configuration=configuration),
    tune_config=tune.TuneConfig(
        num_samples = 5,
        metric="mean_WIS",
        mode="min",
        max_concurrent_trials=24,
    ),
    param_space=hp_search_space,
)
results = tuner.fit()

print("Best hyperparameters found were: ", results.get_best_result().config)

results_df = results.get_dataframe()
print(results_df)
results_df.to_csv("Hyperparameter_results_03_06.csv")

2023-06-03 10:20:49,400	INFO worker.py:1553 -- Started a local Ray instance.


0,1
Current time:,2023-06-03 10:51:00
Running for:,00:29:57.42
Memory:,108.5/236.0 GiB

Trial name,status,loc,cardinality,cell_type,context_length,epochs,num_cells,num_layers,use_feat_dynamic_rea l,use_feat_static_cat,use_feat_static_real
objectiveDeepAR_8f20c_00000,RUNNING,172.22.1.197:1959393,,lstm,1,140,80,6,True,False,False
objectiveDeepAR_8f20c_00001,RUNNING,172.22.1.197:1959511,,lstm,2,140,80,6,True,False,False
objectiveDeepAR_8f20c_00002,RUNNING,172.22.1.197:1959513,,lstm,104,140,80,6,True,False,False
objectiveDeepAR_8f20c_00003,RUNNING,172.22.1.197:1959515,,lstm,1,160,80,6,True,False,False
objectiveDeepAR_8f20c_00004,RUNNING,172.22.1.197:1959517,,lstm,2,160,80,6,True,False,False
objectiveDeepAR_8f20c_00005,RUNNING,172.22.1.197:1959519,,lstm,104,160,80,6,True,False,False
objectiveDeepAR_8f20c_00006,RUNNING,172.22.1.197:1959521,,lstm,1,200,80,6,True,False,False
objectiveDeepAR_8f20c_00007,RUNNING,172.22.1.197:1959523,,lstm,2,200,80,6,True,False,False
objectiveDeepAR_8f20c_00008,RUNNING,172.22.1.197:1959525,,lstm,104,200,80,6,True,False,False
objectiveDeepAR_8f20c_00009,RUNNING,172.22.1.197:1959527,,lstm,1,140,140,6,True,False,False


  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, ?it/s][0m 
  0%|          | 0/50 [00:00<?, 