# I. Prepare data

In [1]:
from tqdm import tqdm
from lxml import etree
import os, nltk

In [2]:
topic_tree = etree.parse('data/topics2016.xml')

def get_topic(i):# returns the summary string of the ith topic
    summary = topic_tree.xpath('//topic[@number="%d"]/summary/text()'%i)[0]
    return str(summary).lower()

# build a mapping of article name (PMCID) to its file path

PMC_PATH = '/local/XW/DATA/TREC/PMCs/'
pmcid2fpath = {}

for subdir1 in os.listdir(PMC_PATH):
    for subdir2 in os.listdir(os.path.join(PMC_PATH, subdir1)):
        diry = os.path.join(PMC_PATH, subdir1, subdir2)
#         print diry, len(os.listdir(diry))
        for fn in os.listdir(diry):
            pmcid = fn[:-5]
            fpath = os.path.join(diry, fn)
            pmcid2fpath[pmcid] = fpath

def get_article_abstract(pmcid):
    fpath = pmcid2fpath[pmcid]
    tree = etree.parse(fpath)
    abstract = tree.xpath('//abstract')[0]
    ret = u''+abstract.xpath('string(.)')
    if len(ret.split())<20: raise Exception('abstraction too short:'+pmcid)
    return ret.lower()

In [3]:
corpus = []
pmcid_2relevance = [{} for i in xrange(31)] # list of dict mapping pmcid to relevance
with open('data/qrels.txt') as f:
    for line in tqdm(f, total=37707): 
        topicid, _, pmcid, relevance = line.split()
        topicid = int(topicid)
        try:
            corpus.append(get_article_abstract(pmcid)) # !some articles don't have an abstract!
            pmcid_2relevance[topicid][pmcid] = int(relevance)
        except: pass

100%|██████████| 37707/37707 [05:47<00:00, 108.38it/s]


In [4]:
print '%d articles are retrieved' % len(corpus)

34289 articles are retrieved


In [5]:
from sklearn.feature_extraction.text import TfidfVectorizer
vectorizer = TfidfVectorizer()
vectorizer.fit_transform(corpus)

def get_idf(wd):
    if wd not in vectorizer.vocabulary_: return 1.0
    return vectorizer.idf_[ vectorizer.vocabulary_[wd] ]

vocab = set(vectorizer.vocabulary_.keys())

In [6]:
import os, sys, time
import numpy as np
from numpy.linalg import norm
import pandas as pd
from tqdm import tqdm
import cPickle as pk
np.random.seed(1) # to be reproducible

In [7]:
W2V_FPATH = '/local/XW/DATA/WORD_EMBEDDINGS/biomed-w2v-200.txt'
GLOVE_FPATH = '/local/XW/DATA/WORD_EMBEDDINGS/glove.6B.200d.txt'

In [8]:
word2vec = {} # maps word ---> embedding vector
with open(W2V_FPATH) as f:
    for line in tqdm(f, total=5443657): #5443657 400000
        vals = line.split()
        word = vals[0]
        if word in vocab:
            vec = np.asarray(vals[1:], dtype='float')
            word2vec[word] = vec
print 'found %d word vectors.' % len(word2vec)

100%|██████████| 5443657/5443657 [01:00<00:00, 89801.52it/s]

found 54911 word vectors.





In [9]:
from nltk.corpus import stopwords
stopwds = set(stopwords.words('english'))

In [10]:
get_topic(17)

'76-year-old female with personal history of diastolic congestive heart failure, atrial fibrillation on coumadin, presenting with low hematocrit and dyspnea.'

In [11]:
_queries = [get_topic(i) for i in xrange(1,21)] # top 20 topics' queries
QUERIES = []
for q in _queries:
    q2 = [wd for wd in q.split() if (wd not in stopwds) and (wd in word2vec)] # filter out stopword and words not in w2v
    QUERIES.append(q2)

In [12]:
map(len, QUERIES)

