diff --git a/configs/default/components/models/vqa/relational_network.yml b/configs/default/components/models/vqa/relational_network.yml new file mode 100644 index 0000000..791d62e --- /dev/null +++ b/configs/default/components/models/vqa/relational_network.yml @@ -0,0 +1,55 @@ +# This file defines the default values for the ElementWiseMultiplication model. + +#################################################################### +# 1. CONFIGURATION PARAMETERS that will be LOADED by the component. +#################################################################### + +# Dropout rate (LOADED) +# Default: 0 (means that it is turned off) +dropout_rate: 0 + +# Size of the output of g_theta network/output after concatenation (LOADED) +output_size: 256 + +streams: + #################################################################### + # 2. Keymappings associated with INPUT and OUTPUT streams. + #################################################################### + + # Stream containing batch of encoded images (INPUT) + feature_maps: feature_maps + + # Stream containing batch of encoded questions (INPUT) + question_encodings: question_encodings + + # Stream containing outputs (OUTPUT) + outputs: outputs + +globals: + #################################################################### + # 3. Keymappings of variables that will be RETRIEVED from GLOBALS. + #################################################################### + + # Height of the features tensor (RETRIEVED) + feature_maps_height: feature_maps_height + + # Width of the features tensor (RETRIEVED) + feature_maps_width: feature_maps_width + + # Depth of the features tensor (RETRIEVED) + feature_maps_depth: feature_maps_depth + + # Size of the question encodings input (RETRIEVED) + question_encoding_size: question_encoding_size + + #################################################################### + # 4. Keymappings associated with GLOBAL variables that will be SET. + #################################################################### + + # Size of the output (SET) + output_size: output_size + + #################################################################### + # 5. Keymappings associated with statistics that will be ADDED. + #################################################################### + diff --git a/configs/vqa_med_2019/c2_classification/c2_classification_all_rnn_vgg16_ewm_size.yml b/configs/vqa_med_2019/c2_classification/c2_classification_all_rnn_vgg16_ewm_size.yml index 91ff5d1..72c5e1a 100644 --- a/configs/vqa_med_2019/c2_classification/c2_classification_all_rnn_vgg16_ewm_size.yml +++ b/configs/vqa_med_2019/c2_classification/c2_classification_all_rnn_vgg16_ewm_size.yml @@ -1,6 +1,14 @@ # Load config defining problems for training, validation and testing. default_configs: vqa_med_2019/c2_classification/default_c2_classification.yml +# Training parameters: +training: + problem: + batch_size: 64 +validation: + problem: + batch_size: 64 + pipeline: name: c2_classification_all_rnn_vgg16_ewm_size @@ -24,8 +32,8 @@ pipeline: question_embeddings: priority: 1.2 type: SentenceEmbeddings - embeddings_size: 50 - pretrained_embeddings_file: glove.6B.50d.txt + embeddings_size: 100 + pretrained_embeddings_file: glove.6B.100d.txt data_folder: ~/data/vqa-med word_mappings_file: questions.all.word.mappings.csv streams: @@ -39,8 +47,9 @@ pipeline: cell_type: LSTM prediction_mode: Last use_logsoftmax: False - initial_state_trainable: False + initial_state_trainable: True hidden_size: 50 + #dropout_rate: 0.5 streams: inputs: embedded_questions predictions: question_activations @@ -117,7 +126,7 @@ pipeline: classifier: priority: 5.3 type: FeedForwardNetwork - hidden_sizes: [110] + hidden_sizes: [100] dropout_rate: 0.5 streams: inputs: concatenated_activations diff --git a/configs/vqa_med_2019/c2_classification/c2_classification_all_rnn_vgg16_mcb.yml b/configs/vqa_med_2019/c2_classification/c2_classification_all_rnn_vgg16_mcb.yml index 0ea068e..75e41ed 100644 --- a/configs/vqa_med_2019/c2_classification/c2_classification_all_rnn_vgg16_mcb.yml +++ b/configs/vqa_med_2019/c2_classification/c2_classification_all_rnn_vgg16_mcb.yml @@ -72,7 +72,6 @@ pipeline: question_image_fusion: priority: 4.1 type: MultimodalCompactBilinearPooling - dropout_rate: 0.5 streams: image_encodings: image_activations question_encodings: question_activations diff --git a/configs/vqa_med_2019/c2_classification/c2_classification_all_rnn_vgg16_relational_net.yml b/configs/vqa_med_2019/c2_classification/c2_classification_all_rnn_vgg16_relational_net.yml new file mode 100644 index 0000000..5268ba7 --- /dev/null +++ b/configs/vqa_med_2019/c2_classification/c2_classification_all_rnn_vgg16_relational_net.yml @@ -0,0 +1,97 @@ +# Load config defining problems for training, validation and testing. +default_configs: vqa_med_2019/c2_classification/default_c2_classification.yml + +# Training parameters: +training: + problem: + batch_size: 64 +validation: + problem: + batch_size: 64 + +pipeline: + name: c2_classification_all_rnn_vgg16_relational_net + + global_publisher: + priority: 0 + type: GlobalVariablePublisher + # Add input_size to globals. + keys: [question_encoder_output_size] + values: [100] + + ################# PIPE 0: question ################# + # Questions encoding. + question_tokenizer: + priority: 1.1 + type: SentenceTokenizer + streams: + inputs: questions + outputs: tokenized_questions + + # Model 1: Embeddings + question_embeddings: + priority: 1.2 + type: SentenceEmbeddings + embeddings_size: 100 + pretrained_embeddings_file: glove.6B.100d.txt + data_folder: ~/data/vqa-med + word_mappings_file: questions.all.word.mappings.csv + streams: + inputs: tokenized_questions + outputs: embedded_questions + + # Model 2: RNN + question_lstm: + priority: 1.3 + type: RecurrentNeuralNetwork + cell_type: LSTM + prediction_mode: Last + use_logsoftmax: False + initial_state_trainable: True + #dropout_rate: 0.5 + hidden_size: 50 + streams: + inputs: embedded_questions + predictions: question_activations + globals: + input_size: embeddings_size + prediction_size: question_encoder_output_size + + ################# PIPE 2: image ################# + # Image encoder. + image_encoder: + priority: 3.1 + type: TorchVisionWrapper + return_feature_maps: True + frozen: True + freeze: True + streams: + inputs: images + outputs: feature_maps + + ################# PIPE 3: fusion + classification ################# + # Element wise multiplication + FF. + question_image_fusion: + priority: 4.1 + type: RelationalNetwork + dropout_rate: 0.5 + output_size: 100 + streams: + question_encodings: question_activations + outputs: fused_image_question_activations + globals: + question_encoding_size: question_encoder_output_size + output_size: fused_image_question_activation_size + + classifier: + priority: 4.2 + type: FeedForwardNetwork + hidden_sizes: [100,100] + dropout_rate: 0.5 + streams: + inputs: fused_image_question_activations + globals: + input_size: fused_image_question_activation_size + prediction_size: vocabulary_size_c2 + + #: pipeline diff --git a/configs/vqa_med_2019/c4_classification/c4_classification_all_rnn_vgg16_ewm_size.yml b/configs/vqa_med_2019/c4_classification/c4_classification_all_rnn_vgg16_ewm_size.yml new file mode 100644 index 0000000..a9b8266 --- /dev/null +++ b/configs/vqa_med_2019/c4_classification/c4_classification_all_rnn_vgg16_ewm_size.yml @@ -0,0 +1,130 @@ +# Load config defining problems for training, validation and testing. +default_configs: vqa_med_2019/c4_classification/default_c4_classification.yml + +pipeline: + name: c4_classification_all_rnn_vgg16_ewm_size + + global_publisher: + priority: 0 + type: GlobalVariablePublisher + # Add input_size to globals. + keys: [question_encoder_output_size, image_encoder_output_size, element_wise_activation_size,image_size_encoder_input_size, image_size_encoder_output_size] + values: [100, 100, 100, 2, 10] + + ################# PIPE 0: question ################# + # Questions encoding. + question_tokenizer: + priority: 1.1 + type: SentenceTokenizer + streams: + inputs: questions + outputs: tokenized_questions + + # Model 1: Embeddings + question_embeddings: + priority: 1.2 + type: SentenceEmbeddings + embeddings_size: 100 + pretrained_embeddings_file: glove.6B.100d.txt + data_folder: ~/data/vqa-med + word_mappings_file: questions.all.word.mappings.csv + streams: + inputs: tokenized_questions + outputs: embedded_questions + + # Model 2: RNN + question_lstm: + priority: 1.3 + type: RecurrentNeuralNetwork + cell_type: LSTM + prediction_mode: Last + use_logsoftmax: False + initial_state_trainable: True + hidden_size: 50 + #dropout_rate: 0.5 + streams: + inputs: embedded_questions + predictions: question_activations + globals: + input_size: embeddings_size + prediction_size: question_encoder_output_size + + ################# PIPE 2: image ################# + # Image encoder. + image_encoder: + priority: 3.1 + type: TorchVisionWrapper + streams: + inputs: images + outputs: image_activations + globals: + output_size: image_encoder_output_size + + ################# PIPE 3: image-question fusion ################# + # Element wise multiplication + FF. + question_image_fusion: + priority: 4.1 + type: ElementWiseMultiplication + dropout_rate: 0.5 + streams: + image_encodings: image_activations + question_encodings: question_activations + outputs: element_wise_activations + globals: + image_encoding_size: image_encoder_output_size + question_encoding_size: question_encoder_output_size + output_size: element_wise_activation_size + + question_image_ffn: + priority: 4.2 + type: FeedForwardNetwork + hidden_sizes: [100] + dropout_rate: 0.5 + streams: + inputs: element_wise_activations + predictions: question_image_activations + globals: + input_size: element_wise_activation_size + prediction_size: element_wise_activation_size + + ################# PIPE 4: image-question-image size fusion + classification ################# + # 2nd subpipeline: image size. + # Model - image size classifier. + image_size_encoder: + priority: 5.1 + type: FeedForwardNetwork + streams: + inputs: image_sizes + predictions: image_size_activations + globals: + input_size: image_size_encoder_input_size + prediction_size: image_size_encoder_output_size + + # 4th subpipeline: concatenation + FF. + concat: + priority: 5.2 + type: Concatenation + input_streams: [question_image_activations,image_size_activations] + # Concatenation + dim: 1 # default + input_dims: [[-1,100],[-1,10]] + output_dims: [-1,110] + streams: + outputs: concatenated_activations + globals: + output_size: concatentated_activations_size + + + classifier: + priority: 5.3 + type: FeedForwardNetwork + hidden_sizes: [500] + dropout_rate: 0.5 + streams: + inputs: concatenated_activations + globals: + input_size: concatentated_activations_size + prediction_size: vocabulary_size_c4 + + + #: pipeline diff --git a/configs/vqa_med_2019/c4_classification/default_c4_classification.yml b/configs/vqa_med_2019/c4_classification/default_c4_classification.yml new file mode 100644 index 0000000..e221187 --- /dev/null +++ b/configs/vqa_med_2019/c4_classification/default_c4_classification.yml @@ -0,0 +1,98 @@ +# Load config defining problems for training, validation and testing. +default_configs: vqa_med_2019/default_vqa_med_2019.yml + +# Training parameters: +training: + problem: + batch_size: 64 + categories: C4 + sampler: + name: WeightedRandomSampler + weights: ~/data/vqa-med/answers.c4.weights.csv + dataloader: + num_workers: 4 + # Termination. + terminal_conditions: + loss_stop: 1.0e-2 + episode_limit: 10000 + epoch_limit: -1 + +# Validation parameters: +validation: + problem: + batch_size: 64 + categories: C4 + dataloader: + num_workers: 4 + + +pipeline: + + # Answer encoding. + answer_indexer: + type: LabelIndexer + priority: 0.1 + data_folder: ~/data/vqa-med + word_mappings_file: answers.c4.word.mappings.csv + # Export mappings and size to globals. + export_word_mappings_to_globals: True + streams: + inputs: answers + outputs: answers_ids + globals: + vocabulary_size: vocabulary_size_c4 + word_mappings: word_mappings_c4 + + + # Predictions decoder. + prediction_decoder: + type: WordDecoder + priority: 10.1 + # Use the same word mappings as label indexer. + import_word_mappings_from_globals: True + streams: + inputs: predictions + outputs: predicted_answers + globals: + vocabulary_size: vocabulary_size_c4 + word_mappings: word_mappings_c4 + + # Loss + nllloss: + type: NLLLoss + priority: 10.2 + targets_dim: 1 + streams: + targets: answers_ids + loss: loss + + # Statistics. + batch_size: + type: BatchSizeStatistics + priority: 100.1 + + #accuracy: + # type: AccuracyStatistics + # priority: 100.2 + # streams: + # targets: answers_ids + + precision_recall: + type: PrecisionRecallStatistics + priority: 100.3 + use_word_mappings: True + show_class_scores: True + show_confusion_matrix: True + streams: + targets: answers_ids + globals: + word_mappings: word_mappings_c4 + num_classes: vocabulary_size_c4 + + # Viewers. + viewer: + type: StreamViewer + priority: 100.4 + input_streams: questions,category_names,answers,predicted_answers + +#: pipeline diff --git a/ptp/components/models/__init__.py b/ptp/components/models/__init__.py index 3451d2f..8d04135 100644 --- a/ptp/components/models/__init__.py +++ b/ptp/components/models/__init__.py @@ -10,6 +10,7 @@ from .vqa.element_wise_multiplication import ElementWiseMultiplication from .vqa.multimodal_compact_bilinear_pooling import MultimodalCompactBilinearPooling +from .vqa.relational_network import RelationalNetwork __all__ = [ 'ConvNetEncoder', @@ -23,4 +24,5 @@ 'Seq2Seq_RNN', 'ElementWiseMultiplication', 'MultimodalCompactBilinearPooling', + 'RelationalNetwork', ] diff --git a/ptp/components/models/vqa/relational_network.py b/ptp/components/models/vqa/relational_network.py new file mode 100644 index 0000000..5b5763a --- /dev/null +++ b/ptp/components/models/vqa/relational_network.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (C) IBM Corporation 2018 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__author__ = "Tomasz Kornuta" + + +import torch + +from ptp.components.models.model import Model +from ptp.data_types.data_definition import DataDefinition + + +class RelationalNetwork(Model): + """ + Model implements relational network. + Model expects image (CNN) features and encoded question. + + + Santoro, A., Raposo, D., Barrett, D. G., Malinowski, M., Pascanu, R., Battaglia, P., & Lillicrap, T. (2017). A simple neural network module for relational reasoning. In Advances in neural information processing systems (pp. 4967-4976). + Reference paper: https://arxiv.org/abs/1706.01427. + """ + def __init__(self, name, config): + """ + Initializes the model, creates the required layers. + + :param name: Name of the model (taken from the configuration file). + + :param config: Parameters read from configuration file. + :type config: ``ptp.configuration.ConfigInterface`` + + """ + super(RelationalNetwork, self).__init__(name, RelationalNetwork, config) + + # Get key mappings. + self.key_feature_maps = self.stream_keys["feature_maps"] + self.key_question_encodings = self.stream_keys["question_encodings"] + self.key_outputs = self.stream_keys["outputs"] + + # Retrieve input sizes from globals. + self.feature_maps_height = self.globals["feature_maps_height"] + self.feature_maps_width = self.globals["feature_maps_width"] + self.feature_maps_depth = self.globals["feature_maps_depth"] + self.question_encoding_size = self.globals["question_encoding_size"] + + + # Create "object" coordinates. + self.obj_coords = [] + for h in range(self.feature_maps_height): + for w in range(self.feature_maps_width): + self.obj_coords.append((h,w)) + + # Get output_size from config and send it to globals. + self.output_size = self.config["output_size"] + self.globals["output_size"] = self.output_size + + # Calculate input size to the g_theta: two "objects" + question (+ optionally: image size) + input_size = 2 * self.feature_maps_depth + self.question_encoding_size + + # Retrieve dropout rate value - if set, will put dropout between every layer. + dropout_rate = self.config["dropout_rate"] + + # Create the model, i.e. the "relational" g_theta MLP. + self.g_theta = torch.nn.Sequential( + torch.nn.Linear(input_size, self.output_size), + # Create activation layer. + torch.nn.ReLU(), + # Create dropout layer. + torch.nn.Dropout(dropout_rate), + torch.nn.Linear(self.output_size, self.output_size), + torch.nn.ReLU(), + torch.nn.Dropout(dropout_rate), + torch.nn.Linear(self.output_size, self.output_size), + torch.nn.ReLU(), + torch.nn.Dropout(dropout_rate), + torch.nn.Linear(self.output_size, self.output_size) + ) + + + + def input_data_definitions(self): + """ + Function returns a dictionary with definitions of input data that are required by the component. + + :return: dictionary containing input data definitions (each of type :py:class:`ptp.utils.DataDefinition`). + """ + return { + self.key_feature_maps: DataDefinition([-1, self.feature_maps_depth, self.feature_maps_height, self.feature_maps_width], [torch.Tensor], "Batch of feature maps [BATCH_SIZE x FEAT_DEPTH x FEAT_HEIGHT x FEAT_WIDTH]"), + self.key_question_encodings: DataDefinition([-1, self.question_encoding_size], [torch.Tensor], "Batch of encoded questions [BATCH_SIZE x QUESTION_ENCODING_SIZE]"), + } + + + def output_data_definitions(self): + """ + Function returns a dictionary with definitions of output data produced the component. + + :return: dictionary containing output data definitions (each of type :py:class:`ptp.utils.DataDefinition`). + """ + return { + self.key_outputs: DataDefinition([-1, self.output_size], [torch.Tensor], "Batch of outputs [BATCH_SIZE x OUTPUT_SIZE]") + } + + def forward(self, data_dict): + """ + Main forward pass of the model. + + :param data_dict: DataDict({'images',**}) + :type data_dict: ``ptp.dadatypes.DataDict`` + """ + + # Unpack DataDict. + feat_m = data_dict[self.key_feature_maps] + enc_q = data_dict[self.key_question_encodings] + + summed_relations = None + # Iterate through all pairs of "objects". + for (h1,w1) in self.obj_coords: + for (h2,w2) in self.obj_coords: + # Get feature maps. + fm1 = feat_m[:, :, h1,w1].view(-1, self.feature_maps_depth) + fm2 = feat_m[:, :, h2,w2].view(-1, self.feature_maps_depth) + # Concatenate with question. + concat = torch.cat([fm1, fm2, enc_q], dim=1) + + # Pass it through g_theta. + rel = self.g_theta(concat) + + # Add to relations. + if summed_relations is None: + summed_relations = rel + else: + # Element wise sum. + summed_relations += rel + + # Add outputs to datadict. + data_dict.extend({self.key_outputs: summed_relations}) diff --git a/ptp/workers/online_trainer.py b/ptp/workers/online_trainer.py index 641a3c6..ec33760 100644 --- a/ptp/workers/online_trainer.py +++ b/ptp/workers/online_trainer.py @@ -334,6 +334,7 @@ def run_experiment(self): # Finalize statistics collection. self.finalize_statistics_collection() self.finalize_tensorboard() + self.logger.info("Experiment logged to: {}".format(self.log_dir)) def main():