[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ourownstory/neural_prophet/blob/master/example_notebooks/autoregression_yosemite_temps.ipynb)

# Hyperparameter optimization with Ray Tune

We introduce the module for hyperparameter otpimization with Ray Tune.

It supports automatic tuning, with predefined by us hyperparameter sets, as well as manual tuning with user-provided configuration of the parameters.

Firstly, we will show how it works with NP model in automated mode.

In [1]:
# install NeuralProphet from our repository
# !pip install git+https://github.com/adasegroup/neural_prophet.git # may take a while
# !pip install tensorboardX

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pandas as pd
import numpy as np
from neuralprophet import NeuralProphet

In [3]:
from neuralprophet.hyperparameter_tuner import tune_hyperparameters




In [4]:
if 'google.colab' in str(get_ipython()):
    data_location = "https://raw.githubusercontent.com/adasegroup/neural_prophet/master/"
else:
    data_location = "../"

df = pd.read_csv(data_location + "example_data/yosemite_temps.csv")
df.head(3)

Unnamed: 0,ds,y
0,2017-05-01 00:00:00,27.8
1,2017-05-01 00:05:00,27.0
2,2017-05-01 00:10:00,26.8


In [5]:
freq = '5min'
best_params, results_df = tune_hyperparameters('NP',
                                               df,
                                               freq)