[7, 14, 26, 11, 12, 20, 16, 11, 13, 12, 14, 13, 23, 11, 40, 16, 11, 19, 20, 16]

In [43]:
N = 40 # = max query length

In [17]:
WD_PLACEHOLDER = '</s>'
def pad_query(q, SZ=40):
    return q + [WD_PLACEHOLDER]*(SZ-len(q))
QUERIES = map(pad_query, QUERIES)

## helper functions

In [20]:
randvec = np.random.randn(200)
def get_histvec(q_wd, doc):
    if q_wd == WD_PLACEHOLDER: return np.zeros(30)
    qvec = word2vec.get(q_wd, randvec)
#     dvecs = np.vstack( [word2vec.get(wd, randvec) for wd in nltk.word_tokenize(doc)] )
    words_doc = filter(lambda wd:wd in word2vec, nltk.word_tokenize(doc))
    dvecs = np.vstack( [ word2vec[wd] for wd in words_doc ] )
    cossims = np.dot(dvecs, qvec) / norm(qvec) / norm(dvecs, axis=1)
    hist, _ = np.histogram( cossims[cossims<1.0], bins=29, range=(-1,1) )
    ones = len(cossims) - sum(hist)
    ret = np.array( list(hist) + [ones] )
    return ret # np.reshape(ret, (-1, 30))

In [21]:
def get_query_doc_feature(query, pmcid): # query: list of words
    doc = get_article_abstract(pmcid)
    return np.array([ get_histvec(wd, doc) for wd in query])

### prepare data

In [None]:
pos_ids, neg_ids = [], [] # pos_ids[q] is a list (positive pmcids for query `q`)
hists_pos, hists_neg = [], [] # hists_pos[q] is a list (positive hists for query `q`)
                              # hists_pos[q][i] is an array of size N*30 (the ith hists-feature array for query q)
for topic in xrange(1,6):
    query = QUERIES[topic-1]
    pos_ids_q, neg_ids_q = [], []
    hists_pos_q, hists_neg_q = [], []
    relevance = pmcid_2relevance[topic]
    for pmcid in tqdm(relevance.keys()):
        if relevance[pmcid]==0: 
            neg_ids_q.append(pmcid)
            hists_neg_q.append(get_query_doc_feature(query,pmcid))
        else: 
            pos_ids_q.append(pmcid)
            hists_pos_q.append(get_query_doc_feature(query,pmcid))
    hists_pos_q, hists_neg_q = map(np.array, [hists_pos_q, hists_neg_q])
    hists_pos.append(hists_pos_q); hists_neg.append(hists_neg_q)
    pos_ids.append(pos_ids_q); neg_ids.append(neg_ids_q)
print len(pos_ids), len(neg_ids)

 27%|██▋       | 355/1328 [00:07<00:18, 53.10it/s]

### prepare generator

In [183]:
VALDATION_SPLIT = 0.2
BATCH_SZ = 128
NB_EPOCH = 20

In [90]:
print map(len, hists_pos)
print map(len, hists_neg)

[124, 34]
[1204, 1153]


In [67]:
idx_pairs = [] # list of tuples of form (q, pos_idx, neg_idx)
for q in xrange(2):
    hists_pos_q, hists_neg_q = hists_pos[q], hists_neg[q]
    for pidx in xrange(len(hists_pos_q)):
        for nidx in xrange(len(hists_neg_q)):
            idx_pairs.append( (q, pidx, nidx) )
idx_pairs = np.array(idx_pairs)
val_sz = int(len(idx_pairs)*VALDATION_SPLIT)
idx_pairs_train, idx_pairs_val = idx_pairs[val_sz:], idx_pairs[:val_sz]

print idx_pairs_train.shape, idx_pairs_val.shape

(150799, 3) (37699, 3)


In [91]:
idx_batch = idx_pairs_train[:10]
zip(*idx_batch)

