In [None]:
do_grid_search = False

In [None]:
import os

os.environ['SNORKELDB'] = "postgresql://localhost:5432/snorkel?user=snorkel&password=snorkel12345"

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

from snorkel import SnorkelSession

session = SnorkelSession()

In [None]:
from ksnorkel import KSUtils

train_sent, dev_sent, test_sent = KSUtils.split_sentences(session, split=[0.8, 0.1, 0.1], seed=12345)

In [None]:
from snorkel.models import Candidate, candidate_subclass
from snorkel.candidates import PretaggedCandidateExtractor


GeneChemicalMetabolism = candidate_subclass('GeneChemicalMetabolism', ['gene', 'chemical'])
candidate_gene_chemical_metabolism_extractor = PretaggedCandidateExtractor(GeneChemicalMetabolism, ['Gene', 'Chemical'])

for k, sents in enumerate([train_sent, dev_sent, test_sent]):
    candidate_gene_chemical_metabolism_extractor.apply(sents, split=k, clear=True)
    print("Number of candidates:", session.query(GeneChemicalMetabolism).filter(GeneChemicalMetabolism.split == k).count())

In [None]:
import gzip
import re

ctd_chem_gene_meta = set()
with gzip.open('data/CTD_chem_gene_ixns.tsv.gz','r') as f:
    for l in f:
        line = str(l).replace('b\'', '').replace('\\n\'', '').replace('\\r','')
        # skip comments
        if line.startswith('#'):
            continue
        #print(line)
        components = line.split('\\t')    
        # add MESH:
        if not components[1].startswith('MESH:'):
            components[1] = "MESH:" + components[1]
            
                
        chemical = components[1]
        gene = components[4]
        key = frozenset((chemical, gene))
                  
        desc = components[8].lower()
        if 'metabol' not in desc and 'increas' not in desc:
            continue
            
        descTyp = components[9]
        if 'increases' not in descTyp and 'affects' not in descTyp:
            continue
        
        che_name = components[0].replace('b\"','').replace('(','\(').replace(')','\)').lower()
        regex = '{}[^\]\[]+metabo'.format(che_name)
        regex2 = '{}[^\]\[]+increase'.format(che_name)
        if re.search(regex, desc) == None and re.search(regex2, desc) == None:
            continue
    

        ctd_chem_gene_meta.add(key)

print('{} chemical-gene metabolisation read from ChG-CTD_chem_gene_ixns'.format(len(ctd_chem_gene_meta)))

def cand_in_chemical_gene_metabolised(c):
    key = frozenset((c.chemical_cid, c.gene_cid))
    if key in ctd_chem_gene_meta:
        return 1
    return -1

In [None]:
import re
from snorkel.lf_helpers import get_tagged_text



# List to parenthetical
def ltp(x):
    return '(' + '|'.join(x) + ')'

metabol_keywords = ['metabolising', 'metabolism', 'metabolisation', 'metabolization', 
                          'increased activity', 'metabolite', 'stimulation',
                          'metabol', 'increase', 'stimulate', 'stimulating', 'induce',
                          'activate', 'potentiate']

def gold_label_function(c):
    if cand_in_chemical_gene_metabolised(c) == 1:
        if re.search(r'' + ltp(metabol_keywords), get_tagged_text(c), re.I):
            return 1
        else:
            return -1
    else:
        return -1


In [None]:
from ksnorkel import KSUtils

KSUtils.add_gold_labels_for_candidates(session, GeneChemicalMetabolism, gold_label_function)

In [None]:
import re
from snorkel.lf_helpers import (
    get_tagged_text,
    rule_regex_search_tagged_text,
    rule_regex_search_btw_AB,
    rule_regex_search_btw_BA,
    rule_regex_search_before_A,
    rule_regex_search_before_B
)




# List to parenthetical
def ltp(x):
    return '(' + '|'.join(x) + ')'


metabolisation_samples = ['metabolising', 'metabolism', 'metabolisation', 'metabolization', 
                          'increased activity', 'metabolite', 'stimulation']

def LF_GC_AB_metabolisation_after_AB(c):
    return 1 if (re.search(r'{{A}}.{0,100}' + '{{B}}.{0,100}' + ltp(metabolisation_samples) , get_tagged_text(c), re.I) or
                re.search(r'{{B}}.{0,100}' + '{{A}}.{0,100}' +ltp(metabolisation_samples) , get_tagged_text(c), re.I)) else 0

