In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import numpy as np

In [2]:
config = {
    'project': 'babble',
    'domain': 'protein',
    'db_name': 'babble_protein_goldlabeled',
    'splits': [0,1,2],
    'gold_explanations': True,
}

In [3]:
# Get DB connection string and add to globals
# NOTE: $SNORKELDB must be set before any snorkel imports
import os

default_db_name = 'babble_' + config['domain'] + ('_debug' if config.get('debug', False) else '')
DB_NAME = config.get('db_name', default_db_name)
if 'postgres' in config and config['postgres']:
    DB_TYPE = 'postgres'
else:
    DB_TYPE = 'sqlite'
    DB_NAME += '.db'
DB_ADDR = "localhost:{0}".format(config['db_port']) if 'db_port' in config else ""
os.environ['SNORKELDB'] = '{0}://{1}/{2}'.format(DB_TYPE, DB_ADDR, DB_NAME)
print("$SNORKELDB = {0}".format(os.environ['SNORKELDB']))

$SNORKELDB = sqlite:///babble_protein_goldlabeled.db


In [4]:
from snorkel import SnorkelSession
session = SnorkelSession()

# Resolve config conflicts (nb_config > local_config > global_config)
from snorkel.contrib.pipelines import merge_configs, get_local_pipeline
config = merge_configs(config)

from snorkel.models import candidate_subclass
candidate_class = candidate_subclass(config['candidate_name'], config['candidate_entities'])

pipeline = get_local_pipeline(config['domain'])
pipe = pipeline(session, candidate_class, config)

Overwriting domain=None to domain=protein
Overwriting lr=0.01 to lr=0.001
Overwriting print_freq=1 to print_freq=5
Overwriting l2_penalty=1.0 to l2_penalty=0.1
Overwriting rebalance=0.25 to rebalance=0.5
Overwriting l1_penalty=1.0 to l1_penalty=0
Overwriting n_epochs=25 to n_epochs=20
Overwriting babbler_candidate_split=1 to babbler_candidate_split=[0, 1, 2]
Overwriting step_size=0.01 to step_size=0.25
Overwriting reg_param=0.25 to reg_param=0.1
Overwriting gold_explanations=False to gold_explanations=True
Overwriting traditional_split=0 to traditional_split=1
Using ProteinPipeline object.


In [5]:
# pipe.load_gold()

In [6]:
pipe.collect()

Linking candidates...
# CANDIDATES: 7615
Building list of target candidate ids...
Collected 29 unique target candidate ids from 30 explanations.
Gathering desired candidates...
Found 29/29 desired candidates
Linking explanations to candidates...
Linked 30/30 explanations
Calling babbler...
Created grammar with 599 rules


In [7]:
from snorkel.annotations import load_gold_labels, load_label_matrix

TRAIN = 0
DEV = 1
TEST = 2

SPLIT = TEST

L = load_label_matrix(pipe.session, split=SPLIT)
L_gold = load_gold_labels(pipe.session, annotator_name='gold', split=SPLIT)
gold = np.array(L_gold.todense().T)[0]
candidates = [L.get_candidate(pipe.session, i) for i in range(L.shape[0])]

orig_accs = np.array(L.lf_stats(pipe.session, labels=L_gold)['Empirical Acc.'])
orig_lfs = [pipe.babbler.semparser.grammar.translate(exp.semantics) for exp in pipe.explanations]

  ac = (tp+tn).astype(float) / (tp+tn+fp+fn)


In [8]:
FILE = 'no_string_{}_lfs.npy'.format(config['domain'])
perts_all = np.load(os.environ['SNORKELHOME'] + '/experiments/babble/' + config['domain'] + '/data/' + FILE)

In [9]:
from collections import namedtuple

ParseMock = namedtuple('ParseMock', ['semantics'])

In [10]:
pert_wins = 0
pert_lose = 0

total_better = 0
total_worse = 0
total_failed = 0
    
