Skip to content

Commit

Permalink
Merge pull request #66 from ParadoxZW/master
Browse files Browse the repository at this point in the history
MMNasNet for VQA-v2 is supported.
  • Loading branch information
MIL-VLG committed Aug 13, 2020
2 parents 1d4ead9 + 7c93d72 commit b4654be
Show file tree
Hide file tree
Showing 13 changed files with 588 additions and 10 deletions.
3 changes: 2 additions & 1 deletion README.md
Expand Up @@ -5,7 +5,7 @@
<a href="https://github.com/MILVLG"><img alt="powered-by MILVLG" src="https://img.shields.io/badge/powered%20by-MILVLG-orange.svg?style=flat&amp;colorA=E1523D&amp;colorB=007D8A"/></a>
</div>

OpenVQA is a general platform for visual question ansering (VQA) research, with implementing state-of-the-art approaches (e.g., [BUTD](https://arxiv.org/abs/1707.07998), [MFH](https://arxiv.org/abs/1708.03619), [BAN](https://arxiv.org/abs/1805.07932) and [MCAN](https://arxiv.org/abs/1906.10770)) on different benchmark datasets like [VQA-v2](https://visualqa.org/), [GQA](https://cs.stanford.edu/people/dorarad/gqa/index.html) and [CLEVR](https://cs.stanford.edu/people/jcjohns/clevr/). Supports for more methods and datasets will be updated continuously.
OpenVQA is a general platform for visual question ansering (VQA) research, with implementing state-of-the-art approaches (e.g., [BUTD](https://arxiv.org/abs/1707.07998), [MFH](https://arxiv.org/abs/1708.03619), [BAN](https://arxiv.org/abs/1805.07932), [MCAN](https://arxiv.org/abs/1906.10770)) and [MMNasNet](https://arxiv.org/pdf/2004.12070.pdf) on different benchmark datasets like [VQA-v2](https://visualqa.org/), [GQA](https://cs.stanford.edu/people/dorarad/gqa/index.html) and [CLEVR](https://cs.stanford.edu/people/jcjohns/clevr/). Supports for more methods and datasets will be updated continuously.



Expand All @@ -30,6 +30,7 @@ Results and models are available in [MODEL ZOO](https://openvqa.readthedocs.io/e
| [MFH](https://arxiv.org/abs/1708.03619) || | |
| [BAN](https://arxiv.org/abs/1805.07932) ||| |
| [MCAN](https://arxiv.org/abs/1906.10770) ||||
| [MMNasNet](https://arxiv.org/pdf/2004.12070.pdf) || | |

## News & Updates

Expand Down
28 changes: 28 additions & 0 deletions configs/vqa/mmnasnet_large.yml
@@ -0,0 +1,28 @@
# Network
MODEL_USE: mmnasnet
ARCH: {
enc: [SA, SA, SA, SA, FFN, FFN, FFN, FFN, SA, FFN, FFN, FFN],
dec: [GA, GA, FFN, FFN, GA, FFN, RSA, GA, FFN, GA, RSA, FFN, RSA, SA, FFN, RSA, GA, FFN]
}
HIDDEN_SIZE: 1024
REL_HBASE: 128
REL_SIZE: 64
MULTI_HEAD: 8
DROPOUT_R: 0.1
FLAT_MLP_SIZE: 1024
FLAT_GLIMPSES: 1
FLAT_OUT_SIZE: 2048

# Execution
BATCH_SIZE: 64
LR_BASE: 0.00007 # 5e-5 for train+val+vg->test
LR_DECAY_R: 0.2
LR_DECAY_LIST: [10, 12]
WARMUP_EPOCH: 3
MAX_EPOCH: 13
GRAD_NORM_CLIP: 1.0
GRAD_ACCU_STEPS: 1
LOSS_FUNC: bce
LOSS_REDUCTION: sum
OPT: Adam
OPT_PARAMS: {betas: '(0.9, 0.98)', eps: '1e-9'}
28 changes: 28 additions & 0 deletions configs/vqa/mmnasnet_small.yml
@@ -0,0 +1,28 @@
# Network
MODEL_USE: mmnasnet
ARCH: {
enc: [SA, SA, SA, SA, FFN, FFN, FFN, FFN, SA, FFN, FFN, FFN],
dec: [GA, GA, FFN, FFN, GA, FFN, RSA, GA, FFN, GA, RSA, FFN, RSA, SA, FFN, RSA, GA, FFN]
}
HIDDEN_SIZE: 512
REL_HBASE: 64
REL_SIZE: 64
MULTI_HEAD: 8
DROPOUT_R: 0.1
FLAT_MLP_SIZE: 512
FLAT_GLIMPSES: 1
FLAT_OUT_SIZE: 1024

# Execution
BATCH_SIZE: 64
LR_BASE: 0.00012 # 1e-4 for train+val+vg->test
LR_DECAY_R: 0.2
LR_DECAY_LIST: [10, 12]
WARMUP_EPOCH: 3
MAX_EPOCH: 13
GRAD_NORM_CLIP: 1.0
GRAD_ACCU_STEPS: 1
LOSS_FUNC: bce
LOSS_REDUCTION: sum
OPT: Adam
OPT_PARAMS: {betas: '(0.9, 0.98)', eps: '1e-9'}
4 changes: 4 additions & 0 deletions docs/_source/basic/model_zoo.md
Expand Up @@ -33,6 +33,8 @@ We provide three groups of results (including the accuracies of *Overall*, *Yes/
| [BAN-8](https://github.com/MILVLG/openvqa/tree/master/configs/vqa/ban_8.yml) | 2e-3 | 66.00 | 83.61 | 47.04 | 57.62 |
| [MCAN-small](https://github.com/MILVLG/openvqa/tree/master/configs/vqa/mcan_small.yml) | 1e-4 | 67.17 | 84.82 | 49.31 | 58.48 |
| [MCAN-large](https://github.com/MILVLG/openvqa/tree/master/configs/vqa/mcan_large.yml) | 7e-5 | 67.50 | 85.14 | 49.66 | 58.80 |
| [MMNasNet-small](https://github.com/MILVLG/openvqa/tree/master/configs/vqa/mmnasnet_small.yml) | 1.2e-4 | 67.79 | 85.02 | 52.25 | 58.80 |
| [MMNasNet-large](https://github.com/MILVLG/openvqa/tree/master/configs/vqa/mmnasnet_large.yml) | 7e-5 | 69.09 | 85.22 | 52.04 | 59.09 |

#### Train+val -> Test-dev

Expand All @@ -45,6 +47,8 @@ We provide three groups of results (including the accuracies of *Overall*, *Yes/
| [BAN-8](https://github.com/MILVLG/openvqa/tree/master/configs/vqa/ban_8.yml) | 1.4e-3 | 69.07 | 85.2 | 49.63 | 59.71 | [model](https://awma1-my.sharepoint.com/:u:/g/personal/yuz_l0_tn/EbJgyL7FPTFAqzMm3HB1xDIBjXpWygOoXrdnDZKEIu34rg?e=kxCVue) |
| [MCAN-small](https://github.com/MILVLG/openvqa/tree/master/configs/vqa/mcan_small.yml) | 1e-4 | 70.33 | 86.77 | 52.14 | 60.40 | [model](https://awma1-my.sharepoint.com/:u:/g/personal/yuz_l0_tn/EcFeQCi_9MVBn6MeESly8OYBZCeBEuaPQqZjT-oXidgKKg?e=5dGjUt) |
| [MCAN-large](https://github.com/MILVLG/openvqa/tree/master/configs/vqa/mcan_large.yml) | 5e-5 | 70.48 | 86.90 | 52.11 | 60.63 | [model](https://awma1-my.sharepoint.com/:u:/g/personal/yuz_l0_tn/Ee6HdFN_FcZAsQEm85WesHgBZBkY8dZ-278dDYG_ty_IwA?e=WK4SX4) |
| [MMNasNet-small](https://github.com/MILVLG/openvqa/tree/master/configs/vqa/mmnasnet_small.yml) | 1e-4 | 71.24 | 87.11 | 56.15 | 61.08 | [model](https://awma1-my.sharepoint.com/:u:/g/personal/yuz_l0_tn/EaUf4tRcw0FPghbwRoVcMo8BQT9SWzgiZBpD2CrFRfS54w?e=mthO4l) |
| [MMNasNet-large](https://github.com/MILVLG/openvqa/tree/master/configs/vqa/mmnasnet_large.yml) | 5e-5 | 71.45 | 87.29 | 55.71 | 61.45 | [model](https://awma1-my.sharepoint.com/:u:/g/personal/yuz_l0_tn/EQwNsq0AVehGqhWS4iwuWsYBPtP78xEqRgFKuRGKodkQWA?e=ZVsBVO) |

#### Train+val+vg -> Test-dev

Expand Down
3 changes: 3 additions & 0 deletions openvqa/core/base_cfgs.py
Expand Up @@ -93,6 +93,9 @@ def __init__(self):
},
}

# Set if bbox_feat need be normalize by image size, default: False
self.BBOX_NORMALIZE = False

# Default training batch size: 64
self.BATCH_SIZE = 64

Expand Down
19 changes: 11 additions & 8 deletions openvqa/datasets/vqa/vqa_loader.py
Expand Up @@ -202,8 +202,9 @@ def load_img_feats(self, idx, iid):
),
img_feat_pad_size=self.__C.FEAT_SIZE['vqa']['BBOX_FEAT_SIZE'][0]
)
grid_feat_iter = np.zeros(1)

return frcn_feat_iter, np.zeros(1), bbox_feat_iter
return frcn_feat_iter, grid_feat_iter, bbox_feat_iter



Expand All @@ -226,15 +227,17 @@ def proc_img_feat(self, img_feat, img_feat_pad_size):


def proc_bbox_feat(self, bbox, img_shape):
bbox_feat = np.zeros((bbox.shape[0], 5), dtype=np.float32)
if self.__C.BBOX_NORMALIZE:
bbox_nm = np.zeros((bbox.shape[0], 4), dtype=np.float32)

bbox_feat[:, 0] = bbox[:, 0] / float(img_shape[1])
bbox_feat[:, 1] = bbox[:, 1] / float(img_shape[0])
bbox_feat[:, 2] = bbox[:, 2] / float(img_shape[1])
bbox_feat[:, 3] = bbox[:, 3] / float(img_shape[0])
bbox_feat[:, 4] = (bbox[:, 2] - bbox[:, 0]) * (bbox[:, 3] - bbox[:, 1]) / float(img_shape[0] * img_shape[1])
bbox_nm[:, 0] = bbox[:, 0] / float(img_shape[1])
bbox_nm[:, 1] = bbox[:, 1] / float(img_shape[0])
bbox_nm[:, 2] = bbox[:, 2] / float(img_shape[1])
bbox_nm[:, 3] = bbox[:, 3] / float(img_shape[0])
return bbox_nm
# bbox_feat[:, 4] = (bbox[:, 2] - bbox[:, 0]) * (bbox[:, 3] - bbox[:, 1]) / float(img_shape[0] * img_shape[1])

return bbox_feat
return bbox


def proc_ques(self, ques, token_to_ix, max_token):
Expand Down
5 changes: 5 additions & 0 deletions openvqa/models/mcan/adapter.py
Expand Up @@ -14,6 +14,9 @@ def __init__(self, __C):
super(Adapter, self).__init__(__C)
self.__C = __C

def bbox_proc(self, bbox):
area = (bbox[:, :, 2] - bbox[:, :, 0]) * (bbox[:, :, 3] - bbox[:, :, 1])
return torch.cat((bbox, area.unsqueeze(2)), -1)

def vqa_init(self, __C):
imgfeat_linear_size = __C.FEAT_SIZE['vqa']['FRCN_FEAT_SIZE'][1]
Expand Down Expand Up @@ -45,6 +48,7 @@ def vqa_forward(self, feat_dict):
img_feat_mask = make_mask(frcn_feat)

if self.__C.USE_BBOX_FEAT:
bbox_feat = self.bbox_proc(bbox_feat)
bbox_feat = self.bbox_linear(bbox_feat)
frcn_feat = torch.cat((frcn_feat, bbox_feat), dim=-1)
img_feat = self.frcn_linear(frcn_feat)
Expand All @@ -60,6 +64,7 @@ def gqa_forward(self, feat_dict):
img_feat_mask = make_mask(frcn_feat)

if self.__C.USE_BBOX_FEAT:
bbox_feat = self.bbox_proc(bbox_feat)
bbox_feat = self.bbox_linear(bbox_feat)
frcn_feat = torch.cat((frcn_feat, bbox_feat), dim=-1)
img_feat = self.frcn_linear(frcn_feat)
Expand Down
1 change: 1 addition & 0 deletions openvqa/models/mcan/model_cfgs.py
Expand Up @@ -21,3 +21,4 @@ def __init__(self):
self.FLAT_OUT_SIZE = 1024
self.USE_AUX_FEAT = False
self.USE_BBOX_FEAT = False
self.BBOX_NORMALIZE = True
120 changes: 120 additions & 0 deletions openvqa/models/mmnasnet/adapter.py
@@ -0,0 +1,120 @@
# --------------------------------------------------------
# OpenVQA
# Written by Zhenwei Shao https://github.com/ParadoxZW
# --------------------------------------------------------

import torch.nn as nn
import torch
from openvqa.core.base_dataset import BaseAdapter
from openvqa.utils.make_mask import make_mask


class Adapter(BaseAdapter):
def __init__(self, __C):
super(Adapter, self).__init__(__C)
self.__C = __C


def relation_embedding(self, f_g):
x_min, y_min, x_max, y_max = torch.chunk(f_g, 4, dim=2) # [bs, n_obj, 1]

cx = (x_min + x_max) * 0.5 # [bs, n_obj, 1]
cy = (y_min + y_max) * 0.5 # [bs, n_obj, 1]
w = (x_max - x_min) + 1. # [bs, n_obj, 1]
h = (y_max - y_min) + 1. # [bs, n_obj, 1]

delta_x = cx - cx.transpose(-1, -2)
delta_x = torch.clamp(torch.abs(delta_x / w), min=1e-3)
delta_x = torch.log(delta_x) # [bs, n_obj, n_obj]

delta_y = cy - cy.transpose(-1, -2)
delta_y = torch.clamp(torch.abs(delta_y / h), min=1e-3)
delta_y = torch.log(delta_y) # [bs, n_obj, n_obj]

delta_w = torch.log(w / w.transpose(-1, -2)) # [bs, n_obj, n_obj]
delta_h = torch.log(h / h.transpose(-1, -2)) # [bs, n_obj, n_obj]
size = delta_h.size()

delta_x = delta_x.view(size[0], size[1], size[2], 1)
delta_y = delta_y.view(size[0], size[1], size[2], 1)
delta_w = delta_w.view(size[0], size[1], size[2], 1)
delta_h = delta_h.view(size[0], size[1], size[2], 1) # [bs, n_obj, n_obj, 1]
position_mat = torch.cat(
(delta_x, delta_y, delta_w, delta_h), -1) # [bs, n_obj, n_obj, 4]

return position_mat

def vqa_init(self, __C):
imgfeat_linear_size = __C.FEAT_SIZE['vqa']['FRCN_FEAT_SIZE'][1]
if __C.USE_BBOX_FEAT:
self.bbox_linear = nn.Linear(5, __C.BBOXFEAT_EMB_SIZE)
imgfeat_linear_size += __C.BBOXFEAT_EMB_SIZE
self.frcn_linear = nn.Linear(imgfeat_linear_size, __C.HIDDEN_SIZE)


def gqa_init(self, __C):
imgfeat_linear_size = __C.FEAT_SIZE['gqa']['FRCN_FEAT_SIZE'][1]
if __C.USE_BBOX_FEAT:
self.bbox_linear = nn.Linear(5, __C.BBOXFEAT_EMB_SIZE)
imgfeat_linear_size += __C.BBOXFEAT_EMB_SIZE
self.frcn_linear = nn.Linear(imgfeat_linear_size, __C.HIDDEN_SIZE)

if __C.USE_AUX_FEAT:
self.grid_linear = nn.Linear(__C.FEAT_SIZE['gqa']['GRID_FEAT_SIZE'][1], __C.HIDDEN_SIZE)


def clevr_init(self, __C):
self.grid_linear = nn.Linear(__C.FEAT_SIZE['clevr']['GRID_FEAT_SIZE'][1], __C.HIDDEN_SIZE)


def vqa_forward(self, feat_dict):
frcn_feat = feat_dict['FRCN_FEAT']
bbox_feat = feat_dict['BBOX_FEAT']

img_feat_mask = make_mask(frcn_feat)

if self.__C.USE_BBOX_FEAT:
bbox_feat = self.bbox_proc(bbox_feat)
bbox_feat = self.bbox_linear(bbox_feat)
frcn_feat = torch.cat((frcn_feat, bbox_feat), dim=-1)
img_feat = self.frcn_linear(frcn_feat)
rel_embed = self.relation_embedding(bbox_feat)

return img_feat, rel_embed, img_feat_mask


def gqa_forward(self, feat_dict):
frcn_feat = feat_dict['FRCN_FEAT']
bbox_feat = feat_dict['BBOX_FEAT']
grid_feat = feat_dict['GRID_FEAT']

img_feat_mask = make_mask(frcn_feat)

if self.__C.USE_BBOX_FEAT:
bbox_feat = self.bbox_linear(bbox_feat)
frcn_feat = torch.cat((frcn_feat, bbox_feat), dim=-1)
img_feat = self.frcn_linear(frcn_feat)

if self.__C.USE_AUX_FEAT:
grid_feat_mask = make_mask(grid_feat)
img_feat_mask = torch.cat((img_feat_mask, grid_feat_mask), dim=-1)
grid_feat = self.grid_linear(grid_feat)
img_feat = torch.cat((img_feat, grid_feat), dim=1)

rel_embed = self.relation_embedding(bbox_feat)

return img_feat, rel_embed, img_feat_mask


def clevr_forward(self, feat_dict):
grid_feat = feat_dict['GRID_FEAT']

img_feat_mask = make_mask(grid_feat)
img_feat = self.grid_linear(grid_feat)

rel_embed = self.relation_embedding(bbox_feat)

return img_feat, rel_embed, img_feat_mask



28 changes: 28 additions & 0 deletions openvqa/models/mmnasnet/model_cfgs.py
@@ -0,0 +1,28 @@
# --------------------------------------------------------
# OpenVQA
# Written by Zhenwei Shao https://github.com/ParadoxZW
# --------------------------------------------------------

from openvqa.core.base_cfgs import BaseCfgs


class Cfgs(BaseCfgs):
def __init__(self):
super(Cfgs, self).__init__()

self.ARCH = {
'enc': ['SA', 'SA', 'SA', 'SA', 'FFN', 'FFN', 'FFN', 'FFN', 'SA', 'FFN', 'FFN', 'FFN'],
'dec': ['GA', 'GA', 'FFN', 'FFN', 'GA', 'FFN', 'RSA', 'GA', 'FFN', 'GA', 'RSA', 'FFN', 'RSA', 'SA', 'FFN', 'RSA', 'GA', 'FFN']
}
self.HIDDEN_SIZE = 512
self.BBOXFEAT_EMB_SIZE = 2048
self.FF_SIZE = 2048
self.MULTI_HEAD = 8
self.DROPOUT_R = 0.1
self.FLAT_MLP_SIZE = 512
self.FLAT_GLIMPSES = 1
self.FLAT_OUT_SIZE = 1024
self.USE_AUX_FEAT = False
self.USE_BBOX_FEAT = False
self.REL_HBASE = 64
self.REL_SIZE = 64

0 comments on commit b4654be

Please sign in to comment.