diff --git a/configs/default/components/models/recurrent_neural_network.yml b/configs/default/components/models/recurrent_neural_network.yml index 3249555..a0e6f5e 100644 --- a/configs/default/components/models/recurrent_neural_network.yml +++ b/configs/default/components/models/recurrent_neural_network.yml @@ -9,7 +9,15 @@ hidden_size: 100 # Flag informing the model to learn the intial state (h0/c0) (LOADED) # When false, (c0/c0) will be initialized as zeros. -initial_state_trainable: True + +# Initial state type: +# * Zero (null vector) +# * Trainable (xavier initialization, trainable) +# * Input (the initial hidden state comes from an input stream) +initial_state: Trainable + +# Wether to include the last hidden state in the outputs +output_last_state: False # Type of recurrent cell (LOADED) # Options: LSTM | GRU | RNN_TANH | RNN_RELU @@ -25,9 +33,19 @@ dropout_rate: 0 # Prediction mode (LOADED) # Options: # * Dense (passes every activation through output layer) | -# * Last (passes only the last activation though output layer) +# * Last (passes only the last activation though output layer) | +# * None (all outputs are discarded) prediction_mode: Dense +# Input mode +# Options: +# * Dense (every iteration expects an input) +# * Autoregression_First (Autoregression, expects an input for the first iteration) +# * Autoregression_None (Autoregression, first input will be a null vector) +input_mode: Dense + +autoregression_length: 42 + # If true, output of the last layer will be additionally processed with Log Softmax (LOADED) use_logsoftmax: True @@ -39,9 +57,17 @@ streams: # Stream containing batch of images (INPUT) inputs: inputs + # Stream containing the inital state of the RNN (INPUT) + # The stream will be actually created only if `inital_state: Input` + input_state: input_state + # Stream containing predictions (OUTPUT) predictions: predictions + # Stream containing the final output state of the RNN (output) + # The stream will be actually created only if `output_last_state: True` + output_state: output_state + globals: #################################################################### # 3. Keymappings of variables that will be RETRIEVED from GLOBALS. diff --git a/configs/default/components/models/seq2seq_rnn.yml b/configs/default/components/models/seq2seq_rnn.yml new file mode 100644 index 0000000..9d9350e --- /dev/null +++ b/configs/default/components/models/seq2seq_rnn.yml @@ -0,0 +1,81 @@ +# This file defines the default values for the RNN model. + +#################################################################### +# 1. CONFIGURATION PARAMETERS that will be LOADED by the component. +#################################################################### + +# Size of the hidden state (LOADED) +hidden_size: 100 + +# Flag informing the model to learn the intial state (h0/c0) (LOADED) +# When false, (c0/c0) will be initialized as zeros. + +# Initial state type: +# * Zero (null vector) +# * Trainable (xavier initialization, trainable) +# * Input (the initial hidden state comes from an input stream) +initial_state: Trainable + +# Wether to include the last hidden state in the outputs +output_last_state: False + +# Type of recurrent cell (LOADED) +# Options: LSTM | GRU | RNN_TANH | RNN_RELU +cell_type: LSTM + +# Number of "stacked" layers (LOADED) +num_layers: 1 + +# Dropout rate (LOADED) +# Default: 0 (means that it is turned off) +dropout_rate: 0 + +# Prediction mode (LOADED) +# Options: +# * Dense (passes every activation through output layer) | +# * Last (passes only the last activation though output layer) | +# * None (all outputs are discarded) +prediction_mode: Dense + +# Input mode +# Options: +# * Dense (every iteration expects an input) +# * Autoregression_First (Autoregression, expects an input for the first iteration) +# * Autoregression_None (Autoregression, first input will be a null vector) +input_mode: Dense + +autoregression_length: 50 + +# If true, output of the last layer will be additionally processed with Log Softmax (LOADED) +use_logsoftmax: True + +streams: + #################################################################### + # 2. Keymappings associated with INPUT and OUTPUT streams. + #################################################################### + + # Stream containing batch of images (INPUT) + inputs: inputs + + # Stream containing predictions (OUTPUT) + predictions: predictions + +globals: + #################################################################### + # 3. Keymappings of variables that will be RETRIEVED from GLOBALS. + #################################################################### + + # Size of the input (RETRIEVED) + input_size: input_size + + # Size of the prediction (RETRIEVED) + prediction_size: prediction_size + + #################################################################### + # 4. Keymappings associated with GLOBAL variables that will be SET. + #################################################################### + + #################################################################### + # 5. Keymappings associated with statistics that will be ADDED. + #################################################################### + diff --git a/configs/vqa_med_2019/c2_classification/default_c2_classification.yml b/configs/vqa_med_2019/c2_classification/default_c2_classification.yml index f88e328..0f9e028 100644 --- a/configs/vqa_med_2019/c2_classification/default_c2_classification.yml +++ b/configs/vqa_med_2019/c2_classification/default_c2_classification.yml @@ -17,7 +17,6 @@ training: episode_limit: 10000 epoch_limit: -1 - # Validation parameters: validation: problem: diff --git a/configs/wikitext/wikitext_language_modeling_seq2seq.yml b/configs/wikitext/wikitext_language_modeling_seq2seq.yml new file mode 100644 index 0000000..84bbeaf --- /dev/null +++ b/configs/wikitext/wikitext_language_modeling_seq2seq.yml @@ -0,0 +1,196 @@ +# This pipeline applies seq2seq on wikitext-2 to make word-level prediction. +# It's been made for test purposes only, as it is doing: +# [word 0 , ... , word 49] -> [word 1 , ... , word 50] (basically copying most of the input) +# +# The seq2seq here is implemented throught the use of 2 `RecurrentNeuralNetwork` + +# Training parameters: +training: + problem: + type: &p_type WikiTextLanguageModeling + data_folder: &data_folder ~/data/language_modeling/wikitext-2 + dataset: &dataset wikitext-2 + subset: train + sentence_length: 50 + batch_size: 64 + + # optimizer parameters: + optimizer: + name: Adam + lr: 1.0e-3 + + # settings parameters + terminal_conditions: + loss_stop: 1.0e-2 + episode_limit: 1000000 + epoch_limit: 100 + +# Validation parameters: +validation: + partial_validation_interval: 100 + problem: + type: *p_type + data_folder: *data_folder + dataset: *dataset + subset: valid + sentence_length: 50 + batch_size: 64 + +# Testing parameters: +testing: + problem: + type: *p_type + data_folder: *data_folder + dataset: *dataset + subset: test + sentence_length: 50 + batch_size: 64 + +pipeline: + name: wikitext_language_modeling_seq2seq + + # Source encoding - model 1. + source_sentence_embedding: + type: SentenceEmbeddings + priority: 1.1 + embeddings_size: 50 + pretrained_embeddings: glove.6B.50d.txt + data_folder: *data_folder + source_vocabulary_files: wiki.train.tokens,wiki.valid.tokens,wiki.test.tokens + vocabulary_mappings_file: wiki.all.tokenized_words + additional_tokens: + export_word_mappings_to_globals: True + streams: + inputs: sources + outputs: embedded_sources + + # Target encoding. + target_indexer: + type: SentenceIndexer + priority: 2.1 + data_folder: *data_folder + import_word_mappings_from_globals: True + streams: + inputs: targets + outputs: indexed_targets + + # Publish the hidden size of the seq2seq + global_publisher: + type: GlobalVariablePublisher + priority: 1 + # Add input_size to globals, so classifier will use it. + keys: s2s_hidden_size + values: 300 + + # FF, to resize the embeddings to whatever the hidden size of te seq2seq is. + ff_resize_s2s_input: + type: FeedForwardNetwork + priority: 2.5 + s2s_hidden_size: 300 + use_logsoftmax: False + dimensions: 3 + streams: + inputs: embedded_sources + predictions: embedded_sources_resized + globals: + input_size: embeddings_size + prediction_size: s2s_hidden_size + + # LSTM Encoder + lstm_encoder: + type: RecurrentNeuralNetwork + priority: 3 + initial_state: Trainable + hidden_size: 300 + num_layers: 3 + use_logsoftmax: False + output_last_state: True + prediction_mode: Last + streams: + inputs: embedded_sources_resized + predictions: s2s_encoder_output + output_state: s2s_state_output + globals: + input_size: s2s_hidden_size + prediction_size: s2s_hidden_size + + # LSTM Decoder + lstm_decoder: + type: RecurrentNeuralNetwork + priority: 4 + initial_state: Input + hidden_size: 300 + num_layers: 3 + use_logsoftmax: False + input_mode: Autoregression_First + autoregression_length: 50 + prediction_mode: Dense + streams: + inputs: s2s_encoder_output + predictions: s2s_decoder_output + input_state: s2s_state_output + globals: + input_size: s2s_hidden_size + prediction_size: s2s_hidden_size + + # FF, to resize the from the hidden size of the seq2seq to the size of the target vector + ff_resize_s2s_output: + type: FeedForwardNetwork + use_logsoftmax: True + dimensions: 3 + priority: 5 + streams: + inputs: s2s_decoder_output + globals: + input_size: s2s_hidden_size + prediction_size: vocabulary_size + + # Loss + nllloss: + type: NLLLoss + priority: 6 + num_targets_dims: 2 + streams: + targets: indexed_targets + loss: loss + + # Prediction decoding. + prediction_decoder: + type: SentenceIndexer + priority: 10 + # Reverse mode. + reverse: True + # Use distributions as inputs. + use_input_distributions: True + data_folder: *data_folder + import_word_mappings_from_globals: True + streams: + inputs: predictions + outputs: prediction_sentences + + + # Statistics. + batch_size: + type: BatchSizeStatistics + priority: 100.0 + + #accuracy: + # type: AccuracyStatistics + # priority: 100.1 + # streams: + # targets: indexed_targets + + bleu: + type: BLEUStatistics + priority: 100.2 + streams: + targets: indexed_targets + + + # Viewers. + viewer: + type: StreamViewer + priority: 100.3 + input_streams: sources,targets,indexed_targets,prediction_sentences + +#: pipeline diff --git a/configs/wikitext/wikitext_language_modeling_seq2seq_simple.yml b/configs/wikitext/wikitext_language_modeling_seq2seq_simple.yml new file mode 100644 index 0000000..731d590 --- /dev/null +++ b/configs/wikitext/wikitext_language_modeling_seq2seq_simple.yml @@ -0,0 +1,167 @@ +# This pipeline applies seq2seq on wikitext-2 to make word-level prediction. +# It's been made for test purposes only, as it is doing: +# [word 0 , ... , word 49] -> [word 1 , ... , word 50] (basically copying most of the input) +# +# The seq2seq here is implemented throught the use of a simplified seq2seq component `Seq2Seq_RNN` + +# Training parameters: +training: + problem: + type: &p_type WikiTextLanguageModeling + data_folder: &data_folder ~/data/language_modeling/wikitext-2 + dataset: &dataset wikitext-2 + subset: train + sentence_length: 50 + batch_size: 64 + + # optimizer parameters: + optimizer: + name: Adam + lr: 1.0e-3 + + # settings parameters + terminal_conditions: + loss_stop: 1.0e-2 + episode_limit: 1000000 + epoch_limit: 100 + +# Validation parameters: +validation: + partial_validation_interval: 100 + problem: + type: *p_type + data_folder: *data_folder + dataset: *dataset + subset: valid + sentence_length: 50 + batch_size: 64 + +# Testing parameters: +testing: + problem: + type: *p_type + data_folder: *data_folder + dataset: *dataset + subset: test + sentence_length: 50 + batch_size: 64 + +pipeline: + name: wikitext_language_modeling_rnn + + # Source encoding - model 1. + source_sentence_embedding: + type: SentenceEmbeddings + priority: 1.1 + embeddings_size: 50 + pretrained_embeddings: glove.6B.50d.txt + data_folder: *data_folder + source_vocabulary_files: wiki.train.tokens,wiki.valid.tokens,wiki.test.tokens + vocabulary_mappings_file: wiki.all.tokenized_words + additional_tokens: + export_word_mappings_to_globals: True + streams: + inputs: sources + outputs: embedded_sources + + # Target encoding. + target_indexer: + type: SentenceIndexer + priority: 2.1 + data_folder: *data_folder + import_word_mappings_from_globals: True + streams: + inputs: targets + outputs: indexed_targets + + # Publish the hidden size of the seq2seq + global_publisher: + type: GlobalVariablePublisher + priority: 1 + # Add input_size to globals, so classifier will use it. + keys: s2s_hidden_size + values: 300 + + # FF, to resize the embeddings to whatever the hidden size of te seq2seq is. + ff_resize_s2s_input: + type: FeedForwardNetwork + priority: 2.5 + s2s_hidden_size: 300 + use_logsoftmax: False + dimensions: 3 + streams: + inputs: embedded_sources + predictions: embedded_sources_resized + globals: + input_size: embeddings_size + prediction_size: s2s_hidden_size + + # LSTM seq2seq + lstm_encoder: + type: Seq2Seq_RNN + priority: 3 + initial_state: Trainable + hidden_size: 300 + num_layers: 3 + use_logsoftmax: False + streams: + inputs: embedded_sources_resized + predictions: s2s_output + globals: + input_size: s2s_hidden_size + prediction_size: s2s_hidden_size + + # FF, to resize the from the hidden size of the seq2seq to the size of the target vector + ff_resize_s2s_output: + type: FeedForwardNetwork + use_logsoftmax: True + dimensions: 3 + priority: 5 + streams: + inputs: s2s_output + globals: + input_size: s2s_hidden_size + prediction_size: vocabulary_size + + # Loss + nllloss: + type: NLLLoss + priority: 6 + num_targets_dims: 2 + streams: + targets: indexed_targets + loss: loss + + # Prediction decoding. + prediction_decoder: + type: SentenceIndexer + priority: 10 + # Reverse mode. + reverse: True + # Use distributions as inputs. + use_input_distributions: True + data_folder: *data_folder + import_word_mappings_from_globals: True + streams: + inputs: predictions + outputs: prediction_sentences + + # Statistics. + batch_size: + type: BatchSizeStatistics + priority: 100.0 + + bleu: + type: BLEUStatistics + priority: 100.2 + streams: + targets: indexed_targets + + + # Viewers. + viewer: + type: StreamViewer + priority: 100.3 + input_streams: sources,targets,indexed_targets,prediction_sentences + +#: pipeline diff --git a/ptp/components/models/__init__.py b/ptp/components/models/__init__.py index b8b6093..3451d2f 100644 --- a/ptp/components/models/__init__.py +++ b/ptp/components/models/__init__.py @@ -6,6 +6,7 @@ from .model import Model from .recurrent_neural_network import RecurrentNeuralNetwork from .sentence_embeddings import SentenceEmbeddings +from .seq2seq_rnn import Seq2Seq_RNN from .vqa.element_wise_multiplication import ElementWiseMultiplication from .vqa.multimodal_compact_bilinear_pooling import MultimodalCompactBilinearPooling @@ -19,6 +20,7 @@ 'Model', 'RecurrentNeuralNetwork', 'SentenceEmbeddings', + 'Seq2Seq_RNN', 'ElementWiseMultiplication', 'MultimodalCompactBilinearPooling', ] diff --git a/ptp/components/models/recurrent_neural_network.py b/ptp/components/models/recurrent_neural_network.py index 612c6dd..75a7bd4 100644 --- a/ptp/components/models/recurrent_neural_network.py +++ b/ptp/components/models/recurrent_neural_network.py @@ -35,9 +35,22 @@ def __init__(self, name, config): # Call constructors of parent classes. Model.__init__(self, name, RecurrentNeuralNetwork, config) - # Get key mappings. - self.key_inputs = self.stream_keys["inputs"] - self.key_predictions = self.stream_keys["predictions"] + # Get input/output mode + self.input_mode = self.config["input_mode"] + self.output_last_state = self.config["output_last_state"] + + # Get prediction mode from configuration. + self.prediction_mode = self.config["prediction_mode"] + if self.prediction_mode not in ['Dense','Last', 'None']: + raise ConfigurationError("Invalid 'prediction_mode' (current {}, available {})".format(self.prediction_mode, ['Dense','Last', 'None'])) + + self.autoregression_length = self.config["autoregression_length"] + + # Check if initial state (h0/c0) is zero, trainable, or coming from input stream. + self.initial_state = self.config["initial_state"] + + # Get number of layers from config. + self.num_layers = self.config["num_layers"] # Retrieve input size from global variables. self.key_input_size = self.global_keys["input_size"] @@ -56,11 +69,6 @@ def __init__(self, name, config): else: raise ConfigurationError("RNN prediction size '{}' must be a single dimension (current {})".format(self.key_prediction_size, self.prediction_size)) - # Get prediction mode from configuration. - self.prediction_mode = self.config["prediction_mode"] - if self.prediction_mode not in ['Dense','Last']: - raise ConfigurationError("Invalid 'prediction_mode' (current {}, available {})".format(self.prediction_mode, ['Dense','Last'])) - # Retrieve hidden size from configuration. self.hidden_size = self.config["hidden_size"] if type(self.hidden_size) == list: @@ -69,14 +77,12 @@ def __init__(self, name, config): else: raise ConfigurationError("RNN hidden_size must be a single dimension (current {})".format(self.hidden_size)) - self.logger.info("Initializing RNN with input size = {}, hidden size = {} and prediction size = {}".format(self.input_size, self.hidden_size, self.prediction_size)) - - # Get number of layers from config. - self.num_layers = self.config["num_layers"] - # Get dropout rate value from config. dropout_rate = self.config["dropout_rate"] + # Create dropout layer. + self.dropout = torch.nn.Dropout(dropout_rate) + # Create RNN depending on the configuration self.cell_type = self.config["cell_type"] if self.cell_type in ['LSTM', 'GRU']: @@ -88,18 +94,16 @@ def __init__(self, name, config): nonlinearity = {'RNN_TANH': 'tanh', 'RNN_RELU': 'relu'}[self.cell_type] # Create rnn cell. self.rnn_cell = torch.nn.RNN(self.input_size, self.hidden_size, self.num_layers, nonlinearity=nonlinearity, dropout=dropout_rate, batch_first=True) - except KeyError: raise ConfigurationError( "Invalid RNN type, available options for 'cell_type' are ['LSTM', 'GRU', 'RNN_TANH', 'RNN_RELU'] (currently '{}')".format(self.cell_type)) - # Check if initial state (h0/c0) are trainable or not. - self.initial_state_trainable = self.config["initial_state_trainable"] - - # Parameters - for a single sample. + # Parameters - for a single sample. h0 = torch.zeros(self.num_layers, 1, self.hidden_size) c0 = torch.zeros(self.num_layers, 1, self.hidden_size) - if self.initial_state_trainable: + self.init_hidden = None + + if self.initial_state == "Trainable": self.logger.info("Using trainable initial (h0/c0) state") # Initialize a single vector used as hidden state. # Initialize it using xavier initialization. @@ -110,15 +114,24 @@ def __init__(self, name, config): if self.cell_type == 'LSTM': torch.nn.init.xavier_uniform(c0) self.init_memory = torch.nn.Parameter(c0, requires_grad=True) - else: + elif self.initial_state in ["Zero", "Input"]: self.logger.info("Using zero initial (h0/c0) state") # We will still embedd it into parameter to enable storing/loading of both types of models by each other. self.init_hidden = torch.nn.Parameter(h0, requires_grad=False) if self.cell_type == 'LSTM': self.init_memory = torch.nn.Parameter(c0, requires_grad=False) - # Create dropout layer. - self.dropout = torch.nn.Dropout(dropout_rate) + # Get key mappings. + if "None" not in self.input_mode: + self.key_inputs = self.stream_keys["inputs"] + if "None" not in self.prediction_mode: + self.key_predictions = self.stream_keys["predictions"] + if self.initial_state == "Input": + self.key_input_state = self.stream_keys["input_state"] + if self.output_last_state: + self.key_output_state = self.stream_keys["output_state"] + + self.logger.info("Initializing RNN with input size = {}, hidden size = {} and prediction size = {}".format(self.input_size, self.hidden_size, self.prediction_size)) # Create the output layer. self.activation2output = torch.nn.Linear(self.hidden_size, self.prediction_size) @@ -151,10 +164,18 @@ def input_data_definitions(self): :return: dictionary containing input data definitions (each of type :py:class:`ptp.utils.DataDefinition`). """ - return { - self.key_inputs: DataDefinition([-1, -1, self.input_size], [torch.Tensor], "Batch of inputs, each represented as index [BATCH_SIZE x SEQ_LEN x INPUT_SIZE]"), - } + d = {} + + if self.input_mode == "Dense": + d[self.key_inputs] = DataDefinition([-1, -1, self.input_size], [torch.Tensor], "Batch of inputs, each represented as index [BATCH_SIZE x SEQ_LEN x INPUT_SIZE]") + elif self.input_mode == "Autoregression_First": + d[self.key_inputs] = DataDefinition([-1, self.input_size], [torch.Tensor], "Batch of inputs, each represented as index [BATCH_SIZE x SEQ_LEN x INPUT_SIZE]") + + # Input hidden state + if self.initial_state == "Input": + d[self.key_input_state] = DataDefinition([-1, 2 if self.cell_type == 'LSTM' else 1, self.input_size, 1, self.hidden_size], [torch.tensor], "Batch of RNN last states") + return d def output_data_definitions(self): """ @@ -162,17 +183,19 @@ def output_data_definitions(self): :return: dictionary containing output data definitions (each of type :py:class:`ptp.utils.DataDefinition`). """ + d = {} if self.prediction_mode == "Dense": - return { - self.key_predictions: DataDefinition([-1, -1, self.prediction_size], [torch.Tensor], "Batch of predictions, each represented as probability distribution over classes [BATCH_SIZE x SEQ_LEN x PREDICTION_SIZE]") - } - else: # "Last" - return { - # Only last prediction. - self.key_predictions: DataDefinition([-1, self.prediction_size], [torch.Tensor], "Batch of predictions, each represented as probability distribution over classes [BATCH_SIZE x SEQ_LEN x PREDICTION_SIZE]") - } - + d[self.key_predictions] = DataDefinition([-1, -1, self.prediction_size], [torch.Tensor], "Batch of predictions, each represented as probability distribution over classes [BATCH_SIZE x SEQ_LEN x PREDICTION_SIZE]") + elif self.prediction_mode == "Last": # "Last" + # Only last prediction. + d[self.key_predictions] = DataDefinition([-1, self.prediction_size], [torch.Tensor], "Batch of predictions, each represented as probability distribution over classes [BATCH_SIZE x SEQ_LEN x PREDICTION_SIZE]") + + # Output hidden state stream + if self.output_last_state: + d[self.key_output_state] = DataDefinition([-1, 2 if self.cell_type == 'LSTM' else 1, self.input_size, 1, self.hidden_size], [torch.tensor], "Batch of RNN last states") + + return d def forward(self, data_dict): """ @@ -184,15 +207,50 @@ def forward(self, data_dict): - predictions: returned output with predictions (log_probs) [BATCH_SIZE x SEQ_LEN x PREDICTION_SIZE] """ + inputs = None + batch_size = None + # Get inputs [BATCH_SIZE x SEQ_LEN x INPUT_SIZE] - inputs = data_dict[self.key_inputs] - batch_size = inputs.shape[0] + if "None" in self.input_mode: + batch_size = data_dict[self.key_input_state][0].shape[1] + inputs = torch.zeros(batch_size, 1, self.hidden_size) + if next(self.parameters()).is_cuda: + inputs = inputs.cuda() + + else: + inputs = data_dict[self.key_inputs] + if inputs.dim() == 2: + inputs = inputs.unsqueeze(1) + batch_size = inputs.shape[0] + # Initialize hidden state. - hidden = self.initialize_hiddens_state(batch_size) + if self.initial_state == "Input": + hidden = data_dict[self.key_input_state] + else: + hidden = self.initialize_hiddens_state(batch_size) + + activations = [] + + # Autoregressive mode - feed back outputs in the input + if "Autoregression" in self.input_mode: + activations_partial, hidden = self.rnn_cell(inputs, hidden) + activations += [activations_partial] + # Feed back the outputs iteratively + for i in range(self.autoregression_length - 1): + activations_partial, hidden = self.rnn_cell(activations_partial, hidden) + # Add the single step output into list + if self.prediction_mode == "Dense": + activations += [activations_partial] + # Reassemble all the outputs from list into an output sequence + if self.prediction_mode == "Dense": + activations = torch.stack(activations, 1) + else: + activations = activations_partial + # Normal mode - feed the entire input sequence at once + else: + activations, hidden = self.rnn_cell(inputs, hidden) - # Propagate inputs through rnn cell. - activations, hidden = self.rnn_cell(inputs, hidden) # Propagate activations through dropout layer. activations = self.dropout(activations) @@ -211,7 +269,10 @@ def forward(self, data_dict): # Log softmax - along PREDICTION dim. if self.use_logsoftmax: outputs = self.log_softmax(outputs) - else: + + # Add predictions to datadict. + data_dict.extend({self.key_predictions: outputs}) + elif self.prediction_mode == "Last": # Pass only the last activation through the output layer. outputs = activations.contiguous()[:, -1, :].squeeze() # Propagate data through the output layer [BATCH_SIZE x PREDICTION_SIZE] @@ -219,6 +280,11 @@ def forward(self, data_dict): # Log softmax - along PREDICTION dim. if self.use_logsoftmax: outputs = self.log_softmax(outputs) - - # Add predictions to datadict. - data_dict.extend({self.key_predictions: outputs}) + # Add predictions to datadict. + data_dict.extend({self.key_predictions: outputs}) + elif self.prediction_mode == "None": + # Nothing, since we don't want to keep the RNN's outputs + pass + + if self.output_last_state: + data_dict.extend({self.key_output_state: hidden}) diff --git a/ptp/components/models/seq2seq_rnn.py b/ptp/components/models/seq2seq_rnn.py new file mode 100644 index 0000000..813ab92 --- /dev/null +++ b/ptp/components/models/seq2seq_rnn.py @@ -0,0 +1,214 @@ +# Copyright (C) aasseman, IBM Corporation 2019 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__author__ = "Alexis Asseman" + +import torch + +from ptp.configuration.configuration_error import ConfigurationError +from ptp.components.models.model import Model +from ptp.data_types.data_definition import DataDefinition + + +class Seq2Seq_RNN(Model): + """ + Simple Classifier consisting of fully connected layer with log softmax non-linearity. + """ + def __init__(self, name, config): + """ + Initializes the model. + + :param config: Dictionary of parameters (read from configuration ``.yaml`` file). + :type config: ``ptp.configuration.ConfigInterface`` + """ + # Call constructors of parent classes. + Model.__init__(self, name, Seq2Seq_RNN, config) + + # Get input/output mode + self.input_mode = self.config["input_mode"] + + self.autoregression_length = self.config["autoregression_length"] + + # Check if initial state (h0/c0) is zero, trainable, or coming from input stream. + self.initial_state = self.config["initial_state"] + + # Get number of layers from config. + self.num_layers = self.config["num_layers"] + + # Retrieve input size from global variables. + self.key_input_size = self.global_keys["input_size"] + self.input_size = self.globals["input_size"] + if type(self.input_size) == list: + if len(self.input_size) == 1: + self.input_size = self.input_size[0] + else: + raise ConfigurationError("RNN input size '{}' must be a single dimension (current {})".format(self.key_input_size, self.input_size)) + + # Retrieve output (prediction) size from global params. + self.prediction_size = self.globals["prediction_size"] + if type(self.prediction_size) == list: + if len(self.prediction_size) == 1: + self.prediction_size = self.prediction_size[0] + else: + raise ConfigurationError("RNN prediction size '{}' must be a single dimension (current {})".format(self.key_prediction_size, self.prediction_size)) + + # Retrieve hidden size from configuration. + self.hidden_size = self.config["hidden_size"] + if type(self.hidden_size) == list: + if len(self.hidden_size) == 1: + self.hidden_size = self.hidden_size[0] + else: + raise ConfigurationError("RNN hidden_size must be a single dimension (current {})".format(self.hidden_size)) + + # Create RNN depending on the configuration + self.cell_type = self.config["cell_type"] + if self.cell_type in ['LSTM', 'GRU']: + # Create rnn cell. + self.rnn_cell_enc = getattr(torch.nn, self.cell_type)(self.input_size, self.hidden_size, self.num_layers, batch_first=True) + self.rnn_cell_dec = getattr(torch.nn, self.cell_type)(self.input_size, self.hidden_size, self.num_layers, batch_first=True) + else: + try: + # Retrieve the non-linearity. + nonlinearity = {'RNN_TANH': 'tanh', 'RNN_RELU': 'relu'}[self.cell_type] + # Create rnn cell. + self.rnn_cell_enc = torch.nn.RNN(self.input_size, self.hidden_size, self.num_layers, nonlinearity=nonlinearity, batch_first=True) + self.rnn_cell_dec = torch.nn.RNN(self.input_size, self.hidden_size, self.num_layers, nonlinearity=nonlinearity, batch_first=True) + except KeyError: + raise ConfigurationError( "Invalid RNN type, available options for 'cell_type' are ['LSTM', 'GRU', 'RNN_TANH', 'RNN_RELU'] (currently '{}')".format(self.cell_type)) + + # Parameters - for a single sample. + h0 = torch.zeros(self.num_layers, 1, self.hidden_size) + c0 = torch.zeros(self.num_layers, 1, self.hidden_size) + + self.init_hidden = None + + if self.initial_state == "Trainable": + self.logger.info("Using trainable initial (h0/c0) state") + # Initialize a single vector used as hidden state. + # Initialize it using xavier initialization. + torch.nn.init.xavier_uniform(h0) + # It will be trainable, i.e. the system will learn what should be the right initialization state. + self.init_hidden = torch.nn.Parameter(h0, requires_grad=True) + # Initilize memory cell in a similar way. + if self.cell_type == 'LSTM': + torch.nn.init.xavier_uniform(c0) + self.init_memory = torch.nn.Parameter(c0, requires_grad=True) + elif self.initial_state == "Zero": + self.logger.info("Using zero initial (h0/c0) state") + # We will still embedd it into parameter to enable storing/loading of both types of models by each other. + self.init_hidden = torch.nn.Parameter(h0, requires_grad=False) + if self.cell_type == 'LSTM': + self.init_memory = torch.nn.Parameter(c0, requires_grad=False) + + # Get key mappings. + self.key_inputs = self.stream_keys["inputs"] + self.key_predictions = self.stream_keys["predictions"] + + self.logger.info("Initializing RNN with input size = {}, hidden size = {} and prediction size = {}".format(self.input_size, self.hidden_size, self.prediction_size)) + + # Create the output layer. + self.activation2output = torch.nn.Linear(self.hidden_size, self.prediction_size) + + # Create the final non-linearity. + self.use_logsoftmax = self.config["use_logsoftmax"] + if self.use_logsoftmax: + # Used then returning dense prediction, i.e. every output of unfolded model. + self.log_softmax = torch.nn.LogSoftmax(dim=2) + + def initialize_hiddens_state(self, batch_size): + + if self.cell_type == 'LSTM': + # Return tuple (hidden_state, memory_cell). + return (self.init_hidden.expand(self.num_layers, batch_size, self.hidden_size).contiguous(), + self.init_memory.expand(self.num_layers, batch_size, self.hidden_size).contiguous() ) + + else: + # Return hidden_state. + return self.init_hidden.expand(self.num_layers, batch_size, self.hidden_size).contiguous() + + + def input_data_definitions(self): + """ + Function returns a dictionary with definitions of input data that are required by the component. + + :return: dictionary containing input data definitions (each of type :py:class:`ptp.utils.DataDefinition`). + """ + d = {} + + d[self.key_inputs] = DataDefinition([-1, -1, self.input_size], [torch.Tensor], "Batch of inputs, each represented as index [BATCH_SIZE x SEQ_LEN x INPUT_SIZE]") + + return d + + def output_data_definitions(self): + """ + Function returns a dictionary with definitions of output data produced the component. + + :return: dictionary containing output data definitions (each of type :py:class:`ptp.utils.DataDefinition`). + """ + d = {} + + d[self.key_predictions] = DataDefinition([-1, -1, self.prediction_size], [torch.Tensor], "Batch of predictions, each represented as probability distribution over classes [BATCH_SIZE x SEQ_LEN x PREDICTION_SIZE]") + + return d + + def forward(self, data_dict): + """ + Forward pass of the model. + + :param data_dict: DataDict({'inputs', 'predictions ...}), where: + + - inputs: expected inputs [BATCH_SIZE x SEQ_LEN x INPUT_SIZE], + - predictions: returned output with predictions (log_probs) [BATCH_SIZE x SEQ_LEN x PREDICTION_SIZE] + """ + + # Get inputs [BATCH_SIZE x SEQ_LEN x INPUT_SIZE] + inputs = data_dict[self.key_inputs] + if inputs.dim() == 2: + inputs = inputs.unsqueeze(1) + batch_size = inputs.shape[0] + + + # Initialize hidden state. + hidden = self.initialize_hiddens_state(batch_size) + + + # Encoder + activations, hidden = self.rnn_cell_enc(inputs, hidden) + + # Propagate inputs through rnn cell. + activations_partial, hidden = self.rnn_cell_dec(activations[:, -1, :].unsqueeze(1), hidden) + activations = [] + activations += [activations_partial] + for i in range(self.autoregression_length - 1): + activations_partial, hidden = self.rnn_cell_dec(activations_partial, hidden) + activations += [activations_partial] + activations = torch.stack(activations, 1) + + # Pass every activation through the output layer. + # Reshape to 2D tensor [BATCH_SIZE * SEQ_LEN x HIDDEN_SIZE] + outputs = activations.contiguous().view(-1, self.hidden_size) + + # Propagate data through the output layer [BATCH_SIZE * SEQ_LEN x PREDICTION_SIZE] + outputs = self.activation2output(outputs) + + # Reshape back to 3D tensor [BATCH_SIZE x SEQ_LEN x PREDICTION_SIZE] + outputs = outputs.view(activations.size(0), activations.size(1), outputs.size(1)) + + # Log softmax - along PREDICTION dim. + if self.use_logsoftmax: + outputs = self.log_softmax(outputs) + + # Add predictions to datadict. + data_dict.extend({self.key_predictions: outputs}) +