Skip to content

Commit

Permalink
Fix issues (#177)
Browse files Browse the repository at this point in the history
* Fix bug in MultiAlignedData for ScalarData

* Expand types in ScalarData

* Improve docs for Conv nets

* Address review comments
  • Loading branch information
AvinashBukkittu authored and gpengzhi committed Aug 28, 2019
1 parent dfe0332 commit 4ab7b7a
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 29 deletions.
3 changes: 2 additions & 1 deletion texar/torch/data/data/multi_aligned_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,10 +236,11 @@ def _get_filter(max_seq_length):
compression_type=hparams_i.compression_type)
sources.append(source_i)
filters.append(None)
self._names.append({"label": hparams_i.data_name})
self._names.append({"data": hparams_i.data_name})

dataset_hparams = dict_fetch(
hparams_i, ScalarData.default_hparams()["dataset"])
dataset_hparams["data_name"] = "data"
self._databases.append(ScalarData(
hparams={"dataset": dataset_hparams}, device=device,
data_source=dummy_source))
Expand Down
24 changes: 8 additions & 16 deletions texar/torch/data/data/scalar_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
Various data classes that define data reading, parsing, batching, and other
preprocessing operations.
"""
from typing import (List, Optional, Type, Union)
from typing import (List, Optional, Union)

import numpy as np
import torch
Expand All @@ -24,6 +24,8 @@
from texar.torch.data.data.dataset_utils import Batch
from texar.torch.data.data.text_data_base import TextLineDataSource
from texar.torch.hyperparams import HParams
from texar.torch.utils.dtypes import get_numpy_dtype


__all__ = [
"_default_scalar_dataset_hparams",
Expand Down Expand Up @@ -86,18 +88,7 @@ def __init__(self, hparams, device: Optional[torch.device] = None,
data_source: Optional[DataSource] = None):
self._hparams = HParams(hparams, self.default_hparams())
self._other_transforms = self._hparams.dataset.other_transformations
data_type = self._hparams.dataset["data_type"]
self._typecast_func: Union[Type[int], Type[float]]
if data_type == "int":
self._typecast_func = int
self._to_data_type = np.int32
elif data_type == "float":
self._typecast_func = float
self._to_data_type = np.float32
else:
raise ValueError("Incorrect 'data_type'. Currently 'int' and "
"'float' are supported. Received {}"
.format(data_type))
self._data_type = get_numpy_dtype(self._hparams.dataset["data_type"])
if data_source is None:
data_source = TextLineDataSource(
self._hparams.dataset.files,
Expand Down Expand Up @@ -146,7 +137,8 @@ def default_hparams():
One of "" (no compression), "ZLIB", or "GZIP".
`"data_type"`: str
The scalar type. Currently supports "int" and "float".
The scalar type. Types defined in
:meth:`~texar.torch.utils.dtypes.get_numpy_dtype` are supported.
`"other_transformations"`: list
A list of transformation functions or function names/paths to
Expand All @@ -170,15 +162,15 @@ def default_hparams():

def process(self, raw_example: List[str]) -> Union[int, float]:
assert len(raw_example) == 1
example: Union[int, float] = self._typecast_func(raw_example[0])
example: Union[int, float] = self._data_type(raw_example[0])

for transform in self._other_transforms:
example = transform(example)
return example

def collate(self, examples: List[Union[int, float]]) -> Batch:
# convert the list of strings into appropriate tensors here
examples_np = np.array(examples, dtype=self._to_data_type)
examples_np = np.array(examples, dtype=self._data_type)
collated_examples = torch.from_numpy(examples_np).to(device=self.device)
return Batch(len(examples),
batch={self.data_name: collated_examples})
Expand Down
20 changes: 16 additions & 4 deletions texar/torch/modules/classifiers/conv_classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,34 @@

class Conv1DClassifier(ClassifierBase):
r"""Simple `Conv-1D` classifier.
This is a combination of the
:class:`~texar.torch.modules.Conv1DEncoder` with a classification layer.
This is a combination of the :class:`~texar.torch.modules.Conv1DEncoder`
with a classification layer.
Args:
in_channels (int): Number of channels in the input tensor.
in_features (int): Size of the feature dimension in the input tensor.
hparams (dict, optional): Hyperparameters. Missing
hyperparameters will be set to default values. See
:meth:`default_hparams` for the hyperparameter structure and
default values.
See :meth:`forward` for the inputs and outputs. If :attr:`"data_format"` is
set to ``"channels_first"`` (this is the default), inputs must be a tensor
of shape `[batch_size, channels, length]`. If :attr:`"data_format"` is set
to ``"channels_last"``, inputs must be a tensor of shape
`[batch_size, length, channels]`. For example, for sequence classification,
`length` corresponds to time steps, and `channels` corresponds to embedding
dim.
Example:
.. code-block:: python
clas = Conv1DClassifier(hparams={'num_classes': 10})
inputs = torch.randn([64, 20, 256])
clas = Conv1DClassifier(in_channels=20, in_features=256,
hparams={'num_classes': 10})
logits, pred = clas(inputs)
# logits == Tensor of shape [64, 10]
# pred == Tensor of shape [64]
Expand Down
21 changes: 13 additions & 8 deletions texar/torch/modules/networks/conv_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,26 +54,31 @@ class Conv1DNetwork(FeedForwardNetworkBase):
layers followed with a sequence of dense layers.
Args:
in_channels (int): Number of channels in the input tensor.
in_features (int): Size of the feature dimension in the input tensor.
hparams (dict, optional): Hyperparameters. Missing
hyperparameter will be set to default values. See
:meth:`default_hparams` for the hyperparameter structure and
default values.
See :meth:`forward` for the inputs and outputs. The inputs must be a
3D Tensor of shape `[batch_size, channels, length]`. For example, for
sequence classification, `length` corresponds to time steps, and `channels`
corresponds to embedding dim.
See :meth:`forward` for the inputs and outputs. If :attr:`"data_format"` is
set to ``"channels_first"`` (this is the default), inputs must be a tensor
of shape `[batch_size, channels, length]`. If :attr:`"data_format"` is set
to ``"channels_last"``, inputs must be a tensor of shape
`[batch_size, length, channels]`. For example, for sequence classification,
`length` corresponds to time steps, and `channels` corresponds to embedding
dim.
Example:
.. code-block:: python
nn = Conv1DNetwork() # Use the default structure
nn = Conv1DNetwork(in_channels=20, in_features=256) # Use the default
inputs = tf.random_uniform([64, 20, 256])
inputs = torch.randn([64, 20, 256])
outputs = nn(inputs)
# outputs == Tensor of shape [64, 128], because the final dense layer
# has size 128.
# outputs == Tensor of shape [64, 256], because the final dense layer
# has size 256.
.. document private functions
"""
Expand Down

0 comments on commit 4ab7b7a

Please sign in to comment.