### Build a Arc-I Model

<img src="https://github.com/faneshion/MatchZoo/blob/master/docs/_static/images/matchzoo-logo.png?raw=true" alt="logo" style="width:600px;float: center"/>

This is a tutorial on training *Arc-I Model* [Hu et al. 2014](http://papers.nips.cc/paper/5550-convolutional-neural-network-architectures-for-matching-natural-language-sentences.pdf) model with [MatchZoo](https://github.com/faneshion/MatchZoo). We use [WikiQA](https://aclweb.org/anthology/D15-1237) as the example benchmark data set to show the usage.



-------

**TL;DR**

The following code block illustrates the main workflow of how to train a Arc-I model. 

```python
from matchzoo import preprocessor
from matchzoo import generators
from matchzoo import models

train, test = ... # prepare your training data and test data.

arci_preprocessor = preprocessor.ArcIPreprocessor()
processed_tr = arci_preprocessor.fit_transform(train, stage='train')
processed_te = arci_preprocessor.fit_transform(test, stage='test')

input_shapes = processed_tr.context['input_shapes']

generator_tr = generators.PointGenerator(processed_tr)
generator_te = generators.PointGenerator(processed_te)
# Example, train with generator, test with the first batch.
X_te, y_te = generator_te[0]

arci_model = models.ArcIModel()
arci_model.params['input_shapes'] = input_shapes
arci_model.guess_and_fill_missing_params()
arci_model.build()
arci_model.compile()
arci_model.fit_generator(generator_tr)
# Make predictions
predictions = arci_model.predict([X_te.text_left, X_te.text_right])
```

-----

MatchZoo expect a list of *Quintuple* as training data:

```python
train = [('qid0', 'did0', 'query 0', 'document 0', 'label 0'),
         ('qid0', 'did1', 'query 0', 'document 1', 'label 1'),
          ...,
         ('qid1', 'did2', 'query 1', 'document 2', 'label 3')]
```

The corresponded columns are `(text_left_id, text_right_id, text_left, text_right, label)`. For Information Retrieval task, *text_left* is referred as *query*, and *text_right* is document.

For the test case, MatchZoo expect a list of *Quadruple* (we do not need labels) as input:

```python
test = [('qid9', 'did5', 'query 9', 'document 5'),
         ...,
        ('qid2', 'did7', 'query 2', 'document 7')]
```

### Table of Content

+ Prepare WikiQA dataset
    - Download
    - Load
    - Adjustment
+ Preprocessing
+ Data Generator
+ Model Training
    - Initialize
    - Hyper-Parameters
    - Make Prediction
    - Model Persistence
- Reference

### Prepare WikiQA dataset

#### Download

We take WikiQA as the example benchmark dataset to show the usage of MatchZoo. Firstly you need to downlowd the data and uncompress the data, we provided the following script to help you download the dataset into `MatchZoo/data/WikiQA` folder, you can change the directory in the following script.

If you already have WikiQA dataset downloaded on your machine, skip the following script.

In [42]:
import os

cmd = 'rm -rf ../../data/WikiQA\n' \
      +'mkdir -p ../../data/WikiQA/\n' \
      +'cd ../../data/WikiQA/\n' \
      +'wget https://download.microsoft.com/download/E/5/F/E5FCFCEE-7005-4814-853D-DAA7C66507E0/WikiQACorpus.zip\n' \
      +'unzip WikiQACorpus.zip\n'
print ('download WikiQA data... ', cmd)
os.system(cmd)

download WikiQA data...  rm -rf ../../data/WikiQA
mkdir -p ../../data/WikiQA/
cd ../../data/WikiQA/
wget https://download.microsoft.com/download/E/5/F/E5FCFCEE-7005-4814-853D-DAA7C66507E0/WikiQACorpus.zip
unzip WikiQACorpus.zip



0

#### Load & Adjustment

The *train/dev/test* files of WikiQA are *WikiQA-train.tsv*, *WikiQA-dev.tsv*, *WikiQA-test.tsv* under the uncompressed folder WikiQACorpus. The data format of WikiQA is as follows:

`QuestionID\tQuestion\tDocumentID\tDocumentTitle\tSentenceID\tSentence\tLabel`

We can convert this format to the expected input format of MatchZoo.

In [43]:
data_folder = '../../data/WikiQA/WikiQACorpus/'

def read_data(input, stage):
    output_list = []
    index = 0
    with open(input) as fin:
        for l in fin:
            tok = l.split('\t')
            if index != 0:
                if stage == 'test':
                    output_list.append((tok[0], tok[4], tok[1], tok[5])) # qid, did, q, d, label
                else:
                    output_list.append((tok[0], tok[4], tok[1], tok[5], tok[6])) # qid, did, q, d 
            index += 1
    return output_list

train = read_data(data_folder + 'WikiQA-train.tsv', stage='train')
dev   = read_data(data_folder + 'WikiQA-dev.tsv', stage='dev')
test  = read_data(data_folder + 'WikiQA-test.tsv', stage='test')

### Preprocessing

You can pre-process your DSSM input in three lines of code:

In [44]:
# Initialize a dssm preprocessor.
from matchzoo import preprocessor
arci_preprocessor = preprocessor.ArcIPreprocessor()
processed_tr = arci_preprocessor.fit_transform(train, stage='train')
processed_te = arci_preprocessor.fit_transform(test, stage='test')

Start building vocabulary & fitting parameters.
2118it [00:00, 2496.62it/s]
18841it [00:09, 1888.30it/s]
Start processing input data for train stage.
2118it [00:00, 2446.24it/s]
18841it [00:10, 1758.83it/s]
Start processing input data for test stage.
633it [00:00, 2737.32it/s]
5961it [00:03, 1913.80it/s]


**What is `processed_tr`?**

`processed_tr` is a **MatchZoo DataPack** data structure (see `matchzoo/datapack.py`). It contains 
1. A *2-columns* `pandas DataFrame` to host all the pre-processed records including index and processed text.
2. A `mapping` variable (python `dict`) to store the relationship between id pairs.
2. a `context` property (dictionary) consists of all the parameters fitted during pre-processing. 

The `fit_transform` method is a linear combination of two methods:

1. Fit parameters using the `fit` function, this only happens when `stage='train'`.
2. Transform data into expected format.

So the previous preprocessing code can also be written as:

```python
processed_te = dssm_preprocessor.transform(test, stage='test')
```

As described, the fitted parameters were stored in `context` property, to access the context, just call:

```python
print(processed_tr.context)
```
An example:

In [45]:
processed_tr.context.keys()

dict_keys(['term_index', 'input_shapes'])

In [46]:
print('vocab size: ', len(processed_tr.context['term_index']))

vocab size:  29924


**What has been stored in the `context?`** 

We stored `term_index` and `input_shapes` in the context property. 


**What is `arci_preprocessor` actually doing?**

The `arci_preprocessor` is calling a sequence of `process_units`. Each `process_unit` is designed to perform one atom operation on input data. For instance, in `arci_preprocessor`, we called:


1. TokenizeUnit: Perform tokenization on raw input data.
2. LowercaseUnit: Transform all tokens into lower case.
3. PuncRemovalUnit: Remove all the punctuations.
4. StopRemovalUnit: Remove all the stopwords.


----

### Data Generation

For memory efficiency, we expect you to use **generator** to generate batches of data on the fly. For example, we can create a **PointGenerator** as follows:

In [47]:
from matchzoo import generators
from matchzoo import tasks
generator_tr = generators.PointGenerator(inputs=processed_tr, task=tasks.Ranking(), batch_size=64, stage='train')
generator_te = generators.PointGenerator(inputs=processed_te, task=tasks.Ranking(), batch_size=64, stage='test')

To get the first batch of trainig data, just call `X_train, y_train = generator[0]`.

**What is PointGenerator?**
**PointGenerator** is this case, it is assumed that each query-document pair in the training data has a numerical or ordinal score. Then the problem can be approximated by a regression/Classification problem — given a single query-document pair, predict its score.

A number of existing supervised machine learning tasks fall into this line. Ordinal regression and classification algorithms can also be used in pointwise approach when they are used to predict the score of a single query-document pair.

**What is PairGenerator?**
In this case, the problem is approximated by a classification problem — learning a binary classifier that can tell which document is better in a given pair of documents.

In MatchZoo, **PairGenerator** generate one positive & `num_neg` negative examples per pair. As an example, to train a DSSM model (for document ranking), we use `num_neg=4`. 

**What is ListGenerator?**
This generator try to directly optimize the value of evaluation measures, averaged over all queries in the training data. 

Chosse the appropriate generator based on your `task`.

----

### Train Your Arc-I Model

To train a Arc-I model, we need to create an instance of ArcIModel:

In [48]:
from matchzoo import models
arci_model = models.ArcIModel()

In [49]:
# The fitted parameters is stored in the `context` property of pre-processor instance during the training stage.
from matchzoo import losses
from matchzoo import tasks
input_shapes = processed_tr.context['input_shapes']
arci_model.params['task'] = tasks.Ranking()
arci_model.params['vocab_size'] = len(processed_tr.context['term_index']) + 1 


In [50]:
arci_model.guess_and_fill_missing_params()
print('arci parameters: ', arci_model.params)

arci parameters:  name                          ArcIModel
model_class                   <class 'matchzoo.models.arci_model.ArcIModel'>
input_shapes                  [(32,), (32,)]
task                          <matchzoo.tasks.ranking.Ranking object at 0x11121f278>
metrics                       ['mae']
loss                          mse
optimizer                     adam
trainable_embedding           False
embedding_dim                 300
vocab_size                    29925
num_blocks                    1
left_kernel_count             [32]
left_kernel_size              [3]
right_kernel_count            [32]
right_kernel_size             [3]
activation                    relu
left_pool_size                [2]
right_pool_size               [2]
padding                       same
dropout_rate                  0.0
embedding_mat                 None
embedding_random_scale        0.2


#### Model Training

To train the model after all the parameters were settled, call:

In [51]:
arci_model.build()
arci_model.compile()
# Fit the arci model on generator.
arci_model.fit_generator(generator_tr, steps_per_epoch=200, epochs=10)
# Make predictions on the first batch of test data
X_te, y_te = generator_te[0]
predictions = arci_model.predict([X_te.text_left, X_te.text_right])

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [52]:
for id_left, id_right, pred in zip(X_te.id_left, X_te.id_right, predictions):
    print("{}/{} is predicted as {}".format(id_left, id_right, pred))

Q933/D901-9 is predicted as [-0.03582207]
Q2304/D359-0 is predicted as [-0.09047449]
Q2100/D1980-3 is predicted as [0.14439546]
Q2382/D230-2 is predicted as [0.10213326]
Q1485/D1411-5 is predicted as [0.27780926]
Q2061/D1946-2 is predicted as [-0.11321834]
Q340/D339-7 is predicted as [-0.01285515]
Q33/D33-7 is predicted as [0.16359125]
Q1957/D1847-8 is predicted as [0.09576143]
Q1389/D1326-14 is predicted as [0.00998938]
Q1446/D1376-8 is predicted as [0.291522]
Q1899/D1795-26 is predicted as [0.08655713]
Q285/D284-4 is predicted as [0.28293288]
Q1983/D1870-4 is predicted as [-0.03440011]
Q2286/D2151-0 is predicted as [-0.03345742]
Q269/D268-3 is predicted as [0.15398271]
Q20/D20-3 is predicted as [0.00160372]
Q2801/D2602-0 is predicted as [0.25547576]
Q2148/D2025-2 is predicted as [-0.12991714]
Q995/D960-0 is predicted as [-0.17178614]
Q337/D336-4 is predicted as [0.0780839]
Q1216/D1165-0 is predicted as [0.03036081]
Q1516/D557-12 is predicted as [0.09236333]
Q506/D499-1 is predicted a

#### Model Persistence

You can persist your trained model using `model.save()` and `load_model` function:

```python
from matchzoo import engine
# Save the model to dir.
arci_model.save('/your-model-saved-path')
# And load the model from dir.
engine.load_model('/your-model-saved-path')
```

## Reference

[Hu et al. 2014] Baotian Hu, Zhengdong Lu, Hang Li, and Qingcai Chen. "Convolutional neural network architectures for matching natural language sentences." In Advances in neural information processing systems (NIPS), pp. 2042-2050. 2014.