def LF_GC_metabolisation_before_AB(c):
    return 1 if (rule_regex_search_before_A(c,  ltp(metabolisation_samples) + '.{0,100}', 1) and
                  rule_regex_search_before_B(c, ltp(metabolisation_samples) + '.{0,100}', 1)) else 0



def LF_CG_A_metabolisation_B_far(c):
    return rule_regex_search_btw_AB(c, '.{100,500}' + ltp(metabolisation_samples) + '.{100,500}', -1) 

def LF_GC_B_metabolisation_A_far(c):
    return rule_regex_search_btw_BA(c, '.{100,500}' + ltp(metabolisation_samples) + '.{100,500}', -1)

def LF_GC_A_metabolisation_B(c):
    return rule_regex_search_btw_AB(c, '.{0,100}' + ltp(metabolisation_samples) + '.{0,100}', 1) 

def LF_GC_B_metabolisation_A(c):
    return rule_regex_search_btw_BA(c, '.{0,100}' + ltp(metabolisation_samples) + '.{0,100}', 1)

def LG_GC_metabolisation_before_B_near(c):
    return rule_regex_search_before_B(c, ltp(metabolisation_samples) + '.{0,100}', 1)
 
def LG_GC_metabolisation_before_B_far(c):
    return rule_regex_search_before_B(c, ltp(metabolisation_samples) + '.{100,2000}', -1)




metabol_samples = ['metabol', 'increase', 'stimulate', 'stimulating', 'induce', 'activate', 'potentiate']


def LF_GC_A_metabol_B_in_sent(c):
    return rule_regex_search_btw_AB(c, '.{0,100}' + ltp(metabol_samples) + '.{0,100}', 1) 

def LF_GC_B_metabol_A_in_sent(c):
    return rule_regex_search_btw_AB(c, '.{0,100}' + ltp(metabol_samples) + '.{0,100}', 1) 


def LF_GC_B_metabolised_A(c):
    return rule_regex_search_btw_BA(c, '.{0,100}' + ltp(metabol_samples) + '.{0,100}', 1) 



not_samples = ['no', 'not']

def LF_GC_A_not_metabol_B(c):
    return rule_regex_search_btw_AB(c, '.{0,100}' + ltp(not_samples) + '.{0,100}' + ltp(metabol_samples) + '.{0,100}', -1) 


def LF_GC_B_not_metabolised_A(c):
    return rule_regex_search_btw_BA(c, '.{0,100}' + ltp(not_samples) + '.{0,100}' + ltp(metabol_samples) + '.{0,100}', -1) 


meta_all_words = []
meta_all_words.extend(metabol_samples)
meta_all_words.extend(metabolisation_samples)
     
def LF_GC_no_metabol(c):
    return -1 if not re.search(r'' + ltp(meta_all_words)  , get_tagged_text(c), re.I) else 0


def LF_GC_metabol_in_sent(c):
    return 1 if re.search(r'' + ltp(meta_all_words), get_tagged_text(c), re.I) else 0



inhibition_samples = ["inhibitor", "inhibition", "auto-inhibition", 'inhibitory', 'suppression']


def LF_GC_AB_before_inhibition(c):
    return -1 if (re.search(r'{{A}}.{0,100}' + '{{B}}.{0,100}' + 'inhibition' , get_tagged_text(c), re.I) or
                re.search(r'{{B}}.{0,100}' + '{{A}}.{0,100}' + 'inhibition' , get_tagged_text(c), re.I)) else 0


def LF_GC_A_supressive_effect_B(c):
    return -1 if  re.search(r'{{A}}.{0,100} supressive'+ '.{0,100} effect .{0,100}' + '{{A}}' , get_tagged_text(c), re.I) else 0


def LF_GC_AB_after_inhibition(c):
    return -1 if (rule_regex_search_before_A(c,  ltp(inhibition_samples) + '.{0,100}', 1) and
                  rule_regex_search_before_B(c, ltp(inhibition_samples) + '.{0,100}', 1)) else 0

def LF_GC_A_inhibition_B(c):
    return rule_regex_search_btw_AB(c, '.{0,100}' + ltp(inhibition_samples) + '.{0,100}', -1) 

