## First, we will install the dependencies for running the Similarity Learning task

In [1]:
import os
import csv
import re
from gensim.similarity_learning import DRMM_TKS_Model
from pprint import pprint

Using TensorFlow backend.


## Data Format

We have to provide data in a format which is understood by the model.
The model understands sentences as a list of words. 
Further, we need to give a :
 1. Queries List
 2. Candidate Document List
 3. Correct Label List

1 is a list of list of words
2 and 3 is actually a list of list of list of words/ints

Example:
```
queries = ["When was Abraham Lincoln born ?".split(), 
            "When was the first World War ?".split()]
docs = [
		 ["Abraham Lincoln was the president of the United States of America".split(),
		 "He was born in 1809".split()],
		 ["The first world war was bad".split(),
		 "It was fought in 1914".split(),
		 "There were over a million deaths".split()]
       ]
labels = [[0,
           1],
		  [0,
           1,
           0]
          ]
```

## About the dataset : WikiQA

The WikiQA corpus is a set of question-answer pairs in which for every query there are several candidate documents of which none, one or more documents might be relevant.
Relevance is purely binary, i.e., 1: relavant, 0: not relevant

Sample data:
```
QuestionID	Question	DocumentID	DocumentTitle	SentenceID	Sentence	Label
Q1	how are glacier caves formed?	D1	Glacier cave	D1-0	A partly submerged glacier cave on Perito Moreno Glacier .	0
Q1	how are glacier caves formed?	D1	Glacier cave	D1-1	The ice facade is approximately 60 m high	0
Q1	how are glacier caves formed?	D1	Glacier cave	D1-2	Ice formations in the Titlis glacier cave	0
Q1	how are glacier caves formed?	D1	Glacier cave	D1-3	A glacier cave is a cave formed within the ice of a glacier .	1
Q1	how are glacier caves formed?	D1	Glacier cave	D1-4	Glacier caves are often called ice caves , but this term is properly used to describe bedrock caves that contain year-round ice.	0
```

## Data Preprocessing
We need to take the above text and make it into `queries, docs, labels` form
We use the below code for that


In [2]:
# Fill the below with wherever you have your WikiQACorpus Folder
wikiqa_data_path = os.path.join('data', 'WikiQACorpus', 'WikiQA-train.tsv')


def preprocess_sent(sent):
    """Utility function to lower, strip and tokenize each sentence
    
    Replace this function if you want to handle preprocessing differently"""
    return re.sub("[^a-zA-Z0-9]", " ", sent.strip().lower()).split()

# Defining some consants for .tsv reading
QUESTION_ID_INDEX = 0
QUESTION_INDEX = 1
ANSWER_INDEX = 5
LABEL_INDEX = 6

with open(wikiqa_data_path, encoding='utf8') as tsv_file:
    tsv_reader = csv.reader(tsv_file, delimiter='\t')
    data_rows = []
    for row in tsv_reader:
        data_rows.append(row)


        
document_group = []
label_group = []

n_relevant_docs = 0
n_filtered_docs = 0

queries = []
docs = []
labels = []

for i, line in enumerate(data_rows[1:], start=1):
    if i < len(data_rows) - 1:  # check if out of bounds might occur
        if data_rows[i][QUESTION_ID_INDEX] == data_rows[i + 1][QUESTION_ID_INDEX]:
            document_group.append(preprocess_sent(data_rows[i][ANSWER_INDEX]))
            label_group.append(int(data_rows[i][LABEL_INDEX]))
            n_relevant_docs += int(data_rows[i][LABEL_INDEX])
        else:
            document_group.append(preprocess_sent(data_rows[i][ANSWER_INDEX]))
            label_group.append(int(data_rows[i][LABEL_INDEX]))

            n_relevant_docs += int(data_rows[i][LABEL_INDEX])

            if n_relevant_docs > 0:
                docs.append(document_group)
                labels.append(label_group)
                queries.append(preprocess_sent(data_rows[i][QUESTION_INDEX]))
            else:
                n_filtered_docs += 1

            n_relevant_docs = 0
            document_group = []
            label_group = []

    else:
        # If we are on the last line
        document_group.append(preprocess_sent(data_rows[i][ANSWER_INDEX]))
        label_group.append(int(data_rows[i][LABEL_INDEX]))
        n_relevant_docs += int(data_rows[i][LABEL_INDEX])

        if n_relevant_docs > 0:
            docs.append(document_group)
            labels.append(label_group)
            queries.append(preprocess_sent(data_rows[i][QUESTION_INDEX]))
        else:
            n_filtered_docs += 1
            n_relevant_docs = 0

## Let's have a look at the data

In [3]:
queries[300]

['where', 'did', 'hurricane', 'katrina', 'begin']

In [4]:
print(docs[300])

