Skip to content
This repository has been archived by the owner on Jan 19, 2019. It is now read-only.

Tests and fixes to new data api #407

Open
wants to merge 20 commits into
base: new_data_api
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 0 additions & 1 deletion deep_qa/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .dataset import Dataset
from .data_generator import DataGenerator
from .instance import Instance
from .vocabulary import Vocabulary
from .tokenizers import tokenizers
52 changes: 38 additions & 14 deletions deep_qa/data/data_generator.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from typing import List
from typing import List, Tuple
import logging
import random
from copy import deepcopy

from ..common.params import Params
from ..common.util import group_by_count
from . import IndexedDataset
from .instances import IndexedInstance
from ..common.util import group_by_count, add_noise_to_dict_values
from . import Dataset, Instance

logger = logging.getLogger(__name__) # pylint: disable=invalid-name

Expand Down Expand Up @@ -104,7 +103,7 @@ def __init__(self, text_trainer, params: Params):
#: this data.
self.last_num_batches = None

def create_generator(self, dataset: IndexedDataset, batch_size: int=None):
def create_generator(self, dataset: Dataset, batch_size: int=None):
"""
Main external API call: converts an ``IndexedDataset`` into a data generator suitable for
use with Keras' ``fit_generator`` and related methods.
Expand All @@ -122,15 +121,17 @@ def generator():
else:
groups = grouped_instances
for group in groups:
batch = IndexedDataset(group)
batch.pad_instances(self.text_trainer.get_padding_lengths(), verbose=False)
yield batch.as_training_data()
batch = Dataset(group)

yield batch.as_arrays(self.text_trainer.get_padding_lengths(), verbose=False)
return generator()

def __create_batches(self, dataset: IndexedDataset, batch_size: int) -> List[List[IndexedInstance]]:
if self.dynamic_padding:
dataset.sort_by_padding(self.text_trainer.get_instance_sorting_keys(), self.padding_noise)
def __create_batches(self, dataset: Dataset, batch_size: int) -> List[List[Instance]]:
instances = dataset.instances
if self.dynamic_padding:
instances = self.sort_dataset_by_padding(dataset,
self.text_trainer.get_instance_sorting_keys(),
self.padding_noise)
if self.adaptive_batch_sizes:
grouped_instances = self.__adaptive_grouping(instances)
else:
Expand All @@ -148,7 +149,7 @@ def __create_batches(self, dataset: IndexedDataset, batch_size: int) -> List[Lis
random.shuffle(grouped_instances)
return grouped_instances

def __adaptive_grouping(self, instances: List[IndexedInstance]):
def __adaptive_grouping(self, instances: List[Instance]):
batches = []
current_batch = []
current_lengths = {}
Expand All @@ -163,13 +164,36 @@ def __adaptive_grouping(self, instances: List[IndexedInstance]):
or len(current_batch) > self.maximum_batch_size):
current_batch.pop()
if logger.getEffectiveLevel() <= logging.DEBUG:
padding_lengths = IndexedDataset(current_batch).padding_lengths()
padding_lengths = Dataset(current_batch).get_padding_lengths()
logger.debug("Batch size: %d; padding: %s", len(current_batch), padding_lengths)
batches.append(current_batch)
current_batch = [instance]
current_lengths = instance_lengths
if logger.getEffectiveLevel() <= logging.DEBUG:
padding_lengths = IndexedDataset(current_batch).padding_lengths()
padding_lengths = Dataset(current_batch).get_padding_lengths()
logger.debug("Batch size: %d; padding: %s", len(current_batch), padding_lengths)
batches.append(current_batch)
return batches

@staticmethod
def sort_dataset_by_padding(dataset: Dataset,
sorting_keys: List[Tuple[str, str]],
padding_noise: float=0.0) -> List[Instance]:
"""
Sorts the ``Instances`` in this ``Dataset`` by their padding lengths, using the keys in
``sorting_keys`` (in the order in which they are provided). ``sorting_keys`` is a list of
``(field_name, padding_key)`` tuples.
"""
instances_with_lengths = []
for instance in dataset.instances:
padding_lengths = instance.get_padding_lengths()
if padding_noise > 0.0:
noisy_lengths = {}
for field_name, field_lengths in padding_lengths:
noisy_lengths[field_name] = add_noise_to_dict_values(field_lengths, padding_noise)
padding_lengths = noisy_lengths
instance_with_lengths = [padding_lengths[field_name][padding_key]
for (field_name, padding_key) in sorting_keys] + [instance]
instances_with_lengths.append(instance_with_lengths)
instances_with_lengths.sort(key=lambda x: x[:-1])
return [instance_with_lengths[-1] for instance_with_lengths in instances_with_lengths]
52 changes: 19 additions & 33 deletions deep_qa/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import numpy
import tqdm

from ..common.util import add_noise_to_dict_values
from . import Instance, Vocabulary
from .instance import Instance
from .vocabulary import Vocabulary
from ..common.checks import ConfigurationError

logger = logging.getLogger(__name__) # pylint: disable=invalid-name

Expand All @@ -22,8 +23,14 @@ def __init__(self, instances: List[Instance]):
A Dataset just takes a list of instances in its constructor. It's important that all
subclasses have an identical constructor to this (though possibly with different Instance
types). If you change the constructor, you also have to override all methods in this base
class that call the constructor, such as `merge()` and `truncate()`.
class that call the constructor, such as `truncate()`.
"""
all_instance_fields_and_types = [{k: v.__class__.__name__ for k, v in x.fields().items()}
for x in instances]
# Check all the field names and Field types are the same for every instance.
if not all([all_instance_fields_and_types[0] == x for x in all_instance_fields_and_types]):
raise ConfigurationError("You cannot construct a Dataset with non-homogeneous Instances.")

