Skip to content
This repository was archived by the owner on Jul 18, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,6 @@ venv.bak/

# mypy
.mypy_cache/

# vscode
.vscode/
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ dropout_rate: 0
# * None (all outputs are discarded)
prediction_mode: Dense

# Enable FFN layer at the output of the RNN (before eventual feed back in the case of autoregression).
# Useful if the raw outputs of the RNN are needed, for attention encoder-decoder for example.
ffn_output: True

# Input mode
# Options:
# * Dense (every iteration expects an input)
Expand Down
36 changes: 7 additions & 29 deletions configs/wikitext/wikitext_language_modeling_seq2seq.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,28 +73,6 @@ pipeline:
streams:
inputs: targets
outputs: indexed_targets

# Publish the hidden size of the seq2seq
global_publisher:
type: GlobalVariablePublisher
priority: 1
# Add input_size to globals, so classifier will use it.
keys: s2s_hidden_size
values: 300

# FF, to resize the embeddings to whatever the hidden size of te seq2seq is.
ff_resize_s2s_input:
type: FeedForwardNetwork
priority: 2.5
s2s_hidden_size: 300
use_logsoftmax: False
dimensions: 3
streams:
inputs: embedded_sources
predictions: embedded_sources_resized
globals:
input_size: embeddings_size
prediction_size: s2s_hidden_size

# LSTM Encoder
lstm_encoder:
Expand All @@ -107,12 +85,12 @@ pipeline:
output_last_state: True
prediction_mode: Last
streams:
inputs: embedded_sources_resized
inputs: embedded_sources
predictions: s2s_encoder_output
output_state: s2s_state_output
globals:
input_size: s2s_hidden_size
prediction_size: s2s_hidden_size
input_size: embeddings_size
prediction_size: embeddings_size

# LSTM Decoder
lstm_decoder:
Expand All @@ -130,10 +108,10 @@ pipeline:
predictions: s2s_decoder_output
input_state: s2s_state_output
globals:
input_size: s2s_hidden_size
prediction_size: s2s_hidden_size
input_size: embeddings_size
prediction_size: embeddings_size

# FF, to resize the from the hidden size of the seq2seq to the size of the target vector
# FF, to resize the from the output size of the seq2seq to the size of the target vector
ff_resize_s2s_output:
type: FeedForwardNetwork
use_logsoftmax: True
Expand All @@ -142,7 +120,7 @@ pipeline:
streams:
inputs: s2s_decoder_output
globals:
input_size: s2s_hidden_size
input_size: embeddings_size
prediction_size: vocabulary_size

# Loss
Expand Down
30 changes: 4 additions & 26 deletions configs/wikitext/wikitext_language_modeling_seq2seq_simple.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,28 +73,6 @@ pipeline:
streams:
inputs: targets
outputs: indexed_targets

# Publish the hidden size of the seq2seq
global_publisher:
type: GlobalVariablePublisher
priority: 1
# Add input_size to globals, so classifier will use it.
keys: s2s_hidden_size
values: 300

# FF, to resize the embeddings to whatever the hidden size of te seq2seq is.
ff_resize_s2s_input:
type: FeedForwardNetwork
priority: 2.5
s2s_hidden_size: 300
use_logsoftmax: False
dimensions: 3
streams:
inputs: embedded_sources
predictions: embedded_sources_resized
globals:
input_size: embeddings_size
prediction_size: s2s_hidden_size

# LSTM seq2seq
lstm_encoder:
Expand All @@ -105,11 +83,11 @@ pipeline:
num_layers: 3
use_logsoftmax: False
streams:
inputs: embedded_sources_resized
inputs: embedded_sources
predictions: s2s_output
globals:
input_size: s2s_hidden_size
prediction_size: s2s_hidden_size
input_size: embeddings_size
prediction_size: embeddings_size

# FF, to resize the from the hidden size of the seq2seq to the size of the target vector
ff_resize_s2s_output:
Expand All @@ -120,7 +98,7 @@ pipeline:
streams:
inputs: s2s_output
globals:
input_size: s2s_hidden_size
input_size: embeddings_size
prediction_size: vocabulary_size

# Loss
Expand Down
116 changes: 73 additions & 43 deletions ptp/components/models/recurrent_neural_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self, name, config):
# Get input/output mode
self.input_mode = self.config["input_mode"]
self.output_last_state = self.config["output_last_state"]
self.ffn_output = self.config["ffn_output"]

# Get prediction mode from configuration.
self.prediction_mode = self.config["prediction_mode"]
Expand Down Expand Up @@ -68,6 +69,9 @@ def __init__(self, name, config):
self.prediction_size = self.prediction_size[0]
else:
raise ConfigurationError("RNN prediction size '{}' must be a single dimension (current {})".format(self.key_prediction_size, self.prediction_size))

