Skip to content
Permalink
Browse files

Allow load before train, fix prediction bug, fixes

This commit makes the following changes:
- Models can now be loaded before training
  - Save training data info along with weights when saving model
  - Fully instantiate all components of model in `initialize` call
- Fix prediction bug
  - Extend copy of train vocab
  - Reset embeddings with new vocab in copy of model
- Deprecate dm.process and instead use dm.data.process & dm.data.process_unlabeled
- Other small fixes
  • Loading branch information...
sidharthms committed Apr 24, 2018
1 parent 5786092 commit acf63136579128b7da547410a1999a2e09b7d603
29 LICENSE
@@ -0,0 +1,29 @@
BSD 3-Clause License

Copyright (c) Sidharth Mudgal and Han Li 2018,
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.

* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
@@ -1,19 +1,33 @@
r"""
The deepmatcher package contains high level modules used in the construction of deep
learning modules for entity matching. It also contains data processing utilities.
learning modules for entity matching.
"""

from .data import process
import warnings

from .data import process as data_process
from .models import modules
from .models.core import (MatchingModel, AttrSummarizer, AttrComparator,
WordContextualizer, WordComparator, WordAggregator, Classifier)
from .models import (attr_summarizers, word_aggregators, word_comparators,
word_contextualizers)

warnings.filterwarnings('always', module='deepmatcher')


def process(*args, **kwargs):
warnings.warn('"deepmatcher.process" is deprecated and will be removed in a later '
'release, please use "deepmatcher.data.process" instead',
DeprecationWarning)
return data_process(*args, **kwargs)


__version__ = '0.0.1a0'

__all__ = [
attr_summarizers, word_aggregators, word_comparators, word_contextualizers, process,
MatchingModel, AttrSummarizer, AttrComparator, WordContextualizer,
WordComparator, WordAggregator, Classifier, modules
MatchingModel, AttrSummarizer, AttrComparator, WordContextualizer, WordComparator,
WordAggregator, Classifier, modules
]

_check_nan = True
@@ -14,23 +14,23 @@ def __new__(cls, *args, **kwargs):
else:
name = kwargs['name']
attr = kwargs['attr']
train_dataset = kwargs['train_dataset']
train_info = kwargs['train_info']
if isinstance(attr, tuple):
data = attr[0]
lengths = attr[1]
else:
data = attr
lengths = None
word_probs = None
if 'word_probs' in train_dataset.metadata:
raw_word_probs = train_dataset.metadata['word_probs'][name]
if 'word_probs' in train_info.metadata:
raw_word_probs = train_info.metadata['word_probs'][name]
word_probs = torch.Tensor(
[[raw_word_probs[w] for w in b] for b in data.data])
if data.is_cuda:
word_probs = word_probs.cuda()
pc = None
if 'pc' in train_dataset.metadata:
pc = torch.Tensor(train_dataset.metadata['pc'][name])
if 'pc' in train_info.metadata:
pc = torch.Tensor(train_info.metadata['pc'][name])
if data.is_cuda:
pc = pc.cuda()
return AttrTensor(data, lengths, word_probs, pc)
@@ -42,13 +42,13 @@ def from_old_metadata(data, old_attrtensor):

class MatchingBatch(object):

def __init__(self, input, train_dataset):
copy_fields = train_dataset.all_text_fields
def __init__(self, input, train_info):
copy_fields = train_info.all_text_fields
for name in copy_fields:
setattr(self, name,
AttrTensor(
name=name, attr=getattr(input, name),
train_dataset=train_dataset))
for name in [train_dataset.label_field, train_dataset.id_field]:
train_info=train_info))
for name in [train_info.label_field, train_info.id_field]:
if name is not None and hasattr(input, name):
setattr(self, name, getattr(input, name))
@@ -1,7 +1,9 @@
from .field import MatchingField
from .dataset import MatchingDataset
from .iterator import MatchingIterator
from .process import process
from .process import process, process_unlabeled
from .dataset import split

__all__ = [MatchingField, MatchingDataset, MatchingIterator, process, split]
__all__ = [
MatchingField, MatchingDataset, MatchingIterator, process, process_unlabeled, split
]
@@ -1,5 +1,6 @@
from __future__ import division

