Skip to content
Permalink
Browse files

Replace slicing with chunk along the last dim (#1244)

* Replace slicing with chunk along the last dim

* Elaborate dim in docs

* Add test

* Minor docstring fixes
  • Loading branch information...
mttk authored and matt-gardner committed May 18, 2018
1 parent a4f2fb8 commit 4f2db42f4077942e0b012300b70807296d0a6b88
Showing with 9 additions and 4 deletions.
  1. +3 −4 allennlp/modules/highway.py
  2. +6 −0 tests/modules/highway_test.py
@@ -22,7 +22,7 @@ class Highway(torch.nn.Module):
Parameters
----------
input_dim : ``int``
The dimensionality of :math:`x`. We assume the input has shape ``(batch_size,
The dimensionality of :math:`x`. We assume the input has shape ``(batch_size, ...,
input_dim)``.
num_layers : ``int``, optional (default=``1``)
The number of highway layers to apply to the input.
@@ -41,7 +41,7 @@ def __init__(self,
for layer in self._layers:
# We should bias the highway layer to just carry its input forward. We do that by
# setting the bias on `B(x)` to be positive, because that means `g` will be biased to
# be high, to we will carry the input forward. The bias on `B(x)` is the second half
# be high, so we will carry the input forward. The bias on `B(x)` is the second half
# of the bias vector in each Linear layer.
layer.bias[input_dim:].data.fill_(1)

@@ -53,8 +53,7 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: # pylint: disable=argu
linear_part = current_input
# NOTE: if you modify this, think about whether you should modify the initialization
# above, too.
nonlinear_part = projected_input[:, (0 * self._input_dim):(1 * self._input_dim)]
gate = projected_input[:, (1 * self._input_dim):(2 * self._input_dim)]
nonlinear_part, gate = projected_input.chunk(2, dim=-1)
nonlinear_part = self._activation(nonlinear_part)
gate = torch.nn.functional.sigmoid(gate)
current_input = gate * linear_part + (1 - gate) * nonlinear_part
@@ -20,3 +20,9 @@ def test_forward_works_on_simple_input(self):
assert result.shape == (2, 2)
# This was checked by hand.
assert_almost_equal(result, [[-0.0394, 0.0197], [1.7527, -0.5550]], decimal=4)

def test_forward_works_on_nd_input(self):
highway = Highway(2, 2)
input_tensor = Variable(torch.ones(2, 2, 2))
output = highway(input_tensor)
assert output.size() == (2, 2, 2)

0 comments on commit 4f2db42

Please sign in to comment.
You can’t perform that action at this time.