Import `pyswip` and consult the Prolog background knowledge base.

In [1]:
from pyswip import Prolog
prolog = Prolog()
prolog.consult('mnist_sum_2.pl')

Test if the `CLP(FD)`-based abduction in the background knowledge base works correctly.

In [2]:
target = 9
for soln in prolog.query("abduce(X, 5, {})".format(target)):
    print("Sum of", soln["X"], "equals {}.".format(target))

Sum of [0, 0, 0, 0, 9] equals 9.
Sum of [0, 0, 0, 1, 8] equals 9.
Sum of [0, 0, 0, 2, 7] equals 9.
Sum of [0, 0, 0, 3, 6] equals 9.
Sum of [0, 0, 0, 4, 5] equals 9.
Sum of [0, 0, 0, 5, 4] equals 9.
Sum of [0, 0, 0, 6, 3] equals 9.
Sum of [0, 0, 0, 7, 2] equals 9.
Sum of [0, 0, 0, 8, 1] equals 9.
Sum of [0, 0, 0, 9, 0] equals 9.
Sum of [0, 0, 1, 0, 8] equals 9.
Sum of [0, 0, 1, 1, 7] equals 9.
Sum of [0, 0, 1, 2, 6] equals 9.
Sum of [0, 0, 1, 3, 5] equals 9.
Sum of [0, 0, 1, 4, 4] equals 9.
Sum of [0, 0, 1, 5, 3] equals 9.
Sum of [0, 0, 1, 6, 2] equals 9.
Sum of [0, 0, 1, 7, 1] equals 9.
Sum of [0, 0, 1, 8, 0] equals 9.
Sum of [0, 0, 2, 0, 7] equals 9.
Sum of [0, 0, 2, 1, 6] equals 9.
Sum of [0, 0, 2, 2, 5] equals 9.
Sum of [0, 0, 2, 3, 4] equals 9.
Sum of [0, 0, 2, 4, 3] equals 9.
Sum of [0, 0, 2, 5, 2] equals 9.
Sum of [0, 0, 2, 6, 1] equals 9.
Sum of [0, 0, 2, 7, 0] equals 9.
Sum of [0, 0, 3, 0, 6] equals 9.
Sum of [0, 0, 3, 1, 5] equals 9.
Sum of [0, 0, 3, 2, 4] equals 9.
Sum of [0,

# Abductive Learning

Now let's try to implement the MNIST sum learning algorithm using the Abductive Learning framework.

### Dataset Generation

Directly copy the codes from the `data_generator.ipynb` notebook file.

In [3]:
import numpy as np
import torch
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import matplotlib.pyplot as plt

transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
    ])
dataset1 = datasets.MNIST('data', train=True, download=True,
                            transform=transform)
dataset2 = datasets.MNIST('data', train=False,
                            transform=transform)

device = torch.device("cpu")

digit_groups_train = {0:[], 1:[], 2:[], 3:[], 4:[], 5:[], 6:[], 7:[], 8:[], 9:[]}
digit_groups_test = {0:[], 1:[], 2:[], 3:[], 4:[], 5:[], 6:[], 7:[], 8:[], 9:[]}

for i in range(len(dataset1)): 
    digit_groups_train[int(dataset1.targets[i])].append(i)
for i in range(len(dataset2)): 
    digit_groups_test[int(dataset2.targets[i])].append(i)

In [4]:
class MNIST_Sum_2:
    def __init__(self, num_exs, list_len, digit_groups):
        self.targets = []
        self.img_indices = []
        self.ground_truth = []
        self.length = num_exs
        for i in range(num_exs):
            # sampling two numbers from 0 to 9
            sampled_digits = np.random.choice(10, list_len)
            self.ground_truth.append(list(sampled_digits))

            # using the sum of the sampled digits as the target
            self.targets.append(sum(sampled_digits))
            ids = []
            for j in range(len(sampled_digits)):
                # get the j-th digits
                digit = sampled_digits[j]
                # total number of the images of the digit
                ids.append(np.random.choice(digit_groups[digit]))
            self.img_indices.append(ids)

# Generate the training and test dataset for MNIST Sum task
mnist_sum_data_train = MNIST_Sum_2(600, 5, digit_groups_train)
mnist_sum_data_test = MNIST_Sum_2(600, 5, digit_groups_test)

### The Machine Learning Part

A neural network for image classification (copying from the `mnist_network.ipynb` example).

In [5]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

device = torch.device("cpu")

class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)


def conv_net(outdim, *args, **kwargs):
    return nn.Sequential(
        nn.Conv2d(1, 32, 3, 1),
        nn.ReLU(),
        nn.Conv2d(32, 64, 3, 1),
        nn.ReLU(),
        nn.MaxPool2d(2),
        nn.Dropout(0.25),
        Flatten(),
        nn.Linear(9216, 128),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(128, outdim),
        nn.LogSoftmax(dim=1)
    )


def auto_enc(outdim, *args, **kwargs):
    return nn.Sequential(
        nn.Linear(outdim, 128),
        nn.ReLU(),
        nn.Linear(128, 784)
    )


