# Composing LFs

In [10]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

# Setting Snorkel DB location
import os

#For PostgreSQL
postgres_location = 'postgresql://jdunnmon:123@localhost:5432'
postgres_db_name = 'memex_json_db_snorkel'
os.environ['SNORKELDB'] = os.path.join(postgres_location,postgres_db_name)

# For SQLite
#db_location = '.'
#db_name = "snorkel_memex.db"
#os.environ['SNORKELDB'] = '{0}:///{1}/{2}'.format("sqlite", db_location, db_name)

# Start Snorkel session
from snorkel import SnorkelSession
session = SnorkelSession()

# Setting parallelism
parallelism = 16

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
from snorkel.annotations import load_gold_labels

L_gold_dev = load_gold_labels(session, annotator_name='gold', split=1)

In [29]:
from snorkel.models import Candidate, candidate_subclass

# Designing candidate subclasses
LocationExtraction = candidate_subclass('Location', ['location'])

cands_dev = session.query(LocationExtraction).filter(LocationExtraction.split == 1).all()

In [171]:
import geograpy
# installed from pip install git+https://github.com/reach2ashish/geograpy.git
ind = 288
cand = cands_dev[ind]
label = int(L_gold_dev[ind].toarray())
txt = cand.location.get_span()
sent = cand.get_parent().text
doc = cand.get_parent().document.sentences
print(f'cand: {txt}')
print(f'sent: {sent}')
print(f'label: {label:d}')
#print(f'doc: {doc}')
placs = geograpy.get_place_context(text=sent)
placs.country_regions
#placs.address_strings

from geograpy import extraction

e = extraction.Extractor(text=sent)
e.find_entities()
print(e.places)

cand: Ottawa Valley
sent: "Look no further, Kings and Queens of the Ottawa Valley court!
label: 1
['Kings', 'Queens', 'Ottawa Valley']


In [281]:
lf_from(cand)

(based|from|out)


1

In [300]:
import re
from fonduer.lf_helpers import (
    get_left_ngrams, get_right_ngrams, get_between_ngrams
)
from snorkel.lf_helpers import get_tagged_text
from snorkel.learning.utils import MentionScorer
import geotext
from geograpy import extraction
from snorkel_utils import *

def test_LF(test_candidates, test_labels, lf):
    """
    Gets the accuracy of a single LF on a split of the candidates, w.r.t. annotator labels,
    and also returns the error buckets of the candidates.
    """
    scorer          = MentionScorer(test_candidates, test_labels)
    test_marginals  = np.array([0.5 * (lf(c) + 1) for c in test_candidates])
    return scorer.score(test_marginals, set_unlabeled_as_neg=False, set_at_thresh_as_neg=False)
                
# List to parenthetical
def ltp(x):
    return '(' + '|'.join(x) + ')'

def rule_regex_search_before_A(candidate, pattern, sign):
    return sign if re.search(pattern + r'*{{A}}', get_tagged_text(candidate), flags=re.I) else 0

def overlap(a, b):
    """Check if a overlaps b.
    This is typically used to check if ANY of a list of phrases is in the ngrams returned by an lf_helper.
    :param a: A collection of items
    :param b: A collection of items
    :rtype: boolean
    """
    return not set(a).isdisjoint(b)

# PUT LFs HERE

def lf_geograpy_entity_neg(c):
    txt = c.location.get_span().lower()
    sent = c.get_parent().text
    e = extraction.Extractor(text=sent)
    e.find_entities()
    places = [p.lower() for p in e.places]
    if txt not in places:
        return -1
    else:
        return 0
    
def lf_geograpy_entity_pos(c):
    txt = c.location.get_span().lower()
    sent = c.get_parent().text
    e = extraction.Extractor(text=sent)
    e.find_entities()
    places = [p.lower() for p in e.places]
    if txt not in places:
        return 0
    else:
        return 1
    
def lf_geograpy_country(c):
    txt = c.location.get_span().lower()
    sent = c.get_parent().text
    places = geograpy.get_place_context(text=sent)
    if places.countries:
        return 1
    else:
        return 0
    
