Skip to content
Permalink
Browse files

Update tutorial, remove moses tokenizer, simplify sample data

  • Loading branch information...
sidharthms committed May 30, 2018
1 parent 3dda425 commit 5877b166f76bcdf79ea14f9a96fb4bb9215ba504
@@ -8,8 +8,8 @@ DeepMatcher
.. image:: https://img.shields.io/badge/License-BSD%203--Clause-blue.svg
:target: https://opensource.org/licenses/BSD-3-Clause

DeepMatcher is a python package for performing entity matching using deep learning. It
provides built-in neural networks and utilities that enable you to train and apply
DeepMatcher is a python package for performing entity / text matching using deep learning.
It provides built-in neural networks and utilities that enable you to train and apply
state-of-the-art deep learning models for entity matching in less than 10 lines of code.
The models are also easily customizable - the modular design allows any subcomponent to be
altered or swapped out for a custom implementation.
@@ -18,16 +18,13 @@ As an example, given labeled tuple pairs such as the following:

.. image:: docs/source/_static/match_input_ex.png

DeepMatcher trains a neural network to perform matching, i.e., to predict
match / non-match labels. The trained network can then be used obtain labels for unseen
tuple pairs or text sequences.
DeepMatcher uses labeled tuple pairs trains a neural network to perform matching, i.e., to
predict match / non-match labels. The trained network can then be used obtain labels for
unlabeled tuple pairs or text sequences.