def LF_GC_B_inhibition_A(c):
    return rule_regex_search_btw_BA(c, '.{0,100}' + ltp(inhibition_samples) + '.{0,100}', -1)

def LG_GC_inhibition_before_B_near(c):
    return rule_regex_search_before_B(c, ltp(inhibition_samples) + '.{0,100}', -1)
 


inhibits_samples = ["inhibit", 'decrease', 'degrees', 'suppress', 'alleviate', 'reduce']

def LF_GC_A_inhibitis_B(c):
    return rule_regex_search_btw_AB(c, '.{0,100}' + ltp(inhibits_samples) + '.{0,100}', -1) 

def LF_GC_A_inhibitis_B_close(c):
    return rule_regex_search_btw_AB(c, '.{0,50}' + ltp(inhibits_samples) + '.{0,50}', -1) 


def LF_GC_inhibits_in_sent(c):
    return -1 if re.search(r'' + ltp(inhibits_samples)  , get_tagged_text(c), re.I) else 0


# prefer meta if both are included
def LF_GC_meta_and_inhibts(c):
    if LF_GC_inhibits_in_sent(c) == -1 and LF_GC_metabol_in_sent(c) == 1:
        return 1
    return 0


def LF_closer_chem(c):
    # Get distance between chemical and gene
    chem_start, chem_end = c.chemical.get_word_start(), c.chemical.get_word_end()
    gen_start, gen_end = c.gene.get_word_start(), c.gene.get_word_end()
    if gen_start < chem_start:
        dist = chem_start - gen_end
    else:
        dist = gen_start - chem_end
    # Try to find chemical closer than @dist/2 in either direction
    sent = c.get_parent()
    closest_other_chem = float('inf')
    for i in range(gen_end, min(len(sent.words), gen_end + dist // 2)):
        et, cid = sent.entity_types[i], sent.entity_cids[i]
        if et == 'Chemical' and cid != sent.entity_cids[chem_start]:
            return -1
    for i in range(max(0, gen_start - dist // 2), gen_start):
        et, cid = sent.entity_types[i], sent.entity_cids[i]
        if et == 'Chemical' and cid != sent.entity_cids[chem_start]:
            return -1
    return 0


def LF_closer_gene(c):
    # Get distance between chemical and gene
    chem_start, chem_end = c.chemical.get_word_start(), c.chemical.get_word_end()
    gen_start, gen_end = c.gene.get_word_start(), c.gene.get_word_end()
    if gen_start < chem_start:
        dist = chem_start - gen_end
    else:
        dist = gen_start - chem_end
    # Try to find gene closer than @dist/8 in either direction
    sent = c.get_parent()
    for i in range(chem_end, min(len(sent.words), chem_end + dist // 8)):
        et, cid = sent.entity_types[i], sent.entity_cids[i]
        if et == 'Gene' and cid != sent.entity_cids[gen_start]:
            return -1
    for i in range(max(0, chem_start - dist // 8), chem_start):
        et, cid = sent.entity_types[i], sent.entity_cids[i]
        if et == 'Gene' and cid != sent.entity_cids[gen_start]:
            return -1
    return 0

                

def LF_CG_inhib_between(c):
    # Get distance between chemical and disease
    chem_start, chem_end = c.chemical.get_word_start(), c.chemical.get_word_end()
    dis_start, dis_end = c.gene.get_word_start(), c.gene.get_word_end()
    if dis_start < chem_start:
        dist = chem_start - dis_end
    else:
        dist = dis_start - chem_end
    # Try to find metabol closer than @dist/2 in either direction
    sent = c.get_parent()
    closest_other_chem = float('inf')
    for i in range(dis_end, min(len(sent.words), dis_end + dist // 2)):
        if re.search(r'' + ltp(inhibits_samples), sent.words[i], re.I):
            return -1
    for i in range(max(0, dis_start - dist // 2), dis_start):
        if re.search(r'' + ltp(inhibits_samples), sent.words[i], re.I):
            return -1
    return 0


    
LFs_GC = [
    gold_label_function,
    
    LF_GC_AB_metabolisation_after_AB,
    LF_GC_metabolisation_before_AB,
    LF_CG_A_metabolisation_B_far,
    LF_GC_B_metabolisation_A_far,
    LF_GC_A_metabolisation_B,
    LF_GC_B_metabolisation_A,
    LG_GC_metabolisation_before_B_near,
    LG_GC_metabolisation_before_B_far,
    LF_GC_A_metabol_B_in_sent,
    LF_GC_B_metabol_A_in_sent,
    LF_GC_B_metabolised_A,
    LF_GC_A_not_metabol_B,
    LF_GC_B_not_metabolised_A,
    LF_GC_no_metabol,
    LF_GC_metabol_in_sent,
    
    LF_GC_AB_before_inhibition,
    LF_GC_A_supressive_effect_B,
    LF_GC_AB_after_inhibition,
    LF_GC_A_inhibition_B,
    LF_GC_B_inhibition_A,
    LG_GC_inhibition_before_B_near,
    LF_GC_A_inhibitis_B,
    LF_GC_A_inhibitis_B_close,
    LF_GC_inhibits_in_sent,
    LF_GC_meta_and_inhibts,
    
    LF_closer_chem,
    LF_closer_gene,
    LF_CG_inhib_between
]

In [None]:
from snorkel.annotations import LabelAnnotator
labeler = LabelAnnotator(lfs=LFs_GC)

%time L_train = labeler.apply(lfs=LFs_GC, parallelism=10)
L_train

In [None]:
L_train.lf_stats(session)

In [None]:
from snorkel.learning import GenerativeModel

gen_model = GenerativeModel()
gen_model.train(L_train, epochs=100, decay=0.95, step_size=0.1 / L_train.shape[0], reg_param=1e-6)
train_marginals = gen_model.marginals(L_train)

In [None]:
import matplotlib.pyplot as plt
plt.hist(train_marginals, bins=20)
plt.show()

In [None]:
from snorkel.annotations import load_gold_labels

L_dev = labeler.apply_existing(split=1)
L_test = labeler.apply_existing(split=2)
L_gold_dev = load_gold_labels(session, annotator_name='gold',split=1)
L_gold_test = load_gold_labels(session, annotator_name='gold', split=2)

In [None]:
#from snorkel.learning import GenerativeModel
#from snorkel.learning import RandomSearch
#from snorkel.learning.structure import DependencySelector


#MAX_DEPS = 5

#ds = DependencySelector()
#deps = ds.select(L_train, threshold=0.1)
#deps = set(list(deps)[0:min(len(deps), MAX_DEPS)])

#print("Using {} dependencies".format(len(deps)))



# use random search to optimize the generative model
#param_grid = {
#    'step_size' : [1e-3, 1e-4, 1e-5, 1e-6],
#    'decay'     : [0.9, 0.95],
#    'epochs'    : [50,100,150],
#    'reg_param' : [1e-3],
#}

#model_class_params = {'lf_propensity' : False }#, 'deps': deps}

#searcher = RandomSearch(GenerativeModel, param_grid, L_train, n=10, model_class_params=model_class_params)
#%time gen_model, run_stats = searcher.fit(L_dev, L_gold_dev) #, deps=deps)
#run_stats

In [None]:
train_marginals = gen_model.marginals(L_train)

In [None]:
import matplotlib.pyplot as plt
plt.hist(train_marginals, bins=20)
plt.show()

In [None]:
from snorkel.annotations import save_marginals

dev_marginals = gen_model.marginals(L_dev)
_, _, _, _ = gen_model.error_analysis(session, L_dev, L_gold_dev)

%time save_marginals(session, L_train, train_marginals)

In [None]:
from snorkel.annotations import load_gold_labels

L_gold_dev = load_gold_labels(session, annotator_name='gold',split=1)
L_gold_test = load_gold_labels(session, annotator_name='gold', split=2)

from snorkel.annotations import load_marginals
train_marginals = load_marginals(session)

In [None]:
tp, fp, tn, fn = gen_model.error_analysis(session, L_test, L_gold_test)

In [None]:
print("Load all ChemicalGeneMetabolism candidates from db...")
train_cands = session.query(GeneChemicalMetabolism).filter(GeneChemicalMetabolism.split == 0).order_by(GeneChemicalMetabolism.id).all()
dev_cands   = session.query(GeneChemicalMetabolism).filter(GeneChemicalMetabolism.split == 1).order_by(GeneChemicalMetabolism.id).all()
test_cands  = session.query(GeneChemicalMetabolism).filter(GeneChemicalMetabolism.split == 2).order_by(GeneChemicalMetabolism.id).all()


all_cands = []
all_cands.extend(train_cands)
all_cands.extend(dev_cands)
all_cands.extend(test_cands)


print("{} {} {}".format(len(train_cands), len(dev_cands), len(test_cands)))
print("Amount of all candidates: {}".format(len(all_cands)))

In [None]:
from ksnorkel import RandomSearchGPU
from snorkel.learning.pytorch.rnn import LSTM



if do_grid_search:
    seed = 12345
    num_model_search = 5

    # search over this parameter grid
    param_grid = {}
    param_grid['batch_size'] = [64, 128]
    param_grid['lr']         = [1e-4, 1e-3, 1e-2]
    param_grid['rebalance']  = [0.0,0.25, 0.5]
    param_grid['embedding_dim'] = [75, 100, 125]
    param_grid['hidden_dim'] = [50, 100, 150]
    param_grid['dropout'] = [0, 0.25, 0.5]

    model_class_params = {
        'n_threads':1
    }


    model_hyperparams = {
        'n_epochs': 100,
        'print_freq': 25,
        'dev_ckpt_delay': 0.5,
        'X_dev': dev_cands,
        'Y_dev': L_gold_dev,
    }

   
    searcher = RandomSearch(LSTM, param_grid, train_cands, train_marginals,
                            n=num_model_search, seed=seed,
                            model_class_params=model_class_params,
                            model_hyperparams=model_hyperparams)

    print("Discriminitive Model Parameter Space (seed={}):".format(seed))
    for i, params in enumerate(searcher.search_space()):
        print("{} {}".format(i, params))

    disc_model, run_stats = searcher.fit(X_valid=dev_cands, Y_valid=L_gold_dev, n_threads=1)
    lstm = disc_model
else:
    from snorkel.learning.pytorch_gpu.rnn import LSTM


    # best config
    train_kwargs = {
        'batch_size':    64,
        'lr':            0.0010,
        'embedding_dim': 75,
        'hidden_dim':    150,
        'n_epochs':      100,
        'dropout':       0.0,
        'rebalance':     0.0,
        'seed':          1701,
        "max_sentence_length": 4000
    }


    lstm = LSTM(n_threads=1)
    lstm.train(train_cands, train_marginals, 
               X_dev=dev_cands, 
               Y_dev=L_gold_dev,
               dev_ckpt=True,
               use_cudnn=True,
               **train_kwargs)


    import torch


    torch.cuda.empty_cache()

In [None]:
tp, fp, tn, fn = lstm.error_analysis(session, test_cands, L_gold_test)

In [None]:
p, r, f1 = lstm.score(test_cands, L_gold_test)
print("Prec: {0:.3f}, Recall: {1:.3f}, F1 Score: {2:.3f}".format(p, r, f1))

In [None]:
lstm.save_marginals(session, all_cands)

In [None]:
from snorkel.models.candidate import Marginal
from snorkel.models import Document, Sentence


print("Storing candidate labels into result file...")
amount_of_candidates = session.query(Candidate).count()
print("Amount of candidates: {}".format(amount_of_candidates))

all_sents = []
all_sents.extend(train_sent)
all_sents.extend(dev_sent)
all_sents.extend(test_sent)


In [None]:
header_str = '{}\t{}\t{}\t{}\t{}\t{}\t{}'.format('sentence_id', 'cand_id','gene_cid', 'gene_span', 'chemical_cid', 'chemical_span', 'sentence')
%time KSUtils.save_binary_relation_confusion_matrix_as_tsv('results/gene_chemical_metabolisation.tsv', session, all_cands, all_sents, header_str, 'gene_cid', 'chemical_cid')

In [None]:
header_str = '{}\t{}\t{}\t{}\t{}\t{}\t{}'.format('document_id', 'sentence_id', 'cand_id','gene_cid', 'gene_span', 'chemical_cid', 'chemical_span')
%time KSUtils.save_binary_relation_as_tsv('results/chemical_gene_metabolism.tsv', session, all_cands, all_sents, header_str, 'gene_cid', 'chemical_cid')

In [None]:
lstm.save("gene_chemical_metabolism.lstm")