In [None]:
import numpy as np
import torch.nn as nn
import torch

from abl.reasoning import ReasonerBase, prolog_KB
from abl.learning import BasicNN, ABLModel
from abl.evaluation import SymbolMetric, ABLMetric
from abl.utils import ABLLogger, reform_list

from examples.hed.hed_bridge import HEDBridge
from models.nn import SymbolNet
from datasets.get_hed import get_hed, split_equation

In [None]:
# Initialize logger
logger = ABLLogger.get_instance("abl")

### Logic Part

In [None]:
# Initialize knowledge base and abducer
class HED_prolog_KB(prolog_KB):
    def __init__(self, pseudo_label_list, pl_file):
        super().__init__(pseudo_label_list, pl_file)
        
    def consist_rule(self, exs, rules):
        rules = str(rules).replace("\'","")
        return len(list(self.prolog.query("eval_inst_feature(%s, %s)." % (exs, rules)))) != 0

    def abduce_rules(self, pred_res):
        prolog_result = list(self.prolog.query("consistent_inst_feature(%s, X)." % pred_res))
        if len(prolog_result) == 0:
            return None
        prolog_rules = prolog_result[0]['X']
        rules = [rule.value for rule in prolog_rules]
        return rules
        

kb = HED_prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='./datasets/learn_add.pl')

class HED_Abducer(ReasonerBase):
    def __init__(self, kb, dist_func='hamming'):
        super().__init__(kb, dist_func, zoopt=True)
    
    def _revise_by_idxs(self, pred_res, key, all_address_flag, idxs):
        pred = []
        k = []
        address_flag = []
        for idx in idxs:
            pred.append(pred_res[idx])
            k.append(key[idx])
            address_flag += list(all_address_flag[idx])
        address_idx = np.where(np.array(address_flag) != 0)[0]   
        candidate = self.revise_by_idx(pred, k, address_idx)
        return candidate
    
    def zoopt_revision_score(self, pred_res, pseudo_label, pred_res_prob, key, sol): 
        all_address_flag = reform_list(sol.get_x(), pseudo_label)
        lefted_idxs = [i for i in range(len(pred_res))]
        candidate_size = []         
        while lefted_idxs:
            idxs = []
            idxs.append(lefted_idxs.pop(0))
            max_candidate_idxs = []
            found = False
            for idx in range(-1, len(pred_res)):
                if (not idx in idxs) and (idx >= 0):
                    idxs.append(idx)
                candidate = self._revise_by_idxs(pseudo_label, key, all_address_flag, idxs)
                if len(candidate) == 0:
                    if len(idxs) > 1:
                        idxs.pop()
                else:
                    if len(idxs) > len(max_candidate_idxs):
                        found = True
                        max_candidate_idxs = idxs.copy() 
            removed = [i for i in lefted_idxs if i in max_candidate_idxs]
            if found:
                candidate_size.append(len(removed) + 1)
                lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs]
        candidate_size.sort()
        score = 0
        import math
        for i in range(0, len(candidate_size)):
            score -= math.exp(-i) * candidate_size[i]
        return score
    
    def abduce(self, data, max_revision=-1, require_more_revision=0):
        batch_pred_label, batch_pred_prob, batch_pred_pseudo_label, batch_y = data

        solution = self.zoopt_get_solution(
            batch_pred_label, batch_pred_pseudo_label, batch_pred_prob, batch_y, max_revision
        )
        batch_revision_idx = reform_list(solution.astype(np.int32), batch_pred_label)
        
        batch_abduced_pseudo_label = []
        for pred_pseudo_label, pred_prob, revision_idx in zip(batch_pred_pseudo_label, batch_pred_prob, batch_revision_idx):
            candidates = self.revise_by_idx([pred_pseudo_label], None, list(np.nonzero(np.array(revision_idx))[0]))
            if len(candidates) == 0:
                batch_abduced_pseudo_label.append([])
            else:
                batch_abduced_pseudo_label.append(candidates[0][0])
                # batch_abduced_pseudo_label.append(self._get_one_candidate(pred_pseudo_label, pred_prob, candidates)[0])
        return batch_abduced_pseudo_label

    def abduce_rules(self, pred_res):
        return self.kb.abduce_rules(pred_res)
        
abducer = HED_Abducer(kb)

### Machine Learning Part

In [None]:
# Initialize necessary component for machine learning part
cls = SymbolNet(
    num_classes=len(kb.pseudo_label_list),
    image_size=(28, 28, 1),
)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, weight_decay=1e-6)

In [None]:
# Initialize BasicNN
# The function of BasicNN is to wrap NN models into the form of an sklearn estimator
base_model = BasicNN(
    cls,
    criterion,
    optimizer,
    device,
    save_interval=1,
    save_dir=logger.save_dir,
    batch_size=32,
    num_epochs=1,
)

In [None]:
# Initialize ABL model
# The main function of the ABL model is to serialize data and 
# provide a unified interface for different machine learning models
model = ABLModel(base_model)

### Metric

In [None]:
# Add metric
metric = [SymbolMetric(prefix="hed"), ABLMetric(prefix="hed")]

### Bridge Machine Learning and Logic Reasoning

In [None]:
bridge = HEDBridge(model, abducer, metric)

### Dataset

In [None]:
total_train_data = get_hed(train=True)
train_data, val_data = split_equation(total_train_data, 3, 1)
test_data = get_hed(train=False)

### Train and Test

In [None]:
bridge.pretrain("./weights/")
bridge.train(train_data, val_data)