Skip to content
This repository was archived by the owner on Jul 18, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion configs/translation/eng_fra_translation_enc_attndec.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ pipeline:
source_vocabulary_files: eng-fra/eng.train.txt,eng-fra/eng.valid.txt,eng-fra/eng.test.txt
vocabulary_mappings_file: eng-fra/eng.all.tokenized_words
regenerate: True
additional_tokens: <eos>
additional_tokens: <PAD>,<EOS>
import_word_mappings_from_globals: False
export_word_mappings_to_globals: False
fixed_padding: 10
Expand All @@ -73,11 +73,15 @@ pipeline:
source_vocabulary_files: eng-fra/fra.train.txt,eng-fra/fra.valid.txt,eng-fra/fra.test.txt
import_word_mappings_from_globals: False
export_word_mappings_to_globals: True
export_pad_mapping_to_globals: True
eos_token: True
fixed_padding: 10
additional_tokens: <PAD>,<EOS>
regenerate: True
streams:
inputs: targets
outputs: indexed_targets
pad_index: tgt_pad_index

# Single layer GRU Encoder
encoder:
Expand Down Expand Up @@ -135,6 +139,8 @@ pipeline:
streams:
targets: indexed_targets
loss: loss
globals:
ignore_index: tgt_pad_index

# Prediction decoding.
prediction_decoder:
Expand All @@ -159,6 +165,7 @@ pipeline:
bleu:
type: BLEUStatistics
priority: 100.2
ignored_words: ["<PAD>", "<EOS>"]
streams:
targets: indexed_targets

Expand Down
8 changes: 5 additions & 3 deletions configs/vqa_med_2019/c4_classification/c4_enc_attndec.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ training:
problem:
batch_size: 64
categories: C4
question_preprocessing: lowercase, remove_punctuation, tokenize
question_preprocessing: lowercase, remove_punctuation, tokenize, random_remove_stop_words #,random_shuffle_words
answer_preprocessing: lowercase, remove_punctuation, tokenize
export_sample_weights: ~/data/vqa-med/answers.c4.weights.csv
sampler:
weights: ~/data/vqa-med/answers.c4.weights.csv
dataloader:
num_workers: 8
num_workers: 2
# Termination.
terminal_conditions:
loss_stop: 1.0e-2
Expand All @@ -27,7 +27,7 @@ validation:
question_preprocessing: lowercase, remove_punctuation, tokenize
answer_preprocessing: lowercase, remove_punctuation, tokenize
dataloader:
num_workers: 8
num_workers: 2

pipeline:
name: c4_dec_attndecoder
Expand All @@ -54,6 +54,8 @@ pipeline:
word_mappings_file: answer_words.c4.preprocessed.word.mappings.csv
import_word_mappings_from_globals: False
export_word_mappings_to_globals: True
export_pad_mapping_to_globals: True
eos_token: True
fixed_padding: 10
streams:
inputs: answers
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
# Load config defining problems for training, validation and testing.
default_configs: vqa_med_2019/default_vqa_med_2019.yml

# Training parameters:
training:
problem:
batch_size: 64
categories: C4
question_preprocessing: lowercase, remove_punctuation, tokenize #, random_remove_stop_words #,random_shuffle_words
answer_preprocessing: lowercase, remove_punctuation, tokenize
export_sample_weights: ~/data/vqa-med/answers.c4.weights.csv
batch_size: 32
sampler:
weights: ~/data/vqa-med/answers.c4.weights.csv
dataloader:
num_workers: 4
# Termination.
terminal_conditions:
loss_stop: 1.0e-2
episode_limit: 1000000
epoch_limit: -1

# Validation parameters:
validation:
problem:
batch_size: 64
categories: C4
question_preprocessing: lowercase, remove_punctuation, tokenize
answer_preprocessing: lowercase, remove_punctuation, tokenize
batch_size: 32
dataloader:
num_workers: 4

pipeline:
name: c4_enc_attndec_resnet152_ewm_cat_is

global_publisher:
priority: 0
type: GlobalVariablePublisher
# Add input_size to globals.
keys: [question_encoder_output_size, image_encoder_output_size, element_wise_activation_size,image_size_encoder_input_size, image_size_encoder_output_size]
values: [100, 100, 100, 2, 10]

# Question embeddings
question_embeddings:
priority: 1.0
type: SentenceEmbeddings
embeddings_size: 100
pretrained_embeddings_file: glove.6B.100d.txt
data_folder: ~/data/vqa-med
word_mappings_file: questions.all.word.mappings.csv
fixed_padding: 10
additional_tokens: <PAD>,<EOS>
streams:
inputs: questions
outputs: embedded_questions

