Skip to content

Commit

Permalink
Fix MergeOp when using reduce pooling ops (#80)
Browse files Browse the repository at this point in the history
Resolve #44
  • Loading branch information
AvinashBukkittu authored and huzecong committed Jun 27, 2019
1 parent 70a37bb commit 67ec905
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 40 deletions.
58 changes: 29 additions & 29 deletions texar/core/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,9 +525,9 @@ def forward(self, input: Tuple) -> torch.Tensor: # type: ignore
# :torch_docs:`torch.mean <torch.html#torch.mean>`
# does not return a tuple
if self._reduce_function == torch.mean:
output = self._reduce_function(input, dim=2, keepdim=True)
output = self._reduce_function(input, dim=2)
else:
output, _ = self._reduce_function(input, dim=2, keepdim=True)
output, _ = self._reduce_function(input, dim=2)
return output


Expand Down Expand Up @@ -629,8 +629,19 @@ class MergeLayer(nn.Module):
:attr:`'elemwise_sum'` and :attr:`'elemwise_mul'`.
"""

_functions: Dict[str, Callable[[torch.Tensor, int], torch.Tensor]] = {
"sum": torch.sum,
"mean": torch.mean,
"prod": torch.prod,
"max": lambda tensors, dim: torch.max(tensors, dim)[0],
"min": lambda tensors, dim: torch.min(tensors, dim)[0],
"and": torch.all,
"or": torch.any,
"logsumexp": torch.logsumexp
}

def __init__(self, layers: Optional[List[nn.Module]] = None,
mode: str = 'concat', dim: int = 2):
mode: str = 'concat', dim: Optional[int] = None):
super().__init__()
self._mode = mode
self._dim = dim
Expand All @@ -656,16 +667,26 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: # type: ignore
Returns:
The merged tensor.
"""
layer_outputs: List[torch.Tensor]
if self._layers is None:
layer_outputs: Union[torch.Tensor, List[torch.Tensor]] = input
layer_outputs = input
if not isinstance(layer_outputs, (list, tuple)):
layer_outputs = [layer_outputs]
else:
layer_outputs = []
for layer in self._layers:
layer_output = layer(input)
layer_outputs.append(layer_output)

# the merge dimension cannot be determined until we get the output from
# individual layers.
# In case of reduce pooling operations, feature dim is removed and
# channel dim is merged.
# In non-reduce pooling operations, feature dim is merged.
dim = self._dim if self._dim is not None else -1

if self._mode == 'concat':
outputs = torch.cat(tensors=layer_outputs, dim=self._dim)
outputs = torch.cat(tensors=layer_outputs, dim=dim)
elif self._mode == 'elemwise_sum':
outputs = layer_outputs[0]
for i in range(1, len(layer_outputs)):
Expand All @@ -674,30 +695,9 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: # type: ignore
outputs = layer_outputs[0]
for i in range(1, len(layer_outputs)):
outputs = torch.mul(outputs, layer_outputs[i])
elif self._mode == 'sum':
_concat = torch.cat(tensors=layer_outputs, dim=self._dim)
outputs = torch.sum(_concat, dim=self._dim)
elif self._mode == 'mean':
_concat = torch.cat(tensors=layer_outputs, dim=self._dim)
outputs = torch.mean(_concat, dim=self._dim)
elif self._mode == 'prod':
_concat = torch.cat(tensors=layer_outputs, dim=self._dim)
outputs = torch.prod(_concat, dim=self._dim)
elif self._mode == 'max':
_concat = torch.cat(tensors=layer_outputs, dim=self._dim)
outputs, _ = torch.max(_concat, dim=self._dim)
elif self._mode == 'min':
_concat = torch.cat(tensors=layer_outputs, dim=self._dim)
outputs, _ = torch.min(_concat, dim=self._dim)
elif self._mode == 'and':
_concat = torch.cat(tensors=layer_outputs, dim=self._dim)
outputs = torch.all(_concat, dim=self._dim)
elif self._mode == 'or':
_concat = torch.cat(tensors=layer_outputs, dim=self._dim)
outputs = torch.any(_concat, dim=self._dim)
elif self._mode == 'logsumexp':
_concat = torch.cat(tensors=layer_outputs, dim=self._dim)
outputs = torch.logsumexp(_concat, dim=self._dim)
elif self._mode in self._functions:
_concat = torch.cat(tensors=layer_outputs, dim=dim)
outputs = self._functions[self._mode](_concat, dim)
else:
raise ValueError("Unknown merge mode: '%s'" % self._mode)

Expand Down
45 changes: 34 additions & 11 deletions texar/core/layers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ def test_max_reduce_pooling_layer(self):
pool_layer = layers.MaxReducePool1d()
inputs = torch.randn(self._batch_size, self._emb_dim, self._seq_length)
output = pool_layer(inputs)
output_reduce, _ = torch.max(inputs, dim=2, keepdim=True)
output_reduce, _ = torch.max(inputs, dim=2)
self.assertEqual(output.shape, torch.Size([self._batch_size,
self._emb_dim, 1]))
self._emb_dim]))
self.assertEqual(torch.all(torch.eq(output, output_reduce)), 1)

