Skip to content
This repository was archived by the owner on Jul 18, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion configs/default/workers/offline_trainer.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
####################################################################
# Section defining all the default values of parameters used during training when using ptp-offline-trainer.

# If you want to use different section for "training" pass its name as command line argument '--training_section_name' to trainer (DEFAULT: training)
# Note: the following parameters will be (anyway) used as default values.
default_training:
Expand Down
35 changes: 26 additions & 9 deletions ptp/components/problems/image_text_to_class/clevr.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,13 +204,23 @@ def __init__(self, name, config):

# Display exemplary sample.
i = 0
sample = self.dataset[i]
# Check if this is a test set.
if "answer" not in sample.keys():
sample["answer"] = "<UNK>"
sample[self.key_question_type_ids] = -1
sample[self.key_question_type_names] = "<UNK>"
else:
sample[self.key_question_type_ids] = self.question_family_id_to_subtype_id_mapping[sample["question_family_index"]]
sample[self.key_question_type_names] = self.question_family_id_to_subtype_mapping[sample["question_family_index"]]

self.logger.info("Exemplary sample {} ({}):\n question_type: {} ({})\n image_ids: {}\n question: {}\n answer: {}".format(
i, self.dataset[i]["question_index"],
self.question_family_id_to_subtype_mapping[self.dataset[i]["question_family_index"]],
self.question_family_id_to_subtype_id_mapping[self.dataset[i]["question_family_index"]],
self.dataset[i]["image_filename"],
self.dataset[i]["question"],
self.dataset[i]["answer"]
i, sample["question_index"],
sample[self.key_question_type_ids],
sample[self.key_question_type_names],
sample["image_filename"],
sample["question"],
sample["answer"]
))


Expand Down Expand Up @@ -334,11 +344,18 @@ def __getitem__(self, index):
data_dict[self.key_questions] = item["question"]

# Return answer.
data_dict[self.key_answers] = item["answer"]
if "answer" in item.keys():
data_dict[self.key_answers] = item["answer"]
else:
data_dict[self.key_answers] = "<UNK>"

# Question type related variables.
data_dict[self.key_question_type_ids] = self.question_family_id_to_subtype_id_mapping[item["question_family_index"]]
data_dict[self.key_question_type_names] = self.question_family_id_to_subtype_mapping[item["question_family_index"]]
if "question_family_index" in item.keys():
data_dict[self.key_question_type_ids] = self.question_family_id_to_subtype_id_mapping[item["question_family_index"]]
data_dict[self.key_question_type_names] = self.question_family_id_to_subtype_mapping[item["question_family_index"]]
else:
data_dict[self.key_question_type_ids] = -1
data_dict[self.key_question_type_names] = "<UNK>"

# Return sample.
return data_dict
Expand Down
43 changes: 28 additions & 15 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,39 @@
from .app_state_tests import TestAppState
from .component_tests import TestComponent
from .config_interface_tests import TestConfigInterface
from .config_registry_tests import TestConfigRegistry
from .data_dict_tests import TestDataDict
from .data_definition_tests import TestDataDefinition
from .handshaking_tests import TestHandshaking
from .pipeline_tests import TestPipeline
from .problem_tests import TestProblem
from .sampler_factory_tests import TestSamplerFactory
from .samplers_tests import TestkFoldRandomSampler, TestkFoldWeightedRandomSampler
from .application.pipeline_tests import TestPipeline

from .components.component_tests import TestComponent
from .components.clevr_tests import TestCLEVR
from .components.problem_tests import TestProblem

from .configuration.config_interface_tests import TestConfigInterface
from .configuration.config_registry_tests import TestConfigRegistry
from .configuration.handshaking_tests import TestHandshaking

from .data_types.data_dict_tests import TestDataDict
from .data_types.data_definition_tests import TestDataDefinition

from .utils.app_state_tests import TestAppState
from .utils.sampler_factory_tests import TestSamplerFactory
from .utils.samplers_tests import TestkFoldRandomSampler, TestkFoldWeightedRandomSampler
from .utils.statistics_tests import TestStatistics

