In [None]:
import os.path as osp

import torch
import torch.nn as nn

from abl.bridge import SimpleBridge
from abl.evaluation import SemanticsMetric, SymbolMetric
from abl.learning import ABLModel, BasicNN
from abl.reasoning import ReasonerBase
from abl.utils import ABLLogger, print_log
from examples.hwf.datasets.get_hwf import get_hwf
from examples.hwf.hwf_kb import HWF_KB
from examples.models.nn import SymbolNet

In [None]:
# Initialize logger and print basic information
print_log("Abductive Learning on the HWF example.", logger="current")

# Retrieve the directory of the Log file and define the directory for saving the model weights.
log_dir = ABLLogger.get_current_instance().log_dir
weights_dir = osp.join(log_dir, "weights")

### Logic Part

In [None]:
# Initialize knowledge base and abducer
kb = HWF_KB()
abducer = ReasonerBase(kb, dist_func="confidence")

### Machine Learning Part

In [None]:
# Initialize necessary component for machine learning part
cls = SymbolNet(num_classes=len(kb.pseudo_label_list), image_size=(45, 45, 1))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99))

In [None]:
# Initialize BasicNN
# The function of BasicNN is to wrap NN models into the form of an sklearn estimator
base_model = BasicNN(
    model=cls,
    criterion=criterion,
    optimizer=optimizer,
    device=device,
    batch_size=128,
    num_epochs=1,
)

In [None]:
# Initialize ABL model
# The main function of the ABL model is to serialize data and
# provide a unified interface for different machine learning models
model = ABLModel(base_model)

### Metric

In [None]:
# Add metric
metric_list = [SymbolMetric(prefix="hwf"), SemanticsMetric(kb=kb, prefix="hwf")]

### Dataset

In [None]:
# Get training and testing data
train_data = get_hwf(train=True, get_gt_pseudo_label=True)
test_data = get_hwf(train=False, get_gt_pseudo_label=True)

### Bridge Machine Learning and Logic Reasoning

In [None]:
bridge = SimpleBridge(model=model, abducer=abducer, metric_list=metric_list)

### Train and Test

In [None]:
bridge.train(train_data, loops=5, segment_size=1000, save_interval=1, save_dir=weights_dir)
bridge.test(test_data)