if "Autoregression" in self.input_mode:
assert self.input_size == self.prediction_size, "In autoregression mode, needs input_size == prediction_size."

# Retrieve hidden size from configuration.
self.hidden_size = self.config["hidden_size"]
Expand Down Expand Up @@ -134,7 +138,9 @@ def __init__(self, name, config):
self.logger.info("Initializing RNN with input size = {}, hidden size = {} and prediction size = {}".format(self.input_size, self.hidden_size, self.prediction_size))

# Create the output layer.
self.activation2output = torch.nn.Linear(self.hidden_size, self.prediction_size)
self.activation2output_lin = None
if(self.ffn_output):
self.activation2output_lin = torch.nn.Linear(self.hidden_size, self.prediction_size)

# Create the final non-linearity.
self.use_logsoftmax = self.config["use_logsoftmax"]
Expand All @@ -157,6 +163,25 @@ def initialize_hiddens_state(self, batch_size):
# Return hidden_state.
return self.init_hidden.expand(self.num_layers, batch_size, self.hidden_size).contiguous()

def activation2output(self, activations):
output = self.dropout(activations)

if(self.ffn_output):
#output = activations.squeeze(1)
shape = activations.shape

# Reshape to 2D tensor [BATCH_SIZE * SEQ_LEN x HIDDEN_SIZE]
output = output.contiguous().view(-1, shape[2])

# Propagate data through the output layer [BATCH_SIZE * SEQ_LEN x PREDICTION_SIZE]
output = self.activation2output_lin(output)
#output = output.unsqueeze(1)

# Reshape back to 3D tensor [BATCH_SIZE x SEQ_LEN x PREDICTION_SIZE]
output = output.view(shape[0], shape[1], output.size(1))

return output


