In [19]:
"""
This code is to conduct an experiment about the integration of humanLF and codexLF to improve LF coverage and accuracy.
"""

import fire
import logging
import torch
import numpy as np
from wrench.dataset import load_dataset
from wrench._logging import LoggingHandler
from wrench.labelmodel import Snorkel, MajorityVoting, MajorityWeightedVoting, DawidSkene, FlyingSquid
from wrench.endmodel import EndClassifierModel, LogRegModel

In [20]:
#### Label Model Abbreviations: 
## Snorkel: Snorkel
## Majority Voting: MV
## Weighted Majority Voting: WMV
## Dawid-Skene: DS
## FlyingSquid: FS

#### End Model Abbreviations:
## MLP: MLP
## Logistic Regression: LR
## BERT: BERT

def LM_FS(train_data, valid_data, test_data):
    
    #### Run label model: Flying Squid
    flysquid_label_model = FlyingSquid()
    flysquid_label_model.fit(
        dataset_train=train_data,
        dataset_valid=valid_data
    )

    acc = flysquid_label_model.test(test_data, 'acc')
    f1 = flysquid_label_model.test(test_data, 'f1_binary')
    #### ========================================= ####
    
    return acc, f1, flysquid_label_model

def LM_DS(train_data, valid_data, test_data):
    
    #### Run label model: Dawid-Skene
    ds_label_model = DawidSkene()
    ds_label_model.fit(
        dataset_train=train_data,
        dataset_valid=valid_data
    )

    acc = ds_label_model.test(test_data, 'acc')
    f1 = ds_label_model.test(test_data, 'f1_binary')
    #### ========================================= ####
    
    return acc, f1, ds_label_model

def LM_WMV(train_data, valid_data, test_data):
    
    #### Run label model: Weighted Majority Voting
    majority_weighted_voter = MajorityWeightedVoting()
    majority_weighted_voter.fit(
        dataset_train=train_data,
        dataset_valid=None
    )

    acc = majority_weighted_voter.test(test_data, 'acc')
    f1 = majority_weighted_voter.test(test_data, 'f1_binary')
    #### ========================================= ####
    
    return _, _, majority_weighted_voter

def LM_MV(train_data, valid_data, test_data):
        
    #### Run label model: Majority Vote
    majority_voter = MajorityVoting()
    majority_voter.fit(
        dataset_train=train_data,
        dataset_valid=valid_data
    )

    acc = majority_voter.test(test_data, 'acc')
    f1 = majority_voter.test(test_data, 'f1_binary')
    #### ========================================= ####
    
    return acc, f1, majority_voter

def LM_Snorkel(train_data, valid_data, test_data):
        
    #### Run label model: Snorkel
    snorkel_label_model = Snorkel(
        lr=0.01,
        l2=0.0,
        n_epochs=300
    )
    snorkel_label_model.fit(
        dataset_train=train_data,
        dataset_valid=valid_data
    )

    acc = snorkel_label_model.test(test_data, 'acc')
    f1 = snorkel_label_model.test(test_data, 'f1_binary')
    #### ========================================= ####
    
    return acc, f1, snorkel_label_model

def EM_MLP(train_data, soft_labels, valid_data, test_data, device):
    
    #### Run end model: MLP
    model = EndClassifierModel(
        batch_size=512,
        test_batch_size=512,
        n_steps=10000,
        backbone='MLP',
        optimizer='Adam',
        optimizer_lr=1e-2,
        optimizer_weight_decay=0.0,
    )
    model.fit(
        dataset_train=train_data,
        y_train=soft_labels,
        dataset_valid=valid_data,
        evaluation_step=10,
        metric='acc',
        patience=100,
        device=device
    )

    acc = model.test(test_data, 'acc')
    f1 = model.test(test_data, 'f1_binary')
    #### ========================================= ####
    
    return acc, f1, model

def EM_LR(train_data, soft_labels, valid_data, test_data, device):

    #### Run end model: Logistic Regression
    model = LogRegModel(
        batch_size=512,
        test_batch_size=512,
        n_steps=10000
    )
    model.fit(
        dataset_train=train_data,
        y_train=soft_labels,
        dataset_valid=valid_data,
        evaluation_step=10,
        metric='acc',
        patience=100,
        device=device
    )
    
    acc = model.test(test_data, 'acc')
    f1 = model.test(test_data, 'f1_binary')
    #### ========================================= ####
    
    return acc, f1, model

In [21]:
dataset="spouse"
codexLF=False
prompt_type=None
LM="Snorkel"
EM="LR"

In [22]:
## Basic Config ##
if codexLF == True: 
    LF_type = prompt_type + "_codex"
