diff --git a/tutorials/wikiqa/cdssm.ipynb b/tutorials/wikiqa/cdssm.ipynb index 296935cd..84f17bb1 100644 --- a/tutorials/wikiqa/cdssm.ipynb +++ b/tutorials/wikiqa/cdssm.ipynb @@ -11,105 +11,70 @@ "text": [ "Using TensorFlow backend.\n" ] - } - ], - "source": [ - "import os\n", - "import keras\n", - "import pandas as pd\n", - "import numpy as np\n", - "import matchzoo as mz\n", - "\n", - "os.environ['CUDA_VISIBLE_DEVICES'] = ''" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "train_pack = mz.datasets.wiki_qa.load_data('train', task='ranking')\n", - "valid_pack = mz.datasets.wiki_qa.load_data('dev', task='ranking', filtered=True)\n", - "predict_pack = mz.datasets.wiki_qa.load_data('test', task='ranking', filtered=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ + }, { - "name": "stderr", + "name": "stdout", "output_type": "stream", "text": [ - "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter: 100%|██████████| 2118/2118 [00:00<00:00, 6542.53it/s]\n", - "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter: 100%|██████████| 18841/18841 [00:05<00:00, 3325.99it/s]\n", - "Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 584789.41it/s]\n", - "Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 438625.05it/s]\n", - "Building Vocabulary from a datapack.: 100%|██████████| 1614998/1614998 [00:00<00:00, 3082231.57it/s]\n", - "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 7934.07it/s]\n", - "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 18841/18841 [00:04<00:00, 3811.25it/s]\n", - "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 115942.78it/s]\n", - "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 87315.29it/s]\n", - "Processing text_left with chain_transform of NgramLetter => WordHashing: 100%|██████████| 2118/2118 [00:08<00:00, 261.04it/s]\n", - "Processing text_right with chain_transform of NgramLetter => WordHashing: 100%|██████████| 18841/18841 [01:24<00:00, 220.85it/s]\n" + "matchzoo version 2.1.0\n", + "\n", + "data loading ...\n", + "data loaded as `train_pack_raw` `dev_pack_raw` `test_pack_raw`\n", + "`ranking_task` initialized with metrics [normalized_discounted_cumulative_gain@3(0.0), normalized_discounted_cumulative_gain@5(0.0), mean_average_precision(0.0)]\n", + "loading embedding ...\n", + "embedding loaded as `glove_embedding`\n" ] } ], "source": [ - "preprocessor = mz.preprocessors.CDSSMPreprocessor(fixed_length_left=10, fixed_length_right=10)\n", - "train_pack_processed = preprocessor.fit_transform(train_pack)" + "%run init.ipynb" ] }, { "cell_type": "code", - "execution_count": 4, - "metadata": { - "scrolled": false - }, + "execution_count": 2, + "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 122/122 [00:00<00:00, 7332.17it/s]\n", - "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 1115/1115 [00:00<00:00, 4039.28it/s]\n", - "Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 71677.42it/s]\n", - "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 86562.93it/s]\n", - "Processing text_left with chain_transform of NgramLetter => WordHashing: 100%|██████████| 122/122 [00:00<00:00, 249.16it/s]\n", - "Processing text_right with chain_transform of NgramLetter => WordHashing: 100%|██████████| 1115/1115 [00:18<00:00, 60.13it/s] \n", - "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 237/237 [00:00<00:00, 8283.96it/s]\n", - "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 2300/2300 [00:01<00:00, 1951.14it/s]\n", - "Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 92016.11it/s]\n", - "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 89923.46it/s]\n", - "Processing text_left with chain_transform of NgramLetter => WordHashing: 100%|██████████| 237/237 [00:00<00:00, 249.05it/s]\n", - "Processing text_right with chain_transform of NgramLetter => WordHashing: 100%|██████████| 2300/2300 [00:11<00:00, 239.25it/s]\n" + "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter: 100%|██████████| 2118/2118 [00:00<00:00, 5365.33it/s]\n", + "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval => NgramLetter: 100%|██████████| 18841/18841 [00:05<00:00, 3205.80it/s]\n", + "Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 310569.71it/s]\n", + "Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 349915.35it/s]\n", + "Building Vocabulary from a datapack.: 100%|██████████| 1614998/1614998 [00:00<00:00, 3031577.24it/s]\n", + "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 8384.04it/s]\n", + "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 18841/18841 [00:04<00:00, 3939.48it/s]\n", + "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 125562.34it/s]\n", + "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 96621.02it/s]\n", + "Processing text_left with chain_transform of NgramLetter => WordHashing: 100%|██████████| 2118/2118 [00:07<00:00, 269.84it/s]\n", + "Processing text_right with chain_transform of NgramLetter => WordHashing: 100%|██████████| 18841/18841 [01:27<00:00, 216.37it/s]\n", + "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 122/122 [00:00<00:00, 7746.89it/s]\n", + "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 1115/1115 [00:14<00:00, 77.28it/s] \n", + "Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 63447.62it/s]\n", + "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 88946.88it/s]\n", + "Processing text_left with chain_transform of NgramLetter => WordHashing: 100%|██████████| 122/122 [00:00<00:00, 273.77it/s]\n", + "Processing text_right with chain_transform of NgramLetter => WordHashing: 100%|██████████| 1115/1115 [00:04<00:00, 226.13it/s]\n", + "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 237/237 [00:00<00:00, 8707.90it/s]\n", + "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval => StopRemoval: 100%|██████████| 2300/2300 [00:01<00:00, 2237.22it/s]\n", + "Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 101299.30it/s]\n", + "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 97484.78it/s]\n", + "Processing text_left with chain_transform of NgramLetter => WordHashing: 100%|██████████| 237/237 [00:00<00:00, 269.12it/s]\n", + "Processing text_right with chain_transform of NgramLetter => WordHashing: 100%|██████████| 2300/2300 [00:10<00:00, 212.35it/s]\n" ] } ], "source": [ - "valid_pack_processed = preprocessor.transform(valid_pack)\n", - "predict_pack_processed = preprocessor.transform(predict_pack)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "ranking_task = mz.tasks.Ranking(loss=mz.losses.RankHingeLoss())\n", - "ranking_task.metrics = [\n", - " mz.metrics.NormalizedDiscountedCumulativeGain(k=3),\n", - " mz.metrics.NormalizedDiscountedCumulativeGain(k=5),\n", - " mz.metrics.MeanAveragePrecision()\n", - "]" + "preprocessor = mz.preprocessors.CDSSMPreprocessor(fixed_length_left=10, fixed_length_right=10)\n", + "train_pack_processed = preprocessor.fit_transform(train_pack_raw)\n", + "valid_pack_processed = preprocessor.transform(dev_pack_raw)\n", + "test_pack_processed = preprocessor.transform(test_pack_raw)" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 3, "metadata": { "scrolled": false }, @@ -182,34 +147,33 @@ }, { "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "pred_x, pred_y = predict_pack_processed[:].unpack()" - ] - }, - { - "cell_type": "code", - "execution_count": 8, + "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "WARNING: PairDataGenerator will be deprecated in MatchZoo v2.2. Use `DataGenerator` with callbacks instead.\n" + "num batches: 102\n" ] } ], "source": [ + "pred_x, pred_y = test_pack_processed[:].unpack()\n", "evaluate = mz.callbacks.EvaluateAllMetrics(model, x=pred_x, y=pred_y, batch_size=len(pred_x))\n", - "train_generator = mz.PairDataGenerator(train_pack_processed, num_dup=2, num_neg=1, batch_size=64, shuffle=True)" + "train_generator = mz.DataGenerator(\n", + " train_pack_processed,\n", + " mode='pair',\n", + " num_dup=2,\n", + " num_neg=1,\n", + " batch_size=20\n", + ")\n", + "print('num batches:', len(train_generator))" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 5, "metadata": { "scrolled": false }, @@ -219,65 +183,65 @@ "output_type": "stream", "text": [ "Epoch 1/20\n", - "32/32 [==============================] - 72s 2s/step - loss: 0.9808\n", - "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.42488215100974375 - normalized_discounted_cumulative_gain@5(0.0): 0.4909840291447614 - mean_average_precision(0.0): 0.45215174584733625\n", + "102/102 [==============================] - 65s 635ms/step - loss: 0.8021\n", + "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.42304382290227854 - normalized_discounted_cumulative_gain@5(0.0): 0.49915948768338086 - mean_average_precision(0.0): 0.46037758752542035\n", "Epoch 2/20\n", - "32/32 [==============================] - 39s 1s/step - loss: 0.7890\n", - "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4523924010888256 - normalized_discounted_cumulative_gain@5(0.0): 0.526580105002082 - mean_average_precision(0.0): 0.4843661871232366\n", + "102/102 [==============================] - 45s 445ms/step - loss: 0.5966\n", + "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.43781763271520285 - normalized_discounted_cumulative_gain@5(0.0): 0.520097599085372 - mean_average_precision(0.0): 0.4762598411822459\n", "Epoch 3/20\n", - "32/32 [==============================] - 39s 1s/step - loss: 0.6269\n", - "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.42912083006483787 - normalized_discounted_cumulative_gain@5(0.0): 0.5059856809765926 - mean_average_precision(0.0): 0.46958607676474695\n", + "102/102 [==============================] - 46s 447ms/step - loss: 0.4992\n", + "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.44923101748788413 - normalized_discounted_cumulative_gain@5(0.0): 0.5136672947113214 - mean_average_precision(0.0): 0.4803110559647868\n", "Epoch 4/20\n", - "32/32 [==============================] - 40s 1s/step - loss: 0.5425\n", - "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.45161347891396936 - normalized_discounted_cumulative_gain@5(0.0): 0.5246607686333176 - mean_average_precision(0.0): 0.4894374228928154\n", + "102/102 [==============================] - 46s 446ms/step - loss: 0.4143\n", + "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.467714954371615 - normalized_discounted_cumulative_gain@5(0.0): 0.5353130653753986 - mean_average_precision(0.0): 0.5017560318255102\n", "Epoch 5/20\n", - "32/32 [==============================] - 39s 1s/step - loss: 0.4786\n", - "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.44182189699245444 - normalized_discounted_cumulative_gain@5(0.0): 0.5185925913281365 - mean_average_precision(0.0): 0.4785410987945083\n", + "102/102 [==============================] - 46s 451ms/step - loss: 0.3489\n", + "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4706511291875292 - normalized_discounted_cumulative_gain@5(0.0): 0.5275832328072992 - mean_average_precision(0.0): 0.5026243479583462\n", "Epoch 6/20\n", - "32/32 [==============================] - 40s 1s/step - loss: 0.4312\n", - "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4353736449145257 - normalized_discounted_cumulative_gain@5(0.0): 0.5193283626946119 - mean_average_precision(0.0): 0.47581070046166846\n", + "102/102 [==============================] - 45s 443ms/step - loss: 0.3231\n", + "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4641151831570107 - normalized_discounted_cumulative_gain@5(0.0): 0.5219564667466021 - mean_average_precision(0.0): 0.4934049132027672\n", "Epoch 7/20\n", - "32/32 [==============================] - 40s 1s/step - loss: 0.3918\n", - "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.42860677523860297 - normalized_discounted_cumulative_gain@5(0.0): 0.511305679951288 - mean_average_precision(0.0): 0.465896748198743\n", + "102/102 [==============================] - 44s 433ms/step - loss: 0.2695\n", + "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4760514687477512 - normalized_discounted_cumulative_gain@5(0.0): 0.5285019348702702 - mean_average_precision(0.0): 0.49994736585416333\n", "Epoch 8/20\n", - "32/32 [==============================] - 40s 1s/step - loss: 0.3399\n", - "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4483104574863611 - normalized_discounted_cumulative_gain@5(0.0): 0.514850172629585 - mean_average_precision(0.0): 0.47729829884300434\n", + "102/102 [==============================] - 47s 457ms/step - loss: 0.2331\n", + "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4661781341235399 - normalized_discounted_cumulative_gain@5(0.0): 0.5260071867435453 - mean_average_precision(0.0): 0.4922605321622356\n", "Epoch 9/20\n", - "32/32 [==============================] - 39s 1s/step - loss: 0.3133\n", - "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.44683839351327287 - normalized_discounted_cumulative_gain@5(0.0): 0.5179511840387869 - mean_average_precision(0.0): 0.47938829437070807\n", + "102/102 [==============================] - 46s 453ms/step - loss: 0.1942\n", + "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4645719968337067 - normalized_discounted_cumulative_gain@5(0.0): 0.5238558790194195 - mean_average_precision(0.0): 0.48851468090847294\n", "Epoch 10/20\n", - "32/32 [==============================] - 39s 1s/step - loss: 0.2852\n", - "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.45147495790907527 - normalized_discounted_cumulative_gain@5(0.0): 0.5115591547752104 - mean_average_precision(0.0): 0.4816185898296844\n", + "102/102 [==============================] - 45s 444ms/step - loss: 0.1734\n", + "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4600910137969285 - normalized_discounted_cumulative_gain@5(0.0): 0.5320923473092672 - mean_average_precision(0.0): 0.48703092961044614\n", "Epoch 11/20\n", - "32/32 [==============================] - 39s 1s/step - loss: 0.2503\n", - "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.45606725359003863 - normalized_discounted_cumulative_gain@5(0.0): 0.5137853825095071 - mean_average_precision(0.0): 0.4898566499228558\n", + "102/102 [==============================] - 47s 464ms/step - loss: 0.1644\n", + "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.45786306386326225 - normalized_discounted_cumulative_gain@5(0.0): 0.5246949873542252 - mean_average_precision(0.0): 0.48502089087514016\n", "Epoch 12/20\n", - "32/32 [==============================] - 40s 1s/step - loss: 0.2114\n", - "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4536596944525606 - normalized_discounted_cumulative_gain@5(0.0): 0.5195598827593197 - mean_average_precision(0.0): 0.48509265834715537\n", + "102/102 [==============================] - 45s 443ms/step - loss: 0.1560\n", + "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4567332855642369 - normalized_discounted_cumulative_gain@5(0.0): 0.528074374789356 - mean_average_precision(0.0): 0.4905494464640722\n", "Epoch 13/20\n", - "32/32 [==============================] - 40s 1s/step - loss: 0.1886\n", - "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.46086526212639095 - normalized_discounted_cumulative_gain@5(0.0): 0.5347221469997477 - mean_average_precision(0.0): 0.4972628568563632\n", + "102/102 [==============================] - 45s 440ms/step - loss: 0.1365\n", + "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4698836510431016 - normalized_discounted_cumulative_gain@5(0.0): 0.5317255666034969 - mean_average_precision(0.0): 0.49152222966181813\n", "Epoch 14/20\n", - "32/32 [==============================] - 40s 1s/step - loss: 0.1770\n", - "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4524514918879093 - normalized_discounted_cumulative_gain@5(0.0): 0.5259706685008083 - mean_average_precision(0.0): 0.4886127130697417\n", + "102/102 [==============================] - 45s 437ms/step - loss: 0.1263\n", + "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.46841048236088156 - normalized_discounted_cumulative_gain@5(0.0): 0.5197949838102164 - mean_average_precision(0.0): 0.4887341126171474\n", "Epoch 15/20\n", - "32/32 [==============================] - 38s 1s/step - loss: 0.1454\n", - "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.447793222144205 - normalized_discounted_cumulative_gain@5(0.0): 0.523210520595902 - mean_average_precision(0.0): 0.4842295350811732\n", + "102/102 [==============================] - 46s 449ms/step - loss: 0.1208\n", + "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4591952265806063 - normalized_discounted_cumulative_gain@5(0.0): 0.5306329604843507 - mean_average_precision(0.0): 0.4956899590808506\n", "Epoch 16/20\n", - "32/32 [==============================] - 39s 1s/step - loss: 0.1573\n", - "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4617597821704854 - normalized_discounted_cumulative_gain@5(0.0): 0.5233844194501931 - mean_average_precision(0.0): 0.48736454745855345\n", + "102/102 [==============================] - 44s 430ms/step - loss: 0.0977\n", + "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4698790408918565 - normalized_discounted_cumulative_gain@5(0.0): 0.5355447042513717 - mean_average_precision(0.0): 0.5005823464725863\n", "Epoch 17/20\n", - "32/32 [==============================] - 40s 1s/step - loss: 0.1437\n", - "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.446051744365399 - normalized_discounted_cumulative_gain@5(0.0): 0.511689458533883 - mean_average_precision(0.0): 0.47390160852977303\n", + "102/102 [==============================] - 44s 436ms/step - loss: 0.0975\n", + "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4699380823064665 - normalized_discounted_cumulative_gain@5(0.0): 0.5335843828018585 - mean_average_precision(0.0): 0.4945873841691485\n", "Epoch 18/20\n", - "32/32 [==============================] - 40s 1s/step - loss: 0.1396\n", - "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.44155513060981444 - normalized_discounted_cumulative_gain@5(0.0): 0.508852366919329 - mean_average_precision(0.0): 0.47430841711913996\n", + "102/102 [==============================] - 44s 431ms/step - loss: 0.1070\n", + "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4558123512520484 - normalized_discounted_cumulative_gain@5(0.0): 0.5280824209271964 - mean_average_precision(0.0): 0.49009730599920476\n", "Epoch 19/20\n", - "32/32 [==============================] - 39s 1s/step - loss: 0.1176\n", - "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4480183178506102 - normalized_discounted_cumulative_gain@5(0.0): 0.5177348617853279 - mean_average_precision(0.0): 0.4799578376954719\n", + "102/102 [==============================] - 45s 446ms/step - loss: 0.0846\n", + "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.47468488947984894 - normalized_discounted_cumulative_gain@5(0.0): 0.5462940161373581 - mean_average_precision(0.0): 0.5050693971440435\n", "Epoch 20/20\n", - "32/32 [==============================] - 37s 1s/step - loss: 0.1094\n", - "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.44665417415821596 - normalized_discounted_cumulative_gain@5(0.0): 0.5142203489288135 - mean_average_precision(0.0): 0.47726846390007227\n" + "102/102 [==============================] - 43s 425ms/step - loss: 0.0833\n", + "Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4571219999869289 - normalized_discounted_cumulative_gain@5(0.0): 0.527098973778668 - mean_average_precision(0.0): 0.4884211445807594\n" ] } ],