[(1, 0, 0, 1, 0, 0, 1, 0, 0, 0),
 (30, 117, 43, 15, 121, 47, 0, 76, 71, 109),
 (602, 416, 878, 114, 1195, 178, 840, 891, 148, 1078)]

In [95]:
IDFs = [ np.array ([ get_idf(wd) for wd in query]) for query in QUERIES]
IDFs = np.array(IDFs)

In [141]:
def batch_generator(idx_pairs, batch_size=BATCH_SZ): 
    np.random.shuffle(idx_pairs)
    batches_pre_epoch = len(idx_pairs) // batch_size
    samples_per_epoch = batches_pre_epoch * batch_size # make samples_per_epoch a multiple of batch size
    counter = 0
    y_true_batch_dummy = np.ones((batch_size))
    while 1:
        idx_batch = idx_pairs[batch_size*counter: min(samples_per_epoch, batch_size*(counter+1))]
#         q_batch, pidx_batch, nidx_batch = zip(*idx_batch)
#         idfs_batch = IDFs[q_batch,]
#         pos_batch = hists_pos[q_batch, pidx_batch]
#         neg_batch = hists_neg[q_batch, nidx_batch]
        idfs_batch, pos_batch, neg_batch = [], [], []
        for q, pidx, nidx in idx_batch:
            query = QUERIES[q]
            _idfs = np.array([get_idf(wd) for wd in query])
            idfs_batch.append(_idfs)
            pos_batch.append(hists_pos[q][pidx])
            neg_batch.append(hists_neg[q][nidx])
        idfs_batch, pos_batch, neg_batch = map(np.array, [idfs_batch, pos_batch, neg_batch])
#         print idfs_batch.shape, pos_batch.shape, neg_batch.shape
        counter += 1
        if (counter >= batches_pre_epoch):
            np.random.shuffle(idx_pairs)
            counter=0
        
        yield [idfs_batch, pos_batch, neg_batch], y_true_batch_dummy

# II. Define the deep relevance model

In [40]:
# define a function for visualization of model
import pydot
from IPython.display import SVG
from keras.utils.visualize_util import model_to_dot
def viz_model(model):
    return SVG(model_to_dot(model).create(prog='dot', format='svg'))

Using TensorFlow backend.


### construct the relevance IR model

In [41]:
from keras.models import Sequential, Model
from keras.layers import Dense, Activation, InputLayer, Flatten, Input, Merge, merge, Reshape
import keras.backend as K
from keras.callbacks import EarlyStopping, TensorBoard
import tensorflow as tf

In [44]:
# 2 main components of the structure: feed forward network and gating
feed_forward = Sequential(
    [Dense(input_dim=30, output_dim=5, activation='tanh'),
     Dense(output_dim=1, activation='tanh')], 
    name='feed_forward_nw')

# ***note: have to wrap ops into Lambda layers !!***
# cf: https://groups.google.com/forum/#!topic/keras-users/fbRS-FkZw_Q
from keras.layers.core import Lambda

input_idf = Input(shape=(N,), name='input_idf')
def scale(x): 
    w = K.variable(1, name='w_g')
    return K.mul(x,w)
def scale_output_shape(input_shape): return input_shape

scaled = Lambda(scale, scale_output_shape, name='softmax_scale')(input_idf)
gs_out = Activation('softmax', name='softmax')(scaled)
gating = Model(input=input_idf, output=gs_out, name='gating')

# first input: hist vectors
input_hists = Input(shape=(N,30), name='input_hists')

def slicei(x, i): return x[:,i,:]
def slicei_output_shape(input_shape): return (input_shape[0], input_shape[2])
zs = [ feed_forward( Lambda(lambda x:slicei(x,i), slicei_output_shape, name='slice%d'%i)(input_hists) )\
          for i in xrange(N) ]

def concat(x): return K.concatenate(x) 
def concat_output_shape(input_shape): return (input_shape[0][0], N)
zs = Lambda(concat, concat_output_shape, name='concat_zs')(zs)

