In [None]:
import sys

sys.path.append("../")

import torch.nn as nn
import torch

from abl.abducer.abducer_base import AbducerBase
from abl.abducer.kb import add_KB

from abl.utils.plog import logger
from abl.models.basic_model import BasicModel
from abl.models.wabl_models import WABLBasicModel

from models.nn import LeNet5
from datasets.mnist_add.get_mnist_add import get_mnist_add
from abl import framework_hed

In [None]:
# Initialize logger
recorder = logger()

### Logic Part

In [None]:
# Initialize knowledge base and abducer
kb = add_KB(GKB_flag=True)
abducer = AbducerBase(kb, dist_func="confidence")

### Machine Learning Part

In [None]:
# Initialize necessary component for machine learning part
cls = LeNet5(num_classes=len(kb.pseudo_label_list))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99))

In [None]:
# Initialize BasicModel
# The function of BasicModel is to wrap NN models into the form of an sklearn estimator
base_model = BasicModel(
    cls,
    criterion,
    optimizer,
    device,
    save_interval=1,
    save_dir=recorder.save_dir,
    batch_size=32,
    num_epochs=1,
    recorder=recorder,
)

### Use WABL model to join two parts

In [None]:
# Initialize WABL model
# The main function of the WABL model is to serialize data and 
# provide a unified interface for different machine learning models
model = WABLBasicModel(base_model, kb.pseudo_label_list)

### Dataset

In [None]:
# Get training and testing data
train_X, train_Z, train_Y = get_mnist_add(train=True, get_pseudo_label=True)
test_X, test_Z, test_Y = get_mnist_add(train=False, get_pseudo_label=True)

### Train and save

In [None]:
# Train model
framework_hed.train(
    model,
    abducer,
    (train_X, train_Z, train_Y),
    (test_X, test_Z, test_Y),
    loop_num=15,
    sample_num=5000,
    verbose=1,
)

# Save results
recorder.dump()