<img src="https://github.com/NTMC-Community/MatchZoo/blob/2.0/artworks/matchzoo-logo.png?raw=True" alt="logo" style="width:600px;float: center"/>

In [1]:
import matchzoo as mz

Using TensorFlow backend.


# Prepare Data

In [57]:
train_data_pack = mz.datasets.wiki_qa.load_data(stage='train', task='ranking')
test_data_pack = mz.datasets.wiki_qa.load_data(stage='test', task='ranking')

In [58]:
type(train_data_pack)

matchzoo.data_pack.data_pack.DataPack

`DataPack` is a MatchZoo native data structure that most MatchZoo data handling processes build upon. A `DataPack` is consists of three `pandas.DataFrame`:

In [59]:
train_data_pack.left.head()

Unnamed: 0_level_0,text_left
id_left,Unnamed: 1_level_1
Q1,how are glacier caves formed?
Q2,How are the directions of the velocity and for...
Q5,how did apollo creed die
Q6,how long is the term for federal judges
Q7,how a beretta model 21 pistols magazines works


In [60]:
train_data_pack.right.head()

Unnamed: 0_level_0,text_right
id_right,Unnamed: 1_level_1
D1-0,A partly submerged glacier cave on Perito More...
D1-1,The ice facade is approximately 60 m high
D1-2,Ice formations in the Titlis glacier cave
D1-3,A glacier cave is a cave formed within the ice...
D1-4,"Glacier caves are often called ice caves , but..."


In [61]:
train_data_pack.relation.head()

Unnamed: 0,id_left,id_right,label
0,Q1,D1-0,0
1,Q1,D1-1,0
2,Q1,D1-2,0
3,Q1,D1-3,1
4,Q1,D1-4,0


It is also possible to convert a `DataPack` into a single `pandas.DataFrame` that holds all information.

In [62]:
train_data_pack.frame().head()

Unnamed: 0,id_left,text_left,id_right,text_right,label
0,Q1,how are glacier caves formed?,D1-0,A partly submerged glacier cave on Perito More...,0
1,Q1,how are glacier caves formed?,D1-1,The ice facade is approximately 60 m high,0
2,Q1,how are glacier caves formed?,D1-2,Ice formations in the Titlis glacier cave,0
3,Q1,how are glacier caves formed?,D1-3,A glacier cave is a cave formed within the ice...,1
4,Q1,how are glacier caves formed?,D1-4,"Glacier caves are often called ice caves , but...",0


However, using such `pandas.DataFrame` consumes much more memory if there are many duplicates in the texts, and that is the exact reason why we use `DataPack`. For more details about data handling, consult `matchzoo/tutorials/data_handling.ipynb`.

# Preprocessing

MatchZoo preprocessors are used to convert a raw `DataPack` into a `DataPack` that ready to be fed into a model. 

In [63]:
preprocessor = mz.preprocessors.NaivePreprocessor()

There are two steps to use a preprocessor. First, `fit`. Then, `transform`. `fit` will only changes the preprocessor's inner state but not the input `DataPack`.

In [64]:
preprocessor.fit(train_data_pack)

Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 2117/2117 [00:00<00:00, 10113.17it/s]
Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 18828/18828 [00:03<00:00, 4999.25it/s]
Processing text_left with extend: 100%|██████████| 2117/2117 [00:00<00:00, 647418.27it/s]
Processing text_right with extend: 100%|██████████| 18828/18828 [00:00<00:00, 690053.00it/s]
Building VocabularyUnit from a datapack.: 100%|██████████| 418540/418540 [00:00<00:00, 2473958.72it/s]


<matchzoo.preprocessors.naive_preprocessor.NaivePreprocessor at 0x11f3ccfd0>

`fit` will gather all information it needs into its `context`. In the above example, we can see a `VocabularyUnit` is built during the fitting process using `train_data_pack`.

In [65]:
preprocessor.context

{'vocab_unit': <matchzoo.processor_units.processor_units.VocabularyUnit at 0x12ce30fd0>}

