# I. Prepare data (run once, then pickle to file)

Things to be pickled: 

* `pmcid2fpath`    
    mapping pmcid to the corresponding file path
* `corpus: dict[str, str]`   
    mapping pmcid to it's content, the content is a BOW representataion (a Counter), it is pickled into another file called `corpus.pk`
* `QUERIES: dict[int, list<str>]`     
    mapping qid to it's query represented as a list of words
* `QUERIES_padded`   
    same as `QUERIES`, but all queries are padded to the same length (=`MAX_QLEN`) with `WD_PLACEHOLDER`
* `IDF: dict[str, float]`   
    mapping a word to its idf
* `relevance: dict[(int,str), int]`   
    mapping (qid,docid) pairs to relevance (0,1,2)
* `n_pos`    
    mapping `qid` to its number of training positive documents
* `candidates: dict[int, list<str>]`  
    mapping qid to list of its candidate docids (that appeared in the qrel)
* `qid_docid2histvec: dict[(int,str), array]`    
    mapping from (qid, docid) to the corresponding histvec
* `instances: dict[int, list<(str,str)>]`    
    mapping qid to list, instances[qid] = list of (pos_docid, neg_docid) pairs for qid,

In [1]:
from tqdm import tqdm
from lxml import etree
import nltk

import os, sys, time
import numpy as np
from numpy.linalg import norm
import pandas as pd
from tqdm import tqdm
import cPickle as pk
np.random.seed(1) 

In [111]:
W2V_FPATH = '/local/XW/DATA/WORD_EMBEDDINGS/W2V_BIO/wikipedia-pubmed-and-PMC-w2v.bin'
# GLOVE_FPATH = '/local/XW/DATA/WORD_EMBEDDINGS/glove.6B.200d.txt'
WD_PLACEHOLDER = '</s>'
PMC_PATH = '/local/XW/DATA/TREC/PMCs/'
PK_FOUT = 'data/TREC-DRMM-preprocessed-0125.pk'

### helper function

In [5]:
topic_tree = etree.parse('data/topics2016.xml')

def get_topic(i):# returns the summary string of the ith topic
    summary = topic_tree.xpath('//topic[@number="%d"]/summary/text()'%i)[0]
    return str(summary).lower().strip()

# build a mapping of article name (PMCID) to its file path
pmcid2fpath = {}

for subdir1 in os.listdir(PMC_PATH):
    for subdir2 in os.listdir(os.path.join(PMC_PATH, subdir1)):
        diry = os.path.join(PMC_PATH, subdir1, subdir2)
#         print diry, len(os.listdir(diry))
        for fn in os.listdir(diry):
            pmcid = fn[:-5]
            fpath = os.path.join(diry, fn)
            pmcid2fpath[pmcid] = fpath

def get_article_abstract(pmcid): # get article title and abstract
    fpath = pmcid2fpath[pmcid]
    tree = etree.parse(fpath)
    ret = u'' + tree.xpath('string(//article-title)') + '\n'
    abstracts = tree.xpath('//abstract')
#     abstracts = tree.xpath('//p')
    ret += u' '.join( [abstract.xpath('string(.)') for abstract in abstracts] )
    if len(ret.split())<20: 
        raise Exception(u'abstraction too short: '+pmcid + ret)
    return ret.lower()

In [6]:
get_article_abstract('2362203')

u'evaluation of a follow-up programme after curative resection for colorectal cancer\nfrequent liver imaging can detect liver metastases from colorectal cancer at an asymptomatic stage. \xa9 1999 cancer research campaign'

In [7]:
get_topic(1)

'a 78 year old male presents with frequent stools and melena.'

### word2vec

In [8]:
from gensim.models import Word2Vec
word2vec = Word2Vec.load_word2vec_format(W2V_FPATH, binary=True)

### queries (stopwords removed)

In [9]:
from nltk.corpus import stopwords
stopwds = set(stopwords.words('english'))

In [12]:
QUERIES = {} # dict[int, list<str>] mapping qid to it's query represented as a list of words
for qid in xrange(1,31):
    query = get_topic(qid)
    q = [wd for wd in query.split() if wd not in stopwds]
    QUERIES[qid] = q

In [13]:
MAX_QLEN = max( map(len, QUERIES.values()) ) 
print MAX_QLEN

58


In [41]:
# padding queries to the same length MAX_QLEN
WD_PLACEHOLDER = '</s>'
def pad_query(q, SZ=MAX_QLEN):
    return q + [WD_PLACEHOLDER]*(SZ-len(q))
