Skip to content
This repository was archived by the owner on Jul 18, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions configs/mnist/mnist_classification_softmax.yml
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
# Load config defining MNIST problems for training, validation and testing.
default_configs: mnist/default_mnist.yml

training:
optimizer:
name: SGD


pipeline:
name: mnist_softmax_classifier

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@ default_configs: vqa_med_2019/default_vqa_med_2019.yml
# Training parameters:
training:
problem:
batch_size: 64
batch_size: 200 # requires to use 4 GPUs!
categories: C4
question_preprocessing: lowercase, remove_punctuation, tokenize #, random_remove_stop_words #,random_shuffle_words
answer_preprocessing: lowercase, remove_punctuation, tokenize
export_sample_weights: ~/data/vqa-med/answers.c4.weights.csv
batch_size: 32
sampler:
weights: ~/data/vqa-med/answers.c4.weights.csv
dataloader:
Expand All @@ -23,11 +22,10 @@ training:
# Validation parameters:
validation:
problem:
batch_size: 64
batch_size: 200
categories: C4
question_preprocessing: lowercase, remove_punctuation, tokenize
answer_preprocessing: lowercase, remove_punctuation, tokenize
batch_size: 32
dataloader:
num_workers: 4

Expand All @@ -49,7 +47,7 @@ pipeline:
pretrained_embeddings_file: glove.6B.100d.txt
data_folder: ~/data/vqa-med
word_mappings_file: questions.all.word.mappings.csv
fixed_padding: 10
fixed_padding: 10 # The longest question! max is 19!
additional_tokens: <PAD>,<EOS>
streams:
inputs: questions
Expand All @@ -66,7 +64,7 @@ pipeline:
export_pad_mapping_to_globals: True
additional_tokens: <PAD>,<EOS>
eos_token: True
fixed_padding: 10
fixed_padding: 10 # The longest question! max is 19!
streams:
inputs: answers
outputs: indexed_answers
Expand All @@ -88,9 +86,11 @@ pipeline:

# Single layer GRU Encoder
encoder:
priority: 3
type: RecurrentNeuralNetwork
# Do not wrap that model with DataDictParallel!
parallelize: False
cell_type: GRU
priority: 3
initial_state: Trainable
hidden_size: 100
num_layers: 1
Expand All @@ -110,7 +110,7 @@ pipeline:
reshaper_1:
priority: 3.01
type: ReshapeTensor
input_dims: [1, -1, 100]
input_dims: [-1, 1, 100]
output_dims: [-1, 100]
streams:
inputs: s2s_state_output
Expand Down Expand Up @@ -148,7 +148,7 @@ pipeline:
priority: 3.3
type: ReshapeTensor
input_dims: [-1, 100]
output_dims: [1, -1, 100]
output_dims: [-1, 1, 100]
streams:
inputs: question_image_activations
outputs: question_image_activations_reshaped
Expand All @@ -161,7 +161,7 @@ pipeline:
priority: 4
hidden_size: 100
use_logsoftmax: False
autoregression_length: 10
autoregression_length: 10 # Current implementation requires this value to be equal to fixed_padding in SentenceEmbeddings/Indexer...
prediction_mode: Dense
dropout_rate: 0.1
streams:
Expand Down Expand Up @@ -198,8 +198,8 @@ pipeline:

# Prediction decoding.
prediction_decoder:
type: SentenceIndexer
priority: 10
type: SentenceIndexer
# Reverse mode.
reverse: True
# Use distributions as inputs.
Expand Down
17 changes: 12 additions & 5 deletions ptp/components/models/attn_decoder_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(self, name, config):
# Create dropout layer.
self.dropout = torch.nn.Dropout(dropout_rate)

# Create rnn cell.
# Create rnn cell: hardcoded one layer GRU.
self.rnn_cell = getattr(torch.nn, "GRU")(self.input_size, self.hidden_size, 1, dropout=dropout_rate, batch_first=True)

