Skip to content

Commit

Permalink
updating data module
Browse files Browse the repository at this point in the history
Former-commit-id: 9b4161d
  • Loading branch information
ZhitingHu committed Mar 15, 2018
1 parent ce1e5a8 commit 106ac50
Show file tree
Hide file tree
Showing 19 changed files with 104 additions and 287 deletions.
4 changes: 2 additions & 2 deletions examples/cvae_dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import tensorflow as tf

# We shall wrap all these modules
from texar.data import MultiSourceTextDataBase
from texar.data import qMultiSourceTextData
from texar.modules import ForwardConnector
from texar.modules import BasicRNNDecoder, get_helper
from texar.modules import HierarchicalEncoder
Expand Down Expand Up @@ -48,7 +48,7 @@
}

# Construct the database
dialog_db = MultiSourceTextDataBase(data_hparams)
dialog_db = qMultiSourceTextData(data_hparams)
data_batch = dialog_db()

# builder encoder
Expand Down
6 changes: 2 additions & 4 deletions examples/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
np.random.seed(rseed)

# We shall wrap all these modules
from texar.data import MonoTextDataBase
from texar.data import qMonoTextData
from texar.modules import ConstantConnector
from texar.modules import BasicRNNDecoder, get_helper
from texar.losses import mle_losses
Expand All @@ -47,7 +47,7 @@ def load_data():
}
}
# Construct the database
text_db = MonoTextDataBase(data_hparams)
text_db = qMonoTextData(data_hparams)
# Get data minibatch, which is a dictionary:
# {
# "text": text_tensor, # text string minibatch,
Expand Down Expand Up @@ -86,7 +86,6 @@ def train():
outputs, final_state, sequence_lengths = decoder(
helper=helper_train, initial_state=connector(batch_size))

print('decoder done')
# Build loss
mle_loss = mle_losses.average_sequence_sparse_softmax_cross_entropy(
labels=data_batch['text_ids'][:, 1:],
Expand Down Expand Up @@ -115,7 +114,6 @@ def train():
sess.run(tf.tables_initializer())

tf.summary.FileWriter('language_models', sess.graph)
exit()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

Expand Down
4 changes: 2 additions & 2 deletions examples/transformer/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import numpy as np
import tensorflow as tf
import logging
from texar.data import PairedTextDataBase
from texar.data import qPairedTextData
from texar.core.utils import _bucket_boundaries
from texar.modules import TransformerEncoder, TransformerDecoder
from texar.losses import mle_losses
Expand Down Expand Up @@ -101,7 +101,7 @@ def config_logging(filepath):
},
}
# Construct the database
text_database = PairedTextDataBase(data_hparams)
text_database = qPairedTextData(data_hparams)
text_data_batch = text_database()
ori_src_text = text_data_batch['source_text_ids']
ori_tgt_text = text_data_batch['target_text_ids']
Expand Down
1 change: 0 additions & 1 deletion examples/transformer/transformer_dutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import random
import numpy as np
import tensorflow as tf
from texar.data import PairedTextDataBase
from texar.modules import TransformerEncoder, TransformerDecoder
from texar.losses import mle_losses
from texar.core import optimization as opt
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
f8459f7232cbe71d09081303f5d04e9dc7f6f21c
3 changes: 2 additions & 1 deletion texar/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

# pylint: disable=wildcard-import

from texar.data.databases import *
#from texar.data.databases import *
from texar.data.q_data import *
from texar.data.vocabulary import *
from texar.data.embedding import *
from texar.data.constants import *
1 change: 1 addition & 0 deletions texar/data/constants.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#
"""
Define a set of constants used to create default data readers. Most can be
overwritten in the form of hparams
Expand Down
12 changes: 12 additions & 0 deletions texar/data/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#
"""
Modules of texar library data inputs.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# pylint: disable=wildcard-import

#from texar.data.data_base import *
19 changes: 0 additions & 19 deletions texar/data/databases/__init__.py

This file was deleted.

201 changes: 0 additions & 201 deletions texar/data/databases/mono_text_database.py

This file was deleted.

18 changes: 18 additions & 0 deletions texar/data/q_data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#
"""
Modules of texar library queue-based data inputs.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# pylint: disable=wildcard-import

from texar.data.q_data.q_data_base import *
from texar.data.q_data.q_mono_text_data import *
from texar.data.q_data.q_paired_text_data import *
from texar.data.q_data.q_multi_source_text_data import *
from texar.data.q_data.q_multi_aligned_data import *
from texar.data.q_data.q_data_providers import *
from texar.data.q_data.q_data_decoders import *
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
from __future__ import print_function
from __future__ import unicode_literals

# pylint: disable=invalid-name

from texar.hyperparams import HParams

__all__ = [
"DataBaseBase"
"qDataBase"
]

class DataBaseBase(object):
class qDataBase(object):
"""Base class of all data classes.
"""

Expand Down
File renamed without changes.
File renamed without changes.

0 comments on commit 106ac50

Please sign in to comment.