Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Make images easier to find for Visual Entailment #4878

Merged
merged 73 commits into from Dec 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
f1ed56a
implement TorchImageLoader
epwalsh Nov 25, 2020
2ff1488
implement ResnetBackbone
epwalsh Nov 26, 2020
17df0df
add resize + normalize to image loader
epwalsh Nov 30, 2020
2395fec
finalize FasterRcnnRegionDetector
epwalsh Dec 1, 2020
cdf5090
pin torchvision
epwalsh Dec 1, 2020
850735c
fix VQAv2Reader
epwalsh Dec 1, 2020
037da33
add box mask field
epwalsh Dec 2, 2020
12979cc
fix merge conflicts
epwalsh Dec 2, 2020
99e5928
dataset reader fixes
epwalsh Dec 2, 2020
38499d2
fix model tests
epwalsh Dec 2, 2020
a2de991
doc fixes
epwalsh Dec 2, 2020
cbd5800
fix merge conflicts
epwalsh Dec 3, 2020
ae4c839
add threshold parameters to FasterRcnnRegionDetector
epwalsh Dec 3, 2020
1252879
address @dirkgr comments
epwalsh Dec 4, 2020
d5bd779
mask fixes
epwalsh Dec 4, 2020
f3df7a2
shape comments
epwalsh Dec 4, 2020
877b7b5
add some more comments
epwalsh Dec 4, 2020
c491536
cache answers_by_question_id
epwalsh Dec 5, 2020
7007fb4
implement LocalCacheResource
epwalsh Dec 5, 2020
38c32ee
fix
epwalsh Dec 5, 2020
4b33a63
Merge branch 'vision' into torchvision
epwalsh Dec 5, 2020
3376d08
add read-only option to cache
epwalsh Dec 6, 2020
415b6ef
fix
epwalsh Dec 6, 2020
a603611
simplify data loader
epwalsh Dec 7, 2020
797e95e
make featurizer and detector optional in readers
epwalsh Dec 7, 2020
a4baae6
Cache in memory
dirkgr Dec 8, 2020
18ca15f
back pressure is important I guess
epwalsh Dec 9, 2020
625f6f3
Merge branch 'torchvision' of github.com:allenai/allennlp into torchv…
epwalsh Dec 9, 2020
0aa3331
merge
epwalsh Dec 9, 2020
273b453
Updated configs
dirkgr Dec 9, 2020
0a206f3
Fixes the way we apply masks
dirkgr Dec 9, 2020
f806eb7
Merge branch 'torchvision' of https://github.com/allenai/allennlp int…
dirkgr Dec 9, 2020
2f5ebf1
Use more of Jiasen's real settings
dirkgr Dec 9, 2020
48c7620
Upgrade the from_huggingface config
dirkgr Dec 9, 2020
f031d5e
Switch back to the images on corpnet
dirkgr Dec 9, 2020
23c3986
Fix random seeds
dirkgr Dec 9, 2020
8dde832
Bigger model needs smaller batch size
dirkgr Dec 9, 2020
8266b83
Adds ability to selectively ignore one input
dirkgr Dec 9, 2020
0f8faac
address some comments
epwalsh Dec 14, 2020
83292d7
Merge branch 'vision' into torchvision
epwalsh Dec 14, 2020
dd4dff8
format + lint
epwalsh Dec 14, 2020
de7f706
fixes
epwalsh Dec 14, 2020
945896f
Bring back bert-base configs
dirkgr Dec 15, 2020
89b98e3
Merge branch 'torchvision' of https://github.com/allenai/allennlp int…
dirkgr Dec 15, 2020
e598453
Merge branch 'vision' into torchvision
dirkgr Dec 15, 2020
c06b0b3
Merge branch 'vision' into torchvision
dirkgr Dec 15, 2020
2303d6e
Merge branch 'vision' into torchvision
dirkgr Dec 15, 2020
55b3c4f
Merge branch 'vision' into torchvision
dirkgr Dec 16, 2020
7af271f
fix error handling
epwalsh Dec 16, 2020
1bf7fc7
fix test
epwalsh Dec 16, 2020
962c965
Merge branch 'vision' into torchvision
epwalsh Dec 16, 2020
490b909
Adds the ability to read from a feature cache, but not run any featur…
dirkgr Dec 17, 2020
edf57f4
Update tests
dirkgr Dec 17, 2020
a9bc3da
Let's stick with "feature_cache"
dirkgr Dec 17, 2020
a5e958d
More epochs, more random
dirkgr Dec 17, 2020
12ba53c
Use the new parameters
dirkgr Dec 17, 2020
14f8a0b
Fix initialization
dirkgr Dec 17, 2020
87647a5
Merge remote-tracking branch 'origin/vision' into NewFeatures
dirkgr Dec 17, 2020
5e9b393
Make tests work, add some documentation
dirkgr Dec 18, 2020
732dab2
Remove the read_from_cache parameter
dirkgr Dec 18, 2020
c84be68
Cleanup of training configs
dirkgr Dec 18, 2020
dbfa73b
Typecheck
dirkgr Dec 18, 2020
ec41a70
Building docs right
dirkgr Dec 18, 2020
87e4087
Better settings for VQA
dirkgr Dec 19, 2020
9aadaa4
Merge branch 'NewFeatures' of https://github.com/allenai/allennlp int…
dirkgr Dec 19, 2020
2565490
Open cached paths when reading json lines
dirkgr Dec 19, 2020
4d16e3f
By default, autodetect GPUs when training
dirkgr Dec 19, 2020
e1f31b7
Switch to torchvision
dirkgr Dec 19, 2020
dab8370
Download training data from the web
dirkgr Dec 19, 2020
33f9723
Merge remote-tracking branch 'origin/vision' into VEPaths
dirkgr Dec 21, 2020
31e866d
This needs to stay at 1024 until we get the new featurization model
dirkgr Dec 21, 2020
4dd7a8f
Have a more descriptive error message when images are missing
dirkgr Dec 21, 2020
4e9c8ea
Update vilbert_ve_from_huggingface.jsonnet
AkshitaB Dec 21, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion allennlp/common/file_utils.py
Expand Up @@ -891,7 +891,7 @@ def open_compressed(
import bz2

open_fn = bz2.open
return open_fn(filename, mode=mode, encoding=encoding, **kwargs)
return open_fn(cached_path(filename), mode=mode, encoding=encoding, **kwargs)


def text_lines_from_file(filename: Union[str, PathLike], strip_lines: bool = True) -> Iterator[str]:
Expand Down
21 changes: 18 additions & 3 deletions allennlp/data/dataset_readers/visual_entailment.py
Expand Up @@ -37,9 +37,24 @@ def _read(self, file_path: str):
# It would be much easier to just process one image at a time, but it's faster to process
# them in batches. So this code gathers up instances until it has enough to fill up a batch
# that needs processing, and then processes them all.
processed_images = self._process_image_paths(
[self.images[info_dict["Flickr30K_ID"] + ".jpg"] for info_dict in info_dicts]
)
filenames = [info_dict["Flickr30K_ID"] + ".jpg" for info_dict in info_dicts]

try:
processed_images = self._process_image_paths(
[self.images[filename] for filename in filenames]
)
except KeyError as e:
missing_filename = e.args[0]
raise KeyError(
missing_filename,
f"We could not find an image with the name {missing_filename}. "
"Because of the size of the image datasets, we don't download them automatically. "
"Please download the images from"
"https://storage.googleapis.com/allennlp-public-data/snli-ve/flickr30k_images.tar.gz, "
"extract them into a directory, and set the image_dir parameter to point to that "
"directory. This dataset reader does not care about the exact directory structure. It "
"finds the images wherever they are.",
)
else:
processed_images = [None for i in range(len(info_dicts))] # type: ignore

Expand Down
22 changes: 11 additions & 11 deletions training_configs/vilbert_ve_from_huggingface.jsonnet
@@ -1,7 +1,7 @@
local model_name = "bert-base-uncased";
local effective_batch_size = 128;
local gpu_batch_size = 32;
local num_gpus = 4;
local num_gpus = 0;

local datadir = "/net/s3/allennlp/akshitab/data/SNLI-VE/data/";

Expand All @@ -10,7 +10,7 @@ local datadir = "/net/s3/allennlp/akshitab/data/SNLI-VE/data/";
"type": "visual-entailment",
"image_dir": datadir + "Flickr30K/flickr30k_images",
"feature_cache_dir": datadir + "/feature_cache",
"image_loader": "detectron",
"image_loader": "torch",
"image_featurizer": "resnet_backbone",
"region_detector": "faster_rcnn",
"tokenizer": {
Expand All @@ -23,16 +23,15 @@ local datadir = "/net/s3/allennlp/akshitab/data/SNLI-VE/data/";
"model_name": model_name
}
},
"max_instances": 30000,
"image_processing_batch_size": 16,
},
"validation_dataset_reader": self.dataset_reader,
"train_data_path": datadir + "snli_ve_train.jsonl",
"validation_data_path": datadir + "snli_ve_dev.jsonl",
"train_data_path": "https://storage.googleapis.com/allennlp-public-data/snli-ve/snli_ve_train.jsonl.gz",
"validation_data_path": "https://storage.googleapis.com/allennlp-public-data/snli-ve/snli_ve_dev.jsonl.gz",
"test_data_path": "https://storage.googleapis.com/allennlp-public-data/snli-ve/snli_ve_test.jsonl.gz",
"model": {
"type": "ve_vilbert_from_huggingface",
"model_name": model_name,
"image_feature_dim": 2048,
"image_feature_dim": 1024,
"image_hidden_size": 1024,
"image_num_attention_heads": 8,
"image_num_hidden_layers": 6,
Expand Down Expand Up @@ -60,14 +59,15 @@ local datadir = "/net/s3/allennlp/akshitab/data/SNLI-VE/data/";
"trainer": {
"optimizer": {
"type": "huggingface_adamw",
"lr": 4e-5
"lr": 4e-5,
"weight_decay": 0.01
},
"learning_rate_scheduler": {
"type": "linear_with_warmup",
"warmup_steps": 2000,
"num_steps_per_epoch": std.ceil(30000 / $["data_loader"]["batch_size"] / $["trainer"]["num_gradient_accumulation_steps"])
"num_steps_per_epoch": std.ceil(529527 / $["data_loader"]["batch_size"] / $["trainer"]["num_gradient_accumulation_steps"]),
"warmup_steps": std.ceil(self.num_steps_per_epoch / 2),
},
"validation_metric": "+f1",
"validation_metric": "+fscore",
"num_epochs": 20,
"num_gradient_accumulation_steps": effective_batch_size / gpu_batch_size / std.max(1, num_gpus)
},
Expand Down