2021-05-21 20:52:07,568	INFO services.py:1269 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8265[39m[22m


KeyboardInterrupt: 

In [5]:
best_params

{'growth': 'linear',
 'n_changepoints': 10,
 'changepoints_range': 0.5,
 'trend_reg': 10.0,
 'yearly_seasonality': False,
 'weekly_seasonality': True,
 'daily_seasonality': True,
 'seasonality_mode': 'additive',
 'seasonality_reg': 0.0,
 'n_lags': 30,
 'd_hidden': 8,
 'num_hidden_layers': 2,
 'ar_sparsity': 0.8,
 'learning_rate': 0.024150905458487568,
 'loss_func': 'Huber',
 'normalize': 'auto'}

This dictionary can further be used in initialization of NeuralProphet model.

This function has also additional parameters:
- **num_epochs**: Max possible number of epochs to train each model.
- **num_samples**: Number of samples from hyperparameter spaces to check.
- **resources_per_trial**: Resources per trial setting for ray.tune.run, {'cpu': 1, 'gpu': 2} for example



This function additionally will output the dataframe with detailed result of each trial if return_results is set to True.

In [7]:
results_df[['config.growth', 'config.n_changepoints', 'config.changepoints_range',
       'config.trend_reg', 'config.yearly_seasonality',
       'config.weekly_seasonality', 'config.daily_seasonality',
       'config.seasonality_mode', 'config.seasonality_reg', 'config.n_lags',
       'config.d_hidden', 'config.num_hidden_layers', 'config.ar_sparsity',
       'config.learning_rate', 'config.loss_func', 'config.normalize']].head()

Unnamed: 0_level_0,config.growth,config.n_changepoints,config.changepoints_range,config.trend_reg,config.yearly_seasonality,config.weekly_seasonality,config.daily_seasonality,config.seasonality_mode,config.seasonality_reg,config.n_lags,config.d_hidden,config.num_hidden_layers,config.ar_sparsity,config.learning_rate,config.loss_func,config.normalize
trial_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
8f12d_00000,off,10,0.9,10.0,False,False,True,multiplicative,1.0,100,64,8,0.8,0.000733,MSE,off
8f12d_00001,linear,100,0.8,0.5,True,False,True,additive,0.0,10,128,16,0.8,0.01959,MSE,soft
8f12d_00002,linear,10,0.5,0.0,True,False,False,multiplicative,0.0,30,128,2,0.1,0.000306,MSE,standardize
8f12d_00003,linear,10,0.5,10.0,True,True,True,multiplicative,0.5,30,8,8,0.8,0.000195,Huber,off
8f12d_00004,linear,10,0.5,0.5,True,True,True,additive,0.5,10,64,2,0.8,0.008229,MSE,soft


## Manual mode

In case of manual mode, a user must provide a config dictionary with hyperparameter spaces in compatability with Ray Tune api.

We provide a minimal example below, for more information on Search Spaces withit this link https://docs.ray.io/en/master/tune/api_docs/search_space.html?highlight=tune.choice

In [None]:
from ray import tune

config = {'n_lags': tune.grid_search([10, 20, 30]),
          'learning_rate': tune.loguniform(1e-4, 1e-1),
          'num_hidden_layers': tune.choice([2, 8, 16])}

In [None]:
freq = '5min'
best_params, results_df = tune_hyperparameters('NP', 
                                               df,
                                               freq, 
                                               mode = 'manual',
                                              config = config)

In [None]:
results_df

In [None]:
best_params

In [8]:
freq = '5min'
best_params, results_df = tune_hyperparameters('LSTM',
                                               df,
                                               freq)

[2m[36m(pid=28975)[0m INFO - (NP.config.set_auto_batch_epoch) - Auto-set batch_size to 64
  0%|          | 0/100 [00:00<?, ?it/s]GPU available: False, used: False
[2m[36m(pid=28975)[0m TPU available: False, using: 0 TPU cores
[2m[36m(pid=28975)[0m GPU available: False, used: False
[2m[36m(pid=28975)[0m TPU available: False, using: 0 TPU cores
[2m[36m(pid=28975)[0m 
[2m[36m(pid=28975)[0m   | Name      | Type         | Params
[2m[36m(pid=28975)[0m -------------------------------------------
[2m[36m(pid=28975)[0m 0 | lstm      | LSTM         | 6.1 M 
[2m[36m(pid=28975)[0m 1 | linear    | Linear       | 257   
[2m[36m(pid=28975)[0m 2 | loss_func | SmoothL1Loss | 0     
[2m[36m(pid=28975)[0m -------------------------------------------
[2m[36m(pid=28975)[0m 6.1 M     Trainable params
[2m[36m(pid=28975)[0m 0         Non-trainable params
[2m[36m(pid=28975)[0m 6.1 M     Total params
[2m[36m(pid=28975)[0m 24.241    Total estimated model params size (M

  0%|          | 0/100 [00:00<?, ?it/s]GPU available: False, used: False
[2m[36m(pid=28974)[0m TPU available: False, using: 0 TPU cores
[2m[36m(pid=28974)[0m GPU available: False, used: False
[2m[36m(pid=28974)[0m TPU available: False, using: 0 TPU cores
[2m[36m(pid=28974)[0m 
[2m[36m(pid=28974)[0m   | Name      | Type         | Params
[2m[36m(pid=28974)[0m -------------------------------------------
[2m[36m(pid=28974)[0m 0 | lstm      | LSTM         | 30.0 K
[2m[36m(pid=28974)[0m 1 | linear    | Linear       | 17    
[2m[36m(pid=28974)[0m 2 | loss_func | SmoothL1Loss | 0     
[2m[36m(pid=28974)[0m -------------------------------------------
[2m[36m(pid=28974)[0m 30.0 K    Trainable params
[2m[36m(pid=28974)[0m 0         Non-trainable params
[2m[36m(pid=28974)[0m 30.0 K    Total params
[2m[36m(pid=28974)[0m 0.120     Total estimated model params size (MB)
[2m[36m(pid=28974)[0m 
[2m[36m(pid=28974)[0m 
[2m[36m(pid=28971)[0m INFO - (NP.co

[2m[36m(pid=29042)[0m INFO - (NP.config.set_auto_batch_epoch) - Auto-set batch_size to 64
  0%|          | 0/100 [00:00<?, ?it/s]GPU available: False, used: False
[2m[36m(pid=29042)[0m TPU available: False, using: 0 TPU cores
[2m[36m(pid=29042)[0m GPU available: False, used: False
[2m[36m(pid=29042)[0m TPU available: False, using: 0 TPU cores
[2m[36m(pid=29042)[0m 
[2m[36m(pid=29042)[0m   | Name      | Type         | Params
[2m[36m(pid=29042)[0m -------------------------------------------
[2m[36m(pid=29042)[0m 0 | lstm      | LSTM         | 275 K 
[2m[36m(pid=29042)[0m 1 | linear    | Linear       | 65    
[2m[36m(pid=29042)[0m 2 | loss_func | SmoothL1Loss | 0     
[2m[36m(pid=29042)[0m -------------------------------------------
[2m[36m(pid=29042)[0m 275 K     Trainable params
[2m[36m(pid=29042)[0m 0         Non-trainable params
[2m[36m(pid=29042)[0m 275 K     Total params
[2m[36m(pid=29042)[0m 1.102     Total estimated model params size (M

[2m[36m(pid=29108)[0m INFO - (NP.config.set_auto_batch_epoch) - Auto-set batch_size to 64
  0%|          | 0/100 [00:00<?, ?it/s]GPU available: False, used: False
[2m[36m(pid=29108)[0m TPU available: False, using: 0 TPU cores
[2m[36m(pid=29108)[0m GPU available: False, used: False
[2m[36m(pid=29108)[0m TPU available: False, using: 0 TPU cores
[2m[36m(pid=29108)[0m 
[2m[36m(pid=29108)[0m   | Name      | Type         | Params
[2m[36m(pid=29108)[0m -------------------------------------------
[2m[36m(pid=29108)[0m 0 | lstm      | LSTM         | 6.2 M 
[2m[36m(pid=29108)[0m 1 | linear    | Linear       | 257   
[2m[36m(pid=29108)[0m 2 | loss_func | SmoothL1Loss | 0     
[2m[36m(pid=29108)[0m -------------------------------------------
[2m[36m(pid=29108)[0m 6.2 M     Trainable params
[2m[36m(pid=29108)[0m 0         Non-trainable params
[2m[36m(pid=29108)[0m 6.2 M     Total params
[2m[36m(pid=29108)[0m 24.659    Total estimated model params size (M

2021-05-21 14:48:23,922	ERROR tune.py:545 -- Trials did not complete: [train_LSTM_tune_ddb98_00000, train_LSTM_tune_ddb98_00004, train_LSTM_tune_ddb98_00009, train_LSTM_tune_ddb98_00011, train_LSTM_tune_ddb98_00013, train_LSTM_tune_ddb98_00014, train_LSTM_tune_ddb98_00015, train_LSTM_tune_ddb98_00016, train_LSTM_tune_ddb98_00017, train_LSTM_tune_ddb98_00018, train_LSTM_tune_ddb98_00019, train_LSTM_tune_ddb98_00020, train_LSTM_tune_ddb98_00021, train_LSTM_tune_ddb98_00022, train_LSTM_tune_ddb98_00023, train_LSTM_tune_ddb98_00024, train_LSTM_tune_ddb98_00025, train_LSTM_tune_ddb98_00026, train_LSTM_tune_ddb98_00027, train_LSTM_tune_ddb98_00028, train_LSTM_tune_ddb98_00029, train_LSTM_tune_ddb98_00030, train_LSTM_tune_ddb98_00031, train_LSTM_tune_ddb98_00032, train_LSTM_tune_ddb98_00033, train_LSTM_tune_ddb98_00034, train_LSTM_tune_ddb98_00035, train_LSTM_tune_ddb98_00036, train_LSTM_tune_ddb98_00037, train_LSTM_tune_ddb98_00038, train_LSTM_tune_ddb98_00039]


In [9]:
best_params

{'learning_rate': 0.013718466327407178,
 'd_hidden': 128,
 'n_lags': 100,
 'num_hidden_layers': 2,
 'lstm_bias': True,
 'lstm_bidirectional': False}

In [10]:
results_df.head()

Unnamed: 0_level_0,loss,time_this_iter_s,done,timesteps_total,episodes_total,training_iteration,experiment_id,date,timestamp,time_total_s,...,time_since_restore,timesteps_since_restore,iterations_since_restore,experiment_tag,config.learning_rate,config.d_hidden,config.n_lags,config.num_hidden_layers,config.lstm_bias,config.lstm_bidirectional
trial_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
ddb98_00000,0.04784,184.533827,False,,,8.0,a26094effe6246d5af6b147ae60a5933,2021-05-21_14-47-36,1621598000.0,1486.938134,...,1486.938134,0.0,8.0,"0_d_hidden=128,learning_rate=0.000337,lstm_bia...",0.000337,128.0,30.0,16.0,False,True
ddb98_00001,0.051339,48.592918,True,,,20.0,4ebbb108e09e4b2986c3835634ed0295,2021-05-21_14-38-22,1621597000.0,932.583967,...,932.583967,0.0,20.0,"1_d_hidden=8,learning_rate=0.0097144,lstm_bias...",0.009714,8.0,100.0,8.0,False,True
ddb98_00002,0.255517,28.091819,True,,,10.0,4be3033d37004d4f98c47a2d0d564c11,2021-05-21_14-27-36,1621596000.0,285.940451,...,285.940451,0.0,10.0,"2_d_hidden=64,learning_rate=0.00010157,lstm_bi...",0.000102,64.0,100.0,8.0,False,False
ddb98_00003,2.9e-05,11.933487,True,,,100.0,44266641f7124ffda7d62dbdab2b0b49,2021-05-21_14-43-17,1621597000.0,1227.605633,...,1227.605633,0.0,100.0,"3_d_hidden=128,learning_rate=0.013718,lstm_bia...",0.013718,128.0,100.0,2.0,True,False
ddb98_00004,0.047503,91.865763,False,,,16.0,f6f09c2bc638437498bd2d01fd44961d,2021-05-21_14-47-51,1621598000.0,1501.235601,...,1501.235601,0.0,16.0,"4_d_hidden=8,learning_rate=0.0013775,lstm_bias...",0.001378,8.0,100.0,16.0,False,True


In [5]:
freq = '5min'
best_params, results_df = tune_hyperparameters('NBeats',
                                               df,
                                               freq, 
                                               num_epochs = 5,
                                              num_samples = 3)

2021-05-21 22:44:37,062	INFO services.py:1269 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8265[39m[22m
[2m[36m(pid=39358)[0m GPU available: False, used: False
[2m[36m(pid=39358)[0m TPU available: False, using: 0 TPU cores
[2m[36m(pid=39362)[0m GPU available: False, used: False
[2m[36m(pid=39362)[0m TPU available: False, using: 0 TPU cores
[2m[36m(pid=39359)[0m GPU available: False, used: False
[2m[36m(pid=39359)[0m TPU available: False, using: 0 TPU cores
[2m[36m(pid=39362)[0m GPU available: False, used: False
[2m[36m(pid=39362)[0m TPU available: False, using: 0 TPU cores
[2m[36m(pid=39359)[0m GPU available: False, used: False
[2m[36m(pid=39359)[0m TPU available: False, using: 0 TPU cores
[2m[36m(pid=39358)[0m GPU available: False, used: False
[2m[36m(pid=39358)[0m TPU available: False, using: 0 TPU cores
[2m[36m(pid=39362)[0m 
[2m[36m(pid=39362)[0m   | Name            | Type       | Params
[2m[36m(pid=39362)[0m ----------------

[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 

[2m[36m(pid=39358)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39358)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39359)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39359)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39358)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39358)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39359)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39359)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39359)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39358)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39358)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39358)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39358)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39359)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39359)[0m torch.Size([64, 20]) torch.Size([64, 

[2m[36m(pid=39358)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39359)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39359)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39358)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39358)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39358)[0m torch.Size([64, 20])
[2m[36m(pid=39358)[0m  
[2m[36m(pid=39359)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39359)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39358)[0m torch.Size([64, 20])
[2m[36m(pid=39358)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39358)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39359)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39359)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39359)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39358)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(

[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 

[2m[36m(pid=39359)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39359)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39359)[0m torch.Size([13, 20]) torch.Size([13, 20])
[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39362)[0m torch.Size([64, 20]) torch.Size([64, 

[2m[36m(pid=39358)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39359)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39359)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39358)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39358)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39358)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39359)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39359)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39358)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39358)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39359)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39359)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39358)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39358)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39359)[0m torch.Size([64, 20]) torch.Size([64, 

[2m[36m(pid=39358)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39358)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39359)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39359)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39359)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39358)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39358)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39358)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39359)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39359)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39358)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39358)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39358)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39359)[0m torch.Size([64, 20]) torch.Size([64, 20])
[2m[36m(pid=39359)[0m torch.Size([64, 20]) torch.Size([64, 

In [6]:
best_params

{'learning_rate': 0.0007916629501075156, 'context_length': 100}

In [7]:
results_df

Unnamed: 0_level_0,loss,time_this_iter_s,done,timesteps_total,episodes_total,training_iteration,experiment_id,date,timestamp,time_total_s,pid,hostname,node_ip,time_since_restore,timesteps_since_restore,iterations_since_restore,experiment_tag,config.learning_rate,config.context_length
trial_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1
faaa5_00000,3.072153,18.586122,True,,,5,7b67c8b9c6eb44aca2845bc8e56a4ebb,2021-05-21_22-46-19,1621626379,97.502764,39359,MacBook-Pro-Polina.local,192.168.0.7,97.502764,0,5,"0_context_length=100,learning_rate=0.00015071",0.000151,100
faaa5_00001,3.030095,17.945912,True,,,5,1656dc1c435f4cc883e27b4c35c1d52a,2021-05-21_22-46-14,1621626374,92.76503,39362,MacBook-Pro-Polina.local,192.168.0.7,92.76503,0,5,"1_context_length=30,learning_rate=0.0029214",0.002921,30
faaa5_00002,2.599046,18.865901,True,,,5,3fb214eaabee41b5b0fba90891813774,2021-05-21_22-46-19,1621626379,97.607086,39358,MacBook-Pro-Polina.local,192.168.0.7,97.607086,0,5,"2_context_length=100,learning_rate=0.00079166",0.000792,100


In [5]:
np.min(10, 2)

AxisError: axis 2 is out of bounds for array of dimension 0

In [7]:
best_params

{'learning_rate': 0.0601654074086814,
 'context_length': 10,
 'hidden_size': 16,
 'attention_head_size': 1}

In [8]:
results_df

Unnamed: 0_level_0,loss,time_this_iter_s,done,timesteps_total,episodes_total,training_iteration,experiment_id,date,timestamp,time_total_s,...,hostname,node_ip,time_since_restore,timesteps_since_restore,iterations_since_restore,experiment_tag,config.learning_rate,config.context_length,config.hidden_size,config.attention_head_size
trial_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
217f9_00000,0.585543,12.794128,True,,,10,374ac13a0d874cd78833bb53a8fd825b,2021-05-21_15-24-11,1621599851,992.677625,...,MacBook-Pro-Polina.local,192.168.0.7,992.677625,0,10,"0_attention_head_size=2,context_length=10,hidd...",0.000435,10,16,2
217f9_00001,0.335199,11.922237,True,,,100,91d1cf0e9d574b45b972a8df7e14a0bc,2021-05-21_15-39-29,1621600769,1911.347692,...,MacBook-Pro-Polina.local,192.168.0.7,1911.347692,0,100,"1_attention_head_size=1,context_length=10,hidd...",0.060165,10,16,1
217f9_00002,0.643699,22.244758,True,,,10,e9709940e5de48da840b8de46ebc4544,2021-05-21_15-22-29,1621599749,890.869505,...,MacBook-Pro-Polina.local,192.168.0.7,890.869505,0,10,"2_attention_head_size=2,context_length=10,hidd...",0.031605,10,32,2


In [None]:
best_params

In [None]:
results_df