# second input: idf scores of each query term 
input_idf = Input(shape=(N,), name='input_idf')
gs = gating(input_idf)

def innerprod(x): return K.sum( K.mul(x[0],x[1]), axis=1)
def innerprod_output_shape(input_shape): return (input_shape[0][0],1)
scores = Lambda(innerprod, innerprod_output_shape, name='innerprod_zs_gs')([zs, gs])

scoring_model = Model(input=[input_idf, input_hists], output=[scores], name='scoring_model')

# third input -- the negative hists vector 
input_hists_neg = Input(shape=(N,30), name='input_hists_neg')

zs_neg = [ feed_forward( Lambda(lambda x:slicei(x,i), slicei_output_shape, name='slice%d_neg'%i)(input_hists_neg) )\
          for i in xrange(N) ]

zs_neg = Lambda(concat, concat_output_shape, name='concat_zs_neg')(zs_neg)

scores_neg = Lambda(innerprod, innerprod_output_shape, name='innerprod_zs_gs_neg')([zs_neg, gs])

two_score_model = Model(input=[input_idf, input_hists, input_hists_neg], 
                        output=[scores, scores_neg], 
                        name='two_score_model')

def diff(x): return tf.sub(x[0], x[1]) #x[0]-x[1]
def diff_output_shape(input_shape): return input_shape[0]
posneg_score_diff = Lambda(diff, diff_output_shape, name='posneg_score_diff')([scores, scores_neg])
ranking_model = Model(input=[input_idf, input_hists,  input_hists_neg]
                      , output=[posneg_score_diff]
                      , name='ranking_model')

# define my loss function: hinge of score_pos - score_neg
def pairwise_hinge(y_true, y_pred): # y_pred = score_pos - score_neg, **y_true doesn't matter here**
    return K.mean( K.maximum(1. - y_pred, y_true*0.0) )  

ranking_model.compile(optimizer='adagrad', loss=pairwise_hinge)

# III. train model (for topic-1)

## train model using `fit_generator`

In [49]:
# self-defined metrics
def ranking_acc(y_true, y_pred):
    y_pred = y_pred > 0 
    return K.mean(y_pred)

In [50]:
ranking_model.compile(optimizer='adagrad', loss=pairwise_hinge, metrics=[ranking_acc])

In [143]:
logdir = './log/relevance_matching'
_callbacks = [ EarlyStopping(monitor='val_loss', patience=2),
               TensorBoard(log_dir=logdir, histogram_freq=0, write_graph=False) ]
# 30-5-1 tanh
ranking_model.fit_generator( batch_generator(idx_pairs_train), 
                    samples_per_epoch = len(idx_pairs_train)//BATCH_SZ*BATCH_SZ,
                    nb_epoch=NB_EPOCH,
                    validation_data=batch_generator(idx_pairs_val),
                    nb_val_samples=len(idx_pairs_val)//BATCH_SZ*BATCH_SZ, 
                    callbacks = _callbacks)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20


<keras.callbacks.History at 0x7f4bc6a36650>

------------------

## below are some testing stuff

In [145]:
zip(pos_ids[0][:10], neg_ids[0][:10])

[('3429740', '3921765'),
 ('2999735', '4085271'),
 ('3526517', '4532844'),
 ('1750992', '2769307'),
 ('3503351', '3809984'),
 ('4130891', '4685986'),
 ('4471306', '2996340'),
 ('3286730', '4395018'),
 ('4659952', '28994'),
 ('4724023', '4716452')]

In [147]:
pos_sample = get_query_doc_feature(query, '3429740')
neg_sample = get_query_doc_feature(query, '3921765')
pair_sample = np.array([pos_sample, neg_sample])
query = QUERIES[0]
_idf = np.array([get_idf(wd) for wd in query])
idf_sample = np.vstack([_idf]*2)

print idf_sample.shape, pair_sample.shape

(2, 40) (2, 40, 30)


### test `scoring_model`

In [148]:
scoring_model.predict([idf_sample,pair_sample])

