Skip to content

Commit

Permalink
changed PairedDataProvider to ParallelDataProvider to allow >2 datasets
Browse files Browse the repository at this point in the history
Former-commit-id: 5298e2c
  • Loading branch information
ZhitingHu committed Nov 25, 2017
1 parent 13e2a3a commit 0a07243
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 76 deletions.
113 changes: 64 additions & 49 deletions txtgen/data/databases/data_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,27 @@
import tensorflow.contrib.slim as tf_slim
from tensorflow.contrib.slim.python.slim.data import parallel_reader

# pylint: disable=too-many-arguments, too-many-locals

__all__ = [
"PairedDataProvider"
"ParallelDataProvider"
]

# pylint: disable=too-many-arguments, too-many-locals
class PairedDataProvider(tf_slim.data_provider.DataProvider):
"""A DataProvider that reads two aligned datasets.
class ParallelDataProvider(tf_slim.data_provider.DataProvider):
"""A DataProvider that reads multiple aligned datasets.
Args:
dataset1 (Dataset): The first dataset.
dataset2 (Dataset): The second dataset.
reader_kwargs1 (dict, optional): Keyword args for dataset1 reader.
reader_kwargs2 (dict, optional): Keyword args for dataset2 reader.
datasets: A list of :class:`Dataset` instances. The provider reads
one element from each of the datasets every time.
reader_kwargs (optional): A list of dictionaries or `None`. Each
dictionary contains keyword arguments for the reader of respective
dataset in :attr:`datatsets`. If not `None`,
:attr:`reader_kwargs` must have the same length with
:attr:`datasets`.
dtypes (list, optional): Types of the data in each of the datasets.
If `None` (default), types of all datasets are assumed to be
`tf.string`. If not `None`, :attr:`dtypes` must have the same length
with :attr:`datasets`.
shuffle (bool): Whether to shuffle the data sources and common queue
when reading.
num_epochs (int, optional): The number of times each data source is
Expand All @@ -39,64 +47,71 @@ class PairedDataProvider(tf_slim.data_provider.DataProvider):
"""

def __init__(self,
dataset1,
dataset2,
reader_kwargs1=None,
reader_kwargs2=None,
datasets,
reader_kwargs=None,
dtypes=None,
shuffle=True,
num_epochs=None,
common_queue_capacity=1024,
common_queue_min=526,
seed=None,
scope=None):
scope = scope or "paired_data_provider"
scope = scope or "parallel_data_provider"

if not isinstance(datasets, list) or len(datasets) < 2:
raise ValueError("`datasets` must be a list of length >= 2.")

_, data1 = parallel_reader.parallel_read(
dataset1.data_sources,
reader_class=dataset1.reader,
num_epochs=num_epochs,
num_readers=1,
# Use one reader to ensure aligned source-target data
reader_kwargs=reader_kwargs1,
shuffle=False,
capacity=common_queue_capacity,
min_after_dequeue=common_queue_min,
scope=scope)
if reader_kwargs is None:
reader_kwargs = [None for _ in range(len(datasets))]
elif not isinstance(reader_kwargs, list) or \
len(reader_kwargs) != len(datasets):
raise ValueError(
"If `reader_kwargs` is not `None`, it must be a list of the "
"same length with `datasets`.")

_, data2 = parallel_reader.parallel_read(
dataset2.data_sources,
reader_class=dataset2.reader,
num_epochs=num_epochs,
num_readers=1,
# Use one reader to ensure aligned source-target data
reader_kwargs=reader_kwargs2,
shuffle=False,
capacity=common_queue_capacity,
min_after_dequeue=common_queue_min,
scope=scope)
if dtypes is None:
dtypes = [tf.string for _ in range(len(datasets))]
elif not isinstance(dtypes, list) or len(dtypes) != len(datasets):
raise ValueError(
"If `dtypes` is not `None`, it must be a list of the "
"same length with `datasets`.")

data_list = []
for dataset, reader_kwargs in zip(datasets, reader_kwargs):
_, data = parallel_reader.parallel_read(
dataset.data_sources,
reader_class=dataset.reader,
num_epochs=num_epochs,
num_readers=1,
# Use one reader to ensure aligned source-target data
reader_kwargs=reader_kwargs,
shuffle=False,
capacity=common_queue_capacity,
min_after_dequeue=common_queue_min,
scope=scope)
data_list.append(data)

if shuffle:
with tf.name_scope(scope): # pylint: disable=not-context-manager
random_shuffle_queue = tf.RandomShuffleQueue(
capacity=common_queue_capacity,
min_after_dequeue=common_queue_min,
dtypes=[tf.string, tf.string],
dtypes=dtypes,
seed=seed,
name="shuffle_queue")
enqueue_ops = [random_shuffle_queue.enqueue([data1, data2])]
enqueue_ops = [random_shuffle_queue.enqueue(data_list)]
queue_runner.add_queue_runner(
queue_runner.QueueRunner(random_shuffle_queue, enqueue_ops))
data1, data2 = random_shuffle_queue.dequeue()

items1 = dataset1.decoder.list_items()
tensors1 = dataset1.decoder.decode(data1, items1)

items2 = dataset2.decoder.list_items()
tensors2 = dataset2.decoder.decode(data2, items2)
data_list = random_shuffle_queue.dequeue()

items = items1 + items2
tensors = tensors1 + tensors2
items_list = []
tensors_list = []
for dataset, data in zip(datasets, data_list):
items = dataset.decoder.list_items()
tensors = dataset.decoder.decode(data, items)
items_list += items
tensors_list += tensors

super(PairedDataProvider, self).__init__(
items_to_tensors=dict(zip(items, tensors)),
num_samples=dataset1.num_samples)
super(ParallelDataProvider, self).__init__(
items_to_tensors=dict(zip(items_list, tensors_list)),
num_samples=datasets[0].num_samples)
22 changes: 16 additions & 6 deletions txtgen/data/databases/database_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,17 @@
from __future__ import unicode_literals

import tempfile
import numpy as np

import tensorflow as tf

from txtgen.data.databases.mono_text_database import MonoTextDataBase
from txtgen.data.databases.paired_text_database import PairedTextDataBase
from txtgen.data.databases.multi_source_text_database import \
MultiSourceTextDataBase

# pylint: disable=too-many-locals

class TextDataBaseTest(tf.test.TestCase):
"""Tests text database class.
"""
Expand All @@ -37,8 +41,8 @@ def test_mono_text_database(self):

# Construct database
hparams = {
"num_epochs": 2,
"batch_size": 2,
"num_epochs": 5,
"batch_size": 3,
"dataset": {
"files": [text_file.name],
"vocab_file": vocab_file.name,
Expand Down Expand Up @@ -75,24 +79,24 @@ def test_paired_text_database(self):
"""Tests the logics of PairedTextDataBase.
"""
# Create test data
vocab_list = ['word', '词']
vocab_list = ['This', 'is', 'a', 'word', '词']
vocab_file = tempfile.NamedTemporaryFile()
vocab_file.write('\n'.join(vocab_list).encode("utf-8"))
vocab_file.flush()

src_text = ['This is a source sentence .', 'source: 词 词 。']
src_text = ['This is a sentence from source .', '词 词 。 source']
src_text_file = tempfile.NamedTemporaryFile()
src_text_file.write('\n'.join(src_text).encode("utf-8"))
src_text_file.flush()

tgt_text = ['This is a target sentence .', 'target: 词 词 。']
tgt_text = ['This is a sentence from target .', '词 词 。 target']
tgt_text_file = tempfile.NamedTemporaryFile()
tgt_text_file.write('\n'.join(tgt_text).encode("utf-8"))
tgt_text_file.flush()

# Construct database
hparams = {
"num_epochs": 3,
"num_epochs": 100,
"batch_size": 3,
"source_dataset": {
"files": [src_text_file.name],
Expand Down Expand Up @@ -134,6 +138,12 @@ def test_paired_text_database(self):
self.assertEqual(text_database.target_vocab.vocab_size,
len(vocab_list) + 4)

src_text = data['source_text']
tgt_text = data['target_text']
for src, tgt in zip(src_text, tgt_text):
np.testing.assert_array_equal(
src[:3], tgt[1:4])

except tf.errors.OutOfRangeError:
print('Done -- epoch limit reached')
finally:
Expand Down
10 changes: 4 additions & 6 deletions txtgen/data/databases/paired_text_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from txtgen.data.databases.database_base import DataBaseBase
from txtgen.data.databases import mono_text_database
from txtgen.data.databases.text_data_decoder import TextDataDecoder
from txtgen.data.databases.data_providers import PairedDataProvider
from txtgen.data.databases.data_providers import ParallelDataProvider
from txtgen.data.vocabulary import Vocab
from txtgen.data.embedding import Embedding

Expand Down Expand Up @@ -162,11 +162,9 @@ def _make_data_provider(self, src_dataset, tgt_dataset):
tgt_reader_kwargs = \
self._hparams.target_dataset["reader"]["kwargs"].todict()

data_provider = PairedDataProvider(
dataset1=src_dataset,
dataset2=tgt_dataset,
reader_kwargs1=src_reader_kwargs,
reader_kwargs2=tgt_reader_kwargs,
data_provider = ParallelDataProvider(
datasets=[src_dataset, tgt_dataset],
reader_kwargs=[src_reader_kwargs, tgt_reader_kwargs],
shuffle=self._hparams.shuffle,
num_epochs=self._hparams.num_epochs,
common_queue_capacity=1024,
Expand Down
3 changes: 2 additions & 1 deletion txtgen/data/embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
from __future__ import unicode_literals

import tempfile
import tensorflow as tf
import numpy as np

import tensorflow as tf

from txtgen.data import embedding


Expand Down
32 changes: 21 additions & 11 deletions txtgen/data/vocabulary.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,25 +85,26 @@ def load(self, filename):
vocab = list(line.strip() for line in vocab_file)

if self._bos_token in vocab:
raise ValueError("Special token already exists in the "
"vocabulary %s" % self._bos_token)
raise ValueError("Special begin-of-seq token already exists in the "
"vocabulary: '%s'" % self._bos_token)
if self._eos_token in vocab:
raise ValueError("Special token already exists in the "
"vocabulary %s" % self._eos_token)
raise ValueError("Special end-of-seq token already exists in the "
"vocabulary: '%s'" % self._eos_token)
if self._unk_token in vocab:
raise ValueError("Special token already exists in the "
"vocabulary %s" % self._unk_token)
raise ValueError("Special UNK token already exists in the "
"vocabulary: '%s'" % self._unk_token)
if self._padding_token in vocab:
raise ValueError("Special padding token already exists in the "
"vocabulary %s, it is an empty token by default"
"vocabulary: '%s', it is an empty token by default"
% self._padding_token)

# Placing _padding_token at the beginning to make sure it take index 0.
# Places _padding_token at the beginning to make sure it take index 0.
vocab = [self._padding_token, self._bos_token, self._eos_token,
self._unk_token] + vocab
# Must make sure this is consistent with the above line
unk_token_idx = 3
vocab_size = len(vocab)
vocab_idx = np.arange(vocab_size)
unk_token_idx = vocab_size - 1

# Creates TF maps
id_to_token_map = tf.contrib.lookup.HashTable(
Expand Down Expand Up @@ -177,8 +178,17 @@ def unk_token(self):
"""
return self._unk_token

