# Step 2: Write Labeling Functions and Train Generative Model

The first thing to do is ensure that modules are auto-reloaded at runtime to allow for development in other files.

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

We then set the Snorkel database location and start and connect to it.  By default, we use a PosgreSQL database backend, which can be created using `createdb DB_NAME` once psql is installed.  Note that Snorkel does *not* currently support parallel database processing with a SQLite backend.

In [None]:
# Setting Snorkel DB location
import os
import sys

import random
import numpy as np

#For network PostgreSQL
#postgres_location = 'postgresql://jdunnmon:123@localhost:5432'
#postgres_db_name = 'es_locs_small'
#os.environ['SNORKELDB'] = os.path.join(postgres_location,postgres_db_name)

#For local PostgreSQL
os.environ['SNORKELDB'] = 'postgres:///es_locs_small'

# Adding path above for utils
sys.path.append('..')

# For SQLite
#db_location = '.'
#db_name = "es_locs_small.db"
#os.environ['SNORKELDB'] = '{0}:///{1}/{2}'.format("sqlite", db_location, db_name)

# Start Snorkel session
from snorkel import SnorkelSession
session = SnorkelSession()

# Setting parallelism
parallelism = 32

# Setting random seed
seed = 1701
random.seed(seed)
np.random.seed(seed)

Create candidate subclass and get dev set candidates.

In [None]:
from dataset_utils import create_candidate_class

# Setting extraction type -- should be a subfield in your data source extractions field!
extraction_type = 'location'

# Creating candidate class
candidate_class, candidate_class_name  = create_candidate_class(extraction_type)

# Getting dev set and printing length
cands_dev = session.query(candidate_class).filter(candidate_class.split == 1).order_by(candidate_class.id).all()
print(f'Dev Candidates: {len(cands_dev)}')

Write Labeling Functions (LFs).

In [None]:
import re
from fonduer.lf_helpers import get_left_ngrams, get_right_ngrams, get_between_ngrams
from snorkel.lf_helpers import get_tagged_text

import geotext
import geograpy
from geograpy import extraction

from gm_utils import *

def lf_geograpy_entity_neg(c):
    txt = c.location.get_span().lower()
    sent = c.get_parent().text
    e = extraction.Extractor(text=sent)
    e.find_entities()
    places = [p.lower() for p in e.places]
    if txt not in places:
        return -1
    else:
        return 0

def lf_geograpy_country(c):
    txt = c.location.get_span().lower()
    sent = c.get_parent().text
    places = geograpy.get_place_context(text=sent)
    if places.countries:
        return 1
    else:
        return 0

def lf_geograpy_region(c):
    txt = c.location.get_span().lower()
    sent = c.get_parent().text
    places = geograpy.get_place_context(text=sent)
    if places.regions:
        return 1
    else:
        return 0

def lf_geograpy_city(c):
    txt = c.location.get_span().lower()
    sent = c.get_parent().text
    places = geograpy.get_place_context(text=sent)
    if places.cities:
        return 1
    else:
        return 0

def lf_geograpy_other(c):
    txt = c.location.get_span().lower()
    sent = c.get_parent().text
    places = geograpy.get_place_context(text=sent)
    if places.other:
        return 1
    else:
        return 0
    
def lf_following_words(c):
    following_words = ['area', 'region', 'province', 'county', 'in']
    return 1 if overlap(
      following_words, 
      get_right_ngrams(c, window=3)) else 0

def lf_preceding_words(c):
    preceding_words = ['in', 'to', 'at', 'of', 'north', 'south', 'east', 'west', 'address', 'downtown', 'en']
    return 1 if overlap(
      preceding_words,
      get_left_ngrams(c, window=4)) else 0

def lf_preceding_words_enhanced(c):
    preceding_words = ['based', 'back', 'come', 'new', 'located', 'location', 'you', 'your']
    return 1 if overlap(
      preceding_words, 
      get_left_ngrams(c, window=4)) else 0

def lf_neg_preceding_words(c):
    preceding_words = ['trained', 'show', 'from']
    return -1 if overlap(
      preceding_words,
      get_left_ngrams(c, window=4)) else 0

def lf_in_url(c):
    following_words = ['url']
    return 1 if overlap(
      following_words, 
      get_right_ngrams(c, window=2)) else 0

def lf_many_locations(c):
    txt = c.location.get_span().lower()
    sent = c.get_parent().text
    e = extraction.Extractor(text=sent)
    e.find_entities()
    thresh = 3
    return -1 if len(e.places)>thresh else 0

def lf_nonletter(c):
    txt = c.location.get_span().lower()
    reg = re.compile(r'[^a-z ]')
    if reg.search(txt):
        return -1
    else:
        return 0

In [None]:
def lf_loc(c):
    return lf_geograpy_region(c) or lf_geograpy_city(c)
    
