In [5]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import re
import json
import torch
import yaml
import math
import re
import torch.nn as nn
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from scipy.stats import sem
from os.path import join
from tqdm import tqdm

In [2]:
# Helper functions
def atof(text):
    try:
        retval = float(text)
    except ValueError:
        retval = text
    return retval

def natural_keys(text):
    '''
    alist.sort(key=natural_keys) sorts in human order
    http://nedbatchelder.com/blog/200712/human_sorting.html
    (See Toothy's implementation in the comments)
    float regex comes from https://stackoverflow.com/a/12643073/190597
    '''
    return [ atof(c) for c in re.split(r'[+-]?([0-9]+(?:[.][0-9]*)?|[.][0-9]+)', text) ]

## Cumulative Arithmetic

In [None]:
# Generate NSIL table row

In [6]:
def get_network_acc(example_dir, num_epochs=20):
    
    nsil_repeats_dir = join(example_dir, 'saved_results', 'repeats')
    for d_idx, dataset in enumerate(['meta_abd_data_sum', 'meta_abd_data_prod']):
        nsil_accs = []
        max_acc = 0
        nsil_dir = nsil_repeats_dir + '/' + dataset
        repeats = os.listdir(nsil_dir)
        repeats = [r for r in repeats if r != '.DS_Store']
        repeats.sort(key=natural_keys)

        # X data is just epoch number
        full_range = num_epochs + 1
        X = list(range(full_range))

        for idx, i in enumerate(repeats):
            with open(join(nsil_dir, i, 'test_log.json'), 'r') as jf:
                tl = json.loads(jf.read())
            if str(num_epochs) not in tl:
                print('Repeat not complete:')
                print(join(nsil_dir, i, 'test_log.json'))
                continue
            acc = tl[str(num_epochs)]['network_accuracy']['digit']
            if acc > max_acc:
                max_acc = acc
            nsil_accs.append(acc)
        
        net_acc = np.mean(np.array(nsil_accs))
        net_err = sem(np.array(nsil_accs))
        print(f'Dataset: {dataset}, Avg network accuracy: {net_acc} ({net_err})')

In [7]:
example = '../../../examples/recursive_arithmetic'
get_network_acc(example)

Dataset: meta_abd_data_sum, Avg network accuracy: 0.9831800000000002 (0.0005978294071054137)
Dataset: meta_abd_data_prod, Avg network accuracy: 0.9805400000000001 (0.0005173006862551082)


In [8]:
class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)


def conv_net(out_dim):
    return nn.Sequential(
        nn.Conv2d(1, 32, 3, 1),
        nn.ReLU(),
        nn.Conv2d(32, 64, 3, 1),
        nn.ReLU(),
        nn.MaxPool2d(2),
        nn.Dropout(0.25),
        Flatten(),
        nn.Linear(9216, 128),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(128, out_dim),
        nn.Softmax(dim=1)
    )


class Net(nn.Module):
    def __init__(self, out_dim):
        super(Net, self).__init__()
        self.enc = conv_net(out_dim)

    def forward(self, x):
        return self.enc(x)


In [9]:
MNIST_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
MNIST_TEST = DataLoader(MNIST(root='../../../data', train=False, download=False, transform=MNIST_transform), batch_size=64)

In [10]:
def get_nn_preds(_net, _loader):
    with torch.no_grad():
        all_preds = torch.tensor([], device='cpu')
        for data, targets in _loader:
            outputs = _net(data)
            confs, preds = torch.max(outputs, 1)
            all_preds = torch.cat((all_preds, preds.to('cpu')), 0)
        return all_preds