@property
def padding_token(self):
"""A string of the special token indicating padding token. The
default padding token is an empty string.
"""
return self._padding_token

@property
def special_tokens(self):
"""The list of special tokens :attr:`[bos_token, eos_token, unk_token]`.
"""The list of special tokens
:attr:`[padding_token, bos_token, eos_token, unk_token]`.
"""
return [self._bos_token, self._eos_token, self._unk_token]
return [self._padding_token, self._bos_token, self._eos_token,
self._unk_token]
12 changes: 9 additions & 3 deletions txtgen/data/vocabulary_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from txtgen.data import vocabulary

# pylint: disable=protected-access

class VocabularyTest(tf.test.TestCase):
"""Tests vocabulary related operations.
Expand All @@ -26,15 +27,15 @@ def test_make_defaultdict(self):
values = [0, 1]
default_value = -1

dict_ = vocabulary._make_defaultdict(keys, values, default_value) # pylint: disable=protected-access
dict_ = vocabulary._make_defaultdict(keys, values, default_value)

self.assertEqual(len(dict_), 2)
self.assertEqual(dict_['word'], 0)
self.assertEqual(dict_['词'], 1)
self.assertEqual(dict_['sth_else'], -1)

def test_vocab_load(self):
"""Test vocabulary load function.
def test_vocab_construction(self):
"""Test vocabulary construction.
"""
vocab_list = ['word', '词']
vocab_file = tempfile.NamedTemporaryFile()
Expand All @@ -48,6 +49,11 @@ def test_vocab_load(self):
set(vocab.token_to_id_map_py.keys()),
set(['word', '词'.encode('utf8')] + vocab.special_tokens))

# Tests UNK token
unk_token_id = vocab.token_to_id_map_py['new']
unk_token_text = vocab.id_to_token_map_py[unk_token_id]
self.assertEqual(unk_token_text, vocab.unk_token)


if __name__ == "__main__":
tf.test.main()
Expand Down

0 comments on commit 0a07243

Please sign in to comment.