Skip to content
This repository was archived by the owner on Jul 18, 2024. It is now read-only.
7 changes: 5 additions & 2 deletions configs/default/components/masking/string_to_mask.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
# 1. CONFIGURATION PARAMETERS that will be LOADED by the component.
####################################################################

# Value that will be used when word is out of vocavbulary (LOADED)
# Value that will be used when word is out of vocabulary (LOADED)
# (Mask for that element will be 0 as well)
out_of_vocabulary_value: -1
# -100 is the default value used by PyTroch loss functions to specify
# target values that will ignored and does not contribute to the input gradient.
# (ignore_index=-100)
out_of_vocabulary_value: -100

streams:
####################################################################
Expand Down
7 changes: 5 additions & 2 deletions configs/default/components/text/label_indexer.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,12 @@ import_word_mappings_from_globals: False
# Flag informing whether word mappings will be exported to globals (LOADED)
export_word_mappings_to_globals: False

# Value that will be used when word is out of vocavbulary (LOADED)
# Value that will be used when word is out of vocabulary (LOADED)
# (Mask for that element will be 0 as well)
out_of_vocabulary_value: -1
# -100 is the default value used by PyTroch loss functions to specify
# target values that will ignored and does not contribute to the input gradient.
# (ignore_index=-100)
out_of_vocabulary_value: -100

streams:
####################################################################
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Load config defining problems for training, validation and testing.
default_configs: vqa_med_2019/c2_classification/default_c2_classification.yml

pipeline:
name: vqa_med_c2_classification_all_rnn_vgg_concat

global_publisher:
type: GlobalVariablePublisher
priority: 0
# Add input_size to globals.
keys: [question_embeddings_output_size, image_size_encoder_input_size, image_size_encoder_output_size, image_encoder_output_size]
values: [100, 2, 10, 100]

# First subpipeline: question.
# Questions encoding.
question_tokenizer:
type: SentenceTokenizer
priority: 1.1
streams:
inputs: questions
outputs: tokenized_questions

# Model 1: Embeddings
question_embeddings:
type: SentenceEmbeddings
priority: 1.2
embeddings_size: 50
pretrained_embeddings_file: glove.6B.50d.txt
data_folder: ~/data/vqa-med
word_mappings_file: questions.all.word.mappings.csv
streams:
inputs: tokenized_questions
outputs: embedded_questions

# Model 2: RNN
question_lstm:
type: RecurrentNeuralNetwork
cell_type: LSTM
prediction_mode: Last
priority: 1.3
use_logsoftmax: False
initial_state_trainable: False
#num_layers: 5
hidden_size: 50
streams:
inputs: embedded_questions
predictions: question_activations
globals:
input_size: embeddings_size
prediction_size: question_embeddings_output_size

# 2nd subpipeline: image size.
# Model - image size classifier.
image_size_encoder:
type: FeedForwardNetwork
priority: 2.1
streams:
inputs: image_sizes
predictions: image_size_activations
globals:
input_size: image_size_encoder_input_size
prediction_size: image_size_encoder_output_size

# 3rd subpipeline: image.
# Image encoder.
image_encoder:
type: TorchVisionWrapper
priority: 3.1
streams:
inputs: images
predictions: image_activations
globals:
prediction_size: image_encoder_output_size

# 4th subpipeline: concatenation + FF.
concat:
type: Concatenation
priority: 4.1
input_streams: [question_activations,image_size_activations,image_activations]
# Concatenation
dim: 1 # default
input_dims: [[-1,100],[-1,10],[-1,100]]
output_dims: [-1,210]
streams:
outputs: concatenated_activations
globals:
output_size: output_size


classifier:
type: FeedForwardNetwork
hidden_sizes: [100]
priority: 4.2
streams:
inputs: concatenated_activations
globals:
input_size: output_size
prediction_size: vocabulary_size_c2


#: pipeline
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Load config defining problems for training, validation and testing.
default_configs: vqa_med_2019/default_vqa_med_2019.yml

# Training parameters:
training:
problem:
categories: C2
sampler:
name: WeightedRandomSampler
weights: ~/data/vqa-med/answers.c2.weights.csv
dataloader:
num_workers: 4

# Validation parameters:
validation:
problem:
categories: C2
dataloader:
num_workers: 4


pipeline:

# Answer encoding.
answer_indexer:
type: LabelIndexer
priority: 0.1
data_folder: ~/data/vqa-med
word_mappings_file: answers.c2.word.mappings.csv
# Export mappings and size to globals.
export_word_mappings_to_globals: True
streams:
inputs: answers
outputs: answers_ids
globals:
vocabulary_size: vocabulary_size_c2
word_mappings: word_mappings_c2


# Predictions decoder.
prediction_decoder:
type: WordDecoder
priority: 10.1
# Use the same word mappings as label indexer.
import_word_mappings_from_globals: True
streams:
inputs: predictions
outputs: predicted_answers
globals:
vocabulary_size: vocabulary_size_c2
word_mappings: word_mappings_c2

# Loss
nllloss:
type: NLLLoss
priority: 10.2
targets_dim: 1
streams:
targets: answers_ids
loss: loss

# Statistics.
batch_size:
type: BatchSizeStatistics
priority: 100.1

#accuracy:
# type: AccuracyStatistics
# priority: 100.2
# streams:
# targets: answers_ids

precision_recall:
type: PrecisionRecallStatistics
priority: 100.3
use_word_mappings: True
show_class_scores: True
show_confusion_matrix: True
streams:
targets: answers_ids
globals:
word_mappings: word_mappings_c2
num_classes: vocabulary_size_c2

# Viewers.
viewer:
type: StreamViewer
priority: 100.4
input_streams: questions,category_names,answers,predicted_answers

#: pipeline
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ pipeline:
# Viewers.
viewer:
type: StreamViewer
priority: 4.3
priority: 7.3
input_streams: questions,answers, category_names,predicted_question_categories_names, pipe5_c1_masks,pipe5_c1_answers_without_yn_ids,pipe5_c1_predictions, pipe6_binary_masks,pipe6_binary_answers_ids,pipe6_binary_predictions, pipe7_merged_predictions


Expand Down
Loading