### Build a Deep Semantic Structured Model (DSSM)

<img src="https://github.com/faneshion/MatchZoo/blob/master/docs/_static/images/matchzoo-logo.png?raw=true" alt="logo" style="width:600px;float: center"/>

This is a tutorial on training *Deep Semantic Similarity Model* [Huang et al. 2013](https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/DSSM_cikm13_talk_v4.pdf) model with [MatchZoo](https://github.com/faneshion/MatchZoo). We use [WikiQA](https://aclweb.org/anthology/D15-1237) as the example benchmark data set to show the usage.

Features:

1. Using the tri-letter based word hashing for scalable word representation.
2. Using the deep neural net to extract high-level semantic representations.
3. Using the click signal to guide the learning.



*To walk through this notebook, you need approx 30 minutes.*

-------

**TL;DR**

The following code block illustrates the main workflow of how to train a DSSM model. 

```python
from matchzoo import preprocessor
from matchzoo import generators
from matchzoo import models

train, test = ... # prepare your training data and test data.

dssm_preprocessor = preprocessor.DSSMPreprocessor()
processed_tr = dssm_preprocessor.fit_transform(train, stage='train')
processed_te = dssm_preprocessor.fit_transform(test, stage='test')
# DSSM expect dimensionality of letter-trigrams as input shape.
# The fitted parameters has been stored in `context` during preprocessing on training data.
input_shapes = processed_tr.context['input_shapes']

generator_tr = generators.PointGenerator(processed_tr)
generator_te = generators.PointGenerator(processed_te)
# Example, train with generator, test with the first batch.
X_te, y_te = generator_te[0]

dssm_model = models.DSSMModel()
dssm_model.params['input_shapes'] = input_shapes
dssm_model.guess_and_fill_missing_params()
dssm_model.build()
dssm_model.compile()
dssm_model.fit_generator(generator_tr)
# Make predictions
predictions = dssm_model.predict([X_te.text_left, X_te.text_right])
```

-----

MatchZoo expect a list of *Quintuple* as training data:

```python
train = [('qid0', 'did0', 'query 0', 'document 0', 'label 0'),
         ('qid0', 'did1', 'query 0', 'document 1', 'label 1'),
          ...,
         ('qid1', 'did2', 'query 1', 'document 2', 'label 3')]
```

The corresponded columns are `(text_left_id, text_right_id, text_left, text_right, label)`. For Information Retrieval task, *text_left* is referred as *query*, and *text_right* is document.

For the test case, MatchZoo expect a list of *Quadruple* (we do not need labels) as input:

```python
test = [('qid9', 'did5', 'query 9', 'document 5'),
         ...,
        ('qid2', 'did7', 'query 2', 'document 7')]
```

### Table of Content

+ Prepare **WikiQA** dataset
    - Download
    - Load
    - Adjustment
+ Preprocessing
+ Data Generator
+ Model Training
    - Initialize
    - Hyper-Parameters
    - Make Prediction
    - Model Persistence
- Reference

### Prepare WikiQA dataset

#### Download

We take WikiQA as the example benchmark dataset to show the usage of MatchZoo. Firstly you need to downlowd the data and uncompress the data, we provided the following script to help you download the dataset into `MatchZoo/data/WikiQA` folder, you can change the directory in the following script.

If you already have WikiQA dataset downloaded on your machine, skip the following script.

In [1]:
import os

cmd = 'mkdir -p ../../data/WikiQA/\n' \
      +'cd ../../data/WikiQA/\n' \
      +'wget https://download.microsoft.com/download/E/5/F/E5FCFCEE-7005-4814-853D-DAA7C66507E0/WikiQACorpus.zip\n' \
      +'unzip WikiQACorpus.zip\n'
print ('download WikiQA data... ', cmd)
os.system(cmd)

download WikiQA data...  mkdir -p ../../data/WikiQA/
cd ../../data/WikiQA/
wget https://download.microsoft.com/download/E/5/F/E5FCFCEE-7005-4814-853D-DAA7C66507E0/WikiQACorpus.zip
unzip WikiQACorpus.zip



0

#### Load & Adjustment

The *train/dev/test* files of WikiQA are *WikiQA-train.tsv*, *WikiQA-dev.tsv*, *WikiQA-test.tsv* under the uncompressed folder WikiQACorpus. The data format of WikiQA is as follows:

`QuestionID\tQuestion\tDocumentID\tDocumentTitle\tSentenceID\tSentence\tLabel`

We can convert this format to the expected input format of MatchZoo.

In [2]:
data_folder = '../../data/WikiQA/WikiQACorpus/'

def read_data(input, stage):
    output_list = []
    index = 0
    with open(input) as fin:
        for l in fin:
            tok = l.split('\t')
            if index != 0:
                if stage == 'test':
                    output_list.append((tok[0], tok[4], tok[1], tok[5])) # qid, did, q, d, label
                else:
                    output_list.append((tok[0], tok[4], tok[1], tok[5], tok[6])) # qid, did, q, d, label 
            index += 1
    return output_list

train = read_data(data_folder + 'WikiQA-train.tsv', stage='train')
dev   = read_data(data_folder + 'WikiQA-dev.tsv', stage='dev')
test  = read_data(data_folder + 'WikiQA-test.tsv', stage='test')

### Preprocessing

You can pre-process your DSSM input in three lines of code:

In [3]:
# Initialize a dssm preprocessor.
from matchzoo import preprocessor
dssm_preprocessor = preprocessor.DSSMPreprocessor()
processed_tr = dssm_preprocessor.fit_transform(train, stage='train')
processed_te = dssm_preprocessor.fit_transform(test, stage='test')

Using TensorFlow backend.
Start building vocabulary & fitting parameters.
100%|██████████| 20959/20959 [00:04<00:00, 4794.48it/s]
Start processing input data for train stage.
100%|██████████| 20959/20959 [00:01<00:00, 16506.70it/s]
Start processing input data for test stage.
100%|██████████| 6594/6594 [00:01<00:00, 3605.97it/s]


**What is `processed_tr`?**

`processed_tr` is a **MatchZoo DataPack** data structure (see `matchzoo/datapack.py`). It contains 
1. A *2-columns* `pandas DataFrame` to host all the pre-processed records including index and processed text.
2. A `mapping` variable (python `dict`) to store the relationship between id pairs.
2. a `context` property (dictionary) consists of all the parameters fitted during pre-processing. 

The `fit_transform` method is a linear combination of two methods:

1. Fit parameters using the `fit` function, this only happens when `stage='train'`.
2. Transform data into expected format.

So the previous three lines code can also be written as:

```python
# Initialize a dssm preprocessor.
from matchzoo import preprocessor
dssm_preprocessor = preprocessor.DSSMPreprocessor()
processed_tr = dssm_preprocessor.fit_transform(train, stage='train')
# We do not need to fit any parameters during the testing stage.
# So we can call transform directly.
processed_te = dssm_preprocessor.transform(test, stage='test')
```

As described, the fitted parameters were stored in `context` property, to access the context, just call:

```python
print(processed_tr.context)
```
An example:

In [4]:
print('vocab size: ', len(processed_tr.context['term_index']))

vocab size:  9643


**What has been stored in the `context?`** 

We stored `input_shapes` in the context property. Since DSSM model's model input shape is dynamic (it depends on user's training data to generate tri-letters), so you **must** manually set models input shape, we'll discuss it in the model training section.