# Create layers for the attention
Expand Down Expand Up @@ -149,7 +149,7 @@ def input_data_definitions(self):
d[self.key_inputs] = DataDefinition([-1, -1, self.hidden_size], [torch.Tensor], "Batch of encoder outputs [BATCH_SIZE x SEQ_LEN x INPUT_SIZE]")

# Input hidden state
d[self.key_input_state] = DataDefinition([1, -1, self.hidden_size], [torch.Tensor], "Batch of RNN last states")
d[self.key_input_state] = DataDefinition([-1, 1, self.hidden_size], [torch.Tensor], "Batch of RNN last hidden states passed from another RNN that will be used as initial [BATCH_SIZE x NUM_LAYERS x SEQ_LEN x HIDDEN_SIZE]")

return d

Expand All @@ -167,9 +167,9 @@ def output_data_definitions(self):
# Only last prediction.
d[self.key_predictions] = DataDefinition([-1, self.prediction_size], [torch.Tensor], "Batch of predictions, each represented as probability distribution over classes [BATCH_SIZE x SEQ_LEN x PREDICTION_SIZE]")

# Output hidden state stream
# Output hidden state stream TODO: why do we need that?
if self.output_last_state:
d[self.key_output_state] = DataDefinition([1, -1, self.hidden_size], [torch.Tensor], "Batch of RNN last states")
d[self.key_output_state] = DataDefinition([-1, 1, self.hidden_size], [torch.Tensor], "Batch of RNN final hidden states [BATCH_SIZE x NUM_LAYERS x SEQ_LEN x HIDDEN_SIZE]")

return d

Expand All @@ -185,9 +185,13 @@ def forward(self, data_dict):

inputs = data_dict[self.key_inputs]
batch_size = inputs.shape[0]
#print("{}: input shape: {}, device: {}\n".format(self.name, inputs.shape, inputs.device))

# Initialize hidden state.
# Initialize hidden state from inputs - as last hidden state from external component.
hidden = data_dict[self.key_input_state]
# For RNNs (aside of LSTM): [BATCH_SIZE x NUM_LAYERS x HIDDEN_SIZE] -> [NUM_LAYERS x BATCH_SIZE x HIDDEN_SIZE]
hidden = hidden.transpose(0,1)
#print("{}: hidden shape: {}, device: {}\n".format(self.name, hidden.shape, hidden.device))

# List that will contain the output sequence
activations = []
Expand Down Expand Up @@ -232,4 +236,7 @@ def forward(self, data_dict):

# Output last hidden state, if requested
if self.output_last_state:
# For others: [NUM_LAYERS x BATCH_SIZE x HIDDEN_SIZE] -> [BATCH_SIZE x NUM_LAYERS x HIDDEN_SIZE]
hidden = hidden.transpose(0,1)
# Export last hidden state.
data_dict.extend({self.key_output_state: hidden})
30 changes: 23 additions & 7 deletions ptp/components/models/recurrent_neural_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def __init__(self, name, config):
except KeyError:
raise ConfigurationError( "Invalid RNN type, available options for 'cell_type' are ['LSTM', 'GRU', 'RNN_TANH', 'RNN_RELU'] (currently '{}')".format(self.cell_type))

# Parameters - for a single sample.
# Parameters - for a single sample 2 x [NUM_LAYERS x BATCH_SIZE x HIDDEN_SIZE]
h0 = torch.zeros(self.num_layers, 1, self.hidden_size)
c0 = torch.zeros(self.num_layers, 1, self.hidden_size)

