In [1]:
import torch
import torch.nn as nn
import clingo
import ast
import pandas as pd
import os
import re
import itertools
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import Dataset
from os.path import join
from skimage import io
from torchvision import transforms
from sklearn.metrics import hamming_loss
from scipy.stats import sem

In [2]:
mnist_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

In [3]:
# Helper functions
def atof(text):
    try:
        retval = float(text)
    except ValueError:
        retval = text
    return retval

def natural_keys(text):
    '''
    alist.sort(key=natural_keys) sorts in human order
    http://nedbatchelder.com/blog/200712/human_sorting.html
    (See Toothy's implementation in the comments)
    float regex comes from https://stackoverflow.com/a/12643073/190597
    '''
    return [ atof(c) for c in re.split(r'[+-]?([0-9]+(?:[.][0-9]*)?|[.][0-9]+)', text) ]

In [4]:
class MNISTDigits(Dataset):
    def __init__(self, csv_file, root_dir):
        """
        Args:
            csv_file (string): Path to the csv file with label annotations.
            root_dir (string): Directory with all the images.
        """
        self.root_dir = root_dir
        self.label_file = pd.read_csv(csv_file)

    def __len__(self):
        return len(self.label_file)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        img_name = os.path.join(self.root_dir, str(idx) + '.jpg')
        image = io.imread(img_name, as_gray=True)
        label = self.label_file.iloc[idx, 1]
        image = mnist_transform(image)

        return image.float(), label - 1

In [5]:
class MNISTNet(nn.Module):
    def __init__(self, num_out):
        super(MNISTNet, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 6, 5),  # 6 is the output chanel size; 5 is the kernal size; 1 (chanel) 28 28 -> 6 24 24
            nn.MaxPool2d(2, 2),  # kernal size 2; stride size 2; 6 24 24 -> 6 12 12
            nn.ReLU(True),       # inplace=True means that it will modify the input directly thus save memory
            nn.Conv2d(6, 16, 5),  # 6 12 12 -> 16 8 8
            nn.MaxPool2d(2, 2),  # 16 8 8 -> 16 4 4
            nn.ReLU(True)
        )
        self.classifier = nn.Sequential(
            nn.Linear(16 * 4 * 4, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, num_out),
            nn.Softmax(1)
        )

    def forward(self, x):
        x = self.encoder(x)
        x = x.view(-1, 16 * 4 * 4)
        x = self.classifier(x)
        return x


In [6]:
def generate_hitting_sets(min_dv, max_dv):
    digits = list(range(min_dv, max_dv+1))
    k_2_combos = list(itertools.combinations(digits, 2))
    k_1_combos = [[d] for d in digits]
    hs = []
    for c in k_2_combos:
        hs.append(list(c))
    hs += k_1_combos
    return hs

In [9]:
mnist_test_base = '../../../data/mnist_1_to_5/test'
fashion_mnist_test_base = '../../../data/fashion_mnist_1_to_5/test'
mnist_test_digits = MNISTDigits(join(mnist_test_base, 'labels.csv'), mnist_test_base)
fashion_mnist_test_digits = MNISTDigits(join(fashion_mnist_test_base, 'labels.csv'), fashion_mnist_test_base)
possible_hitting_sets = generate_hitting_sets(1,5)

In [10]:
hs_learned_hyp = '''
:- ss_element(V1,V2); not hit(V1); elt(V2); ss(V1).
0 {hs(V1,V2) } 1 :- elt(V2); hs_index(V1).
hit(V1) :- hs(V3,V2); ss_element(V1,V2); hs_index(V3); elt(V2); ss(V1).
:- hs(V3,V1); hs(V3,V2); V1 != V2; hs_index(V3); elt(V2); elt(V1).
'''
chs_learned_hyp = '''
:- ss_element(3,V2); ss_element(V1,1); elt(V2); ss(V1).
:- ss_element(V1,V2); not hit(V1); elt(V2); ss(V1).
0 {hs(V1,V2) } 1 :- elt(V2); hs_index(V1).
hit(V1) :- hs(V3,V2); ss_element(V1,V2); hs_index(V3); elt(V2); ss(V1).
:- hs(V3,V1); hs(V3,V2); V1 != V2; hs_index(V3); elt(V2); elt(V1).
'''

In [11]:
def run_clingo(p):
    clingo_control = clingo.Control(["--warn=none", '0', '--project'])
    modls = []
    try:
        clingo_control.add("base", [], p)
    except RuntimeError:
        print('Clingo runtime error')
        print('Program: {0}'.format(p))
        sys.exit(1)
    clingo_control.ground([("base", [])])

    def on_model(m):
        modls.append(str(m))

    clingo_control.solve(on_model=on_model)
    return modls