def test_average_reduce_pooling_layer(self):
Expand All @@ -115,33 +115,56 @@ def test_average_reduce_pooling_layer(self):
pool_layer = layers.AvgReducePool1d()
inputs = torch.randn(self._batch_size, self._emb_dim, self._seq_length)
output = pool_layer(inputs)
output_reduce = torch.mean(inputs, dim=2, keepdim=True)
output_reduce = torch.mean(inputs, dim=2)
self.assertEqual(output.shape, torch.Size([self._batch_size,
self._emb_dim, 1]))
self._emb_dim]))
self.assertEqual(torch.all(torch.eq(output, output_reduce)), 1)


class MergeLayerTest(unittest.TestCase):
r"""Tests MergeLayer.
"""

def test_layer_logics(self):
def test_layer_logic(self):
r"""Test the logic of MergeLayer.
"""
layers_ = list()
layers_.append(nn.Conv1d(in_channels=32, out_channels=32,
kernel_size=3))
layers_.append(nn.Conv1d(in_channels=32, out_channels=32,
kernel_size=4))
kernel_size=3))
layers_.append(nn.Conv1d(in_channels=32, out_channels=32,
kernel_size=5))
layers_.append(nn.Linear(in_features=10, out_features=64))
layers_.append(nn.Linear(in_features=10, out_features=64))
m_layer = layers.MergeLayer(layers_)
kernel_size=3))

modes = ["concat", "sum", "mean", "prod", "max", "min", "logsumexp",
"elemwise_sum", "elemwise_mul"]

for mode in modes:
m_layer = layers.MergeLayer(layers_, mode=mode)
input = torch.randn(32, 32, 10)
output = m_layer(input)

if mode == "concat":
self.assertEqual(output.shape, torch.Size([32, 32, 24]))
elif mode == "elemwise_sum" or mode == "elemwise_mul":
self.assertEqual(output.shape, torch.Size([32, 32, 8]))
else:
self.assertEqual(output.shape, torch.Size([32, 32]))

for mode in ["and", "or"]:
m_layer = layers.MergeLayer(layers=None, mode=mode)
input = torch.ones(32, 32, 10, dtype=torch.uint8)
output = m_layer(input)

self.assertEqual(output.shape, torch.Size([32, 32]))

def test_empty_merge_layer(self):
r"""Test the output of MergeLayer with empty layers.
"""
m_layer = layers.MergeLayer(layers=None)
input = torch.randn(32, 32, 10)
output = m_layer(input)
self.assertEqual(output.shape, torch.Size([32, 32, 149]))
self.assertEqual(torch.all(torch.eq(output, input)), 1)


if __name__ == "__main__":
Expand Down
32 changes: 32 additions & 0 deletions texar/modules/networks/conv_networks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,38 @@ def test_feedforward(self):
outputs_2 = network_2(inputs_2)
self.assertEqual(outputs_2.shape, torch.Size([128, 10]))

# test whether concatenation happens along channel dim when feature dim
# has been reduced
hparams = {
# Conv layers
"num_conv_layers": 1,
"out_channels": 128,
"kernel_size": [3, 4, 5],
"other_conv_kwargs": {"padding": 0},
# Pooling layers
"pooling": "AvgPool1d",
"pool_size": None,
"pool_stride": 1,
# Dense layers
"num_dense_layers": 0,
"out_features": [],
"dense_activation": "ReLU",
"other_dense_kwargs": None,
# Dropout
"dropout_conv": [],
"dropout_dense": []
}

network_3 = Conv1DNetwork(in_channels=inputs_2.shape[1],
in_features=inputs_2.shape[2],
hparams=hparams)
inputs_3 = inputs_2
outputs_3 = network_3(inputs_3)
num_of_kernels = len(hparams["kernel_size"])
out_channels = hparams["out_channels"]
self.assertEqual(outputs_3.shape,
torch.Size([128, num_of_kernels * out_channels]))


if __name__ == "__main__":
unittest.main()

0 comments on commit 67ec905

Please sign in to comment.