import copy
import logging
import os
import pdb
@@ -15,6 +16,7 @@
from torchtext import data

from ..models.modules import NoMeta, Pool
from .field import MatchingField
from .iterator import MatchingIterator

logger = logging.getLogger(__name__)
@@ -200,7 +202,7 @@ def compute_metadata(self, pca=False):

# Create an iterator over the entire dataset.
train_iter = MatchingIterator(
self, self, batch_size=1024, device=-1, sort_in_buckets=False)
self, self, train=False, batch_size=1024, device=-1, sort_in_buckets=False)
counter = defaultdict(Counter)

# For each attribute, find the number of times each word id occurs in the dataset.
@@ -245,7 +247,7 @@ def compute_metadata(self, pca=False):

# Create an iterator over the entire dataset.
train_iter = MatchingIterator(
self, self, batch_size=1024, device=-1, sort_in_buckets=False)
self, self, train=False, batch_size=1024, device=-1, sort_in_buckets=False)
attr_embeddings = defaultdict(list)

# Run the constructed neural network to compute weighted sequence embeddings
@@ -273,6 +275,7 @@ def finalize_metadata(self):
the cache.
"""

self.orig_metadata = copy.deepcopy(self.metadata)
for name in self.all_text_fields:
self.metadata['word_probs'][name] = defaultdict(
lambda: 1 / self.metadata['totals'][name],
@@ -410,6 +413,8 @@ def load_cache(fields, datafiles, cachefile, column_naming, state_args):
args_mismatch = field.preprocess_args() != cached_data['field_args'][name]
if none_mismatch or args_mismatch:
cache_stale_cause.add('Field arguments have changed.')
if field is not None and not isinstance(field, MatchingField):
cache_stale_cause.add('Cache update required.')

if column_naming != cached_data['column_naming']:
cache_stale_cause.add('Other arguments have changed.')
@@ -456,10 +461,9 @@ def restore_data(fields, cached_data):
@classmethod
def splits(cls,
path,
train,
train=None,
validation=None,
test=None,
unlabeled=None,
fields=None,
embeddings=None,
embeddings_cache=None,
@@ -478,8 +482,6 @@ def splits(cls,
for no validation set. Default is None.
test (str): Suffix to add to path for the test set, or None for no test
set. Default is None.
unlabeled (str): Suffix to add to path for an unlabeled dataset (e.g. for
prediction). Default is None.
fields (list(tuple(str, MatchingField))): A list of tuples containing column
name (e.g. "left_address") and corresponding :class:`~data.MatchingField`
pairs, in the same order that the columns occur in the CSV file. Tuples of
@@ -501,14 +503,14 @@ def splits(cls,
Returns:
Tuple[MatchingDataset]: Datasets for (train, validation, and test) splits in
that order, if provided, or dataset for unlabeled, if provided.
that order, if provided.
"""

fields_dict = dict(fields)
state_args = {'train_pca': train_pca}