def lf_nonloc(c):
    return -1 if not lf_loc(c) else 0

def lf_from(c):
    return lf_preceding_words(c) and lf_loc(c)

def lf_from_enhanced(c):
    return lf_preceding_words(c) and lf_preceding_words_enhanced(c) and lf_loc(c)

def lf_loc_url(c):
    return lf_loc(c) and lf_in_url(c)

Creating list of LFs to use

In [None]:
LFs = [
    lf_geograpy_entity_neg,
    lf_nonletter,
    lf_nonloc,
    lf_from,
    lf_from_enhanced,
    lf_neg_preceding_words,
    lf_many_locations,
    lf_loc_url,
]

Loading gold dev set labels from database.

In [None]:
from snorkel.annotations import load_gold_labels

L_gold_dev = load_gold_labels(session, annotator_name='gold', split=1)

Evaluating labeling functions on dev set.

In [None]:
# Only run once
"""
import nltk
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
nltk.download('maxent_ne_chunker')
nltk.download('words')
"""

In [None]:
from  snorkel.annotations import LabelAnnotator
import numpy as np
labeler = LabelAnnotator(lfs=LFs)

%time L_dev = labeler.apply(split=1, parallelism=parallelism)
L_dev.lf_stats(session, L_gold_dev)

Creating viewer to assist in LF development.

In [None]:
from snorkel.viewer import SentenceNgramViewer

# Can insert function here to select candidates based on arbitary criteria

#Creating viewer for dev candidates
sv = SentenceNgramViewer(cands_dev, session)
sv

Once LFs are performing well, apply to entire database.  Applying to unlabeled data can take a while.

In [None]:
from  snorkel.annotations import LabelAnnotator
import numpy as np
labeler = LabelAnnotator(lfs=LFs)

%time L_train = labeler.apply(split=0, parallelism=parallelism)
%time L_test = labeler.apply(split=2, parallelism=parallelism)

# can also load with:
# %time L_train = labeler.load_matrix(session, split=0)

Training the generative model.

In [None]:
from snorkel.learning import GenerativeModel
from snorkel.learning import RandomSearch

# Setting parameter ranges for search
param_ranges = {
    'step_size' : [1e-2, 1e-3, 1e-4, 1e-5, 1e-6],
    'decay' : [1.0, 0.95, 0.9],
    'epochs' : [20, 50, 100]
}

# Creating generative model
gen_model = GenerativeModel()

# Creating searcher over hyperparameters-- n is the number of models to train
searcher = RandomSearch(GenerativeModel, param_ranges, L_train, n=5)

# Searching model
%time gen_model, run_stats = searcher.fit(L_dev, L_gold_dev, n_threads=parallelism)

# Printing results of model search
run_stats

Printing learned LF accuracies.

In [None]:
gen_model.weights.lf_accuracy

Error analysis for generative model

In [None]:
tp, fp, tn, fn = gen_model.error_analysis(session, L_dev, L_gold_dev)

In [None]:
sv = SentenceNgramViewer(fp, session)
sv

Getting marginals, plotting training marginals

In [None]:
import matplotlib.pyplot as plt

L_eval = L_test
eval_marginals = gen_model.marginals(L_eval)
training_marginals = gen_model.marginals(L_train)

# Plotting training marignals
plt.hist(training_marginals, bins=20)
plt.show()

Creating dictionary of extractions

In [6]:
from gm_utils import create_extractions_dict
doc_extractions = create_extractions_dict(session, L_eval, eval_marginals, extractions=[extraction_type], dummy=False, geocode=True)

# Uncomment to inspecting extractions dict to check format
doc_extractions

KeyboardInterrupt: 

In [4]:
# Optional, test extractions against original data file
"""
import csv

data_loc = 'output_all100K.tsv'
tsvfile = open(data_loc)
reader = csv.reader(tsvfile, delimiter='\t')
print(next(reader))

for i in range(10):
    line = next(reader)
    print(line[7])
    print()
"""

"\nimport csv\n\ndata_loc = 'output_all100K.tsv'\ntsvfile = open(data_loc)\nreader = csv.reader(tsvfile, delimiter='\t')\nprint(next(reader))\n\nfor i in range(10):\n    line = next(reader)\n    print(line[7])\n    print()\n"

###### Saving extractions to json file

In [None]:
import json

# Setting filename
out_filename = "loc_ext_test_generative.jsonl"

# Saving file to jsonl in extractions format
with open(out_filename, 'w') as outfile:
    for k,v in doc_extractions.items():
        v['url'] = k
        print(json.dumps(v), file=outfile)

Saving training marginals for use with discriminative model

In [None]:
from snorkel.annotations import save_marginals
%time save_marginals(session, L_train, training_marginals)