# Phenotype/SNP relation extraction from tables

Here we will demo the module that parses tables in papers and extracts relations between SNPs and phenotypes (in cases in which the paper discusses multiple phenotypes).

## Preparations

We start by configuring Jupyter and setting up our environment.

In [1]:
%load_ext autoreload
%autoreload 2

import sys
import cPickle
import numpy as np
import sqlalchemy

# set the paths to snorkel and gwaskb
sys.path.append('../snorkel-tables')
sys.path.append('../src')
sys.path.append('../src/crawler')

# set up the directory with the input papers
abstract_dir = '../data/db/papers'

# set up matplotlib
import matplotlib
%matplotlib inline
matplotlib.rcParams['figure.figsize'] = (12,4)

# create a Snorkel session
from snorkel import SnorkelSession
session = SnorkelSession()

### Load corpus

We load our usual corpus of GWAS papers.

In [2]:
from extractor.parser import UnicodeXMLTableDocParser
from snorkel.parser import XMLMultiDocParser

xml_parser = XMLMultiDocParser(
    path=abstract_dir,
    doc='./*',
    text='.//table',
    id='.//article-id[@pub-id-type="pmid"]/text()',
    keep_xml_tree=True)

In [2]:
from snorkel.parser import CorpusParser, OmniParser
from snorkel.models import Corpus

# parses tables into rows, cols, cells...
table_parser = OmniParser(timeout=1000000)

try:
    corpus = session.query(Corpus).filter(Corpus.name == 'GWAS Table Corpus').one()
except:
    cp = CorpusParser(xml_parser, table_parser)
    %time corpus = cp.parse_corpus(name='GWAS Table Corpus', session=session)
    session.add(corpus)
    session.commit()

print 'Loaded corpus of %d documents' % len(corpus)

Loaded corpus of 589 documents


## Candidate extraction

### Define candidate matchers

#### RSid matcher

In [17]:
from snorkel.matchers import RegexMatchSpan
rsid_matcher = RegexMatchSpan(rgx=r'rs\d+(/[ATCG]{1,2})*$')

#### Phenotype matchers

The first matcher checks if we are in a column whose header labels it as a phenotype column.

In [18]:
from snorkel.matchers import CellNameDictionaryMatcher

phen_words = ['trait', 'phenotype', 'outcome'] # words that denote phenotypes
phen_matcher = CellNameDictionaryMatcher(axis='col', d=phen_words, n_max=3, ignore_case=True)

The next matcher will match phenotypes in cells that span an entire axis

In [19]:
from snorkel.matchers import DictionaryMatch
from db.kb import KnowledgeBase
from extractor.util import make_ngrams

# collect phenotype list
kb = KnowledgeBase()
# efo phenotypes
efo_phenotype_list0 = kb.get_phenotype_candidates(source='efo', peek=True) # TODO: remove peaking
efo_phenotype_list = list(make_ngrams(efo_phenotype_list0))
# mesh diseases
mesh_phenotype_list0 = kb.get_phenotype_candidates(source='mesh')
mesh_phenotype_list = list(make_ngrams(mesh_phenotype_list0))
# mesh chemicals
chem_phenotype_list = kb.get_phenotype_candidates(source='chemical')

phenotype_names = efo_phenotype_list + mesh_phenotype_list + chem_phenotype_list
phen_name_matcher = DictionaryMatch(d=phenotype_names, ignore_case=True, stemmer='porter')

### Relation extraction

In [21]:
from snorkel.candidates import CandidateExtractor
from snorkel.throttlers import AlignmentThrottler, SeparatingSpanThrottler, OrderingThrottler, CombinedThrottler

# create a Snorkel class for the relation we will extract
from snorkel.models import candidate_subclass
RsidPhenRel = candidate_subclass('RsidPhenRel', ['rsid','phen'])

# define our candidate spaces
from snorkel.candidates import TableNgrams, TableCells, SpanningTableCells
unigrams = TableNgrams(n_max=1)
cells = TableCells()
spanning_cells = SpanningTableCells(axis='row')

# we will be looking only at aligned cells
row_align_filter = AlignmentThrottler(axis='row', infer=True)

# and at cells where the phenotype is in a spanning header cell above the rsid cell
sep_span_filter = SeparatingSpanThrottler(align_axis='col') # rsid and phen are not separated by spanning cells
col_order_filter = OrderingThrottler(axis='col', first=1) # phen spanning cell comes first
header_filter = CombinedThrottler([sep_span_filter, col_order_filter]) # combine the two throttlers

