diff --git a/tutorials/wikiqa/dssm.ipynb b/tutorials/wikiqa/dssm.ipynb index 8d99ebc7..8eb6e11d 100644 --- a/tutorials/wikiqa/dssm.ipynb +++ b/tutorials/wikiqa/dssm.ipynb @@ -9,30 +9,25 @@ "name": "stderr", "output_type": "stream", "text": [ - "Using TensorFlow backend.\n", - "/data/users/fyx/.local/python3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n", - " return f(*args, **kwds)\n", - "/data/users/fyx/.local/python3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n", - " return f(*args, **kwds)\n" + "Using TensorFlow backend.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "matchzoo version 2.1.0\n", + "\n", + "data loading ...\n", + "data loaded as `train_pack_raw` `dev_pack_raw` `test_pack_raw`\n", + "`ranking_task` initialized with metrics [normalized_discounted_cumulative_gain@3(0.0), normalized_discounted_cumulative_gain@5(0.0), mean_average_precision(0.0)]\n", + "loading embedding ...\n", + "embedding loaded as `glove_embedding`\n" ] } ], "source": [ - "import keras\n", - "import numpy as np\n", - "import pandas as pd\n", - "import matchzoo as mz" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "train_pack = mz.datasets.wiki_qa.load_data('train', task='ranking')\n", - "valid_pack = mz.datasets.wiki_qa.load_data('dev', task='ranking', filter=True)\n", - "predict_pack = mz.datasets.wiki_qa.load_data('test', task='ranking', filter=True)" + "%run init.ipynb" ] }, { @@ -44,19 +39,25 @@ "name": "stderr", "output_type": "stream", "text": [ - "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit => NgramLetterUnit: 100%|██████████| 2118/2118 [00:00<00:00, 6893.66it/s]\n", - "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit => NgramLetterUnit: 100%|██████████| 18841/18841 [00:05<00:00, 3323.25it/s]\n", - "Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 514034.02it/s]\n", - "Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 430982.12it/s]\n", - "Building VocabularyUnit from a datapack.: 100%|██████████| 1614976/1614976 [00:00<00:00, 2964446.21it/s]\n", - "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit => NgramLetterUnit => WordHashingUnit: 100%|██████████| 2118/2118 [00:00<00:00, 4878.11it/s]\n", - "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit => NgramLetterUnit => WordHashingUnit: 100%|██████████| 18841/18841 [00:07<00:00, 2572.20it/s]\n" + "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter: 100%|██████████| 2118/2118 [00:00<00:00, 3802.39it/s]\n", + "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter: 100%|██████████| 18841/18841 [00:04<00:00, 3959.06it/s]\n", + "Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 822625.79it/s]\n", + "Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 597166.86it/s]\n", + "Building Vocabulary from a datapack.: 100%|██████████| 1614998/1614998 [00:00<00:00, 4642343.92it/s]\n", + "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter => WordHashing: 100%|██████████| 2118/2118 [00:00<00:00, 2853.90it/s]\n", + "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter => WordHashing: 100%|██████████| 18841/18841 [00:12<00:00, 1456.96it/s]\n", + "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter => WordHashing: 100%|██████████| 122/122 [00:00<00:00, 2308.40it/s]\n", + "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter => WordHashing: 100%|██████████| 1115/1115 [00:00<00:00, 2025.86it/s]\n", + "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter => WordHashing: 100%|██████████| 237/237 [00:00<00:00, 2678.58it/s]\n", + "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter => WordHashing: 100%|██████████| 2300/2300 [00:01<00:00, 1345.18it/s]\n" ] } ], "source": [ "preprocessor = mz.preprocessors.DSSMPreprocessor()\n", - "train_pack_processed = preprocessor.fit_transform(train_pack)" + "train_pack_processed = preprocessor.fit_transform(train_pack_raw)\n", + "valid_pack_processed = preprocessor.transform(dev_pack_raw)\n", + "test_pack_processed = preprocessor.transform(test_pack_raw)" ] }, { @@ -65,19 +66,21 @@ "metadata": {}, "outputs": [ { - "name": "stderr", - "output_type": "stream", - "text": [ - "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit => NgramLetterUnit => WordHashingUnit: 100%|██████████| 122/122 [00:00<00:00, 4624.45it/s]\n", - "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit => NgramLetterUnit => WordHashingUnit: 100%|██████████| 1115/1115 [00:00<00:00, 2609.60it/s]\n", - "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit => NgramLetterUnit => WordHashingUnit: 100%|██████████| 237/237 [00:00<00:00, 5193.33it/s]\n", - "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit => NgramLetterUnit => WordHashingUnit: 100%|██████████| 2300/2300 [00:00<00:00, 2579.14it/s]\n" - ] + "data": { + "text/plain": [ + "{'vocab_unit': ,\n", + " 'vocab_size': 9645,\n", + " 'embedding_input_dim': 9645,\n", + " 'input_shapes': [(9645,), (9645,)]}" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "valid_pack_processed = preprocessor.transform(valid_pack)\n", - "predict_pack_processed = preprocessor.transform(predict_pack)" + "preprocessor.context" ] }, { @@ -103,17 +106,16 @@ "name": "stdout", "output_type": "stream", "text": [ - "Parameter \"name\" set to DSSM.\n", "__________________________________________________________________________________________________\n", "Layer (type) Output Shape Param # Connected to \n", "==================================================================================================\n", - "text_left (InputLayer) (None, 9644) 0 \n", + "text_left (InputLayer) (None, 9645) 0 \n", "__________________________________________________________________________________________________\n", - "text_right (InputLayer) (None, 9644) 0 \n", + "text_right (InputLayer) (None, 9645) 0 \n", "__________________________________________________________________________________________________\n", - "dense_1 (Dense) (None, 300) 2893500 text_left[0][0] \n", + "dense_1 (Dense) (None, 300) 2893800 text_left[0][0] \n", "__________________________________________________________________________________________________\n", - "dense_5 (Dense) (None, 300) 2893500 text_right[0][0] \n", + "dense_5 (Dense) (None, 300) 2893800 text_right[0][0] \n", "__________________________________________________________________________________________________\n", "dense_2 (Dense) (None, 300) 90300 dense_1[0][0] \n", "__________________________________________________________________________________________________\n", @@ -132,8 +134,8 @@ "__________________________________________________________________________________________________\n", "dense_9 (Dense) (None, 1) 2 dot_1[0][0] \n", "==================================================================================================\n", - "Total params: 6,225,258\n", - "Trainable params: 6,225,258\n", + "Total params: 6,225,858\n", + "Trainable params: 6,225,858\n", "Non-trainable params: 0\n", "__________________________________________________________________________________________________\n" ] @@ -155,38 +157,45 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ - "pred_x, pred_y = predict_pack_processed[:].unpack()\n", + "pred_x, pred_y = test_pack_processed[:].unpack()\n", "evaluate = mz.callbacks.EvaluateAllMetrics(model, x=pred_x, y=pred_y, batch_size=len(pred_x))" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 11, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WARNING: PairDataGenerator will be deprecated in MatchZoo v2.2. Use `DataGenerator` with callbacks instead.\n" + ] + }, { "data": { "text/plain": [ - "16" + "32" ] }, - "execution_count": 8, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "train_generator = mz.PairDataGenerator(train_pack_processed, num_dup=1, num_neg=4, batch_size=64, shuffle=True)\n", + "train_generator = mz.PairDataGenerator(train_pack_processed, num_dup=1, num_neg=4, batch_size=32, shuffle=True)\n", "len(train_generator)" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 12, "metadata": { "scrolled": false }, @@ -196,65 +205,65 @@ "output_type": "stream", "text": [ "Epoch 1/20\n", - "16/16 [==============================] - 3s 175ms/step - loss: 1.5564\n", - "Validation: normalized_discounted_cumulative_gain@3(0): 0.45901904458281206 - normalized_discounted_cumulative_gain@5(0): 0.5475429093192175 - mean_average_precision(0): 0.4911167747235369\n", + "32/32 [==============================] - 7s 215ms/step - loss: 1.3325\n", + "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4431849853601904 - normalized_discounted_cumulative_gain@5(0.0): 0.5295386323998266 - mean_average_precision(0.0): 0.48303488812718776\n", "Epoch 2/20\n", - "16/16 [==============================] - 2s 96ms/step - loss: 1.2718\n", - "Validation: normalized_discounted_cumulative_gain@3(0): 0.4654009025837788 - normalized_discounted_cumulative_gain@5(0): 0.5392725183775516 - mean_average_precision(0): 0.48941174032387735\n", + "32/32 [==============================] - 6s 176ms/step - loss: 1.3159\n", + "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4353814849661657 - normalized_discounted_cumulative_gain@5(0.0): 0.5032525911610362 - mean_average_precision(0.0): 0.4776049822282439\n", "Epoch 3/20\n", - "16/16 [==============================] - 2s 100ms/step - loss: 1.1539\n", - "Validation: normalized_discounted_cumulative_gain@3(0): 0.4791185363644891 - normalized_discounted_cumulative_gain@5(0): 0.5538418263100713 - mean_average_precision(0): 0.5071942064030672\n", + "32/32 [==============================] - 5s 171ms/step - loss: 1.2955\n", + "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4088637099689691 - normalized_discounted_cumulative_gain@5(0.0): 0.48351010067595823 - mean_average_precision(0.0): 0.4432379861560312\n", "Epoch 4/20\n", - "16/16 [==============================] - 2s 99ms/step - loss: 1.0995\n", - "Validation: normalized_discounted_cumulative_gain@3(0): 0.4886745351048811 - normalized_discounted_cumulative_gain@5(0): 0.5562701289500901 - mean_average_precision(0): 0.5133703768780384\n", + "32/32 [==============================] - 6s 173ms/step - loss: 1.2726\n", + "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.46569627211992487 - normalized_discounted_cumulative_gain@5(0.0): 0.5305277638291452 - mean_average_precision(0.0): 0.4903964896023526\n", "Epoch 5/20\n", - "16/16 [==============================] - 2s 95ms/step - loss: 1.0674\n", - "Validation: normalized_discounted_cumulative_gain@3(0): 0.4853827809349306 - normalized_discounted_cumulative_gain@5(0): 0.5682223793780434 - mean_average_precision(0): 0.5169149799106053\n", + "32/32 [==============================] - 6s 172ms/step - loss: 1.2439\n", + "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.44778538209256513 - normalized_discounted_cumulative_gain@5(0.0): 0.5104380434420628 - mean_average_precision(0.0): 0.47615129143046664\n", "Epoch 6/20\n", - "16/16 [==============================] - 2s 99ms/step - loss: 1.0479\n", - "Validation: normalized_discounted_cumulative_gain@3(0): 0.46480280740048857 - normalized_discounted_cumulative_gain@5(0): 0.5281305738905067 - mean_average_precision(0): 0.492201293611383\n", + "32/32 [==============================] - 6s 172ms/step - loss: 1.2202\n", + "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4452573045503587 - normalized_discounted_cumulative_gain@5(0.0): 0.5137975378931312 - mean_average_precision(0.0): 0.4742872412051932\n", "Epoch 7/20\n", - "16/16 [==============================] - 2s 94ms/step - loss: 1.0274\n", - "Validation: normalized_discounted_cumulative_gain@3(0): 0.4898804500151732 - normalized_discounted_cumulative_gain@5(0): 0.560890920302221 - mean_average_precision(0): 0.5058987750652866\n", + "32/32 [==============================] - 5s 170ms/step - loss: 1.2038\n", + "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.41264292792428936 - normalized_discounted_cumulative_gain@5(0.0): 0.4740615140630128 - mean_average_precision(0.0): 0.45294026408574084\n", "Epoch 8/20\n", - "16/16 [==============================] - 2s 97ms/step - loss: 1.0142\n", - "Validation: normalized_discounted_cumulative_gain@3(0): 0.4941487381639577 - normalized_discounted_cumulative_gain@5(0): 0.5665778188888335 - mean_average_precision(0): 0.5140344974996873\n", + "32/32 [==============================] - 6s 172ms/step - loss: 1.1848\n", + "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.45527149721829696 - normalized_discounted_cumulative_gain@5(0.0): 0.5229678873030444 - mean_average_precision(0.0): 0.48490323375232625\n", "Epoch 9/20\n", - "16/16 [==============================] - 2s 98ms/step - loss: 1.0008\n", - "Validation: normalized_discounted_cumulative_gain@3(0): 0.5069173712527351 - normalized_discounted_cumulative_gain@5(0): 0.5866287176354987 - mean_average_precision(0): 0.5269820047509921\n", + "32/32 [==============================] - 5s 171ms/step - loss: 1.1504 3\n", + "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4401749964298954 - normalized_discounted_cumulative_gain@5(0.0): 0.5202410581724496 - mean_average_precision(0.0): 0.47967943778482564\n", "Epoch 10/20\n", - "16/16 [==============================] - 2s 97ms/step - loss: 0.9873\n", - "Validation: normalized_discounted_cumulative_gain@3(0): 0.49700543338958786 - normalized_discounted_cumulative_gain@5(0): 0.5805879493729443 - mean_average_precision(0): 0.5150557804829956\n", + "32/32 [==============================] - 5s 172ms/step - loss: 1.1314\n", + "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.44883790476151675 - normalized_discounted_cumulative_gain@5(0.0): 0.5215788412779597 - mean_average_precision(0.0): 0.48274548802838624\n", "Epoch 11/20\n", - "16/16 [==============================] - 2s 109ms/step - loss: 0.9750\n", - "Validation: normalized_discounted_cumulative_gain@3(0): 0.5025024116631714 - normalized_discounted_cumulative_gain@5(0): 0.5923504552250592 - mean_average_precision(0): 0.5225994206215725\n", + "32/32 [==============================] - 6s 173ms/step - loss: 1.1109\n", + "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.45835958548802597 - normalized_discounted_cumulative_gain@5(0.0): 0.5254562351939174 - mean_average_precision(0.0): 0.48819163523037407\n", "Epoch 12/20\n", - "16/16 [==============================] - 2s 98ms/step - loss: 0.9644\n", - "Validation: normalized_discounted_cumulative_gain@3(0): 0.5103579714392016 - normalized_discounted_cumulative_gain@5(0): 0.5903011924569881 - mean_average_precision(0): 0.5302960840144384\n", + "32/32 [==============================] - 6s 174ms/step - loss: 1.0915\n", + "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4540812972538116 - normalized_discounted_cumulative_gain@5(0.0): 0.502728792326375 - mean_average_precision(0.0): 0.48229166522394096\n", "Epoch 13/20\n", - "16/16 [==============================] - 2s 95ms/step - loss: 0.9534\n", - "Validation: normalized_discounted_cumulative_gain@3(0): 0.5160164432378087 - normalized_discounted_cumulative_gain@5(0): 0.5993189848710705 - mean_average_precision(0): 0.5396924014803761\n", + "32/32 [==============================] - 6s 173ms/step - loss: 1.0805\n", + "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4462255256302118 - normalized_discounted_cumulative_gain@5(0.0): 0.5097488218798687 - mean_average_precision(0.0): 0.4751972950775518\n", "Epoch 14/20\n", - "16/16 [==============================] - 2s 97ms/step - loss: 0.9442\n", - "Validation: normalized_discounted_cumulative_gain@3(0): 0.5032196943848134 - normalized_discounted_cumulative_gain@5(0): 0.5852110589773137 - mean_average_precision(0): 0.5268703682283354\n", + "32/32 [==============================] - 6s 174ms/step - loss: 1.0575\n", + "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4263585495923587 - normalized_discounted_cumulative_gain@5(0.0): 0.5014903707963352 - mean_average_precision(0.0): 0.46364289738480496\n", "Epoch 15/20\n", - "16/16 [==============================] - 2s 95ms/step - loss: 0.9336\n", - "Validation: normalized_discounted_cumulative_gain@3(0): 0.5317052606549753 - normalized_discounted_cumulative_gain@5(0): 0.6023052111746768 - mean_average_precision(0): 0.5478608387627374\n", + "32/32 [==============================] - 6s 179ms/step - loss: 1.0396\n", + "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.43936705108731194 - normalized_discounted_cumulative_gain@5(0.0): 0.5218713927469146 - mean_average_precision(0.0): 0.47233172236473137\n", "Epoch 16/20\n", - "16/16 [==============================] - 2s 97ms/step - loss: 0.9239\n", - "Validation: normalized_discounted_cumulative_gain@3(0): 0.508831525587222 - normalized_discounted_cumulative_gain@5(0): 0.5817387639362221 - mean_average_precision(0): 0.5274733502791022\n", + "32/32 [==============================] - 6s 182ms/step - loss: 1.0156\n", + "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.45080782122574514 - normalized_discounted_cumulative_gain@5(0.0): 0.5181271382497495 - mean_average_precision(0.0): 0.4832342072703635\n", "Epoch 17/20\n", - "16/16 [==============================] - 2s 97ms/step - loss: 0.9144\n", - "Validation: normalized_discounted_cumulative_gain@3(0): 0.5203610304609562 - normalized_discounted_cumulative_gain@5(0): 0.5991831842478874 - mean_average_precision(0): 0.5431669157143841\n", + "32/32 [==============================] - 6s 175ms/step - loss: 0.9932\n", + "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.423108628561739 - normalized_discounted_cumulative_gain@5(0.0): 0.49596605935842625 - mean_average_precision(0.0): 0.4667294180948952\n", "Epoch 18/20\n", - "16/16 [==============================] - 2s 98ms/step - loss: 0.9071\n", - "Validation: normalized_discounted_cumulative_gain@3(0): 0.5138966254324063 - normalized_discounted_cumulative_gain@5(0): 0.5919938110840292 - mean_average_precision(0): 0.5360874679703793\n", + "32/32 [==============================] - 5s 172ms/step - loss: 0.9800\n", + "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4378084124127128 - normalized_discounted_cumulative_gain@5(0.0): 0.5098753091251295 - mean_average_precision(0.0): 0.4734416114488085\n", "Epoch 19/20\n", - "16/16 [==============================] - 2s 97ms/step - loss: 0.8990\n", - "Validation: normalized_discounted_cumulative_gain@3(0): 0.4872557976860261 - normalized_discounted_cumulative_gain@5(0): 0.5698167643783474 - mean_average_precision(0): 0.5167562193088068\n", + "32/32 [==============================] - 6s 172ms/step - loss: 0.9662\n", + "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4504450479915345 - normalized_discounted_cumulative_gain@5(0.0): 0.519107636100811 - mean_average_precision(0.0): 0.48712867088141415\n", "Epoch 20/20\n", - "16/16 [==============================] - 2s 98ms/step - loss: 0.8915\n", - "Validation: normalized_discounted_cumulative_gain@3(0): 0.5492809107350253 - normalized_discounted_cumulative_gain@5(0): 0.6194080901274281 - mean_average_precision(0): 0.5626976311311754\n" + "32/32 [==============================] - 6s 172ms/step - loss: 0.9512\n", + "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.45663442312293695 - normalized_discounted_cumulative_gain@5(0.0): 0.5363645153841258 - mean_average_precision(0.0): 0.4956098197015037\n" ] } ], @@ -286,7 +295,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.5" + "version": "3.6.7" } }, "nbformat": 4,