array([-0.73008156, -0.47747353], dtype=float32)

In [149]:
a = feed_forward.predict(pos_sample)
print a

[[-0.81683552]
 [-0.82646966]
 [-0.99879682]
 [-0.99982327]
 [-0.85237795]
 [-0.81685901]
 [-0.81277031]
 [-0.70481849]
 [-0.8772012 ]
 [ 0.20669119]
 [-0.62638611]
 [-0.81685901]
 [-0.81684172]
 [-0.8168422 ]
 [ 0.58492392]
 [ 0.58492392]
 [ 0.58492392]
 [ 0.58492392]
 [ 0.58492392]
 [ 0.58492392]
 [ 0.58492392]
 [ 0.58492392]
 [ 0.58492392]
 [ 0.58492392]
 [ 0.58492392]
 [ 0.58492392]
 [ 0.58492392]
 [ 0.58492392]
 [ 0.58492392]
 [ 0.58492392]
 [ 0.58492392]
 [ 0.58492392]
 [ 0.58492392]
 [ 0.58492392]
 [ 0.58492392]
 [ 0.58492392]
 [ 0.58492392]
 [ 0.58492392]
 [ 0.58492392]
 [ 0.58492392]]


In [150]:
b = gating.predict(idf_sample)[0]
print b

[ 0.11612402  0.00800367  0.00871019  0.02551604  0.0907026   0.06773899
  0.60106444  0.0024891   0.0024891   0.0024891   0.0024891   0.0024891
  0.0024891   0.0024891   0.0024891   0.0024891   0.0024891   0.0024891
  0.0024891   0.0024891   0.0024891   0.0024891   0.0024891   0.0024891
  0.0024891   0.0024891   0.0024891   0.0024891   0.0024891   0.0024891
  0.0024891   0.0024891   0.0024891   0.0024891   0.0024891   0.0024891
  0.0024891   0.0024891   0.0024891   0.0024891 ]


In [151]:
b.dot(a)

array([-0.7300815], dtype=float32)

In [37]:
c = feed_forward.predict(neg_sample)
print c
print b.dot(c)

[[ 0.87186098]
 [ 0.7142241 ]
 [-0.96034747]
 [ 0.94236648]
 [-0.3230921 ]
 [ 0.99326599]
 [-0.93802834]
 [ 0.99049699]
 [-0.08088312]
 [ 0.98430783]
 [ 0.87186098]]
[ 0.44491631]


==> the scoring model works all right

### test ranking_model

In [152]:
ranking_model.predict( [idf_sample, pair_sample, np.array([neg_sample, pos_sample]) ])

array([-0.25260803,  0.25260803], dtype=float32)

In [154]:
-0.73008156 - -0.47747353

-0.25260803000000004

In [160]:
def predict_score(pmcid):
    _idf = np.array([get_idf(wd) for wd in query])
    _idf = np.vstack([_idf])
    _hist = get_query_doc_feature(query, pmcid).reshape(1,N,30)
    return scoring_model.predict([_idf, _hist])[0]

In [164]:
predict_score('3429740')

0.35019341

### see some results

In [163]:
zip( map(predict_score, pos_ids[0][:10]), map(predict_score, neg_ids[0][:10]))

[(0.35019341, 0.30549839),
 (0.37719917, -0.70841134),
 (0.38686728, -0.66828567),
 (-0.60399097, 0.36510727),
 (-0.7060799, -0.68677717),
 (0.33125544, 0.021299459),
 (0.35732347, -0.79787028),
 (0.37714282, -0.69743919),
 (0.37723261, -0.70512909),
 (0.48084351, -0.76500708)]

In [42]:
def predict_score_diff( (pmcid_pos, pmcid_neg) ):
    _idf = np.array([get_idf(wd) for wd in query.split()])
    _idf = np.vstack([_idf])
    hist_pos = get_query_doc_feature(query, pmcid_pos).reshape((1,11,30))
    hist_neg = get_query_doc_feature(query, pmcid_neg).reshape((1,11,30))
    return ranking_model.predict([_idf, hist_pos, hist_neg])[0]

