In [6]:
import numpy as np
import time
import json
from copy import deepcopy as dc
import torch

import joblib
import re

%load_ext autoreload
%autoreload 2

import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

from data.utils import votes_filter, plot_overlap_conflict
from data.trec.utils import load_data, get_inst, get_subject

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
train = load_data('train')
dev = load_data('dev')
test = load_data('test')

In [9]:
from annotators import WeakRule, BinaryRERules, regex_decision

train_instances, _ = get_inst(train)
dev_instances, dev_labels = get_inst(dev)

In [10]:
'''
0 - ABBR - Abbreviation

1 - DESC - Description and abstract concepts

2 - ENTY - Entities

3 - HUM - Human beings

4 - LOC - Locations

5 - NUM - Numeric values

'''



def first_word(instance):
    st = instance['string']
    word = st.split()[0].lower()
    if word == "who":
        return 0
    elif word == "where":
        return 1
    elif word == "when":
        return 2
    elif word == "why":
        return 3
    elif word == "how":
        return 4
    elif word == "name":
        return 5

   # elif word == 'what':
     #   return 6
    else:
        return -1
    
label_maps = {
    0: [3],
    1: [4],
    2: [5],
    3: [1],
    4: [1, 5],
    5: [2, 3],
    6: [0]
}

r0 = WeakRule(exec_module=first_word, label_maps=label_maps)
v3 = r0.execute(dev_instances)
acc, p, r, cwacc = r0.eval(dev_labels, class_wise_acc=True)
print(acc, p, r)
print(cwacc)

def called(instance):
    st = instance['string'].lower().split()
    if "called" in st:
        return 1
    
    return -1
    
label_maps = {
    0: [0, 1, 5],
    1: [2, 3, 4]
}

r1 = WeakRule(exec_module=called, label_maps=label_maps)
v3 = r1.execute(dev_instances)
acc, p, r, cwacc = r1.eval(dev_labels, class_wise_acc=True)
print(acc, p, r)
print(cwacc)

def mean(instance):
    st = instance['string'].lower().split()
    if "mean" in st or "meaning" in st:
        return 1
    
    return -1
    
label_maps = {
    0: [2, 3, 4, 5],
    1: [0, 1]
}

r2 = WeakRule(exec_module=mean, label_maps=label_maps)
v3 = r2.execute(dev_instances)
acc, p, r, cwacc = r2.eval(dev_labels, class_wise_acc=True)
print(acc, p, r)
print(cwacc)

def abbre(instance):
    st = instance['string'].lower()
    if "stand for" in st or "abbreviat" in st:
        return 1
    
    return -1
    
label_maps = {
    0: [1, 2, 3, 4, 5],
    1: [0]
}

r3 = WeakRule(exec_module=abbre, label_maps=label_maps)
v3 = r3.execute(dev_instances)
acc, p, r, cwacc = r3.eval(dev_labels, class_wise_acc=True)
print(acc, p, r)
print(cwacc)


# DESC
def desc(instance):
    st = instance['string'].lower()
    tokens = st.split()
    if "definition" in tokens or \
            "come from" in st or \
            "origin" in tokens:
        return 1
    
    return -1
    
label_maps = {
    0: [0, 2, 3, 4, 5],
    1: [1]
}


r4 = WeakRule(exec_module=desc, label_maps=label_maps)
v3 = r4.execute(dev_instances)
acc, p, r, cwacc = r4.eval(dev_labels, class_wise_acc=True)
print(acc, p, r)
print(cwacc)

# ENTY
def enty(instance):
    subject = get_subject(instance['string'])
    if subject in ("animal", "body", "color", "creative",
                   "currency", "disease", "event", "food", "instrument",
                   "language", "letter", "plant", "product", "religion",
                   "sport", "substance", "symbol", "technique", "term",
                   "vehicle", "word"):
        return 1
    
    return -1
    
label_maps = {
    0: [0, 1, 3, 4, 5],
    1: [2]
}

