Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Make GQA work (#4884)
Browse files Browse the repository at this point in the history
* Refactored shared code

* typecheck fix

* rebase

* Refactored shared code

* typecheck fix

* rebase

* Cleaned up GQA reader tests

* Modify instance format for vilbert-vqa model

* update for vision branch bump

* Adding training config for GQA

* Unnamed variable

* Various GQA fixes

* Temporary extra configs needed to make vocab

* Remove unused file

* Optimize VQA score instead of F-Score

* Use our newly created vocab

* Remove temporary configs

* Don't fail when we don't need to create a directory

* Make a config that works on the servers as well

* Update comment

* A new command to count instances

* Temporary config to count instances

* Undo temporary changes

* Put in the correct number of steps per epoch

* Remove this number from the config because it's almost certainly wrong

* Don't put Fields in Tuples

* Formatting

* More informative error message when batches are heterogeneous

* Formatting

* Not my type

* Generate the fields properly when answers are missing

* Properly discard instances with missing answers

* Changelog

* Update number of steps per epoch

* Adds a config for balanced GQA

* fix file_utils extract with directory

* fix Batch._check_types

* Fill in URL

Co-authored-by: Jackson Stokes <jacksons@Jacksons-MacBook-Pro.local>
Co-authored-by: Akshita Bhagia <akshita23bhagia@gmail.com>
Co-authored-by: Evan Pete Walsh <epwalsh10@gmail.com>
  • Loading branch information
4 people committed Jan 4, 2021
1 parent fbab0bd commit 15d32da
Show file tree
Hide file tree
Showing 17 changed files with 361 additions and 168 deletions.
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ dataset at every epoch) and a `MultiTaskScheduler` (for ordering the instances w
- Transformer toolkit to plug and play with modular components of transformer architectures.
- `VisionReader` and `VisionTextModel` base classes added. `VisualEntailment` and `VQA` inherit from these.
- Added reader for the GQA dataset
- Added a config to traing a GQA model
- Added a command to count the number of instances we're going to be training with

### Changed

Expand All @@ -43,6 +45,15 @@ dataset at every epoch) and a `MultiTaskScheduler` (for ordering the instances w
With the `PyTorchDataLoader` this is controlled with the `lazy` parameter, but with
the `MultiProcessDataLoading` this is controlled by the `max_instances_in_memory` setting.
- `ArrayField` is now called `TensorField`, and implemented in terms of torch tensors, not numpy.
- If you are trying to create a heterogeneous batch, you now get a better error message.
- Readers using the new vision features now explicitly log how they are featurizing images.

### Fixed

- The `build-vocab` command no longer crashes when the resulting vocab file is
in the current working directory.
- VQA models now use the `vqa_score` metric for early stopping. This results in
much better scores.


## Unreleased (1.x branch)
Expand Down
1 change: 1 addition & 0 deletions allennlp/commands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from allennlp.commands.subcommand import Subcommand
from allennlp.commands.test_install import TestInstall
from allennlp.commands.train import Train
from allennlp.commands.count_instances import CountInstances
from allennlp.common.plugins import import_plugins
from allennlp.common.util import import_module_and_submodules

Expand Down
3 changes: 2 additions & 1 deletion allennlp/commands/build_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def build_vocab_from_args(args: argparse.Namespace):
raise RuntimeError(f"{args.output_path} already exists. Use --force to overwrite.")

output_directory = os.path.dirname(args.output_path)
os.makedirs(output_directory, exist_ok=True)
if len(output_directory) > 0:
os.makedirs(output_directory, exist_ok=True)

params = Params.from_file(args.param_path)

Expand Down
52 changes: 52 additions & 0 deletions allennlp/commands/count_instances.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""
Subcommand for counting the number of instances from a training config.
"""

import argparse
import logging

from overrides import overrides

from allennlp.commands.subcommand import Subcommand
from allennlp.common.params import Params


logger = logging.getLogger(__name__)


@Subcommand.register("count-instances")
class CountInstances(Subcommand):
@overrides
def add_subparser(self, parser: argparse._SubParsersAction) -> argparse.ArgumentParser:
description = """Count the number of training instances in an experiment config file."""
subparser = parser.add_parser(self.name, description=description, help=description)
subparser.add_argument("param_path", type=str, help="path to an experiment config file")

subparser.add_argument(
"-o",
"--overrides",
type=str,
default="",
help=(
"a json(net) structure used to override the experiment configuration, e.g., "
"'{\"vocabulary.min_count.labels\": 10}'. Nested parameters can be specified either"
" with nested dictionaries or with dot syntax."
),
)

subparser.set_defaults(func=count_instances_from_args)

return subparser


def count_instances_from_args(args: argparse.Namespace):
from allennlp.training.util import data_loaders_from_params

params = Params.from_file(args.param_path)

data_loaders = data_loaders_from_params(params, train=True, validation=False, test=False)
instances = sum(
1 for data_loader in data_loaders.values() for _ in data_loader.iter_instances()
)

print(f"Success! One epoch of training contains {instances} instances.")
6 changes: 5 additions & 1 deletion allennlp/common/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,11 @@ def cached_path(
# Normalize the path.
url_or_filename = os.path.abspath(url_or_filename)

if extract_archive and (is_zipfile(file_path) or tarfile.is_tarfile(file_path)):
if (
extract_archive
and os.path.isfile(file_path)
and (is_zipfile(file_path) or tarfile.is_tarfile(file_path))
):
# We'll use a unique directory within the cache to root to extract the archive to.
# The name of the directoy is a hash of the resource file path and it's modification
# time. That way, if the file changes, we'll know when to extract it again.
Expand Down
26 changes: 19 additions & 7 deletions allennlp/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

import logging
from collections import defaultdict
from collections import defaultdict, Counter
from typing import Dict, Iterable, Iterator, List, Union

import numpy
Expand Down Expand Up @@ -39,12 +39,24 @@ def _check_types(self) -> None:
"""
Check that all the instances have the same types.
"""
all_instance_fields_and_types: List[Dict[str, str]] = [
{k: v.__class__.__name__ for k, v in x.fields.items()} for x in self.instances
]
# Check all the field names and Field types are the same for every instance.
if not all(all_instance_fields_and_types[0] == x for x in all_instance_fields_and_types):
raise ConfigurationError("You cannot construct a Batch with non-homogeneous Instances.")
field_name_to_type_counters: Dict[str, Counter] = defaultdict(lambda: Counter())
field_counts: Counter = Counter()
for instance in self.instances:
for field_name, value in instance.fields.items():
field_name_to_type_counters[field_name][value.__class__.__name__] += 1
field_counts[field_name] += 1
for field_name, type_counters in field_name_to_type_counters.items():
if len(type_counters) > 1:
raise ConfigurationError(
"You cannot construct a Batch with non-homogeneous Instances. "
f"Field '{field_name}' has {len(type_counters)} different types: "
f"{', '.join(type_counters.keys())}"
)
if field_counts[field_name] != len(self.instances):
raise ConfigurationError(
"You cannot construct a Batch with non-homogeneous Instances. "
f"Field '{field_name}' present in some Instances but not others."
)

def get_padding_lengths(self) -> Dict[str, Dict[str, int]]:
"""
Expand Down
90 changes: 59 additions & 31 deletions allennlp/data/dataset_readers/gqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Union,
Optional,
Tuple,
Iterable,
)
import json
import os
Expand Down Expand Up @@ -111,73 +112,100 @@ def _read(self, split_or_filename: str):
}