else: 
    LF_type = "human"

logging.basicConfig(
    filename="../exp_log/" + dataset + "/" + LF_type + "_" + LM + "_" + EM + ".txt",
    filemode="w",
    format='%(asctime)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
    level=logging.INFO
)

logger = logging.getLogger(__name__)
dataset_path = '../datasets/'
device = torch.device('cuda:0')
bert_feature = True
if dataset == "census": bert_feature = False
logger.info(f'Dataset: {dataset}')
##################

## Load Dataset ##
train_data, valid_data, test_data = load_dataset(
    dataset_path,
    dataset,
    use_codexLF=codexLF,
    pt_type=prompt_type,
    extract_feature=bert_feature,
    extract_fn='bert', # extract bert embedding
    model_name='bert-base-cased',
    cache_name='bert'
)
##################

## Train Label Model ##
if LM == "Snorkel": 
    lm_acc, lm_f1, lm = LM_Snorkel(train_data, valid_data, test_data)
    logger.info(f'Snorkel label model test acc: {round(lm_acc, 4)}')
    logger.info(f'Snorkel label model test f1: {round(lm_f1, 4)}')
elif LM == "MV": 
    lm_acc, lm_f1, lm = LM_MV(train_data, valid_data, test_data)
    logger.info(f'majority voter test acc: {round(lm_acc, 4)}')
    logger.info(f'majority voter test f1: {round(lm_f1, 4)}')
elif LM == "WMV":
    lm_acc, lm_f1, lm = LM_WMV(train_data, valid_data, test_data)
    logger.info(f'weighted majority voter test acc: {round(lm_acc, 4)}')
    logger.info(f'weighted majority voter test f1: {round(lm_f1, 4)}')
elif LM == "DS": 
    lm_acc, lm_f1, lm = LM_DS(train_data, valid_data, test_data)
    logger.info(f'Dawid-Skene label model test acc: {round(lm_acc, 4)}')
    logger.info(f'Dawid-Skene label model test f1: {round(lm_f1, 4)}')
elif LM == "FS": 
    lm_acc, lm_f1, lm = LM_FS(train_data, valid_data, test_data)
    logger.info(f'Flying Squid label model test acc: {round(lm_acc, 4)}')
    logger.info(f'Flying Squid label model test f1: {round(lm_f1, 4)}')
else: 
    logger.info(f'Cannot find this LM')
