In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

import os

In [32]:
experiments = os.listdir('../jsb_gridsearch/')

gate_expmts = [exp for exp in experiments if 'gate' in exp]
bilin_expmts = [exp for exp in experiments if 'simple' in exp]
gru_expmts = [exp for exp in experiments if 'gru' in exp]
lstm_expmts = [exp for exp in experiments if 'lstm' in exp]
van_expts = [exp for exp in experiments if 'vanilla' in exp]
print(bilin_expmts)

['simple_cp-0.001-16-100', 'simple_cp-0.001-16-25', 'simple_cp-0.001-16-35', 'simple_cp-0.001-16-50', 'simple_cp-0.001-16-75', 'simple_cp-0.001-32-100', 'simple_cp-0.001-32-25', 'simple_cp-0.001-32-35', 'simple_cp-0.001-32-50', 'simple_cp-0.001-32-75', 'simple_cp-0.001-4-100', 'simple_cp-0.001-4-25', 'simple_cp-0.001-4-35', 'simple_cp-0.001-4-50', 'simple_cp-0.001-4-75', 'simple_cp-0.001-8-100', 'simple_cp-0.001-8-25', 'simple_cp-0.001-8-35', 'simple_cp-0.001-8-50', 'simple_cp-0.001-8-75', 'simple_cp-0.01-16-100', 'simple_cp-0.01-16-25', 'simple_cp-0.01-16-35', 'simple_cp-0.01-16-50', 'simple_cp-0.01-16-75', 'simple_cp-0.01-32-100', 'simple_cp-0.01-32-25', 'simple_cp-0.01-32-35', 'simple_cp-0.01-32-50', 'simple_cp-0.01-32-75', 'simple_cp-0.01-4-100', 'simple_cp-0.01-4-25', 'simple_cp-0.01-4-35', 'simple_cp-0.01-4-50', 'simple_cp-0.01-4-75', 'simple_cp-0.01-8-100', 'simple_cp-0.01-8-25', 'simple_cp-0.01-8-35', 'simple_cp-0.01-8-50', 'simple_cp-0.01-8-75', 'simple_cp-0.1-16-100', 'simple

In [33]:
def get_nll(result):
    return float(result.split(':')[-1])

def get_xent(result):
    # more awkward
    return float(result.split(':')[1].split(',')[0])

def best_expts(all_expts, directory='../jsb_gridsearch/'):
    best_val_expt = ''
    best_test_expt = ''
    best_val_err = 100
    best_test_err = 100
    
    for exp in all_expts:
        results_path = os.path.join(directory, exp, 'earlystopped_results.txt')
        # open it up and find the numbers
        with open(results_path) as fp:
            test, valid, train, _ = fp.read().split('\n')
            # now split out the nll
            test_nll = get_nll(test)
            valid_nll = get_nll(valid)
            
            if test_nll < best_test_err:
                best_test_err = test_nll
                best_test_expt = exp
            if valid_nll < best_val_err:
                best_val_err = valid_nll
                best_val_expt = exp
    return best_test_expt, best_val_expt

In [34]:
def format_results(exp, directory='../jsb_gridsearch/'):
    splits = exp.split('-')
    if len(splits) == 5:
        cell = '-'.join(splits[:2])
        lr, bs, sl = splits[2:]
    else:
        cell, lr, bs, sl = splits
    print('CELL: {}'.format(cell))
    print('~lr: {}, bs: {}, sl: {}~'.format(lr, bs, sl))
    result_path = os.path.join(directory, exp, 'earlystopped_results.txt')
    with open(result_path) as fp:
        print(fp.read())


In [35]:
for expt in best_expts(gate_expmts):
    format_results(expt)

CELL: cp-gate
~lr: 0.01, bs: 8, sl: 35~
Test  xent: 0.15677809715270996, nll: 8.622796114753275
Valid xent: 0.1559621635824442, nll: 8.577919900417328
Train xent: 0.14541432231664658, nll: 7.9977883625030515

CELL: cp-gate
~lr: 0.01, bs: 8, sl: 75~
Test  xent: 0.15726640607629502, nll: 8.649653026035853
Valid xent: 0.15542446502617427, nll: 8.54834611075265
Train xent: 0.1443079262971878, nll: 7.936936461407205



In [36]:
for expt in best_expts(lstm_expmts):
    format_results(expt)

CELL: lstm
~lr: 0.001, bs: 4, sl: 100~
Test  xent: 0.15441124818541788, nll: 8.492618777535178
Valid xent: 0.1530637646263296, nll: 8.41850722919811
Train xent: 0.14562407433986663, nll: 8.009324264526366

CELL: lstm
~lr: 0.001, bs: 4, sl: 100~
Test  xent: 0.15441124818541788, nll: 8.492618777535178
Valid xent: 0.1530637646263296, nll: 8.41850722919811
Train xent: 0.14562407433986663, nll: 8.009324264526366



In [37]:
for expt in best_expts(gru_expmts):
    format_results(expt)

CELL: gru
~lr: 0.001, bs: 4, sl: 100~
Test  xent: 0.1551306058060039, nll: 8.532183560458096
Valid xent: 0.15407809886065396, nll: 8.474295442754572
Train xent: 0.14569367383207593, nll: 8.013152190617152

CELL: gru
~lr: 0.001, bs: 4, sl: 100~
Test  xent: 0.1551306058060039, nll: 8.532183560458096
Valid xent: 0.15407809886065396, nll: 8.474295442754572
Train xent: 0.14569367383207593, nll: 8.013152190617152



In [38]:
for expt in best_expts(van_expts):
    format_results(expt)

CELL: vanilla
~lr: 0.001, bs: 8, sl: 50~
Test  xent: 0.15683378821069544, nll: 8.625858740373092
Valid xent: 0.15506418320265683, nll: 8.528530554337935
Train xent: 0.14674661798136576, nll: 8.071064308711462

CELL: vanilla
~lr: 0.001, bs: 8, sl: 50~
Test  xent: 0.15683378821069544, nll: 8.625858740373092
Valid xent: 0.15506418320265683, nll: 8.528530554337935
Train xent: 0.14674661798136576, nll: 8.071064308711462



In [39]:
for expt in best_expts(bilin_expmts):
    format_results(expt)

CELL: simple_cp
~lr: 0.001, bs: 4, sl: 25~
Test  xent: 0.1574447069396364, nll: 8.65945907349282
Valid xent: 0.15634396283522897, nll: 8.598918220271235
Train xent: 0.14626479595899583, nll: 8.044563899721417

CELL: simple_cp
~lr: 0.001, bs: 4, sl: 25~
Test  xent: 0.1574447069396364, nll: 8.65945907349282
Valid xent: 0.15634396283522897, nll: 8.598918220271235
Train xent: 0.14626479595899583, nll: 8.044563899721417