#def lf_location_words(c):
#    location_words = ['place']
#    txt = c.location.get_span().lower()
#    sent = c.get_parent().text
#    if any([a in sent.lower() for a in location_words]):
#        return 1 
#    else: 
#        return 0

def lf_call(c):
    call_words = ['call']
    return -1 if overlap(
      call_words, 
      get_left_ngrams(c, window=1)) else 0

def lf_many_locations(c):
    txt = c.location.get_span().lower()
    sent = c.get_parent().text
    e = extraction.Extractor(text=sent)
    e.find_entities()
    thresh = 3
    return -1 if len(e.places)>thresh else 0


def lf_is_country(c):
    txt = c.location.get_span().lower()
    if lookup_country_name(txt).lower() != 'no country': return 1 
    if lookup_country_alpha2(txt).lower() != 'no country': return 1 
    if lookup_country_alpha3(txt).lower() != 'no country': 
        return 1 
    else:
        return 0
    
    
def lf_is_state(c):
    txt = c.location.get_span().lower()
    if lookup_state_name(txt).lower() != 'no state' : return 1
    if lookup_state_abbr(txt).lower() != 'no state':
        return 1 
    else:
        return 0
    
def lf_following_words(c):
    following_words = ['area', 'escort', 'province']
    return 1 if overlap(
      following_words, 
      get_left_ngrams(c, window=3)) else 0

def lf_preceding_words(c):
    preceding_words = ['escort','province','area']
    return 1 if overlap(
      preceding_words, 
      get_right_ngrams(c, window=3)) else 0

def lf_escort(c):
    words = ['escort']
    return 1 if overlap(
      words, 
      get_right_ngrams(c, window=2)) or overlap(
      words, 
      get_left_ngrams(c, window=10))else 0

def lf_from(c):
    words = ['based']
    return 1 if overlap(
      words, 
      get_right_ngrams(c, window=5)) or overlap(
      words, 
      get_left_ngrams(c, window=5))else 0

def lf_area(c):
    words = ['area']
    return 1 if overlap(
      words, 
      get_right_ngrams(c, window=5)) or overlap(
      words, 
      get_left_ngrams(c, window=5))else 0


def lf_long_context(c):
    sent = c.get_parent().text.split()
    return -1 if len(sent)>10 else 0

#def lf_comma(c):
#    words = [',']
#    return 1 if overlap(
#      words, 
#      get_right_ngrams(c, window=1)) and overlap(
#      words, 
#      get_left_ngrams(c, window=1))else 0


#def lf_prep_before(c):
#    prep_words = ['to']
#    return -1 if overlap(
#      prep_words, 
#      get_left_ngrams(c, window=1)) else 0

In [301]:
tp, fp, tn, fn = test_LF(cands_dev, L_gold_dev, lf_geograpy_country)

Scores (Un-adjusted)
Pos. class accuracy: 1.0
Neg. class accuracy: 0.0
Precision            0.328
Recall               1.0
F1                   0.495
----------------------------------------
TP: 923 | FP: 1887 | TN: 0 | FN: 0



In [None]:
LFs = [
    lf_geograpy_entity_pos,
    lf_geograpy_entity_neg,
    #lf_location_words,
    lf_call,
    lf_many_locations,
    lf_following_words,
    lf_from,
]

In [137]:
from  snorkel.annotations import LabelAnnotator
import numpy as np
labeler = LabelAnnotator(lfs=LFs)

np.random.seed(1701)
%time L_dev = labeler.apply(split=1, parallelism=16)

Clearing existing...
Running UDF...
CPU times: user 7.45 s, sys: 1.04 s, total: 8.48 s
Wall time: 1min 14s


In [140]:
L_dev.lf_stats(session, L_gold_dev)