QUERIES_padded = {}
for qid in QUERIES:
    QUERIES_padded[qid] = pad_query(QUERIES[qid])

In [93]:
QUERIES = QUERIES_padded

### corpus

In [19]:
corpus = {} # dict[str, str] mapping pmcid to it's content

with open('data/qrels.txt') as f:
    for line in tqdm(f, total=37707): 
        qid, _, pmcid, rel = line.split()
        qid = int(qid)
        try:
            if pmcid not in corpus: 
                corpus[pmcid] = get_article_abstract(pmcid)
        except: 
            pass

100%|██████████| 37707/37707 [11:40<00:00, 71.79it/s]


In [21]:
print '%d articles are retrieved' % len(corpus)

26255 articles are retrieved


### relevance and candidates

In [76]:
from collections import defaultdict
candidates = defaultdict(list) # dict[int, list<str>] mapping qid to list of its candidate docids (that appeared in the qrel)
relevance = {} # dict[(int,str), int] mapping (qid,docid) pairs to its relevance (0,1,2)
n_pos = defaultdict(int) # dict[int, int] mapping qid to the number of positive documents in qrel
with open('data/qrels.txt') as f:
    for line in f: 
        qid, _, pmcid, rel = line.split()
        qid = int(qid); rel = int(rel)
        if pmcid in corpus: 
            candidates[qid].append(pmcid)
            relevance[(qid,pmcid)] = int(rel)
            if int(rel) > 0: n_pos[qid] += 1

In [77]:
print map(len, candidates.values())

[1340, 1196, 1353, 1328, 1384, 825, 1065, 1127, 1101, 1098, 954, 1158, 1336, 1138, 1023, 1441, 1034, 1079, 1146, 1022, 933, 1110, 1552, 1331, 1093, 961, 928, 1570, 738, 1237]


### IDF

In [25]:
from sklearn.feature_extraction.text import TfidfVectorizer
vectorizer = TfidfVectorizer()
vectorizer.fit_transform(corpus.values())
vocab = vectorizer.vocabulary_ # mapping word to its internal index

In [30]:
def get_idf(wd):
    if wd ==WD_PLACEHOLDER: return -10.0
    if wd not in vectorizer.vocabulary_: 
        return 5.0 # give a high score for rare words in query... 
    return vectorizer.idf_[ vectorizer.vocabulary_[wd] ]

In [110]:
IDFs = { qid:np.array([ get_idf(wd) for wd in query]) for qid,query in QUERIES.iteritems()} # mapping qid to IDFs array

### helper functions

In [95]:
def similarity(wd1, wd2):
    if wd1==wd2: return 1.0
    if wd1 in word2vec and wd2 in word2vec: 
        return word2vec.similarity(wd1,wd2)
    else: return None

def get_histvec(q_wd, doc): # get LCH feature for qwd, doc
    if q_wd == WD_PLACEHOLDER: 
        return np.zeros(30)
    doc_words = doc.split()
    hist = np.zeros(30)
    # we'll use the same method as in NN4IR.cpp, line 774
    for d_wd in doc_words:
        sim = similarity(q_wd, d_wd)
        if sim is None: continue
        idx = (sim+1.0)/2.0 * (30-1) # position in the histogram
        hist[int(idx)] += 1.0 
    ret = np.log10(hist+1.0) # line 815
    return ret 

In [96]:
def get_query_doc_feature(qid, docid):
    query = QUERIES[qid]
    doc = corpus[docid]
    return np.array([ get_histvec(qwd, doc) for qwd in query])

### qid, pmcid to histvec

In [109]:
qid_docid2histvec = {} # mapping from (qid, docid) to histvec
for qid in tqdm(QUERIES.keys()):
    for docid in candidates[qid]:
        _hist = get_query_doc_feature(qid, docid).reshape(1,MAX_QLEN,30)
        qid_docid2histvec[(qid, docid)] = _hist

100%|██████████| 30/30 [31:40<00:00, 52.57s/it]


## prepare training instances -- (posid, negid) pairs, cf `NN4IR::LoadDataSet`

* for each posid, use <= `pernegative` negids for making (posid, negid) pairs, 
* for each query, gen <= `num_of_instance` pairs
* **trick**: if a query has fewer pos docids than others, then the params `pernegative` and `num_of_instance` are augmented accordingly (x5/x3/x1.5 at 25/50/75 quantile)

