From f8fda5fcec8ca5fbabd55cbb0dcb36163fa67e18 Mon Sep 17 00:00:00 2001 From: Alexis Asseman <33075224+aasseman@users.noreply.github.com> Date: Tue, 23 Apr 2019 14:24:06 -0700 Subject: [PATCH 1/4] Correct RNN components such that output goes through FF before being fed back in autoregression mode --- .../wikitext_language_modeling_seq2seq.yml | 36 ++----- ...itext_language_modeling_seq2seq_simple.yml | 30 +----- .../models/recurrent_neural_network.py | 95 +++++++++++-------- ptp/components/models/seq2seq_rnn.py | 24 ++--- 4 files changed, 76 insertions(+), 109 deletions(-) diff --git a/configs/wikitext/wikitext_language_modeling_seq2seq.yml b/configs/wikitext/wikitext_language_modeling_seq2seq.yml index 84bbeaf..aa531fd 100644 --- a/configs/wikitext/wikitext_language_modeling_seq2seq.yml +++ b/configs/wikitext/wikitext_language_modeling_seq2seq.yml @@ -73,28 +73,6 @@ pipeline: streams: inputs: targets outputs: indexed_targets - - # Publish the hidden size of the seq2seq - global_publisher: - type: GlobalVariablePublisher - priority: 1 - # Add input_size to globals, so classifier will use it. - keys: s2s_hidden_size - values: 300 - - # FF, to resize the embeddings to whatever the hidden size of te seq2seq is. - ff_resize_s2s_input: - type: FeedForwardNetwork - priority: 2.5 - s2s_hidden_size: 300 - use_logsoftmax: False - dimensions: 3 - streams: - inputs: embedded_sources - predictions: embedded_sources_resized - globals: - input_size: embeddings_size - prediction_size: s2s_hidden_size # LSTM Encoder lstm_encoder: @@ -107,12 +85,12 @@ pipeline: output_last_state: True prediction_mode: Last streams: - inputs: embedded_sources_resized + inputs: embedded_sources predictions: s2s_encoder_output output_state: s2s_state_output globals: - input_size: s2s_hidden_size - prediction_size: s2s_hidden_size + input_size: embeddings_size + prediction_size: embeddings_size # LSTM Decoder lstm_decoder: @@ -130,10 +108,10 @@ pipeline: predictions: s2s_decoder_output input_state: s2s_state_output globals: - input_size: s2s_hidden_size - prediction_size: s2s_hidden_size + input_size: embeddings_size + prediction_size: embeddings_size - # FF, to resize the from the hidden size of the seq2seq to the size of the target vector + # FF, to resize the from the output size of the seq2seq to the size of the target vector ff_resize_s2s_output: type: FeedForwardNetwork use_logsoftmax: True @@ -142,7 +120,7 @@ pipeline: streams: inputs: s2s_decoder_output globals: - input_size: s2s_hidden_size + input_size: embeddings_size prediction_size: vocabulary_size # Loss diff --git a/configs/wikitext/wikitext_language_modeling_seq2seq_simple.yml b/configs/wikitext/wikitext_language_modeling_seq2seq_simple.yml index 731d590..fd489db 100644 --- a/configs/wikitext/wikitext_language_modeling_seq2seq_simple.yml +++ b/configs/wikitext/wikitext_language_modeling_seq2seq_simple.yml @@ -73,28 +73,6 @@ pipeline: streams: inputs: targets outputs: indexed_targets - - # Publish the hidden size of the seq2seq - global_publisher: - type: GlobalVariablePublisher - priority: 1 - # Add input_size to globals, so classifier will use it. - keys: s2s_hidden_size - values: 300 - - # FF, to resize the embeddings to whatever the hidden size of te seq2seq is. - ff_resize_s2s_input: - type: FeedForwardNetwork - priority: 2.5 - s2s_hidden_size: 300 - use_logsoftmax: False - dimensions: 3 - streams: - inputs: embedded_sources - predictions: embedded_sources_resized - globals: - input_size: embeddings_size - prediction_size: s2s_hidden_size # LSTM seq2seq lstm_encoder: @@ -105,11 +83,11 @@ pipeline: num_layers: 3 use_logsoftmax: False streams: - inputs: embedded_sources_resized + inputs: embedded_sources predictions: s2s_output globals: - input_size: s2s_hidden_size - prediction_size: s2s_hidden_size + input_size: embeddings_size + prediction_size: embeddings_size # FF, to resize the from the hidden size of the seq2seq to the size of the target vector ff_resize_s2s_output: @@ -120,7 +98,7 @@ pipeline: streams: inputs: s2s_output globals: - input_size: s2s_hidden_size + input_size: embeddings_size prediction_size: vocabulary_size # Loss diff --git a/ptp/components/models/recurrent_neural_network.py b/ptp/components/models/recurrent_neural_network.py index 75a7bd4..2a1706e 100644 --- a/ptp/components/models/recurrent_neural_network.py +++ b/ptp/components/models/recurrent_neural_network.py @@ -68,6 +68,9 @@ def __init__(self, name, config): self.prediction_size = self.prediction_size[0] else: raise ConfigurationError("RNN prediction size '{}' must be a single dimension (current {})".format(self.key_prediction_size, self.prediction_size)) + + if "Autoregression" in self.input_mode: + assert self.input_size == self.prediction_size, "In autoregression mode, needs input_size == prediction_size." # Retrieve hidden size from configuration. self.hidden_size = self.config["hidden_size"] @@ -213,10 +216,9 @@ def forward(self, data_dict): # Get inputs [BATCH_SIZE x SEQ_LEN x INPUT_SIZE] if "None" in self.input_mode: batch_size = data_dict[self.key_input_state][0].shape[1] - inputs = torch.zeros(batch_size, 1, self.hidden_size) + inputs = torch.zeros(batch_size, self.hidden_size) if next(self.parameters()).is_cuda: inputs = inputs.cuda() - else: inputs = data_dict[self.key_inputs] if inputs.dim() == 2: @@ -235,56 +237,71 @@ def forward(self, data_dict): # Autoregressive mode - feed back outputs in the input if "Autoregression" in self.input_mode: activations_partial, hidden = self.rnn_cell(inputs, hidden) + activations_partial = activations_partial.squeeze(1) + activations_partial = self.dropout(activations_partial) + activations_partial = self.activation2output(activations_partial) activations += [activations_partial] + # Feed back the outputs iteratively for i in range(self.autoregression_length - 1): - activations_partial, hidden = self.rnn_cell(activations_partial, hidden) + activations_partial, hidden = self.rnn_cell(activations_partial.unsqueeze(1), hidden) + activations_partial = activations_partial.squeeze(1) + activations_partial = self.dropout(activations_partial) + activations_partial = self.activation2output(activations_partial) # Add the single step output into list if self.prediction_mode == "Dense": activations += [activations_partial] # Reassemble all the outputs from list into an output sequence if self.prediction_mode == "Dense": - activations = torch.stack(activations, 1) - else: - activations = activations_partial + outputs = torch.stack(activations, 1) + # Log softmax - along PREDICTION dim. + if self.use_logsoftmax: + outputs = self.log_softmax(outputs) + # Add predictions to datadict. + data_dict.extend({self.key_predictions: outputs}) + elif self.prediction_mode == "Last": + if self.use_logsoftmax: + outputs = self.log_softmax(activations_partial) + # Add predictions to datadict. + data_dict.extend({self.key_predictions: outputs}) + # Normal mode - feed the entire input sequence at once else: activations, hidden = self.rnn_cell(inputs, hidden) - - # Propagate activations through dropout layer. - activations = self.dropout(activations) + # Propagate activations through dropout layer. + activations = self.dropout(activations) - if self.prediction_mode == "Dense": - # Pass every activation through the output layer. - # Reshape to 2D tensor [BATCH_SIZE * SEQ_LEN x HIDDEN_SIZE] - outputs = activations.contiguous().view(-1, self.hidden_size) - - # Propagate data through the output layer [BATCH_SIZE * SEQ_LEN x PREDICTION_SIZE] - outputs = self.activation2output(outputs) - - # Reshape back to 3D tensor [BATCH_SIZE x SEQ_LEN x PREDICTION_SIZE] - outputs = outputs.view(activations.size(0), activations.size(1), outputs.size(1)) - - # Log softmax - along PREDICTION dim. - if self.use_logsoftmax: - outputs = self.log_softmax(outputs) - - # Add predictions to datadict. - data_dict.extend({self.key_predictions: outputs}) - elif self.prediction_mode == "Last": - # Pass only the last activation through the output layer. - outputs = activations.contiguous()[:, -1, :].squeeze() - # Propagate data through the output layer [BATCH_SIZE x PREDICTION_SIZE] - outputs = self.activation2output(outputs) - # Log softmax - along PREDICTION dim. - if self.use_logsoftmax: - outputs = self.log_softmax(outputs) - # Add predictions to datadict. - data_dict.extend({self.key_predictions: outputs}) - elif self.prediction_mode == "None": - # Nothing, since we don't want to keep the RNN's outputs - pass + if self.prediction_mode == "Dense": + # Pass every activation through the output layer. + # Reshape to 2D tensor [BATCH_SIZE * SEQ_LEN x HIDDEN_SIZE] + outputs = activations.contiguous().view(-1, self.hidden_size) + + # Propagate data through the output layer [BATCH_SIZE * SEQ_LEN x PREDICTION_SIZE] + outputs = self.activation2output(outputs) + + # Reshape back to 3D tensor [BATCH_SIZE x SEQ_LEN x PREDICTION_SIZE] + outputs = outputs.view(activations.size(0), activations.size(1), outputs.size(1)) + + # Log softmax - along PREDICTION dim. + if self.use_logsoftmax: + outputs = self.log_softmax(outputs) + + # Add predictions to datadict. + data_dict.extend({self.key_predictions: outputs}) + elif self.prediction_mode == "Last": + # Pass only the last activation through the output layer. + outputs = activations.contiguous()[:, -1, :].squeeze() + # Propagate data through the output layer [BATCH_SIZE x PREDICTION_SIZE] + outputs = self.activation2output(outputs) + # Log softmax - along PREDICTION dim. + if self.use_logsoftmax: + outputs = self.log_softmax(outputs) + # Add predictions to datadict. + data_dict.extend({self.key_predictions: outputs}) + elif self.prediction_mode == "None": + # Nothing, since we don't want to keep the RNN's outputs + pass if self.output_last_state: data_dict.extend({self.key_output_state: hidden}) diff --git a/ptp/components/models/seq2seq_rnn.py b/ptp/components/models/seq2seq_rnn.py index 813ab92..16380a8 100644 --- a/ptp/components/models/seq2seq_rnn.py +++ b/ptp/components/models/seq2seq_rnn.py @@ -185,25 +185,19 @@ def forward(self, data_dict): # Encoder activations, hidden = self.rnn_cell_enc(inputs, hidden) + activations_partial = self.activation2output(activations[:, -1, :]) # Propagate inputs through rnn cell. - activations_partial, hidden = self.rnn_cell_dec(activations[:, -1, :].unsqueeze(1), hidden) - activations = [] - activations += [activations_partial] + activations_partial, hidden = self.rnn_cell_dec(activations_partial.unsqueeze(1), hidden) + activations_partial = activations_partial.squeeze(1) + activations_partial = self.activation2output(activations_partial) + activations = [activations_partial] for i in range(self.autoregression_length - 1): - activations_partial, hidden = self.rnn_cell_dec(activations_partial, hidden) + activations_partial, hidden = self.rnn_cell_dec(activations_partial.unsqueeze(1), hidden) + activations_partial = activations_partial.squeeze(1) + activations_partial = self.activation2output(activations_partial) activations += [activations_partial] - activations = torch.stack(activations, 1) - - # Pass every activation through the output layer. - # Reshape to 2D tensor [BATCH_SIZE * SEQ_LEN x HIDDEN_SIZE] - outputs = activations.contiguous().view(-1, self.hidden_size) - - # Propagate data through the output layer [BATCH_SIZE * SEQ_LEN x PREDICTION_SIZE] - outputs = self.activation2output(outputs) - - # Reshape back to 3D tensor [BATCH_SIZE x SEQ_LEN x PREDICTION_SIZE] - outputs = outputs.view(activations.size(0), activations.size(1), outputs.size(1)) + outputs = torch.stack(activations, 1) # Log softmax - along PREDICTION dim. if self.use_logsoftmax: From c7f1c3455ee54c950b64fa7d554b3fa53fb6d76a Mon Sep 17 00:00:00 2001 From: Alexis Asseman <33075224+aasseman@users.noreply.github.com> Date: Wed, 24 Apr 2019 11:39:53 -0700 Subject: [PATCH 2/4] Added option to RecurrentNeuralNetwork to switch off FFN output layer --- .gitignore | 3 + .../models/recurrent_neural_network.yml | 4 ++ .../models/recurrent_neural_network.py | 55 +++++++++++-------- 3 files changed, 38 insertions(+), 24 deletions(-) diff --git a/.gitignore b/.gitignore index 894a44c..7696fb1 100644 --- a/.gitignore +++ b/.gitignore @@ -102,3 +102,6 @@ venv.bak/ # mypy .mypy_cache/ + +# vscode +.vscode/ \ No newline at end of file diff --git a/configs/default/components/models/recurrent_neural_network.yml b/configs/default/components/models/recurrent_neural_network.yml index a0e6f5e..6767753 100644 --- a/configs/default/components/models/recurrent_neural_network.yml +++ b/configs/default/components/models/recurrent_neural_network.yml @@ -37,6 +37,10 @@ dropout_rate: 0 # * None (all outputs are discarded) prediction_mode: Dense +# Enable FFN layer at the output of the RNN (before eventual feed back in the case of autoregression). +# Useful if the raw outputs of the RNN are needed, for attention encoder-decoder for example. +ffn_output: True + # Input mode # Options: # * Dense (every iteration expects an input) diff --git a/ptp/components/models/recurrent_neural_network.py b/ptp/components/models/recurrent_neural_network.py index 2a1706e..053e4ef 100644 --- a/ptp/components/models/recurrent_neural_network.py +++ b/ptp/components/models/recurrent_neural_network.py @@ -38,6 +38,7 @@ def __init__(self, name, config): # Get input/output mode self.input_mode = self.config["input_mode"] self.output_last_state = self.config["output_last_state"] + self.ffn_output = self.config["ffn_output"] # Get prediction mode from configuration. self.prediction_mode = self.config["prediction_mode"] @@ -137,7 +138,9 @@ def __init__(self, name, config): self.logger.info("Initializing RNN with input size = {}, hidden size = {} and prediction size = {}".format(self.input_size, self.hidden_size, self.prediction_size)) # Create the output layer. - self.activation2output = torch.nn.Linear(self.hidden_size, self.prediction_size) + self.activation2output_lin = None + if(self.ffn_output): + self.activation2output_lin = torch.nn.Linear(self.hidden_size, self.prediction_size) # Create the final non-linearity. self.use_logsoftmax = self.config["use_logsoftmax"] @@ -160,6 +163,25 @@ def initialize_hiddens_state(self, batch_size): # Return hidden_state. return self.init_hidden.expand(self.num_layers, batch_size, self.hidden_size).contiguous() + def activation2output(self, activations): + output = self.dropout(activations) + + if(self.ffn_output): + #output = activations.squeeze(1) + shape = activations.shape + + # Reshape to 2D tensor [BATCH_SIZE * SEQ_LEN x HIDDEN_SIZE] + output = output.contiguous().view(-1, shape[2]) + + # Propagate data through the output layer [BATCH_SIZE * SEQ_LEN x PREDICTION_SIZE] + output = self.activation2output_lin(output) + #output = output.unsqueeze(1) + + # Reshape back to 3D tensor [BATCH_SIZE x SEQ_LEN x PREDICTION_SIZE] + output = output.view(shape[0], shape[1], output.size(1)) + + return output + def input_data_definitions(self): """ @@ -237,23 +259,19 @@ def forward(self, data_dict): # Autoregressive mode - feed back outputs in the input if "Autoregression" in self.input_mode: activations_partial, hidden = self.rnn_cell(inputs, hidden) - activations_partial = activations_partial.squeeze(1) - activations_partial = self.dropout(activations_partial) activations_partial = self.activation2output(activations_partial) activations += [activations_partial] # Feed back the outputs iteratively for i in range(self.autoregression_length - 1): - activations_partial, hidden = self.rnn_cell(activations_partial.unsqueeze(1), hidden) - activations_partial = activations_partial.squeeze(1) - activations_partial = self.dropout(activations_partial) + activations_partial, hidden = self.rnn_cell(activations_partial, hidden) activations_partial = self.activation2output(activations_partial) # Add the single step output into list if self.prediction_mode == "Dense": activations += [activations_partial] # Reassemble all the outputs from list into an output sequence if self.prediction_mode == "Dense": - outputs = torch.stack(activations, 1) + outputs = torch.cat(activations, 1) # Log softmax - along PREDICTION dim. if self.use_logsoftmax: outputs = self.log_softmax(outputs) @@ -261,7 +279,7 @@ def forward(self, data_dict): data_dict.extend({self.key_predictions: outputs}) elif self.prediction_mode == "Last": if self.use_logsoftmax: - outputs = self.log_softmax(activations_partial) + outputs = self.log_softmax(activations_partial.squeeze(1)) # Add predictions to datadict. data_dict.extend({self.key_predictions: outputs}) @@ -269,20 +287,10 @@ def forward(self, data_dict): else: activations, hidden = self.rnn_cell(inputs, hidden) - # Propagate activations through dropout layer. - activations = self.dropout(activations) - if self.prediction_mode == "Dense": # Pass every activation through the output layer. - # Reshape to 2D tensor [BATCH_SIZE * SEQ_LEN x HIDDEN_SIZE] - outputs = activations.contiguous().view(-1, self.hidden_size) - - # Propagate data through the output layer [BATCH_SIZE * SEQ_LEN x PREDICTION_SIZE] - outputs = self.activation2output(outputs) - - # Reshape back to 3D tensor [BATCH_SIZE x SEQ_LEN x PREDICTION_SIZE] - outputs = outputs.view(activations.size(0), activations.size(1), outputs.size(1)) - + outputs = self.activation2output(activations) + # Log softmax - along PREDICTION dim. if self.use_logsoftmax: outputs = self.log_softmax(outputs) @@ -290,10 +298,9 @@ def forward(self, data_dict): # Add predictions to datadict. data_dict.extend({self.key_predictions: outputs}) elif self.prediction_mode == "Last": - # Pass only the last activation through the output layer. - outputs = activations.contiguous()[:, -1, :].squeeze() - # Propagate data through the output layer [BATCH_SIZE x PREDICTION_SIZE] - outputs = self.activation2output(outputs) + outputs = self.activation2output(activations.contiguous()[:, -1, :].unsqueeze(1)) + outputs = outputs.squeeze(1) + # Log softmax - along PREDICTION dim. if self.use_logsoftmax: outputs = self.log_softmax(outputs) From 57c386ebca30fb25e40ef553c59ea827a1b40ca1 Mon Sep 17 00:00:00 2001 From: Alexis Asseman <33075224+aasseman@users.noreply.github.com> Date: Wed, 24 Apr 2019 14:09:31 -0700 Subject: [PATCH 3/4] Fixed DataDefinition of RecurrentNeuralNetwork's output and input state streams --- ptp/components/models/recurrent_neural_network.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/ptp/components/models/recurrent_neural_network.py b/ptp/components/models/recurrent_neural_network.py index 053e4ef..dda5118 100644 --- a/ptp/components/models/recurrent_neural_network.py +++ b/ptp/components/models/recurrent_neural_network.py @@ -198,7 +198,10 @@ def input_data_definitions(self): # Input hidden state if self.initial_state == "Input": - d[self.key_input_state] = DataDefinition([-1, 2 if self.cell_type == 'LSTM' else 1, self.input_size, 1, self.hidden_size], [torch.tensor], "Batch of RNN last states") + if self.cell_type == "LSTM": + d[self.key_input_state] = DataDefinition([2, self.num_layers, -1, self.hidden_size], [torch.tensor], "Batch of RNN last states") + else: + d[self.key_input_state] = DataDefinition([self.num_layers, -1, self.hidden_size], [torch.tensor], "Batch of RNN last states") return d @@ -218,8 +221,11 @@ def output_data_definitions(self): # Output hidden state stream if self.output_last_state: - d[self.key_output_state] = DataDefinition([-1, 2 if self.cell_type == 'LSTM' else 1, self.input_size, 1, self.hidden_size], [torch.tensor], "Batch of RNN last states") - + if self.cell_type == "LSTM": + d[self.key_output_state] = DataDefinition([2, self.num_layers, -1, self.hidden_size], [torch.tensor], "Batch of RNN last states") + else: + d[self.key_output_state] = DataDefinition([self.num_layers, -1, self.hidden_size], [torch.tensor], "Batch of RNN last states") + return d def forward(self, data_dict): From 6a259ea09961d1120ca20311db9c1ee237fc6cec Mon Sep 17 00:00:00 2001 From: Alexis Asseman <33075224+aasseman@users.noreply.github.com> Date: Wed, 24 Apr 2019 17:54:55 -0700 Subject: [PATCH 4/4] Fixing typos --- ptp/components/models/recurrent_neural_network.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ptp/components/models/recurrent_neural_network.py b/ptp/components/models/recurrent_neural_network.py index dda5118..041a8fa 100644 --- a/ptp/components/models/recurrent_neural_network.py +++ b/ptp/components/models/recurrent_neural_network.py @@ -199,9 +199,9 @@ def input_data_definitions(self): # Input hidden state if self.initial_state == "Input": if self.cell_type == "LSTM": - d[self.key_input_state] = DataDefinition([2, self.num_layers, -1, self.hidden_size], [torch.tensor], "Batch of RNN last states") + d[self.key_input_state] = DataDefinition([2, self.num_layers, -1, self.hidden_size], [torch.Tensor], "Batch of RNN last states") else: - d[self.key_input_state] = DataDefinition([self.num_layers, -1, self.hidden_size], [torch.tensor], "Batch of RNN last states") + d[self.key_input_state] = DataDefinition([self.num_layers, -1, self.hidden_size], [torch.Tensor], "Batch of RNN last states") return d @@ -222,9 +222,9 @@ def output_data_definitions(self): # Output hidden state stream if self.output_last_state: if self.cell_type == "LSTM": - d[self.key_output_state] = DataDefinition([2, self.num_layers, -1, self.hidden_size], [torch.tensor], "Batch of RNN last states") + d[self.key_output_state] = DataDefinition([2, self.num_layers, -1, self.hidden_size], [torch.Tensor], "Batch of RNN last states") else: - d[self.key_output_state] = DataDefinition([self.num_layers, -1, self.hidden_size], [torch.tensor], "Batch of RNN last states") + d[self.key_output_state] = DataDefinition([self.num_layers, -1, self.hidden_size], [torch.Tensor], "Batch of RNN last states") return d