In [18]:
import sys
sys.path.append('../wrench')
import logging
import torch
import numpy as np
from wrench.dataset import load_dataset
from wrench.logging import LoggingHandler
from wrench.labelmodel import Snorkel
from ars2 import create_unbalanced_set, calc_prior

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])

logger = logging.getLogger(__name__)

device = torch.device('cuda')

#### Load dataset
dataset_path = 'wrench/datasets/'
data = 'chemprot'
imbalance_ratio = 1
train_data, valid_data, test_data = load_dataset(
    dataset_path,
    data,
    extract_feature=True,
    extract_fn='bert', # extract bert embedding
    model_name='bert-base-cased',
    cache_name='bert'
)

2022-06-06 15:21:56 - loading data from wrench/datasets/chemprot/train.json


  0%|          | 0/12861 [00:00<?, ?it/s]

2022-06-06 15:21:56 - loading data from wrench/datasets/chemprot/valid.json


  0%|          | 0/1607 [00:00<?, ?it/s]

2022-06-06 15:21:57 - loading data from wrench/datasets/chemprot/test.json


  0%|          | 0/1607 [00:00<?, ?it/s]

2022-06-06 15:21:57 - loading features from wrench/datasets/chemprot/train_bert.pkl
2022-06-06 15:21:57 - loading features from wrench/datasets/chemprot/valid_bert.pkl
2022-06-06 15:21:57 - loading features from wrench/datasets/chemprot/test_bert.pkl


In [19]:
# print(calc_prior(train_data.labels, train_data.n_class))
# print(calc_prior(valid_data.labels, valid_data.n_class))
imbalance_ids_train = create_unbalanced_set(train_data, imbalance_ratio)
imbalance_ids_valid = create_unbalanced_set(valid_data, imbalance_ratio)
train_data = train_data.create_subset(imbalance_ids_train)
valid_data = valid_data.create_subset(imbalance_ids_valid)
print(calc_prior(train_data.labels, train_data.n_class))
print(calc_prior(valid_data.labels, valid_data.n_class))

[ 540 3422 1662 4067  395  606   55   51 1498  565]
[ 72 398 210 533  52  76   9   9 183  65]
[540, 3422, 1662, 4067, 395, 606, 55, 51, 1498, 565]
[72, 398, 210, 533, 52, 76, 9, 9, 183, 65]


In [20]:
paras = {
    'agnews'    : {'GenerativeModel': {'l2': 0.001, 'lr': 0.0001, 'n_epochs': 100},
                   'Snorkel'        : {'l2': 0.01, 'lr': 0.01, 'n_epochs': 200}},
    'basketball': {'GenerativeModel': {'l2': 0.1, 'lr': 5e-05, 'n_epochs': 50},
                   'Snorkel'        : {'l2': 1e-05, 'lr': 0.0001, 'n_epochs': 10}},
    'cdr'       : {'GenerativeModel': {'l2': 0.01, 'lr': 1e-05, 'n_epochs': 5},
                   'Snorkel'        : {'l2': 0.001, 'lr': 1e-05, 'n_epochs': 5}},
    'census'    : {'GenerativeModel': {'l2': 0.1, 'lr': 1e-05, 'n_epochs': 5},
                   'Snorkel'        : {'l2': 0.0001, 'lr': 0.1, 'n_epochs': 10}},
    'chemprot'  : {'GenerativeModel': {'l2': 1e-05, 'lr': 1e-05, 'n_epochs': 5},
                   'Snorkel'        : {'l2': 0.001, 'lr': 0.001, 'n_epochs': 5}},
    'commercial': {'GenerativeModel': {'l2': 1e-05, 'lr': 1e-05, 'n_epochs': 50},
                   'Snorkel'        : {'l2': 1e-05, 'lr': 0.1, 'n_epochs': 200}},
    'imdb'      : {'GenerativeModel': {'l2': 0.1, 'lr': 0.0001, 'n_epochs': 5},
                   'Snorkel'        : {'l2': 0.001, 'lr': 0.1, 'n_epochs': 200}},
    'semeval'   : {'GenerativeModel': {'l2': 0.1, 'lr': 0.0001, 'n_epochs': 5},
                   'Snorkel'        : {'l2': 0.001, 'lr': 0.001, 'n_epochs': 10}},
    'sms'       : {'GenerativeModel': {'l2': 0.0001, 'lr': 5e-05, 'n_epochs': 100},
                   'Snorkel'        : {'l2': 0.0001, 'lr': 1e-05, 'n_epochs': 5}},
    'spouse'    : {'GenerativeModel': {'l2': 0.1, 'lr': 5e-05, 'n_epochs': 100},
                   'Snorkel'        : {'l2': 0.1, 'lr': 0.001, 'n_epochs': 5}},
    'tennis'    : {'GenerativeModel': {'l2': 0.1, 'lr': 5e-05, 'n_epochs': 5},
                   'Snorkel'        : {'l2': 1e-05, 'lr': 0.0001, 'n_epochs': 100}},
    'trec'      : {'GenerativeModel': {'l2': 0.001, 'lr': 0.0001, 'n_epochs': 200},
                   'Snorkel'        : {'l2': 0.1, 'lr': 0.001, 'n_epochs': 5}},
    'yelp'      : {'GenerativeModel': {'l2': 0.1, 'lr': 0.0001, 'n_epochs': 5},
                   'Snorkel'        : {'l2': 1e-05, 'lr': 0.1, 'n_epochs': 100}},
    'youtube'   : {'GenerativeModel': {'l2': 0.1, 'lr': 0.0001, 'n_epochs': 50},
                   'Snorkel'        : {'l2': 0.01, 'lr': 0.01, 'n_epochs': 200}}}

In [21]:
#### Run label model: Snorkel
label_model = Snorkel(
    **paras[data]['Snorkel']
)
label_model.fit(
    dataset_train=train_data,
    dataset_valid=valid_data
)
acc = label_model.test(test_data, 'acc')
logger.info(f'label model test acc: {acc}')

100%|█████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 168.04epoch/s]

2022-06-06 15:22:03 - label model test acc: 0.5196017423771002





In [22]:
#### Filter out uncovered training data
train_data = train_data.get_covered_subset()
aggregated_hard_labels = label_model.predict(train_data)
aggregated_soft_labels = label_model.predict_proba(train_data)
print(calc_prior(train_data.labels, train_data.n_class))
print(calc_prior(valid_data.labels, valid_data.n_class))

[434, 2580, 1533, 3846, 388, 606, 45, 50, 1092, 498]
[72, 398, 210, 533, 52, 76, 9, 9, 183, 65]


In [23]:
imbalance_ids_train

array([1998, 1592, 4257, ..., 2856, 3287,  502])

In [24]:
np.save(f'./wrench/im_examples/label_model_output/train_ids_imbalance{imbalance_ratio}_{data}.npy', imbalance_ids_train)
np.save(f'./wrench/im_examples/label_model_output/valid_ids_imbalance{imbalance_ratio}_{data}.npy', imbalance_ids_valid)
np.save(f'./wrench/im_examples/label_model_output/pred_imbalance{imbalance_ratio}_{data}_hard.npy', aggregated_hard_labels)