r5 = WeakRule(exec_module=enty, label_maps=label_maps)
v3 = r5.execute(dev_instances)
acc, p, r, cwacc = r5.eval(dev_labels, class_wise_acc=True)
print(acc, p, r)
print(cwacc)


# LOC
def loc(instance):
    subject = get_subject(instance['string'])
    if subject in ("city", "country", "mountain", "state", "capital"):
        return 1
    
    return -1
    
label_maps = {
    0: [0, 1, 2, 3, 5],
    1: [4]
}


r7 = WeakRule(exec_module=loc, label_maps=label_maps)
v3 = r7.execute(dev_instances)
acc, p, r, cwacc = r7.eval(dev_labels, class_wise_acc=True)
print(acc, p, r)
print(cwacc)


# NUM
def num(instance):
    st = instance['string'].lower()
    if "what year" in st:
        return 1
    
    return -1
    
label_maps = {
    0: [0, 1, 2, 3, 4],
    1: [5]
}

r8 = WeakRule(exec_module=num, label_maps=label_maps)
v3 = r8.execute(dev_instances)
acc, p, r, cwacc = r8.eval(dev_labels, class_wise_acc=True)
print(acc, p, r)
print(cwacc)


# NUM
def num1(instance):
    st = instance['string'].lower()
    if "how many" in st or 'how much' in st or 'how old' in st:
        return 1
    
    return -1
    
label_maps = {
    0: [0, 1, 2, 3, 4],
    1: [5]
}

r9 = WeakRule(exec_module=num1, label_maps=label_maps)
v3 = r9.execute(dev_instances)
acc, p, r, cwacc = r9.eval(dev_labels, class_wise_acc=True)
print(acc, p, r)
print(cwacc)


test_patt = 'what.*mean'
whatmean = BinaryRERules(name='what_*_mean',re_pattern=test_patt, 
                                 preproc=lambda inst:inst['string'].lower(), 
                                 label_maps={0:[0, 2, 3, 4,5], 1:[1]}, unipolar=True)
v5 = whatmean.execute(dev_instances)

acc, p, r, cwacc = whatmean.eval(dev_labels, class_wise_acc=True)
print(acc, p, r)
print(cwacc)


test_patt = 'what.*use of|what.*origin of|why do'
desc_r1 = BinaryRERules(name='what_*_',re_pattern=test_patt, 
                                 preproc=lambda inst:inst['string'].lower(), 
                                 label_maps={0:[0, 2, 3, 4,5], 1:[1]}, unipolar=True)
v5 = desc_r1.execute(dev_instances)

acc, p, r, cwacc = desc_r1.eval(dev_labels, class_wise_acc=True)
print(acc, p, r)
print(cwacc)


test_patt = 'how far|what.*birthday|how long|how deep|when did|when was|how tall|what month|population|toll|how big|how long|what year'
numpattern1 = BinaryRERules(name='num_patt',re_pattern=test_patt, 
                                 preproc=lambda inst:inst['string'].lower(), 
                                 label_maps={0:[0,1, 2, 3, 4], 1:[5]}, unipolar=True)
v5 = numpattern1.execute(dev_instances)

acc, p, r, cwacc = numpattern1.eval(dev_labels, class_wise_acc=True)
print(acc, p, r)
print(cwacc)


test_patt = 'what is the origin|what is the history|what.*mean|how do you buy|what is the difference|how can I|how do I|what effect'
descpatt_1 = BinaryRERules(name='num_patt',re_pattern=test_patt, 
                                 preproc=lambda inst:inst['string'].lower(), 
                                 label_maps={0:[0, 2, 3, 4,5], 1:[1]}, unipolar=True)
v5 = descpatt_1.execute(dev_instances)

acc, p, r, cwacc = descpatt_1.eval(dev_labels, class_wise_acc=True)
print(acc, p, r)
print(cwacc)


