Skip to content
Permalink
Browse files

Variational dropout in Augmented LSTM (#2344)

This pull requests fixes #2320 I have done point 2 on the list within the issue. Apart from the points made in the issue I have also added more detail to the doc strings in the [Stacked Bi-Directional LSTM](https://github.com/allenai/allennlp/blob/master/allennlp/modules/stacked_bidirectional_lstm.py) such as the corrected shape of the `final_states` returned form the forward pass and the correct shape required for the `initial_state` argument for the forward pass. I have also included the returned Type of the forward pass for the [Stacked Bi-Directional LSTM](https://github.com/allenai/allennlp/blob/master/allennlp/modules/stacked_bidirectional_lstm.py).

Screen shot of the changed doc strings for the augmented LSTM is below: Only change is the wording of the `recurrent_dropout_probability` argument within the constructor:
![augmented doc string](https://user-images.githubusercontent.com/13574854/51051958-34cf7180-15cd-11e9-9a48-7bbb039ae504.png)

Screen shots of the changed doc strings for the Stacked Bi-Directional LSTM are below: Changes are those stated above with regards to the forward pass and the wording of the `recurrent_dropout_probability` and `layer_dropout_probability` within the constructor as well as the text of the constructor:
![stacked constructor doc string](https://user-images.githubusercontent.com/13574854/51051945-2719ec00-15cd-11e9-9809-4a603a8e1a66.png)
![stacked forward doc string](https://user-images.githubusercontent.com/13574854/51051952-2d0fcd00-15cd-11e9-906b-7a1cf3ab7b25.png)
  • Loading branch information...
apmoore1 authored and DeNeutoy committed Feb 6, 2019
1 parent 2e7acb0 commit ce83cb41bfec1f9c8398f6c77db78bf5a5b3adb5
@@ -35,7 +35,8 @@ class AugmentedLstm(torch.nn.Module):
`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks
<https://arxiv.org/abs/1512.05287>`_ . Implementation wise, this simply
applies a fixed dropout mask per sequence to the recurrent connection of the
LSTM.
LSTM. Dropout is not applied to the output sequence nor the last hidden
state that is returned, it is only applied to all previous hidden states.
use_highway: bool, optional (default = True)
Whether or not to use highway connections between layers. This effectively involves
reparameterising the normal output of an LSTM as::
@@ -163,6 +164,9 @@ def forward(self, # pylint: disable=arguments-differ
# Actually get the slices of the batch which we need for the computation at this timestep.
previous_memory = full_batch_previous_memory[0: current_length_index + 1].clone()
previous_state = full_batch_previous_state[0: current_length_index + 1].clone()
# Only do recurrent dropout if the dropout prob is > 0.0 and we are in training mode.
if dropout_mask is not None and self.training:
previous_state = previous_state * dropout_mask[0: current_length_index + 1]
timestep_input = sequence_tensor[0: current_length_index + 1, index]

# Do the projections for all the gates all at once.
@@ -188,10 +192,6 @@ def forward(self, # pylint: disable=arguments-differ
highway_input_projection = projected_input[:, 5 * self.hidden_size:6 * self.hidden_size]
timestep_output = highway_gate * timestep_output + (1 - highway_gate) * highway_input_projection

# Only do dropout if the dropout prob is > 0.0 and we are in training mode.
if dropout_mask is not None and self.training:
timestep_output = timestep_output * dropout_mask[0: current_length_index + 1]

# We've been doing computation with less than the full batch, so here we create a new
# variable for the the whole batch at this timestep and insert the result for the
# relevant elements of the batch into it.
@@ -2,6 +2,7 @@
import torch
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
from allennlp.modules.augmented_lstm import AugmentedLstm
from allennlp.modules.input_variational_dropout import InputVariationalDropout
from allennlp.common.checks import ConfigurationError


@@ -10,8 +11,9 @@ class StackedBidirectionalLstm(torch.nn.Module):
A standard stacked Bidirectional LSTM where the LSTM layers
are concatenated between each layer. The only difference between
this and a regular bidirectional LSTM is the application of
variational dropout to the hidden states of the LSTM.
Note that this will be slower, as it doesn't use CUDNN.
variational dropout to the hidden states and outputs of each layer apart
from the last layer of the LSTM. Note that this will be slower, as it
doesn't use CUDNN.
Parameters
----------
@@ -22,9 +24,13 @@ class StackedBidirectionalLstm(torch.nn.Module):
num_layers : int, required
The number of stacked Bidirectional LSTMs to use.
recurrent_dropout_probability: float, optional (default = 0.0)
The dropout probability to be used in a dropout scheme as stated in
`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks
<https://arxiv.org/abs/1512.05287>`_ .
The recurrent dropout probability to be used in a dropout scheme as
stated in `A Theoretically Grounded Application of Dropout in Recurrent
Neural Networks <https://arxiv.org/abs/1512.05287>`_ .
layer_dropout_probability: float, optional (default = 0.0)
The layer wise dropout probability to be used in a dropout scheme as
stated in `A Theoretically Grounded Application of Dropout in
Recurrent Neural Networks <https://arxiv.org/abs/1512.05287>`_ .
use_highway: bool, optional (default = True)
Whether or not to use highway connections between layers. This effectively involves
reparameterising the normal output of an LSTM as::
@@ -37,6 +43,7 @@ def __init__(self,
hidden_size: int,
num_layers: int,
recurrent_dropout_probability: float = 0.0,
layer_dropout_probability: float = 0.0,
use_highway: bool = True) -> None:
super(StackedBidirectionalLstm, self).__init__()

@@ -66,26 +73,28 @@ def __init__(self,
self.add_module('backward_layer_{}'.format(layer_index), backward_layer)
layers.append([forward_layer, backward_layer])
self.lstm_layers = layers
self.layer_dropout = InputVariationalDropout(layer_dropout_probability)

def forward(self, # pylint: disable=arguments-differ
inputs: PackedSequence,
initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None):
initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
) -> Tuple[PackedSequence, Tuple[torch.Tensor, torch.Tensor]]:
"""
Parameters
----------
inputs : ``PackedSequence``, required.
A batch first ``PackedSequence`` to run the stacked LSTM over.
initial_state : Tuple[torch.Tensor, torch.Tensor], optional, (default = None)
A tuple (state, memory) representing the initial hidden state and memory
of the LSTM. Each tensor has shape (1, batch_size, output_dimension * 2).
of the LSTM. Each tensor has shape (num_layers, batch_size, output_dimension * 2).
Returns
-------
output_sequence : PackedSequence
The encoded sequence of shape (batch_size, sequence_length, hidden_size * 2)
final_states: torch.Tensor
The per-layer final (state, memory) states of the LSTM, each with shape
(num_layers, batch_size, hidden_size * 2).
(num_layers * 2, batch_size, hidden_size * 2).
"""
if not initial_state:
hidden_states = [None] * len(self.lstm_layers)
@@ -110,6 +119,10 @@ def forward(self, # pylint: disable=arguments-differ
backward_output, _ = pad_packed_sequence(backward_output, batch_first=True)

output_sequence = torch.cat([forward_output, backward_output], -1)
# Apply layer wise dropout on each output sequence apart from the
# first (input) and last
if i < (self.num_layers - 1):
output_sequence = self.layer_dropout(output_sequence)
output_sequence = pack_padded_sequence(output_sequence, lengths, batch_first=True)

final_h.extend([final_forward_state[0], final_backward_state[0]])
@@ -99,3 +99,51 @@ def test_augmented_lstm_is_initialized_with_correct_biases(self):
lstm = AugmentedLstm(2, 3, use_highway=False)
true_state_bias = numpy.array([0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0])
numpy.testing.assert_array_equal(lstm.state_linearity.bias.data.numpy(), true_state_bias)

def test_dropout_is_not_applied_to_output_or_returned_hidden_states(self):
sorted_tensor, sorted_sequence, _, _ = sort_batch_by_length(self.random_tensor, self.sequence_lengths)
tensor = pack_padded_sequence(sorted_tensor, sorted_sequence.data.tolist(), batch_first=True)
lstm = AugmentedLstm(10, 11, recurrent_dropout_probability=0.5)
output, (hidden_state, _) = lstm(tensor)
output_sequence, _ = pad_packed_sequence(output, batch_first=True)
# Test returned output sequence
num_hidden_dims_zero_across_timesteps = ((output_sequence.sum(1) == 0).sum()).item()
# If this is not True then dropout has been applied to the output of the LSTM
assert not num_hidden_dims_zero_across_timesteps
# Should not have dropout applied to the last hidden state as this is not used
# within the LSTM and makes it more consistent with the `torch.nn.LSTM` where
# dropout is not applied to any of it's output. This would also make it more
# consistent with the Keras LSTM implementation as well.
hidden_state = hidden_state.squeeze()
num_hidden_dims_zero_across_timesteps = ((hidden_state == 0).sum()).item()
assert not num_hidden_dims_zero_across_timesteps

def test_dropout_version_is_different_to_no_dropout(self):
augmented_lstm = AugmentedLstm(10, 11)
dropped_augmented_lstm = AugmentedLstm(10, 11, recurrent_dropout_probability=0.9)
# Initialize all weights to be == 1.
constant_init = Initializer.from_params(Params({"type": "constant", "val": 0.5}))
initializer = InitializerApplicator([(".*", constant_init)])
initializer(augmented_lstm)
initializer(dropped_augmented_lstm)

initial_state = torch.randn([1, 5, 11])
initial_memory = torch.randn([1, 5, 11])

# If we use too bigger number like in the PyTorch test the dropout has no affect
sorted_tensor, sorted_sequence, _, _ = sort_batch_by_length(self.random_tensor, self.sequence_lengths)
lstm_input = pack_padded_sequence(sorted_tensor, sorted_sequence.data.tolist(), batch_first=True)

augmented_output, augmented_state = augmented_lstm(lstm_input, (initial_state, initial_memory))
dropped_output, dropped_state = dropped_augmented_lstm(lstm_input, (initial_state, initial_memory))
dropped_output_sequence, _ = pad_packed_sequence(dropped_output, batch_first=True)
augmented_output_sequence, _ = pad_packed_sequence(augmented_output, batch_first=True)
with pytest.raises(AssertionError):
numpy.testing.assert_array_almost_equal(dropped_output_sequence.data.numpy(),
augmented_output_sequence.data.numpy(), decimal=4)
with pytest.raises(AssertionError):
numpy.testing.assert_array_almost_equal(dropped_state[0].data.numpy(),
augmented_state[0].data.numpy(), decimal=4)
with pytest.raises(AssertionError):
numpy.testing.assert_array_almost_equal(dropped_state[1].data.numpy(),
augmented_state[1].data.numpy(), decimal=4)
@@ -1,16 +1,17 @@
# pylint: disable=no-self-use,invalid-name
import numpy
import pytest
import torch
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

from allennlp.modules.stacked_bidirectional_lstm import StackedBidirectionalLstm
from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder
from allennlp.modules.seq2vec_encoders import Seq2VecEncoder
from allennlp.common.testing import AllenNlpTestCase
from allennlp.common.params import Params
from allennlp.nn import InitializerApplicator, Initializer
from allennlp.nn.util import sort_batch_by_length


class TestStackedBidirectionalLstm(AllenNlpTestCase):
class TestStackedBidirectionalLstm():
def test_stacked_bidirectional_lstm_completes_forward_pass(self):
input_tensor = torch.rand(4, 5, 3)
input_tensor[1, 4:, :] = 0.
@@ -55,3 +56,51 @@ def test_stacked_bidirectional_lstm_can_complete_forward_pass_seq2vec(self):
mask = torch.ones(4, 5)
output = encoder(input_tensor, mask)
assert output.detach().numpy().shape == (4, 18)


@pytest.mark.parametrize("dropout_name", ('layer_dropout_probability',
'recurrent_dropout_probability'))
def test_stacked_bidirectional_lstm_dropout_version_is_different(self, dropout_name: str):
stacked_lstm = StackedBidirectionalLstm(input_size=10, hidden_size=11,
num_layers=3)
if dropout_name == 'layer_dropout_probability':
dropped_stacked_lstm = StackedBidirectionalLstm(input_size=10, hidden_size=11,
num_layers=3,
layer_dropout_probability=0.9)
elif dropout_name == 'recurrent_dropout_probability':
dropped_stacked_lstm = StackedBidirectionalLstm(input_size=10, hidden_size=11,
num_layers=3,
recurrent_dropout_probability=0.9)
else:
raise ValueError('Do not recognise the following dropout name '
f'{dropout_name}')
# Initialize all weights to be == 1.
constant_init = Initializer.from_params(Params({"type": "constant", "val": 0.5}))
initializer = InitializerApplicator([(".*", constant_init)])
initializer(stacked_lstm)
initializer(dropped_stacked_lstm)

initial_state = torch.randn([3, 5, 11])
initial_memory = torch.randn([3, 5, 11])

tensor = torch.rand([5, 7, 10])
sequence_lengths = torch.LongTensor([7, 7, 7, 7, 7])

sorted_tensor, sorted_sequence, _, _ = sort_batch_by_length(tensor, sequence_lengths)
lstm_input = pack_padded_sequence(sorted_tensor, sorted_sequence.data.tolist(), batch_first=True)

stacked_output, stacked_state = stacked_lstm(lstm_input, (initial_state, initial_memory))
dropped_output, dropped_state = dropped_stacked_lstm(lstm_input, (initial_state, initial_memory))
dropped_output_sequence, _ = pad_packed_sequence(dropped_output, batch_first=True)
stacked_output_sequence, _ = pad_packed_sequence(stacked_output, batch_first=True)
if dropout_name == 'layer_dropout_probability':
with pytest.raises(AssertionError):
numpy.testing.assert_array_almost_equal(dropped_output_sequence.data.numpy(),
stacked_output_sequence.data.numpy(), decimal=4)
if dropout_name == 'recurrent_dropout_probability':
with pytest.raises(AssertionError):
numpy.testing.assert_array_almost_equal(dropped_state[0].data.numpy(),
stacked_state[0].data.numpy(), decimal=4)
with pytest.raises(AssertionError):
numpy.testing.assert_array_almost_equal(dropped_state[1].data.numpy(),
stacked_state[1].data.numpy(), decimal=4)

0 comments on commit ce83cb4

Please sign in to comment.
You can’t perform that action at this time.