# Step 4: Evaluate Saved Discriminative Model

The first thing to do is ensure that modules are auto-reloaded at runtime to allow for development in other files.

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

We then set the Snorkel database location and start and connect to it.  By default, we use a PosgreSQL database backend, which can be created using `createdb DB_NAME` once psql is installed.  Note that Snorkel does *not* currently support parallel database processing with a SQLite backend.

In [2]:
# Setting Snorkel DB location
import os
import sys

import random
import numpy as np

#For PostgreSQL
postgres_location = 'postgresql://jdunnmon:123@localhost:5432'
postgres_db_name = 'loc_jd_1M_parsed'
os.environ['SNORKELDB'] = os.path.join(postgres_location,postgres_db_name)

#For local PostgreSQL
#os.environ['SNORKELDB'] = 'postgres:///es_locs_small'

# Adding path above for utils
sys.path.append('../utils')

# For SQLite
#db_location = '.'
#db_name = "es_locs_small.db"
#os.environ['SNORKELDB'] = '{0}:///{1}/{2}'.format("sqlite", db_location, db_name)

# Start Snorkel session
from snorkel import SnorkelSession
session = SnorkelSession()

# Setting parallelism
parallelism = 72

# Setting random seed
seed = 1701
random.seed(seed)
np.random.seed(seed)

In [14]:
from dataset_utils import set_preprocessor, combine_dedupe

# Set data source: options are 'content.tsv', 'memex_jsons', 'es'
data_source = 'es'

# Setting max number of docs to ingest
max_docs = 1000

# Setting location of data source

# For ES:
#data_loc = '/lfs/local/0/jdunnmon/data/chtap/output_all_parsed'
#data_loc = '/dfs/scratch0/jdunnmon/data/memex-data/tsvs/output_all_b'
data_loc = '/dfs/scratch0/jdunnmon/data/memex-data/tsvs/price/test/shard/parsed'

# Optional: add tsv with additional documents to create combined tsv without duplicates
#data_loc = combine_dedupe(data_loc, 'output_location.tsv', 'combined.tsv')

# If memex_raw_content is a content_field, uses term as a regex in raw data in addition to getting title and body
term = r'\b[Ll]ocation:|\b[cC]ity:'

# Doc length in characters, remove to have no max
max_doc_length=10000

# Setting preprocessor
doc_preprocessor = set_preprocessor(data_source, data_loc, max_docs=max_docs, verbose=True, clean_docs=False,
                                    content_fields=['raw_content', 'url'], term=term, max_doc_length=max_doc_length)

Using parallelized loader


In [None]:
from snorkel.parser import CorpusParser
from snorkel.parser.rule_parser import RegexTokenizer, RuleBasedParser, SpacyTokenizer
from snorkel.parser.spacy_parser import Spacy
from snorkel.parser.simple_tokenizer import SimpleTokenizer
from snorkel.parser.corenlp import StanfordCoreNLPServer

# Applying corpus parser
#parser = Spacy(annotators=['entity','parser'],
#                 lang='en', num_threads=1)
#parser = RuleBasedParser(tokenizer=SpacyTokenizer())
parser=SimpleTokenizer(delim='<|>')
#parser = StanfordCoreNLPServer()
corpus_parser = CorpusParser(parser=parser)
%time corpus_parser.apply(list(doc_preprocessor), parallelism=1, verbose=True)