Unnamed: 0,j,Coverage,Overlaps,Conflicts,TP,FP,FN,TN,Empirical Acc.
lf_geograpy_entity_pos,0,0.735759,0.421733,0.41192,976,2098,0,0,0.317502
lf_geograpy_entity_neg,1,0.264241,0.114648,0.014361,0,0,328,776,0.702899
lf_location_words,2,0.056247,0.056247,0.046434,47,188,0,0,0.2
lf_call,3,0.006702,0.006702,0.001915,0,0,0,28,1.0
lf_many_locations,4,0.517951,0.517951,0.421494,0,0,621,1543,0.713031


In [113]:
np.sum(L_gold_dev==1)/len(L_gold_dev.todense())

0.3121110579224509

In [129]:
np.sum(L_gold_dev==-1)/len(L_gold_dev.todense())

0.687888942077549

In [124]:
(len(tp)+len(tn))/(len(fp)+len(fn)+len(tn)+len(fp))

0.3305660377358491

In [125]:
len(tp)/(len(tp)+len(fn))

0.7484662576687117

In [128]:
776/(776+328)

0.7028985507246377

In [260]:
from snorkel.viewer import SentenceNgramViewer
labeled = []
cands_dev_red = cands_dev[400:500]
L_gold_dev_red = L_gold_dev[400:500]
for ii, c in enumerate(cands_dev_red):
    if L_gold_dev_red[ii] == 1:
        labeled.append(c)
print("Number labeled:", len(labeled))

SentenceNgramViewer(labeled, session)

Number labeled: 32


<IPython.core.display.Javascript object>

SentenceNgramViewer(cids=[[[1, 10], [7], [22]], [[15], [4], [13]], [[19, 21], [14], [2]], [[5], [17], [26, 29]…

Clearing existing...
Running UDF...

KeyboardInterrupt: 

In [None]:
L_dev.lf_stats(L_gold_dev)

In [12]:
from  snorkel.annotations import LabelAnnotator
import numpy as np
labeler = LabelAnnotator(lfs=LFs)

np.random.seed(1701)
%time L_train = labeler.apply(split=0)
%time L_test = labeler.apply(split=0)
#L_train

# can also load with:
# %time L_train = labeler.load_matrix(session, split=0)

Clearing existing...
Running UDF...

CPU times: user 29min 16s, sys: 6.22 s, total: 29min 23s
Wall time: 30min 5s


In [13]:
L_train.lf_stats(session)

Unnamed: 0,j,Coverage,Overlaps,Conflicts
lf_geograpy_entity,0,1.0,0.100187,0.092394
lf_location_words,1,0.092456,0.092456,0.092394
lf_call,2,0.008105,0.008105,0.000374
lf_many_locations,3,0.0,0.0,0.0


## Training the Generative Model

In [17]:
from snorkel.learning import GenerativeModel

gen_model = GenerativeModel(n_threads=parallelism)
gen_model.train(L_train, epochs=100, decay=0.95, step_size=0.1 / L_train.shape[0], reg_param=1e-6)

TypeError: __init__() got an unexpected keyword argument 'n_threads'

In [None]:
gen_model.weights.lf_accuracy

In [None]:
train_marginals = gen_model.marginals(L_train)

In [None]:
import matplotlib.pyplot as plt
plt.hist(train_marginals, bins=20)
plt.show()

In [None]:
gen_model.learned_lf_stats()

In [None]:
L_dev = labeler.apply_existing(split=1)

In [None]:
tp, fp, tn, fn = gen_model.error_analysis(session, L_dev, L_gold_dev)

## Error Analysis

In [None]:
from snorkel.viewer import SentenceNgramViewer

# NOTE: This if-then statement is only to avoid opening the viewer during automated testing of this notebook
# You should ignore this!
import os
if 'CI' not in os.environ:
    sv = SentenceNgramViewer(fn, session)
else:
    sv = None

In [None]:
sv

In [None]:
c = sv.get_selected() if sv else list(fp.union(fn))[0]
c

In [None]:
c.labels

In [None]:
L_dev.lf_stats(session, L_gold_dev, gen_model.learned_lf_stats()['Accuracy'])

In [None]:
from snorkel.annotations import save_marginals
%time save_marginals(session, L_train, train_marginals)