In [4]:
import sys

sys.path.append("../../")

import torch.nn as nn
import torch

from abl.abducer.abducer_base import AbducerBase
from abl.abducer.kb import add_KB

from abl.utils.plog import logger
from abl.models.basic_model import BasicModel
from abl.models.wabl_models import WABLBasicModel

from models.nn import LeNet5
from datasets.get_mnist_add import get_mnist_add
from abl import framework

In [5]:
# Initialize logger
recorder = logger()




### Logic Part

In [6]:
# Initialize knowledge base and abducer
kb = add_KB(GKB_flag=True)
abducer = AbducerBase(kb, dist_func="confidence")

### Machine Learning Part

In [7]:
# Initialize necessary component for machine learning part
cls = LeNet5(num_classes=len(kb.pseudo_label_list))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99))

In [8]:
# Initialize BasicModel
# The function of BasicModel is to wrap NN models into the form of an sklearn estimator
base_model = BasicModel(
    cls,
    criterion,
    optimizer,
    device,
    save_interval=1,
    save_dir=recorder.save_dir,
    batch_size=32,
    num_epochs=1,
    recorder=recorder,
)

### Use WABL model to join two parts

In [9]:
# Initialize WABL model
# The main function of the WABL model is to serialize data and 
# provide a unified interface for different machine learning models
model = WABLBasicModel(base_model, kb.pseudo_label_list)

### Dataset

In [10]:
# Get training and testing data
train_X, train_Z, train_Y = get_mnist_add(train=True, get_pseudo_label=True)
test_X, test_Z, test_Y = get_mnist_add(train=False, get_pseudo_label=True)

### Train and save

In [11]:
# Train model
framework.train(
    model,
    abducer,
    (train_X, train_Z, train_Y),
    (test_X, test_Z, test_Y),
    loop_num=15,
    sample_num=5000,
    verbose=1,
)

# Save results
recorder.dump()

INFO:root:seg_idx:0, part num:6, data num:30000
INFO:root:Start Predict Probability 
INFO:root:#Result# {'func:': 'predict: cost 1.6403421089053154s'}
INFO:root:#Result# {'func:': 'batch_abduce: cost 0.45026259310543537s'}
INFO:root:loop: 1 {'Character level accuracy': 0.099, 'ABL accuracy': 0.029}
INFO:root:model fitting
INFO:root:0/1 model training loss is 1.9688767925262451
INFO:root:Saving model and opter
INFO:root:Model fitted, minimal loss is 1.9688767925262451
INFO:root:#Result# {'func:': 'train: cost 0.8824481889605522s'}
INFO:root:seg_idx:1, part num:6, data num:30000
INFO:root:Start Predict Probability 
INFO:root:#Result# {'func:': 'predict: cost 0.3438392709940672s'}
INFO:root:#Result# {'func:': 'batch_abduce: cost 0.34831187315285206s'}
INFO:root:loop: 2 {'Character level accuracy': 0.1754, 'ABL accuracy': 0.0798}
INFO:root:model fitting
INFO:root:0/1 model training loss is 1.5468237173080444
INFO:root:Saving model and opter
INFO:root:Model fitted, minimal loss is 1.5468237

KeyboardInterrupt: 