This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
/
lstm_cell_with_projection.py
254 lines (216 loc) · 11.9 KB
/
lstm_cell_with_projection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
"""
An LSTM with Recurrent Dropout, a hidden_state which is projected and
clipping on both the hidden state and the memory state of the LSTM.
"""
from typing import Optional, Tuple, List
import torch
from allennlp.nn.util import get_dropout_mask
from allennlp.nn.initializers import block_orthogonal
class LstmCellWithProjection(torch.nn.Module):
"""
An LSTM with Recurrent Dropout and a projected and clipped hidden state and
memory. Note: this implementation is slower than the native Pytorch LSTM because
it cannot make use of CUDNN optimizations for stacked RNNs due to and
variational dropout and the custom nature of the cell state.
[0]: https://arxiv.org/abs/1512.05287
# Parameters
input_size : `int`, required.
The dimension of the inputs to the LSTM.
hidden_size : `int`, required.
The dimension of the outputs of the LSTM.
cell_size : `int`, required.
The dimension of the memory cell used for the LSTM.
go_forward : `bool`, optional (default = `True`)
The direction in which the LSTM is applied to the sequence.
Forwards by default, or backwards if False.
recurrent_dropout_probability : `float`, optional (default = `0.0`)
The dropout probability to be used in a dropout scheme as stated in
[A Theoretically Grounded Application of Dropout in Recurrent Neural Networks]
[0]. Implementation wise, this simply
applies a fixed dropout mask per sequence to the recurrent connection of the
LSTM.
state_projection_clip_value : `float`, optional, (default = `None`)
The magnitude with which to clip the hidden_state after projecting it.
memory_cell_clip_value : `float`, optional, (default = `None`)
The magnitude with which to clip the memory cell.
# Returns
output_accumulator : `torch.FloatTensor`
The outputs of the LSTM for each timestep. A tensor of shape
(batch_size, max_timesteps, hidden_size) where for a given batch
element, all outputs past the sequence length for that batch are
zero tensors.
final_state : `Tuple[torch.FloatTensor, torch.FloatTensor]`
The final (state, memory) states of the LSTM, with shape
(1, batch_size, hidden_size) and (1, batch_size, cell_size)
respectively. The first dimension is 1 in order to match the Pytorch
API for returning stacked LSTM states.
"""
def __init__(
self,
input_size: int,
hidden_size: int,
cell_size: int,
go_forward: bool = True,
recurrent_dropout_probability: float = 0.0,
memory_cell_clip_value: Optional[float] = None,
state_projection_clip_value: Optional[float] = None,
) -> None:
super().__init__()
# Required to be wrapped with a `PytorchSeq2SeqWrapper`.
self.input_size = input_size
self.hidden_size = hidden_size
self.cell_size = cell_size
self.go_forward = go_forward
self.state_projection_clip_value = state_projection_clip_value
self.memory_cell_clip_value = memory_cell_clip_value
self.recurrent_dropout_probability = recurrent_dropout_probability
# We do the projections for all the gates all at once.
self.input_linearity = torch.nn.Linear(input_size, 4 * cell_size, bias=False)
self.state_linearity = torch.nn.Linear(hidden_size, 4 * cell_size, bias=True)
# Additional projection matrix for making the hidden state smaller.
self.state_projection = torch.nn.Linear(cell_size, hidden_size, bias=False)
self.reset_parameters()
def reset_parameters(self):
# Use sensible default initializations for parameters.
block_orthogonal(self.input_linearity.weight.data, [self.cell_size, self.input_size])
block_orthogonal(self.state_linearity.weight.data, [self.cell_size, self.hidden_size])
self.state_linearity.bias.data.fill_(0.0)
# Initialize forget gate biases to 1.0 as per An Empirical
# Exploration of Recurrent Network Architectures, (Jozefowicz, 2015).
self.state_linearity.bias.data[self.cell_size : 2 * self.cell_size].fill_(1.0)
def forward(
self,
inputs: torch.FloatTensor,
batch_lengths: List[int],
initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
):
"""
# Parameters
inputs : `torch.FloatTensor`, required.
A tensor of shape (batch_size, num_timesteps, input_size)
to apply the LSTM over.
batch_lengths : `List[int]`, required.
A list of length batch_size containing the lengths of the sequences in batch.
initial_state : `Tuple[torch.Tensor, torch.Tensor]`, optional, (default = `None`)
A tuple (state, memory) representing the initial hidden state and memory
of the LSTM. The `state` has shape (1, batch_size, hidden_size) and the
`memory` has shape (1, batch_size, cell_size).
# Returns
output_accumulator : `torch.FloatTensor`
The outputs of the LSTM for each timestep. A tensor of shape
(batch_size, max_timesteps, hidden_size) where for a given batch
element, all outputs past the sequence length for that batch are
zero tensors.
final_state : `Tuple[torch.FloatTensor, torch.FloatTensor]`
A tuple (state, memory) representing the initial hidden state and memory
of the LSTM. The `state` has shape (1, batch_size, hidden_size) and the
`memory` has shape (1, batch_size, cell_size).
"""
batch_size = inputs.size()[0]
total_timesteps = inputs.size()[1]
output_accumulator = inputs.new_zeros(batch_size, total_timesteps, self.hidden_size)
if initial_state is None:
full_batch_previous_memory = inputs.new_zeros(batch_size, self.cell_size)
full_batch_previous_state = inputs.new_zeros(batch_size, self.hidden_size)
else:
full_batch_previous_state = initial_state[0].squeeze(0)
full_batch_previous_memory = initial_state[1].squeeze(0)
current_length_index = batch_size - 1 if self.go_forward else 0
if self.recurrent_dropout_probability > 0.0 and self.training:
dropout_mask = get_dropout_mask(
self.recurrent_dropout_probability, full_batch_previous_state
)
else:
dropout_mask = None
for timestep in range(total_timesteps):
# The index depends on which end we start.
index = timestep if self.go_forward else total_timesteps - timestep - 1
# What we are doing here is finding the index into the batch dimension
# which we need to use for this timestep, because the sequences have
# variable length, so once the index is greater than the length of this
# particular batch sequence, we no longer need to do the computation for
# this sequence. The key thing to recognise here is that the batch inputs
# must be _ordered_ by length from longest (first in batch) to shortest
# (last) so initially, we are going forwards with every sequence and as we
# pass the index at which the shortest elements of the batch finish,
# we stop picking them up for the computation.
if self.go_forward:
while batch_lengths[current_length_index] <= index:
current_length_index -= 1
# If we're going backwards, we are _picking up_ more indices.
else:
# First conditional: Are we already at the maximum number of elements in the batch?
# Second conditional: Does the next shortest sequence beyond the current batch
# index require computation use this timestep?
while (
current_length_index < (len(batch_lengths) - 1)
and batch_lengths[current_length_index + 1] > index
):
current_length_index += 1
# Actually get the slices of the batch which we
# need for the computation at this timestep.
# shape (batch_size, cell_size)
previous_memory = full_batch_previous_memory[0 : current_length_index + 1].clone()
# Shape (batch_size, hidden_size)
previous_state = full_batch_previous_state[0 : current_length_index + 1].clone()
# Shape (batch_size, input_size)
timestep_input = inputs[0 : current_length_index + 1, index]
# Do the projections for all the gates all at once.
# Both have shape (batch_size, 4 * cell_size)
projected_input = self.input_linearity(timestep_input)
projected_state = self.state_linearity(previous_state)
# Main LSTM equations using relevant chunks of the big linear
# projections of the hidden state and inputs.
input_gate = torch.sigmoid(
projected_input[:, (0 * self.cell_size) : (1 * self.cell_size)]
+ projected_state[:, (0 * self.cell_size) : (1 * self.cell_size)]
)
forget_gate = torch.sigmoid(
projected_input[:, (1 * self.cell_size) : (2 * self.cell_size)]
+ projected_state[:, (1 * self.cell_size) : (2 * self.cell_size)]
)
memory_init = torch.tanh(
projected_input[:, (2 * self.cell_size) : (3 * self.cell_size)]
+ projected_state[:, (2 * self.cell_size) : (3 * self.cell_size)]
)
output_gate = torch.sigmoid(
projected_input[:, (3 * self.cell_size) : (4 * self.cell_size)]
+ projected_state[:, (3 * self.cell_size) : (4 * self.cell_size)]
)
memory = input_gate * memory_init + forget_gate * previous_memory
# Here is the non-standard part of this LSTM cell; first, we clip the
# memory cell, then we project the output of the timestep to a smaller size
# and again clip it.
if self.memory_cell_clip_value:
memory = torch.clamp(
memory, -self.memory_cell_clip_value, self.memory_cell_clip_value
)
# shape (current_length_index, cell_size)
pre_projection_timestep_output = output_gate * torch.tanh(memory)
# shape (current_length_index, hidden_size)
timestep_output = self.state_projection(pre_projection_timestep_output)
if self.state_projection_clip_value:
timestep_output = torch.clamp(
timestep_output,
-self.state_projection_clip_value,
self.state_projection_clip_value,
)
# Only do dropout if the dropout prob is > 0.0 and we are in training mode.
if dropout_mask is not None:
timestep_output = timestep_output * dropout_mask[0 : current_length_index + 1]
# We've been doing computation with less than the full batch, so here we create a new
# variable for the the whole batch at this timestep and insert the result for the
# relevant elements of the batch into it.
full_batch_previous_memory = full_batch_previous_memory.clone()
full_batch_previous_state = full_batch_previous_state.clone()
full_batch_previous_memory[0 : current_length_index + 1] = memory
full_batch_previous_state[0 : current_length_index + 1] = timestep_output
output_accumulator[0 : current_length_index + 1, index] = timestep_output
# Mimic the pytorch API by returning state in the following shape:
# (num_layers * num_directions, batch_size, ...). As this
# LSTM cell cannot be stacked, the first dimension here is just 1.
final_state = (
full_batch_previous_state.unsqueeze(0),
full_batch_previous_memory.unsqueeze(0),
)
return output_accumulator, final_state