self.instances = instances

def truncate(self, max_instances: int):
Expand All @@ -43,28 +50,7 @@ def index_instances(self, vocab: Vocabulary):
for instance in tqdm.tqdm(self.instances):
instance.index_fields(vocab)

def sort_by_padding(self, sorting_keys: List[(str, str)], padding_noise: float=0.0):
"""
Sorts the ``Instances`` in this ``Dataset`` by their padding lengths, using the keys in
``sorting_keys`` (in the order in which they are provided). ``sorting_keys`` is a list of
``(field_name, padding_key)`` tuples.
"""
# TODO(matt): this code should probably go into the data generator.
instances_with_lengths = []
for instance in self.instances:
padding_lengths = instance.get_padding_lengths()
if padding_noise > 0.0:
noisy_lengths = {}
for field_name, field_lengths in padding_lengths:
noisy_lengths[field_name] = add_noise_to_dict_values(field_lengths, padding_noise)
padding_lengths = noisy_lengths
instance_with_lengths = [padding_lengths[field_name][padding_key]
for (field_name, padding_key) in sorting_keys] + [instance]
instances_with_lengths.append(instance_with_lengths)
instances_with_lengths.sort(key=lambda x: x[:-1])
self.instances = [instance_with_lengths[-1] for instance_with_lengths in instances_with_lengths]

