In [5]:
import json
import numpy as np

from sklearn.model_selection import ParameterSampler
from scipy.stats.distributions import lognorm

In [4]:
# rng = np.random.RandomState(0)

In [9]:
parameter_grid = {
#     'epochs': [25, 50, 100, 200], # it doesn't make sense to search epochs, when we can use early stopping
    'batch_size': [128, 256, 512, 768, 1024],
    'learning_rate': [0.00001, 0.0001, 0.001, 0.01, 0.1], # TODO: can this be optimized with a learning rate scheduler?
    'latent_dim': [128, 256, 320, 512, 768, 1024],
    'embedding_dim': [32, 50, 64, 96, 128],
    'vocabulary_size': [3000, 4000, 5000, 6000, 7000],
    'max_input_seq_length': [64, 96, 128, 160, 192],
    'max_output_seq_length': [5, 6, 7, 8],
#     'learning_rate_distro': lognorm([0.01], loc=-0.8),
    'dropout_rate': [0.03, 0.05, 0.1, 0.2], # TODO: can we sample from a continuout distribution?
#     'bi_lstm': [True, False], # not supported yet
}

In [10]:
random_search_candidates = list(ParameterSampler(parameter_grid, n_iter=16))
random_search_candidates

[{'vocabulary_size': 4000,
  'max_output_seq_length': 5,
  'max_input_seq_length': 192,
  'learning_rate': 0.1,
  'latent_dim': 320,
  'embedding_dim': 64,
  'dropout_rate': 0.03,
  'batch_size': 512},
 {'vocabulary_size': 6000,
  'max_output_seq_length': 5,
  'max_input_seq_length': 160,
  'learning_rate': 0.1,
  'latent_dim': 1024,
  'embedding_dim': 96,
  'dropout_rate': 0.03,
  'batch_size': 128},
 {'vocabulary_size': 6000,
  'max_output_seq_length': 8,
  'max_input_seq_length': 96,
  'learning_rate': 0.001,
  'latent_dim': 256,
  'embedding_dim': 64,
  'dropout_rate': 0.03,
  'batch_size': 1024},
 {'vocabulary_size': 5000,
  'max_output_seq_length': 5,
  'max_input_seq_length': 192,
  'learning_rate': 0.01,
  'latent_dim': 1024,
  'embedding_dim': 32,
  'dropout_rate': 0.1,
  'batch_size': 512},
 {'vocabulary_size': 4000,
  'max_output_seq_length': 8,
  'max_input_seq_length': 192,
  'learning_rate': 0.01,
  'latent_dim': 320,
  'embedding_dim': 32,
  'dropout_rate': 0.2,
  'batch

In [11]:
print(json.dumps(random_search_candidates, indent=2))

[
  {
    "vocabulary_size": 4000,
    "max_output_seq_length": 5,
    "max_input_seq_length": 192,
    "learning_rate": 0.1,
    "latent_dim": 320,
    "embedding_dim": 64,
    "dropout_rate": 0.03,
    "batch_size": 512
  },
  {
    "vocabulary_size": 6000,
    "max_output_seq_length": 5,
    "max_input_seq_length": 160,
    "learning_rate": 0.1,
    "latent_dim": 1024,
    "embedding_dim": 96,
    "dropout_rate": 0.03,
    "batch_size": 128
  },
  {
    "vocabulary_size": 6000,
    "max_output_seq_length": 8,
    "max_input_seq_length": 96,
    "learning_rate": 0.001,
    "latent_dim": 256,
    "embedding_dim": 64,
    "dropout_rate": 0.03,
    "batch_size": 1024
  },
  {
    "vocabulary_size": 5000,
    "max_output_seq_length": 5,
    "max_input_seq_length": 192,
    "learning_rate": 0.01,
    "latent_dim": 1024,
    "embedding_dim": 32,
    "dropout_rate": 0.1,
    "batch_size": 512
  },
  {
    "vocabulary_size": 4000,
    "max_output_seq_length": 8,
    "max_input_seq_length": 1