In [20]:
# For each repeat, at epoch 10, load network, get test predictions, compute MAE
def task_results(example, data_files):
    
    for dataset in data_files:        
        nsil_dir = join(example, 'saved_results', 'repeats', dataset)
        repeats = os.listdir(nsil_dir)
        repeats = [r for r in repeats if r != '.DS_Store']
        repeats.sort(key=natural_keys)
        
        for test_file in data_files[dataset]:
            with open(join(example, 'data', test_file), 'r') as f:
                TASK_DATA = yaml.load(f, Loader=yaml.Loader)
        
            results = []
            for idx, i in enumerate(repeats):
                try:
                    net_weights_file = join(nsil_dir, i, 'networks', 'net_digit_iteration_20.pt')
                    net = Net(10)
                    net.load_state_dict(torch.load(net_weights_file, map_location=torch.device('cpu')))
                    net.eval()
                except:
                    print(f"Skipping repeat {i} as no iteration 10 network...")
                    continue

                # Get network predictions on test images
                nn_preds = get_nn_preds(net, MNIST_TEST)

                # For each example in test set, compute metrics
                task_preds = []
                task_targets = []
                for ex in TASK_DATA:
                    ex_preds = [int(nn_preds[x].item()) for x in ex.x_idxs]
                    if 'sum' in dataset:
                        task_pred = sum(ex_preds)
                    else:
                        task_pred = np.prod(np.array(ex_preds))
                    task_preds.append(task_pred)
                    task_targets.append(ex.y)
                
                    
                # Compute MAE or logMAE
                if 'sum' in dataset:
                    test_loss = torch.nn.L1Loss(reduction='sum')(
                        torch.FloatTensor(task_preds), torch.FloatTensor(task_targets)).item()
                else:
                    test_loss = torch.nn.L1Loss(reduction='sum')(
                        torch.log(torch.FloatTensor(task_preds)+1e-10),
                        torch.log(torch.FloatTensor(task_targets)+1e-10)).item()
                test_loss /= len(TASK_DATA)
                results.append(test_loss)
            
            # Compute average over repeats
            avg = np.mean(np.array(results))
            std_err = sem(np.array(results))
            if 'prod' in dataset:
                mae_type = 'logMAE'
            else:
                mae_type = 'MAE'
            print(f'File: {test_file}, {mae_type}: {avg:f} ({std_err:f})')

In [21]:
d_fs = {
    'meta_abd_data_sum': ['mysum_full_test.yaml', 'mysum_full_test_10.yaml', 'mysum_full_test_100.yaml'],
    'meta_abd_data_prod': ['myprod_full_test.yaml', 'myprod_full_test_10.yaml', 'myprod_full_test_15.yaml']
}

In [22]:
task_results(example, d_fs)

File: mysum_full_test.yaml, MAE: 0.238260 (0.010211)
File: mysum_full_test_10.yaml, MAE: 0.625320 (0.026914)
File: mysum_full_test_100.yaml, MAE: 4.449560 (0.177717)
File: myprod_full_test.yaml, logMAE: 0.321787 (0.017189)
File: myprod_full_test_10.yaml, logMAE: 0.528803 (0.030440)
File: myprod_full_test_15.yaml, logMAE: 2.478133 (0.103972)


## 2 Two-Digit Arithmetic Results

In [47]:
base = r'''
\begin{table*}[]
\centering
\resizebox{0.8\linewidth}{!}{%
\begin{tabular}{@{}lllllll@{}}
\cmidrule(l){2-7}
\multicolumn{1}{c}{} & \multicolumn{3}{c}{\textbf{Addition}}                                             & \multicolumn{3}{c}{\textbf{E9P}}                                                  \\ \cmidrule(l){2-7}
Dataset \% & \multicolumn{1}{c}{\textbf{100}} & \multicolumn{1}{c}{\textbf{10}} & \multicolumn{1}{c}{\textbf{5}} & \multicolumn{1}{c}{\textbf{100}} & \multicolumn{1}{c}{\textbf{10}} & \multicolumn{1}{c}{\textbf{5}} \\ \midrule
ff_nsl
NeurASP
nsil
\end{tabular}
}
\caption{Non-Recursive Arithmetic naive baseline results. Standard error over 5 repeats is shown in parentheses.}
\label{tab:non_recursive_naive}
\end{table*}
'''

