In [None]:
# should process a grid search for lstm?
do_grid_search = False

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

from snorkel import SnorkelSession

session = SnorkelSession()

In [None]:
from ksnorkel import KSUtils


train_sent, dev_sent, test_sent = KSUtils.split_sentences(session, split=[0.8, 0.1, 0.1], seed=12345)

In [None]:
from snorkel.models import Candidate, candidate_subclass
from snorkel.candidates import PretaggedCandidateExtractor

ChemicalGeneInteraction = candidate_subclass('ChemicalGeneInteraction', ['chemical', 'gene'])
candidate_extractor = PretaggedCandidateExtractor(ChemicalGeneInteraction, ['Chemical', 'Gene'])

for k, sents in enumerate([train_sent,dev_sent, test_sent]):  
    candidate_extractor.apply(sents, split=k, clear=True)
    print("Number of candidates:", session.query(ChemicalGeneInteraction).filter(ChemicalGeneInteraction.split == k).count())

In [None]:
import gzip
import re

ctd_chem_gene_inter = set()
i = 0
with gzip.open('data/CTD_chem_gene_ixns.tsv.gz','r') as f:
    for l in f:
        line = str(l).replace('b\'', '').replace('\\n\'', '').replace('\\r','')
        # skip comments
        if line.startswith('#'):
            continue
        #print(line)
        components = line.split('\\t')    
        
        # add MESH:
        if not components[1].startswith('MESH:'):
            components[1] = "MESH:" + components[1]
            
        chemical = components[1]
        gene = components[4]
        key = frozenset((chemical, gene))
        ctd_chem_gene_inter.add(key)
        i += 1

    
print('{} chemical-gene assocations read from ChG-CTD_chem_gene_ixns'.format(len(ctd_chem_gene_inter)))
#240349
def cand_in_chemical_gene_interactions(c):
    key = frozenset((c.chemical_cid, c.gene_cid))
    if key in ctd_chem_gene_inter:
    #    print(key)
        return 1
    return -1

In [None]:
from ksnorkel import KSUtils

KSUtils.add_gold_labels_for_candidates(session, ChemicalGeneInteraction, cand_in_chemical_gene_interactions)

In [None]:
import re
from snorkel.lf_helpers import (
    get_tagged_text,
    rule_regex_search_tagged_text,
    rule_regex_search_btw_AB,
    rule_regex_search_btw_BA,
    rule_regex_search_before_A,
    rule_regex_search_before_B
)




# List to parenthetical
def ltp(x):
    return '(' + '|'.join(x) + ')'



inhibition_samples = ["inhibitors", "inhibition", "auto-inhibition"]


def LF_CG_AB_before_inhibition(c):
    return 1 if (re.search(r'{{A}}.{0,50} ' + '{{B}}.{0,50}' + 'inhibition' , get_tagged_text(c), re.I) or
                re.search(r'{{B}}.{0,50} ' + '{{A}}.{0,50}' + 'inhibition' , get_tagged_text(c), re.I)) else 0
                

def LF_CG_AB_after_inhibition(c):
    return 1 if (rule_regex_search_before_A(c,  ltp(inhibition_samples) + '.{0,100}', 1) and
                  rule_regex_search_before_B(c, ltp(inhibition_samples) + '.{0,100}', 1)) else 0

def LF_CG_A_inhibition_B_far(c):
    return rule_regex_search_btw_AB(c, '.{0,400}' + ltp(inhibition_samples) + '.{0,400}', -1) 

def LF_CG_B_inhibition_A_far(c):
    return rule_regex_search_btw_BA(c, '.{0,400}' + ltp(inhibition_samples) + '.{0,400}', -1)

def LF_CG_A_inhibition_B(c):
    return rule_regex_search_btw_AB(c, '.{0,50}' + ltp(inhibition_samples) + '.{0,50}', 1) 

def LF_CG_B_inhibition_A(c):
    return rule_regex_search_btw_BA(c, '.{0,50}' + ltp(inhibition_samples) + '.{0,50}', 1)

def LG_CG_inhibition_before_B_near(c):
    return rule_regex_search_before_B(c, ltp(inhibition_samples) + '.{0,50}', 1)
 
def LG_CG_inhibition_before_B_far(c):
    return rule_regex_search_before_B(c, ltp(inhibition_samples) + '.{0,2000}', -1)


inhibits_samples = ["inhibit"]

def LF_CG_A_inhibitis_B(c):
    return rule_regex_search_btw_AB(c, '.{0,50}' + ltp(inhibits_samples) + '.{0,50}', 1) 

inhibited_samples = ["inhibited"]



def LF_CG_B_inhibited_A(c):
    return rule_regex_search_btw_BA(c, '.{0,50}' + ltp(inhibited_samples) + '.{0,50}', 1) 



