Skip to content

Commit

Permalink
Add VQA2, VisualGenome, FBResNet152 (for pytorch)
Browse files Browse the repository at this point in the history
Factory

- vqa models, convnets and vqa datasets can be created via factories

VQA 2.0

- VQA2(AbstractVQA) added

VisualGenome

- VisualGenome(AbstractVQADataset) added for merging with VQA datasets
- VisualGenomeImages(AbstractImagesDataset) added to extract features
- `extract.py` now allows to extract VisualGenome features

Variable features size

- `extract.py` now allows to extract from images of size != 448 via cli arg `--size`
- FeaturesDataset now have an optional `opt['size']` parameter

FBResNet152

- `convnets.py` provides support for external pretrained-models as well as ResNets from torchvision
- especially FBResNet152 is the porting of fbresnet152torch from torch7 used until now
  • Loading branch information
Cadene committed Jul 18, 2017
1 parent 57752d6 commit 42391fd
Show file tree
Hide file tree
Showing 23 changed files with 1,070 additions and 126 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@
[submodule "vqa/external/skip-thoughts.torch"]
path = vqa/external/skip-thoughts.torch
url = https://github.com/Cadene/skip-thoughts.torch.git
[submodule "vqa/external/pretrained-models.pytorch"]
path = vqa/external/pretrained-models.pytorch
url = https://github.com/Cadene/pretrained-models.pytorch.git
75 changes: 55 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
# Visual Question Answering in pytorch

