In [10]:
import sys

sys.path.append("../../")

import numpy as np
import torch.nn as nn
import torch

from abl.abducer.abducer_base import AbducerBase
from abl.abducer.kb import prolog_KB

from abl.utils.plog import logger
from abl.models.basic_nn import BasicNN
from abl.models.abl_model import ABLModel
from abl.utils.utils import reform_idx

from models.nn import SymbolNet
from datasets.get_hed import get_hed, split_equation
import framework_hed

In [11]:
# Initialize logger
recorder = logger()

### Logic Part

In [12]:
# Initialize knowledge base and abducer
class HED_prolog_KB(prolog_KB):
    def __init__(self, pseudo_label_list, pl_file):
        super().__init__(pseudo_label_list, pl_file)
        
    def consist_rule(self, exs, rules):
        rules = str(rules).replace("\'","")
        return len(list(self.prolog.query("eval_inst_feature(%s, %s)." % (exs, rules)))) != 0

    def abduce_rules(self, pred_res):
        prolog_result = list(self.prolog.query("consistent_inst_feature(%s, X)." % pred_res))
        if len(prolog_result) == 0:
            return None
        prolog_rules = prolog_result[0]['X']
        rules = [rule.value for rule in prolog_rules]
        return rules
        
    
kb = HED_prolog_KB(pseudo_label_list=[1, 0, '+', '='], pl_file='./datasets/learn_add.pl')

class HED_Abducer(AbducerBase):
    def __init__(self, kb, dist_func='hamming'):
        super().__init__(kb, dist_func, zoopt=True)
    
    def _address_by_idxs(self, pred_res, key, all_address_flag, idxs):
        pred = []
        k = []
        address_flag = []
        for idx in idxs:
            pred.append(pred_res[idx])
            k.append(key[idx])
            address_flag += list(all_address_flag[idx])
        address_idx = np.where(np.array(address_flag) != 0)[0]   
        candidate = self.address_by_idx(pred, k, address_idx)
        return candidate
    
    def zoopt_address_score(self, pred_res, pred_res_prob, key, sol): 
        all_address_flag = reform_idx(sol.get_x(), pred_res)
        lefted_idxs = [i for i in range(len(pred_res))]
        candidate_size = []         
        while lefted_idxs:
            idxs = []
            idxs.append(lefted_idxs.pop(0))
            max_candidate_idxs = []
            found = False
            for idx in range(-1, len(pred_res)):
                if (not idx in idxs) and (idx >= 0):
                    idxs.append(idx)
                candidate = self._address_by_idxs(pred_res, key, all_address_flag, idxs)
                if len(candidate) == 0:
                    if len(idxs) > 1:
                        idxs.pop()
                else:
                    if len(idxs) > len(max_candidate_idxs):
                        found = True
                        max_candidate_idxs = idxs.copy() 
            removed = [i for i in lefted_idxs if i in max_candidate_idxs]
            if found:
                candidate_size.append(len(removed) + 1)
                lefted_idxs = [i for i in lefted_idxs if i not in max_candidate_idxs]       
        candidate_size.sort()
        score = 0
        import math
        for i in range(0, len(candidate_size)):
            score -= math.exp(-i) * candidate_size[i]
        return score

    def abduce_rules(self, pred_res):
        return self.kb.abduce_rules(pred_res)
        
abducer = HED_Abducer(kb)

ERROR: /home/gaoeh/ABL-Package/examples/hed/datasets/learn_add.pl:67:9: Syntax error: Operator expected


### Machine Learning Part

In [None]:
# Initialize necessary component for machine learning part
cls = SymbolNet(
    num_classes=len(kb.pseudo_label_list),
    image_size=(28, 28, 1),
)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, weight_decay=1e-6)

In [None]:
# Pretrain NN classifier
framework_hed.hed_pretrain(kb, cls, recorder)

In [None]:
# Initialize BasicNN
# The function of BasicNN is to wrap NN models into the form of an sklearn estimator
base_model = BasicNN(
    cls,
    criterion,
    optimizer,
    device,
    save_interval=1,
    save_dir=recorder.save_dir,
    batch_size=32,
    num_epochs=1,
    recorder=recorder,
)

### Use ABL model to join two parts

In [None]:
model = ABLModel(base_model, kb.pseudo_label_list)

### Dataset

In [None]:
total_train_data = get_hed(train=True)
train_data, val_data = split_equation(total_train_data, 3, 1)
test_data = get_hed(train=False)

### Train and save

In [None]:
model, mapping = framework_hed.train_with_rule(model, abducer, train_data, val_data, select_num=10, min_len=5, max_len=8)
framework_hed.hed_test(model, abducer, mapping, train_data, test_data, min_len=5, max_len=8)

recorder.dump()