`VocabularyUnit` is a `StatefulProcessorUnit` that has a similar `fit`/`transform` interface. Once a `VocabularyUnit` `fit`, it will store a mapping from `term` to `index` and the reverse in its `state`.

The `NaivePreprocessor` already handles `VocabularyUnit` internally, so we do not have to worry about that. Just access it through the `NaivePreprocessor`'s `context`.

In [66]:
vocab_unit = preprocessor.context['vocab_unit']
print(vocab_unit.state['term_index']['match'])
print(vocab_unit.state['term_index']['zoo'])
print(vocab_unit.state['index_term'][1])
print(vocab_unit.state['index_term'][2])

8783
1111
61
rapid


Once `fit`, the preprocessor has enough information to `transform`.  `transform` will not change the preprocessor's inner state and the input `DataPack`, but return a transformed `DataPack`.

In [67]:
train_data_pack_processed = preprocessor.transform(train_data_pack)
test_data_pack_processed = preprocessor.transform(test_data_pack)

Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => VocabularyUnit => FixedLengthUnit: 100%|██████████| 2117/2117 [00:00<00:00, 8024.41it/s]
Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => VocabularyUnit => FixedLengthUnit: 100%|██████████| 18828/18828 [00:04<00:00, 4681.04it/s]
Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => VocabularyUnit => FixedLengthUnit: 100%|██████████| 630/630 [00:00<00:00, 8012.50it/s]
Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => VocabularyUnit => FixedLengthUnit: 100%|██████████| 5914/5914 [00:01<00:00, 4581.09it/s]


In [68]:
train_data_pack_processed.left.head()

Unnamed: 0_level_0,text_left
id_left,Unnamed: 1_level_1
Q1,"[12865, 29105, 964, 4922, 22693, 0, 0, 0, 0, 0..."
Q2,"[12865, 29105, 15189, 4952, 3835, 15189, 833, ..."
Q5,"[12865, 17587, 25589, 25294, 794, 0, 0, 0, 0, ..."
Q6,"[12865, 4818, 13106, 15189, 19429, 24964, 7481..."
Q7,"[12865, 4851, 13642, 29984, 2429, 23111, 15512..."


As we can see, `text_left` is already in sequence form that nerual networks love.

Just to make sure we have the correct sequence:

In [69]:
print('Before:', train_data_pack.left.loc['Q1']['text_left'])
sequence = train_data_pack_processed.left.loc['Q1']['text_left']
print('After:', sequence)
print('Translated:', '_'.join([vocab_unit.state['index_term'][i] for i in sequence]))