garbage = []
    
for i, (orig_lf, orig_acc, perts_list) in enumerate(zip(orig_lfs, orig_accs, perts_all)):
    print("\nLF #{}".format(i))
    print("Original LF: {}".format(orig_lf))
    print("Original Acc: {}".format(orig_acc))
    better = 0
    worse = 0
    failed = 0
    for pert in perts_list:
        sem = eval(pert)
        try:
            lf = pipe.babbler.semparser.grammar.evaluate(ParseMock(sem))
            sig = [lf(c) for c in candidates]
            
            prod = np.multiply(sig, gold)
            agree = sum(prod == 1)
            disag = sum(prod == -1)
            acc = float(agree)/(agree+disag)
            
            if acc > orig_acc:
                print("Better Acc: {}".format(acc))
                better += 1
                print("Better LF: {}".format(pipe.babbler.semparser.grammar.translate(sem)))
            else:
                worse += 1
        except:
            failed += 1
            garbage.append(sem)
        if better:
            break
            
    if better:
        pert_wins += 1
    else:
        pert_lose += 1
        
    total_better += better
    total_worse += worse
    total_failed += failed
        
print("Summary:")
print("Pert Wins: {}".format(pert_wins))
print("Pert Lose: {}".format(pert_lose))

print("Total Better: {}".format(total_better))
print("Total Worse : {}".format(total_worse))
print("Total Failed: {}".format(total_failed))


LF #0
Original LF: return 1 if (any([s.in(text(between([X,Y]))) for s in user_list('prep')]) and (not any([s.in(text(between([X,Y]))) for s in user_list('negative')]) and count(between([X,Y])).(< 10))) else 0
Original Acc: 0.485436893204

LF #1
Original LF: return -1 if any([s.in(text(the sentence)) for s in user_list('nucleic_acids')]) else 0
Original Acc: 0.955882352941

LF #2
Original LF: return 1 if ('activ'.in(text(between([X,Y]))) and 'activ'.(all([in(text(within 40 chars of X)),in(text(within 40 chars of Y))]))) else 0
Original Acc: 0.134615384615
Better Acc: 0.161290322581
Better LF: return 1 if ('activ'.in(text(between([X,Y]))) and 'activ'.(all([in(text(within 30 chars of X)),in(text(within 30 chars of Y))]))) else 0

LF #3
Original LF: return 1 if (text(the sentence).contains('activates') and text(X).in(text(less than 6 word(s) to the right of Y))) else 0
Original Acc: 0.75
Better Acc: 1.0
Better LF: return 1 if (text(the sentence).contains('activates') and text(X).in(text(e

In [11]:
garbage

[('.root',
  ('.label',
   ('.bool', True),
   ('.and',
    ('.any',
     ('.map',
      ('.in',
       ('.extract_text',
        ('.between',
         ('.list', ('.arg', ('.int', 0)), ('.arg', ('.int', 2)))))),
      ('.user_list', ('.string', u'prep')))),
    ('.and',
     ('.none',
      ('.map',
       ('.in',
        ('.extract_text',
         ('.between',
          ('.list', ('.arg', ('.int', 0)), ('.arg', ('.int', 2)))))),
       ('.user_list', ('.string', u'negative')))),
     ('.call',
      ('.lt', ('.int', 0)),
      ('.count',
       ('.between',
        ('.list', ('.arg', ('.int', 0)), ('.arg', ('.int', 2)))))))))),
 ('.root',
  ('.label',
   ('.bool', True),
   ('.and',
    ('.any',
     ('.map',
      ('.in',
       ('.extract_text',
        ('.between',
         ('.list', ('.arg', ('.int', 2)), ('.arg', ('.int', 2)))))),
      ('.user_list', ('.string', u'prep')))),
    ('.and',
     ('.none',
      ('.map',
       ('.in',
        ('.extract_text',
         ('.between',