def padding_lengths(self) -> Dict[str, Dict[str, int]]:
def get_padding_lengths(self) -> Dict[str, Dict[str, int]]:
"""
Gets the maximum padding lengths from all ``Instances`` in this dataset. Each ``Instance``
has multiple ``Fields``, and each ``Field`` could have multiple things that need padding.
Expand All @@ -77,7 +63,7 @@ def padding_lengths(self) -> Dict[str, Dict[str, int]]:
padding_lengths = defaultdict(dict)
all_instance_lengths = [instance.get_padding_lengths() for instance in self.instances]
if not all_instance_lengths:
return padding_lengths
return {**padding_lengths}
all_field_lengths = defaultdict(list)
for instance_lengths in all_instance_lengths:
for field_name, instance_field_lengths in instance_lengths.items():
Expand All @@ -86,7 +72,7 @@ def padding_lengths(self) -> Dict[str, Dict[str, int]]:
for padding_key in field_lengths[0].keys():
max_value = max(x[padding_key] if padding_key in x else 0 for x in field_lengths)
padding_lengths[field_name][padding_key] = max_value
return padding_lengths
return {**padding_lengths}

def as_arrays(self,
padding_lengths: Dict[str, Dict[str, int]]=None,
Expand Down Expand Up @@ -131,14 +117,14 @@ def as_arrays(self,
if verbose:
logger.info("Padding dataset of size %d to lengths %s", len(self.instances), str(padding_lengths))
logger.info("Getting max lengths from instances")
instance_padding_lengths = self.padding_lengths()
instance_padding_lengths = self.get_padding_lengths()
if verbose:
logger.info("Instance max lengths: %s", str(instance_padding_lengths))
lengths_to_use = defaultdict(dict)
for field_name, instance_field_lengths in instance_padding_lengths.items():
for padding_key in instance_field_lengths:
for padding_key in instance_field_lengths.keys():
if padding_lengths[field_name].get(padding_key) is not None:
lengths_to_use[field_name][padding_key] = padding_lengths[padding_key]
lengths_to_use[field_name][padding_key] = padding_lengths[field_name][padding_key]
else:
lengths_to_use[field_name][padding_key] = instance_field_lengths[padding_key]

Expand All @@ -147,17 +133,17 @@ def as_arrays(self,
if verbose:
logger.info("Now actually padding instances to length: %s", str(lengths_to_use))
for instance in tqdm.tqdm(self.instances):
for field, arrays in instance.pad(lengths_to_use):
for field, arrays in instance.pad(lengths_to_use).items():
field_arrays[field].append(arrays)
else:
for instance in self.instances:
for field, arrays in instance.pad(lengths_to_use):
for field, arrays in instance.pad(lengths_to_use).items():
field_arrays[field].append(arrays)

# Finally, we combine the arrays that we got for each instance into one big array (or set
# of arrays) per field.
for field_name, field_array_list in field_arrays.items():
if isinstance(field_array_list[0], [list, tuple]):
if isinstance(field_array_list[0], (list, tuple)):
field_arrays[field_name] = [numpy.asarray(x) for x in zip(*field_array_list)]
else:
field_arrays[field_name] = numpy.asarray(field_array_list)
Expand Down
2 changes: 1 addition & 1 deletion deep_qa/data/dataset_readers/dataset_reader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from . import dataset_readers
from .. import Dataset
from ...common import Params

Expand All @@ -17,5 +16,6 @@ def read(self) -> Dataset:

@staticmethod
def from_params(params: Params):
from . import dataset_readers
choice = params.pop_choice('type', list(dataset_readers.keys()))
return dataset_readers[choice].from_params(params)
1 change: 1 addition & 0 deletions deep_qa/data/dataset_readers/language_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ..token_indexers import TokenIndexer, SingleIdTokenIndexer
from ..tokenizers import Tokenizer, WordTokenizer


class LanguageModelingReader(DatasetReader):
"""
Reads a text file and converts it into a ``Dataset`` suitable for training a language model.
Expand Down
8 changes: 7 additions & 1 deletion deep_qa/data/fields/index_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from overrides import overrides
import numpy

from . import Field, SequenceField
from .field import Field
from .sequence_field import SequenceField


class IndexField(Field):
Expand Down Expand Up @@ -31,3 +32,8 @@ def pad(self, padding_lengths: Dict[str, int]) -> List[numpy.array]:
@overrides
def empty_field(self):
return IndexField(0, None)

def sequence_index(self):
# This method can't be called index,
# as that name is already reserved.
return self._index
19 changes: 15 additions & 4 deletions deep_qa/data/fields/label_field.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from typing import Dict, List, Union
import logging

from overrides import overrides
import numpy

from . import Field
from .. import Vocabulary
from .field import Field
from ..vocabulary import Vocabulary

logger = logging.getLogger(__name__) # pylint: disable=invalid-name


class LabelField(Field):
"""
Expand All @@ -26,17 +30,21 @@ class LabelField(Field):
"""
def __init__(self,
label: Union[str, int],
label_namespace: str='labels',
label_namespace: str='*labels',
num_labels: int=None):
self._label = label
self._label_namespace = label_namespace
if num_labels is None:
self._label_id = None
self._num_labels = None
if not self._label_namespace.startswith("*"):
logger.warning("The namespace of your tag (%s) does not begin with *,"
" meaning the vocabulary namespace will contain UNK "
"and PAD tokens by default.", self._label_namespace)
else:
assert isinstance(label, int), "Labels must be ints if you want to skip indexing"
self._label_id = label
self.num_labels = num_labels
self._num_labels = num_labels