This package is the official PyTorch rewrite of the code for the paper
`Deep Learning for Entity Matching`_ (SIGMOD '18). All the publicly available datasets used
in the paper can be found at `Prof. AnHai Doan's data repository`_.

This package is currently in early-release alpha. Please report any
crashes / bugs / problems you encounter while using this package.
For details on the architecture of the models used, take a look at our paper `Deep
Learning for Entity Matching`_ (SIGMOD '18). All the publicly available datasets used in
the paper can be found at `Prof. AnHai Doan's data repository`_.

**********
Quick Start: DeepMatcher in 30 seconds
@@ -105,14 +102,14 @@ Answer Selection.
API Reference
**********

API docs are under construction. The half baked docs `can be accessed here`_.
API docs `are here`_.

**********
Support
**********

This package is under active development. If you run into any issues or have questions,
please file GitHub issues.
please `file GitHub issues`_.

**********
The Team
@@ -128,6 +125,7 @@ and Han Li, under the supervision of Prof. AnHai Doan and Prof. Theodoros Rekats
.. _`Data Processing`: https://nbviewer.jupyter.org/github/sidharthms/deepmatcher/blob/master/examples/data_processing.ipynb
.. _`Matching Models`: https://nbviewer.jupyter.org/github/sidharthms/deepmatcher/blob/master/examples/matching_models.ipynb
.. _`End to End Entity Matching`: https://nbviewer.jupyter.org/github/sidharthms/deepmatcher/blob/master/examples/end_to_end_em.ipynb
.. _`can be accessed here`: http://pages.cs.wisc.edu/~sidharth/deepmatcher/index.html
.. _`are here`: https://deepmatcher.github.io/docs/
.. _`Question Answering with DeepMatcher`: https://nbviewer.jupyter.org/github/sidharthms/deepmatcher/blob/master/examples/question_answering.ipynb
.. _`WikiQA`: https://aclweb.org/anthology/D15-1237
.. _`file GitHub issues`: https://github.com/sidharthms/deepmatcher/issues
@@ -35,12 +35,13 @@ def process(*args, **kwargs):
return data_process(*args, **kwargs)


__version__ = '0.0.1a0'
__version__ = '0.0.1b'
__author__ = 'Sidharth Mudgal, Han Li'

__all__ = [
'attr_summarizers', 'word_aggregators', 'word_comparators', 'word_contextualizers', 'process',
'MatchingModel', 'AttrSummarizer', 'WordContextualizer', 'WordComparator', 'WordAggregator',
'Classifier', 'modules'
'attr_summarizers', 'word_aggregators', 'word_comparators', 'word_contextualizers',
'process', 'MatchingModel', 'AttrSummarizer', 'WordContextualizer', 'WordComparator',
'WordAggregator', 'Classifier', 'modules'
]

_check_nan = True
@@ -521,7 +521,7 @@ def splits(cls,
if not auto_rebuild_cache:
raise MatchingDataset.CacheStaleException(cache_stale_cause)
else:
print('Rebuilding data cache because:', cache_stale_cause)
logger.warn('Rebuilding data cache because:', cache_stale_cause)

if not check_cached_data or not cache_stale_cause:
datasets = MatchingDataset.restore_data(fields, cached_data)
@@ -542,26 +542,27 @@ def splits(cls,
d for d in (train_data, val_data, test_data) if d is not None)

after_load = timer()
print('Load time:', after_load - begin)
logger.info('Data load took: {}s'.format(after_load - begin))

fields_set = set(fields_dict.values())
for field in fields_set:
if field is not None and field.use_vocab:
field.build_vocab(
*datasets, vectors=embeddings, cache=embeddings_cache)
after_vocab = timer()
print('Vocab time:', after_vocab - after_load)
logger.info('Vocab construction time: {}s'.format(after_vocab - after_load))

if train:
datasets[0].compute_metadata(train_pca)
after_metadata = timer()
print('Metadata time:', after_metadata - after_vocab)
logger.info(
'Metadata computation time: {}s'.format(after_metadata - after_vocab))

if cache:
MatchingDataset.save_cache(datasets, fields_dict, datafiles, cachefile,
column_naming, state_args)
after_cache = timer()
print('Cache time:', after_cache - after_vocab)
logger.info('Cache save time: {}s'.format(after_cache - after_vocab))

if train:

@@ -3,6 +3,7 @@
import tarfile
import zipfile

import nltk
import six

import fastText
@@ -106,11 +107,18 @@ class MatchingField(data.Field):

_cached_vec_data = {}

def __init__(self, tokenize='moses', id=False, **kwargs):
def __init__(self, tokenize='nltk', id=False, **kwargs):
self.tokenizer_arg = tokenize
self.is_id = id
tokenize = MatchingField._get_tokenizer(tokenize)
super(MatchingField, self).__init__(tokenize=tokenize, **kwargs)

@staticmethod
def _get_tokenizer(tokenizer):
if tokenizer == 'nltk':
return nltk.word_tokenize
return tokenizer

def preprocess_args(self):
attrs = [
'sequential', 'init_token', 'eos_token', 'unk_token', 'preprocessing',
@@ -23,7 +23,7 @@ def __init__(self,
self.sort_in_buckets = sort_in_buckets
self.train_info = train_info
super(MatchingIterator, self).__init__(
dataset, batch_size, train=train, repeat=False, **kwargs)
dataset, batch_size, train=train, repeat=False, sort=False, **kwargs)

@classmethod
def splits(cls, datasets, batch_sizes=None, **kwargs):
@@ -1,5 +1,6 @@
import copy
import io
import logging
import os
from timeit import default_timer as timer

@@ -10,6 +11,8 @@
from .dataset import MatchingDataset
from .field import MatchingField

logger = logging.getLogger(__name__)


def _check_header(header, id_attr, left_prefix, right_prefix, label_attr, ignore_columns):
r"""Verify CSV file header.
@@ -79,6 +82,7 @@ def _maybe_download_nltk_data():
import nltk
nltk.download('perluniprops', quiet=True)
nltk.download('nonbreaking_prefixes', quiet=True)
nltk.download('punkt', quiet=True)


def process(path,
@@ -89,7 +93,7 @@ def process(path,
cache='cacheddata.pth',
check_cached_data=True,
auto_rebuild_cache=True,
tokenize='moses',
tokenize='nltk',
lowercase=True,
embeddings='fasttext.en.bin',
embeddings_cache_path='~/.vector_cache',
@@ -250,7 +254,7 @@ def process_unlabeled(path, trained_model, ignore_columns=None):
assert set(dataset.all_text_fields) == set(train_info.all_text_fields)

after_load = timer()
print('Load time:', after_load - begin)
logger.info('Data load time: {}s'.format(after_load - begin))

reverse_fields_dict = dict((pair[1], pair[0]) for pair in fields)
for field, name in reverse_fields_dict.items():
@@ -267,6 +271,6 @@ def process_unlabeled(path, trained_model, ignore_columns=None):
}

after_vocab = timer()
print('Vocab update time:', after_vocab - after_load)
logger.info('Vocab update time: {}s'.format(after_vocab - after_load))

return dataset
@@ -7,7 +7,6 @@
from torch.nn.utils import clip_grad_norm

logger = logging.getLogger('deepmatcher.optim')
logger.setLevel(logging.INFO)


class SoftNLLLoss(nn.NLLLoss):
@@ -102,6 +101,7 @@ def __init__(self,
self.betas = [beta1, beta2]
self.adagrad_accum = adagrad_accum
self.params = None
logger.info('Initial learning rate: {:0.3e}'.format(self.lr))

def set_parameters(self, params):
"""Sets the model parameters and initializes the base optimizer.
@@ -166,7 +166,7 @@ def update_learning_rate(self, acc, epoch):

if self.start_decay:
self.lr = self.lr * self.lr_decay
logger.info('Setting learning rate to {:0.3e}'.format(self.lr))
logger.info('Setting learning rate to {:0.3e} for next epoch'.format(self.lr))

self.last_acc = acc
self._set_rate(self.lr)
@@ -194,7 +194,7 @@ def _run(run_type,
if train and epoch == 0:
print('* Number of trainable parameters:', tally_parameters(model))

epoch_str = 'Epoch ' + str(epoch + 1) + ' :'
epoch_str = 'Epoch {0:d}'.format(epoch + 1)
print('===> ', run_type, epoch_str)
batch_end = time.time()

@@ -342,7 +342,7 @@ def train(model,

score = Runner._run('EVAL', model, validation_dataset, train=False, **kwargs)

optimizer.update_learning_rate(score, epoch)
optimizer.update_learning_rate(score, epoch + 1)
model.optimizer_state = optimizer.base_optimizer.state_dict()

new_best_found = False
@@ -354,11 +354,15 @@ def train(model,
if best_save_path and new_best_found:
print('Saving best model...')
model.save_state(best_save_path)
print('Done.')

if save_every_prefix is not None and (epoch + 1) % save_every_freq == 0:
print('Saving epoch model...')
save_path = '{prefix}_ep{epoch}.pth'.format(
prefix=save_every_prefix, epoch=epoch + 1)
model.save_state(save_path)
print('Done.')
print('---------------------\n')

print('Loading best model...')
model.load_state(best_save_path)
@@ -61,7 +61,7 @@
# The short X.Y version.
version = '0.0'
# The full version, including alpha/beta/rc tags.
release = '0.0.1'
release = '0.0.1b'

# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
@@ -13,7 +13,7 @@ DeepMatcher is a deep learning library for entity matching.
Tutorials
---------

Tutorials are located `here <https://github.com/sidharthms/deepmatcher/blob/master/examples/getting_started.ipynb>`_
Tutorials are located `here <https://github.com/sidharthms/deepmatcher#tutorials>`_

.. toctree::
:maxdepth: 1
@@ -65,7 +65,7 @@
"\n",
"By default, data processing involves performing the following two modifications to all data:\n",
"\n",
"**Tokenization:** Tokenization involves dividing text into a sequence of tokens, which roughly correspond to \"words\". E.g., \"This ain't funny. It's actually hillarious.\" will be converted to the following sequence after tokenization: \\['This', 'ain', '&apos;t', 'funny', '.', 'It', '&apos;s', 'actually', 'hillarious', '.'\\]. The tokenizer can be set by specifying the `tokenizer` parameter. By default, this is set to `\"moses\"`, which will use the **[MosesTokenizer](http://www.nltk.org/api/nltk.tokenize.html#nltk.tokenize.moses.MosesTokenizer)** in the `nltk` package. Alternatively, you may set this to `\"spacy\"` which will use the tokenizer provided by the `spacy` package. You need to first [install and setup](https://spacy.io/usage/) `spacy` to do this."
"**Tokenization:** Tokenization involves dividing text into a sequence of tokens, which roughly correspond to \"words\". E.g., \"This ain't funny. It's actually hillarious.\" will be converted to the following sequence after tokenization: \\['This', 'ain', '&apos;t', 'funny', '.', 'It', '&apos;s', 'actually', 'hillarious', '.'\\]. The tokenizer can be set by specifying the `tokenizer` parameter. By default, this is set to `\"nltk\"`, which will use the **[default nltk tokenizer](https://www.nltk.org/api/nltk.tokenize.html#nltk.tokenize.word_tokenize)**. Alternatively, you may set this to `\"spacy\"` which will use the tokenizer provided by the `spacy` package. You need to first [install and setup](https://spacy.io/usage/) `spacy` to do this."
]
},
{
@@ -324,39 +324,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Now when you change a processing parameter with `auto_rebuild_cache` set to False, you get an error: "
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"ename": "CacheStaleException",
"evalue": "{'Field arguments have changed.'}",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mCacheStaleException\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-14-5223c3043e17>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mtokenize\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'spacy'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mcheck_cached_data\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m auto_rebuild_cache=False)\n\u001b[0m",
"\u001b[0;32m/afs/cs.wisc.edu/u/s/i/sidharth/private/deepmatcher/deepmatcher/data/process.py\u001b[0m in \u001b[0;36mprocess\u001b[0;34m(path, train, validation, test, unlabeled, cache, check_cached_data, auto_rebuild_cache, tokenize, lowercase, embeddings, embeddings_cache_path, ignore_columns, include_lengths, id_attr, label_attr, left_prefix, right_prefix, pca)\u001b[0m\n\u001b[1;32m 210\u001b[0m \u001b[0mcheck_cached_data\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 211\u001b[0m \u001b[0mauto_rebuild_cache\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 212\u001b[0;31m train_pca=pca)\n\u001b[0m\u001b[1;32m 213\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 214\u001b[0m \u001b[0;31m# Save additional information to train dataset.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/afs/cs.wisc.edu/u/s/i/sidharth/private/deepmatcher/deepmatcher/data/dataset.py\u001b[0m in \u001b[0;36msplits\u001b[0;34m(cls, path, train, validation, test, fields, embeddings, embeddings_cache, column_naming, cache, check_cached_data, auto_rebuild_cache, train_pca, **kwargs)\u001b[0m\n\u001b[1;32m 521\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcheck_cached_data\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mcache_stale_cause\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 522\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mauto_rebuild_cache\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 523\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mMatchingDataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mCacheStaleException\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcache_stale_cause\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 524\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 525\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Rebuilding data cache because:'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcache_stale_cause\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mCacheStaleException\u001b[0m: {'Field arguments have changed.'}"
]
}
],
"source": [
"train, validation, test = dm.data.process(\n",
" path='sample_data/itunes-amazon',\n",
" train='train.csv',\n",
" validation='validation.csv',\n",
" test='test.csv',\n",
" ignore_columns=('left_id', 'right_id'),\n",
" cache='my_itunes_cache.pth',\n",
" tokenize='spacy',\n",
" check_cached_data=True,\n",
" auto_rebuild_cache=False)"
"Now when you change a processing parameter with `auto_rebuild_cache` set to False, you will get an error."
]
}
],
@@ -376,7 +344,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
"version": "3.6.4"
}
},
"nbformat": 4,

0 comments on commit 5877b16

Please sign in to comment.
You can’t perform that action at this time.