test_patt = 'how.*tell|how d.*affect|how do.*work|how do you fix|how do you get|how do you find|how do I find|how.*made'
descpatt_2 = BinaryRERules(name='num_patt',re_pattern=test_patt, 
                                 preproc=lambda inst:inst['string'].lower(), 
                                 label_maps={0:[0, 2, 3, 4,5], 1:[1]}, unipolar=True)
v5 = descpatt_2.execute(dev_instances)

acc, p, r, cwacc = descpatt_2.eval(dev_labels, class_wise_acc=True)
print(acc, p, r)
print(cwacc)



test_patt = 'how do|how was|how are|how is|how was|how could|how can'
newhow = BinaryRERules(name='num_patt',re_pattern=test_patt, 
                                 preproc=lambda inst:inst['string'].lower(), 
                                 label_maps={0:[0, 2, 3, 4,5], 1:[1]}, unipolar=True)
v5 = newhow.execute(dev_instances)

acc, p, r, cwacc = newhow.eval(dev_labels, class_wise_acc=True)
print(acc, p, r)
print(cwacc)



test_patt = 'what.*stand for'
wsf = BinaryRERules(name='what_*_sf',re_pattern=test_patt, 
                                 preproc=lambda inst:inst['string'].lower(), 
                                 label_maps={0:[1, 2, 3, 4,5], 1:[0]}, unipolar=True)
v5 = wsf.execute(dev_instances)

acc, p, r, cwacc = wsf.eval(dev_labels, class_wise_acc=True)
print(acc, p, r)
print(cwacc)

1.0 1.0 0.3213367609254499
[nan, 1.0, 1.0, 1.0, 1.0, 1.0]
1.0 1.0 0.017994858611825194
[nan, nan, 1.0, 1.0, 1.0, nan]
0.8571428571428571 0.8571428571428571 0.015424164524421594
[nan, 1.0, nan, 0.0, nan, nan]
1.0 1.0 0.012853470437017995
[1.0, nan, nan, nan, nan, nan]
1.0 1.0 0.007712082262210797
[nan, 1.0, nan, nan, nan, nan]
0.8888888888888888 0.8888888888888888 0.02056555269922879
[nan, nan, 1.0, nan, nan, 0.0]
1.0 1.0 0.03598971722365039
[nan, nan, nan, nan, 1.0, nan]
1.0 1.0 0.010282776349614395
[nan, nan, nan, nan, nan, 1.0]
0.9629629629629629 0.9629629629629629 0.06683804627249357
[nan, nan, nan, 0.0, nan, 1.0]
1.0 1.0 0.017994858611825194
[nan, 1.0, nan, nan, nan, nan]
1.0 1.0 0.02056555269922879
[nan, 1.0, nan, nan, nan, nan]
0.96 0.96 0.061696658097686374
[nan, nan, nan, nan, 0.0, 1.0]
1.0 1.0 0.038560411311053984
[nan, 1.0, nan, nan, nan, nan]
1.0 1.0 0.005141388174807198
[nan, 1.0, nan, nan, nan, nan]
1.0 1.0 0.038560411311053984
[nan, 1.0, nan, nan, nan, nan]
1.0 1.0 0.0128

In [11]:
from annotators import executor, evaluator

train_instances, _ = get_inst(train)

plfs = [r0,r1,r2,r3,r7,r8,r9, r4, r5,
              whatmean, 
              desc_r1, 
              numpattern1, 
              descpatt_1, 
              descpatt_2, 
              newhow,
              wsf]

votes, fid2clusters = executor(plfs, train_instances, one_indexed=False)
votes, _, train_idx = votes_filter(votes, return_idx=True)