def mlp(indim, outdim, *args, **kwargs):
    return nn.Sequential(
        nn.Linear(indim, 32),
        nn.ReLU(),
        nn.Linear(32, outdim),
        nn.LogSoftmax(dim=1),
    )

class LSTM(nn.Module):
    """A (Bi)LSTM Model.

    Attributes:
        num_layers: the number of LSTM layers (number of stacked LSTM models) in the network.
        in_dim: the size of the input sample.
        hidden_dim: the size of the hidden layers.
        out_dim: the size of the output.
        activation: the activation function.
        bidirectional: the flag for bidirectional LSTM
        dropout: the dropout rate if num_layers > 1
    """

    def __init__(self, num_layers, in_dim, hidden_dim, out_dim,
                 bidirectional=False, dropout=0):
        super().__init__()
        self.num_layers = num_layers
        self.in_dim = in_dim
        self.hidden_dim = hidden_dim
        self.out_dim = out_dim
        self.bidirectional = bidirectional
        self.dropout = dropout

        self.lstm = nn.LSTM(self.in_dim,
                            self.hidden_dim,
                            num_layers=self.num_layers,
                            bidirectional=self.bidirectional,
                            dropout=self.dropout,
                            batch_first=True)
        fc_dim = self.hidden_dim * 2 if self.bidirectional else self.hidden_dim
        self.fc = nn.Linear(fc_dim, self.out_dim)

    def forward(self, inputs):
        lstm_out, _ = self.lstm(inputs)
        outputs = self.fc(lstm_out[:, -1, :])
        outputs = torch.sigmoid(outputs)
        return outputs

    def loss_function(self, pred, y):
        return F.binary_cross_entropy(pred, y.view(y.shape[0], -1))


class Net(nn.Module):
    outdim = 10

    def __init__(self, outdim):
        super(Net, self).__init__()
        self.outdim = outdim
        self.enc = conv_net(outdim)

    def forward(self, x):
        output = self.enc(x)
        return output

    def loss_function(self, pred, y):
        return F.nll_loss(pred, y)