**What is `dssm_preprocessor` actually doing?**

The `dssm_preprocessor` is calling a sequence of `process_units`. Each `process_unit` is designed to perform one atom operation on input data. For instance, in `dssm_preprocessor`, we called:

1. TokenizeUnit: Perform tokenization on raw input data.
2. LowercaseUnit: Transform all tokens into lower case.
3. PuncRemovalUnit: Remove all the punctuations.
4. StopRemovalUnit: Remove all the stopwords.
5. NgramLetterUnit: Create n-gram-letters (by default we're creating tri-letters) as input data, for example: the token `test` we be transformed to `['#te', 'tes', 'est', 'st#']`.
6. VocabularyUnit: Create vocabulary to get the dimensionality of `tri-letters`.
7. WordHashingUnit: Create `WordHashing` layer as described in the paper.

----

### Data Generation

For memory efficiency, we expect you to use **generator** to generate batches of data on the fly. For example, we can create a **PointGenerator** as follows:

In [11]:
from matchzoo import generators
generator_tr = generators.PointGenerator(processed_tr, batch_size=64, stage='train')
generator_te = generators.PointGenerator(processed_te, batch_size=64, stage='test')

To get the first batch of trainig data, just call `X_train, y_train = generator[0]`.

**What is PointGenerator?**
**PointGenerator** is this case, it is assumed that each query-document pair in the training data has a numerical or ordinal score. Then the problem can be approximated by a regression/Classification problem — given a single query-document pair, predict its score.

A number of existing supervised machine learning algorithms can be readily used for this purpose. Ordinal regression and classification algorithms can also be used in pointwise approach when they are used to predict the score of a single query-document pair, and it takes a small, finite number of values.

**What is PairGenerator?**
TO BE ADDED

**What is ListGenerator?**
TO BE ADDED

----

### Train Your DSSM Model

To train a DSSM model, we need to create an instance of DSSMModel:

In [12]:
from matchzoo import models
dssm_model = models.DSSMModel()

Then, we need to set hyper-parameters to our DSSM Model. In general, there are **two types of hyper-parameters**:

**Required parameters**: For DSSM, since the `input_shapes` depend on the dimensionality of fitted training data, you're required to set this parameter manually!

In [13]:
# The fitted parameters is stored in the `context` property of pre-processor instance during the training stage.
input_shapes = processed_tr.context['input_shapes']
dssm_model.params['input_shapes'] = input_shapes

**Tunable parameters**: For DSSM, you're allowed to tune these parameters:

```python
from matchzoo import tasks

params = {'w_initializer': 'glorot_normal', # see keras weight_initializer.
          'b_initializer': 'zeros', # see keras bias_initializer.
          'dim_fan_out': 128, # Dimension of output layer.
          'dim_hidden': 300, # Dimension of hidden layer.
          'activation_hidden': 'tanh', # Activation function of hidden layer, see keras activation.
          'num_hidden_layers': 2, # Number of hidden layers.
          'optimizer': 'sgd', # By default, we're using sgd, see keras optimizer.
          'task': tasks.Classification, # Default Classification, you can use tasks.Ranking
          'loss': 'categorical_crossentropy', # categorical_crossentropy, see keras loss.
          'metric': 'acc', # Accuracy by default, see keras metric.
         }
```

Same as **required parameters**, use `dssm_model.params['parameter-name'] = parameter-value` to set the hyper parameters. If you want to keep everything by default values, just use

In [14]:
dssm_model.guess_and_fill_missing_params()
print('dssm parameters: ', dssm_model.params)

dssm parameters:  name                          DSSMModel
model_class                   <class 'matchzoo.models.dssm_model.DSSMModel'>
input_shapes                  [(9644,), (9644,)]
task                          <matchzoo.tasks.classification.Classification object at 0x7f40c5936940>
metrics                       ['acc']
loss                          categorical_crossentropy
optimizer                     adam
w_initializer                 glorot_normal
b_initializer                 zeros
dim_fan_out                   128
dim_hidden                    300
activation_hidden             tanh
num_hidden_layers             2


#### Model Training

To train the model after all the parameters were settled, call:

In [15]:
dssm_model.build()
dssm_model.compile()
# Fit the dssm model on generator.
dssm_model.fit_generator(generator_tr, steps_per_epoch=200, epochs=10)
# Make predictions on the first batch of test data
X_te, y_te = generator_te[0]
predictions = dssm_model.predict([X_te.id_left, X_te.id_right])

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [21]:
# Make predictions on all the test data
pred_all = []
for X_te, y_te in generator_te:
    pred = dssm_model.predict([X_te.id_left, X_te.id_right])
    print(pred)

[[0.6393617  0.36063838]
 [0.6379478  0.36205217]
 [0.95432234 0.0456777 ]
 [0.93329775 0.06670222]
 [0.95431864 0.04568134]
 [0.95256877 0.04743122]
 [0.95432454 0.04567546]
 [0.63619405 0.36380595]
 [0.9543243  0.04567569]
 [0.9543251  0.04567495]
 [0.87831247 0.12168746]
 [0.6361466  0.3638534 ]
 [0.6361461  0.36385384]
 [0.9543137  0.04568638]
 [0.9543202  0.04567977]
 [0.9543203  0.04567974]
 [0.6361474  0.36385265]
 [0.63614714 0.3638529 ]
 [0.87204325 0.12795667]
 [0.95432496 0.04567499]
 [0.6361474  0.3638526 ]
 [0.9543251  0.04567498]
 [0.95432353 0.04567651]
 [0.8926593  0.10734072]
 [0.9543247  0.04567521]
 [0.6706216  0.32937843]
 [0.6361462  0.36385378]
 [0.89709604 0.10290398]
 [0.9543251  0.04567492]
 [0.9543251  0.04567498]
 [0.6833408  0.31665927]
 [0.9543252  0.04567486]
 [0.95432496 0.04567499]
 [0.95353025 0.0464698 ]
 [0.9543229  0.04567706]
 [0.94253    0.05747008]
 [0.63615257 0.36384743]
 [0.83217335 0.16782664]
 [0.9513066  0.04869347]
 [0.86336946 0.1366305 ]


[[0.6749617  0.3250383 ]
 [0.8337881  0.16621189]
 [0.94529366 0.05470637]
 [0.6361464  0.36385354]
 [0.9543252  0.04567484]
 [0.9489513  0.05104876]
 [0.9543104  0.04568965]
 [0.9543168  0.04568315]
 [0.9543247  0.04567528]
 [0.63728416 0.3627158 ]
 [0.9457799  0.05422014]
 [0.73076344 0.26923656]
 [0.9543252  0.04567484]
 [0.9543243  0.04567567]
 [0.63619286 0.36380714]
 [0.63620406 0.36379597]
 [0.8599128  0.14008717]
 [0.6398694  0.3601306 ]
 [0.8093788  0.19062123]
 [0.95432496 0.04567505]
 [0.95432085 0.04567914]
 [0.95312494 0.04687501]
 [0.95432454 0.04567546]
 [0.95432156 0.04567844]
 [0.8717249  0.12827507]
 [0.9543252  0.04567484]
 [0.90782183 0.09217812]
 [0.9543252  0.04567484]
 [0.6446817  0.35531828]
 [0.95006526 0.04993474]
 [0.92865473 0.07134528]
 [0.9543251  0.04567497]
 [0.95357406 0.04642601]
 [0.95072985 0.04927013]
 [0.95432496 0.04567504]
 [0.8023661  0.19763397]
 [0.6674428  0.33255717]
 [0.95432484 0.04567511]
 [0.9543247  0.04567521]
 [0.6372796  0.36272037]


[[0.87914205 0.12085797]
 [0.63636965 0.36363035]
 [0.63614655 0.36385348]
 [0.86279625 0.13720377]
 [0.63614637 0.36385354]
 [0.63614625 0.36385375]
 [0.6395082  0.36049184]
 [0.95242953 0.0475704 ]
 [0.6361463  0.36385372]
 [0.95223063 0.04776937]
 [0.7149645  0.28503543]
 [0.75068724 0.24931273]
 [0.9543251  0.04567488]
 [0.9543252  0.04567484]
 [0.9543246  0.04567538]
 [0.8804136  0.11958638]
 [0.7518477  0.24815236]
 [0.8805476  0.11945243]
 [0.95432496 0.04567498]
 [0.9543251  0.0456749 ]
 [0.6361921  0.36380792]
 [0.6828046  0.31719545]
 [0.9543149  0.04568513]
 [0.9538043  0.04619564]
 [0.6545838  0.34541616]
 [0.952228   0.04777202]
 [0.95432484 0.04567517]
 [0.9542952  0.04570473]
 [0.63615143 0.3638486 ]
 [0.9543251  0.0456749 ]
 [0.9482998  0.05170016]
 [0.755158   0.24484192]
 [0.6361509  0.363849  ]
 [0.95353234 0.04646766]
 [0.76977587 0.23022415]
 [0.6361536  0.36384642]
 [0.9543251  0.04567492]
 [0.9543198  0.04568027]
 [0.9543173  0.0456828 ]
 [0.9543168  0.04568313]


[[0.95432496 0.04567508]
 [0.63823205 0.36176792]
 [0.9075125  0.09248748]
 [0.95339465 0.04660528]
 [0.6370823  0.36291763]
 [0.95415086 0.0458491 ]
 [0.94673395 0.0532661 ]
 [0.9542775  0.0457225 ]
 [0.64971566 0.35028437]
 [0.95431334 0.0456867 ]
 [0.9543247  0.04567521]
 [0.8480158  0.15198416]
 [0.9543246  0.0456754 ]
 [0.9408957  0.05910435]
 [0.7182926  0.2817074 ]
 [0.8399555  0.16004445]
 [0.9543242  0.04567584]
 [0.77402556 0.22597443]
 [0.9543251  0.04567495]
 [0.73447853 0.26552144]
 [0.63614666 0.3638533 ]
 [0.95432496 0.04567499]
 [0.887777   0.11222301]
 [0.79683024 0.20316981]
 [0.9543251  0.04567492]
 [0.9501625  0.04983752]
 [0.8399569  0.1600431 ]
 [0.9543242  0.0456758 ]
 [0.637082   0.36291805]
 [0.9543226  0.04567748]
 [0.91851676 0.08148316]
 [0.8802147  0.11978532]
 [0.95353556 0.04646443]
 [0.9543246  0.0456754 ]
 [0.63614666 0.3638533 ]
 [0.9411536  0.05884645]
 [0.95431966 0.04568039]
 [0.63614887 0.36385116]
 [0.88119274 0.11880725]
 [0.95432484 0.04567517]


[[0.9543252  0.04567486]
 [0.9543251  0.04567498]
 [0.95426124 0.04573873]
 [0.9543251  0.04567498]
 [0.83574283 0.16425714]
 [0.95432496 0.04567505]
 [0.954108   0.04589206]
 [0.63614637 0.36385363]
 [0.95431715 0.04568289]
 [0.6361464  0.36385354]
 [0.78312796 0.21687213]
 [0.95280826 0.04719176]
 [0.954291   0.04570907]
 [0.6374138  0.36258623]
 [0.9543251  0.04567488]
 [0.6782951  0.32170498]
 [0.95431083 0.04568918]
 [0.9543178  0.04568226]
 [0.9542504  0.04574966]
 [0.9433416  0.05665832]
 [0.86663294 0.13336709]
 [0.95432484 0.04567513]
 [0.9543021  0.04569794]
 [0.6363761  0.36362392]
 [0.86809856 0.13190141]
 [0.6361522  0.36384788]
 [0.95099574 0.04900425]
 [0.6361462  0.36385378]
 [0.6361512  0.36384878]
 [0.95236427 0.04763572]
 [0.63878775 0.36121225]
 [0.9542963  0.0457037 ]
 [0.952233   0.04776696]
 [0.6361484  0.36385158]
 [0.94760513 0.05239487]
 [0.94242716 0.05757284]
 [0.86022913 0.13977088]
 [0.9543251  0.04567492]
 [0.9533174  0.0466826 ]
 [0.92785335 0.07214667]


[[0.95432496 0.04567502]
 [0.9542985  0.04570158]
 [0.63618994 0.36381006]
 [0.63614625 0.36385375]
 [0.95432454 0.0456755 ]
 [0.94118184 0.05881814]
 [0.67216766 0.32783237]
 [0.9543247  0.04567526]
 [0.6363885  0.36361155]
 [0.8637506  0.13624942]
 [0.9543247  0.04567523]
 [0.9543112  0.04568885]
 [0.9543228  0.04567726]
 [0.90573317 0.09426685]
 [0.95432484 0.04567515]
 [0.63615036 0.3638496 ]
 [0.9543252  0.04567486]
 [0.6497568  0.3502432 ]
 [0.7518817  0.24811833]
 [0.9543251  0.04567495]
 [0.8649613  0.13503866]
 [0.9543112  0.04568881]
 [0.95432407 0.04567586]
 [0.89269996 0.10730006]
 [0.9119095  0.08809049]
 [0.6833404  0.31665963]
 [0.9543251  0.04567493]
 [0.95422876 0.04577119]
 [0.9543252  0.04567486]
 [0.9543244  0.04567556]
 [0.92125183 0.07874819]
 [0.9543251  0.04567493]
 [0.95432335 0.04567672]
 [0.6497316  0.35026842]
 [0.64866024 0.35133976]
 [0.6363162  0.3636839 ]
 [0.63614815 0.36385182]
 [0.6361461  0.36385384]
 [0.9516626  0.04833743]
 [0.6370859  0.36291406]


[[0.954304   0.04569598]
 [0.9527127  0.04728732]
 [0.95432484 0.0456751 ]
 [0.6439808  0.3560192 ]
 [0.6671408  0.33285925]
 [0.63614756 0.3638525 ]
 [0.7408839  0.2591161 ]
 [0.85498637 0.14501369]
 [0.95432496 0.04567499]
 [0.9543244  0.04567553]
 [0.83864206 0.16135795]
 [0.9543214  0.0456786 ]
 [0.6870316  0.31296837]
 [0.95432496 0.04567501]
 [0.95432335 0.0456767 ]
 [0.95422083 0.04577914]
 [0.75868803 0.24131194]
 [0.654462   0.34553802]
 [0.9543243  0.04567564]
 [0.637525   0.36247495]
 [0.9543246  0.04567538]
 [0.95432377 0.04567621]
 [0.8791532  0.12084679]
 [0.9360706  0.06392945]
 [0.93210906 0.0678909 ]
 [0.9543251  0.04567495]
 [0.9543243  0.04567564]
 [0.9260145  0.07398549]
 [0.9488265  0.0511735 ]
 [0.6526623  0.34733772]
 [0.95432323 0.04567675]
 [0.9534071  0.04659284]
 [0.95432484 0.04567511]
 [0.95432496 0.04567502]
 [0.9543251  0.04567493]
 [0.9543246  0.04567531]
 [0.94664806 0.05335196]
 [0.63644737 0.36355263]
 [0.95432484 0.0456752 ]
 [0.9543251  0.0456749 ]


[[0.94122887 0.05877117]
 [0.63889897 0.36110103]
 [0.922037   0.07796302]
 [0.95234966 0.04765033]
 [0.95427394 0.04572607]
 [0.69607645 0.30392355]
 [0.6382268  0.36177325]
 [0.9543247  0.04567528]
 [0.63615036 0.36384967]
 [0.9543246  0.04567531]
 [0.9543228  0.04567718]
 [0.95432407 0.0456759 ]
 [0.6715853  0.32841462]
 [0.6394956  0.36050442]
 [0.95432496 0.04567505]
 [0.9524065  0.04759345]
 [0.638883   0.36111695]
 [0.81577486 0.1842251 ]
 [0.63623667 0.36376333]
 [0.9543251  0.04567492]
 [0.95431423 0.04568581]
 [0.9543251  0.0456749 ]
 [0.6361462  0.36385378]
 [0.95432484 0.04567511]
 [0.9543247  0.04567523]
 [0.902228   0.09777205]
 [0.8018721  0.19812796]
 [0.9543246  0.0456754 ]
 [0.6361497  0.36385027]
 [0.954324   0.045676  ]
 [0.6361465  0.3638535 ]
 [0.87176895 0.12823108]
 [0.7755364  0.22446364]
 [0.9022407  0.09775934]
 [0.86954135 0.1304586 ]
 [0.7414897  0.2585103 ]
 [0.63615364 0.3638464 ]
 [0.8357502  0.16424978]
 [0.6363486  0.36365145]
 [0.9543218  0.04567818]


#### Model Persistence

You can persist your trained model using `model.save()` and `load_model` function:

```python
from matchzoo import engine
# Save the model to dir.
dssm_model.save('/your-model-saved-path')
# And load the model from dir.
engine.load_model('/your-model-saved-path')
```

## Reference

[Huang et al. 2013] Po-Sen Huang, Xiaodong He, Jianfeng Gao, Li Deng, Alex Acero, and Larry Heck. 2013. Learning deep structured semantic models for web search using clickthrough data. In Proc. CIKM. ACM, 2333–2338.