Skip to content
This repository was archived by the owner on Jul 18, 2024. It is now read-only.

Commit a251630

Browse files
authored
Merge pull request #20 from IBM/c4_pipelines
vqa_med: question augmentations: random word shuffling and random remove words
2 parents 7c97fe2 + c9327d2 commit a251630

File tree

3 files changed

+141
-7
lines changed

3 files changed

+141
-7
lines changed
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Load config defining problems for training, validation and testing.
2+
default_configs: vqa_med_2019/c2_classification/default_c2_classification.yml
3+
4+
# Training parameters:
5+
training:
6+
problem:
7+
batch_size: 128
8+
9+
# Validation parameters:
10+
validation:
11+
problem:
12+
batch_size: 128
13+
14+
pipeline:
15+
name: c2_word_answer_onehot_bow
16+
17+
# Answer encoding.
18+
answer_tokenizer:
19+
type: SentenceTokenizer
20+
priority: 1.1
21+
preprocessing: lowercase,remove_punctuation
22+
remove_characters: [“,”,’]
23+
streams:
24+
inputs: answers
25+
outputs: tokenized_answer_words
26+
27+
answer_onehot_encoder:
28+
type: SentenceOneHotEncoder
29+
priority: 1.2
30+
data_folder: ~/data/vqa-med
31+
word_mappings_file: answer_words.c2.preprocessed.word.mappings.csv
32+
export_word_mappings_to_globals: True
33+
streams:
34+
inputs: tokenized_answer_words
35+
outputs: encoded_answer_words
36+
globals:
37+
vocabulary_size: answer_words_vocabulary_size
38+
word_mappings: answer_words_word_mappings
39+
40+
answer_bow_encoder:
41+
type: BOWEncoder
42+
priority: 1.3
43+
streams:
44+
inputs: encoded_answer_words
45+
outputs: bow_answer_words
46+
globals:
47+
bow_size: answer_words_vocabulary_size
48+
49+
# Model.
50+
classifier:
51+
type: FeedForwardNetwork
52+
hidden_sizes: [500, 500]
53+
dropout_rate: 0.5
54+
priority: 3
55+
streams:
56+
inputs: bow_answer_words
57+
globals:
58+
input_size: answer_words_vocabulary_size
59+
prediction_size: vocabulary_size_c2
60+
61+
# Viewers.
62+
viewer:
63+
type: StreamViewer
64+
priority: 100.4
65+
input_streams: answers, tokenized_answer_words, predicted_answers
66+
67+
#: pipeline

configs/vqa_med_2019/c4_classification/c4_word_answer_onehot_bow.yml

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,11 @@ default_configs: vqa_med_2019/c4_classification/default_c4_classification.yml
55
training:
66
problem:
77
batch_size: 128
8-
remove_punctuation: all
98

109
# Validation parameters:
1110
validation:
1211
problem:
1312
batch_size: 128
14-
remove_punctuation: all
1513

1614
pipeline:
1715
name: c4_word_answer_onehot_bow
@@ -51,13 +49,19 @@ pipeline:
5149
# Model.
5250
classifier:
5351
type: FeedForwardNetwork
54-
hidden_sizes: [500]
52+
hidden_sizes: [500, 500]
5553
dropout_rate: 0.5
5654
priority: 3
5755
streams:
5856
inputs: bow_answer_words
5957
globals:
6058
input_size: answer_words_vocabulary_size
6159
prediction_size: vocabulary_size_c4
62-
60+
61+
# Viewers.
62+
viewer:
63+
type: StreamViewer
64+
priority: 100.4
65+
input_streams: answers, tokenized_answer_words, predicted_answers
66+
6367
#: pipeline

ptp/components/problems/image_text_to_class/vqa_med_2019.py

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@
2020
import os
2121
import string
2222
import tqdm
23+
2324
import pandas as pd
2425
from PIL import Image
25-
26+
import numpy as np
2627
import nltk
2728

2829
import torch
@@ -307,6 +308,62 @@ def preprocess_text(self, text, lowercase = False, remove_punctuation = False, t
307308
# Return cleaned text.
308309
return cleansed_words
309310

311+
def random_remove_stop_words(self, words):
312+
"""
313+
Function removes random stop words, each with 0.5 probability.
314+
315+
:param words: tokenized text
316+
:return: resulting tokenized text.
317+
"""
318+
319+
# Find stop words.
320+
stops = set(nltk.corpus.stopwords.words("english"))
321+
stop_words = [False]*len(words)
322+
for i, word in enumerate(words):
323+
if word in stops:
324+
stop_words[i] = True
325+
#print(stop_words)
326+
if sum(stop_words) > 0:
327+
remove_probs = np.random.binomial(1, 0.5, len(words))
328+
#print(remove_probs)
329+
result = []
330+
for word,is_stop,rem_prob in zip(words,stop_words,remove_probs):
331+
if is_stop and rem_prob:
332+
# Remove word.
333+
continue
334+
# Else: add word.
335+
result.append(word)
336+
337+
return result
338+
339+
340+
def random_shuffle_words(self, words):
341+
"""
342+
Function randomly shuffles, with probability of 0.5, two consecutive words in text.
343+
344+
:param words: tokenized text
345+
:return: resulting tokenized text.
346+
"""
347+
# Do not shuffle if there are less than 2 words.
348+
if len(words) < 2:
349+
return words
350+
# Shuffle with probability of 0.5.
351+
if np.random.binomial(1, 0.5, 1):
352+
return words
353+
354+
# Find words to shuffle - random without replacement.
355+
shuffled_i = np.random.choice(len(words)-1, )
356+
indices = [i for i in range(len(words))]
357+
indices[shuffled_i] = shuffled_i+1
358+
indices[shuffled_i+1] = shuffled_i
359+
#print(indices)
360+
361+
# Create resulting table.
362+
result = [words[indices[i]] for i in range(len(words))]
363+
364+
return result
365+
366+
310367
def load_dataset(self, source_files, source_categories):
311368
"""
312369
Loads the dataset from one or more files.
@@ -368,7 +425,6 @@ def load_dataset(self, source_files, source_categories):
368425
return dataset
369426

370427

371-
372428
def __getitem__(self, index):
373429
"""
374430
Getter method to access the dataset and return a single sample.
@@ -424,7 +480,14 @@ def __getitem__(self, index):
424480

425481
# Apply question transformations.
426482
preprocessed_question = item[self.key_questions]
427-
# TODO: apply additional random transformations e.g. "shuffle_words"
483+
if 'tokenize' in self.question_preprocessing:
484+
# Apply them only if text is tokenized.
485+
if 'random_remove_stop_words' in self.question_preprocessing:
486+
preprocessed_question = self.random_remove_stop_words(preprocessed_question)
487+
488+
if 'random_shuffle_words' in self.question_preprocessing:
489+
preprocessed_question = self.random_shuffle_words(preprocessed_question)
490+
# Return question.
428491
data_dict[self.key_questions] = preprocessed_question
429492

430493
# Return answer.

0 commit comments

Comments
 (0)