Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Other VQA datasets (#4834)
Browse files Browse the repository at this point in the history
* Make the VQA reader work for the other datasets

* Also find pngs

* Really support pngs

* Remove debug code

* More logging

* Unexpected formatting

* Respect the device

* This is how your replace things in named tuples.

* Remove unused import

* This is how you override a method properly.

* This is how you set parameters in detectron.

* Also set the device for the region detector

* Training configs for all three datasets contained in VQA

* Bigger batches

* Bigger batches for image processing

* Fix vilbert-from-huggingface config

* Make the config switch modes for constructing vocab

* More vocab, more docs, better way of deriving vocab

* Modernize the from_huggingface config

* More updates to the from_huggingface config

* Better hyperparameters stolen from another project

* Fix for inverted parameter

* Formatting

* Throw a meaningful error message when we don't have images

* Add a warning that includes instructions for how to fix things

* Remove unused script

* Merge issue
  • Loading branch information
dirkgr committed Dec 3, 2020
1 parent e729e9a commit 7887119
Show file tree
Hide file tree
Showing 12 changed files with 567 additions and 120 deletions.
4 changes: 2 additions & 2 deletions allennlp/data/dataset_readers/gqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
cuda_device: Optional[Union[int, torch.device]] = None,
max_instances: Optional[int] = None,
image_processing_batch_size: int = 8,
skip_image_feature_extraction: bool = False,
run_image_feature_extraction: bool = True,
) -> None:
super().__init__(
image_dir,
Expand All @@ -73,7 +73,7 @@ def __init__(
cuda_device=cuda_device,
max_instances=max_instances,
image_processing_batch_size=image_processing_batch_size,
skip_image_feature_extraction=skip_image_feature_extraction,
run_image_feature_extraction=run_image_feature_extraction,
)
self.data_dir = data_dir

Expand Down
34 changes: 22 additions & 12 deletions allennlp/data/dataset_readers/vision_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,29 @@ class VisionReader(DatasetReader):
----------
image_dir: `str`
Path to directory containing `png` image files.
Path to directory containing image files. The structure of the directory doesn't matter. We
find images by finding filenames that match `*[image_id].jpg`.
image_featurizer: `GridEmbedder`
The backbone image processor (like a ResNet), whose output will be passed to the region
detector for finding object boxes in the image.
region_detector: `RegionDetector`
For pulling out regions of the image (both coordinates and features) that will be used by
downstream models.
tokenizer: `Tokenizer`, optional
The `Tokenizer` to use to tokenize the text. By default, this uses the tokenizer for
`"bert-base-uncased"`.
token_indexers: `Dict[str, TokenIndexer]`, optional
The `TokenIndexer` to use. By default, this uses the indexer for `"bert-base-uncased"`.
cuda_device: `Union[int, torch.device]`, optional
Either a torch device or a GPU number. This is the GPU we'll use to featurize the images.
max_instances: `int`, optional
image_processing_batch_size: `int`, optional (default = `8`)
skip_image_feature_extraction: `bool`, optional (default = `False`)
For debugging, you can use this parameter to limit the number of instances the reader
returns.
image_processing_batch_size: `int`
The number of images to process at one time while featurizing. Default is 8.
run_image_feature_extraction: `bool`
If this is set to `False`, we skip featurizing images completely. This can be useful
for debugging or for generating the vocabulary ahead of time. Default is `True`.
"""

def __init__(
Expand All @@ -66,12 +76,12 @@ def __init__(
region_detector: RegionDetector,
*,
feature_cache_dir: Optional[Union[str, PathLike]] = None,
tokenizer: Tokenizer = None,
token_indexers: Dict[str, TokenIndexer] = None,
tokenizer: Optional[Tokenizer] = None,
token_indexers: Optional[Dict[str, TokenIndexer]] = None,
cuda_device: Optional[Union[int, torch.device]] = None,
max_instances: Optional[int] = None,
image_processing_batch_size: int = 8,
skip_image_feature_extraction: bool = False
run_image_feature_extraction: bool = True,
) -> None:
super().__init__(
max_instances=max_instances,
Expand All @@ -89,6 +99,7 @@ def __init__(
cuda_device = -1
check_for_gpu(cuda_device)
self.cuda_device = int_to_device(cuda_device)
logger.info(f"Processing images on device {cuda_device}")

# tokenizers and indexers
if tokenizer is None:
Expand All @@ -98,14 +109,15 @@ def __init__(
token_indexers = {"tokens": PretrainedTransformerIndexer("bert-base-uncased")}
self._token_indexers = token_indexers

self.skip_image_feature_extraction = skip_image_feature_extraction
if not skip_image_feature_extraction:
self.run_image_feature_extraction = run_image_feature_extraction
if run_image_feature_extraction:
logger.info("Discovering images ...")
self.images = {
os.path.basename(filename): filename
for extension in {"png", "jpg"}
for filename in tqdm(
glob.iglob(os.path.join(image_dir, "**", "*.jpg"), recursive=True),
desc="Discovering images",
glob.iglob(os.path.join(image_dir, "**", f"*.{extension}"), recursive=True),
desc=f"Discovering {extension} images",
)
}
logger.info("Done discovering images")
Expand Down Expand Up @@ -162,9 +174,7 @@ def yield_batch():
images = images.to(self.cuda_device)
sizes = sizes.to(self.cuda_device)
featurized_images = self.image_featurizer(images, sizes)
print("done featurizing")
detector_results = self.region_detector(images, sizes, featurized_images)
print("done detecting")
features = detector_results["features"]
coordinates = detector_results["coordinates"]

Expand Down
2 changes: 1 addition & 1 deletion allennlp/data/dataset_readers/visual_entailment.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _read(self, file_path: str):
lines = json_lines_from_file(file_path)
info_dicts: List[Dict] = list(self.shard_iterable(lines)) # type: ignore

if not self.skip_image_feature_extraction:
if self.run_image_feature_extraction:
# It would be much easier to just process one image at a time, but it's faster to process
# them in batches. So this code gathers up instances until it has enough to fill up a batch
# that needs processing, and then processes them all.
Expand Down
135 changes: 96 additions & 39 deletions allennlp/data/dataset_readers/vqav2.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,40 +238,70 @@ class VQAv2Reader(VisionReader):
----------
image_dir: `str`
Path to directory containing `png` image files.
image_loader: `ImageLoader`
The image loader component used to load the images.
image_featurizer: `GridEmbedder`
The backbone image processor (like a ResNet), whose output will be passed to the region
detector for finding object boxes in the image.
region_detector: `RegionDetector`
For pulling out regions of the image (both coordinates and features) that will be used by
downstream models.
data_dir: `str`
Path to directory containing text files for each dataset split. These files contain
the sentences and metadata for each task instance.
answer_vocab: `Union[Vocabulary, str]`, optional
The vocabulary to use for answers. The reader will look into the `"answers"` namespace
in the vocabulary to find possible answers.
If this is given, the reader only outputs instances with answers contained in this vocab.
If this is not given, the reader outputs all instances with all answers.
If this is a URL or filename, we will download a previously saved vocabulary from there.
feature_cache_dir: `Union[str, PathLike]`, optional
An optional directory to cache the featurized images in. Featurizing images takes a long
time, and many images are duplicated, so we highly recommend to use this cache.
tokenizer: `Tokenizer`, optional
token_indexers: `Dict[str, TokenIndexer]`
lazy : `bool`, optional
Whether to load data lazily. Passed to super class.
The `Tokenizer` to use to tokenize the text. By default, this uses the tokenizer for
`"bert-base-uncased"`.
token_indexers: `Dict[str, TokenIndexer]`, optional
The `TokenIndexer` to use. By default, this uses the indexer for `"bert-base-uncased"`.
cuda_device: `Union[int, torch.device]`, optional
Either a torch device or a GPU number. This is the GPU we'll use to featurize the images.
max_instances: `int`, optional
For debugging, you can use this parameter to limit the number of instances the reader
returns.
image_processing_batch_size: `int`
The number of images to process at one time while featurizing. Default is 8.
run_image_feature_extraction: `bool`
If this is set to `False`, we skip featurizing images completely. This can be useful
for debugging or for generating the vocabulary ahead of time. Default is `True`.
multiple_answers_per_question: `bool`
VQA questions have multiple answers. By default, we use all of them, and give more
points to the more common answer. But VQA also has a special answer, the so-called
"multiple choice answer". If this is set to `False`, we only use that answer.
"""

def __init__(
self,
image_dir: Union[str, PathLike],
image_loader: ImageLoader,
image_featurizer: GridEmbedder,
region_detector: RegionDetector,
image_dir: Union[str, PathLike] = None,
*,
answer_vocab: Union[
Vocabulary, str
] = "https://storage.googleapis.com/allennlp-public-data/vqav2/vqav2_vocab.tar.gz",
answer_vocab: Optional[Union[Vocabulary, str]] = None,
feature_cache_dir: Optional[Union[str, PathLike]] = None,
tokenizer: Tokenizer = None,
token_indexers: Dict[str, TokenIndexer] = None,
tokenizer: Optional[Tokenizer] = None,
token_indexers: Optional[Dict[str, TokenIndexer]] = None,
cuda_device: Optional[Union[int, torch.device]] = None,
max_instances: Optional[int] = None,
image_processing_batch_size: int = 8,
skip_image_feature_extraction: bool = False,
keep_unanswerable_questions: bool = True,
run_image_feature_extraction: bool = True,
multiple_answers_per_question: bool = True,
) -> None:
if image_dir is None:
raise ValueError(
"Because of the size of the image datasets, we don't download them automatically. "
"Please go to https://visualqa.org/download.html, download the datasets you need, "
"and set the image_dir parameter to point to your download location. This dataset "
"reader does not care about the exact directory structure. It finds the images "
"wherever they are."
)

super().__init__(
image_dir,
image_loader,
Expand All @@ -283,11 +313,11 @@ def __init__(
cuda_device=cuda_device,
max_instances=max_instances,
image_processing_batch_size=image_processing_batch_size,
skip_image_feature_extraction=skip_image_feature_extraction,
run_image_feature_extraction=run_image_feature_extraction,
)

# read answer vocab
if keep_unanswerable_questions:
if answer_vocab is None:
self.answer_vocab = None
else:
if isinstance(answer_vocab, str):
Expand All @@ -298,6 +328,25 @@ def __init__(
for a in answer_vocab.get_token_to_index_vocabulary("answers").keys()
)

if run_image_feature_extraction:
# normalize self.images some more
# At this point, self.images maps filenames to full paths, but we want to map image ids to full paths.
filename_re = re.compile(r".*(\d{12})\.((jpg)|(png))")

def id_from_filename(filename: str) -> Optional[int]:
match = filename_re.fullmatch(filename)
if match is None:
return None
return int(match.group(1))

self.images = {
id_from_filename(name): full_path for name, full_path in self.images.items()
}
if None in self.images:
del self.images[None]

self.multiple_answers_per_question = multiple_answers_per_question

@overrides
def _read(self, splits_or_list_of_splits: Union[str, List[str]]):
# if we are given a list of splits, concatenate them
Expand Down Expand Up @@ -369,48 +418,60 @@ class Split(NamedTuple):
try:
split = splits[split_name]
except KeyError:
raise ValueError(
f"Unrecognized split: {split_name}. We require a split, not a filename, for "
"VQA because the image filenames require using the split."
)
raise ValueError(f"Unrecognized split: {split_name}.")

annotations_by_question_id = {}
answers_by_question_id = {}
if split.annotations is not None:
with open(cached_path(split.annotations, extract_archive=True)) as f:
annotations = json.load(f)
for a in annotations["annotations"]:
annotations_by_question_id[a["question_id"]] = a
qid = a["question_id"]
answer_counts: MutableMapping[str, int] = Counter()
if self.multiple_answers_per_question:
for answer in (answer_dict["answer"] for answer_dict in a["answers"]):
answer_counts[preprocess_answer(answer)] += 1
else:
answer_counts[preprocess_answer(a["multiple_choice_answer"])] = 1
answers_by_question_id[qid] = answer_counts

questions = []
with open(cached_path(split.questions, extract_archive=True)) as f:
questions_file = json.load(f)
image_subtype = questions_file["data_subtype"]
for ques in questions_file["questions"]:
ques["image_subtype"] = image_subtype
questions.append(ques)
questions = questions[question_slice]

question_dicts = list(self.shard_iterable(questions))
processed_images: Iterable[Optional[Tuple[Tensor, Tensor]]]
if not self.skip_image_feature_extraction:
if self.run_image_feature_extraction:
# It would be much easier to just process one image at a time, but it's faster to process
# them in batches. So this code gathers up instances until it has enough to fill up a batch
# that needs processing, and then processes them all.
processed_images = self._process_image_paths(
self.images[
f"COCO_{question_dict['image_subtype']}_{question_dict['image_id']:012d}.jpg"

try:
image_paths = [
self.images[int(question_dict["image_id"])] for question_dict in question_dicts
]
for question_dict in question_dicts
)
except KeyError as e:
missing_id = e.args[0]
raise KeyError(
missing_id,
f"We could not find an image with the id {missing_id}. "
"Because of the size of the image datasets, we don't download them automatically. "
"Please go to https://visualqa.org/download.html, download the datasets you need, "
"and set the image_dir parameter to point to your download location. This dataset "
"reader does not care about the exact directory structure. It finds the images "
"wherever they are.",
)

processed_images = self._process_image_paths(image_paths)
else:
processed_images = [None for i in range(len(question_dicts))]

attempted_instances_count = 0
failed_instances_count = 0
for question_dict, processed_image in zip(question_dicts, processed_images):
answers = annotations_by_question_id.get(question_dict["question_id"])
if answers is not None:
answers = answers["answers"]
answers = answers_by_question_id.get(question_dict["question_id"])

instance = self.text_to_instance(question_dict["question"], processed_image, answers)
attempted_instances_count += 1
Expand All @@ -431,7 +492,7 @@ def text_to_instance(
self, # type: ignore
question: str,
image: Union[str, Tuple[Tensor, Tensor]],
answers: Optional[List[Dict[str, str]]] = None,
answer_counts: Optional[MutableMapping[str, int]] = None,
*,
use_cache: bool = True,
) -> Optional[Instance]:
Expand All @@ -452,13 +513,9 @@ def text_to_instance(
fields["box_features"] = ArrayField(features)
fields["box_coordinates"] = ArrayField(coords)

if answers:
if answer_counts is not None:
answer_fields = []
weights = []
answer_counts: MutableMapping[str, int] = Counter()

for answer in (a["answer"] for a in answers):
answer_counts[preprocess_answer(answer)] += 1

for answer, count in answer_counts.items():
if self.answer_vocab is None or answer in self.answer_vocab:
Expand Down
11 changes: 11 additions & 0 deletions allennlp/modules/vision/grid_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,14 @@ def get_output_dim(self) -> int:

def get_stride(self) -> int:
return self.backbone.output_shape()["res4"].stride

def to(self, device):
if isinstance(device, int) or isinstance(device, torch.device):
if self._pipeline_object is not None:
self._pipeline_object.model.to(device)
if isinstance(device, torch.device):
device = device.index
self.flat_parameters = self.flat_parameters._replace(device=device)
return self
else:
return super().to(device)
11 changes: 11 additions & 0 deletions allennlp/modules/vision/region_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,14 @@ def forward(
"class_probs": probs_tensor,
"num_regions": batch_num_detections,
}

def to(self, device):
if isinstance(device, int) or isinstance(device, torch.device):
if self._model_object is not None:
self._model_object.model.to(device)
if isinstance(device, torch.device):
device = device.index
self.flat_parameters = self.flat_parameters._replace(device=device)
return self
else:
return super().to(device)
22 changes: 0 additions & 22 deletions test_fixtures/data/vqav2/save_answer_vocab.py

This file was deleted.

0 comments on commit 7887119

Please sign in to comment.