{0: [3], 1: [4], 2: [5], 3: [1], 4: [1, 5], 5: [2, 3], 6: [0]}
{0: [0, 1, 5], 1: [2, 3, 4]}
{0: [2, 3, 4, 5], 1: [0, 1]}
{0: [1, 2, 3, 4, 5], 1: [0]}
{0: [0, 1, 2, 3, 5], 1: [4]}
{0: [0, 1, 2, 3, 4], 1: [5]}
{0: [0, 1, 2, 3, 4], 1: [5]}
{0: [0, 2, 3, 4, 5], 1: [1]}
{0: [0, 1, 3, 4, 5], 1: [2]}
{0: [0, 2, 3, 4, 5], 1: [1]}
{0: [0, 2, 3, 4, 5], 1: [1]}
{0: [0, 1, 2, 3, 4], 1: [5]}
{0: [0, 2, 3, 4, 5], 1: [1]}
{0: [0, 2, 3, 4, 5], 1: [1]}
{0: [0, 2, 3, 4, 5], 1: [1]}
{0: [1, 2, 3, 4, 5], 1: [0]}


In [12]:
from nplm import Abstention

lm_cfg={'lr': 0.01,
 'epoch': 5,
 'seed': 0,
 'batch_size': 8192,
 'momentum': 0.9,
 'step_schedule': 'p',
 'step_multiplier': 0.1}

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
labelmodel = Abstention(num_classes=6,
                        fid2clusters=fid2clusters,
                        opt_cfg=lm_cfg,
                        preset_cb=torch.ones([6]).to(device)/ 6,
                        device=device, verbose=True)

labelmodel.optimize(votes)
label_estim = labelmodel.weak_label(votes)

epoch   1: 100%|██████████| 1/1 [00:00<00:00, 15.41it/s, Epoch Loss: =13.4]
epoch   2: 100%|██████████| 1/1 [00:00<00:00, 17.44it/s, Epoch Loss: =13.3]
epoch   3: 100%|██████████| 1/1 [00:00<00:00, 17.56it/s, Epoch Loss: =13.3]
epoch   4:   0%|          | 0/1 [00:00<?, ?it/s]

Setup:  0.00919961929321289


epoch   4: 100%|██████████| 1/1 [00:00<00:00, 17.01it/s, Epoch Loss: =13.2]
epoch   5: 100%|██████████| 1/1 [00:00<00:00, 17.14it/s, Epoch Loss: =13.1]


Setup:  0.011838912963867188
Parallel Estimation:  0.007220745086669922


In [13]:
np.save('trec6_soft_labels', {'sl': label_estim, 'idx': train_idx})

In [14]:
from eval import gen_stats, topk_results

data_instance_test, data_label_test = get_inst(test)

votes_test, _ = executor(plfs, data_instance_test, one_indexed=False)


votes_test_filtered, label_filtered_test = votes_filter(votes_test, labels=data_label_test)


wlabels_test = labelmodel.weak_label(votes_test)

test_label_test = [elem+1 for elem in data_label_test]

_, stats = gen_stats(wlabels_test, test_label_test)
print(stats)
print(topk_results(wlabels_test, test_label_test, [1,2]))

{0: [3], 1: [4], 2: [5], 3: [1], 4: [1, 5], 5: [2, 3], 6: [0]}
{0: [0, 1, 5], 1: [2, 3, 4]}
{0: [2, 3, 4, 5], 1: [0, 1]}
{0: [1, 2, 3, 4, 5], 1: [0]}
{0: [0, 1, 2, 3, 5], 1: [4]}
{0: [0, 1, 2, 3, 4], 1: [5]}
{0: [0, 1, 2, 3, 4], 1: [5]}
{0: [0, 2, 3, 4, 5], 1: [1]}
{0: [0, 1, 3, 4, 5], 1: [2]}
{0: [0, 2, 3, 4, 5], 1: [1]}
{0: [0, 2, 3, 4, 5], 1: [1]}
{0: [0, 1, 2, 3, 4], 1: [5]}
{0: [0, 2, 3, 4, 5], 1: [1]}
{0: [0, 2, 3, 4, 5], 1: [1]}
{0: [0, 2, 3, 4, 5], 1: [1]}
{0: [1, 2, 3, 4, 5], 1: [0]}
Setup:  0.0019774436950683594
Parallel Estimation:  0.00121307373046875
(0.382, 0.7317976025615662, 0.4775683434404692, 0.42981563493713243)
[0.446 0.534]
