In [None]:
import numpy as np
import torch.nn as nn
import torch

from abl.reasoning.reasoner import ReasonerBase
from abl.reasoning.kb import KBBase

from abl.utils.plog import logger
from abl.learning.basic_nn import BasicNN
from abl.learning.abl_model import ABLModel

from models.nn import SymbolNet
from datasets.get_hwf import get_hwf
from abl import framework

In [None]:
# Initialize logger
recorder = logger()

### Logic Part

In [None]:
# Initialize knowledge base and abducer
class HWF_KB(KBBase):
    def __init__(
        self, 
        pseudo_label_list=['1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', 'times', 'div'], 
        len_list=[1, 3, 5, 7],
        GKB_flag=False,
        max_err=1e-3,
        use_cache=True
    ):
        super().__init__(pseudo_label_list, len_list, GKB_flag, max_err, use_cache)

    def _valid_candidate(self, formula):
        if len(formula) % 2 == 0:
            return False
        for i in range(len(formula)):
            if i % 2 == 0 and formula[i] not in ['1', '2', '3', '4', '5', '6', '7', '8', '9']:
                return False
            if i % 2 != 0 and formula[i] not in ['+', '-', 'times', 'div']:
                return False
        return True

    def logic_forward(self, formula):
        if not self._valid_candidate(formula):
            return np.inf
        mapping = {str(i): str(i) for i in range(1, 10)}
        mapping.update({'+': '+', '-': '-', 'times': '*', 'div': '/'})
        formula = [mapping[f] for f in formula]
        return eval(''.join(formula))

kb = HWF_KB(GKB_flag=True)
abducer = ReasonerBase(kb)

### Machine Learning Part

In [None]:
# Initialize necessary component for machine learning part
cls = SymbolNet(num_classes=len(kb.pseudo_label_list), image_size=(45, 45, 1))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99))

In [None]:
# Initialize BasicNN
# The function of BasicNN is to wrap NN models into the form of an sklearn estimator
base_model = BasicNN(
    cls,
    criterion,
    optimizer,
    device,
    save_interval=1,
    save_dir=recorder.save_dir,
    batch_size=32,
    num_epochs=1,
    recorder=recorder,
)

### Use ABL model to join two parts

In [None]:
# Initialize ABL model
# The main function of the ABL model is to serialize data and 
# provide a unified interface for different machine learning models
model = ABLModel(base_model, kb.pseudo_label_list)

### Dataset

In [None]:
# Get training and testing data
train_data = get_hwf(train=True, get_pseudo_label=True)
test_data = get_hwf(train=False, get_pseudo_label=True)

### Train and save

In [None]:
# Train model
framework.train(
    model, abducer, train_data, test_data, loop_num=15, sample_num=5000, verbose=1
)

# Save results
recorder.dump()