filename = splits.get(split_or_filename, split_or_filename)
filename = cached_path(filename, extract_archive=True)

# If we're considering a directory of files (such as train_all)
# loop through each in file in generator
if os.path.isdir(filename):
files = [f"{filename}{file_path}" for file_path in os.listdir(filename)]
files = [os.path.join(filename, file_path) for file_path in os.listdir(filename)]
else:
files = [filename]

# Ensure order is deterministic.
files.sort()

for data_file in files:
with open(cached_path(data_file, extract_archive=True)) as f:
with open(data_file) as f:
questions_with_annotations = json.load(f)

# It would be much easier to just process one image at a time, but it's faster to process
# them in batches. So this code gathers up instances until it has enough to fill up a batch
# that needs processing, and then processes them all.
question_dicts = list(
self.shard_iterable(
questions_with_annotations[q_id] for q_id in questions_with_annotations
)
)

processed_images = self._process_image_paths(
self.images[f"{question_dict['imageId']}.jpg"] for question_dict in question_dicts
)
processed_images: Iterable[Optional[Tuple[Tensor, Tensor]]]
if self.produce_featurized_images:
# It would be much easier to just process one image at a time, but it's faster to process
# them in batches. So this code gathers up instances until it has enough to fill up a batch
# that needs processing, and then processes them all.
filenames = [f"{question_dict['imageId']}.jpg" for question_dict in question_dicts]
try:
processed_images = self._process_image_paths(
self.images[filename] for filename in filenames
)
except KeyError as e:
missing_filename = e.args[0]
raise KeyError(
missing_filename,
f"We could not find an image with the name {missing_filename}. "
"Because of the size of the image datasets, we don't download them automatically. "
"Please download the images from"
"https://nlp.stanford.edu/data/gqa/images.zip, "
"extract them into a directory, and set the image_dir parameter to point to that "
"directory. This dataset reader does not care about the exact directory structure. It "
"finds the images wherever they are.",
)
else:
processed_images = [None] * len(question_dicts)