In [115]:
print sorted(n_pos.items(), key=lambda (k,v): v)
all_pos = sorted( n_pos.values() ) 
print all_pos

[(22, 8), (27, 12), (4, 18), (10, 19), (2, 34), (30, 34), (18, 63), (7, 68), (21, 68), (15, 69), (5, 95), (26, 103), (12, 107), (23, 108), (9, 117), (14, 117), (29, 118), (1, 125), (13, 128), (6, 133), (3, 144), (17, 171), (16, 175), (28, 203), (19, 204), (25, 213), (11, 350), (20, 616), (24, 714), (8, 809)]
[8, 12, 18, 19, 34, 34, 63, 68, 68, 69, 95, 103, 107, 108, 117, 117, 118, 125, 128, 133, 144, 171, 175, 203, 204, 213, 350, 616, 714, 809]


In [119]:
# -- quantiles of pos docid numbers
avg_pos_80 = all_pos[len(all_pos) * 9 / 10 - 1] # x1.5
avg_pos_50 = all_pos[len(all_pos) * 5 / 10 - 2] # x3
avg_pos_10 = all_pos[len(all_pos) * 5 / 30] # x10
print avg_pos_10, avg_pos_50, avg_pos_80 # quantiles of posid numbers

34 108 350


In [157]:
instances = {} # mapping qid to list, instances[qid] = list (pos_docid, neg_docid) pairs for qid, 
# use pairs in instances for training
np.random.seed(1)
for qid in QUERIES.keys():
    
    pernegative = 20 # number of limited pairs per positive sample
    num_of_instances = 8000 # number limit of pairs per query
    
    num_pos_currquery = n_pos[qid]
    curr_pernegative = pernegative
    curr_num_of_instance = num_of_instances # -- their trick: gen less pairs for queries with more pos docs
    if(num_pos_currquery <= avg_pos_10): 
        curr_pernegative *= 10; curr_num_of_instance *= 10
    elif(num_pos_currquery <= avg_pos_50): 
        curr_pernegative *= 3; curr_num_of_instance *= 3; 
    elif(num_pos_currquery <= avg_pos_80): 
        curr_pernegative *= 1.5; curr_num_of_instance *= 1.5; 
    
    rel_scores = defaultdict(list) # mapping a rel score to list of docids
    for docid in candidates[qid]:
        rel = relevance[(qid,docid)]
        rel_scores[rel].append(docid)
    scores = sorted( rel_scores.keys(), reverse=True ) # scores are sorted in desc order
    print 'scores =',scores, 
    total_instance = 0
    for i in xrange(len(scores)): # scores[i] = pos score
        for j in xrange(i+1, len(scores)): # scores[j] = neg score
            total_instance += len(rel_scores[scores[i]]) * len(rel_scores[scores[j]])
    print 'total=', total_instance, 
    total_instance = min(total_instance, curr_num_of_instance)
    from numpy.random import choice 
    instances_for_q = []
    for i in xrange(len(scores)):# scores are sorted in desc order
        pos_score = scores[i]
        cur_pos_ids = rel_scores[pos_score] # mapping a rel score to list of docids
        cur_neg_ids = []
        for j in xrange(i+1, len(scores)):
            neg_score = scores[j]
            cur_neg_ids += rel_scores[neg_score]# FOUND A BUG HERE
        if len(cur_neg_ids)==0: break
        for posid in cur_pos_ids:
            for negid in choice(cur_neg_ids, min(len(cur_neg_ids),int(curr_pernegative)), replace=False):
                instances_for_q.append( (posid,negid) )
            if len(instances_for_q)>=total_instance: break
        if len(instances_for_q)>=total_instance: break
    print 'got %d instances for query %d' % (len(instances_for_q), qid)
    instances[qid] = instances_for_q