### test the scoring model (metrics=AP )

In [166]:
def AP(pos_scores, neg_scores):
    Q = len(pos_scores)
    pos_tags = [1] * len(pos_scores)
    neg_tags = [0] * len(neg_scores)
    all_tagged = zip(pos_scores, pos_tags) + zip(neg_scores, neg_tags)
    ranked_list = sorted(all_tagged, reverse=True)
    print ranked_list[:20]
    ranked_tag = zip(*ranked_list)[1]
    print ranked_tag[:20]
    precision_at_i = []
    corr, total = 0.0, 0
    while corr<Q:
        if ranked_tag[total]==1: 
            corr += 1
            precision_at_i.append(corr*1.0 / (total+1) )
        total += 1
    print precision_at_i[:20]
    return np.mean(precision_at_i)

In [179]:
def AP_of_topic(q):
    query = QUERIES[q]
    _idf = np.array([get_idf(wd) for wd in query])
    _idfs = np.vstack([_idf]*len(hists_pos[q]))
    pos_scores = scoring_model.predict( [ _idfs, hists_pos[q]])
    _idfs = np.vstack([_idf]*len(hists_neg[q]))
    neg_scores = scoring_model.predict( [ _idfs, hists_neg[q]])
    print 'mean:', pos_scores.mean(), neg_scores.mean()
    print 'max:',pos_scores.max(), neg_scores.max()
    print 'min:',pos_scores.min(), neg_scores.min()
    return AP(pos_scores, neg_scores)

In [180]:
AP_of_topic(0)

mean: 0.192697 -0.492493
max: 0.582177 0.585927
min: -0.800651 -0.843911
[(0.58592725, 0), (0.58572489, 0), (0.5857085, 0), (0.58217734, 1), (0.58211708, 0), (0.57812393, 1), (0.56911486, 0), (0.56231618, 0), (0.54702652, 0), (0.54344326, 1), (0.52571243, 0), (0.52161843, 1), (0.5163936, 0), (0.51093429, 0), (0.50954407, 1), (0.50019461, 1), (0.4983612, 1), (0.49835235, 0), (0.49779099, 1), (0.48898831, 0)]
(0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0)
[0.25, 0.3333333333333333, 0.3, 0.3333333333333333, 0.3333333333333333, 0.375, 0.4117647058823529, 0.42105263157894735, 0.4090909090909091, 0.4166666666666667, 0.3793103448275862, 0.3870967741935484, 0.3611111111111111, 0.358974358974359, 0.375, 0.38095238095238093, 0.3541666666666667, 0.3673469387755102, 0.38, 0.38461538461538464]


0.32909877623586348

In [181]:
AP_of_topic(1)

mean: -0.0839959 -0.612511
max: 0.769786 0.854114
min: -0.872192 -0.98742
[(0.85411352, 0), (0.84318459, 0), (0.825872, 0), (0.81277239, 0), (0.79127645, 0), (0.77174681, 0), (0.76978552, 1), (0.76650035, 1), (0.76614183, 0), (0.76086164, 0), (0.75385493, 0), (0.75144458, 0), (0.74910921, 0), (0.74099368, 1), (0.74041063, 0), (0.73638779, 0), (0.73409617, 0), (0.73292655, 0), (0.73257864, 0), (0.72589278, 0)]
(0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0)
[0.14285714285714285, 0.25, 0.21428571428571427, 0.14814814814814814, 0.1724137931034483, 0.1875, 0.12962962962962962, 0.11428571428571428, 0.08256880733944955, 0.08264462809917356, 0.0859375, 0.09090909090909091, 0.0948905109489051, 0.09929078014184398, 0.1048951048951049, 0.10666666666666667, 0.10625, 0.10975609756097561, 0.1144578313253012, 0.11428571428571428]


0.098118803256635245