In [1]:
from JPL_interface import JPL
from modules.module import TransferModule
from modules.active_learning import RandomActiveLearning, LeastConfidenceActiveLearning
from taglets.taglet_executer import TagletExecutor
from task import Task
from label_model import get_label_distribution
from custom_dataset import CustomDataSet, SoftLabelDataSet
import torch
from taglets.end_model import EndModel
import numpy as np
import datetime

In [3]:
def get_task():
    task_names = api.get_available_tasks()
    task_name = task_names[0]  # Image classification task
    api.create_session(task_name)
    task_metadata = api.get_task_metadata(task_name)

    num_base_checkpoints = len(task_metadata['base_label_budget'])
    num_adapt_checkpoints = len(task_metadata['adaptation_label_budget'])

    task = Task(task_name, task_metadata)
    session_status = api.get_session_status()
    current_dataset = session_status['current_dataset']
    task.classes = current_dataset['classes']
    task.number_of_channels = current_dataset['number_of_channels']
    task.dataset_name = current_dataset['name']

    task.unlabeled_image_path = "/data/bats/datasets/lwll/lwll_datasets/mnist/mnist_sample/train"
    task.evaluation_image_path = "/data/bats/datasets/lwll/lwll_datasets/mnist/mnist_sample/test"  # Should be updated later

    # task.unlabeled_image_path = "./sql_data/MNIST/train"
    # task.evaluation_image_path = "./sql_data/MNIST/test"  # Should be updated later
    task.phase = session_status['pair_stage']
    if session_status['pair_stage'] == 'adaptation':
        task.labeled_images = []
        task.pretrained = task_metadata['adaptation_can_use_pretrained_model']
    elif session_status['pair_stage'] == 'base':
        task.labeled_images = api.get_seed_labels()
        task.pretrained = task_metadata['base_can_use_pretrained_model']
    return task, num_base_checkpoints, num_adapt_checkpoints


In [6]:
api = JPL()
task, num_base_checkpoints, num_adapt_checkpoints = get_task()
random_active_learning = RandomActiveLearning()
confidence_active_learning = LeastConfidenceActiveLearning()
taglet_executor = TagletExecutor()
end_model = EndModel(task)
# task.get_related_concepts()
batch_size = 32
num_workers = 2
use_gpu = False
testing = True

In [12]:
def run_checkpoints():
    run_checkpoints_base()
    run_checkpoints_adapt()

def run_checkpoints_base():
    update_task()
    for i in range(num_base_checkpoints):
        run_one_checkpoint(i)

def run_one_checkpoint( checkpoint_num):
    session_status = api.get_session_status()
    assert session_status['pair_stage'] == 'base'
    print('------------------------------------------------------------')
    print('--------------------base check point: {}'.format(checkpoint_num)+'---------------------')
    print('------------------------------------------------------------')

    available_budget = get_available_budget()
    unlabeled_image_names = task.get_unlabeled_image_names()
    print('number of unlabeled data: {}'.format(len(unlabeled_image_names)))
    if checkpoint_num == 0:
        candidates = random_active_learning.find_candidates(available_budget, unlabeled_image_names)
    else:
        candidates = confidence_active_learning.find_candidates(available_budget, unlabeled_image_names)
    request_labels(candidates)
    predictions = get_predictions(session_status['pair_stage'])
    submit_predictions(predictions)

def run_checkpoints_adapt():
    update_task()
    print()
    for i in range(num_adapt_checkpoints):
        session_status = api.get_session_status()
        # assert session_status['pair_stage'] == 'adaptation'
        print('------------------------------------------------------------')
        print('--------------------Adapt check point: {}'.format(i)+'---------------------')
        print('------------------------------------------------------------')

        available_budget = get_available_budget()
        unlabeled_image_names = task.get_unlabeled_image_names()
        print('number of unlabeled data: {}'.format(len(unlabeled_image_names)))
        if i == 0:
            candidates = random_active_learning.find_candidates(available_budget, unlabeled_image_names)
        else:
            candidates = confidence_active_learning.find_candidates(available_budget, unlabeled_image_names)
        request_labels(candidates)
        predictions = get_predictions(session_status['pair_stage'])
        submit_predictions(predictions)

def get_task():
    task_names = api.get_available_tasks()
    task_name = task_names[0]  # Image classification task
    api.create_session(task_name)
    task_metadata = api.get_task_metadata(task_name)

    num_base_checkpoints = len(task_metadata['base_label_budget'])
    num_adapt_checkpoints = len(task_metadata['adaptation_label_budget'])

    task = Task(task_name, task_metadata)
    session_status = api.get_session_status()
    current_dataset = session_status['current_dataset']
    task.classes = current_dataset['classes']
    task.number_of_channels = current_dataset['number_of_channels']
    task.dataset_name = current_dataset['name']