scores = [2, 1, 0] total= 154789 got 3750 instances for query 1
scores = [2, 1, 0] total= 39788 got 6800 instances for query 2
scores = [2, 1, 0] total= 177911 got 4320 instances for query 3
scores = [2, 1, 0] total= 23625 got 3600 instances for query 4
scores = [2, 1, 0] total= 123781 got 5700 instances for query 5
scores = [2, 1, 0] total= 95268 got 3990 instances for query 6
scores = [2, 1, 0] total= 68628 got 4080 instances for query 7
scores = [2, 1, 0] total= 397930 got 8000 instances for query 8
scores = [2, 1, 0] total= 118530 got 3510 instances for query 9
scores = [2, 0] total= 20501 got 3800 instances for query 10
scores = [2, 1, 0] total= 238181 got 10500 instances for query 11
scores = [2, 1, 0] total= 114899 got 6420 instances for query 12
scores = [2, 1, 0] total= 158044 got 3840 instances for query 13
scores = [2, 1, 0] total= 120987 got 3510 instances for query 14
scores = [2, 1, 0] total= 65894 got 4140 instances for query 15
scores = [2, 1, 0] total= 229206 got 5250 

In [158]:
print '%d training instances in total' % sum( map(len, instances.values()) )

155800 training instances in total


## another way of negative sampling: each query having the same amount of pairs -- but works less as well?

In [160]:
N_PAIRS_PER_QUERY = 8000 # the smallest query (22) have 8816 possible pairs 
from numpy.random import choice 
instances = {} # mapping qid to list, instances[qid] = list (pos_docid, neg_docid) pairs for qid, 

np.random.seed(1)
for qid in QUERIES.keys():
    rel_scores = defaultdict(list) # mapping a rel score to list of docids
    for docid in candidates[qid]:
        rel = relevance[(qid,docid)]
        rel_scores[rel].append(docid)
    scores = sorted( rel_scores.keys(), reverse=True ) # scores are sorted in desc order
    print 'scores =',scores, 
    
    all_instances = []
    for i in xrange(len(scores)): 
        pos_score = scores[i]
        for j in xrange(i+1, len(scores)): 
            neg_score = scores[j]
            for posid in rel_scores[pos_score]:
                for negid in rel_scores[neg_score]: 
                    all_instances.append( (posid, negid) )
    
    instances_for_q = []
    for i in choice(len(all_instances), N_PAIRS_PER_QUERY, replace=False):
        instances_for_q.append(all_instances[i])
    
    print 'got %d instances out of %d for query %d' % (len(instances_for_q), len(all_instances), qid)
    instances[qid] = instances_for_q

 scores = [2, 1, 0] got 8000 instances out of 154789 for query 1
scores = [2, 1, 0] got 8000 instances out of 39788 for query 2
scores = [2, 1, 0] got 8000 instances out of 177911 for query 3
scores = [2, 1, 0] got 8000 instances out of 23625 for query 4
scores = [2, 1, 0] got 8000 instances out of 123781 for query 5
scores = [2, 1, 0] got 8000 instances out of 95268 for query 6
scores = [2, 1, 0] got 8000 instances out of 68628 for query 7
scores = [2, 1, 0] got 8000 instances out of 397930 for query 8
scores = [2, 1, 0] got 8000 instances out of 118530 for query 9
scores = [2, 0] got 8000 instances out of 20501 for query 10
scores = [2, 1, 0] got 8000 instances out of 238181 for query 11
scores = [2, 1, 0] got 8000 instances out of 114899 for query 12
scores = [2, 1, 0] got 8000 instances out of 158044 for query 13
scores = [2, 1, 0] got 8000 instances out of 120987 for query 14
scores = [2, 1, 0] got 8000 instances out of 65894 for query 15
scores = [2, 1, 0] got 8000 instances out 

In [161]:
print '%d training instances in total' % sum( map(len, instances.values()) )

240000 training instances in total


## pickle to file 

In [122]:
data_to_pickle = {
    'QUERIES': QUERIES,
    'corpus': corpus,
    'IDFs': IDFs, # mapping qid to its IDF input vector
    'candidates': candidates,# mapping qid to list of docids that corresponds to qid in the qrel file 
    'n_pos': n_pos, # n_pos[qid] = number of positive 
    'relevance': relevance,  # mapping (qid,docid) pairs to relevance (0,1,2)
    'qid_docid2histvec': qid_docid2histvec, # mapping (qid, docid) to histvec
    'instances': instances,  # instances[qid] = list (pos_docid, neg_docid) pairs for qid
}
with open(PK_FOUT, 'wb') as f:
    pk.dump(data_to_pickle, f, pk.HIGHEST_PROTOCOL)

### prepare generator

In [124]:
VALDATION_SPLIT = 0.2
BATCH_SZ = 64
NB_EPOCH = 50

In [126]:
idx_pairs = []
for qid in instances:
    for posid, negid in instances[qid]:
        idx_pairs.append( (qid,posid,negid) )

# II. Define DRMM model