for question_dict, processed_image in zip(question_dicts, processed_images):
answer = {
"answer": question_dict["answer"],
}
yield self.text_to_instance(question_dict["question"], processed_image, answer)
instance = self.text_to_instance(question_dict["question"], processed_image, answer)
if instance is not None:
yield instance

@overrides
def text_to_instance(
self, # type: ignore
question: str,
image: Union[str, Tuple[Tensor, Tensor]],
answer: Dict[str, str] = None,
image: Optional[Union[str, Tuple[Tensor, Tensor]]],
answer: Optional[Dict[str, str]] = None,
*,
use_cache: bool = True,
) -> Instance:
) -> Optional[Instance]:
from allennlp.data import Field

tokenized_question = self._tokenizer.tokenize(question)
question_field = TextField(tokenized_question, None)
if isinstance(image, str):
features, coords = next(self._process_image_paths([image], use_cache=use_cache))
else:
features, coords = image
fields: Dict[str, Field] = {"question": TextField(tokenized_question, None)}

fields = {
"box_features": ArrayField(features),
"box_coordinates": ArrayField(coords),
"box_mask": ArrayField(
if answer is not None:
labels_fields = []
weights = []
if not self.answer_vocab or answer["answer"] in self.answer_vocab:
labels_fields.append(LabelField(answer["answer"], label_namespace="answers"))
weights.append(1.0)

if len(labels_fields) <= 0:
return None

fields["label_weights"] = ArrayField(torch.tensor(weights))
fields["labels"] = ListField(labels_fields)

if image is not None:
if isinstance(image, str):
features, coords = next(self._process_image_paths([image], use_cache=use_cache))
else:
features, coords = image
fields["box_features"] = ArrayField(features)
fields["box_coordinates"] = ArrayField(coords)
fields["box_mask"] = ArrayField(
features.new_ones((features.shape[0],), dtype=torch.bool),
padding_value=False,
dtype=torch.bool,
),
"question": question_field,
}

if answer:
if not self.answer_vocab or answer["answer"] in self.answer_vocab:
fields["labels"] = ListField(
[LabelField(answer["answer"], label_namespace="answers")]
)
fields["label_weights"] = ArrayField(torch.tensor([1.0]))
)

return Instance(fields)

Expand Down
15 changes: 12 additions & 3 deletions allennlp/data/dataset_readers/vision_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def __init__(
self._coordinates_cache_instance: Optional[MutableMapping[str, Tensor]] = None

# image processors
self.image_loader = None
if image_loader and image_featurizer and region_detector:
if cuda_device is None:
if torch.cuda.device_count() > 0:
Expand All @@ -152,9 +153,17 @@ def __init__(
self._region_detector = None
self.image_processing_batch_size = image_processing_batch_size

self.produce_featurized_images = (
image_loader and image_featurizer and region_detector
) or (self.feature_cache_dir and self.coordinates_cache_dir)
self.produce_featurized_images = False
if self.feature_cache_dir and self.coordinates_cache_dir:
logger.info(f"Featurizing images with a cache at {self.feature_cache_dir}")
self.produce_featurized_images = True
if image_loader and image_featurizer and region_detector:
if self.produce_featurized_images:
logger.info("Falling back to a full image featurization pipeline")
else:
logger.info("Featurizing images with a full image featurization pipeline")
self.produce_featurized_images = True

if self.produce_featurized_images:
if image_dir is None:
if image_loader and image_featurizer and region_detector:
Expand Down
2 changes: 1 addition & 1 deletion allennlp/data/dataset_readers/vqav2.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ class Split(NamedTuple):

processed_images = self._process_image_paths(image_paths)
else:
processed_images = [None for i in range(len(question_dicts))]
processed_images = [None for _ in range(len(question_dicts))]

attempted_instances_count = 0
failed_instances_count = 0
Expand Down
4 changes: 2 additions & 2 deletions allennlp/training/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,8 +433,8 @@ def make_vocab_from_params(
"datasets_for_vocab_creation", None
)
# Do a quick sanity check here. There's no need to load any datasets if the vocab
# type is "empty".
if datasets_for_vocab_creation is None and vocab_params.get("type") in ("empty", "from_files"):
# type is "empty" or "from_files".
if datasets_for_vocab_creation is None and vocab_params.get("type") in {"empty", "from_files"}:
datasets_for_vocab_creation = []

data_loaders: Dict[str, DataLoader]
Expand Down

0 comments on commit 15d32da

Please sign in to comment.