#     task.unlabeled_image_path = "/data/bats/datasets/lwll/lwll_datasets/mnist/mnist_sample/train"
#     task.evaluation_image_path = "/data/bats/datasets/lwll/lwll_datasets/mnist/mnist_sample/test"  # Should be updated later

    task.unlabeled_image_path = "./sql_data/MNIST/train"
    task.evaluation_image_path = "./sql_data/MNIST/test"  # Should be updated later
    task.phase = session_status['pair_stage']
    if session_status['pair_stage'] == 'adaptation':
        task.labeled_images = []
        task.pretrained = task_metadata['adaptation_can_use_pretrained_model']
    elif session_status['pair_stage'] == 'base':
        task.labeled_images = api.get_seed_labels()
        task.pretrained = task_metadata['base_can_use_pretrained_model']
    return task, num_base_checkpoints, num_adapt_checkpoints

def update_task():
    task_metadata = api.get_task_metadata(task.name)
    session_status = api.get_session_status()
    current_dataset = session_status['current_dataset']
    task.classes = current_dataset['classes']
    task.number_of_channels = current_dataset['number_of_channels']
    task.dataset_name = current_dataset['name']

#     task.unlabeled_image_path = "/data/bats/datasets/lwll/lwll_datasets/mnist/mnist_sample/train"
#     task.evaluation_image_path = "/data/bats/datasets/lwll/lwll_datasets/mnist/mnist_sample/test"  # Should be updated later

    task.unlabeled_image_path = "./sql_data/MNIST/train"
    task.evaluation_image_path = "./sql_data/MNIST/test"  # Should be updated later
    task.phase = session_status['pair_stage']
    if session_status['pair_stage'] == 'adaptation':
        task.labeled_images = []
        task.pretrained = task_metadata['adaptation_can_use_pretrained_model']
    elif session_status['pair_stage'] == 'base':
        task.labeled_images = api.get_seed_labels()
        task.pretrained = task_metadata['base_can_use_pretrained_model']

def get_available_budget():
    session_status = api.get_session_status()
    available_budget = session_status['budget_left_until_checkpoint']

    if testing:
        available_budget = available_budget // 10
    return available_budget

def request_labels( examples):
    query = {'example_ids': examples}
    labeled_images = api.request_label(query)
    task.add_labeled_images(labeled_images)
    print("New labeled images:", len(labeled_images))
    print("Total labeled images:", len(task.labeled_images))

def combine_soft_labels( unlabeled_labels, unlabeled_names, train_image_names, train_image_labels):
    def to_soft_one_hot(l):
        soh = [0.15] * len(task.classes)
        soh[l] = 0.85
        return soh

    soft_labels_labeled_images = []
    for image_label in train_image_labels:
        soft_labels_labeled_images.append(to_soft_one_hot(int(image_label)))

    all_soft_labels = np.concatenate((unlabeled_labels, np.array(soft_labels_labeled_images)), axis=0)
    all_names = unlabeled_names + train_image_names

    end_model_train_data = SoftLabelDataSet(task.unlabeled_image_path,
                                      all_names,
                                      all_soft_labels,
                                      task.transform_image(),
                                      task.number_of_channels)

    train_data = torch.utils.data.DataLoader(end_model_train_data,
                                       batch_size=batch_size,
                                       shuffle=True,
                                       num_workers=num_workers)

    return train_data

