In [None]:
%matplotlib inline
import os
import pandas as pd
import numpy as np
from tools import train_test
from tools import plots

# pandas displaying options
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.height', 1000)
pd.set_option('display.width', 1000)

In [None]:
# read all results into infos dataframe
rdir = './results'
infos = pd.DataFrame()
for root, subdirs, files in os.walk(rdir):
    if root[len(rdir)+1:].count(os.sep) == 2:
        if 'model_info.tsv' in files:
            info = pd.read_csv(root + '/model_info.tsv', sep='\t')
            timestamp = str(root.split(os.sep)[-1])
            
            # add loss stats to info
            if 'learning_curve.tsv' in files:
                learning_curve = pd.read_csv(root + '/learning_curve.tsv', sep='\t')
                if info['log_type'].iloc[0] == 'epoch':
                    info['epoch_loss_min'] = learning_curve['epoch_loss'].min()
                    info['epoch_loss_last'] = learning_curve['epoch_loss'].iloc[-1]
            
            info.index = [timestamp]
            infos = pd.concat((infos, info))

# Analysis

## Filter data

In [None]:
filt = infos.loc[
    (infos['model_name']   == 'Analogy') &
    (infos['dataset_name'] == 'NELL186')
].sort_index()

In [None]:
filt

In [None]:
filt.loc[filt.index == unicode(1526567410)].transpose()

## Plot learning curve

In [None]:
def learning_curve_from_filt(timestamp):
    timestamp = unicode(timestamp)
    for idx,row in filt.iterrows():
        if idx == timestamp:
            model_path = './results/' + row['dataset_name'] +'/'+ row['model_name'] +'/'+ idx
            learning_curve = pd.read_csv(model_path + '/learning_curve.tsv', sep='\t')
            fig = plots.plot_learning_curve(learning_curve, row)

In [None]:
learning_curve_from_filt(1526180534)

# Export a set of models

In [None]:
models_ts = [
    1526710056,
    1526710447,
    1526711822,
    1526417226,
    1526535074,
]

filt = pd.DataFrame()
for ts in models_ts:
    filt = pd.concat((filt, infos.loc[infos.index == unicode(ts)]))

In [None]:
filt.to_csv('/home/arthurcgusmao/Downloads/best_models.tsv', sep='\t')