Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Added GQA reader #4832

Merged
merged 9 commits into from
Dec 2, 2020
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ dataset at every epoch) and a `MultiTaskScheduler` (for ordering the instances w
- Added abstraction and concrete implementation for region detectors.
- Transformer toolkit to plug and play with modular components of transformer architectures.
- `VisionReader` and `VisionTextModel` base classes added. `VisualEntailment` and `VQA` inherit from these.
- Added reader for the GQA dataset

### Changed

Expand Down
1 change: 1 addition & 0 deletions allennlp/data/dataset_readers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from allennlp.data.dataset_readers.vision_reader import VisionReader
from allennlp.data.dataset_readers.vqav2 import VQAv2Reader
from allennlp.data.dataset_readers.visual_entailment import VisualEntailmentReader
from allennlp.data.dataset_readers.gqa import GQAReader
except ModuleNotFoundError as err:
if err.name not in ("detectron2", "torchvision"):
raise
161 changes: 161 additions & 0 deletions allennlp/data/dataset_readers/gqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from os import PathLike
from typing import (
Dict,
Union,
Optional,
Tuple,
)
import json
import os

from overrides import overrides
import torch
from torch import Tensor

from allennlp.common.file_utils import cached_path
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.fields import ArrayField, LabelField, TextField
from allennlp.data.image_loader import ImageLoader
from allennlp.data.instance import Instance
from allennlp.data.token_indexers import TokenIndexer
from allennlp.data.tokenizers import Tokenizer
from allennlp.modules.vision.grid_embedder import GridEmbedder
from allennlp.modules.vision.region_detector import RegionDetector
from allennlp.data.dataset_readers.vision_reader import VisionReader


@DatasetReader.register("gqa")
class GQAReader(VisionReader):
"""
Parameters
----------
image_dir: `str`
Path to directory containing `png` image files.
image_featurizer: `GridEmbedder`
The backbone image processor (like a ResNet), whose output will be passed to the region
detector for finding object boxes in the image.
region_detector: `RegionDetector`
For pulling out regions of the image (both coordinates and features) that will be used by
downstream models.
data_dir: `str`
Path to directory containing text files for each dataset split. These files contain
the sentences and metadata for each task instance.
tokenizer: `Tokenizer`, optional
token_indexers: `Dict[str, TokenIndexer]`
lazy : `bool`, optional
Whether to load data lazily. Passed to super class.
"""

def __init__(
self,
image_dir: Union[str, PathLike],
image_loader: ImageLoader,
image_featurizer: GridEmbedder,
region_detector: RegionDetector,
*,
feature_cache_dir: Optional[Union[str, PathLike]] = None,
data_dir: Optional[Union[str, PathLike]] = None,
tokenizer: Tokenizer = None,
token_indexers: Dict[str, TokenIndexer] = None,
cuda_device: Optional[Union[int, torch.device]] = None,
max_instances: Optional[int] = None,
image_processing_batch_size: int = 8,
skip_image_feature_extraction: bool = False,
) -> None:
super().__init__(
image_dir,
image_loader,
image_featurizer,
region_detector,
feature_cache_dir=feature_cache_dir,
tokenizer=tokenizer,
token_indexers=token_indexers,
cuda_device=cuda_device,
max_instances=max_instances,
image_processing_batch_size=image_processing_batch_size,
skip_image_feature_extraction=skip_image_feature_extraction,
)
self.data_dir = data_dir

@overrides
def _read(self, split_or_filename: str):

if not self.data_dir:
self.data_dir = "https://nlp.stanford.edu/data/gqa/questions1.2.zip!"

splits = {
"challenge_all": f"{self.data_dir}challenge_all_questions.json",
"challenge_balanced": f"{self.data_dir}challenge_balanced_questions.json",
"test_all": f"{self.data_dir}test_all_questions.json",
"test_balanced": f"{self.data_dir}test_balanced_questions.json",
"testdev_all": f"{self.data_dir}testdev_all_questions.json",
"testdev_balanced": f"{self.data_dir}testdev_balanced_questions.json",
"train_balanced": f"{self.data_dir}train_balanced_questions.json",
"train_all": f"{self.data_dir}train_all_questions",
"val_all": f"{self.data_dir}val_all_questions.json",
"val_balanced": f"{self.data_dir}val_balanced_questions.json",
}

filename = splits.get(split_or_filename, split_or_filename)

# If we're considering a directory of files (such as train_all)
# loop through each in file in generator
if os.path.isdir(filename):
files = [f"{filename}{file_path}" for file_path in os.listdir(filename)]
else:
files = [filename]

for data_file in files:
with open(cached_path(data_file, extract_archive=True)) as f:
questions_with_annotations = json.load(f)

# It would be much easier to just process one image at a time, but it's faster to process
# them in batches. So this code gathers up instances until it has enough to fill up a batch
# that needs processing, and then processes them all.
question_dicts = list(
self.shard_iterable(
questions_with_annotations[q_id] for q_id in questions_with_annotations
)
)

processed_images = self._process_image_paths(
self.images[f"{question_dict['imageId']}.jpg"] for question_dict in question_dicts
)

for question_dict, processed_image in zip(question_dicts, processed_images):
answers = {
"answer": question_dict["answer"],
"fullAnswer": question_dict["fullAnswer"],
jvstokes marked this conversation as resolved.
Show resolved Hide resolved
}
yield self.text_to_instance(question_dict["question"], processed_image, answers)

@overrides
def text_to_instance(
self, # type: ignore
question: str,
image: Union[str, Tuple[Tensor, Tensor]],
answers: Dict[str, str] = None,
*,
use_cache: bool = True,
) -> Instance:
tokenized_question = self._tokenizer.tokenize(question)
question_field = TextField(tokenized_question, None)
if isinstance(image, str):
features, coords = next(self._process_image_paths([image], use_cache=use_cache))
else:
features, coords = image

fields = {
"box_features": ArrayField(features),
"box_coordinates": ArrayField(coords),
"question": question_field,
}

if answers:
fields["label"] = LabelField(answers["answer"], label_namespace="answer")

return Instance(fields)

@overrides
def apply_token_indexers(self, instance: Instance) -> None:
instance["question"].token_indexers = self._token_indexers # type: ignore
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
43 changes: 43 additions & 0 deletions test_fixtures/data/gqa/question_dir/questions0.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
{
"202218649": {
"semantic": [
{
"operation": "select",
"dependencies": [],
"argument": "chalkboard (0)"
},
{
"operation": "relate",
"dependencies": [0],
"argument": "_,hanging above,s (12)"
},
{
"operation": "query",
"dependencies": [1],
"argument": "name"
}
],
"entailed": ["202218648"],
"equivalent": ["202218649"],
"question": "What is hanging above the chalkboard?",
"imageId": "n578564",
"isBalanced": true,
"groups": {
"global": "thing",
"local": "14-chalkboard_hanging above,s"
},
"answer": "picture",
"semanticStr": "select: chalkboard (0)->relate: _,hanging above,s (12) [0]->query: name [1]",
"annotations": {
"answer": {"0": "12"},
"question": {},
"fullAnswer": {"1": "12", "6": "0"}
},
"types": {
"detailed": "relS",
"semantic": "rel",
"structural": "query"
},
"fullAnswer": "The picture is hanging above the chalkboard."
}
}
51 changes: 51 additions & 0 deletions test_fixtures/data/gqa/question_dir/questions1.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
{
"20240871": {
"semantic": [
{
"operation": "select",
"dependencies": [],
"argument": "water (4)"
},
{
"operation": "relate",
"dependencies": [0],
"argument": "table,below,s (11)"
},
{
"operation": "verify shape",
"dependencies": [1],
"argument": "round"
},
{
"operation": "verify material",
"dependencies": [1],
"argument": "wood "
},
{"operation": "and",
"dependencies": [2, 3],
"argument": ""
}
],
"entailed": ["20240900", "20240892", "20240891", "20240890", "20240879", "20240896", "20240895", "20240894", "20240875", "20240897", "20240899", "20240898", "20240870", "20240878", "20240910", "20240877", "20240909", "20240886", "20240887", "20240882", "20240911", "20240872", "20240888", "20240889"], "equivalent": ["20240871", "20240870"],
"question": "Does the table below the water look wooden and round?",
"imageId": "n166008",
"isBalanced": false,
"groups": {
"global": null,
"local": "05-round_wood"
},
"answer": "yes",
"semanticStr": "select: water (4)->relate: table,below,s (11) [0]->verify shape: round [1]->verify material: wood [1]->and: [2, 3]",
"annotations": {
"answer": {},
"question": {"2": "11", "5": "4"},
"fullAnswer": {"2": "11"}
},
"types": {
"detailed": "verifyAttrs",
"semantic": "attr",
"structural": "logical"
},
"fullAnswer": "Yes, the table is wooden and round."
}
}
43 changes: 43 additions & 0 deletions test_fixtures/data/gqa/questions.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
{
"202218649": {
"semantic": [
{
"operation": "select",
"dependencies": [],
"argument": "chalkboard (0)"
},
{
"operation": "relate",
"dependencies": [0],
"argument": "_,hanging above,s (12)"
},
{
"operation": "query",
"dependencies": [1],
"argument": "name"
}
],
"entailed": ["202218648"],
"equivalent": ["202218649"],
"question": "What is hanging above the chalkboard?",
"imageId": "n578564",
"isBalanced": true,
"groups": {
"global": "thing",
"local": "14-chalkboard_hanging above,s"
},
"answer": "picture",
"semanticStr": "select: chalkboard (0)->relate: _,hanging above,s (12) [0]->query: name [1]",
"annotations": {
"answer": {"0": "12"},
"question": {},
"fullAnswer": {"1": "12", "6": "0"}
},
"types": {
"detailed": "relS",
"semantic": "rel",
"structural": "query"
},
"fullAnswer": "The picture is hanging above the chalkboard."
}
}