Expand Down Expand Up @@ -228,9 +228,9 @@ def input_data_definitions(self):
# Input hidden state
if self.initial_state == "Input":
if self.cell_type == "LSTM":
d[self.key_input_state] = DataDefinition([2, self.num_layers, -1, self.hidden_size], [torch.Tensor], "Batch of LSTM initial hidden states (h0/c0) passed from another LSTM [2 x NUM_LAYERS x SEQ_LEN x HIDDEN_SIZE]")
d[self.key_input_state] = DataDefinition([-1, 2, self.num_layers, self.hidden_size], [torch.Tensor], "Batch of LSTM last hidden states (h0/c0) passed from another LSTM that will be used as initial [BATCH_SIZE x 2 x NUM_LAYERS x SEQ_LEN x HIDDEN_SIZE]")
else:
d[self.key_input_state] = DataDefinition([self.num_layers, -1, self.hidden_size], [torch.Tensor], "Batch of RNN initial hidden states passed from another RNN [NUM_LAYERS x SEQ_LEN x HIDDEN_SIZE]")
d[self.key_input_state] = DataDefinition([-1, self.num_layers, self.hidden_size], [torch.Tensor], "Batch of RNN last hidden states passed from another RNN that will be used as initial [BATCH_SIZE x NUM_LAYERS x SEQ_LEN x HIDDEN_SIZE]")

return d

Expand All @@ -253,9 +253,9 @@ def output_data_definitions(self):
# Output: hidden state stream.
if self.output_last_state:
if self.cell_type == "LSTM":
d[self.key_output_state] = DataDefinition([2, self.num_layers, -1, self.hidden_size], [torch.Tensor], "Batch of LSTM final hidden states (h0/c0) [2 x NUM_LAYERS x SEQ_LEN x HIDDEN_SIZE]")
d[self.key_output_state] = DataDefinition([-1, 2, self.num_layers, self.hidden_size], [torch.Tensor], "Batch of LSTM final hidden states (h0/c0) [BATCH_SIZE x 2 x NUM_LAYERS x SEQ_LEN x HIDDEN_SIZE]")
else:
d[self.key_output_state] = DataDefinition([self.num_layers, -1, self.hidden_size], [torch.Tensor], "Batch of RNN final hidden states [NUM_LAYERS x SEQ_LEN x HIDDEN_SIZE]")
d[self.key_output_state] = DataDefinition([-1, self.num_layers, self.hidden_size], [torch.Tensor], "Batch of RNN final hidden states [BATCH_SIZE x NUM_LAYERS x SEQ_LEN x HIDDEN_SIZE]")

return d

Expand Down Expand Up @@ -285,14 +285,22 @@ def forward(self, data_dict):
if inputs.dim() == 2:
inputs = inputs.unsqueeze(1)
batch_size = inputs.shape[0]

#print("{}: input shape: {}, device: {}\n".format(self.name, inputs.shape, inputs.device))

# Get initial state, depending on the settings.
if self.initial_state == "Input":
# Initialize hidden state.
# Initialize hidden state from inputs - as last hidden state from external component.
hidden = data_dict[self.key_input_state]
# Flip batch and num_layer dims so batch will be third/second!
if self.cell_type == 'LSTM':
# For LSTM: [BATCH_SIZE x NUM_LAYERS x 2 x HIDDEN_SIZE] -> [2 x NUM_LAYERS x BATCH_SIZE x HIDDEN_SIZE]
hidden = hidden.transpose(0,2)
else:
# For others: [BATCH_SIZE x NUM_LAYERS x HIDDEN_SIZE] -> [NUM_LAYERS x BATCH_SIZE x HIDDEN_SIZE]
hidden = hidden.transpose(0,1)
else:
hidden = self.initialize_hiddens_state(batch_size)
#print("{}: hidden shape: {}, device: {}\n".format(self.name, hidden.shape, hidden.device))

activations = []

Expand Down Expand Up @@ -355,4 +363,12 @@ def forward(self, data_dict):
pass