> /lfs/raiders5/0/jdunnmon/chtap/extractors/src/utils/dataset_utils.py(749)parse_file()
-> num_fields = len(line.split('\t'))
(Pdb) n
> /lfs/raiders5/0/jdunnmon/chtap/extractors/src/utils/dataset_utils.py(750)parse_file()
-> continue
(Pdb) num_fields
19
(Pdb) n
> /lfs/raiders5/0/jdunnmon/chtap/extractors/src/utils/dataset_utils.py(746)parse_file()
-> for ind, line in enumerate(tsv):
(Pdb) n
> /lfs/raiders5/0/jdunnmon/chtap/extractors/src/utils/dataset_utils.py(747)parse_file()
-> if ind == 0:
(Pdb) n
> /lfs/raiders5/0/jdunnmon/chtap/extractors/src/utils/dataset_utils.py(751)parse_file()
-> try:
(Pdb) n
> /lfs/raiders5/0/jdunnmon/chtap/extractors/src/utils/dataset_utils.py(754)parse_file()
-> if num_fields == 21:
(Pdb) n
> /lfs/raiders5/0/jdunnmon/chtap/extractors/src/utils/dataset_utils.py(756)parse_file()
-> elif num_fields == 19:
(Pdb) n
> /lfs/raiders5/0/jdunnmon/chtap/extractors/src/utils/dataset_utils.py(757)parse_file()
-> (memex_id, memex_content_type, crawl_data, memex_crawler,

In [10]:
from snorkel.models import Document, Sentence

# Printing number of docs/sentences
print("Documents:", session.query(Document).count())
print("Sentences:", session.query(Sentence).count())

Documents: 1000
Sentences: 1000


In [11]:
# Getting all documents parsed by Snorkel
sents = session.query(Sentence).all()

In [13]:
sents[0]

Sentence(Document b'422CDB4BF9939B2DC198C383213B267BE146840A096F1B9510CEA24939310916',0,b'show[field_image][Alexis&lt;br /&gt;&lt;br /&gt;&lt;a href=&quot;http://www.naughtyreviews.com/companions/alexis-bellingham-white-female-2&quot; class=&quot;active&quot;&gt;View Image Details&lt;/a&gt;]"">\\n <img src=""http://li1186-134.members.linode.com/files/imagecache/profile_gallery/images/2116/360-483-7810_bellingham_alexis_2.jpg"" alt=""360-483-7810_bellingham_alexis.jpg"" title=""Alexis"" />\\n </a>\\n </div>\\n <span href=""#"" class=""galRightArrow""></span>\\n </div>\\n </div>\\n <!-- /profile-block -->\\n </div>\\n <!-- /mini-panel-elegance2_individual -->\\n\\n <div class=""closure"">\\n <span class=""navlinks"">\\n\\n </span>\\n </div>\\n </div><!--/content-->\\n </div><!--/panel-pane organization node profile-->\\n </div><!-- /tbl -->\\n </div><!-- /profileInfo -->\\n </div> <!-- /tbl -->\\n</div> <!-- /pageWidth -->\\n\\n\\n <footer>\\n <div class=""pageWidth clearfix"">\\n \\t<di

Create candidate subclass

In [6]:
from snorkel.candidates import Ngrams
from snorkel.candidates import CandidateExtractor
from dataset_utils import create_candidate_class, LocationMatcher, city_index
from snorkel.matchers import Union, LambdaFunctionMatcher

# Setting extraction type -- should be a subfield in your data source extractions field!
extraction_type = 'location'

# Creating candidate class
candidate_class, candidate_class_name = create_candidate_class(extraction_type)

# Defining ngrams for candidates
location_ngrams = Ngrams(n_max=3)

# Define matchers
cities = city_index('../utils/data/cities15000.txt')
geo_location_matcher = LambdaFunctionMatcher(func=cities.fast_loc)
# spacy_location_matcher = LocationMatcher(longest_match_only=True)

# Union matchers and create candidate extractor
location_matcher = Union(geo_location_matcher)
cand_extractor   = CandidateExtractor(candidate_class, [location_ngrams], [location_matcher])

In [7]:
# Applying candidate extractor to each split
%time cand_extractor.apply(sents, split=0, parallelism=8)
#print("Number of candidates:", session.query(candidate_class).filter(candidate_class.split == 0).count())

Clearing existing...
Running UDF...
CPU times: user 4min 26s, sys: 36.2 s, total: 5min 2s
Wall time: 1h 26min 22s


Getting candidates for evaluation

In [7]:
# Split to pull eval candidates from
eval_split = 0

# Executing query for eval candidates
eval_cands = session.query(candidate_class).filter(candidate_class.split == eval_split).order_by(candidate_class.id).all()
print(f'Loaded {len(eval_cands)} candidates...')

Loaded 629973 candidates...


Loading trained discriminative model

In [8]:
from dm_utils import LSTM

# defining model
lstm = LSTM(n_threads=parallelism)

# defining saved weights directory and name

model_name = 'location_lstm' # this was provided when the model was saved!
save_dir = 'checkpoints' # this was provided when the model was saved!

# loading
lstm.load(model_name=model_name, save_dir=save_dir, verbose=True)

Number of learning epochs:      5
Learning rate:                  1e-05
Use attention:                  False
LSTM hidden dimension:          128
Dropout:                        0.5
Checkpoint Patience:            2
Batch size:                     128
Rebalance:                      0.5
Load pre-trained embedding:     False
Host device:                    gpu
Word embedding size:            300
Word embedding:                 None
[LSTM] Loaded model <location_lstm>, only_param=False


Getting and plotting discriminative model marginals.

In [9]:
import matplotlib.pyplot as plt

%time eval_marginals = lstm.marginals(eval_cands)
# Plotting eval marginals
#plt.hist(eval_marginals, bins=20)
#plt.show()

ObjectDeletedError: Instance '<Span at 0x7f16831d3748>' has been deleted, or its row is otherwise not present.

Generating extractions from discriminative model marginals

In [None]:
from gm_utils import create_extractions_dict

# Enter googlemaps api key to get geocodes, leave blank to just use extracted locations
geocode_key = None
# geocode_key = 'AIzaSyBlLyOaasYMgMxFGUh2jJyxIG0_pZFF_jM'

doc_extractions = create_extractions_dict(session, eval_cands, eval_marginals, extractions=[extraction_type],
                                          dummy=False, geocode_key=geocode_key)

# Uncomment to inspecting extractions dict to check format
#doc_extractions

Saving extractions to json file

In [None]:
import json

# Setting filename
out_filename = "loc_ext_test_discriminative_eval_100K.jsonl"

# Saving file to jsonl in extractions format
with open(out_filename, 'w') as outfile:
    for k,v in doc_extractions.items():
        v['id'] = k
        print(json.dumps(v), file=outfile)

In [47]:
doc_extractions["b'3B0BD327C38D7578DA50E19F8BD68873B5C0E9CD63000BF37AEE120D7901AA06'"]['location']

[['Extracted String: date', ('Date', '12', 'JP', 35301, 42.46806, 140.86806)],
 ['Extracted String: goes', ('Goes', '10', 'NL', 36931, 51.50417, 3.88889)],
 ['Extracted String: name', ('Namur', 'WAL', 'BE', 106284, 50.4669, 4.86746)],
 ['Extracted String: aurora',
  ('Aurora', 'IL', 'US', 200661, 41.76058, -88.32007),
  ('Aurora', 'CO', 'US', 359407, 39.72943, -104.83192),
  ('Aurora', '40', 'PH', 16178, 13.3476, 122.5195),
  ('Aurora', 'OH', 'US', 15838, 41.31755, -81.34539)],
 ['Extracted String: new york',
  ('New York City', 'NY', 'US', 8175133, 40.71427, -74.00597)],
 ['Extracted String: queens',
  ('Queens', 'NY', 'US', 2272771, 40.68149, -73.83652),
  ('Queens Village', 'NY', 'US', 51919, 40.72677, -73.74152)],
 ['Extracted String: much', ('Much', '07', 'DE', 15231, 50.90383, 7.40306)]]

In [None]:
len(doc_extractions)