In [10]:
import numpy as np
import pandas as pd
import os
import re
import json
from scipy.stats import sem
from os.path import join

In [13]:
# Helper functions
def atof(text):
    try:
        retval = float(text)
    except ValueError:
        retval = text
    return retval

def natural_keys(text):
    '''
    alist.sort(key=natural_keys) sorts in human order
    http://nedbatchelder.com/blog/200712/human_sorting.html
    (See Toothy's implementation in the comments)
    float regex comes from https://stackoverflow.com/a/12643073/190597
    '''
    return [ atof(c) for c in re.split(r'[+-]?([0-9]+(?:[.][0-9]*)?|[.][0-9]+)', text) ]

In [85]:
header = r'''
\begin{table}[]
\resizebox{\linewidth}{!}{%
\begin{tabular}{@{}lcccccc@{}}
\toprule
                                            & \multicolumn{3}{c}{\textbf{Addition}}      & \multicolumn{3}{c}{\textbf{E9P}} \\ \midrule
\multicolumn{1}{l|}{Training set size (\%)} & \textbf{100} & \textbf{10} & \textbf{5} & \textbf{100}     & \textbf{10}     & \textbf{5}    \\ \midrule
'''

In [93]:
data_rows = r'''
\multicolumn{1}{l|}{$Meta_{Abd}$}             &     &    & \multicolumn{1}{l|}{}  &         &        &      \\
\multicolumn{1}{l|}{CNN}                    &     &    & \multicolumn{1}{l|}{}  &         &        &      \\
\multicolumn{1}{l|}{CBM}                    &     &    & \multicolumn{1}{l|}{}  &         &        &      \\
\multicolumn{1}{l|}{CBM-S}                  &     &    & \multicolumn{1}{l|}{}  &         &        &      \\
\multicolumn{1}{l|}{CNN-LSTM-NAC}           &     &    & \multicolumn{1}{l|}{}  &         &        &      \\
\multicolumn{1}{l|}{CNN-LSTM-NALU}          &     &    & \multicolumn{1}{l|}{}  &         &        &      \\ \midrule
\multicolumn{1}{l|}{NSIL}                   &     &    & \multicolumn{1}{l|}{}  &         &        &      \\ \bottomrule
'''

In [94]:
footer = r'''
\end{tabular}
}
\caption{Non-recursive arithmetic results with reducing training set sizes. Results indicate average test set accuracy over 20 repeats. Best results highlighted in bold, and standard error in parentheses. }
\end{table}
'''

In [95]:
def get_method_result(
     example_dir,
     baselines,
     dataset_pct=100,
     dataset=None):
    baseline_base_dir = example_dir+'/baselines/saved_results'
    nsl_repeats_dir = example_dir+'/saved_results'
    
    if dataset is not None:
        baseline_base_dir = baseline_base_dir+'/'+dataset
        nsl_repeats_dir = nsl_repeats_dir+'/repeats/'+dataset
    else:
        nsl_repeats_dir = nsl_repeats_dir+'/repeats'
    
    # Get NSL results
    nsl_dir = nsl_repeats_dir+'/'+str(dataset_pct)
    repeats = os.listdir(nsl_dir)
    repeats = [r for r in repeats if r != '.DS_Store']
    repeats.sort(key=natural_keys)
    
    # X data is just epoch number
    epoch_num = 20
    all_results_epoch = []
    for idx, i in enumerate(repeats):
        # Read in test_log and get end-to-end accuracy at this epoch
        with open(join(nsl_dir, i, 'test_log.json'), 'r') as jf:
            tl = json.loads(jf.read())
            acc = tl[str(epoch_num)]['end_to_end_acc']
            all_results_epoch.append(acc)

    # Compute mean and std err across all repeats
    nsl_means = np.mean(all_results_epoch)
    nsl_errs = sem(all_results_epoch)
        
    # Get baseline results
    baseline_results = {}
    for b in baselines:
        baseline_results[b] = {}
        
        b_dir = baseline_base_dir+'/'+b+'/'+str(dataset_pct)
        b_repeats = os.listdir(b_dir)
        b_repeats = [b_r for b_r in b_repeats if b_r != '.DS_Store']
        b_repeats = [b_r for b_r in b_repeats if 'csv' not in b_r]
        b_repeats = [b_r for b_r in b_repeats if 'txt' not in b_r]
        b_repeats.sort(key=natural_keys)
        
        all_b_results_epoch = []
        for idx, i in enumerate(b_repeats):
            # Read in test_log and get accuracy for this epoch
            if b == 'meta_abd':
                b_tl = pd.read_csv(b_dir+'/'+i+'/test.csv')
                acc = b_tl['task_accuracy'].iloc[-1]