##################

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 22254/22254 [00:00<00:00, 315275.14it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2811/2811 [00:00<00:00, 242391.99it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2701/2701 [00:00<00:00, 150237.58it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 300/300 [00:26<00:00, 11.16epoch/s]


In [23]:
## Filter out uncovered training data ##
train_data_uncovered = train_data.get_uncovered_subset()
print("Uncovered: ", len(train_data_uncovered))
train_data_covered = train_data.get_covered_subset()
print("Covered: ", len(train_data_covered))
lm_cov = len(train_data_covered) / len(train_data)
logger.info(f'label model test cov: {round(lm_cov, 4)}')

Uncovered:  16520
Covered:  5734


In [24]:
aggregated_hard_labels_human = lm.predict(train_data_covered)
aggregated_soft_labels_human = lm.predict_proba(train_data_covered)

In [25]:
from datasets.spouse.LFs.human_label_function import LF1, LF2, LF3, LF4, LF5

In [26]:
LF_list = [LF1, LF2, LF3, LF4, LF5]

In [27]:
for i, example in enumerate(train_data_uncovered.examples):
    data = example["text"]
    weak_labels = []
    for LF in LF_list:
        weak_label = LF.is_spouse(data)
        if weak_label == None:
            weak_label = -1
        weak_labels.append(weak_label)
    train_data_uncovered.weak_labels[i] = weak_labels

In [28]:
if LM == "Snorkel": 
    lm2_acc, lm2_f1, lm2 = LM_Snorkel(train_data_uncovered, train_data_uncovered, train_data_uncovered)
elif LM == "MV": 
    lm2_acc, lm2_f1, lm2 = LM_MV(train_data_uncovered, train_data_uncovered, train_data_uncovered)
elif LM == "WMV":
    lm2_acc, lm2_f1, lm2 = LM_WMV(train_data_uncovered, train_data_uncovered, train_data_uncovered)
elif LM == "DS": 
    lm2_acc, lm2_f1, lm2 = LM_DS(train_data_uncovered, train_data_uncovered, train_data_uncovered)
elif LM == "FS": 
    lm2_acc, lm2_f1, lm2 = LM_FS(train_data_uncovered, train_data_uncovered, train_data_uncovered)
else: 
    logger.info(f'Cannot find this LM')
##################

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 300/300 [00:00<00:00, 627.22epoch/s]


In [29]:
aggregated_hard_labels_codex = lm2.predict(train_data_uncovered)
aggregated_soft_labels_codex = lm2.predict_proba(train_data_uncovered)

In [30]:
"""
train_data_covered.ids = train_data_covered.ids + train_data_uncovered.ids
train_data_covered.labels = train_data_covered.labels + train_data_uncovered.labels
train_data_covered.features = np.vstack((train_data_covered.features, train_data_uncovered.features))
train_data_covered.examples = train_data_covered.examples + train_data_uncovered.examples
train_data_covered.weak_labels = train_data_covered.weak_labels + train_data_uncovered.weak_labels
aggregated_hard_labels_integration = np.concatenate((aggregated_hard_labels_human, aggregated_hard_labels_codex))
aggregated_soft_labels_integration = np.vstack((aggregated_soft_labels_human, aggregated_soft_labels_codex))
"""

'\ntrain_data_covered.ids = train_data_covered.ids + train_data_uncovered.ids\ntrain_data_covered.labels = train_data_covered.labels + train_data_uncovered.labels\ntrain_data_covered.features = np.vstack((train_data_covered.features, train_data_uncovered.features))\ntrain_data_covered.examples = train_data_covered.examples + train_data_uncovered.examples\ntrain_data_covered.weak_labels = train_data_covered.weak_labels + train_data_uncovered.weak_labels\naggregated_hard_labels_integration = np.concatenate((aggregated_hard_labels_human, aggregated_hard_labels_codex))\naggregated_soft_labels_integration = np.vstack((aggregated_soft_labels_human, aggregated_soft_labels_codex))\n'

In [31]:
train_X_human = np.vstack((train_data_covered.features))
train_X_integration = np.vstack((train_data_covered.features, train_data_uncovered.features))

train_y_human = aggregated_hard_labels_human
train_y_integration = np.concatenate((aggregated_hard_labels_human, aggregated_hard_labels_codex))

test_X = test_data.features
test_y = test_data.labels

In [32]:
"""
## Train End Model ##
if EM == "MLP":
    em_acc, em_f1, em = EM_MLP(train_data_covered, aggregated_soft_labels_integration, valid_data, test_data, device)
    logger.info(f'End model (MLP) test acc: {round(em_acc, 4)}')
    logger.info(f'End model (MLP) test f1: {round(em_f1, 4)}')
elif EM == "LR":
    em_acc, em_f1, em = EM_LR(train_data_covered, aggregated_soft_labels_integration, valid_data, test_data, device)
    logger.info(f'End model (Logistic Regression) test acc: {round(em_acc, 4)}')
    logger.info(f'End model (Logistic Regression) test f1: {round(em_f1, 4)}')
elif EM == "Stop":
    logger.info(f'Not going to run end model')
else:
    logger.info(f'Cannot find this EM')
"""

'\n## Train End Model ##\nif EM == "MLP":\n    em_acc, em_f1, em = EM_MLP(train_data_covered, aggregated_soft_labels_integration, valid_data, test_data, device)\n    logger.info(f\'End model (MLP) test acc: {round(em_acc, 4)}\')\n    logger.info(f\'End model (MLP) test f1: {round(em_f1, 4)}\')\nelif EM == "LR":\n    em_acc, em_f1, em = EM_LR(train_data_covered, aggregated_soft_labels_integration, valid_data, test_data, device)\n    logger.info(f\'End model (Logistic Regression) test acc: {round(em_acc, 4)}\')\n    logger.info(f\'End model (Logistic Regression) test f1: {round(em_f1, 4)}\')\nelif EM == "Stop":\n    logger.info(f\'Not going to run end model\')\nelse:\n    logger.info(f\'Cannot find this EM\')\n'

In [33]:
from sklearn.linear_model import LogisticRegression

In [34]:
lgr_human = LogisticRegression(C=1e4, solver="lbfgs", max_iter=1000)
lgr_human.fit(X=train_X_human, y=train_y_human)
print(f"Test Accuracy: {lgr_human.score(X=test_X, y=test_y) * 100:.1f}%")

Test Accuracy: 92.0%


In [35]:
lgr_integration = LogisticRegression(C=1e4, solver="lbfgs", max_iter=1000)
lgr_integration.fit(X=train_X_integration, y=train_y_integration)
print(f"Test Accuracy: {lgr_integration.score(X=test_X, y=test_y) * 100:.1f}%")

KeyboardInterrupt: 