datasets = None
if cache and not unlabeled:
if cache:
datafiles = list(f for f in (train, validation, test) if f is not None)
datafiles = [os.path.expanduser(os.path.join(path, d)) for d in datafiles]
cachefile = os.path.expanduser(os.path.join(path, cache))
@@ -531,18 +533,14 @@ def splits(cls,
if not datasets:
begin = timer()
dataset_args = {'fields': fields, 'column_naming': column_naming, **kwargs}
if not unlabeled:
train_data = None if train is None else cls(
path=os.path.join(path, train), **dataset_args)
val_data = None if validation is None else cls(
path=os.path.join(path, validation), **dataset_args)
test_data = None if test is None else cls(
path=os.path.join(path, test), **dataset_args)
datasets = tuple(
d for d in (train_data, val_data, test_data) if d is not None)
else:
datasets = (MatchingDataset(
path=os.path.join(path, unlabeled), **dataset_args),)
train_data = None if train is None else cls(
path=os.path.join(path, train), **dataset_args)
val_data = None if validation is None else cls(
path=os.path.join(path, validation), **dataset_args)
test_data = None if test is None else cls(
path=os.path.join(path, test), **dataset_args)
datasets = tuple(
d for d in (train_data, val_data, test_data) if d is not None)

after_load = timer()
print('Load time:', after_load - begin)
@@ -560,15 +558,28 @@ def splits(cls,
after_metadata = timer()
print('Metadata time:', after_metadata - after_vocab)

if cache and not unlabeled:
if cache:
MatchingDataset.save_cache(datasets, fields_dict, datafiles, cachefile,
column_naming, state_args)
after_cache = timer()
print('Cache time:', after_cache - after_vocab)

if train:

datasets[0].finalize_metadata()

# Save additional information to train dataset.
datasets[0].embeddings = embeddings
datasets[0].embeddings_cache = embeddings_cache
datasets[0].train_pca = train_pca

# Set vocabs.
for dataset in datasets:
dataset.vocabs = {
name: datasets[0].fields[name].vocab
for name in datasets[0].all_text_fields
}

if len(datasets) == 1:
return datasets[0]
return tuple(datasets)
@@ -75,7 +75,35 @@ def cache(self, name, cache, url=None):
self.dim = len(self['a'])


class MatchingVocab(vocab.Vocab):

def extend_vectors(self, tokens, vectors):
tot_dim = sum(v.dim for v in vectors)
prev_len = len(self.itos)

new_tokens = []
for token in tokens:
if token not in self.stoi:
self.itos.append(token)
self.stoi[token] = len(self.itos) - 1
new_tokens.append(token)
self.vectors.resize_(len(self.itos), tot_dim)

for i in range(prev_len, prev_len + len(new_tokens)):
token = self.itos[i]
assert token == new_tokens[i - prev_len]

start_dim = 0
for v in vectors:
end_dim = start_dim + v.dim
self.vectors[i][start_dim:end_dim] = v[token.strip()]
start_dim = end_dim
assert (start_dim == tot_dim)


class MatchingField(data.Field):
vocab_cls = MatchingVocab

_cached_vec_data = {}

def __init__(self, tokenize='moses', id=False, **kwargs):
@@ -131,6 +159,30 @@ def build_vocab(self, *args, vectors=None, cache=None, **kwargs):
vectors = MatchingField._get_vector_data(vectors, cache)
super(MatchingField, self).build_vocab(*args, vectors=vectors, **kwargs)

def extend_vocab(self, *args, vectors=None, cache=None):
sources = []
for arg in args:
if isinstance(arg, data.Dataset):
sources += [
getattr(arg, name)
for name, field in arg.fields.items()
if field is self
]
else:
sources.append(arg)

tokens = set()
for source in sources:
for x in source:
if not self.sequential:
tokens.add(x)
else:
tokens.update(x)

if self.vocab.vectors is not None:
vectors = MatchingField._get_vector_data(vectors, cache)
self.vocab.extend_vectors(tokens, vectors)

def numericalize(self, arr, *args, **kwargs):
if not self.is_id:
return super(MatchingField, self).numericalize(arr, *args, **kwargs)
@@ -11,13 +11,17 @@

class MatchingIterator(data.BucketIterator):

def __init__(self, dataset, train_dataset, batch_size, sort_in_buckets=None,
def __init__(self,
dataset,
train_info,
train,
batch_size,
sort_in_buckets=None,
**kwargs):
train = dataset == train_dataset
if sort_in_buckets is None:
sort_in_buckets = train
self.sort_in_buckets = sort_in_buckets
self.train_dataset = train_dataset
self.train_info = train_info
super(MatchingIterator, self).__init__(
dataset, batch_size, train=train, repeat=False, **kwargs)

@@ -39,14 +43,15 @@ def splits(cls, datasets, batch_sizes=None, **kwargs):
for i in range(len(datasets)):
ret.append(
cls(datasets[i],
train_dataset=datasets[0],
train_info=datasets[0],
train=i==0,
batch_size=batch_sizes[i],
**kwargs))
return tuple(ret)

def __iter__(self):
for batch in super(MatchingIterator, self).__iter__():
yield MatchingBatch(batch, self.train_dataset)
yield MatchingBatch(batch, self.train_info)

def create_batches(self):
if self.sort_in_buckets:

0 comments on commit acf6313

Please sign in to comment.
You can’t perform that action at this time.