def train(model, device, train_loader, optimizer, epoch,
          log_interval=1000, dry_run=False):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = model.loss_function(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            if dry_run:
                break

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            # sum up batch loss
            test_loss += model.loss_function(output, target).item()
            # get the index of the max log-probability
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('-- Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


### The Logic Abduction Part

It involves the following steps:
1. Given the target (label, i.e., the sum of the two images), using `pyswip` to abduce possible pseudo-labels for them.
2. Calculate the probability of each pair of pseudo-labels.
3. Return the most probable pseudo-labels to retrain the neural network.

_Remark_: For more complicated problems, a better way of searching for the best pseudo-labels is required.

##### Abducing possible pseudo-labels given the sum

- `pl` is the Prolog instance that consulted `mnist_sum_2.pl`;
- `target` is the sum of `5` images.

In [6]:
def abduce(pl, target):
    # This abduce/2 function is defined in "mnist_sum.pl"
    ans = [];
    for soln in pl.query("abduce(X, 5, {})".format(target)):
        ans.append(soln["X"])
    if len(ans) > 0:
        return ans
    else:
        return None

Test the `abduce` function.

In [7]:
print(abduce(prolog, 15))

[[0, 0, 0, 6, 9], [0, 0, 0, 7, 8], [0, 0, 0, 8, 7], [0, 0, 0, 9, 6], [0, 0, 1, 5, 9], [0, 0, 1, 6, 8], [0, 0, 1, 7, 7], [0, 0, 1, 8, 6], [0, 0, 1, 9, 5], [0, 0, 2, 4, 9], [0, 0, 2, 5, 8], [0, 0, 2, 6, 7], [0, 0, 2, 7, 6], [0, 0, 2, 8, 5], [0, 0, 2, 9, 4], [0, 0, 3, 3, 9], [0, 0, 3, 4, 8], [0, 0, 3, 5, 7], [0, 0, 3, 6, 6], [0, 0, 3, 7, 5], [0, 0, 3, 8, 4], [0, 0, 3, 9, 3], [0, 0, 4, 2, 9], [0, 0, 4, 3, 8], [0, 0, 4, 4, 7], [0, 0, 4, 5, 6], [0, 0, 4, 6, 5], [0, 0, 4, 7, 4], [0, 0, 4, 8, 3], [0, 0, 4, 9, 2], [0, 0, 5, 1, 9], [0, 0, 5, 2, 8], [0, 0, 5, 3, 7], [0, 0, 5, 4, 6], [0, 0, 5, 5, 5], [0, 0, 5, 6, 4], [0, 0, 5, 7, 3], [0, 0, 5, 8, 2], [0, 0, 5, 9, 1], [0, 0, 6, 0, 9], [0, 0, 6, 1, 8], [0, 0, 6, 2, 7], [0, 0, 6, 3, 6], [0, 0, 6, 4, 5], [0, 0, 6, 5, 4], [0, 0, 6, 6, 3], [0, 0, 6, 7, 2], [0, 0, 6, 8, 1], [0, 0, 6, 9, 0], [0, 0, 7, 0, 8], [0, 0, 7, 1, 7], [0, 0, 7, 2, 6], [0, 0, 7, 3, 5], [0, 0, 7, 4, 4], [0, 0, 7, 5, 3], [0, 0, 7, 6, 2], [0, 0, 7, 7, 1], [0, 0, 7, 8, 0], [0, 0, 8, 0, 

##### The Abductive Learning Procedure

Importing useful libraries and set the default neural network training parameters within the Abductive Learning Process.

In [8]:
from tqdm import tqdm

nn_train_kwargs = {'batch_size': 64, 'shuffle': True}
nn_epoch = 2

nn_test_loader = torch.utils.data.DataLoader(dataset2, **nn_train_kwargs)

Useful functions for abduction:
1. `get_mnist_imgs`: given `indices`, sample a subset of images from `dataset` (such as the `MNIST` dataset).
2. `best_pseudo_label`: given a set of abduced possible pseudo-labels and the pseudo-label distribution, return the most probable pseudo-label combination for each image. 

In [9]:
def get_mnist_imgs(dataset, indices, use_cuda=False):
    """
    Given get the image tensor from mnist dataset by indices
    """

    n = len(indices)
    img_tensor, tgt = dataset[indices[0]]
    img_tensor = torch.reshape(img_tensor, (1, 1, 28, 28))
    targets = [tgt]
    i = 1
    while i < n:
        img, tgt = dataset[indices[i]]
        img = torch.reshape(img, (1, 1, 28, 28))
        img_tensor = torch.cat((img_tensor, img), 0)
        targets.append(tgt)
        i = i + 1
    if use_cuda:
        img_tensor = img_tensor.to(torch.device("cuda"))
    return img_tensor, targets

def best_pseudo_label(pseudo_label_lists, pseudo_label_scores):
    best_score = -100000.0
    best_combi = np.zeros(pseudo_label_scores.shape[0])
    probabilities = np.exp(pseudo_label_scores)
    for label_combi in pseudo_label_lists:
        # because the scores are log_softmax, the log probability can be calculated as sum
        score = 1.0
        for j in range(len(label_combi)):
            score = score*probabilities[j, label_combi[j]]
        if score >= best_score:
            best_score = score
            best_combi = label_combi
    return best_combi, score

Main procedure for abductive Learning. Given a machine learning `model` and a prolog instance `pl` with `dataset`, it does the following steps:
1. Using `model` to predict the pseudo-label probabilistic distribution of `dataset`;
2. Finding the best pseudo-label combination considering both the abduction result from `pl` and the pseudo-label probabilistic distribution;
3. Retrain the neural network with the abduced pseudo-labels.

In [10]:
def ABL_main(model, pl, dataset, optimizer=None, scheduler=None):
    # number of examples
    num_examples = dataset.length
    abduced_data_ids = []
    abduced_labels = []
    ground_truth_labels = []

    # start abduction
    for i in tqdm (range(num_examples), desc="Abducing..."):
        target = int(dataset.targets[i])
        possible_pseudo_labels = abduce(pl, target)
        if possible_pseudo_labels is not None:
            # reshape the tensor of the two MNIST images to match NN model's input dimensions
            img_indices = dataset.img_indices[i]
            imgs, _ = get_mnist_imgs(dataset1, img_indices, use_cuda=False)

            pseudo_label_distribution = model(imgs).detach().numpy()

            # find the pseudo-labels with the maximum likelihood
            abduced_pseudo_labels, _ = best_pseudo_label(possible_pseudo_labels, pseudo_label_distribution)

            # for abduced dataset
            abduced_data_ids = abduced_data_ids + img_indices
            abduced_labels = abduced_labels + abduced_pseudo_labels
            ground_truth_labels = ground_truth_labels + dataset.ground_truth[i]

    # changing the training data labels to the abduced labels
    for i, img in enumerate(abduced_data_ids):
        dataset1.targets[img] = abduced_labels[i]
    
    print(ground_truth_labels)
    print(abduced_labels)
    abduction_accuracy = np.sum(np.array(ground_truth_labels) == np.array(abduced_labels))/len(abduced_labels)

    # making new dataset with abduced labels
    abduced_data = torch.utils.data.Subset(dataset1, abduced_data_ids)

    # training the neural network model
    abduced_train_loader = torch.utils.data.DataLoader(abduced_data, batch_size=64)

    for epoch in range(1, nn_epoch + 1):
        train(model, device, abduced_train_loader, optimizer, epoch)
        print("Abduction accuracy: ", abduction_accuracy)
        scheduler.step()
    test(model, device, nn_test_loader)

### Running Experiment

Initialise model and optimizer.

In [11]:
model = Net(outdim=10).to(device)
test(model, device, nn_test_loader)

optimizer = optim.Adadelta(model.parameters(), lr=1.0)
# optimizer = optim.SGD(model.parameters(), lr=1.0)
scheduler = StepLR(optimizer, step_size=1, gamma=0.7)

-- Test set: Average loss: 0.0363, Accuracy: 1128/10000 (11%)


##### Run Abductive Learning without any pre-train.

In [12]:
ABL_epochs = 20
for epoch in range(ABL_epochs):
    ABL_main(model, prolog, mnist_sum_data_train, optimizer=optimizer, scheduler=scheduler)

Abducing...: 100%|██████████| 600/600 [02:52<00:00,  3.49it/s]


[2, 8, 4, 5, 4, 3, 7, 8, 6, 2, 6, 5, 4, 5, 2, 6, 8, 1, 0, 1, 4, 7, 7, 9, 2, 4, 7, 3, 2, 1, 8, 3, 3, 9, 3, 4, 2, 0, 6, 2, 2, 0, 3, 8, 5, 7, 1, 9, 2, 0, 7, 9, 8, 8, 5, 8, 8, 3, 7, 6, 9, 0, 3, 8, 3, 3, 3, 4, 6, 5, 1, 7, 7, 0, 0, 3, 7, 2, 3, 4, 5, 6, 3, 0, 6, 4, 7, 5, 2, 2, 7, 7, 6, 1, 7, 7, 8, 7, 3, 8, 1, 2, 0, 3, 4, 5, 2, 1, 1, 7, 4, 1, 1, 5, 8, 1, 5, 5, 1, 5, 6, 8, 6, 1, 5, 1, 8, 0, 4, 6, 2, 8, 4, 5, 7, 4, 2, 3, 3, 4, 3, 5, 5, 4, 3, 6, 3, 7, 5, 3, 5, 6, 4, 1, 9, 2, 3, 8, 9, 5, 2, 2, 7, 5, 7, 4, 2, 4, 5, 7, 7, 3, 5, 5, 1, 3, 1, 1, 4, 9, 3, 2, 4, 3, 6, 3, 8, 7, 2, 3, 9, 5, 7, 7, 7, 3, 5, 9, 7, 3, 3, 8, 1, 9, 8, 3, 7, 6, 8, 8, 5, 5, 4, 7, 0, 8, 4, 7, 9, 4, 1, 0, 3, 5, 1, 4, 0, 3, 5, 1, 3, 2, 3, 4, 5, 1, 0, 0, 4, 7, 0, 4, 6, 6, 2, 4, 9, 6, 1, 9, 0, 3, 6, 9, 9, 4, 6, 5, 7, 1, 5, 7, 6, 7, 7, 5, 9, 9, 9, 6, 5, 5, 9, 2, 7, 9, 9, 8, 4, 1, 6, 0, 5, 7, 1, 0, 2, 5, 0, 2, 5, 2, 3, 3, 7, 8, 0, 8, 4, 6, 2, 1, 9, 4, 7, 2, 9, 1, 6, 6, 0, 4, 3, 6, 0, 2, 2, 4, 9, 7, 7, 3, 7, 8, 3, 8, 9, 5, 3, 0, 6, 7, 4, 

Abducing...: 100%|██████████| 600/600 [02:50<00:00,  3.51it/s]


[2, 8, 4, 5, 4, 3, 7, 8, 6, 2, 6, 5, 4, 5, 2, 6, 8, 1, 0, 1, 4, 7, 7, 9, 2, 4, 7, 3, 2, 1, 8, 3, 3, 9, 3, 4, 2, 0, 6, 2, 2, 0, 3, 8, 5, 7, 1, 9, 2, 0, 7, 9, 8, 8, 5, 8, 8, 3, 7, 6, 9, 0, 3, 8, 3, 3, 3, 4, 6, 5, 1, 7, 7, 0, 0, 3, 7, 2, 3, 4, 5, 6, 3, 0, 6, 4, 7, 5, 2, 2, 7, 7, 6, 1, 7, 7, 8, 7, 3, 8, 1, 2, 0, 3, 4, 5, 2, 1, 1, 7, 4, 1, 1, 5, 8, 1, 5, 5, 1, 5, 6, 8, 6, 1, 5, 1, 8, 0, 4, 6, 2, 8, 4, 5, 7, 4, 2, 3, 3, 4, 3, 5, 5, 4, 3, 6, 3, 7, 5, 3, 5, 6, 4, 1, 9, 2, 3, 8, 9, 5, 2, 2, 7, 5, 7, 4, 2, 4, 5, 7, 7, 3, 5, 5, 1, 3, 1, 1, 4, 9, 3, 2, 4, 3, 6, 3, 8, 7, 2, 3, 9, 5, 7, 7, 7, 3, 5, 9, 7, 3, 3, 8, 1, 9, 8, 3, 7, 6, 8, 8, 5, 5, 4, 7, 0, 8, 4, 7, 9, 4, 1, 0, 3, 5, 1, 4, 0, 3, 5, 1, 3, 2, 3, 4, 5, 1, 0, 0, 4, 7, 0, 4, 6, 6, 2, 4, 9, 6, 1, 9, 0, 3, 6, 9, 9, 4, 6, 5, 7, 1, 5, 7, 6, 7, 7, 5, 9, 9, 9, 6, 5, 5, 9, 2, 7, 9, 9, 8, 4, 1, 6, 0, 5, 7, 1, 0, 2, 5, 0, 2, 5, 2, 3, 3, 7, 8, 0, 8, 4, 6, 2, 1, 9, 4, 7, 2, 9, 1, 6, 6, 0, 4, 3, 6, 0, 2, 2, 4, 9, 7, 7, 3, 7, 8, 3, 8, 9, 5, 3, 0, 6, 7, 4, 

Abducing...: 100%|██████████| 600/600 [02:51<00:00,  3.49it/s]


[2, 8, 4, 5, 4, 3, 7, 8, 6, 2, 6, 5, 4, 5, 2, 6, 8, 1, 0, 1, 4, 7, 7, 9, 2, 4, 7, 3, 2, 1, 8, 3, 3, 9, 3, 4, 2, 0, 6, 2, 2, 0, 3, 8, 5, 7, 1, 9, 2, 0, 7, 9, 8, 8, 5, 8, 8, 3, 7, 6, 9, 0, 3, 8, 3, 3, 3, 4, 6, 5, 1, 7, 7, 0, 0, 3, 7, 2, 3, 4, 5, 6, 3, 0, 6, 4, 7, 5, 2, 2, 7, 7, 6, 1, 7, 7, 8, 7, 3, 8, 1, 2, 0, 3, 4, 5, 2, 1, 1, 7, 4, 1, 1, 5, 8, 1, 5, 5, 1, 5, 6, 8, 6, 1, 5, 1, 8, 0, 4, 6, 2, 8, 4, 5, 7, 4, 2, 3, 3, 4, 3, 5, 5, 4, 3, 6, 3, 7, 5, 3, 5, 6, 4, 1, 9, 2, 3, 8, 9, 5, 2, 2, 7, 5, 7, 4, 2, 4, 5, 7, 7, 3, 5, 5, 1, 3, 1, 1, 4, 9, 3, 2, 4, 3, 6, 3, 8, 7, 2, 3, 9, 5, 7, 7, 7, 3, 5, 9, 7, 3, 3, 8, 1, 9, 8, 3, 7, 6, 8, 8, 5, 5, 4, 7, 0, 8, 4, 7, 9, 4, 1, 0, 3, 5, 1, 4, 0, 3, 5, 1, 3, 2, 3, 4, 5, 1, 0, 0, 4, 7, 0, 4, 6, 6, 2, 4, 9, 6, 1, 9, 0, 3, 6, 9, 9, 4, 6, 5, 7, 1, 5, 7, 6, 7, 7, 5, 9, 9, 9, 6, 5, 5, 9, 2, 7, 9, 9, 8, 4, 1, 6, 0, 5, 7, 1, 0, 2, 5, 0, 2, 5, 2, 3, 3, 7, 8, 0, 8, 4, 6, 2, 1, 9, 4, 7, 2, 9, 1, 6, 6, 0, 4, 3, 6, 0, 2, 2, 4, 9, 7, 7, 3, 7, 8, 3, 8, 9, 5, 3, 0, 6, 7, 4, 

Abducing...: 100%|██████████| 600/600 [02:52<00:00,  3.47it/s]


[2, 8, 4, 5, 4, 3, 7, 8, 6, 2, 6, 5, 4, 5, 2, 6, 8, 1, 0, 1, 4, 7, 7, 9, 2, 4, 7, 3, 2, 1, 8, 3, 3, 9, 3, 4, 2, 0, 6, 2, 2, 0, 3, 8, 5, 7, 1, 9, 2, 0, 7, 9, 8, 8, 5, 8, 8, 3, 7, 6, 9, 0, 3, 8, 3, 3, 3, 4, 6, 5, 1, 7, 7, 0, 0, 3, 7, 2, 3, 4, 5, 6, 3, 0, 6, 4, 7, 5, 2, 2, 7, 7, 6, 1, 7, 7, 8, 7, 3, 8, 1, 2, 0, 3, 4, 5, 2, 1, 1, 7, 4, 1, 1, 5, 8, 1, 5, 5, 1, 5, 6, 8, 6, 1, 5, 1, 8, 0, 4, 6, 2, 8, 4, 5, 7, 4, 2, 3, 3, 4, 3, 5, 5, 4, 3, 6, 3, 7, 5, 3, 5, 6, 4, 1, 9, 2, 3, 8, 9, 5, 2, 2, 7, 5, 7, 4, 2, 4, 5, 7, 7, 3, 5, 5, 1, 3, 1, 1, 4, 9, 3, 2, 4, 3, 6, 3, 8, 7, 2, 3, 9, 5, 7, 7, 7, 3, 5, 9, 7, 3, 3, 8, 1, 9, 8, 3, 7, 6, 8, 8, 5, 5, 4, 7, 0, 8, 4, 7, 9, 4, 1, 0, 3, 5, 1, 4, 0, 3, 5, 1, 3, 2, 3, 4, 5, 1, 0, 0, 4, 7, 0, 4, 6, 6, 2, 4, 9, 6, 1, 9, 0, 3, 6, 9, 9, 4, 6, 5, 7, 1, 5, 7, 6, 7, 7, 5, 9, 9, 9, 6, 5, 5, 9, 2, 7, 9, 9, 8, 4, 1, 6, 0, 5, 7, 1, 0, 2, 5, 0, 2, 5, 2, 3, 3, 7, 8, 0, 8, 4, 6, 2, 1, 9, 4, 7, 2, 9, 1, 6, 6, 0, 4, 3, 6, 0, 2, 2, 4, 9, 7, 7, 3, 7, 8, 3, 8, 9, 5, 3, 0, 6, 7, 4, 

Abducing...: 100%|██████████| 600/600 [02:52<00:00,  3.48it/s]


[2, 8, 4, 5, 4, 3, 7, 8, 6, 2, 6, 5, 4, 5, 2, 6, 8, 1, 0, 1, 4, 7, 7, 9, 2, 4, 7, 3, 2, 1, 8, 3, 3, 9, 3, 4, 2, 0, 6, 2, 2, 0, 3, 8, 5, 7, 1, 9, 2, 0, 7, 9, 8, 8, 5, 8, 8, 3, 7, 6, 9, 0, 3, 8, 3, 3, 3, 4, 6, 5, 1, 7, 7, 0, 0, 3, 7, 2, 3, 4, 5, 6, 3, 0, 6, 4, 7, 5, 2, 2, 7, 7, 6, 1, 7, 7, 8, 7, 3, 8, 1, 2, 0, 3, 4, 5, 2, 1, 1, 7, 4, 1, 1, 5, 8, 1, 5, 5, 1, 5, 6, 8, 6, 1, 5, 1, 8, 0, 4, 6, 2, 8, 4, 5, 7, 4, 2, 3, 3, 4, 3, 5, 5, 4, 3, 6, 3, 7, 5, 3, 5, 6, 4, 1, 9, 2, 3, 8, 9, 5, 2, 2, 7, 5, 7, 4, 2, 4, 5, 7, 7, 3, 5, 5, 1, 3, 1, 1, 4, 9, 3, 2, 4, 3, 6, 3, 8, 7, 2, 3, 9, 5, 7, 7, 7, 3, 5, 9, 7, 3, 3, 8, 1, 9, 8, 3, 7, 6, 8, 8, 5, 5, 4, 7, 0, 8, 4, 7, 9, 4, 1, 0, 3, 5, 1, 4, 0, 3, 5, 1, 3, 2, 3, 4, 5, 1, 0, 0, 4, 7, 0, 4, 6, 6, 2, 4, 9, 6, 1, 9, 0, 3, 6, 9, 9, 4, 6, 5, 7, 1, 5, 7, 6, 7, 7, 5, 9, 9, 9, 6, 5, 5, 9, 2, 7, 9, 9, 8, 4, 1, 6, 0, 5, 7, 1, 0, 2, 5, 0, 2, 5, 2, 3, 3, 7, 8, 0, 8, 4, 6, 2, 1, 9, 4, 7, 2, 9, 1, 6, 6, 0, 4, 3, 6, 0, 2, 2, 4, 9, 7, 7, 3, 7, 8, 3, 8, 9, 5, 3, 0, 6, 7, 4, 

Abducing...: 100%|██████████| 600/600 [02:52<00:00,  3.48it/s]


[2, 8, 4, 5, 4, 3, 7, 8, 6, 2, 6, 5, 4, 5, 2, 6, 8, 1, 0, 1, 4, 7, 7, 9, 2, 4, 7, 3, 2, 1, 8, 3, 3, 9, 3, 4, 2, 0, 6, 2, 2, 0, 3, 8, 5, 7, 1, 9, 2, 0, 7, 9, 8, 8, 5, 8, 8, 3, 7, 6, 9, 0, 3, 8, 3, 3, 3, 4, 6, 5, 1, 7, 7, 0, 0, 3, 7, 2, 3, 4, 5, 6, 3, 0, 6, 4, 7, 5, 2, 2, 7, 7, 6, 1, 7, 7, 8, 7, 3, 8, 1, 2, 0, 3, 4, 5, 2, 1, 1, 7, 4, 1, 1, 5, 8, 1, 5, 5, 1, 5, 6, 8, 6, 1, 5, 1, 8, 0, 4, 6, 2, 8, 4, 5, 7, 4, 2, 3, 3, 4, 3, 5, 5, 4, 3, 6, 3, 7, 5, 3, 5, 6, 4, 1, 9, 2, 3, 8, 9, 5, 2, 2, 7, 5, 7, 4, 2, 4, 5, 7, 7, 3, 5, 5, 1, 3, 1, 1, 4, 9, 3, 2, 4, 3, 6, 3, 8, 7, 2, 3, 9, 5, 7, 7, 7, 3, 5, 9, 7, 3, 3, 8, 1, 9, 8, 3, 7, 6, 8, 8, 5, 5, 4, 7, 0, 8, 4, 7, 9, 4, 1, 0, 3, 5, 1, 4, 0, 3, 5, 1, 3, 2, 3, 4, 5, 1, 0, 0, 4, 7, 0, 4, 6, 6, 2, 4, 9, 6, 1, 9, 0, 3, 6, 9, 9, 4, 6, 5, 7, 1, 5, 7, 6, 7, 7, 5, 9, 9, 9, 6, 5, 5, 9, 2, 7, 9, 9, 8, 4, 1, 6, 0, 5, 7, 1, 0, 2, 5, 0, 2, 5, 2, 3, 3, 7, 8, 0, 8, 4, 6, 2, 1, 9, 4, 7, 2, 9, 1, 6, 6, 0, 4, 3, 6, 0, 2, 2, 4, 9, 7, 7, 3, 7, 8, 3, 8, 9, 5, 3, 0, 6, 7, 4, 

Abducing...:   5%|▌         | 30/600 [00:08<02:48,  3.39it/s]


KeyboardInterrupt: 

##### Run Abductive Learning with one-shot pre-train

Sample a one-shot training dataset.

In [13]:
import random

# reset the machine learning model
model = Net(outdim=10).to(device)

# reset dataset1 to reset the labels for one-shot training, 
# since the previous abductive learning process has changed 
# the ground truth labels in dataset1
dataset1 = datasets.MNIST('data', train=True, download=True,
                            transform=transform)

n_samples = 5
few_shot_indices = []

for i in range(10):
    few_shot_indices = few_shot_indices + \
        random.sample(digit_groups_train[i], n_samples)

# few_shot_indices = random.sample(all_img_indices, n_samples)

sup_imgs_train = torch.utils.data.Subset(dataset1, few_shot_indices)

sup_train_loader = torch.utils.data.DataLoader(
    sup_imgs_train, **nn_train_kwargs)

optimizer = optim.Adadelta(model.parameters(), lr=1.0)
scheduler = StepLR(optimizer, step_size=1, gamma=0.7)

for epoch in range(1, 5):
    train(model, device, sup_train_loader,
        optimizer, epoch)
    #test(model, device, nn_test_loader)
    scheduler.step()
test(model, device, nn_test_loader)

-- Test set: Average loss: 0.0262, Accuracy: 5630/10000 (56%)


In [14]:
ABL_epochs = 10
for epoch in range(ABL_epochs):
    ABL_main(model, prolog, mnist_sum_data_train, optimizer=optimizer, scheduler=scheduler)

Abducing...: 100%|██████████| 600/600 [02:52<00:00,  3.48it/s]


[2, 8, 4, 5, 4, 3, 7, 8, 6, 2, 6, 5, 4, 5, 2, 6, 8, 1, 0, 1, 4, 7, 7, 9, 2, 4, 7, 3, 2, 1, 8, 3, 3, 9, 3, 4, 2, 0, 6, 2, 2, 0, 3, 8, 5, 7, 1, 9, 2, 0, 7, 9, 8, 8, 5, 8, 8, 3, 7, 6, 9, 0, 3, 8, 3, 3, 3, 4, 6, 5, 1, 7, 7, 0, 0, 3, 7, 2, 3, 4, 5, 6, 3, 0, 6, 4, 7, 5, 2, 2, 7, 7, 6, 1, 7, 7, 8, 7, 3, 8, 1, 2, 0, 3, 4, 5, 2, 1, 1, 7, 4, 1, 1, 5, 8, 1, 5, 5, 1, 5, 6, 8, 6, 1, 5, 1, 8, 0, 4, 6, 2, 8, 4, 5, 7, 4, 2, 3, 3, 4, 3, 5, 5, 4, 3, 6, 3, 7, 5, 3, 5, 6, 4, 1, 9, 2, 3, 8, 9, 5, 2, 2, 7, 5, 7, 4, 2, 4, 5, 7, 7, 3, 5, 5, 1, 3, 1, 1, 4, 9, 3, 2, 4, 3, 6, 3, 8, 7, 2, 3, 9, 5, 7, 7, 7, 3, 5, 9, 7, 3, 3, 8, 1, 9, 8, 3, 7, 6, 8, 8, 5, 5, 4, 7, 0, 8, 4, 7, 9, 4, 1, 0, 3, 5, 1, 4, 0, 3, 5, 1, 3, 2, 3, 4, 5, 1, 0, 0, 4, 7, 0, 4, 6, 6, 2, 4, 9, 6, 1, 9, 0, 3, 6, 9, 9, 4, 6, 5, 7, 1, 5, 7, 6, 7, 7, 5, 9, 9, 9, 6, 5, 5, 9, 2, 7, 9, 9, 8, 4, 1, 6, 0, 5, 7, 1, 0, 2, 5, 0, 2, 5, 2, 3, 3, 7, 8, 0, 8, 4, 6, 2, 1, 9, 4, 7, 2, 9, 1, 6, 6, 0, 4, 3, 6, 0, 2, 2, 4, 9, 7, 7, 3, 7, 8, 3, 8, 9, 5, 3, 0, 6, 7, 4, 

Abducing...: 100%|██████████| 600/600 [02:51<00:00,  3.51it/s]


[2, 8, 4, 5, 4, 3, 7, 8, 6, 2, 6, 5, 4, 5, 2, 6, 8, 1, 0, 1, 4, 7, 7, 9, 2, 4, 7, 3, 2, 1, 8, 3, 3, 9, 3, 4, 2, 0, 6, 2, 2, 0, 3, 8, 5, 7, 1, 9, 2, 0, 7, 9, 8, 8, 5, 8, 8, 3, 7, 6, 9, 0, 3, 8, 3, 3, 3, 4, 6, 5, 1, 7, 7, 0, 0, 3, 7, 2, 3, 4, 5, 6, 3, 0, 6, 4, 7, 5, 2, 2, 7, 7, 6, 1, 7, 7, 8, 7, 3, 8, 1, 2, 0, 3, 4, 5, 2, 1, 1, 7, 4, 1, 1, 5, 8, 1, 5, 5, 1, 5, 6, 8, 6, 1, 5, 1, 8, 0, 4, 6, 2, 8, 4, 5, 7, 4, 2, 3, 3, 4, 3, 5, 5, 4, 3, 6, 3, 7, 5, 3, 5, 6, 4, 1, 9, 2, 3, 8, 9, 5, 2, 2, 7, 5, 7, 4, 2, 4, 5, 7, 7, 3, 5, 5, 1, 3, 1, 1, 4, 9, 3, 2, 4, 3, 6, 3, 8, 7, 2, 3, 9, 5, 7, 7, 7, 3, 5, 9, 7, 3, 3, 8, 1, 9, 8, 3, 7, 6, 8, 8, 5, 5, 4, 7, 0, 8, 4, 7, 9, 4, 1, 0, 3, 5, 1, 4, 0, 3, 5, 1, 3, 2, 3, 4, 5, 1, 0, 0, 4, 7, 0, 4, 6, 6, 2, 4, 9, 6, 1, 9, 0, 3, 6, 9, 9, 4, 6, 5, 7, 1, 5, 7, 6, 7, 7, 5, 9, 9, 9, 6, 5, 5, 9, 2, 7, 9, 9, 8, 4, 1, 6, 0, 5, 7, 1, 0, 2, 5, 0, 2, 5, 2, 3, 3, 7, 8, 0, 8, 4, 6, 2, 1, 9, 4, 7, 2, 9, 1, 6, 6, 0, 4, 3, 6, 0, 2, 2, 4, 9, 7, 7, 3, 7, 8, 3, 8, 9, 5, 3, 0, 6, 7, 4, 

Abducing...: 100%|██████████| 600/600 [02:50<00:00,  3.51it/s]


[2, 8, 4, 5, 4, 3, 7, 8, 6, 2, 6, 5, 4, 5, 2, 6, 8, 1, 0, 1, 4, 7, 7, 9, 2, 4, 7, 3, 2, 1, 8, 3, 3, 9, 3, 4, 2, 0, 6, 2, 2, 0, 3, 8, 5, 7, 1, 9, 2, 0, 7, 9, 8, 8, 5, 8, 8, 3, 7, 6, 9, 0, 3, 8, 3, 3, 3, 4, 6, 5, 1, 7, 7, 0, 0, 3, 7, 2, 3, 4, 5, 6, 3, 0, 6, 4, 7, 5, 2, 2, 7, 7, 6, 1, 7, 7, 8, 7, 3, 8, 1, 2, 0, 3, 4, 5, 2, 1, 1, 7, 4, 1, 1, 5, 8, 1, 5, 5, 1, 5, 6, 8, 6, 1, 5, 1, 8, 0, 4, 6, 2, 8, 4, 5, 7, 4, 2, 3, 3, 4, 3, 5, 5, 4, 3, 6, 3, 7, 5, 3, 5, 6, 4, 1, 9, 2, 3, 8, 9, 5, 2, 2, 7, 5, 7, 4, 2, 4, 5, 7, 7, 3, 5, 5, 1, 3, 1, 1, 4, 9, 3, 2, 4, 3, 6, 3, 8, 7, 2, 3, 9, 5, 7, 7, 7, 3, 5, 9, 7, 3, 3, 8, 1, 9, 8, 3, 7, 6, 8, 8, 5, 5, 4, 7, 0, 8, 4, 7, 9, 4, 1, 0, 3, 5, 1, 4, 0, 3, 5, 1, 3, 2, 3, 4, 5, 1, 0, 0, 4, 7, 0, 4, 6, 6, 2, 4, 9, 6, 1, 9, 0, 3, 6, 9, 9, 4, 6, 5, 7, 1, 5, 7, 6, 7, 7, 5, 9, 9, 9, 6, 5, 5, 9, 2, 7, 9, 9, 8, 4, 1, 6, 0, 5, 7, 1, 0, 2, 5, 0, 2, 5, 2, 3, 3, 7, 8, 0, 8, 4, 6, 2, 1, 9, 4, 7, 2, 9, 1, 6, 6, 0, 4, 3, 6, 0, 2, 2, 4, 9, 7, 7, 3, 7, 8, 3, 8, 9, 5, 3, 0, 6, 7, 4, 

Abducing...:  13%|█▎        | 78/600 [00:21<02:23,  3.63it/s]


KeyboardInterrupt: 