Before: how are glacier caves formed?
After: [12865, 29105, 964, 4922, 22693, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Translated: how_are_glacier_caves_formed_________________________


For more details about data handling, consult `matchzoo/tutorials/preprocessing.ipynb`.

# Build Model

MatchZoo provides many built-in text matching models.

In [70]:
mz.models.list_available()

[matchzoo.models.naive_model.NaiveModel,
 matchzoo.models.dssm_model.DSSMModel,
 matchzoo.models.cdssm_model.CDSSMModel,
 matchzoo.models.dense_baseline_model.DenseBaselineModel,
 matchzoo.models.arci_model.ArcIModel,
 matchzoo.models.knrm_model.KNRMModel,
 matchzoo.models.duet_model.DUETModel,
 matchzoo.models.drmmtks_model.DRMMTKSModel,
 matchzoo.models.drmm.DRMM]

In [71]:
model = mz.models.DenseBaselineModel()

The model is initialized with a hyper parameter table, in which values are partially filled.

In [72]:
print(model.params)

name                          None
model_class                   <class 'matchzoo.models.dense_baseline_model.DenseBaselineModel'>
input_shapes                  None
task                          None
optimizer                     None
with_multi_layer_perceptron   True
mlp_num_units                 256
mlp_num_layers                None
mlp_num_fan_out               None
mlp_activation_func           None


In [73]:
model.params['name'] = 'My First Model'
model.params['mlp_num_units'] = 3
print(model.params)

name                          My First Model
model_class                   <class 'matchzoo.models.dense_baseline_model.DenseBaselineModel'>
input_shapes                  None
task                          None
optimizer                     None
with_multi_layer_perceptron   True
mlp_num_units                 3
mlp_num_layers                None
mlp_num_fan_out               None
mlp_activation_func           None


Use `guess_and_fill_missing_params` to automatically fill-in other hyper parameters. This involves some guessing so the parameter it fills could be wrong. For example, the default task is `Ranking`, and if we do not set it to `Classification` manaully for data packs prepared for classification, then the shape of the model output and the data will mismatch.

In [74]:
model.guess_and_fill_missing_params()
print(model.params)

Parameter "task" set to Ranking Task.
Parameter "input_shapes" set to [(30,), (30,)].
Parameter "optimizer" set to adam.
Parameter "mlp_num_layers" set to 3.
Parameter "mlp_num_fan_out" set to 32.
Parameter "mlp_activation_func" set to relu.
name                          My First Model
model_class                   <class 'matchzoo.models.dense_baseline_model.DenseBaselineModel'>
input_shapes                  [(30,), (30,)]
task                          Ranking Task
optimizer                     adam
with_multi_layer_perceptron   True
mlp_num_units                 3
mlp_num_layers                3
mlp_num_fan_out               32
mlp_activation_func           relu


In [75]:
model.params.completed()

True

With all parameters filled in, we can now build and compile the model.

In [76]:
model.build()
model.compile()
model.backend.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
text_left (InputLayer)          (None, 30)           0                                            
__________________________________________________________________________________________________
text_right (InputLayer)         (None, 30)           0                                            
__________________________________________________________________________________________________
concatenate_11 (Concatenate)    (None, 60)           0           text_left[0][0]                  
                                                                 text_right[0][0]                 
__________________________________________________________________________________________________
dense_99 (Dense)                (None, 3)            183         concatenate_11[0][0]             
__________

For more details about models, consult `matchzoo/tutorials/models.ipynb`.

# Train, Evaluate, Predict

A `DataPack` can `unpack` itself into data that can be directly used to train a MatchZoo model.

In [77]:
x, y = train_data_pack_processed.unpack()
test_x, test_y = test_data_pack_processed.unpack()

In [78]:
model.fit(x, y, batch_size=32, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x12cc389e8>

An alternative to train a model is to use a `DataGenerator`. This might be useful for delaying expensive preprocessing steps or doing real-time data augmentation. For more details about `DataGenerator`, consult `matchzoo/tutorials/data_handling.ipynb`.

In [79]:
data_generator = mz.DataGenerator(train_data_pack_processed, batch_size=32)

In [80]:
model.fit_generator(data_generator, epochs=5, use_multiprocessing=True, workers=4)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x12cb4f6a0>

In [81]:
model.evaluate(test_x, test_y)



{'loss': 0.04958058035328858, 'mean_absolute_error': 0.08137181649280577}

In [82]:
model.predict(test_x)

array([[ 0.0290747 ],
       [ 0.03049568],
       [ 0.05234022],
       ...,
       [ 0.05234022],
       [-0.01602625],
       [ 0.02186111]], dtype=float32)

# Automation

MatchZoo strives for ease of use, and package `matchzoo.auto` is a perfect example of that.

`matchzoo.auto.prepare` handles interaction among data, model, and preprocessor automatically. For example, some model like `DSSM` have dynamic input shapes based on the result of word hashing. Some models have an embedding layer which dimension is related to the data's vocabulary size. `prepare` takes care of all that and returns properly prepared model, data, and preprocessor for you.

In [83]:
model_ok, train_ok, preprocesor_ok = mz.auto.prepare(
    model=mz.models.DSSMModel(),
    data_pack=train_data_pack[:100]
)
test_ok = preprocesor_ok.transform(test_data_pack, verbose=0)
model_ok.fit(*train_ok.unpack(), batch_size=32)
model_ok.evaluate(*test_ok.unpack())

Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit => NgramLetterUnit: 100%|██████████| 13/13 [00:00<00:00, 2099.17it/s]
Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit => NgramLetterUnit: 100%|██████████| 100/100 [00:00<00:00, 2978.23it/s]
Processing text_left with extend: 100%|██████████| 13/13 [00:00<00:00, 16902.03it/s]
Processing text_right with extend: 100%|██████████| 100/100 [00:00<00:00, 82776.87it/s]
Building VocabularyUnit from a datapack.: 100%|██████████| 8523/8523 [00:00<00:00, 2028373.41it/s]
Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit => NgramLetterUnit => WordHashingUnit: 100%|██████████| 13/13 [00:00<00:00, 3881.96it/s]
Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit => NgramLetterUnit => WordHashingUnit: 100%

Parameter "name" set to DSSMModel.
Parameter "mlp_num_layers" set to 3.
Parameter "mlp_num_units" set to 64.
Parameter "mlp_num_fan_out" set to 32.
Parameter "mlp_activation_func" set to relu.





Epoch 1/1


{'loss': 0.04682399806163918, 'mean_absolute_error': 0.09136102424095159}

For more details about automation, consult `matchzoo/tutorials/automation.ipynb`.

# Full Example

In [50]:
model_classes = [
#     mz.models.DenseBaselineModel,
    mz.models.DSSMModel,
#     mz.models.ArcIModel,
    mz.models.DUETModel,
    mz.models.KNRMModel
]

In [52]:
task = mz.tasks.Ranking(metrics=['mae', 'ap', 'ndcg'])
results = []
for model_class in model_classes:
    print(model_class)
    model = model_class()
    model.params['task'] = task
    model_ok, train_ok, preprocesor_ok = mz.auto.prepare(
        model=model,
        data_pack=train_data_pack[:2000],
        verbose=0
    )
    test_ok = preprocesor_ok.transform(test_data_pack, verbose=0)
    callback = mz.engine.BaseModel.EvaluateOnCall(
        model_ok,
        *test_ok.unpack(),
        valid_steps=1,
        batch_size=1024
    )
    history = model_ok.fit(*train_ok.unpack(), batch_size=32, epochs=30, callbacks=[callback])
    results.append({
        'name': model_ok.params['name'],
        'history': history
    })

<class 'matchzoo.models.dssm_model.DSSMModel'>
Epoch 1/30


Validation: loss:0.045543 - mean_absolute_error:0.076757 - average_precision(0):0.109458 - normalized_discounted_cumulative_gain@1(0):0.171429


Epoch 2/30


Validation: loss:0.045394 - mean_absolute_error:0.088274 - average_precision(0):0.088499 - normalized_discounted_cumulative_gain@1(0):0.117460


Epoch 3/30


Validation: loss:0.045644 - mean_absolute_error:0.081708 - average_precision(0):0.071189 - normalized_discounted_cumulative_gain@1(0):0.058730


Epoch 4/30


Validation: loss:0.045792 - mean_absolute_error:0.098531 - average_precision(0):0.069814 - normalized_discounted_cumulative_gain@1(0):0.053968


Epoch 5/30


Validation: loss:0.048159 - mean_absolute_error:0.067191 - average_precision(0):0.071763 - normalized_discounted_cumulative_gain@1(0):0.052381


Epoch 6/30


Validation: loss:0.048167 - mean_absolute_error:0.065402 - average_precision(0):0.072011 - normalized_discounted_cumulative_gain@1(0):0.055556


Epoch 7/30


Validation: loss:0.046857 - mean_absolute_error:0.061412 - average_precision(0):0.077657 - normalized_discounted_cumulative_gain@1(0):0.061905


Epoch 8/30


Validation: loss:0.048006 - mean_absolute_error:0.061719 - average_precision(0):0.083894 - normalized_discounted_cumulative_gain@1(0):0.085714


Epoch 9/30


Validation: loss:0.047898 - mean_absolute_error:0.057489 - average_precision(0):0.085298 - normalized_discounted_cumulative_gain@1(0):0.082540


Epoch 10/30


Validation: loss:0.049338 - mean_absolute_error:0.061706 - average_precision(0):0.084895 - normalized_discounted_cumulative_gain@1(0):0.085714


Epoch 11/30


Validation: loss:0.050000 - mean_absolute_error:0.083023 - average_precision(0):0.076570 - normalized_discounted_cumulative_gain@1(0):0.066667


Epoch 12/30


Validation: loss:0.048649 - mean_absolute_error:0.059577 - average_precision(0):0.083245 - normalized_discounted_cumulative_gain@1(0):0.084127


Epoch 13/30


Validation: loss:0.052190 - mean_absolute_error:0.072221 - average_precision(0):0.078505 - normalized_discounted_cumulative_gain@1(0):0.082540


Epoch 14/30


Validation: loss:0.051632 - mean_absolute_error:0.064774 - average_precision(0):0.076152 - normalized_discounted_cumulative_gain@1(0):0.071429


Epoch 15/30


Validation: loss:0.052422 - mean_absolute_error:0.062620 - average_precision(0):0.074892 - normalized_discounted_cumulative_gain@1(0):0.065079


Epoch 16/30


Validation: loss:0.051442 - mean_absolute_error:0.068840 - average_precision(0):0.080503 - normalized_discounted_cumulative_gain@1(0):0.084127


Epoch 17/30


Validation: loss:0.050332 - mean_absolute_error:0.058878 - average_precision(0):0.078628 - normalized_discounted_cumulative_gain@1(0):0.077778


Epoch 18/30


Validation: loss:0.050741 - mean_absolute_error:0.062498 - average_precision(0):0.077367 - normalized_discounted_cumulative_gain@1(0):0.077778


Epoch 19/30


Validation: loss:0.051705 - mean_absolute_error:0.063853 - average_precision(0):0.077160 - normalized_discounted_cumulative_gain@1(0):0.074603


Epoch 20/30


Validation: loss:0.050908 - mean_absolute_error:0.060600 - average_precision(0):0.076329 - normalized_discounted_cumulative_gain@1(0):0.069841


Epoch 21/30


Validation: loss:0.052240 - mean_absolute_error:0.062494 - average_precision(0):0.078704 - normalized_discounted_cumulative_gain@1(0):0.077778


Epoch 22/30


Validation: loss:0.052651 - mean_absolute_error:0.062036 - average_precision(0):0.077991 - normalized_discounted_cumulative_gain@1(0):0.076190


Epoch 23/30


Validation: loss:0.052988 - mean_absolute_error:0.062354 - average_precision(0):0.078470 - normalized_discounted_cumulative_gain@1(0):0.076190


Epoch 24/30


Validation: loss:0.053495 - mean_absolute_error:0.062666 - average_precision(0):0.078839 - normalized_discounted_cumulative_gain@1(0):0.073016


Epoch 25/30


Validation: loss:0.053534 - mean_absolute_error:0.062591 - average_precision(0):0.075070 - normalized_discounted_cumulative_gain@1(0):0.073016


Epoch 26/30


Validation: loss:0.058544 - mean_absolute_error:0.072717 - average_precision(0):0.079654 - normalized_discounted_cumulative_gain@1(0):0.079365


Epoch 27/30


Validation: loss:0.056426 - mean_absolute_error:0.066179 - average_precision(0):0.078927 - normalized_discounted_cumulative_gain@1(0):0.082540


Epoch 28/30


Validation: loss:0.052934 - mean_absolute_error:0.061124 - average_precision(0):0.079760 - normalized_discounted_cumulative_gain@1(0):0.080952


Epoch 29/30


Validation: loss:0.052890 - mean_absolute_error:0.060001 - average_precision(0):0.077983 - normalized_discounted_cumulative_gain@1(0):0.076190


Epoch 30/30


Validation: loss:0.067992 - mean_absolute_error:0.082952 - average_precision(0):0.078702 - normalized_discounted_cumulative_gain@1(0):0.080952


<class 'matchzoo.models.duet_model.DUETModel'>
Epoch 1/30


Validation: loss:0.114333 - mean_absolute_error:0.230604 - average_precision(0):0.077900 - normalized_discounted_cumulative_gain@1(0):0.080952


Epoch 2/30


Validation: loss:0.109614 - mean_absolute_error:0.223631 - average_precision(0):0.076245 - normalized_discounted_cumulative_gain@1(0):0.079365


Epoch 3/30


Validation: loss:0.093209 - mean_absolute_error:0.205808 - average_precision(0):0.080472 - normalized_discounted_cumulative_gain@1(0):0.087302


Epoch 4/30


Validation: loss:0.067838 - mean_absolute_error:0.156972 - average_precision(0):0.069405 - normalized_discounted_cumulative_gain@1(0):0.050794


Epoch 5/30


Validation: loss:0.064790 - mean_absolute_error:0.145144 - average_precision(0):0.073291 - normalized_discounted_cumulative_gain@1(0):0.068254


Epoch 6/30


Validation: loss:0.061738 - mean_absolute_error:0.142388 - average_precision(0):0.092843 - normalized_discounted_cumulative_gain@1(0):0.103175


Epoch 7/30


Validation: loss:0.059069 - mean_absolute_error:0.122790 - average_precision(0):0.088096 - normalized_discounted_cumulative_gain@1(0):0.098413


Epoch 8/30


Validation: loss:0.072565 - mean_absolute_error:0.156123 - average_precision(0):0.087786 - normalized_discounted_cumulative_gain@1(0):0.106349


Epoch 9/30


Validation: loss:0.054217 - mean_absolute_error:0.145490 - average_precision(0):0.082633 - normalized_discounted_cumulative_gain@1(0):0.090476


Epoch 10/30


Validation: loss:0.055360 - mean_absolute_error:0.109603 - average_precision(0):0.086299 - normalized_discounted_cumulative_gain@1(0):0.101587


Epoch 11/30


Validation: loss:0.048227 - mean_absolute_error:0.095547 - average_precision(0):0.086365 - normalized_discounted_cumulative_gain@1(0):0.106349


Epoch 12/30


Validation: loss:0.050610 - mean_absolute_error:0.097966 - average_precision(0):0.090446 - normalized_discounted_cumulative_gain@1(0):0.112698


Epoch 13/30


Validation: loss:0.049483 - mean_absolute_error:0.110112 - average_precision(0):0.089243 - normalized_discounted_cumulative_gain@1(0):0.104762


Epoch 14/30


Validation: loss:0.047736 - mean_absolute_error:0.099076 - average_precision(0):0.088277 - normalized_discounted_cumulative_gain@1(0):0.107937


Epoch 15/30


Validation: loss:0.049395 - mean_absolute_error:0.100433 - average_precision(0):0.086719 - normalized_discounted_cumulative_gain@1(0):0.107937


Epoch 16/30


Validation: loss:0.048313 - mean_absolute_error:0.100516 - average_precision(0):0.084085 - normalized_discounted_cumulative_gain@1(0):0.093651


Epoch 17/30


Validation: loss:0.048822 - mean_absolute_error:0.090068 - average_precision(0):0.087820 - normalized_discounted_cumulative_gain@1(0):0.107937


Epoch 18/30


Validation: loss:0.047539 - mean_absolute_error:0.091345 - average_precision(0):0.089196 - normalized_discounted_cumulative_gain@1(0):0.111111


Epoch 19/30


Validation: loss:0.048329 - mean_absolute_error:0.096509 - average_precision(0):0.089706 - normalized_discounted_cumulative_gain@1(0):0.109524


Epoch 20/30


Validation: loss:0.049088 - mean_absolute_error:0.087453 - average_precision(0):0.087815 - normalized_discounted_cumulative_gain@1(0):0.103175


Epoch 21/30


Validation: loss:0.049250 - mean_absolute_error:0.092390 - average_precision(0):0.088795 - normalized_discounted_cumulative_gain@1(0):0.109524


Epoch 22/30


Validation: loss:0.047590 - mean_absolute_error:0.092381 - average_precision(0):0.087648 - normalized_discounted_cumulative_gain@1(0):0.104762


Epoch 23/30


Validation: loss:0.048999 - mean_absolute_error:0.088861 - average_precision(0):0.091012 - normalized_discounted_cumulative_gain@1(0):0.114286


Epoch 24/30


Validation: loss:0.047836 - mean_absolute_error:0.091863 - average_precision(0):0.089791 - normalized_discounted_cumulative_gain@1(0):0.111111


Epoch 25/30


Validation: loss:0.049269 - mean_absolute_error:0.097107 - average_precision(0):0.088192 - normalized_discounted_cumulative_gain@1(0):0.103175


Epoch 26/30


Validation: loss:0.047399 - mean_absolute_error:0.085143 - average_precision(0):0.090711 - normalized_discounted_cumulative_gain@1(0):0.122222


Epoch 27/30


Validation: loss:0.048005 - mean_absolute_error:0.081556 - average_precision(0):0.090060 - normalized_discounted_cumulative_gain@1(0):0.120635


Epoch 28/30


Validation: loss:0.051325 - mean_absolute_error:0.091785 - average_precision(0):0.092538 - normalized_discounted_cumulative_gain@1(0):0.122222


Epoch 29/30


Validation: loss:0.048427 - mean_absolute_error:0.089653 - average_precision(0):0.090019 - normalized_discounted_cumulative_gain@1(0):0.114286


Epoch 30/30


Validation: loss:0.047267 - mean_absolute_error:0.088103 - average_precision(0):0.088311 - normalized_discounted_cumulative_gain@1(0):0.112698


<class 'matchzoo.models.knrm_model.KNRMModel'>
Epoch 1/30


Validation: loss:117.930981 - mean_absolute_error:8.732786 - average_precision(0):0.081012 - normalized_discounted_cumulative_gain@1(0):0.076190


Epoch 2/30


Validation: loss:121.352373 - mean_absolute_error:8.829915 - average_precision(0):0.082520 - normalized_discounted_cumulative_gain@1(0):0.093651


Epoch 3/30


Validation: loss:98.819584 - mean_absolute_error:7.942581 - average_precision(0):0.082267 - normalized_discounted_cumulative_gain@1(0):0.080952


Epoch 4/30


Validation: loss:98.624207 - mean_absolute_error:7.914940 - average_precision(0):0.089310 - normalized_discounted_cumulative_gain@1(0):0.104762


Epoch 5/30


Validation: loss:110.373807 - mean_absolute_error:8.462737 - average_precision(0):0.086867 - normalized_discounted_cumulative_gain@1(0):0.090476


Epoch 6/30


Validation: loss:79.451089 - mean_absolute_error:7.127314 - average_precision(0):0.089864 - normalized_discounted_cumulative_gain@1(0):0.101587


Epoch 7/30


Validation: loss:82.577951 - mean_absolute_error:7.286862 - average_precision(0):0.087953 - normalized_discounted_cumulative_gain@1(0):0.095238


Epoch 8/30


Validation: loss:82.333500 - mean_absolute_error:7.232734 - average_precision(0):0.086417 - normalized_discounted_cumulative_gain@1(0):0.096825


Epoch 9/30


Validation: loss:77.742353 - mean_absolute_error:7.063289 - average_precision(0):0.086391 - normalized_discounted_cumulative_gain@1(0):0.090476


Epoch 10/30


Validation: loss:85.860183 - mean_absolute_error:7.395090 - average_precision(0):0.086827 - normalized_discounted_cumulative_gain@1(0):0.100000


Epoch 11/30


Validation: loss:77.322888 - mean_absolute_error:7.022624 - average_precision(0):0.084190 - normalized_discounted_cumulative_gain@1(0):0.101587


Epoch 12/30


Validation: loss:80.229193 - mean_absolute_error:7.160812 - average_precision(0):0.085076 - normalized_discounted_cumulative_gain@1(0):0.098413


Epoch 13/30


Validation: loss:77.051263 - mean_absolute_error:7.016855 - average_precision(0):0.085580 - normalized_discounted_cumulative_gain@1(0):0.093651


Epoch 14/30


Validation: loss:72.627709 - mean_absolute_error:6.814121 - average_precision(0):0.086618 - normalized_discounted_cumulative_gain@1(0):0.106349


Epoch 15/30


Validation: loss:72.603743 - mean_absolute_error:6.815005 - average_precision(0):0.089176 - normalized_discounted_cumulative_gain@1(0):0.101587


Epoch 16/30


Validation: loss:76.827244 - mean_absolute_error:7.015418 - average_precision(0):0.085690 - normalized_discounted_cumulative_gain@1(0):0.100000


Epoch 17/30


Validation: loss:78.383146 - mean_absolute_error:7.084431 - average_precision(0):0.085002 - normalized_discounted_cumulative_gain@1(0):0.095238


Epoch 18/30


Validation: loss:78.832836 - mean_absolute_error:7.093513 - average_precision(0):0.086213 - normalized_discounted_cumulative_gain@1(0):0.100000


Epoch 19/30


Validation: loss:70.772976 - mean_absolute_error:6.716375 - average_precision(0):0.086347 - normalized_discounted_cumulative_gain@1(0):0.101587


Epoch 20/30


Validation: loss:73.001729 - mean_absolute_error:6.820185 - average_precision(0):0.088929 - normalized_discounted_cumulative_gain@1(0):0.107937


Epoch 21/30


Validation: loss:70.820399 - mean_absolute_error:6.715217 - average_precision(0):0.087925 - normalized_discounted_cumulative_gain@1(0):0.101587


Epoch 22/30


Validation: loss:75.307090 - mean_absolute_error:6.953686 - average_precision(0):0.088235 - normalized_discounted_cumulative_gain@1(0):0.112698


Epoch 23/30


Validation: loss:73.888705 - mean_absolute_error:6.869039 - average_precision(0):0.086433 - normalized_discounted_cumulative_gain@1(0):0.101587


Epoch 24/30


Validation: loss:70.082874 - mean_absolute_error:6.698294 - average_precision(0):0.088442 - normalized_discounted_cumulative_gain@1(0):0.104762


Epoch 25/30


Validation: loss:72.201666 - mean_absolute_error:6.784047 - average_precision(0):0.087044 - normalized_discounted_cumulative_gain@1(0):0.104762


Epoch 26/30


Validation: loss:74.222488 - mean_absolute_error:6.889943 - average_precision(0):0.086905 - normalized_discounted_cumulative_gain@1(0):0.095238


Epoch 27/30


Validation: loss:73.290817 - mean_absolute_error:6.830411 - average_precision(0):0.086096 - normalized_discounted_cumulative_gain@1(0):0.098413


Epoch 28/30


Validation: loss:67.386414 - mean_absolute_error:6.560674 - average_precision(0):0.086512 - normalized_discounted_cumulative_gain@1(0):0.103175


Epoch 29/30


Validation: loss:71.351753 - mean_absolute_error:6.762890 - average_precision(0):0.087010 - normalized_discounted_cumulative_gain@1(0):0.104762


Epoch 30/30


Validation: loss:64.831316 - mean_absolute_error:6.433554 - average_precision(0):0.086082 - normalized_discounted_cumulative_gain@1(0):0.095238


In [53]:
import bokeh
from bokeh.io import output_notebook, push_notebook
from bokeh.layouts import column
from bokeh.models.glyphs import Line
from bokeh.models.tools import HoverTool
from bokeh.plotting import figure, show, Figure

In [56]:
charts = {
    metric: figure(
        title=str(metric),
        sizing_mode='scale_width',
        width=800, height=400
    ) for metric in results[0]['history'].history.keys()
}
for metric, sub_chart in charts.items():
    lines = {}
    for result, color in zip(results, bokeh.palettes.Category10[10]):
        x = result['history'].epoch
        y = result['history'].history[metric]
        lines[result['name']] = sub_chart.line(
            x, y, color=color, line_width=2, alpha=0.5, legend=result['name'])
output_notebook()
show(column(*charts.values()))