[['hurricane', 'katrina', 'was', 'the', 'deadliest', 'and', 'most', 'destructive', 'atlantic', 'hurricane', 'of', 'the', '2005', 'atlantic', 'hurricane', 'season'], ['it', 'was', 'the', 'costliest', 'natural', 'disaster', 'as', 'well', 'as', 'one', 'of', 'the', 'five', 'deadliest', 'hurricanes', 'in', 'the', 'history', 'of', 'the', 'united', 'states'], ['among', 'recorded', 'atlantic', 'hurricanes', 'it', 'was', 'the', 'sixth', 'strongest', 'overall'], ['at', 'least', '1', '833', 'people', 'died', 'in', 'the', 'hurricane', 'and', 'subsequent', 'floods', 'making', 'it', 'the', 'deadliest', 'u', 's', 'hurricane', 'since', 'the', '1928', 'okeechobee', 'hurricane', 'total', 'property', 'damage', 'was', 'estimated', 'at', '81', 'billion', '2005', 'usd', 'nearly', 'triple', 'the', 'damage', 'brought', 'by', 'hurricane', 'andrew', 'in', '1992'], ['hurricane', 'katrina', 'formed', 'over', 'the', 'bahamas', 'on', 'august', '23', '2005', 'and', 'crossed', 'southern', 'florida', 'as', 'a', 'moder

In [5]:
print(labels[300])

[1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]


## Making a train-validation split
At this point, it would be good to make a train-validation split so we can see how the model performs as it trains

In [6]:
train_queries, test_queries = queries[:int(len(queries)*0.8)], queries[int(len(queries)*0.8): ]
train_docs, test_docs = docs[:int(len(docs)*0.8)], docs[int(len(docs)*0.8):]
train_labels, test_labels = labels[:int(len(labels)*0.8)], labels[int(len(labels)*0.8):]

In [7]:
print(len(train_queries), len(test_queries))
print(len(train_docs), len(test_docs))
print(len(train_labels), len(test_labels))

697 175
697 175
697 175


# Training the Model
If we want to train the model with some pretrained word embeddingd like Glove, we will have to specify the path

In [8]:
word_embedding_path = os.path.join('data', 'glove.6B.50d.txt')

We would like to monitor the progress of training of the model.
However, we can't rely on the metrics provided by keras as those metrics don't necessarily apply to Information Retrieval problems.

We can additionally provide a validation dataset which will be tested after every epoch.

Now that we have the preprocessed extracted data, training the model just takes one line:

In [9]:
# Train the model
drmm_tks_model = DRMM_TKS_Model(train_queries, train_docs, train_labels, word_embedding_path=word_embedding_path,
                                epochs=10, validation_data=[test_queries, test_docs, test_labels])

2018-06-18 19:51:51,255 : INFO : Starting Vocab Build
2018-06-18 19:51:51,380 : INFO : Vocab Build Complete
2018-06-18 19:51:51,381 : INFO : Vocab Size is 17776
2018-06-18 19:51:51,382 : INFO : Building embedding index using pretrained word embeddings
2018-06-18 19:52:03,115 : INFO : The embeddings_index built from the given file has 400000 words of 50 dimensions
2018-06-18 19:52:03,117 : INFO : Embedding Matrix for Embedding Layer has shape (17777, 50) 
2018-06-18 19:52:03,165 : INFO : There are 590 words not in the embeddings. Setting them to zero
2018-06-18 19:52:03,166 : INFO : Adding additional dimensions from the embedding file to embedding matrix
2018-06-18 19:52:04,632 : INFO : Normalizing the word embeddings
2018-06-18 19:52:05,276 : INFO : Embedding Matrix now has shape (400593, 50)
2018-06-18 19:52:05,279 : INFO : Pad word has been set to index 400590
2018-06-18 19:52:05,285 : INFO : Embedding index build complete


__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
query (InputLayer)              (None, 200)          0                                            
__________________________________________________________________________________________________
doc (InputLayer)                (None, 200)          0                                            
__________________________________________________________________________________________________
embedding_1 (Embedding)         (None, 200, 50)      20029650    query[0][0]                      
                                                                 doc[0][0]                        
__________________________________________________________________________________________________
mm_q_embed_DOT_d_embed (Dot)    (None, 200, 200)     0           embedding_1[0][0]                
          

2018-06-18 19:52:55,605 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:52:55,625 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:52:55,644 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:52:55,667 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:52:55,687 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:52:55,712 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped


MAP:  0.5148505514499304
nDCG@ 1 :  0.33714285714285713
nDCG@ 3 :  0.5147417470374223
nDCG@ 5 :  0.5763412968630361
nDCG@ 10 :  0.6312150086506307
nDCG@ 20 :  0.6441387764470141
Epoch 2/10


2018-06-18 19:53:44,085 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:53:44,097 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:53:44,111 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:53:44,124 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:53:44,138 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:53:44,155 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped


MAP:  0.5216076752784827
nDCG@ 1 :  0.3314285714285714
nDCG@ 3 :  0.5341093278066933
nDCG@ 5 :  0.58575913485778
nDCG@ 10 :  0.6370480273675955
nDCG@ 20 :  0.649876765940878
Epoch 3/10


2018-06-18 19:54:23,214 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:54:23,228 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:54:23,239 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:54:23,255 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:54:23,267 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:54:23,281 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped


MAP:  0.5316416888839249
nDCG@ 1 :  0.35428571428571426
nDCG@ 3 :  0.5336843589095192
nDCG@ 5 :  0.5971383676102587
nDCG@ 10 :  0.6445071782700846
nDCG@ 20 :  0.6573150417223009
Epoch 4/10


2018-06-18 19:55:06,206 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:55:06,220 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:55:06,235 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:55:06,250 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:55:06,263 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:55:06,277 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped


MAP:  0.5350727855650217
nDCG@ 1 :  0.36
nDCG@ 3 :  0.5364738097247982
nDCG@ 5 :  0.5996773969921729
nDCG@ 10 :  0.6484333190423246
nDCG@ 20 :  0.6601953660352579
Epoch 5/10


2018-06-18 19:55:56,269 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:55:56,307 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:55:56,341 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:55:56,388 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:55:56,451 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:55:56,487 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped


MAP:  0.5373028281242567
nDCG@ 1 :  0.37142857142857144
nDCG@ 3 :  0.5370187504788523
nDCG@ 5 :  0.5971636097417318
nDCG@ 10 :  0.6484826998744052
nDCG@ 20 :  0.661812118324339
Epoch 6/10


2018-06-18 19:56:40,373 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:56:40,384 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:56:40,398 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:56:40,411 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:56:40,424 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:56:40,439 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped


MAP:  0.5471022957373889
nDCG@ 1 :  0.38285714285714284
nDCG@ 3 :  0.5484473219074237
nDCG@ 5 :  0.6066320151337564
nDCG@ 10 :  0.6561264864429803
nDCG@ 20 :  0.6694559048929141
Epoch 7/10


2018-06-18 19:57:18,349 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:57:18,358 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:57:18,369 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:57:18,379 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:57:18,389 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:57:18,400 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped


MAP:  0.5502088665302951
nDCG@ 1 :  0.38857142857142857
nDCG@ 3 :  0.5579885069893176
nDCG@ 5 :  0.6065795860359091
nDCG@ 10 :  0.6611604432569254
nDCG@ 20 :  0.6714596617183064
Epoch 8/10


2018-06-18 19:57:58,372 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:57:58,392 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:57:58,411 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:57:58,424 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:57:58,434 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:57:58,451 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped


MAP:  0.5525359204287775
nDCG@ 1 :  0.3942857142857143
nDCG@ 3 :  0.5557239904718314
nDCG@ 5 :  0.6102927756720273
nDCG@ 10 :  0.6614013697050168
nDCG@ 20 :  0.6731490781017055
Epoch 9/10


2018-06-18 19:58:39,081 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:58:39,098 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:58:39,109 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:58:39,119 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:58:39,131 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:58:39,142 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped


MAP:  0.5443128736343021
nDCG@ 1 :  0.38285714285714284
nDCG@ 3 :  0.5425974989686475
nDCG@ 5 :  0.6019846618312067
nDCG@ 10 :  0.6551148717902364
nDCG@ 20 :  0.6668608783992421
Epoch 10/10


2018-06-18 19:59:26,760 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:59:26,770 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:59:26,784 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:59:26,801 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:59:26,815 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped
2018-06-18 19:59:26,834 : INFO : Using 175 out of 175 data points which is 0.00%. 0 were skipped


MAP:  0.5404327061827062
nDCG@ 1 :  0.37142857142857144
nDCG@ 3 :  0.5387366770097259
nDCG@ 5 :  0.605006023715385
nDCG@ 10 :  0.6521062841469755
nDCG@ 20 :  0.663808927216609


## Testing the model on new data

The testing of the data can be done on completely unseen data using `model.predict(queries, docs)` where
queries: list of list of words
docs: list of list of list of words

In [10]:
# Example:
queries = ["how are glacier caves formed ?".split()]
docs = ["A partly submerged glacier cave on Perito Moreno Glacier".split(),
        "A glacier cave is a cave formed within the ice of a glacier".split()]

In [11]:
drmm_tks_model.predict(queries, docs)

[[0.45600405]
 [0.5344676 ]]


As can be seen above, the correct answer has the higher similarity score.