# Target encoding.
target_indexer:
type: SentenceIndexer
priority: 1.1
data_folder: ~/data/vqa-med
word_mappings_file: answer_words.c4.preprocessed.word.mappings.csv
import_word_mappings_from_globals: False
export_word_mappings_to_globals: True
export_pad_mapping_to_globals: True
additional_tokens: <PAD>,<EOS>
eos_token: True
fixed_padding: 10
streams:
inputs: answers
outputs: indexed_answers
globals:
vocabulary_size: ans_vocabulary_size
word_mappings: ans_word_mappings
pad_index: ans_pad_index

# Image encoder.
image_encoder:
priority: 2.0
type: TorchVisionWrapper
model_type: resnet152
streams:
inputs: images
outputs: image_activations
globals:
output_size: image_encoder_output_size

# Single layer GRU Encoder
encoder:
type: RecurrentNeuralNetwork
cell_type: GRU
priority: 3
initial_state: Trainable
hidden_size: 100
num_layers: 1
use_logsoftmax: False
output_last_state: True
prediction_mode: Dense
ffn_output: False
dropout_rate: 0.1
streams:
inputs: embedded_questions
predictions: s2s_encoder_output
output_state: s2s_state_output
globals:
input_size: embeddings_size
prediction_size: question_encoder_output_size

reshaper_1:
priority: 3.01
type: ReshapeTensor
input_dims: [1, -1, 100]
output_dims: [-1, 100]
streams:
inputs: s2s_state_output
outputs: s2s_state_output_reshaped
globals:
output_size: s2s_state_output_reshaped_size

# Element wise multiplication + FF.
question_image_fusion:
priority: 3.1
type: ElementWiseMultiplication
dropout_rate: 0.5
streams:
image_encodings: image_activations
question_encodings: s2s_state_output_reshaped
outputs: element_wise_activations
globals:
image_encoding_size: image_encoder_output_size
question_encoding_size: question_encoder_output_size
output_size: element_wise_activation_size

question_image_ffn:
priority: 3.2
type: FeedForwardNetwork
hidden_sizes: [100]
dropout_rate: 0.5
streams:
inputs: element_wise_activations
predictions: question_image_activations
globals:
input_size: element_wise_activation_size
prediction_size: element_wise_activation_size

reshaper_2:
priority: 3.3
type: ReshapeTensor
input_dims: [-1, 100]
output_dims: [1, -1, 100]
streams:
inputs: question_image_activations
outputs: question_image_activations_reshaped
globals:
output_size: question_image_activations_reshaped_size

# Single layer GRU Decoder with attention
decoder:
type: Attn_Decoder_RNN
priority: 4
hidden_size: 100
use_logsoftmax: False
autoregression_length: 10
prediction_mode: Dense
dropout_rate: 0.1
streams:
inputs: s2s_encoder_output
predictions: s2s_decoder_output
input_state: question_image_activations_reshaped
globals:
input_size: element_wise_activation_size
prediction_size: element_wise_activation_size

# FF, to resize the from the output size of the seq2seq to the size of the target vector
ff_resize_s2s_output:
type: FeedForwardNetwork
use_logsoftmax: True
dimensions: 3
priority: 5
dropout_rate: 0.1
streams:
inputs: s2s_decoder_output
globals:
input_size: element_wise_activation_size
prediction_size: ans_vocabulary_size

# Loss
nllloss:
type: NLLLoss
priority: 6
num_targets_dims: 2
streams:
targets: indexed_answers
loss: loss
globals:
ignore_index: ans_pad_index

# Prediction decoding.
prediction_decoder:
type: SentenceIndexer
priority: 10
# Reverse mode.
reverse: True
# Use distributions as inputs.
use_input_distributions: True
data_folder: ~/data/vqa-med
import_word_mappings_from_globals: True
globals:
word_mappings: ans_word_mappings
streams:
inputs: predictions
outputs: prediction_sentences

# Statistics.
batch_size:
type: BatchSizeStatistics
priority: 100.0

bleu:
type: BLEUStatistics
priority: 100.2
globals:
word_mappings: ans_word_mappings
streams:
targets: indexed_answers


# Viewers.
viewer:
type: StreamViewer
priority: 100.3
input_streams: questions,answers,indexed_answers,prediction_sentences

#: pipeline