This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
/
visual_entailment.py
112 lines (91 loc) · 4.13 KB
/
visual_entailment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import logging
from typing import (
Dict,
List,
Union,
Optional,
Tuple,
)
from overrides import overrides
import torch
from torch import Tensor
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.fields import Field, ArrayField, LabelField, TextField
from allennlp.data.instance import Instance
from allennlp.common.file_utils import json_lines_from_file
from allennlp_models.vision.dataset_readers.vision_reader import VisionReader
logger = logging.getLogger(__name__)
@DatasetReader.register("visual-entailment")
class VisualEntailmentReader(VisionReader):
"""
The dataset reader for visual entailment.
"""
@overrides
def _read(self, file_path: str):
split_prefix = "https://storage.googleapis.com/allennlp-public-data/snli-ve/"
splits = {
"dev": split_prefix + "snli_ve_dev.jsonl.gz",
"test": split_prefix + "snli_ve_test.jsonl.gz",
"train": split_prefix + "snli_ve_train.jsonl.gz",
}
file_path = splits.get(file_path, file_path)
lines = json_lines_from_file(file_path)
info_dicts: List[Dict] = list(self.shard_iterable(lines)) # type: ignore
if self.produce_featurized_images:
# It would be much easier to just process one image at a time, but it's faster to process
# them in batches. So this code gathers up instances until it has enough to fill up a batch
# that needs processing, and then processes them all.
filenames = [info_dict["Flickr30K_ID"] + ".jpg" for info_dict in info_dicts]
try:
processed_images = self._process_image_paths(
[self.images[filename] for filename in filenames]
)
except KeyError as e:
missing_filename = e.args[0]
raise KeyError(
missing_filename,
f"We could not find an image with the name {missing_filename}. "
"Because of the size of the image datasets, we don't download them automatically. "
"Please download the images from"
"https://storage.googleapis.com/allennlp-public-data/snli-ve/flickr30k_images.tar.gz, "
"extract them into a directory, and set the image_dir parameter to point to that "
"directory. This dataset reader does not care about the exact directory structure. It "
"finds the images wherever they are.",
)
else:
processed_images = [None for _ in range(len(info_dicts))] # type: ignore
for info_dict, processed_image in zip(info_dicts, processed_images):
hypothesis = info_dict["sentence2"]
answer = info_dict["gold_label"]
instance = self.text_to_instance(processed_image, hypothesis, answer)
yield instance
@overrides
def text_to_instance(
self, # type: ignore
image: Union[str, Tuple[Tensor, Tensor]],
hypothesis: str,
label: Optional[str] = None,
*,
use_cache: bool = True,
) -> Instance:
tokenized_hypothesis = self._tokenizer.tokenize(hypothesis)
hypothesis_field = TextField(tokenized_hypothesis, None)
fields: Dict[str, Field] = {"hypothesis": hypothesis_field}
if image is not None:
if isinstance(image, str):
features, coords = next(self._process_image_paths([image], use_cache=use_cache))
else:
features, coords = image
fields["box_features"] = ArrayField(features)
fields["box_coordinates"] = ArrayField(coords)
fields["box_mask"] = ArrayField(
features.new_ones((features.shape[0],), dtype=torch.bool),
padding_value=False,
dtype=torch.bool,
)
if label:
fields["labels"] = LabelField(label)
return Instance(fields)
@overrides
def apply_token_indexers(self, instance: Instance) -> None:
instance["hypothesis"].token_indexers = self._token_indexers # type: ignore