In [55]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

import os

DIR = '../mrnn/nottingham_gridsearch_fair/'

In [56]:


experiments = os.listdir(DIR)

gate_expmts = [exp for exp in experiments if 'gate' in exp and 'comb' not in exp]
combo_expmts = [exp for exp in experiments if 'comb' in exp]
bilin_expmts = [exp for exp in experiments if 'simple' in exp]
gru_expmts = [exp for exp in experiments if 'gru' in exp]
lstm_expmts = [exp for exp in experiments if 'lstm' in exp]
van_expts = [exp for exp in experiments if 'vanilla' in exp]
print(combo_expmts)
print(gate_expmts)

['cp-gate-combined-0.1-8-75-rankone', 'cp-gate-combined-0.1-8-75-rankhalf', 'cp-gate-combined-0.1-8-75-rankfull', 'cp-gate-combined-0.1-8-75-rankdouble', 'cp-gate-combined-0.1-8-100-rankone', 'cp-gate-combined-0.1-8-100-rankhalf', 'cp-gate-combined-0.1-8-100-rankfull', 'cp-gate-combined-0.1-8-100-rankdouble', 'cp-gate-combined-0.01-8-75-rankone', 'cp-gate-combined-0.01-8-75-rankhalf', 'cp-gate-combined-0.01-8-75-rankfull', 'cp-gate-combined-0.01-8-75-rankdouble', 'cp-gate-combined-0.01-8-100-rankone', 'cp-gate-combined-0.01-8-100-rankhalf', 'cp-gate-combined-0.01-8-100-rankfull', 'cp-gate-combined-0.01-8-100-rankdouble', 'cp-gate-combined-0.001-8-75-rankone', 'cp-gate-combined-0.001-8-75-rankhalf', 'cp-gate-combined-0.001-8-75-rankfull', 'cp-gate-combined-0.001-8-75-rankdouble', 'cp-gate-combined-0.001-8-100-rankone', 'cp-gate-combined-0.001-8-100-rankhalf', 'cp-gate-combined-0.001-8-100-rankfull', 'cp-gate-combined-0.001-8-100-rankdouble']
['cp-gate-0.1-8-75-rankone', 'cp-gate-0.1-8-7

In [57]:
def get_nll(result):
    return float(result.split(':')[-1])

def get_xent(result):
    # more awkward
    return float(result.split(':')[1].split(',')[0])

def best_expts(all_expts, directory=DIR):
    best_val_expt = ''
    best_test_expt = ''
    best_val_err = 100
    best_test_err = 100
    
    for exp in all_expts:
        results_path = os.path.join(directory, exp, 'earlystopped_results.txt')
        # open it up and find the numbers
        if not os.path.exists(results_path):
            print('no results for: {}'.format(results_path))
            continue
        with open(results_path) as fp:
            test, valid, train, _ = fp.read().split('\n')
            # now split out the nll
            test_nll = get_nll(test)
            valid_nll = get_nll(valid)
            
            if test_nll < best_test_err:
                best_test_err = test_nll
                best_test_expt = exp
            if valid_nll < best_val_err:
                best_val_err = valid_nll
                best_val_expt = exp
    return best_test_expt, best_val_expt

In [58]:
def format_results(exp, directory=DIR):
    splits = exp.split('-')
    if len(splits) >= 6:
        num = len(splits) - 4
        cell = '-'.join(splits[:num])
        lr, bs, sl, rank = splits[num:]
    elif len(splits) == 5:
        cell, lr, bs, sl, rank = splits
    
    print('CELL: {}'.format(cell))
    print('~lr: {}, bs: {}, sl: {}, rk: {}~'.format(lr, bs, sl, rank))
    result_path = os.path.join(directory, exp, 'earlystopped_results.txt')
    with open(result_path) as fp:
        print(fp.read())


In [59]:
for expt in best_expts(gate_expmts):
    format_results(expt)

CELL: cp-gate
~lr: 0.001, bs: 8, sl: 75, rk: rankone~
Test  xent: 0.06688828590149815, nll: 4.213984843846914
Valid xent: 0.0667765854710811, nll: 4.20694793525495
Train xent: 0.06293986812233925, nll: 3.9652333211090607

CELL: cp-gate
~lr: 0.001, bs: 8, sl: 100, rk: rankone~
Test  xent: 0.06694772358645093, nll: 4.217735897411
Valid xent: 0.06648336566592518, nll: 4.188481414527224
Train xent: 0.06188390677061555, nll: 3.8987145801475145



In [60]:
for expt in best_expts(combo_expmts):
    format_results(expt)

CELL: cp-gate-combined
~lr: 0.001, bs: 8, sl: 100, rk: rankfull~
Test  xent: 0.06750924417918379, nll: 4.253118228912354
Valid xent: 0.06719245323747919, nll: 4.23316026152226
Train xent: 0.06354729982927375, nll: 4.003514129112209

CELL: cp-gate-combined
~lr: 0.001, bs: 8, sl: 75, rk: rankfull~
Test  xent: 0.06764293154953299, nll: 4.261527187115437
Valid xent: 0.06716979731266436, nll: 4.231719503277226
Train xent: 0.06389146060761759, nll: 4.025183474815498



In [61]:
for expt in best_expts(lstm_expmts):
    format_results(expt)

CELL: lstm
~lr: 0.001, bs: 8, sl: 100, rk: rankone~
Test  xent: 0.06368763500993901, nll: 4.012352440573952
Valid xent: 0.06298857365261044, nll: 3.968311251255504
Train xent: 0.056866609436624194, nll: 3.5826250686904424

CELL: lstm
~lr: 0.001, bs: 8, sl: 100, rk: rankone~
Test  xent: 0.06368763500993901, nll: 4.012352440573952
Valid xent: 0.06298857365261044, nll: 3.968311251255504
Train xent: 0.056866609436624194, nll: 3.5826250686904424



In [62]:
for expt in best_expts(gru_expmts):
    format_results(expt)

CELL: gru
~lr: 0.001, bs: 8, sl: 100, rk: rankone~
Test  xent: 0.06374120976437223, nll: 4.015734724564986
Valid xent: 0.06328735467896127, nll: 3.98714262142516
Train xent: 0.056945674306801544, nll: 3.587613225522624

CELL: gru
~lr: 0.001, bs: 8, sl: 75, rk: rankone~
Test  xent: 0.06378244729460897, nll: 4.018322541907027
Valid xent: 0.0629468722465007, nll: 3.965680429809972
Train xent: 0.054850695168567914, nll: 3.45561660507978



In [63]:
for expt in best_expts(van_expts):
    format_results(expt)

CELL: vanilla
~lr: 0.001, bs: 8, sl: 100, rk: rankone~
Test  xent: 0.06169973164796829, nll: 3.8871165925806217
Valid xent: 0.06139618360943962, nll: 3.8679935430225574
Train xent: 0.05282099833248428, nll: 3.3277516990765186

CELL: vanilla
~lr: 0.001, bs: 8, sl: 100, rk: rankone~
Test  xent: 0.06169973164796829, nll: 3.8871165925806217
Valid xent: 0.06139618360943962, nll: 3.8679935430225574
Train xent: 0.05282099833248428, nll: 3.3277516990765186



In [64]:
for expt in best_expts(bilin_expmts):
    format_results(expt)

CELL: simple_cp
~lr: 0.001, bs: 8, sl: 100, rk: rankone~
Test  xent: 0.0619917136024345, nll: 3.905510997772217
Valid xent: 0.061066428493512306, nll: 3.847218551133808
Train xent: 0.05387315863734996, nll: 3.3940386459298804

CELL: simple_cp
~lr: 0.001, bs: 8, sl: 100, rk: rankone~
Test  xent: 0.0619917136024345, nll: 3.905510997772217
Valid xent: 0.061066428493512306, nll: 3.847218551133808
Train xent: 0.05387315863734996, nll: 3.3940386459298804