#                 _max_epoch = b_tl['task_accuracy'].index[-1]
                all_b_results_epoch.append(acc)
            else:
                b_tl = pd.read_csv(b_dir+'/'+i+'/test_log.csv')
                acc = b_tl['accuracy'][epoch_num]
                all_b_results_epoch.append(acc)
                
        # Calculate mean and stderr
        baseline_results[b]['mean'] = np.mean(all_b_results_epoch)
        baseline_results[b]['err'] = sem(all_b_results_epoch)
            
    # Plot graph
    return nsl_means, nsl_errs, baseline_results

In [102]:
# Get results for each method in list of 100, 10, 5
def get_all_results():
    example_name = 'arithmetic'
    results = {}
    # Baseline ID with display name
    baseline_info = {
        'cnn': 'CNN', 
        'cbm_joint_lambda_0': 'CBM',
        'cbm_joint_lambda_0_with_softmax': 'CBM-S',
        'cnn_lstm_nac': 'CNN-LSTM-NAC',
        'cnn_lstm_nalu': 'CNN-LSTM-NALU',
        'meta_abd': 'Meta_Abd'
    }
    for d in ['sum', 'e9p']:
        if d == 'e9p':
            del baseline_info['meta_abd']
        method_results = {}
        for pct in [100, 10, 5]:
            example = '../../../examples/'+example_name
            res = get_method_result(example, baseline_info, dataset=d, dataset_pct=pct)
            res_str = f'{round(res[0],3)} ({round(res[1],3)})'
            if 'nsil' in method_results:
                method_results['nsil'].append(res_str)
            else:
                method_results['nsil'] = [res_str]
            
            baseline_res = res[2]
            for b in baseline_res:
                b_res_str = f"{round(baseline_res[b]['mean'],3)} ({round(baseline_res[b]['err'],3)})"
                if b in method_results:
                    method_results[b].append(b_res_str)
                else:
                    method_results[b] = [b_res_str]
        results[d] = method_results
    return results

In [107]:
def fill_table(res):
    # Get row results
    row_res_to_fill = []
    for m in ['meta_abd', 'cnn', 'cbm_joint_lambda_0', 'cbm_joint_lambda_0_with_softmax', 'cnn_lstm_nac', 'cnn_lstm_nalu','nsil']:
        for task in ['sum', 'e9p']:
            if task == 'e9p' and m == 'meta_abd':
                row_res_to_fill += ['-']*3
            else:
                row_res_to_fill += res[task][m]
    
    new_rows = []
    for r in data_rows.split('\n'):
        if r != '':
            new_r = ''
            for idx in range(len(r)):
                if r[idx] == '&':
                    new_r += '& ' + row_res_to_fill.pop(0)
                else:
                    new_r += r[idx]
            new_rows.append(new_r)
    formatted = '\n'.join(new_rows)
    return f'{header}\n{formatted}\n{footer}'

In [108]:
_r = get_all_results()
print(fill_table(_r))


\begin{table}[]
\resizebox{\linewidth}{!}{%
\begin{tabular}{@{}lllllll@{}}
\toprule
                                            & \multicolumn{3}{c}{Addition}      & \multicolumn{3}{c}{E9P} \\ \midrule
\multicolumn{1}{l|}{Training set size (\%)} & 100 & 10 & \multicolumn{1}{l|}{5} & 100     & 10     & 5    \\ \midrule

\multicolumn{1}{l|}{$Meta_{Abd}$}             & 0.268 (0.03)     & 0.388 (0.053)    & 0.213 (0.035) \multicolumn{1}{l|}{}  & -         & -        & -      \\
\multicolumn{1}{l|}{CNN}                    & 0.948 (0.001)     & 0.715 (0.007)    & 0.425 (0.006) \multicolumn{1}{l|}{}  & 0.968 (0.001)         & 0.899 (0.002)        & 0.792 (0.007)      \\
\multicolumn{1}{l|}{CBM}                    & 0.962 (0.003)     & 0.534 (0.034)    & 0.141 (0.013) \multicolumn{1}{l|}{}  & 0.978 (0.001)         & 0.934 (0.002)        & 0.842 (0.014)      \\
\multicolumn{1}{l|}{CBM-S}                  & 0.671 (0.053)     & 0.095 (0.0)    & 0.095 (0.0) \multicolumn{1}{l|}{}  & 0.881 (0.062) 