def LF_CG_not_inhibited(c):
    sent = c.get_parent()
    if 'not inhib' not in get_tagged_text(c):
        return -1
    return 0
                

    
LFs_DG = [
    LF_CG_AB_before_inhibition,
    LF_CG_AB_after_inhibition,
    LF_CG_A_inhibition_B_far,
    LF_CG_B_inhibition_A_far,
    LF_CG_A_inhibition_B,
    LF_CG_B_inhibition_A,
    LG_CG_inhibition_before_B_near,
    LG_CG_inhibition_before_B_far,
    LF_CG_A_inhibitis_B,
    LF_CG_B_inhibited_A#,  
 #   LF_CG_not_inhibited
]

In [None]:
import re
from snorkel.lf_helpers import (
    get_tagged_text,
    rule_regex_search_tagged_text,
    rule_regex_search_btw_AB,
    rule_regex_search_btw_BA,
    rule_regex_search_before_A,
    rule_regex_search_before_B,
)
# List to parenthetical
def ltp(x):
    return '(' + '|'.join(x) + ')'



inhibition_samples = ["metabolism", "metabolite", "auto-inhibition"]

def LF_GC_A_metabolism_B(c):
    return rule_regex_search_btw_AB(c, '.{0,50}' + ltp(inhibition_samples) + '.{0,50}', 1) 

def LF_GC_B_metabolism_A(c):
    return rule_regex_search_btw_BA(c, '.{0,50}' + ltp(inhibition_samples) + '.{0,50}', 1)

def LG_GC_metabolism_before_B_near(c):
    return rule_regex_search_before_B(c, ltp(inhibition_samples) + '.{0,50}', 1)
 
def LG_GC_metabolism_before_B_far(c):
    return rule_regex_search_before_B(c, ltp(inhibition_samples) + '.{0,2000}', 1)


inhibits_samples = ["metabolises", "metabolizes"]

def LF_GC_A_metabolis_B(c):
    return rule_regex_search_btw_AB(c, '.{0,50}' + ltp(inhibits_samples) + '.{0,50}', 1) 

inhibited_samples = ["metabolised", "metabolized"]

def LF_GC_A_metabolized_B(c):
    return rule_regex_search_btw_AB(c, '.{0,50}' + ltp(inhibited_samples) + '.{0,50}', 1) 

def LF_GC_B_metabolized_A(c):
    return rule_regex_search_btw_BA(c, '.{0,50}' + ltp(inhibited_samples) + '.{0,50}', 1) 



def LF_CG_not_metabol(c):
    sent = c.get_parent()
    if 'not metabol' in get_tagged_text(c):
        return -1
    return 0
                
    
def LF_CG_in_CTD_chem_gene(c):
    if cand_in_chemical_gene_interactions(c) == 1:
        return 1
    else:
        return -1


LFs_GC = [
    LF_GC_A_metabolism_B,
    LF_GC_B_metabolism_A,
    LG_GC_metabolism_before_B_near,
    LG_GC_metabolism_before_B_far,
    LF_GC_A_metabolis_B,
    LF_GC_B_metabolized_A,
#    LF_CG_not_metabol,
    LF_CG_in_CTD_chem_gene
]

In [None]:
LFs_DG.extend(LFs_GC)

In [None]:
from snorkel.annotations import LabelAnnotator
labeler = LabelAnnotator(lfs=LFs_DG)
%time L_train = labeler.apply(lfs=LFs_DG, parallelism=1)
L_train

L_train.lf_stats(session)

In [None]:
from snorkel.annotations import load_gold_labels

L_dev = labeler.apply_existing(split=1)
L_test = labeler.apply_existing(split=2)
L_gold_dev = load_gold_labels(session, annotator_name='gold',split=1)
L_gold_test = load_gold_labels(session, annotator_name='gold', split=2)


In [None]:
from snorkel.learning import GenerativeModel
from snorkel.learning import RandomSearch
from snorkel.learning.structure import DependencySelector


MAX_DEPS = 5

ds = DependencySelector()
deps = ds.select(L_train, threshold=0.1)
deps = set(list(deps)[0:min(len(deps), MAX_DEPS)])

print("Using {} dependencies".format(len(deps)))



# use random search to optimize the generative model
param_grid = {
    'step_size' : [1e-3, 1e-4, 1e-5, 1e-6],
    'decay'     : [0.9, 0.95],
    'epochs'    : [50,100,150],
    'reg_param' : [1e-3],
}

model_class_params = {'lf_propensity' : False }#, 'deps': deps}

searcher = RandomSearch(GenerativeModel, param_grid, L_train, n=10, model_class_params=model_class_params)
%time gen_model, run_stats = searcher.fit(L_dev, L_gold_dev) #, deps=deps)
run_stats

train_marginals = gen_model.marginals(L_train)

