In [1]:
import os.path as osp

import torch.nn as nn
import torch

from abl.reasoning import ReasonerBase, KBBase

from abl.learning import BasicNN, ABLModel
from abl.bridge import SimpleBridge
from abl.evaluation import SymbolMetric
from abl.utils import ABLLogger, print_log

from examples.models.nn import LeNet5
from examples.mnist_add.datasets.get_mnist_add import get_mnist_add

In [2]:
# Initialize logger
print_log("Abductive Learning on the MNIST Add example.", logger="current")

# Retrieve the directory of the Log file and define the directory for saving the model weights.
log_dir = ABLLogger.get_current_instance().log_dir
weights_dir = osp.join(log_dir, "weights")

11/15 21:35:55 - abl - [4m[37mINFO[0m - Abductive Learning on the MNIST Add example.


### Logic Part

In [3]:
# Initialize knowledge base and abducer
class add_KB(KBBase):
    def logic_forward(self, nums):
        return sum(nums)

kb = add_KB(pseudo_label_list=list(range(10)))

# kb = prolog_KB(pseudo_label_list=list(range(10)), pl_file='datasets/mnist_add/add.pl')
abducer = ReasonerBase(kb, dist_func="confidence")

### Machine Learning Part

In [4]:
# Initialize necessary component for machine learning part
cls = LeNet5(num_classes=len(kb.pseudo_label_list))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99))

In [5]:
# Initialize BasicNN
# The function of BasicNN is to wrap NN models into the form of an sklearn estimator
base_model = BasicNN(
    cls,
    criterion,
    optimizer,
    device,
    batch_size=32,
    num_epochs=1,
)

### Use ABL model to join two parts

In [6]:
# Initialize ABL model
# The main function of the ABL model is to serialize data and 
# provide a unified interface for different machine learning models
model = ABLModel(base_model)

### Metric

In [7]:
# Add metric
metric = [SymbolMetric(prefix="mnist_add")]

### Dataset

In [8]:
# Get training and testing data
train_data = get_mnist_add(train=True, get_pseudo_label=True)
test_data = get_mnist_add(train=False, get_pseudo_label=True)

### Bridge Machine Learning and Logic Reasoning

In [9]:
bridge = SimpleBridge(model, abducer, metric)

### Train and Test

In [10]:
bridge.train(train_data, loops=5, segment_size=10000, save_interval=1, save_dir=weights_dir)
bridge.test(test_data)

11/15 21:36:21 - abl - [4m[37mINFO[0m - loop(train) [1/5] segment(train) [1/3] model loss is 1.80390
11/15 21:36:24 - abl - [4m[37mINFO[0m - loop(train) [1/5] segment(train) [2/3] model loss is 1.41898
11/15 21:36:26 - abl - [4m[37mINFO[0m - loop(train) [1/5] segment(train) [3/3] model loss is 1.08221
11/15 21:36:26 - abl - [4m[37mINFO[0m - Evaluation start: loop(val) [1]
11/15 21:36:27 - abl - [4m[37mINFO[0m - Evaluation ended, mnist_add/character_accuracy: 0.590 
11/15 21:36:27 - abl - [4m[37mINFO[0m - Saving model: loop(save) [1]
11/15 21:36:27 - abl - [4m[37mINFO[0m - Checkpoints will be saved to results/20231115_21_35_55/weights/model_checkpoint_loop_1.pth
11/15 21:36:29 - abl - [4m[37mINFO[0m - loop(train) [2/5] segment(train) [1/3] model loss is 0.65210
11/15 21:36:31 - abl - [4m[37mINFO[0m - loop(train) [2/5] segment(train) [2/3] model loss is 0.13546
11/15 21:36:32 - abl - [4m[37mINFO[0m - loop(train) [2/5] segment(train) [3/3] model loss is 0.080