diff --git a/configs/default/components/problems/image_text_to_class/gqa.yml b/configs/default/components/problems/image_text_to_class/gqa.yml new file mode 100644 index 0000000..e1c662f --- /dev/null +++ b/configs/default/components/problems/image_text_to_class/gqa.yml @@ -0,0 +1,77 @@ +# This file defines the default values for the GQA problem. + +#################################################################### +# 1. CONFIGURATION PARAMETERS that will be LOADED by the component. +#################################################################### + +# Folder where problem will store data (LOADED) +data_folder: '~/data/gqa' + +# Defines the set (split) that will be used (LOADED) +# Options: training_0 | training | validation | test_dev | test | challenge | submission (?) +# Note: test_dev should be used for validation. +split: training_0 + +# Flag indicating whether the problem will load and return images (LOADED) +stream_images: True + +# Resize parameter (LOADED) +# When present, resizes the images from original size to [height, width] +# Depth remains set to 3. +resize_image: [224, 224] + +# Select applied image preprocessing/augmentations (LOADED) +# Use one (or more) of the affine transformations: +# none | normalize | all +# Accepted formats: a,b,c or [a,b,c] +image_preprocessing: none + +streams: + #################################################################### + # 2. Keymappings associated with INPUT and OUTPUT streams. + #################################################################### + + # Stream containing batch of indices (OUTPUT) + # Every problem MUST return that stream. + indices: indices + + # Stream containing batch of sample (original) identifiers (OUTPUT) + sample_ids: sample_ids + + # Stream containing batch of images (OUTPUT) + images: images + + # Stream containing batch of image names (OUTPUT) + image_ids: image_ids + + # Stream containing batch of questions (OUTPUT) + questions: questions + + # Stream containing targets answers (labels) (OUTPUT) + answers: answers + + # Stream containing targets answers consisting of many words (OUTPUT) + full_answers: full_answers + + # Stream containing scene descriptions (OUTPUT) + #answers: scene_graphs + +globals: + #################################################################### + # 3. Keymappings of variables that will be RETRIEVED from GLOBALS. + #################################################################### + + #################################################################### + # 4. Keymappings associated with GLOBAL variables that will be SET. + #################################################################### + + # Width of the image (SET) + input_width: image_width + # Height of the image (SET) + input_height: image_height + # Depth of the image (SET) + input_depth: image_depth + + #################################################################### + # 5. Keymappings associated with statistics that will be ADDED. + #################################################################### diff --git a/ptp/components/problems/image_text_to_class/clevr.py b/ptp/components/problems/image_text_to_class/clevr.py index afc5fe3..ff18441 100644 --- a/ptp/components/problems/image_text_to_class/clevr.py +++ b/ptp/components/problems/image_text_to_class/clevr.py @@ -71,10 +71,6 @@ def __init__(self, name, config): # Call constructors of parent classes. Problem.__init__(self, name, CLEVR, config) - # (Eventually) download required packages. - #nltk.download('punkt') - #nltk.download('stopwords') - # Get key mappings of all output streams. self.key_images = self.stream_keys["images"] self.key_image_ids = self.stream_keys["image_ids"] @@ -246,7 +242,7 @@ def output_data_definitions(self): d[self.key_questions] = DataDefinition([-1, 1], [list, str], "Batch of questions, each being a string consisting of many words [BATCH_SIZE] x [STRING]") # Add stream with answers. - d[self.key_answers]= DataDefinition([-1, 1], [list, str], "Batch of target answers, each being a string consisting of many words [BATCH_SIZE] x [STRING]") + d[self.key_answers]= DataDefinition([-1, 1], [list, str], "Batch of target answers, each being a string consisting of sinlge word (label) [BATCH_SIZE] x [STRING]") return d @@ -271,7 +267,7 @@ def load_dataset(self, source_data_file): dataset = [] with open(source_data_file) as f: - self.logger.info('Loading samples from {} ...'.format(source_data_file)) + self.logger.info("Loading samples from '{}'...".format(source_data_file)) dataset = json.load(f)['questions'] self.logger.info("Loaded dataset consisting of {} samples".format(len(dataset))) diff --git a/ptp/components/problems/image_text_to_class/gqa.py b/ptp/components/problems/image_text_to_class/gqa.py new file mode 100644 index 0000000..0f2da3c --- /dev/null +++ b/ptp/components/problems/image_text_to_class/gqa.py @@ -0,0 +1,338 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright (C) IBM Corporation 2019 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__author__ = "Tomasz Kornuta" + +import os +import json +import tqdm +from PIL import Image + +import torch +from torchvision import transforms + +from ptp.components.problems.problem import Problem +from ptp.data_types.data_definition import DataDefinition + +from ptp.configuration.config_parsing import get_value_from_dictionary, get_value_list_from_dictionary +from ptp.configuration.configuration_error import ConfigurationError + +class GQA(Problem): + """ + Problem providing data associated with the GQA dataset (Question Answering on Image Scene Graphs). + + The dataset consists of 22M questions about various day-to-day images. + Each image is associated with a scene graph of the image's objects, attributes and relations. + Each question is associated with a structured representation of its semantics, a functional program + that specifies the reasoning steps have to be taken to answer it. + + For more details please refer to the associated _website or _paper for more details. + Test set with answers can be downloaded from a separate repository _repo. + + .. _website: https://cs.stanford.edu/people/dorarad/gqa/index.html + + .._paper: http://openaccess.thecvf.com/content_CVPR_2019/html/Hudson_GQA_A_New_Dataset_for_Real-World_Visual_Reasoning_and_Compositional_CVPR_2019_paper.html + + """ + def __init__(self, name, config): + """ + Initializes problem object. Calls base constructor. Downloads the dataset if not present and loads the adequate files depending on the mode. + + :param name: Name of the component. + + :param class_type: Class type of the component. + + :param config: Dictionary of parameters (read from configuration ``.yaml`` file). + """ + # Call constructors of parent classes. + Problem.__init__(self, name, GQA, config) + + # Get key mappings of all output streams. + self.key_sample_ids = self.stream_keys["sample_ids"] + self.key_images = self.stream_keys["images"] + self.key_image_ids = self.stream_keys["image_ids"] + self.key_questions = self.stream_keys["questions"] + self.key_answers = self.stream_keys["answers"] + self.key_full_answers = self.stream_keys["full_answers"] + + # Get flag informing whether we want to stream images or not. + self.stream_images = self.config['stream_images'] + + # Check the resize image option. + if len(self.config['resize_image']) != 2: + self.logger.error("'resize_image' field must contain 2 values: the desired height and width") + exit(-1) + + # Output image dimensions. + self.height = self.config['resize_image'][0] + self.width = self.config['resize_image'][1] + self.depth = 3 + self.logger.info("Setting image size to [D x H x W]: {} x {} x {}".format(self.depth, self.height, self.width)) + + # Set global variables - all dimensions ASIDE OF BATCH. + self.globals["image_height"] = self.height + self.globals["image_width"] = self.width + self.globals["image_depth"] = self.depth + + # Get image preprocessing. + self.image_preprocessing = get_value_list_from_dictionary( + "image_preprocessing", self.config, + 'none | normalize | all'.split(" | ") + ) + if 'none' in self.image_preprocessing: + self.image_preprocessing = [] + if 'all' in self.image_preprocessing: + self.image_preprocessing = ['normalize'] + # Add resize as transformation. + self.image_preprocessing = ["resize"] + self.image_preprocessing + + self.logger.info("Applied image preprocessing: {}".format(self.image_preprocessing)) + + # Get the absolute path. + self.data_folder = os.path.expanduser(self.config['data_folder']) + + # Get split. + split = get_value_from_dictionary('split', self.config, "training_0 | training | validation | test_dev | test".split(" | ")) + self.split_image_folder = os.path.join(self.data_folder, "images") + + # Set split-dependent data. + if split == 'training': + # Training split folder and file with data question. + data_files = [] + for i in range(10): + data_files.append(os.path.join(self.data_folder, "questions1.2", "train_all_questions", "train_all_questions_{}.json".format(i))) + + elif split == 'training_0': + # Validation split folder and file with data question. + data_files = [ os.path.join(self.data_folder, "questions1.2", "train_all_questions", "train_all_questions_0.json") ] + self.logger.warning("Please remember that this split constitutes only 10 percent of the whole training set!") + + elif split == 'validation': + # Validation split folder and file with data question. + data_files = [ os.path.join(self.data_folder, "questions1.2", "val_all_questions.json") ] + self.logger.warning("Please use 'test_dev' split for validation!") + + elif split == 'test_dev': + # Validation split folder and file with data question. + data_files = [ os.path.join(self.data_folder, "questions1.2", "testdev_all_questions.json") ] + + elif split == 'test': + # Test split folder and file with data question. + data_files = [ os.path.join(self.data_folder, "questions1.2", "test_all_questions.json") ] + + else: + raise ConfigurationError("Split {} not supported yet".format(split)) + + # Load dataset. + self.dataset = self.load_dataset(data_files) + + # Display exemplary sample. + i = 0 + sample = self.dataset[i] + # Check if this is a test set. + self.logger.info("Exemplary sample {} ({}):\n image_ids: {}\n question: {}\n answer: {} ({})".format( + i, + sample[self.key_sample_ids], + sample[self.key_image_ids], + sample[self.key_questions], + sample[self.key_answers], + sample[self.key_full_answers] + )) + + + def output_data_definitions(self): + """ + Function returns a dictionary with definitions of output data produced the component. + + :return: dictionary containing output data definitions (each of type :py:class:`ptp.utils.DataDefinition`). + """ + # Add all "standard" streams. + d = { + self.key_indices: DataDefinition([-1, 1], [list, int], "Batch of sample indices [BATCH_SIZE] x [1]"), + self.key_sample_ids: DataDefinition([-1, 1], [list, int], "Batch of sample ids [BATCH_SIZE] x [1]"), + self.key_image_ids: DataDefinition([-1, 1], [list, str], "Batch of image names, each being a single word [BATCH_SIZE] x [STRING]"), + } + + # Return images only when required. + if self.stream_images: + d[self.key_images] = DataDefinition([-1, self.depth, self.height, self.width], [torch.Tensor], "Batch of images [BATCH_SIZE x IMAGE_DEPTH x IMAGE_HEIGHT x IMAGE_WIDTH]") + + # Add stream with questions. + d[self.key_questions] = DataDefinition([-1, 1], [list, str], "Batch of questions, each being a string consisting of many words [BATCH_SIZE] x [STRING]") + + # Add stream with answers. + d[self.key_answers]= DataDefinition([-1, 1], [list, str], "Batch of target answers, each being a string consisting of few words (still treated as a single label) [BATCH_SIZE] x [STRING]") + d[self.key_full_answers]= DataDefinition([-1, 1], [list, str], "Batch of target full (long) answers, each being a string consisting of many words [BATCH_SIZE] x [STRING]") + + return d + + + def __len__(self): + """ + Returns the "size" of the "problem" (total number of samples). + + :return: The size of the problem. + """ + return len(self.dataset) + + + def load_dataset(self, source_files): + """ + Loads the dataset from source files. + + :param source_files: list of jSON file with image ids, questions, answers, scene graphs, etc. + + """ + self.logger.info("Loading dataset from:\n {}".format(source_files)) + dataset = [] + + # Load and process files, one by one. + for source_file in source_files: + with open(source_file) as f: + self.logger.info("Loading samples from '{}'...".format(source_file)) + json_dataset = json.load(f) + # Process samples. + + # Add tdqm bar. + t = tqdm.tqdm(total=len(json_dataset)) + for key,value in json_dataset.items(): + # New sample. + sample = {} + sample[self.key_sample_ids] = key + sample[self.key_image_ids] = value["imageId"] + sample[self.key_questions] = value["question"] + + # Return answer. + if "answer" in value.keys(): + sample[self.key_answers] = value["answer"] + sample[self.key_full_answers] = value["fullAnswer"] + else: + # Test set. + sample[self.key_answers] = "" + sample[self.key_full_answers] = "" + + # Add to dataset. + dataset.append(sample) + t.update() + # Close the bar. + t.close() + + self.logger.info("Loaded dataset consisting of {} samples".format(len(dataset))) + return dataset + + + def get_image(self, img_id): + """ + Function loads and returns image along with its size. + Additionally, it performs all the required transformations. + + :param img_id: Identifier of the images. + :param img_folder: Path to the image. + + :return: image (Tensor) + """ + + # Load the image. + img = Image.open(os.path.join(self.split_image_folder, img_id+".jpg")).convert('RGB') + + image_transformations_list = [] + + # Optional: resize. + if 'resize' in self.image_preprocessing: + image_transformations_list.append(transforms.Resize([self.height,self.width])) + + # Add obligatory transformation. + image_transformations_list.append(transforms.ToTensor()) + + # Optional: normalization. + if 'normalize' in self.image_preprocessing: + # Use normalization that the pretrained models from TorchVision require. + image_transformations_list.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) + + # Resize the image and transform to Torch Tensor. + transforms_com = transforms.Compose(image_transformations_list) + # Apply transformations. + img = transforms_com(img) + + # Return image. + return img + + def __getitem__(self, index): + """ + Getter method to access the dataset and return a single sample. + + :param index: index of the sample to return. + :type index: int + + :return: DataDict({'indices', 'sample_ids', images', 'images_ids','questions', 'answers', 'full_answers'}) + """ + # Get item. + item = self.dataset[index] + + # Create the resulting sample (data dict). + data_dict = self.create_data_dict(index) + + # Return sample id. + data_dict[self.key_sample_ids] = item[self.key_sample_ids] + + # Load and stream the image ids. + img_id = item[self.key_image_ids] + data_dict[self.key_image_ids] = img_id + + # Load the adequate image - only when required. + if self.stream_images: + data_dict[self.key_images] = self.get_image(img_id) + + # Return question. + data_dict[self.key_questions] = item[self.key_questions] + + # Return answers. + data_dict[self.key_answers] = item[self.key_answers] + data_dict[self.key_full_answers] = item[self.key_full_answers] + + # Return sample. + return data_dict + + + def collate_fn(self, batch): + """ + Combines a list of DataDict (retrieved with :py:func:`__getitem__`) into a batch. + + :param batch: list of individual samples to combine + :type batch: list + + :return: DataDict({'indices', 'images', 'images_ids','questions', 'answers', 'category_ids', 'image_sizes'}) + + """ + # Collate indices. + data_dict = self.create_data_dict([sample[self.key_indices] for sample in batch]) + + # Collate sample ids. + data_dict[self.key_sample_ids] = [item[self.key_sample_ids] for item in batch] + + # Stack images. + data_dict[self.key_image_ids] = [item[self.key_image_ids] for item in batch] + if self.stream_images: + data_dict[self.key_images] = torch.stack([item[self.key_images] for item in batch]).type(torch.FloatTensor) + + # Collate lists/lists of lists. + data_dict[self.key_questions] = [item[self.key_questions] for item in batch] + data_dict[self.key_answers] = [item[self.key_answers] for item in batch] + data_dict[self.key_full_answers] = [item[self.key_full_answers] for item in batch] + + # Return collated dict. + return data_dict diff --git a/tests/__init__.py b/tests/__init__.py index a634db8..9471702 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -4,7 +4,7 @@ from .components.component_tests import TestComponent from .components.problems.clevr_tests import TestCLEVR -#from .components.problems.gqa_tests import TestGQA +from .components.problems.gqa_tests import TestGQA from .components.problems.problem_tests import TestProblem from .configuration.config_interface_tests import TestConfigInterface @@ -25,7 +25,7 @@ 'TestkFoldWeightedRandomSampler', # Components 'TestComponent', - #'TestGQA', + 'TestGQA', 'TestProblem', # Configuration 'TestConfigRegistry', diff --git a/tests/components/problems/clevr_tests.py b/tests/components/problems/clevr_tests.py index 00d75af..791bf40 100644 --- a/tests/components/problems/clevr_tests.py +++ b/tests/components/problems/clevr_tests.py @@ -29,20 +29,20 @@ class TestCLEVR(unittest.TestCase): def test_training_set(self): """ - Tests the CLEVR training split. + Tests the training split. ..note: - Test is performed only if json file '~/data/CLEVR_v1.0/questions/CLEVR_train_questions.json' is found. + Test on real data is performed only if json file '~/data/CLEVR_v1.0/questions/CLEVR_train_questions.json' is found. """ # Empty config. - config = ConfigInterface() - config.add_config_params({"split": "training"}) + config = ConfigInterface("CLEVR") + config.add_config_params({"clevr_training": {"split": "training", "globals": {"image_height": "clevr_image_height", "image_width": "clevr_image_width"}}}) # Check the existence of test set. - if check_file_existence(path.expanduser('~/data/CLEVR_v1.0/questions'),'CLEVR_train_questions.json'): + if False: #check_file_existence(path.expanduser('~/data/CLEVR_v1.0/questions'),'CLEVR_train_questions.json'): # Create object. - clevr = CLEVR("CLEVR", config) + clevr = CLEVR("clevr_training", config["clevr_training"]) # Check dataset size. self.assertEqual(len(clevr), 699989) @@ -61,7 +61,7 @@ def test_training_set(self): # Mock up the load_dataset method. with patch( "ptp.components.problems.image_text_to_class.clevr.CLEVR.load_dataset", MagicMock( side_effect = [ dataset_content ] )): - clevr = CLEVR("CLEVR", config) + clevr = CLEVR("clevr_training", config["clevr_training"]) # Mock up the get_image method. with patch( "ptp.components.problems.image_text_to_class.clevr.CLEVR.get_image", MagicMock( side_effect = [ "0" ] )): @@ -78,20 +78,20 @@ def test_training_set(self): def test_validation_set(self): """ - Tests the CLEVR validation split. + Tests the validation split. ..note: Test on real data is performed only if json file '~/data/CLEVR_v1.0/questions/CLEVR_val_questions.json' is found. """ # Empty config. config = ConfigInterface() - config.add_config_params({"split": "validation"}) + config.add_config_params({"clevr_validation": {"split": "validation", "globals": {"image_height": "clevr_image_height", "image_width": "clevr_image_width"}}}) # Check the existence of test set. - if check_file_existence(path.expanduser('~/data/CLEVR_v1.0/questions'),'CLEVR_test_questions.json'): + if False: #check_file_existence(path.expanduser('~/data/CLEVR_v1.0/questions'),'CLEVR_test_questions.json'): # Create object. - clevr = CLEVR("CLEVR", config) + clevr = CLEVR("clevr_validation", config["clevr_validation"]) # Check dataset size. self.assertEqual(len(clevr), 149991) @@ -107,7 +107,7 @@ def test_validation_set(self): # Mock up the load_dataset method. with patch( "ptp.components.problems.image_text_to_class.clevr.CLEVR.load_dataset", MagicMock( side_effect = [ dataset_content ] )): - clevr = CLEVR("CLEVR", config) + clevr = CLEVR("clevr_validation", config["clevr_validation"]) # Mock up the get_image method. with patch( "ptp.components.problems.image_text_to_class.clevr.CLEVR.get_image", MagicMock( side_effect = [ "0" ] )): @@ -124,20 +124,20 @@ def test_validation_set(self): def test_test_set(self): """ - Tests the CLEVR test split. + Tests the test split. ..note: Test on real data is performed only if json file '~/data/CLEVR_v1.0/questions/CLEVR_test_questions.json' is found. """ # Empty config. config = ConfigInterface() - config.add_config_params({"split": "test"}) + config.add_config_params({"clevr_test": {"split": "test", "globals": {"image_height": "clevr_image_height", "image_width": "clevr_image_width"}}}) # Check the existence of test set. - if check_file_existence(path.expanduser('~/data/CLEVR_v1.0/questions'),'CLEVR_test_questions.json'): + if False: #check_file_existence(path.expanduser('~/data/CLEVR_v1.0/questions'),'CLEVR_test_questions.json'): # Create object. - clevr = CLEVR("CLEVR", config) + clevr = CLEVR("clevr_test", config["clevr_test"]) # Check dataset size. self.assertEqual(len(clevr), 149988) @@ -150,7 +150,7 @@ def test_test_set(self): # Mock up the load_dataset method. with patch( "ptp.components.problems.image_text_to_class.clevr.CLEVR.load_dataset", MagicMock( side_effect = [ dataset_content ] )): - clevr = CLEVR("CLEVR", config) + clevr = CLEVR("clevr_test", config["clevr_test"]) # Mock up the get_image method. with patch( "ptp.components.problems.image_text_to_class.clevr.CLEVR.get_image", MagicMock( side_effect = [ "0" ] )): diff --git a/tests/components/problems/gqa_tests.py b/tests/components/problems/gqa_tests.py new file mode 100644 index 0000000..7174f71 --- /dev/null +++ b/tests/components/problems/gqa_tests.py @@ -0,0 +1,204 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) tkornuta, IBM Corporation 2019 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__author__ = "Tomasz Kornuta" + +import unittest +from unittest.mock import MagicMock, patch +from os import path + +from ptp.components.utils.io import check_file_existence +from ptp.components.problems.image_text_to_class.gqa import GQA +from ptp.configuration.config_interface import ConfigInterface + + +class TestGQA(unittest.TestCase): + + + def test_training_0_split(self): + """ + Tests the training_0 split. + + ..note: + Test on real data is performed only if adequate json source file is found. + """ + # Empty config. + config = ConfigInterface() + config.add_config_params({"gqa_training_0": {"split": "training_0", "globals": {"image_height": "gqa_image_height", "image_width": "gqa_image_width"}}}) + + # Check the existence of test set. + if False: #check_file_existence(path.expanduser('~/data/gqa/questions1.2/train_all_questions'),'train_all_questions_0.json'): + + # Create object. + problem = GQA("gqa_training_0", config["gqa_training_0"]) + + # Check dataset size. + self.assertEqual(len(problem), 1430536) + + # Get sample. + sample = problem[0] + + else: + processed_dataset_content = [ {'sample_ids': '07333408', 'image_ids': '2375429', 'questions': 'What is on the white wall?', 'answers': 'pipe', 'full_answers': 'The pipe is on the wall.'} ] + + # Mock up the load_dataset method. + with patch( "ptp.components.problems.image_text_to_class.gqa.GQA.load_dataset", MagicMock( side_effect = [ processed_dataset_content ] )): + problem = GQA("gqa_training_0", config["gqa_training_0"]) + + # Mock up the get_image method. + with patch( "ptp.components.problems.image_text_to_class.gqa.GQA.get_image", MagicMock( side_effect = [ "0" ] )): + sample = problem[0] + + # Check sample. + self.assertEqual(sample['indices'], 0) + self.assertEqual(sample['sample_ids'], '07333408') + self.assertEqual(sample['image_ids'], '2375429') + self.assertEqual(sample['questions'], 'What is on the white wall?') + self.assertEqual(sample['answers'], 'pipe') + self.assertEqual(sample['full_answers'], 'The pipe is on the wall.') + + + def test_validation_split(self): + """ + Tests the validation split. + + ..note: + Test on real data is performed only if adequate json source file is found. + """ + # Empty config. + config = ConfigInterface() + config.add_config_params({"gqa_validation": {"split": "validation", "globals": {"image_height": "gqa_image_height", "image_width": "gqa_image_width"}}}) + + # Check the existence of test set. + if False: #check_file_existence(path.expanduser('~/data/gqa/questions1.2'),'val_all_questions.json'): + + # Create object. + problem = GQA("gqa_validation", config["gqa_validation"]) + + # Check dataset size. + self.assertEqual(len(problem), 2011853) + + # Get sample. + sample = problem[0] + + else: + processed_dataset_content = [ {'sample_ids': '05451384', 'image_ids': '2382986', 'questions': 'Are there blankets under the brown cat?', 'answers': 'no', 'full_answers': 'No, there is a towel under the cat.'} ] + + # Mock up the load_dataset method. + with patch( "ptp.components.problems.image_text_to_class.gqa.GQA.load_dataset", MagicMock( side_effect = [ processed_dataset_content ] )): + problem = GQA("gqa_validation", config["gqa_validation"]) + + # Mock up the get_image method. + with patch( "ptp.components.problems.image_text_to_class.gqa.GQA.get_image", MagicMock( side_effect = [ "0" ] )): + sample = problem[0] + + # Check sample. + self.assertEqual(sample['indices'], 0) + self.assertEqual(sample['sample_ids'], '05451384') + self.assertEqual(sample['image_ids'], '2382986') + self.assertEqual(sample['questions'], 'Are there blankets under the brown cat?') + self.assertEqual(sample['answers'], 'no') + self.assertEqual(sample['full_answers'], 'No, there is a towel under the cat.') + + + def test_test_dev_split(self): + """ + Tests the test_dev split. + + ..note: + Test on real data is performed only if adequate json source file is found. + """ + # Empty config. + config = ConfigInterface() + config.add_config_params({"gqa_testdev": {"split": "test_dev", "globals": {"image_height": "gqa_image_height", "image_width": "gqa_image_width"}}}) + + # Check the existence of test set. + if False: #check_file_existence(path.expanduser('~/data/gqa/questions1.2'),'testdev_all_questions.json'): + + # Create object. + problem = GQA("gqa_testdev", config["gqa_testdev"]) + + # Check dataset size. + self.assertEqual(len(problem), 172174) + + # Get sample. + sample = problem[0] + + else: + processed_dataset_content = [ {'sample_ids': '20968379', 'image_ids': 'n288870', 'questions': 'Do the shorts have dark color?', 'answers': 'yes', 'full_answers': 'Yes, the shorts are dark.'} ] + + # Mock up the load_dataset method. + with patch( "ptp.components.problems.image_text_to_class.gqa.GQA.load_dataset", MagicMock( side_effect = [ processed_dataset_content ] )): + problem = GQA("gqa_testdev", config["gqa_testdev"]) + + # Mock up the get_image method. + with patch( "ptp.components.problems.image_text_to_class.gqa.GQA.get_image", MagicMock( side_effect = [ "0" ] )): + sample = problem[0] + + # Check sample. + self.assertEqual(sample['indices'], 0) + self.assertEqual(sample['sample_ids'], '20968379') + self.assertEqual(sample['image_ids'], 'n288870') + self.assertEqual(sample['questions'], 'Do the shorts have dark color?') + self.assertEqual(sample['answers'], 'yes') + self.assertEqual(sample['full_answers'], 'Yes, the shorts are dark.') + + + def test_test_split(self): + """ + Tests the test split. + + ..note: + Test on real data is performed only if adequate json source file is found. + """ + # Empty config. + config = ConfigInterface() + config.add_config_params({"gqa_test": {"split": "test", "globals": {"image_height": "gqa_image_height", "image_width": "gqa_image_width"}}}) + + # Check the existence of test set. + if False: #check_file_existence(path.expanduser('~/data/gqa/questions1.2'),'test_all_questions.json'): + + # Create object. + problem = GQA("gqa_test", config["gqa_test"]) + + # Check dataset size. + self.assertEqual(len(problem), 1340048) + + # Get sample. + sample = problem[0] + + else: + processed_dataset_content = [ {'sample_ids': '201971873', 'image_ids': 'n15740', 'questions': 'Is the blanket to the right of a pillow?', 'answers': '', 'full_answers': ''} ] + + # Mock up the load_dataset method. + with patch( "ptp.components.problems.image_text_to_class.gqa.GQA.load_dataset", MagicMock( side_effect = [ processed_dataset_content ] )): + problem = GQA("gqa_test", config["gqa_test"]) + + # Mock up the get_image method. + with patch( "ptp.components.problems.image_text_to_class.gqa.GQA.get_image", MagicMock( side_effect = [ "0" ] )): + sample = problem[0] + + # Check sample. + self.assertEqual(sample['indices'], 0) + self.assertEqual(sample['sample_ids'], '201971873') + self.assertEqual(sample['image_ids'], 'n15740') + self.assertEqual(sample['questions'], 'Is the blanket to the right of a pillow?') + self.assertEqual(sample['answers'], '') + self.assertEqual(sample['full_answers'], '') + + +#if __name__ == "__main__": +# unittest.main() \ No newline at end of file