This repo was made by [Remi Cadene](http://remicadene.com) (LIP6) and [Hedi Ben-Younes](https://twitter.com/labegne) (LIP6-Heuritech), two PhD Students working on VQA at [UPMC-LIP6](http://lip6.fr) and their professors [Matthieu Cord](http://webia.lip6.fr/~cord) (LIP6) and [Nicolas Thome](http://webia.lip6.fr/~thomen) (LIP6-CNAM). We developped this code in the frame of a research paper called [MUTAN: Multimodal Tucker Fusion for VQA](https://arxiv.org/abs/1705.06676) which is (as far as we know) the current state-of-the-art on the [VQA-1 dataset](http://visualqa.org).
This repo was made by [Remi Cadene](http://remicadene.com) (LIP6) and [Hedi Ben-Younes](https://twitter.com/labegne) (LIP6-Heuritech), two PhD Students working on VQA at [UPMC-LIP6](http://lip6.fr) and their professors [Matthieu Cord](http://webia.lip6.fr/~cord) (LIP6) and [Nicolas Thome](http://webia.lip6.fr/~thomen) (LIP6-CNAM). We developped this code in the frame of a research paper called [MUTAN: Multimodal Tucker Fusion for VQA](https://arxiv.org/abs/1705.06676) which is (as far as we know) the current state-of-the-art on the [VQA 1.0 dataset](http://visualqa.org).

The goal of this repo is two folds:
- to make it easier to reproduce our results,
- to provide an efficient and modular code base to the community for further research on other VQA datasets.

If you have any questions about our code or model, don't hesitate to contact us or to submit any issues. Pull request are welcome!

#### News:

- coming soon: pretrained models on VQA2, features of FBResnet152, web app demo
- 18th july 2017: VQA2, VisualGenome, FBResnet152 (for pytorch) added
- 16th july 2017: paper accepted at ICCV2017
- 30th may 2017: poster accepted at CVPR2017 (VQA Workshop)

#### Summary:

* [Introduction](#introduction)
Expand All @@ -27,7 +34,10 @@ If you have any questions about our code or model, don't hesitate to contact us
* [Models](#models)
* [Quick examples](#quick-examples)
* [Extract features from COCO](#extract-features-from-coco)
* [Train models on VQA](#train-models-on-vqa)
* [Extract features from VisualGenome](#extract-features-from-visualgenome)
* [Train models on VQA 1.0](#train-models-on-vqa-1-0)
* [Train models on VQA 2.0](#train-models-on-vqa-2-0)
* [Train models on VQA + VisualGenome](#train-models-on-vqa-2-0)
* [Monitor training](#monitor-training)
* [Restart training](#restart-training)
* [Evaluate models on VQA](#evaluate-models-on-vqa)
Expand Down Expand Up @@ -108,7 +118,7 @@ Our code has two external dependencies:
Data will be automaticaly downloaded and preprocessed when needed. Links to data are stored in `vqa/datasets/vqa.py` and `vqa/datasets/coco.py`.


## Reproducing results
## Reproducing results on VQA 1.0

### Features

Expand Down Expand Up @@ -173,7 +183,7 @@ To obtain test and testdev results, you will need to zip your result json file (
|
├── train.py # train & eval models
├── eval_res.py # eval results files with OpenEnded metric
├── extract.pt # extract features from coco with CNNs
├── extract.py # extract features from coco with CNNs
└── visu.py # visualize logs and monitor training
```

Expand All @@ -189,16 +199,15 @@ You can easly add new options in your custom yaml file if needed. Also, if you w

### Datasets

We currently provide three datasets:
We currently provide four datasets:

- [COCOImages](http://mscoco.org/) currently used to extract features, it comes with three datasets: trainset, valset and testset
- COCOFeatures used by any VQA datasets
- [VQA](http://www.visualqa.org/vqa_v1_download.html) comes with four datasets: trainset, valset, testset (including test-std and test-dev) and "trainvalset" (concatenation of trainset and valset)
- [VisualGenomeImages]() currently used to extract features, it comes with one split: trainset
- [VQA 1.0](http://www.visualqa.org/vqa_v1_download.html) comes with four datasets: trainset, valset, testset (including test-std and test-dev) and "trainvalset" (concatenation of trainset and valset)
- [VQA 2.0](http://www.visualqa.org) same but twice bigger (however same images than VQA 1.0)

We plan to add:

- [VisualGenome](http://visualgenome.org/)
- [VQA2](http://www.visualqa.org/)
- [CLEVR](http://cs.stanford.edu/people/jcjohns/clevr/)

### Models
Expand Down Expand Up @@ -245,7 +254,16 @@ CUDA_VISIBLE_DEVICES=0 python extract.py
CUDA_VISIBLE_DEVICES=1,2 python extract.py
```

### Train models on VQA
### Extract features from VisualGenome

Same here, but only train is available:

```
python extract.py --dataset vgenome --dir_data data/vgenome --data_split train
```


### Train models on VQA 1.0

Display help message, selected options and run default. The needed data will be automaticaly downloaded and processed using the options in `options/default.yaml`.

Expand All @@ -258,19 +276,19 @@ python train.py
Run a MutanNoAtt model with default options.

```
python train.py --path_opt options/vqa/mutan_noatt.yaml --dir_logs logs/vqa/mutan_noatt
python train.py --path_opt options/vqa/mutan_noatt.yaml --dir_logs logs/vqa/mutan_noatt_train
```

Run a MutanAtt model on the trainset and evaluate on the valset after each epoch.

```
python train.py --vqa_trainsplit train --path_opt options/vqa/mutan_att.yaml
python train.py --vqa_trainsplit train --path_opt options/vqa/mutan_att_trainval.yaml
```

Run a MutanAtt model on the trainset and valset (by default) and run throw the testset after each epoch (produce a results file that you can submit to the evaluation server).

```
python train.py --vqa_trainsplit trainval --path_opt options/vqa/mutan_att.yaml
python train.py --vqa_trainsplit trainval --path_opt options/vqa/mutan_att_trainval.yaml
```

### Monitor training
Expand Down Expand Up @@ -301,6 +319,22 @@ Create a visualization of multiple experiments to compare them or monitor them l
python visu.py --dir_logs logs/vqa/mutan_noatt,logs/vqa/mutan_att
```

### Train models on VQA 2.0

See options of [vqa2/mutan_att_trainval](https://github.com/Cadene/vqa.pytorch/blob/master/options/vqa2/mutan_att_trainval.yaml):

```
python train.py --path_opt options/vqa2/mutan_att_trainval.yaml
```

### Train models on VQA (1.0 or 2.0) + VisualGenome

See options of [vqa2/mutan_att_trainval_vg](https://github.com/Cadene/vqa.pytorch/blob/master/options/vqa2/mutan_att_trainval_vg.yaml):

```
python train.py --path_opt options/vqa2/mutan_att_trainval_vg.yaml
```

### Restart training

Restart the model from the last checkpoint.
Expand Down Expand Up @@ -329,13 +363,14 @@ Please cite the arXiv paper if you use Mutan in your work:

```
@article{benyounescadene2017mutan,
title={MUTAN: Multimodal Tucker Fusion for Visual Question Answering},
author={Hedi Ben-Younes and
R{\'{e}}mi Cad{\`{e}}ne and
Nicolas Thome and
Matthieu Cord}},
journal={arXiv preprint arXiv:1705.06676},
year={2017}
author = {Hedi Ben-Younes and
R{\'{e}}mi Cad{\`{e}}ne and
Nicolas Thome and
Matthieu Cord},
title = {MUTAN: Multimodal Tucker Fusion for Visual Question Answering},
journal = {ICCV},
year = {2017},
url = {http://arxiv.org/abs/1705.06676}
}
```

Expand Down
84 changes: 52 additions & 32 deletions extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,60 +8,73 @@
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
from torch.autograd import Variable

import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models

import vqa.datasets.coco as coco
import vqa.models.convnets as convnets
import vqa.datasets as datasets
from vqa.lib.dataloader import DataLoader
from vqa.models.utils import ResNet
from vqa.lib.logger import AvgMeter

model_names = sorted(name for name in models.__dict__
if name.islower() and name.startswith("resnet")
and callable(models.__dict__[name]))

parser = argparse.ArgumentParser(description='Extract')
parser.add_argument('--dir_data', default='data/coco', metavar='DIR',
help='dir dataset: mscoco or visualgenome')
parser.add_argument('--dataset', default='coco',
choices=['coco', 'vgenome'],
help='dataset type: coco (default) | vgenome')
parser.add_argument('--dir_data', default='data/coco',
help='dir dataset to download or/and load images')
parser.add_argument('--data_split', default='train', type=str,
help='Options: (default) train | val | test')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet152',
choices=model_names,
parser.add_argument('--arch', '-a', default='resnet152',
choices=convnets.model_names,
help='model architecture: ' +
' | '.join(model_names) +
' (default: resnet152)')
parser.add_argument('--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 8)')
parser.add_argument('--batch_size', '-b', default=80, type=int, metavar='N',
' | '.join(convnets.model_names) +
' (default: fbresnet152)')
parser.add_argument('--workers', default=4, type=int,
help='number of data loading workers (default: 4)')
parser.add_argument('--batch_size', '-b', default=80, type=int,
help='mini-batch size (default: 80)')
parser.add_argument('--mode', default='both', type=str,
help='Options: att | noatt | (default) both')
parser.add_argument('--size', default=448, type=int,
help='Image size (448 for noatt := avg pooling to get 224) (default:448)')


def main():
global args
args = parser.parse_args()

print("=> using pre-trained model '{}'".format(args.arch))
model = models.__dict__[args.arch](pretrained=True)
model = ResNet(model, False)
model = nn.DataParallel(model).cuda()
model = convnets.factory({'arch':args.arch}, cuda=True, data_parallel=True)

#extract_name = 'arch,{}_layer,{}_resize,{}'.format()
extract_name = 'arch,{}'.format(args.arch)
extract_name = 'arch,{}_size,{}'.format(args.arch, args.size)

#dir_raw = os.path.join(args.dir_data, 'raw')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])

dataset = coco.COCOImages(args.data_split, dict(dir=args.dir_data),
transform=transforms.Compose([
transforms.Scale(448),
transforms.CenterCrop(448),
transforms.ToTensor(),
normalize,
]))
if args.dataset == 'coco':
if 'coco' not in args.dir_data:
raise ValueError('"coco" string not in dir_data')
dataset = datasets.COCOImages(args.data_split, dict(dir=args.dir_data),
transform=transforms.Compose([
transforms.Scale(args.size),
transforms.CenterCrop(args.size),
transforms.ToTensor(),
normalize,
]))
elif args.dataset == 'vgenome':
if args.data_split != 'train':
raise ValueError('train split is required for vgenome')
if 'vgenome' not in args.dir_data:
raise ValueError('"vgenome" string not in dir_data')
dataset = datasets.VisualGenomeImages(args.data_split, dict(dir=args.dir_data),
transform=transforms.Compose([
transforms.Scale(args.size),
transforms.CenterCrop(args.size),
transforms.ToTensor(),
normalize,
]))

data_loader = DataLoader(dataset,
batch_size=args.batch_size, shuffle=False,
Expand All @@ -79,13 +92,19 @@ def extract(data_loader, model, path_file, mode):
path_txt = path_file + '.txt'
hdf5_file = h5py.File(path_hdf5, 'w')

# estimate output shapes
output = model(Variable(torch.ones(1, 3, args.size, args.size),
volatile=True))

nb_images = len(data_loader.dataset)
if mode == 'both' or mode == 'att':
shape_att = (nb_images, 2048, 14, 14)
shape_att = (nb_images, output.size(1), output.size(2), output.size(3))
print('Warning: shape_att={}'.format(shape_att))
hdf5_att = hdf5_file.create_dataset('att', shape_att,
dtype='f')#, compression='gzip')
if mode == 'both' or mode == 'noatt':
shape_noatt = (nb_images, 2048)
shape_noatt = (nb_images, output.size(1))
print('Warning: shape_noatt={}'.format(shape_noatt))
hdf5_noatt = hdf5_file.create_dataset('noatt', shape_noatt,
dtype='f')#, compression='gzip')

Expand All @@ -98,7 +117,7 @@ def extract(data_loader, model, path_file, mode):

idx = 0
for i, input in enumerate(data_loader):
input_var = torch.autograd.Variable(input['visual'], volatile=True)
input_var = Variable(input['visual'], volatile=True)
output_att = model(input_var)

nb_regions = output_att.size(2) * output_att.size(3)
Expand All @@ -111,6 +130,7 @@ def extract(data_loader, model, path_file, mode):
hdf5_noatt[idx:idx+batch_size] = output_noatt.data.cpu().numpy()
idx += batch_size

torch.cuda.synchronize()
batch_time.update(time.time() - end)
end = time.time()

Expand Down
40 changes: 40 additions & 0 deletions options/vqa2/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
logs:
dir_logs: logs/vqa2/default
vqa:
dataset: VQA2
dir: data/vqa2
trainsplit: train
nans: 2000
maxlength: 26
minwcount: 0
nlp: mcb
pad: right
samplingans: True
coco:
dir: data/coco
arch: fbresnet152
mode: noatt
size: 448
model:
arch: MLBNoAtt
seq2vec:
arch: skipthoughts
dir_st: data/skip-thoughts
type: UniSkip
dropout: 0.25
fixed_emb: False
fusion:
dim_v: 2048
dim_q: 2400
dim_h: 1200
dropout_v: 0.5
dropout_q: 0.5
activation_v: tanh
activation_q: tanh
classif:
activation: tanh
dropout: 0.5
optim:
lr: 0.0001
batch_size: 512
epochs: 100
49 changes: 49 additions & 0 deletions options/vqa2/mlb_att_trainval.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
logs:
dir_logs: logs/vqa2/mlb_att_trainval
vqa:
dataset: VQA2
dir: data/vqa2
trainsplit: trainval
nans: 2000
maxlength: 26
minwcount: 0
nlp: mcb
pad: right
samplingans: True
coco:
dir: data/coco
arch: fbresnet152
mode: att
size: 448
model:
arch: MLBAtt
dim_v: 2048
dim_q: 2400
seq2vec:
arch: skipthoughts
dir_st: data/skip-thoughts
type: BayesianUniSkip
dropout: 0.25
fixed_emb: False
attention:
nb_glimpses: 4
dim_h: 1200
dropout_v: 0.5
dropout_q: 0.5
dropout_mm: 0.5
activation_v: tanh
activation_q: tanh
activation_mm: tanh
fusion:
dim_h: 1200
dropout_v: 0.5
dropout_q: 0.5
activation_v: tanh
activation_q: tanh
classif:
activation: tanh
dropout: 0.5
optim:
lr: 0.0001
batch_size: 128
epochs: 100
Loading

0 comments on commit 42391fd

Please sign in to comment.