In [12]:
def get_hs_from_clingo(elts, learned_hyp, ctx):
    # Create and run clingo program to get hitting sets
    clingo_prog = f'hs_index(1..2). ss(1..4). {elts} {learned_hyp}'
    clingo_prog += f'\n {ctx}\n #show hs/2.'
    models = run_clingo(clingo_prog)

    # Process models
    models_sets = []
    for m in models:
        # Ignore models with no hs(1,X).
        facts = [i for i in m.split(' ') if 'hs(' in i]
        if len(facts) != 1 or facts[0].split(',')[0].split('(')[1] == '1':
            list_rep = []
            for hsd in facts:
                list_rep.append(int(hsd.split(',')[1].split(')')[0]))
            if set(list_rep) not in models_sets:
                models_sets.append(set(list_rep))

    vector = [0]*len(possible_hitting_sets)
    for phs_idx, phs in enumerate(possible_hitting_sets):
        if set(phs) in models_sets:
            vector[phs_idx] = 1

    return vector

In [13]:
def run(test_file, net, mnist_images, learned_hyp):
    with open(test_file, 'r') as testf:
        testf = testf.readlines()[1:]
        hamming_losses = []
        for line in testf:
            ss, hitting_sets = line.rstrip().split('|')
            ss = ast.literal_eval(ss)
            nn_pred_ctx = ''
            ground_truth_ctx = ''
            ground_truth_elements = []
            nn_pred_elements = []
            for subset_id, subset in enumerate(ss):
                nn_preds_this_subset = []
                for image_id in subset:
                    
                    im = mnist_images[image_id]
                    im_label = im[1] + 1
                    
                    # Ground truth
                    ground_truth_ctx += f'ss_element({subset_id+1},{im_label}).\n'
                    ground_truth_elements.append(im_label)
                    
                    # Show image for debugging
                    # plt.imshow(im[0].permute(1, 2, 0))
                    # plt.show()
                    
                    # NN prediction
                    pred = net(im[0].unsqueeze(0)).argmax(dim=1)[0].item() + 1
                    nn_pred_ctx += f'ss_element({subset_id+1},{pred}).\n'
                    nn_pred_elements.append(pred)
            
            
            gt_elt_str = ' '.join([f'elt({d}).' for d in ground_truth_elements])
            pred_elt_str = ' '.join([f'elt({d}).' for d in nn_pred_elements])
            
            # Get true hs
            true_hs = get_hs_from_clingo(gt_elt_str, learned_hyp, ground_truth_ctx)
            
            # Get predicted hs
            predicted_hs = get_hs_from_clingo(pred_elt_str, learned_hyp, nn_pred_ctx)
            
            # Compute hamming loss
            hl = hamming_loss(true_hs, predicted_hs)
            hamming_losses.append(hl)

        return np.mean(hamming_losses), sem(hamming_losses)
            
                

In [18]:
def compute_hamming_loss(example_dir, learned_hyp, dataset='HS_mnist'):
    nsl_repeats_dir = example_dir+'/saved_results'
    
    if dataset is not None:
        nsl_repeats_dir = nsl_repeats_dir+'/repeats/'+dataset
    else:
        nsl_repeats_dir = nsl_repeats_dir+'/repeats'
    
    # We get mean hamming loss after 20 epochs
    hamming_losses = []
    nsl_dir = nsl_repeats_dir+'/'+str(100)
    repeats = os.listdir(nsl_dir)
    repeats = [r for r in repeats if r != '.DS_Store']
    repeats.sort(key=natural_keys)
    
    for idx, i in enumerate(repeats):
        # Read network
        network_weights_path = nsl_dir+'/'+i+'/networks/net_digit_iteration_20.pt'
        net = MNISTNet(5)
        net.load_state_dict(torch.load(network_weights_path, map_location=torch.device('cpu')))
        net.eval()

        # Get NN predictions for each image on test set
        if 'HS' in dataset:
            task = 'hs'
        else:
            task = 'chs'
        if 'mnist' in dataset:
            images = mnist_test_digits
            test_file = example_dir + f'data/{task}/mnist_test_with_hs.csv'
        else:
            images = fashion_mnist_test_digits
            test_file = example_dir + f'data/{task}/fashion_mnist/test_with_hs.csv'
        hl, _ = run(test_file, net, images, learned_hyp)
        hamming_losses.append(hl)
        
    # Compute mean and std err across all repeats
    hamming_loss_mean = np.mean(hamming_losses)
    hamming_loss_err =  sem(hamming_losses)
        
    return hamming_loss_mean, hamming_loss_err

## HS

In [19]:
# example
example_dir = '../../../examples/hitting_sets/'
lh = hs_learned_hyp

### MNIST

In [20]:
compute_hamming_loss(example_dir, lh, dataset='HS_mnist')

(0.002985232067510549, 0.0001943696188940777)

### FashionMNIST

In [21]:
compute_hamming_loss(example_dir, lh, dataset='HS_fashion_mnist')

(0.14970464135021097, 0.0007783818553406404)

## CHS

In [23]:
# example
chs_lh = chs_learned_hyp

### MNIST

In [24]:
compute_hamming_loss(example_dir, chs_lh, dataset='CHS_mnist')

(0.002721518987341772, 0.00012611877535556893)

### FashionMNIST

In [25]:
compute_hamming_loss(example_dir, chs_lh, dataset='CHS_fashion_mnist')

(0.12614978902953583, 0.002098039870738607)