if self.output_last_state:
# Flip batch and num_layer dims so batch will be first!
if self.cell_type == 'LSTM':
# For LSTM: [2 x NUM_LAYERS x BATCH_SIZE x HIDDEN_SIZE] -> [BATCH_SIZE x NUM_LAYERS x 2 x HIDDEN_SIZE]
hidden = hidden.transpose(0,2)
else:
# For others: [NUM_LAYERS x BATCH_SIZE x HIDDEN_SIZE] -> [BATCH_SIZE x NUM_LAYERS x HIDDEN_SIZE]
hidden = hidden.transpose(0,1)
# Export last hidden state.
data_dict.extend({self.key_output_state: hidden})
35 changes: 7 additions & 28 deletions ptp/components/models/sentence_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ def __init__(self, name, config):
emb_vectors = emb.load_pretrained_embeddings(self.logger, self.data_folder, self.config["pretrained_embeddings_file"], self.word_to_ix, self.embeddings_size)
self.embeddings.weight = torch.nn.Parameter(emb_vectors)

# Get index of <PAD> from vocabulary.
self.pad_index = self.word_to_ix['<PAD>']



def input_data_definitions(self):
"""
Expand Down Expand Up @@ -110,12 +114,8 @@ def forward(self, data_dict):

# Unpack DataDict.
inputs = data_dict[self.key_inputs]

#print("{}: input len: {}, device: {}\n".format(self.name, len(inputs), "-"))

# Get index of padding.
pad_index = self.word_to_ix['<PAD>']

indices_list = []
# Process samples 1 by one.
for sample in inputs:
Expand All @@ -132,35 +132,14 @@ def forward(self, data_dict):
# Apply fixed padding to all sequences if requested
# Otherwise let torch.nn.utils.rnn.pad_sequence handle it and choose a dynamic padding
if self.fixed_padding > 0:
pad_trunc_list(output_sample, self.fixed_padding, padding_value=pad_index)
pad_trunc_list(output_sample, self.fixed_padding, padding_value=self.pad_index)

#indices_list.append(self.app_state.FloatTensor(output_sample))
indices_list.append(self.app_state.LongTensor(output_sample))

# Create list of (index,len) tuples.
#order_len = [(i, len(indices_list[i])) for i in range(len(indices_list))]

# Sort by seq_length.
#sorted_order_len = sorted(order_len, key=lambda x: x[1], reverse=True)

# Now sort indices list.
#sorted_indices_list = [indices_list[sorted_order_len[i][0]] for i in range(len(indices_list))]

# Pad the indices list.
#padded_indices = torch.nn.utils.rnn.pad_sequence(sorted_indices_list, batch_first=True)

# Revert to the original batch order!!!
#new_old_order = [(i, sorted_order_len[i][0]) for i in range(len(indices_list))]
#sorted_new_old_order = sorted(new_old_order, key=lambda x: x[1])
#unsorted_padded_indices = [padded_indices[sorted_new_old_order[i][0]] for i in range(len(indices_list))]

# Change to tensor.
#unsorted_padded_indices_tensor = torch.stack( [self.app_state.FloatTensor(lst) for lst in unsorted_padded_indices] )

# Padd indices using pad index retrieved from vocabulary.
padded_indices = torch.nn.utils.rnn.pad_sequence(indices_list, batch_first=True, padding_value=self.pad_index)
# Embedd indices.
#embedds = self.embeddings(unsorted_padded_indices_tensor)

padded_indices = torch.nn.utils.rnn.pad_sequence(indices_list, batch_first=True, padding_value=pad_index)
embedds = self.embeddings(padded_indices)

#print("{}: embedds shape: {}, device: {}\n".format(self.name, embedds.shape, embedds.device))
Expand Down
5 changes: 3 additions & 2 deletions ptp/components/utils/word_mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,14 @@ def pad_trunc_list(l: list, length: int, padding_value = 0, eos_value = None):

:return: None
"""

if len(l) < length:
if eos_value is not None:
l.append(eos_value)
l.extend([padding_value]*(length-len(l)))

elif len(l) > length:
#print("pad_trunc_list to cat!: {}".format(len(l)))
#exit(1)
del l[length:]
if eos_value is not None:
l[length-1] = eos_value
l[length-1] = eos_value