@overrides
def count_vocab_items(self, counter: Dict[str, Dict[str, int]]):
Expand All @@ -62,3 +70,6 @@ def pad(self, padding_lengths: Dict[str, int]) -> List[numpy.array]:
@overrides
def empty_field(self):
return LabelField(0, self._label_namespace)

def label(self):
return self._label
12 changes: 8 additions & 4 deletions deep_qa/data/fields/list_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from overrides import overrides
import numpy

from . import Field, SequenceField
from .. import Vocabulary
from .field import Field
from .sequence_field import SequenceField
from ..vocabulary import Vocabulary
from ...common.util import pad_sequence_to_length


Expand All @@ -16,7 +17,7 @@ class ListField(SequenceField):
def __init__(self, field_list: List[Field]):
field_class_set = set([field.__class__ for field in field_list])
assert len(field_class_set) == 1, "ListFields must contain a single field type, found " +\
str(field_class_set)
str(field_class_set)
self._field_list = field_list

@overrides
Expand Down Expand Up @@ -47,7 +48,7 @@ def pad(self, padding_lengths: Dict[str, int]) -> List[numpy.array]:
padding_lengths['num_fields'],
self._field_list[0].empty_field)
padded_fields = [field.pad(padding_lengths) for field in padded_field_list]
if isinstance(padded_fields[0], [list, tuple]):
if isinstance(padded_fields[0], (list, tuple)):
return [numpy.asarray(x) for x in zip(*padded_fields)]
else:
return [numpy.asarray(padded_fields)]
Expand All @@ -56,3 +57,6 @@ def pad(self, padding_lengths: Dict[str, int]) -> List[numpy.array]:
def empty_field(self):
raise RuntimeError("Nested ListFields are not implemented, and if you want this "
"you should probably try to simplify your data type, anyway")

def fields(self):
return self._field_list
2 changes: 1 addition & 1 deletion deep_qa/data/fields/sequence_field.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import Field
from .field import Field


class SequenceField(Field):
Expand Down
24 changes: 19 additions & 5 deletions deep_qa/data/fields/tag_field.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from typing import Dict, List
import logging

from overrides import overrides
import numpy

from . import Field, SequenceField
from .. import Vocabulary
from .field import Field
from .sequence_field import SequenceField
from ..vocabulary import Vocabulary
from ...common.util import pad_sequence_to_length
from ...common.checks import ConfigurationError

logger = logging.getLogger(__name__) # pylint: disable=invalid-name


class TagField(Field):
Expand All @@ -14,14 +19,20 @@ class TagField(Field):
Because it's a labeling of some other field, we take that field as input here, and we use it to
determine our padding and other things.
"""
def __init__(self, tags: List[str], sequence_field: SequenceField, tag_namespace: str='tags'):
def __init__(self, tags: List[str], sequence_field: SequenceField, tag_namespace: str='*tags'):
self._tags = tags
self._sequence_field = sequence_field
self._tag_namespace = tag_namespace
self._indexed_tags = None
self._num_tags = None
assert len(tags) == sequence_field.sequence_length(), "Tag length and sequence length " +\
"don't match: %d and %d" % (len(tags), sequence_field.sequence_length())

if not self._tag_namespace.startswith("*"):
logger.warning("The namespace of your tag (%s) does not begin with *, meaning the vocabulary "
"namespace will contain UNK and PAD tokens by default.", self._tag_namespace)

if len(tags) != sequence_field.sequence_length():
raise ConfigurationError("Tag length and sequence length "
"don't match: %d and %d" % (len(tags), sequence_field.sequence_length()))

@overrides
def count_vocab_items(self, counter: Dict[str, Dict[str, int]]):
Expand Down Expand Up @@ -54,3 +65,6 @@ def empty_field(self):
tag_field = TagField([], None)
tag_field._indexed_tags = []
return tag_field

def tags(self):
return self._tags