In [1]:
import tensorflow as tf
from tensorboard.plugins.hparams import api as hp
from src.utils import args as model_args

In [2]:
HP_NUM_LAYERS = hp.HParam('num_layers', hp.Discrete([1, 2]))
HP_DROPOUT = hp.HParam('dropout', hp.Discrete([0.0, 0.1, 0.5]))
HP_HIDDEN_DIM = hp.HParam('hidden_dim', hp.Discrete([25, 50, 100, 200, 400]))
HP_USE_ATTENTION = hp.HParam('use_attention', hp.Discrete([True, False]))
HP_INCLUDE_POS_TAG = hp.HParam('include_pos_tag', hp.Discrete(['', 'aux', 'input']))

METRIC_ACCURACY = 'accuracy'

with tf.summary.create_file_writer('snap/hparam_tuning').as_default():
  hp.hparams_config(
    hparams=[HP_NUM_LAYERS, HP_DROPOUT, HP_HIDDEN_DIM, HP_USE_ATTENTION],
    metrics=[hp.Metric(METRIC_ACCURACY, display_name='Accuracy')],
  )

In [3]:
model_args.parse_args()

In [4]:
from src.tasks import train
from src.data.scan import get_dataset

In [5]:
train_ds, test_ds, (in_vec, _, _) = get_dataset('simple')
pad_idx = in_vec.get_vocabulary().index('')
start_idx = in_vec.get_vocabulary().index('<sos>')
end_idx = in_vec.get_vocabulary().index('<eos>')

In vocabulary: {0: '', 1: '[UNK]', 2: '<sos>', 3: '<eos>', 4: 'run', 5: 'opposite', 6: 'right', 7: 'after', 8: 'turn', 9: 'left', 10: 'twice', 11: 'around', 12: 'thrice', 13: 'walk', 14: 'jump', 15: 'and', 16: 'look'}
Pos vocabulary: {0: '', 1: '[UNK]', 2: '<sos>', 3: '<eos>', 4: 'ADJ', 5: 'ADP', 6: 'VERB', 7: 'PRT', 8: 'ADV', 9: 'CONJ', 10: 'NOUN'}
Out vocabulary: {0: '', 1: '[UNK]', 2: '<sos>', 3: '<eos>', 4: 'I_TURN_LEFT', 5: 'I_TURN_RIGHT', 6: 'I_RUN', 7: 'I_WALK', 8: 'I_JUMP', 9: 'I_LOOK'}
16728 4182


In [6]:
import importlib
import src.models.lstm_base as lstm_base
importlib.reload(train)
importlib.reload(lstm_base)

<module 'src.models.lstm_base' from 'd:\\Desktop\\Universidad\\2021-2\\TextMining\\Proyecto\\TextMiningCode\\src\\models\\lstm_base.py'>

In [7]:
results = {}
for num_layers in HP_NUM_LAYERS.domain.values:
    for dropout in HP_DROPOUT.domain.values:
        for hidden_dim in HP_HIDDEN_DIM.domain.values:
            for use_attention in HP_USE_ATTENTION.domain.values:
                for include_pos_tag in HP_INCLUDE_POS_TAG.domain.values:
                    hparams = {
                        HP_NUM_LAYERS: num_layers,
                        HP_DROPOUT: dropout,
                        HP_HIDDEN_DIM: hidden_dim,
                        HP_USE_ATTENTION: use_attention,
                        HP_INCLUDE_POS_TAG: include_pos_tag,
                    }

                    run_name = f'lay({num_layers})-drop({dropout})-hidden({hidden_dim})'
                    if use_attention:
                        run_name += '-attn'

                    if include_pos_tag:
                        run_name += f'-pos({include_pos_tag})'

                    model_args.args.name = run_name
                    model_args.args.hidden_layers = num_layers
                    model_args.args.dropout = dropout
                    model_args.args.hidden_size = hidden_dim
                    model_args.args.use_attention = use_attention
                    model_args.args.include_pos_tag = include_pos_tag

                    print('--- Starting trial: %s' % run_name)
                    print({h.name: hparams[h] for h in hparams})
                    
                    res = train.train(train_ds, test_ds, pad_idx, start_idx, end_idx)

                    results[run_name] = res

--- Starting trial: lay(1)-drop(0.0)-hidden(25)
{'num_layers': 1, 'dropout': 0.0, 'hidden_dim': 25, 'use_attention': False, 'include_pos_tag': ''}
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
Model: "seq2_seq_lstm"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_embedding (Embedding)  multiple                 375       
                                                                 
 action_output (Embedding)   multiple                  64        
                                                                 
 dropout (Dropout)           multiple                  0         
                                                                 
 bidirectional (Bidirectiona  multiple                 10200     
 l)  