def input_data_definitions(self):
"""
Expand All @@ -173,7 +198,10 @@ def input_data_definitions(self):

# Input hidden state
if self.initial_state == "Input":
d[self.key_input_state] = DataDefinition([-1, 2 if self.cell_type == 'LSTM' else 1, self.input_size, 1, self.hidden_size], [torch.tensor], "Batch of RNN last states")
if self.cell_type == "LSTM":
d[self.key_input_state] = DataDefinition([2, self.num_layers, -1, self.hidden_size], [torch.Tensor], "Batch of RNN last states")
else:
d[self.key_input_state] = DataDefinition([self.num_layers, -1, self.hidden_size], [torch.Tensor], "Batch of RNN last states")

return d

Expand All @@ -193,8 +221,11 @@ def output_data_definitions(self):

# Output hidden state stream
if self.output_last_state:
d[self.key_output_state] = DataDefinition([-1, 2 if self.cell_type == 'LSTM' else 1, self.input_size, 1, self.hidden_size], [torch.tensor], "Batch of RNN last states")

if self.cell_type == "LSTM":
d[self.key_output_state] = DataDefinition([2, self.num_layers, -1, self.hidden_size], [torch.Tensor], "Batch of RNN last states")
else:
d[self.key_output_state] = DataDefinition([self.num_layers, -1, self.hidden_size], [torch.Tensor], "Batch of RNN last states")

return d

def forward(self, data_dict):
Expand All @@ -213,10 +244,9 @@ def forward(self, data_dict):
# Get inputs [BATCH_SIZE x SEQ_LEN x INPUT_SIZE]
if "None" in self.input_mode:
batch_size = data_dict[self.key_input_state][0].shape[1]
inputs = torch.zeros(batch_size, 1, self.hidden_size)
inputs = torch.zeros(batch_size, self.hidden_size)
if next(self.parameters()).is_cuda:
inputs = inputs.cuda()

else:
inputs = data_dict[self.key_inputs]
if inputs.dim() == 2:
Expand All @@ -235,56 +265,56 @@ def forward(self, data_dict):
# Autoregressive mode - feed back outputs in the input
if "Autoregression" in self.input_mode:
activations_partial, hidden = self.rnn_cell(inputs, hidden)
activations_partial = self.activation2output(activations_partial)
activations += [activations_partial]

# Feed back the outputs iteratively
for i in range(self.autoregression_length - 1):
activations_partial, hidden = self.rnn_cell(activations_partial, hidden)
activations_partial = self.activation2output(activations_partial)
# Add the single step output into list
if self.prediction_mode == "Dense":
activations += [activations_partial]
# Reassemble all the outputs from list into an output sequence
if self.prediction_mode == "Dense":
activations = torch.stack(activations, 1)
else:
activations = activations_partial
outputs = torch.cat(activations, 1)
# Log softmax - along PREDICTION dim.
if self.use_logsoftmax:
outputs = self.log_softmax(outputs)
# Add predictions to datadict.
data_dict.extend({self.key_predictions: outputs})
elif self.prediction_mode == "Last":
if self.use_logsoftmax:
outputs = self.log_softmax(activations_partial.squeeze(1))
# Add predictions to datadict.
data_dict.extend({self.key_predictions: outputs})

# Normal mode - feed the entire input sequence at once
else:
activations, hidden = self.rnn_cell(inputs, hidden)


# Propagate activations through dropout layer.
activations = self.dropout(activations)

if self.prediction_mode == "Dense":
# Pass every activation through the output layer.
# Reshape to 2D tensor [BATCH_SIZE * SEQ_LEN x HIDDEN_SIZE]
outputs = activations.contiguous().view(-1, self.hidden_size)

# Propagate data through the output layer [BATCH_SIZE * SEQ_LEN x PREDICTION_SIZE]
outputs = self.activation2output(outputs)

# Reshape back to 3D tensor [BATCH_SIZE x SEQ_LEN x PREDICTION_SIZE]
outputs = outputs.view(activations.size(0), activations.size(1), outputs.size(1))

# Log softmax - along PREDICTION dim.
if self.use_logsoftmax:
outputs = self.log_softmax(outputs)

# Add predictions to datadict.
data_dict.extend({self.key_predictions: outputs})
elif self.prediction_mode == "Last":
# Pass only the last activation through the output layer.
outputs = activations.contiguous()[:, -1, :].squeeze()
# Propagate data through the output layer [BATCH_SIZE x PREDICTION_SIZE]
outputs = self.activation2output(outputs)
# Log softmax - along PREDICTION dim.
if self.use_logsoftmax:
outputs = self.log_softmax(outputs)
# Add predictions to datadict.
data_dict.extend({self.key_predictions: outputs})
elif self.prediction_mode == "None":
# Nothing, since we don't want to keep the RNN's outputs
pass
if self.prediction_mode == "Dense":
# Pass every activation through the output layer.
outputs = self.activation2output(activations)

# Log softmax - along PREDICTION dim.
if self.use_logsoftmax:
outputs = self.log_softmax(outputs)

# Add predictions to datadict.
data_dict.extend({self.key_predictions: outputs})
elif self.prediction_mode == "Last":
outputs = self.activation2output(activations.contiguous()[:, -1, :].unsqueeze(1))
outputs = outputs.squeeze(1)

# Log softmax - along PREDICTION dim.
if self.use_logsoftmax:
outputs = self.log_softmax(outputs)
# Add predictions to datadict.
data_dict.extend({self.key_predictions: outputs})
elif self.prediction_mode == "None":
# Nothing, since we don't want to keep the RNN's outputs
pass

if self.output_last_state:
data_dict.extend({self.key_output_state: hidden})
24 changes: 9 additions & 15 deletions ptp/components/models/seq2seq_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,25 +185,19 @@ def forward(self, data_dict):

# Encoder
activations, hidden = self.rnn_cell_enc(inputs, hidden)
activations_partial = self.activation2output(activations[:, -1, :])

# Propagate inputs through rnn cell.
activations_partial, hidden = self.rnn_cell_dec(activations[:, -1, :].unsqueeze(1), hidden)
activations = []
activations += [activations_partial]
activations_partial, hidden = self.rnn_cell_dec(activations_partial.unsqueeze(1), hidden)
activations_partial = activations_partial.squeeze(1)
activations_partial = self.activation2output(activations_partial)
activations = [activations_partial]
for i in range(self.autoregression_length - 1):
activations_partial, hidden = self.rnn_cell_dec(activations_partial, hidden)
activations_partial, hidden = self.rnn_cell_dec(activations_partial.unsqueeze(1), hidden)
activations_partial = activations_partial.squeeze(1)
activations_partial = self.activation2output(activations_partial)
activations += [activations_partial]
activations = torch.stack(activations, 1)

# Pass every activation through the output layer.
# Reshape to 2D tensor [BATCH_SIZE * SEQ_LEN x HIDDEN_SIZE]
outputs = activations.contiguous().view(-1, self.hidden_size)

# Propagate data through the output layer [BATCH_SIZE * SEQ_LEN x PREDICTION_SIZE]
outputs = self.activation2output(outputs)

# Reshape back to 3D tensor [BATCH_SIZE x SEQ_LEN x PREDICTION_SIZE]
outputs = outputs.view(activations.size(0), activations.size(1), outputs.size(1))
outputs = torch.stack(activations, 1)

# Log softmax - along PREDICTION dim.
if self.use_logsoftmax:
Expand Down