In [44]:
# Get FF-NSL/NeurASP
def get_baseline_arithmetic(base):
    baselines = ['ff_nsl', 'NeurASP']
    tasks = ['sum', 'e9p']
    pcts = [100,10,5]
    for b in baselines:
        b_row = ''
        for t in tasks:
            for p in pcts:
                # Load results
                with open(f'../../../examples/arithmetic/baselines/saved_results/{t}/{b}/{p}/results.json') as rf:
                    rf = json.loads(rf.read())
                    res = f"{rf['task']['acc']:.{4}f} ({rf['task']['std_err']:.{4}f})"
                    b_row += f'{res} & '
        if b == 'ff_nsl':
            b_row = 'FF-NSL & ' + b_row[:-2] + r'\\'
        else:
            b_row = 'NeurASP & ' + b_row[:-2] + r'\\ \midrule'
        
        base = base.replace(b, b_row)
    return base
        

In [59]:
def get_nsil_arithmetic(nsil_dir, base):
    nsil_row = ''
    tasks = ['sum', 'e9p']
    pcts = [100,10,5]
    for t in tasks:
        for p in pcts:
            # Get 5 repeat average
            nsl_dir = nsil_dir+'/'+str(t)+'/'+str(p)
            repeats = os.listdir(nsl_dir)
            repeats = [r for r in repeats if r != '.DS_Store']
            repeats.sort(key=natural_keys)

            all_results_epoch = []
            for idx, i in enumerate(repeats):
                if idx < 5:
                    # Read in test_log and get end-to-end accuracy at this epoch
                    with open(join(nsl_dir, i, 'test_log.json'), 'r') as jf:
                        tl = json.loads(jf.read())
                        acc = tl[str(20)]['end_to_end_acc']
                        all_results_epoch.append(acc)

            # Compute mean and std err across all repeats
            nsl_mean = np.mean(all_results_epoch)
            nsl_err = sem(all_results_epoch)
            res = f"{nsl_mean:.{4}f} ({nsl_err:.{4}f})"
            nsil_row += f'{res} & '
            
    nsil_row = 'NSIL & ' + nsil_row[:-2] + r'\\ \bottomrule'
    base = base.replace('nsil', nsil_row)
    return base
        

In [60]:
with_baselines = get_baseline_arithmetic(base)
print(get_nsil_arithmetic('../../../examples/arithmetic/saved_results/repeats', with_baselines))