In [127]:
# define a function for visualization of model
import pydot
from IPython.display import SVG
from keras.utils.visualize_util import model_to_dot
def viz_model(model):
    return SVG(model_to_dot(model).create(prog='dot', format='svg'))

Using TensorFlow backend.


In [128]:
from keras.models import Sequential, Model
from keras.layers import Dense, Activation, InputLayer, Flatten, Input, Merge, merge, Reshape
import keras.backend as K
from keras.callbacks import EarlyStopping, TensorBoard
import tensorflow as tf

In [131]:
# 2 main components of the structure: feed forward network and gating
feed_forward = Sequential(
    [Dense(input_dim=30, output_dim=5, activation='tanh'),
     Dense(output_dim=1, activation='tanh'),
     ], 
    name='feed_forward_nw')

feed_forward.summary()

____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
dense_3 (Dense)                  (None, 5)             155         dense_input_2[0][0]              
____________________________________________________________________________________________________
dense_4 (Dense)                  (None, 1)             6           dense_3[0][0]                    
Total params: 161
Trainable params: 161
Non-trainable params: 0
____________________________________________________________________________________________________


In [134]:
from keras.engine.topology import Layer

class MyLayer(Layer): # a scaled layer
    def __init__(self, **kwargs):
        super(MyLayer, self).__init__(**kwargs)
    def build(self, input_shape):
        self.output_dim = input_shape[1]
        self.W = self.add_weight(shape=(1,), # Create a trainable weight variable for this layer.
                                 initializer='one',
                                 trainable=True)
        super(MyLayer, self).build(input_shape)  # Be sure to call this somewhere!
    def call(self, x, mask=None):
        return tf.mul(x, self.W)
    def get_output_shape_for(self, input_shape):
        return (input_shape[0], self.output_dim)

input_idf = Input(shape=(MAX_QLEN,), name='input_idf')
scaled = MyLayer()(input_idf)
gs_out = Activation('softmax', name='softmax')(scaled)
gating = Model(input=input_idf, output=gs_out, name='gating')

gating.summary()

____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
input_idf (InputLayer)           (None, 58)            0                                            
____________________________________________________________________________________________________
mylayer_3 (MyLayer)              (None, 58)            1           input_idf[0][0]                  
____________________________________________________________________________________________________
softmax (Activation)             (None, 58)            0           mylayer_3[0][0]                  
Total params: 1
Trainable params: 1
Non-trainable params: 0
____________________________________________________________________________________________________


In [135]:
from keras.layers.core import Lambda

# first input: hist vectors
input_hists = Input(shape=(MAX_QLEN,30), name='input_hists')

def slicei(x, i): return x[:,i,:]
def slicei_output_shape(input_shape): return (input_shape[0], input_shape[2])
zs = [ feed_forward( Lambda(lambda x:slicei(x,i), slicei_output_shape, name='slice%d'%i)(input_hists) )\
          for i in xrange(MAX_QLEN) ]

def concat(x): return K.concatenate(x) 
def concat_output_shape(input_shape): return (input_shape[0][0], MAX_QLEN)
zs = Lambda(concat, concat_output_shape, name='concat_zs')(zs)

# second input: idf scores of each query term 
input_idf = Input(shape=(MAX_QLEN,), name='input_idf')
gs = gating(input_idf)

def innerprod(x): return K.sum( tf.mul(x[0],x[1]), axis=1)
def innerprod_output_shape(input_shape): return (input_shape[0][0],1)
scores = Lambda(innerprod, innerprod_output_shape, name='innerprod_zs_gs')([zs, gs])

scoring_model = Model(input=[input_idf, input_hists], output=[scores], name='scoring_model')

# third input -- the negative hists vector 
input_hists_neg = Input(shape=(MAX_QLEN,30), name='input_hists_neg')

zs_neg = [ feed_forward( Lambda(lambda x:slicei(x,i), slicei_output_shape, name='slice%d_neg'%i)(input_hists_neg) )\
          for i in xrange(MAX_QLEN) ]

zs_neg = Lambda(concat, concat_output_shape, name='concat_zs_neg')(zs_neg)

scores_neg = Lambda(innerprod, innerprod_output_shape, name='innerprod_zs_gs_neg')([zs_neg, gs])

two_score_model = Model(input=[input_idf, input_hists, input_hists_neg], 
                        output=[scores, scores_neg], 
                        name='two_score_model')

