In [None]:
import os.path as osp

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

from abl.bridge import SimpleBridge
from abl.evaluation import ReasoningMetric, SymbolMetric
from abl.learning import ABLModel, BasicNN
from abl.reasoning import KBBase, Reasoner
from abl.utils import ABLLogger, print_log
from examples.mnist_add.datasets import get_mnist_add
from examples.models.nn import LeNet5

In [None]:
# Build logger
print_log("Abductive Learning on the MNIST Addition example.", logger="current")

# Retrieve the directory of the Log file and define the directory for saving the model weights.
log_dir = ABLLogger.get_current_instance().log_dir
weights_dir = osp.join(log_dir, "weights")

### Load Datasets

In [None]:
# Get training and testing data
train_data = get_mnist_add(train=True, get_pseudo_label=True)
test_data = get_mnist_add(train=False, get_pseudo_label=True)

In [None]:
print(f"There are {len(train_data[0])} data examples in the training set and {len(test_data[0])} data examples in the test set")
print(f"Each of the data example has {len(train_data)} components: X, gt_pseudo_label, and Y.")
print("For instance, in the First data example in the training set, we have:")
print(f"X ({len(train_data[0][0])} images):")
plt.subplot(1,2,1)
plt.axis('off') 
plt.imshow(train_data[0][0][0].numpy().transpose(1, 2, 0))
plt.subplot(1,2,2)
plt.axis('off') 
plt.imshow(train_data[0][0][1].numpy().transpose(1, 2, 0))
plt.show()
print(f"gt_pseudo_label ({len(train_data[1][0])} ground truth pseudo label): {train_data[1][0][0]}, {train_data[1][0][1]}")
print(f"Y (their sum result): {train_data[2][0]}")

### Learning Part

In [None]:
# Build necessary components for BasicNN
cls = LeNet5(num_classes=10)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
# Build BasicNN
# The function of BasicNN is to wrap NN models into the form of an sklearn estimator
base_model = BasicNN(
    cls,
    loss_fn,
    optimizer,
    device,
    batch_size=32,
    num_epochs=1,
)

In [None]:
# Build ABLModel
# The main function of the ABL model is to serialize data and
# provide a unified interface for different machine learning models
model = ABLModel(base_model)

### Logic Part

In [None]:
# Build knowledge base and reasoner
class AddKB(KBBase):
    def __init__(self, pseudo_label_list):
        super().__init__(pseudo_label_list)

    # Implement the deduction function
    def logic_forward(self, nums):
        return sum(nums)


kb = AddKB(pseudo_label_list=list(range(10)))
reasoner = Reasoner(kb, dist_func="confidence")

### Datasets and Evaluation Metrics

In [None]:
# Set up metrics
metric_list = [SymbolMetric(prefix="mnist_add"), ReasoningMetric(kb=kb, prefix="mnist_add")]

### Bridge Machine Learning and Logic Reasoning

In [None]:
bridge = SimpleBridge(model, reasoner, metric_list)

### Train and Test

In [None]:
bridge.train(train_data, loops=5, segment_size=1/3, save_interval=1, save_dir=weights_dir)
bridge.test(test_data)