In [2]:
import os.path as osp

import numpy as np
import torch
import torch.nn as nn
from zoopt import Dimension, Objective, Opt, Parameter

from abl.evaluation import ReasoningMetric, SymbolMetric
from abl.learning import ABLModel, BasicNN
from abl.reasoning import PrologKB, Reasoner
from abl.utils import ABLLogger, print_log, reform_list
from examples.hed.datasets.get_hed import get_hed, split_equation
from examples.hed.hed_bridge import HEDBridge
from examples.models.nn import SymbolNet

In [3]:
# Build logger
print_log("Abductive Learning on the HED example.", logger="current")

# Retrieve the directory of the Log file and define the directory for saving the model weights.
log_dir = ABLLogger.get_current_instance().log_dir
weights_dir = osp.join(log_dir, "weights")

12/18 09:01:12 - abl - INFO - Abductive Learning on the HED example.


### Logic Part

In [4]:
# Initialize knowledge base and abducer
class HedKB(PrologKB):
    def __init__(self, pseudo_label_list, pl_file):
        super().__init__(pseudo_label_list, pl_file)

    def consist_rule(self, exs, rules):
        rules = str(rules).replace("'", "")
        return len(list(self.prolog.query("eval_inst_feature(%s, %s)." % (exs, rules)))) != 0

    def abduce_rules(self, pred_res):
        prolog_result = list(self.prolog.query("consistent_inst_feature(%s, X)." % pred_res))
        if len(prolog_result) == 0:
            return None
        prolog_rules = prolog_result[0]["X"]
        rules = [rule.value for rule in prolog_rules]
        return rules


class HedReasoner(Reasoner):
    def revise_at_idx(self, data_example):
        revision_idx = np.where(np.array(data_example.flatten("revision_flag")) != 0)[0]
        candidate = self.kb.revise_at_idx(
            data_example.pred_pseudo_label, data_example.Y, data_example.X, revision_idx
        )
        return candidate

    def zoopt_revision_score(self, symbol_num, data_example, sol):
        revision_flag = reform_list(
            list(sol.get_x().astype(np.int32)), data_example.pred_pseudo_label
        )
        data_example.revision_flag = revision_flag

        lefted_idxs = [i for i in range(len(data_example.pred_idx))]
        candidate_size = []
        max_consistent_idxs = []
        while lefted_idxs:
            idxs = []
            idxs.append(lefted_idxs.pop(0))
            max_candidate_idxs = []
            found = False
            for idx in range(-1, len(data_example.pred_idx)):
                if (not idx in idxs) and (idx >= 0):
                    idxs.append(idx)
                candidates, _ = self.revise_at_idx(data_example[idxs])
                if len(candidates) == 0:
                    if len(idxs) > 1:
                        idxs.pop()
                else:
                    if len(idxs) > len(max_candidate_idxs):
                        found = True
                        max_candidate_idxs = idxs.copy()
            removed = [i for i in lefted_idxs if i in max_candidate_idxs]
            if found:
                removed.insert(0, idxs[0])
                candidate_size.append(len(removed))
                max_consistent_idxs = max_candidate_idxs.copy()
                lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs]
        candidate_size.sort()
        score = 0
        import math

        for i in range(0, len(candidate_size)):
            score -= math.exp(-i) * candidate_size[i]
        return score, max_consistent_idxs
    
    def _zoopt_get_solution(self, symbol_num, data_example, max_revision_num):
        dimension = Dimension(size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num)
        objective = Objective(
            lambda sol: self.zoopt_revision_score(symbol_num, data_example, sol)[0],
            dim=dimension,
            constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num),
        )
        parameter = Parameter(budget=100, intermediate_result=False, autoset=True)
        solution = Opt.min(objective, parameter)
        return solution

    def abduce(self, data_example):
        symbol_num = data_example.elements_num("pred_pseudo_label")
        max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num)

        solution = self._zoopt_get_solution(symbol_num, data_example, max_revision_num)
        _, max_candidate_idxs = self.zoopt_revision_score(symbol_num, data_example, solution)

        abduced_pseudo_label = [[] for _ in range(len(data_example))]

        if len(max_candidate_idxs) > 0:
            candidates, _ = self.revise_at_idx(data_example[max_candidate_idxs])
            for i, idx in enumerate(max_candidate_idxs):
                abduced_pseudo_label[idx] = candidates[0][i]
        data_example.abduced_pseudo_label = abduced_pseudo_label
        return abduced_pseudo_label

    def abduce_rules(self, pred_res):
        return self.kb.abduce_rules(pred_res)


kb = HedKB(pseudo_label_list=[1, 0, "+", "="], pl_file="./datasets/learn_add.pl")
reasoner = HedReasoner(kb, dist_func="hamming", use_zoopt=True, max_revision=10)

### Machine Learning Part

In [5]:
# Build necessary components for BasicNN
cls = SymbolNet(num_classes=4)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, weight_decay=1e-4)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
# Build BasicNN
# The function of BasicNN is to wrap NN models into the form of an sklearn estimator
base_model = BasicNN(
    cls,
    loss_fn,
    optimizer,
    device,
    batch_size=32,
    num_epochs=1,
    save_interval=1,
    stop_loss=None,
    save_dir=weights_dir,
)

In [7]:
# Build ABLModel
# The main function of the ABL model is to serialize data and
# provide a unified interface for different machine learning models
model = ABLModel(base_model)

### Metric

In [8]:
# Set up metrics
metric_list = [SymbolMetric(prefix="hed"), ReasoningMetric(kb=kb, prefix="hed")]

### Bridge Machine Learning and Logic Reasoning

In [9]:
bridge = HEDBridge(model, reasoner, metric_list)

### Dataset

In [12]:
total_train_data = get_hed(train=True)
train_data, val_data = split_equation(total_train_data, 3, 1)
test_data = get_hed(train=False)

### Train and Test

In [13]:
bridge.pretrain("./weights")
bridge.train(train_data, val_data)

12/18 09:04:27 - abl - INFO - Pretrain Start
12/18 09:04:31 - abl - INFO - Checkpoints will be saved to ./weights/model_checkpoint_epoch_1.pth
12/18 09:04:33 - abl - INFO - Checkpoints will be saved to ./weights/model_checkpoint_epoch_2.pth
12/18 09:04:34 - abl - INFO - Checkpoints will be saved to ./weights/model_checkpoint_epoch_3.pth
12/18 09:04:36 - abl - INFO - Checkpoints will be saved to ./weights/model_checkpoint_epoch_4.pth
12/18 09:04:37 - abl - INFO - Checkpoints will be saved to ./weights/model_checkpoint_epoch_5.pth
12/18 09:04:38 - abl - INFO - Checkpoints will be saved to ./weights/model_checkpoint_epoch_6.pth
12/18 09:04:40 - abl - INFO - Checkpoints will be saved to ./weights/model_checkpoint_epoch_7.pth
12/18 09:04:41 - abl - INFO - Checkpoints will be saved to ./weights/model_checkpoint_epoch_8.pth
12/18 09:04:43 - abl - INFO - Checkpoints will be saved to ./weights/model_checkpoint_epoch_9.pth
12/18 09:04:44 - abl - INFO - Checkpoints will be saved to ./weights/mode

TypeError: unsupported format string passed to BasicNN.__format__