def diff(x): return tf.sub(x[0], x[1]) # **??? if I write `x[0]-x[1]` I get negative of the diff ???**
def diff_output_shape(input_shape): return input_shape[0]
posneg_score_diff = Lambda(diff, diff_output_shape, name='posneg_score_diff')([scores, scores_neg])
ranking_model = Model(input=[input_idf, input_hists,  input_hists_neg]
                      , output=[posneg_score_diff]
                      , name='ranking_model')

In [136]:
# define my loss function: hinge of score_pos - score_neg
def pairwise_hinge(y_true, y_pred): # y_pred = score_pos - score_neg, **y_true doesn't matter here**
    return K.mean( K.maximum(0.1 - y_pred, y_true*0.0) )  

# self-defined metrics
def ranking_acc(y_true, y_pred):
    y_pred = y_pred > 0 
    return K.mean(y_pred)

ranking_model.compile(optimizer='adagrad', loss=pairwise_hinge, metrics=[ranking_acc])

In [138]:
# viz_model(ranking_model)

In [140]:
gating.predict(IDFs[1].reshape(-1,MAX_QLEN))

array([[  1.08647980e-01,   8.66970513e-03,   9.70950723e-03,
          2.63532810e-02,   9.77671891e-02,   6.89344853e-02,
          6.77386999e-01,   2.52833311e-03,   4.22274447e-08,
          4.22274447e-08,   4.22274447e-08,   4.22274447e-08,
          4.22274447e-08,   4.22274447e-08,   4.22274447e-08,
          4.22274447e-08,   4.22274447e-08,   4.22274447e-08,
          4.22274447e-08,   4.22274447e-08,   4.22274447e-08,
          4.22274447e-08,   4.22274447e-08,   4.22274447e-08,
          4.22274447e-08,   4.22274447e-08,   4.22274447e-08,
          4.22274447e-08,   4.22274447e-08,   4.22274447e-08,
          4.22274447e-08,   4.22274447e-08,   4.22274447e-08,
          4.22274447e-08,   4.22274447e-08,   4.22274447e-08,
          4.22274447e-08,   4.22274447e-08,   4.22274447e-08,
          4.22274447e-08,   4.22274447e-08,   4.22274447e-08,
          4.22274447e-08,   4.22274447e-08,   4.22274447e-08,
          4.22274447e-08,   4.22274447e-08,   4.22274447e-08,
        

In [153]:
gating.get_weights()

[array([ 1.00318825], dtype=float32)]

# III. train model

In [139]:
initial_weights = ranking_model.get_weights()

In [125]:
def batch_generator(idx_pairs, batch_size=BATCH_SZ): 
    # ** parameter `idx_pairs` is list of tuple (qid, pos_docid, neg_docid)**
    np.random.shuffle(idx_pairs)
    batches_pre_epoch = len(idx_pairs) // batch_size
    samples_per_epoch = batches_pre_epoch * batch_size # make samples_per_epoch a multiple of batch size
    counter = 0
    y_true_batch_dummy = np.ones((batch_size))
    while 1:
        idx_batch = idx_pairs[batch_size*counter: min(samples_per_epoch, batch_size*(counter+1))]
        idfs_batch, pos_batch, neg_batch = [], [], []
        for qid, pos_docid, neg_docid in idx_batch:
            idfs_batch.append(IDFs[qid])
            pos_batch.append(qid_docid2histvec[(qid,pos_docid)].reshape(MAX_QLEN,30))
            neg_batch.append(qid_docid2histvec[(qid,neg_docid)].reshape(MAX_QLEN,30))
        idfs_batch, pos_batch, neg_batch = map(np.array, [idfs_batch, pos_batch, neg_batch])
#         print idfs_batch.shape, pos_batch.shape, neg_batch.shape
        counter += 1
        if (counter >= batches_pre_epoch):
            np.random.shuffle(idx_pairs)
            counter=0
        yield [idfs_batch, pos_batch, neg_batch], y_true_batch_dummy

In [142]:
def get_idx_pairs(qids):
    idx_pairs = []
    for qid in qids:
        for posid, negid in instances[qid]:
            idx_pairs.append( (qid,posid, negid) )
    return idx_pairs

In [143]:
def shuffle_weights(model, weights=None):
    """Randomly permute the weights in `model`, or the given `weights`.
    This is a fast approximation of re-initializing the weights of a model.
    Assumes weights are distributed independently of the dimensions of the weight tensors
      (i.e., the weights have the same distribution along each dimension).
    :param Model model: Modify the weights of the given model.
    :param list(ndarray) weights: The model's weights will be replaced by a random permutation of these weights.
      If `None`, permute the model's current weights.
    """
    if weights is None:
        weights = model.get_weights()
    weights = [np.random.permutation(w.flat).reshape(w.shape) for w in weights]
    model.set_weights(weights)

In [144]:
def TREC_output(qid, run_name = 'my_run', fpath = None):
    res = [] # list of (score, pmcid) tuples
    for docid in candidates[qid]:
        input_idf = IDFs[qid].reshape((-1,MAX_QLEN))
        input_hist = qid_docid2histvec[(qid,docid)]
        score = scoring_model.predict([input_idf, input_hist])[0]
        res.append( (score, docid) )
    res = sorted(res, reverse=True)
    fout = sys.stdout if fpath==None else open(fpath, 'a')
    for rank, (score, docid) in enumerate(res[:2000]):
        print >>fout, '%d  Q0  %s  %d  %f  %s' % (qid, docid, rank, score, run_name)

In [141]:
logdir = './logs/DRMM_0124'
_callbacks = [ EarlyStopping(monitor='val_loss', patience=2),
               TensorBoard(log_dir=logdir, histogram_freq=0, write_graph=False) ]

In [154]:
def KFold(fpath, K = 5, run_name = 'my_run',  batch_size=BATCH_SZ):
    open(fpath,'w').close() # clear previous content in file 
    qids = sorted( QUERIES.keys() )
    np.random.seed(0)
    np.random.shuffle(qids)
    fold_sz = len(QUERIES) / K
    for fold in xrange(K):
        print 'fold %d' % fold, 
        val_start, val_end = fold*fold_sz, (fold+1)*fold_sz
        qids_val = qids[val_start:val_end] # train/val queries for each fold 
        qids_train = qids[:val_start] + qids[val_end:]
        print qids_val
        idx_pairs_train = get_idx_pairs(qids_train)
        idx_pairs_val = get_idx_pairs(qids_val)
        
        shuffle_weights(ranking_model, initial_weights) # reset model parameters
        ranking_model.fit_generator( batch_generator(idx_pairs_train, batch_size=batch_size), # train model 
                    samples_per_epoch = len(idx_pairs_train)//batch_size*batch_size,
                    nb_epoch=10,
                    validation_data=batch_generator(idx_pairs_val, batch_size=batch_size),
                    nb_val_samples=len(idx_pairs_val)//batch_size*batch_size, 
                    callbacks = _callbacks)
        print 'fold %d complete, outputting to %s...' % (fold, fpath)
        for qid in qids_val:
            TREC_output(qid, run_name = run_name, fpath = fpath)

## 5-fold validation and output ranked list

In [162]:
KFold('data/trec-output/0125_summary_5fold_uniform_sampling.rankedlist')

fold 0 [3, 29, 14, 11, 27, 25]
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
fold 1 [28, 12, 18, 23, 6, 17]
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
fold 2 [9, 15, 24, 21, 2, 30]
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
fold 3 [7, 5, 19, 20, 10, 8]
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
fold 4 [26, 4, 1, 22, 16, 13]
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


### LOO run

In [151]:
KFold('data/trec-output/0125_summary_LOO.rankedlist', K=30)

fold 0 [3]
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
fold 1 [29]
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
fold 2 [14]
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
fold 3 [11]
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
fold 4 [27]
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
fold 5 [25]
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
fold 6 [28]
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
fold 7 [12]
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
fold 8 [18]
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
fold 9 [23]
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
fold 1

To evaluate:

trec_eval -q -M1000 official_qrels submitted_results

In [147]:
sorted( n_pos.items()) # 10 15 18 2 21 22 23 27 3 4 5 7 

[(1, 125),
 (2, 34),
 (3, 144),
 (4, 18),
 (5, 95),
 (6, 133),
 (7, 68),
 (8, 809),
 (9, 117),
 (10, 19),
 (11, 350),
 (12, 107),
 (13, 128),
 (14, 117),
 (15, 69),
 (16, 175),
 (17, 171),
 (18, 63),
 (19, 204),
 (20, 616),
 (21, 68),
 (22, 8),
 (23, 108),
 (24, 714),
 (25, 213),
 (26, 103),
 (27, 12),
 (28, 203),
 (29, 118),
 (30, 34)]