__all__ = [
'TestAppState',
# Application
'TestPipeline',
# Components
'TestComponent',
'TestCLEVR',
'TestProblem',
# Configuration
'TestConfigRegistry',
'TestConfigInterface',
'TestHandshaking',
# DataTypes
'TestDataDict',
'TestDataDefinition',
'TestHandshaking',
'TestPipeline',
'TestProblem',
# Utils
'TestAppState',
'TestSamplerFactory',
'TestkFoldRandomSampler',
'TestkFoldWeightedRandomSampler',
'TestStatistics',
]
125 changes: 125 additions & 0 deletions tests/components/clevr_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# -*- coding: utf-8 -*-
#
# Copyright (C) tkornuta, IBM Corporation 2019
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

__author__ = "Tomasz Kornuta"

import unittest
from os import path

from ptp.components.utils.io import check_file_existence
from ptp.components.problems.image_text_to_class.clevr import CLEVR
from ptp.data_types.data_definition import DataDefinition
from ptp.configuration.config_interface import ConfigInterface


class TestCLEVR(unittest.TestCase):

def __init__(self, *args, **kwargs):
super(TestCLEVR, self).__init__(*args, **kwargs)

# Check the existence of training set.
self.unittest_training_set = False # check_file_existence(path.expanduser('~/data/CLEVR_v1.0/questions'),'CLEVR_train_questions.json')
# Check the existence of validation set.
self.unittest_validation_set = check_file_existence(path.expanduser('~/data/CLEVR_v1.0/questions'),'CLEVR_val_questions.json')
# Check the existence of test set.
self.unittest_test_set = check_file_existence(path.expanduser('~/data/CLEVR_v1.0/questions'),'CLEVR_test_questions.json')


def test_training_set(self):
"""
Tests the CLEVR training split.

..note:
Test is performed only if json file '~/data/CLEVR_v1.0/questions/CLEVR_train_questions.json' is found.
"""
if not self.unittest_training_set:
return
# Empty config.
config = ConfigInterface()
config.add_config_params({"split": "training"})
clevr = CLEVR("CLEVR", config)

# Check dataset size.
self.assertEqual(len(clevr), 699989)

# Check sample.
sample = clevr[0]
self.assertEqual(sample['indices'], 0)
self.assertEqual(sample['image_ids'], 'CLEVR_train_000000.png')
self.assertEqual(sample['question_type_ids'], 4)
self.assertEqual(sample['question_type_names'], 'greater_than')
self.assertEqual(sample['questions'], 'Are there more big green things than large purple shiny cubes?')
self.assertEqual(sample['answers'], 'yes')


def test_validation_set(self):
"""
Tests the CLEVR validation split.

..note:
Test is performed only if json file '~/data/CLEVR_v1.0/questions/CLEVR_val_questions.json' is found.
"""
if not self.unittest_validation_set:
return
# Empty config.
config = ConfigInterface()
config.add_config_params({"split": "validation"})
clevr = CLEVR("CLEVR", config)

# Check dataset size.
self.assertEqual(len(clevr), 149991)

# Check sample.
sample = clevr[0]
self.assertEqual(sample['indices'], 0)
self.assertEqual(sample['image_ids'], 'CLEVR_val_000000.png')
self.assertEqual(sample['question_type_ids'], 10)
self.assertEqual(sample['question_type_names'], 'exist')
self.assertEqual(sample['questions'], 'Are there any other things that are the same shape as the big metallic object?')
self.assertEqual(sample['answers'], 'no')


def test_test_set(self):
"""
Tests the CLEVR test split.

..note:
Test is performed only if json file '~/data/CLEVR_v1.0/questions/CLEVR_test_questions.json' is found.
"""
if not self.unittest_test_set:
return
# Empty config.
config = ConfigInterface()
config.add_config_params({"split": "test"})
clevr = CLEVR("CLEVR", config)

# Check dataset size.
self.assertEqual(len(clevr), 149988)

# Check sample.
sample = clevr[0]
self.assertEqual(sample['indices'], 0)
self.assertEqual(sample['image_ids'], 'CLEVR_test_000000.png')
self.assertEqual(sample['question_type_ids'], -1)
self.assertEqual(sample['question_type_names'], '<UNK>')
self.assertEqual(sample['questions'], 'Is there anything else that is the same shape as the small brown matte object?')
self.assertEqual(sample['answers'], '<UNK>')




#if __name__ == "__main__":
# unittest.main()
File renamed without changes.
File renamed without changes.