# the first extractor looks at phenotype names in columns with a header indicating it's a phenotype
ce1 = CandidateExtractor(RsidPhenRel, [unigrams, cells], [rsid_matcher, phen_matcher], throttler=row_align_filter)

# the second extractor looks at phenotype names in columns with a header indicating it's a phenotype
ce2 = CandidateExtractor(RsidPhenRel, [unigrams, spanning_cells], [rsid_matcher, phen_name_matcher], throttler=header_filter, stop_on_duplicates=False)

# collect that cells that will be searched for candidates
tables = [table for doc in corpus.documents for table in doc.tables]

We are now ready to perform relation extraction.

In [22]:
from snorkel.models import CandidateSet

try:
    rels1 = session.query(CandidateSet).filter(CandidateSet.name == 'RsidPhenRel Set 1').one()
except:
    %time rels1 = ce1.extract(tables, 'RsidPhenRel Set 1', session)
    
print "%s relations extracted, e.g." % len(rels1)
for cand in rels1[:10]:
    print cand

3646 relations extracted, e.g.
RsidPhenRel(Span("rs464766", parent=302832, chars=[0,7], words=[0,0]), Span("Mean BMI", parent=302831, chars=[0,7], words=[0,1]))
RsidPhenRel(Span("rs10504576", parent=302729, chars=[0,9], words=[0,0]), Span("Mean WC", parent=302728, chars=[0,6], words=[0,1]))
RsidPhenRel(Span("rs2296465", parent=302754, chars=[0,8], words=[0,0]), Span("Mean BMI", parent=302753, chars=[0,7], words=[0,1]))
RsidPhenRel(Span("rs315711", parent=302948, chars=[0,7], words=[0,0]), Span("Mean WC", parent=302947, chars=[0,6], words=[0,1]))
RsidPhenRel(Span("rs2221880", parent=302780, chars=[0,8], words=[0,0]), Span("Mean BMI", parent=302779, chars=[0,7], words=[0,1]))
RsidPhenRel(Span("rs10488165", parent=302873, chars=[0,9], words=[0,0]), Span("Mean WC", parent=302872, chars=[0,6], words=[0,1]))
RsidPhenRel(Span("rs4129319", parent=302812, chars=[0,8], words=[0,0]), Span("Mean WC", parent=302811, chars=[0,6], words=[0,1]))
RsidPhenRel(Span("rs1374489", parent=302760, chars=[0,8]

In [23]:
from snorkel.models import CandidateSet

try:
    rels2 = session.query(CandidateSet).filter(CandidateSet.name == 'RsidPhenRel Set 2').one()
except:
    %time rels2 = ce2.extract(tables, 'RsidPhenRel Set 2', session)
    
print "%s relations extracted, e.g." % len(rels2)
for cand in rels2[:10]: 
    print cand

80 relations extracted, e.g.
RsidPhenRel(Span("rs6462411", parent=320685, chars=[0,8], words=[0,0]), Span("Thyroid Stimulating Hormone", parent=320670, chars=[0,26], words=[0,2]))
RsidPhenRel(Span("rs10848704", parent=320706, chars=[0,9], words=[0,0]), Span("Thyroid Stimulating Hormone", parent=320670, chars=[0,26], words=[0,2]))
RsidPhenRel(Span("rs925488", parent=320749, chars=[0,7], words=[0,0]), Span("Thyroid Stimulating Hormone", parent=320670, chars=[0,26], words=[0,2]))
RsidPhenRel(Span("rs7804166", parent=320791, chars=[0,8], words=[0,0]), Span("Thyroid Stimulating Hormone", parent=320670, chars=[0,26], words=[0,2]))
RsidPhenRel(Span("rs6956479", parent=320909, chars=[0,8], words=[0,0]), Span("Thyroid Stimulating Hormone", parent=320670, chars=[0,26], words=[0,2]))
RsidPhenRel(Span("rs976731", parent=320220, chars=[0,7], words=[0,0]), Span("Waist Circumference", parent=320185, chars=[0,18], words=[0,1]))
RsidPhenRel(Span("rs2225614", parent=320241, chars=[0,8], words=[0,0]), Sp

Finally, we merge the two sets of candiates into a single set.

In [24]:
from snorkel.models import CandidateSet

try:
    rels = session.query(CandidateSet).filter(CandidateSet.name == 'RsidPhenRel Canidates').one()
except:
    rels = CandidateSet(name='RsidPhenRel Canidates')
    for c in rels1: rels.append(c)
    for c in rels2: rels.append(c)

    session.add(rels)
    session.commit()

print '%d candidates in total' % len(rels)

3726 candidates in total


## Learning the correctness of relations

Next, we will train machine learning models to identify which phenotype candidates are actually correct.

### Generating a labeled set of examples

We first split data into an (unlabeled) training set (since we will use unsupervised risk estimation to train a candidate on it), and a dev/test set.

In [7]:
try:
    train_c = session.query(CandidateSet).filter(CandidateSet.name == 'RsidPhenRel Training Candidates').one()
    devtest_c = session.query(CandidateSet).filter(CandidateSet.name == 'RsidPhenRel Dev/Test Candidates').one()
except:
    # delete any previous sets with that name
    session.query(CandidateSet).filter(CandidateSet.name == 'RsidPhenRel Training Candidates').delete()
    session.query(CandidateSet).filter(CandidateSet.name == 'RsidPhenRel Dev/Test Candidates').delete()

    frac_test = 0.5

    # initialize the new sets
    train_c = CandidateSet(name='RsidPhenRel Training Candidates')
    devtest_c = CandidateSet(name='RsidPhenRel Dev/Test Candidates')

    # choose a random subset for the labeled set
    n_test = len(rels) * frac_test
    test_idx = set(np.random.choice(len(rels), size=(n_test,), replace=False))

    # add to the sets
    for i, c in enumerate(rels):
        if i in test_idx:
            devtest_c.append(c)
        else:
            train_c.append(c)

    # save the results
    session.add(train_c)
    session.add(devtest_c)
    session.commit()

print 'Initialized %d training and %d dev/testing candidates' % (len(train_c), len(devtest_c))

Initialized 1863 training and 1863 dev/testing candidates


### Labelling functions

Following the data programming approach, we define set of labeling functions. We will learn their accuracy via unsupervised learning and use them for classifying candidates.

In [8]:
from snorkel.lf_helpers import *
s=None
doc = [d for d in corpus.documents if d.name == '17903303'][0]
table = doc.tables[3]
for cell in table.cells:
    top_cells = get_aligned_cells(cell, 'col', infer=True)
    top_phrases = [phrase for cell in top_cells for phrase in cell.phrases]
# rels[0][1].parent.table.cells[0].phrases
# corpus.documents[0].phrases



 BeautifulSoup([your markup])

to this:

 BeautifulSoup([your markup], "lxml")

  markup_type=markup_type))


In [9]:
from snorkel.lf_helpers import *

bad_words = ['rs number', 'rs id', 'rsid']

# negative LFs
def LF_number(m):
    txt = m[1].get_span()
    frac_num = len([ch for ch in txt if ch.isdigit()]) / float(len(txt))
    return -1 if len(txt) > 5 and frac_num > 0.4 or frac_num > 0.6 else 0

def LF_bad_phen_mentions(m):
    if cell_spans(m[1].parent.cell, m[1].parent.table, 'row'): return 0
    #     if m[1].context.cell.spans('row'): return 0
    top_cells = get_aligned_cells(m[1].parent.cell, 'col', infer=True)
    top_cells = [cell for cell in top_cells]
#     top_cells = m.span1.context.cell.aligned_cells(axis='col', induced=True)
    try:
        top_phrases = [phrase for cell in top_cells for phrase in cell.phrases]
    except:
        for cell in top_cells:
            print cell, cell.phrases
    if not top_phrases: return 0
    matching_phrases = []
    for phrase in top_phrases:
        if any (phen_matcher._f_ngram(word) for word in phrase.text.split(' ')):
            matching_phrases.append(phrase)
    small_matching_phrases = [phrase for phrase in matching_phrases if len(phrase.text) <= 25]
    return -1 if not small_matching_phrases else 0

def LF_bad_word(m):
    txt = m[1].get_span()
    return -1 if any(word in txt for word in bad_words) else 0

LF_tables_neg = [LF_number, LF_bad_phen_mentions]

# positive LFs
def LF_no_neg(m):
    return +1 if not any(LF(m) for LF in LF_tables_neg) else 0

LF_tables_pos = [LF_no_neg]

LFs = LF_tables_neg + LF_tables_pos

We generate features for the training set.

In [10]:
from snorkel.annotations import LabelManager
label_manager = LabelManager()

try:
    %time L_train = label_manager.load(session, train_c, 'RsidPhenRel LF Labels6')
except sqlalchemy.orm.exc.NoResultFound:
    %time L_train = label_manager.create(session, train_c, 'RsidPhenRel LF Labels6', f=LFs)

CPU times: user 96.4 ms, sys: 9.73 ms, total: 106 ms
Wall time: 133 ms


Let's also look at some basic statistics.

In [11]:
L_train.lf_stats()

Unnamed: 0,conflicts,coverage,j,overlaps
LF_number,0,0.129361,0,0.065486
LF_bad_phen_mentions,0,0.13956,1,0.065486
LF_no_neg,0,0.796565,2,0.0


### Training a machine learning model

Next, we train a generative model, just like in the phenotype extraction notebook.

In [12]:
from snorkel.learning import NaiveBayes

gen_model = NaiveBayes()
gen_model.train(L_train, n_iter=10000, rate=1e-2)

because the backend has already been chosen;
matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
or matplotlib.backends is imported for the first time.



Training marginals (!= 0.5):	1863
Features:			3
Begin training for rate=0.01, mu=1e-06
	Learning epoch = 0	Gradient mag. = 0.069766
	Learning epoch = 250	Gradient mag. = 0.079974
	Learning epoch = 500	Gradient mag. = 0.086107
	Learning epoch = 750	Gradient mag. = 0.091676
	Learning epoch = 1000	Gradient mag. = 0.096584
	Learning epoch = 1250	Gradient mag. = 0.100796
	Learning epoch = 1500	Gradient mag. = 0.104331
	Learning epoch = 1750	Gradient mag. = 0.107247
	Learning epoch = 2000	Gradient mag. = 0.109621
	Learning epoch = 2250	Gradient mag. = 0.111536
	Learning epoch = 2500	Gradient mag. = 0.113072
	Learning epoch = 2750	Gradient mag. = 0.114301
	Learning epoch = 3000	Gradient mag. = 0.115283
	Learning epoch = 3250	Gradient mag. = 0.116066
	Learning epoch = 3500	Gradient mag. = 0.116693
	Learning epoch = 3750	Gradient mag. = 0.117194
	Learning epoch = 4000	Gradient mag. = 0.117596
	Learning epoch = 4250	Gradient mag. = 0.117918
	Learning epoch = 4500	Gradient mag. = 0.118178
	Learni

In [13]:
gen_model.w

array([ 9.08323015,  8.86630481,  0.98503744])

## Classify all the candidates

In [26]:
from snorkel.annotations import LabelManager
label_manager = LabelManager()

# delete existing labels
# session.rollback()
# session.query(AnnotationKeySet).filter(AnnotationKeySet.name == 'RsidPhenRel LF All Labels').delete()
%time L_all = label_manager.create(session, rels, 'RsidPhenRel LF All Lab', f=LFs)

Generating annotations for 3726 candidates...
Loading sparse Label matrix...
CPU times: user 14min 4s, sys: 4min 3s, total: 18min 7s
Wall time: 18min 13s


Save the results

In [27]:
preds = gen_model.odds(L_all)
good_rels = [(c[0].parent.document.name, c[0].get_span(), c[1].get_span()) for (c, p) in zip(rels, preds) if p > 0]
print len(good_rels), 'relations extracted, e.g.:'
print good_rels[:10]

# store relations to annotate
with open('results/nb-output/rels.acronyms.extracted.tsv', 'w') as f:
    for doc_id, str1, str2 in good_rels:
        try:
            out = u'{}\t{}\t{}\n'.format(doc_id, unicode(str1), str2)
            f.write(out.encode("UTF-8"))
        except:
            print 'Error saving:', str1, str2

2958 relations extracted, e.g.:
[(u'17903300', u'rs464766', u'Mean BMI'), (u'17903300', u'rs10504576', u'Mean WC'), (u'17903300', u'rs2296465', u'Mean BMI'), (u'17903300', u'rs315711', u'Mean WC'), (u'17903300', u'rs2221880', u'Mean BMI'), (u'17903300', u'rs10488165', u'Mean WC'), (u'17903300', u'rs4129319', u'Mean WC'), (u'17903300', u'rs1374489', u'Mean BMI'), (u'17903300', u'rs7941883', u'Mean WC'), (u'17903300', u'rs7202384', u'Mean BMI')]


## Acronym resolution

A large fraction of the phenotypes we extracted consist of acronyms. This section deals with translating these acronyms.

This step requires a list of acronyms and the phenotype to which they correspond. You need to run the acronym extraction notebook to produce this list.

In [28]:
from extractor.dictionary import Dictionary, unravel

D = Dictionary()
D.load('results/nb-output/acronyms.extracted.all.tsv')
print len(D), 'definitions loaded'

3221 definitions loaded


The above dictionary object performs acronym translation. We apply it to our extracted relations to produce a new, translated list.

### Filter relations with low p-values

A number of the above relations will involve SNPs that are not statistically significant. Here, we would like to prefilter them based on p-values extracted in the p-value extraction notebook.

If you haven't yet generated these p-values, you may skip this step (just comment out the filtering in the final cell), as we will perform filtering in the final notebook anyway.

In [29]:
pval_rsid_dict = dict()
pval_dict = dict() # combine all of the pvalues for a SNPs in the same document into one set
with open('results/nb-output/pval-rsid.tsv') as f:
    for line in f:
        pmid, rsid, table_id, row_id, col_id, log_pval = line.strip().split('\t')
        log_pval, table_id, row_id, col_id = float(log_pval), int(table_id), int(row_id), int(col_id)
        
        if pmid not in pval_rsid_dict: pval_rsid_dict[pmid] = dict()
        key = (rsid, table_id, row_id)
        if key not in pval_rsid_dict[pmid]: pval_rsid_dict[pmid][key] = set()
        pval_rsid_dict[pmid][key].add(log_pval)
                
        if pmid not in pval_dict: pval_dict[pmid] = dict()
        if rsid not in pval_dict[pmid]: pval_dict[pmid][rsid] = set()
        pval_dict[pmid][rsid].add(log_pval)

pval_dict0 = {pmid : {rsid : min(pval_dict[pmid][rsid]) for rsid in pval_dict[pmid]} for pmid in pval_dict}
pval_rsid_dict0 = {pmid : {key : min(pval_rsid_dict[pmid][key]) for key in pval_rsid_dict[pmid]} for pmid in pval_rsid_dict}
pval_dict = pval_dict0
pval_rsid_dict = pval_rsid_dict0

### Save all relations that have sufficiently small p-values

Finally, we resolve acronyms and save the relations with their resolved phenotype names. We also store their coordinates (table, row, col) in the table.

In [40]:
# preds = learner.predict_wmv(candidates)
predicted_candidates =  [c for (c, p) in zip(rels, preds) if p > 0]

import re
import unicodedata
def _normalize_str(s):
    try:
        s = s.encode('utf-8')
        return s
    except UnicodeEncodeError: 
        pass
    try:
        s = s.decode('utf-8')
        return s
    except UnicodeDecodeError: 
        pass    
    raise Exception()
    
def clean_rsid(rsid):
    return re.sub('/.+', '', rsid)

with open('results/nb-output/phen-rsid.table.rel.all.tsv', 'w') as f:
    for c in predicted_candidates:
        pmid = c[0].parent.document.name
        rsid = c[0].get_span()
        phen = c[1].get_span()        
        table_id = c[0].parent.table.position
        row_num = c[0].parent.cell.row.position
        col_num = c[0].parent.cell.col.position # of the rsid
        
        if row_num is None:
            print c[0].parent.cell

        phen = (unravel(pmid, phen, D))
        if isinstance(phen, unicode):
            phen = phen.encode('utf-8')
        
        try:
            log_pval = pval_rsid_dict[pmid][(rsid, table_id, row_num)]
        except KeyError:
            log_pval = -1000
#             continue
        if 10**log_pval > 1e-5: continue

        out_str = '{pmid}\t{rsid}\t{phen}\t{pval}\ttable\t{table_id}\t{row}\t{col}\n'.format(
                    pmid=pmid, rsid=clean_rsid(rsid), phen=phen, pval=log_pval, table_id=table_id, row=row_num, col=col_num)
        f.write(out_str)