-
Notifications
You must be signed in to change notification settings - Fork 375
/
DecoderRNN.py
202 lines (167 loc) · 9.12 KB
/
DecoderRNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import random
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
from .attention import Attention
from .baseRNN import BaseRNN
if torch.cuda.is_available():
import torch.cuda as device
else:
import torch as device
class DecoderRNN(BaseRNN):
r"""
Provides functionality for decoding in a seq2seq framework, with an option for attention.
Args:
vocab_size (int): size of the vocabulary
max_len (int): a maximum allowed length for the sequence to be processed
hidden_size (int): the number of features in the hidden state `h`
sos_id (int): index of the start of sentence symbol
eos_id (int): index of the end of sentence symbol
n_layers (int, optional): number of recurrent layers (default: 1)
rnn_cell (str, optional): type of RNN cell (default: gru)
bidirectional (bool, optional): if the encoder is bidirectional (default False)
input_dropout_p (float, optional): dropout probability for the input sequence (default: 0)
dropout_p (float, optional): dropout probability for the output sequence (default: 0)
use_attention(bool, optional): flag indication whether to use attention mechanism or not (default: false)
Attributes:
KEY_ATTN_SCORE (str): key used to indicate attention weights in `ret_dict`
KEY_LENGTH (str): key used to indicate a list representing lengths of output sequences in `ret_dict`
KEY_SEQUENCE (str): key used to indicate a list of sequences in `ret_dict`
Inputs: inputs, encoder_hidden, encoder_outputs, function, teacher_forcing_ratio
- **inputs** (batch, seq_len, input_size): list of sequences, whose length is the batch size and within which
each sequence is a list of token IDs. It is used for teacher forcing when provided. (default `None`)
- **encoder_hidden** (num_layers * num_directions, batch_size, hidden_size): tensor containing the features in the
hidden state `h` of encoder. Used as the initial hidden state of the decoder. (default `None`)
- **encoder_outputs** (batch, seq_len, hidden_size): tensor with containing the outputs of the encoder.
Used for attention mechanism (default is `None`).
- **function** (torch.nn.Module): A function used to generate symbols from RNN hidden state
(default is `torch.nn.functional.log_softmax`).
- **teacher_forcing_ratio** (float): The probability that teacher forcing will be used. A random number is
drawn uniformly from 0-1 for every decoding token, and if the sample is smaller than the given value,
teacher forcing would be used (default is 0).
Outputs: decoder_outputs, decoder_hidden, ret_dict
- **decoder_outputs** (seq_len, batch, vocab_size): list of tensors with size (batch_size, vocab_size) containing
the outputs of the decoding function.
- **decoder_hidden** (num_layers * num_directions, batch, hidden_size): tensor containing the last hidden
state of the decoder.
- **ret_dict**: dictionary containing additional information as follows {*KEY_LENGTH* : list of integers
representing lengths of output sequences, *KEY_SEQUENCE* : list of sequences, where each sequence is a list of
predicted token IDs }.
"""
KEY_ATTN_SCORE = 'attention_score'
KEY_LENGTH = 'length'
KEY_SEQUENCE = 'sequence'
def __init__(self, vocab_size, max_len, hidden_size,
sos_id, eos_id,
n_layers=1, rnn_cell='gru', bidirectional=False,
input_dropout_p=0, dropout_p=0, use_attention=False):
super(DecoderRNN, self).__init__(vocab_size, max_len, hidden_size,
input_dropout_p, dropout_p,
n_layers, rnn_cell)
self.bidirectional_encoder = bidirectional
self.rnn = self.rnn_cell(hidden_size, hidden_size, n_layers, batch_first=True, dropout=dropout_p)
self.output_size = vocab_size
self.max_length = max_len
self.use_attention = use_attention
self.eos_id = eos_id
self.sos_id = sos_id
self.init_input = None
self.embedding = nn.Embedding(self.output_size, self.hidden_size)
if use_attention:
self.attention = Attention(self.hidden_size)
self.out = nn.Linear(self.hidden_size, self.output_size)
def forward_step(self, input_var, hidden, encoder_outputs, function):
batch_size = input_var.size(0)
output_size = input_var.size(1)
embedded = self.embedding(input_var)
embedded = self.input_dropout(embedded)
output, hidden = self.rnn(embedded, hidden)
attn = None
if self.use_attention:
output, attn = self.attention(output, encoder_outputs)
predicted_softmax = function(self.out(output.view(-1, self.hidden_size))).view(batch_size, output_size, -1)
return predicted_softmax, hidden, attn
def forward(self, inputs=None, encoder_hidden=None, function=F.log_softmax,
encoder_outputs=None, teacher_forcing_ratio=0):
ret_dict = dict()
if self.use_attention:
if encoder_outputs is None:
raise ValueError("Argument encoder_outputs cannot be None when attention is used.")
ret_dict[DecoderRNN.KEY_ATTN_SCORE] = list()
if inputs is None:
if teacher_forcing_ratio > 0:
raise ValueError("Teacher forcing has to be disabled (set 0) when no inputs is provided.")
if inputs is None and encoder_hidden is None:
batch_size = 1
else:
if inputs is not None:
batch_size = inputs.size(0)
else:
if self.rnn_cell is nn.LSTM:
batch_size = encoder_hidden[0].size(1)
elif self.rnn_cell is nn.GRU:
batch_size = encoder_hidden.size(1)
if inputs is None:
inputs = Variable(torch.LongTensor([self.sos_id]),
volatile=True).view(batch_size, -1)
if torch.cuda.is_available():
inputs = inputs.cuda()
max_length = self.max_length
else:
max_length = inputs.size(1) - 1 # minus the start of sequence symbol
decoder_hidden = self._init_state(encoder_hidden)
use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
decoder_outputs = []
sequence_symbols = []
lengths = np.array([max_length] * batch_size)
def decode(step, step_output, step_attn):
decoder_outputs.append(step_output)
if self.use_attention:
ret_dict[DecoderRNN.KEY_ATTN_SCORE].append(step_attn)
symbols = decoder_outputs[-1].topk(1)[1]
sequence_symbols.append(symbols)
eos_batches = symbols.data.eq(self.eos_id)
if eos_batches.dim() > 0:
eos_batches = eos_batches.cpu().view(-1).numpy()
update_idx = ((lengths > di) & eos_batches) != 0
lengths[update_idx] = len(sequence_symbols)
return symbols
# Manual unrolling is used to support random teacher forcing.
# If teacher_forcing_ratio is True or False instead of a probability, the unrolling can be done in graph
if use_teacher_forcing:
decoder_input = inputs[:, :-1]
decoder_output, decoder_hidden, attn = self.forward_step(decoder_input, decoder_hidden, encoder_outputs,
function=function)
for di in range(decoder_output.size(1)):
step_output = decoder_output[:, di, :]
step_attn = attn[:, di, :]
decode(di, step_output, step_attn)
else:
decoder_input = inputs[:, 0].unsqueeze(1)
for di in range(max_length):
decoder_output, decoder_hidden, step_attn = self.forward_step(decoder_input, decoder_hidden, encoder_outputs,
function=function)
step_output = decoder_output.squeeze(1)
symbols = decode(di, step_output, step_attn)
decoder_input = symbols
ret_dict[DecoderRNN.KEY_SEQUENCE] = sequence_symbols
ret_dict[DecoderRNN.KEY_LENGTH] = lengths.tolist()
return decoder_outputs, decoder_hidden, ret_dict
def _init_state(self, encoder_hidden):
""" Initialize the encoder hidden state. """
if encoder_hidden is None:
return None
if isinstance(encoder_hidden, tuple):
encoder_hidden = tuple([self._cat_directions(h) for h in encoder_hidden])
else:
encoder_hidden = self._cat_directions(encoder_hidden)
return encoder_hidden
def _cat_directions(self, h):
""" If the encoder is bidirectional, do the following transformation.
(#directions * #layers, #batch, hidden_size) -> (#layers, #batch, #directions * hidden_size)
"""
if self.bidirectional_encoder:
h = torch.cat([h[0:h.size(0):2], h[1:h.size(0):2]], 2)
return h