In [None]:
import matplotlib.pyplot as plt
plt.hist(train_marginals, bins=20)
plt.show()

In [None]:
from snorkel.annotations import save_marginals

dev_marginals = gen_model.marginals(L_dev)
_, _, _, _ = gen_model.error_analysis(session, L_dev, L_gold_dev)

%time save_marginals(session, L_train, train_marginals)

In [None]:
tp, fp, tn, fn = gen_model.error_analysis(session, L_test, L_gold_test)

In [None]:
print("Load all ChemicalGeneInteraction candidates from db...")
train_cands = session.query(ChemicalGeneInteraction).filter(ChemicalGeneInteraction.split == 0).order_by(ChemicalGeneInteraction.id).all()
dev_cands   = session.query(ChemicalGeneInteraction).filter(ChemicalGeneInteraction.split == 1).order_by(ChemicalGeneInteraction.id).all()
test_cands  = session.query(ChemicalGeneInteraction).filter(ChemicalGeneInteraction.split == 2).order_by(ChemicalGeneInteraction.id).all()


all_cands = []
all_cands.extend(train_cands)
all_cands.extend(dev_cands)
all_cands.extend(test_cands)


print("{} {} {}".format(len(train_cands), len(dev_cands), len(test_cands)))
print("Amount of all candidates: {}".format(len(all_cands)))

In [None]:
from snorkel.learning.pytorch import LSTM
from snorkel.annotations import load_gold_labels

#train_kwargs = {
#    'lr':            0.01,
#    'embedding_dim': 75,
#    'hidden_dim':    75,
#    'n_epochs':      100,
#    'dropout':       0.25,
#    'seed':          1701
#}

if not do_grid_search:
    # Best configuration
    train_kwargs = {
        'batch_size':    128,
        'lr':            0.01,
        'embedding_dim': 100,
        'hidden_dim':    100,
        'n_epochs':      100,
        'dropout':       0.0,
        'rebalance':     0.0,
        'seed':          1701
    }

    lstm = LSTM(n_threads=10)
    lstm.train(train_cands, train_marginals, X_dev=dev_cands, Y_dev=L_gold_dev, **train_kwargs)

    p, r, f1 = lstm.score(test_cands, L_gold_test)
    print("Prec: {0:.3f}, Recall: {1:.3f}, F1 Score: {2:.3f}".format(p, r, f1))

In [None]:
from snorkel.learning import RandomSearch
from snorkel.learning.pytorch import LSTM

if do_grid_search:
    seed = 12345
    num_model_search = 25

    # search over this parameter grid
    param_grid = {}
    param_grid['batch_size'] = [64, 128]
    param_grid['lr']         = [1e-4, 1e-3, 1e-2]
    param_grid['rebalance']  = [0.0,0.25, 0.5]
    param_grid['embedding_dim'] = [75, 100, 125]
    param_grid['hidden_dim'] = [50, 100, 150]
    param_grid['dropout'] = [0, 0.25, 0.5]

    model_class_params = {
        'n_threads':1
    }


    model_hyperparams = {
        'n_epochs': 100,
        'print_freq': 25,
        'dev_ckpt_delay': 0.5,
        'X_dev': dev_cands,
        'Y_dev': L_gold_dev,
    }


    searcher = RandomSearch(LSTM, param_grid, train_cands, train_marginals,
                            n=num_model_search, seed=seed,
                            model_class_params=model_class_params,
                            model_hyperparams=model_hyperparams)

    print("Discriminitive Model Parameter Space (seed={}):".format(seed))
    for i, params in enumerate(searcher.search_space()):
        print("{} {}".format(i, params))

    disc_model, run_stats = searcher.fit(X_valid=dev_cands, Y_valid=L_gold_dev, n_threads=1)
    lstm = disc_model

In [None]:
if do_grid_search:
    run_stats

In [None]:
tp, fp, tn, fn = lstm.error_analysis(session, test_cands, L_gold_test)

In [None]:
lstm.save_marginals(session, all_cands)

In [None]:
from snorkel.models.candidate import Marginal
from snorkel.models import Document, Sentence


print("Storing candidate labels into result file...")
amount_of_candidates = session.query(Candidate).count()
print("Amount of candidates: {}".format(amount_of_candidates))
all_sents = []
all_sents.extend(train_sent)
all_sents.extend(dev_sent)
all_sents.extend(test_sent)

header_str = '{}\t{}\t{}\t{}\t{}\t{}\t{}'.format('document_id', 'sentence_id', 'cand_id','chemical_cid', 'chemical_span', 'gene_cid', 'gene_span')
%time KSUtils.save_binary_relation_as_tsv('results/chemical_gene_interaction.tsv', session, all_cands, all_sents, header_str, 'chemical_cid', 'gene_cid')

In [None]:
lstm.save("chemical_gene_interaction.lstm")