\begin{table*}[]
\centering
\resizebox{0.8\linewidth}{!}{%
\begin{tabular}{@{}lllllll@{}}
\cmidrule(l){2-7}
\multicolumn{1}{c}{} & \multicolumn{3}{c}{\textbf{Addition}}                                             & \multicolumn{3}{c}{\textbf{E9P}}                                                  \\ \cmidrule(l){2-7}
Dataset \% & \multicolumn{1}{c}{\textbf{100}} & \multicolumn{1}{c}{\textbf{10}} & \multicolumn{1}{c}{\textbf{5}} & \multicolumn{1}{c}{\textbf{100}} & \multicolumn{1}{c}{\textbf{10}} & \multicolumn{1}{c}{\textbf{5}} \\ \midrule
FF-NSL & 0.9753 (0.0021) & 0.9362 (0.0029) & 0.9151 (0.0058) & 0.9809 (0.0016) & 0.9513 (0.0030) & 0.9346 (0.0051) \\
NeurASP & 0.9762 (0.0013) & 0.9492 (0.0016) & 0.9149 (0.0051) & 0.9797 (0.0015) & 0.9642 (0.0009) & 0.9500 (0.0014) \\ \midrule
NSIL & 0.9762 (0.0013) & 0.9449 (0.0025) & 0.8782 (0.0134) & 0.9816 (0.0009) & 0.9634 (0.0007) & 0.9510 (0.0016) \\ \bottomrule
\end{tabular}
}
\caption{Non-Recursive Arithmetic naive baseline results. Standa

## 3) Hitting Sets

In [79]:
base = r'''
\begin{table*}[]
\centering
\resizebox{0.6\linewidth}{!}{%
\begin{tabular}{@{}lllll@{}}
\cmidrule(l){2-5}
\multicolumn{1}{c}{}        & \multicolumn{2}{c}{\textbf{HS}}                                       & \multicolumn{2}{c}{\textbf{CHS}}                                      \\ \cmidrule(l){2-5} 
Dataset & \multicolumn{1}{c}{\textbf{MNIST}} & \multicolumn{1}{c}{\textbf{FashionMNIST}} & \multicolumn{1}{c}{\textbf{MNIST}} & \multicolumn{1}{c}{\textbf{FashionMNIST}} \\ \midrule
ff_nsl
NeurASP
nsil
\end{tabular}
}
\caption{Hitting Sets naive baseline results. Standard error over 5 repeats is shown in parentheses.}
\label{tab:hitting_sets_naive}
\end{table*}
'''

In [76]:
# Get FF-NSL/NeurASP
def get_baseline_hitting_sets(base):
    baselines = ['ff_nsl', 'NeurASP']
    tasks = ['HS_mnist', 'HS_fashion_mnist', 'CHS_mnist', 'CHS_fashion_mnist']
    for b in baselines:
        b_row = ''
        for t in tasks:
            # Load results
            with open(f'../../../examples/hitting_sets/baselines/saved_results/{t}/{b}/results.json') as rf:
                rf = json.loads(rf.read())
                res = f"{rf['task']['acc']:.{4}f} ({rf['task']['std_err']:.{4}f})"
                b_row += f'{res} & '
        if b == 'ff_nsl':
            b_row = 'FF-NSL & ' + b_row[:-2] + r'\\'
        else:
            b_row = 'NeurASP & ' + b_row[:-2] + r'\\ \midrule'
        
        base = base.replace(b, b_row)
    return base
        

In [77]:
def get_nsil_hitting_sets(nsil_dir, base):
    nsil_row = ''
    tasks = ['HS_mnist', 'HS_fashion_mnist', 'CHS_mnist', 'CHS_fashion_mnist']
    for t in tasks:
        # Get 5 repeat average
        nsl_dir = nsil_dir+'/'+str(t)+'/100'
        repeats = os.listdir(nsl_dir)
        repeats = [r for r in repeats if r != '.DS_Store']
        repeats.sort(key=natural_keys)

        all_results_epoch = []
        for idx, i in enumerate(repeats):
            if idx < 5:
                # Read in test_log and get end-to-end accuracy at this epoch
                with open(join(nsl_dir, i, 'test_log.json'), 'r') as jf:
                    tl = json.loads(jf.read())
                    acc = tl[str(20)]['end_to_end_acc']
                    all_results_epoch.append(acc)

        # Compute mean and std err across all repeats
        nsl_mean = np.mean(all_results_epoch)
        nsl_err = sem(all_results_epoch)
        res = f"{nsl_mean:.{4}f} ({nsl_err:.{4}f})"
        nsil_row += f'{res} & '
            
    nsil_row = 'NSIL & ' + nsil_row[:-2] + r'\\ \bottomrule'
    base = base.replace('nsil', nsil_row)
    return base
        

In [78]:
with_baselines = get_baseline_hitting_sets(base)
print(get_nsil_hitting_sets('../../../examples/hitting_sets/saved_results/repeats', with_baselines))


\begin{table*}[]
\centering
\resizebox{0.6\linewidth}{!}{%
\begin{tabular}{@{}lllll@{}}
\cmidrule(l){2-5}
\multicolumn{1}{c}{}        & \multicolumn{2}{c}{\textbf{HS}}                                       & \multicolumn{2}{c}{\textbf{CHS}}                                      \\ \cmidrule(l){2-5} 
\multicolumn{1}{c}{Dataset} & \multicolumn{1}{c}{\textbf{MNIST}} & \multicolumn{1}{c}{\textbf{FashionMNIST}} & \multicolumn{1}{c}{\textbf{MNIST}} & \multicolumn{1}{c}{\textbf{FashionMNIST}} \\ \midrule
FF-NSL & 0.9937 (0.0017) & 0.8816 (0.0110) & 0.9962 (0.0012) & 0.9563 (0.0034) \\
NeurASP & 0.9981 (0.0013) & 0.8975 (0.0041) & 0.9994 (0.0006) & 0.9538 (0.0070) \\ \midrule
NSIL & 0.9962 (0.0012) & 0.8747 (0.0053) & 0.9981 (0.0013) & 0.9544 (0.0021) \\ \bottomrule
\end{tabular}
}
\caption{Hitting Sets naive baseline results. Standard error over 5 repeats is shown in parentheses.}
\label{tab:hitting_sets_naive}
\end{table*}