def get_predictions( phase):
    """train taglets, label model, and endmodel, and return prediction
    :param phase: 'base' or 'adapt'
    """
    train_data_loader, val_data_loader,  train_image_names, train_image_labels = task.load_labeled_data(
        batch_size,
        num_workers)

    unlabeled_data_loader, unlabeled_image_names = task.load_unlabeled_data(batch_size,
                                                                                 num_workers)

    mnist_module = TransferModule(task=task)

    print("**********Training taglets on labeled data**********")
    t1 = datetime.datetime.now()
    mnist_module.train_taglets(train_data_loader, val_data_loader, use_gpu, phase, testing)
    t2 = datetime.datetime.now()
    print()
    print(".....Taglet training time: {}".format((t2 - t1).seconds))

    taglets = mnist_module.get_taglets()
    taglet_executor.set_taglets(taglets)

    print("**********Executing taglets on unlabled data**********")
    t1 = datetime.datetime.now()
    label_matrix, candidates = taglet_executor.execute(unlabeled_data_loader, use_gpu, testing)
    confidence_active_learning.set_candidates(candidates)
    t2 = datetime.datetime.now()
    print()
    print(".....Taglet executing time: {}".format((t2 - t1).seconds))


    print("**********Label Model**********")
    t1 = datetime.datetime.now()
    soft_labels_unlabeled_images = get_label_distribution(label_matrix, len(task.classes), testing)
    t2 = datetime.datetime.now()
    print()
    print(".....Label Model time: {}".format((t2 - t1).seconds))



    print("**********End Model**********")
    t1 = datetime.datetime.now()
    if testing:
        unlabeled_image_names = unlabeled_image_names[:len(soft_labels_unlabeled_images)]
    end_model_train_data_loader = combine_soft_labels(soft_labels_unlabeled_images,
                                                     unlabeled_image_names,
                                                     train_image_names, train_image_labels)
    end_model.train(end_model_train_data_loader, val_data_loader, use_gpu, testing)
    t2 = datetime.datetime.now()
    print()
    print(".....End Model time: {}".format((t2 - t1).seconds))

    return end_model.predict(task.evaluation_image_path,
                                  task.number_of_channels,
                                  task.transform_image(),
                                  use_gpu)

def submit_predictions( predictions):
    submit_status = api.submit_prediction(predictions)
    session_status = api.get_session_status()
    print("Checkpoint scores", session_status['checkpoint_scores'])
    print("Phase:", session_status['pair_stage'])


In [13]:
run_checkpoints()

------------------------------------------------------------
--------------------base check point: 0---------------------
------------------------------------------------------------
number of unlabeled data: 4990
New labeled images: 300
Total labeled images: 310
number of training data: 248
number of validation data: 62
**********Training taglets on labeled data**********
...........prototype...........
epoch: 0
train loss: 0.0097
validation loss: 0.0745
validation acc: 8.0645%
Deep copying new best model.(validation of 0.0806%, over 0.0000%)
Epoch 1 result: 
Average training loss: 0.0097
Average validation loss: 0.0745
Average validation accuracy: 8.0645%
...........prototype...........
epoch: 0
train loss: 0.0097
validation loss: 0.0745
validation acc: 8.0645%
Deep copying new best model.(validation of 0.0806%, over 0.0000%)
Epoch 1 result: 
Average training loss: 0.0097
Average validation loss: 0.0745
Average validation accuracy: 8.0645%
...........finetune...........
epoch: 0
trai

  candidate_probabilities.append(torch.max(torch.nn.functional.softmax(outputs)).item())



.....Taglet executing time: 1
**********Label Model**********

.....Label Model time: 0
**********End Model**********
...........end model...........
epoch: 0
train loss: 0.0180
train acc: 1.0714%
validation loss: 0.0769
validation acc: 3.2258%
Deep copying new best model.(validation of 0.0323%, over 0.0000%)
Epoch 1 result: 
Average training loss: 0.0180
Average training accuracy: 1.0714%
Average validation loss: 0.0769
Average validation accuracy: 3.2258%

.....End Model time: 1
Checkpoint scores [{'accuracy': 0.047}]
Phase: base
------------------------------------------------------------
--------------------base check point: 1---------------------
------------------------------------------------------------
number of unlabeled data: 4690
New labeled images: 32
Total labeled images: 342
number of training data: 273
number of validation data: 69
**********Training taglets on labeled data**********
...........prototype...........
epoch: 0
train loss: 0.0088
validation loss: 0.0669
va

train loss: 0.0375
train acc: 5.1095%
validation loss: 0.0838
validation acc: 14.8148%
Epoch 1 result: 
Average training loss: 0.0375
Average training accuracy: 5.1095%
Average validation loss: 0.0838
Average validation accuracy: 14.8148%

.....End Model time: 1
Checkpoint scores [{'accuracy': 0.047}, {'accuracy': 0.12}, {'accuracy': 0.181}, {'accuracy': 0.133}, {'accuracy': 0.183}]
Phase: adaptation
------------------------------------------------------------
--------------------Adapt check point: 2---------------------
------------------------------------------------------------
number of unlabeled data: 4868
New labeled images: 32
Total labeled images: 164
number of training data: 131
number of validation data: 33
^^^^^^^^^ adaptation: loading from base
^^^^^^^^^ adaptation: loading from base
^^^^^^^^^ adaptation: loading from base
**********Training taglets on labeled data**********
...........prototype...........
epoch: 0
train loss: 0.0185
validation loss: 0.1401
validation acc: 