Skip to content
5 changes: 1 addition & 4 deletions models/sequence_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,10 +604,7 @@ def _prefix_tokens(
prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1)
prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1))
prefix_mask = prefix_toks.ne(self.pad)
if self.constraint_trie is None:
lprobs[prefix_mask] = torch.min(prefix_lprobs) - 1
else:
lprobs[prefix_mask] = -math.inf
lprobs[prefix_mask] = -math.inf
lprobs[prefix_mask] = lprobs[prefix_mask].scatter(
-1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask]
)
Expand Down
41 changes: 41 additions & 0 deletions run_scripts/vqa/evaluate_vqa_unconstrained.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/usr/bin/env bash

# This script evaluates **unconstrainedly** finetuned OFA-Large checkpoint (with --unconstrained-training set to True during finetuning)
# which does not use a fixed candidate answer set (trainval_ans2label.pkl).
# For more details about the unconstrained finetuning, refer to Line 62-68 in train_vqa_distributed.sh

# Usage: bash evaluate_vqa_unconstrained.sh ${split} ${ckpt_path}

# The port for communication. Note that if you want to run multiple tasks on the same machine,
# you need to specify different port numbers.
export MASTER_PORT=8082

user_dir=../../ofa_module
bpe_dir=../../utils/BPE

# val or test
split=$1

data=../../dataset/vqa_data/vqa_${split}.tsv
path=$2 # please speficy your path of unconstrainedly finetuned checkpoint
result_path=../../results/vqa_${split}_unconstrained
selected_cols=0,5,2,3,4

CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch --nproc_per_node=4 --master_port=${MASTER_PORT} ../../evaluate.py \
${data} \
--path=${path} \
--user-dir=${user_dir} \
--task=vqa_gen \
--batch-size=16 \
--log-format=simple --log-interval=10 \
--seed=7 \
--gen-subset=${split} \
--results-path=${result_path} \
--fp16 \
--ema-eval \
--beam-search-vqa-eval \
--beam=5 \
--unnormalized \
--temperature=1.0 \
--num-workers=0 \
--model-overrides="{\"data\":\"${data}\",\"bpe_dir\":\"${bpe_dir}\",\"selected_cols\":\"${selected_cols}\"}"
10 changes: 9 additions & 1 deletion run_scripts/vqa/train_vqa_base_distributed.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ max_src_length=80
max_object_length=30
max_tgt_length=30
num_bins=1000
patch_image_size=480

uses_ema="--uses-ema"
store_ema="--store-ema"
Expand All @@ -60,6 +59,14 @@ ema_start_update=0
# As mentioned in the readme, you can choose from allcand or beamsearch evaluation, default to allcand
val_inference_type=allcand

# Specify whether to activate unconstrained VQA finetuning, which does not use a pre-defined candidate answer set
# If --unconstrained-training is acitvated, --ans2label-file will **not be used even if it is specified**
# Meanwhile, --val-inference-type must be set to **beamsearch**
# By default, we follow the constrained finetuning as we mentioned in OFA paper, the candidate answer set shall be specified by --ans2label-file
# For more details about this option, please refer to issue #123 and PR #124
unconstrained_training_flag=""
# unconstrained_training_flag="--unconstrained-training"

for max_epoch in {15,}; do
echo "max_epoch "${max_epoch}
for warmup_ratio in {0.04,}; do
Expand Down Expand Up @@ -120,6 +127,7 @@ for max_epoch in {15,}; do
--find-unused-parameters \
--freeze-encoder-embedding \
--freeze-decoder-embedding \
${unconstrained_training_flag} \
--ans2label-file=${ans2label_file} \
--valid-batch-size=20 \
--add-type-embedding \
Expand Down
10 changes: 9 additions & 1 deletion run_scripts/vqa/train_vqa_distributed.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ max_src_length=80
max_object_length=30
max_tgt_length=30
num_bins=1000
patch_image_size=480

uses_ema="--uses-ema"
store_ema="--store-ema"
Expand All @@ -60,6 +59,14 @@ ema_start_update=0
# As mentioned in the readme, you can choose from allcand or beamsearch evaluation, default to allcand
val_inference_type=allcand

# Specify whether to activate unconstrained VQA finetuning, which does not use a pre-defined candidate answer set
# If --unconstrained-training is acitvated, --ans2label-file will **not be used even if it is specified**
# Meanwhile, --val-inference-type must be set to **beamsearch**
# By default, we follow the constrained finetuning as we mentioned in OFA paper, the candidate answer set shall be specified by --ans2label-file
# For more details about this option, please refer to issue #123 and PR #124
unconstrained_training_flag=""
# unconstrained_training_flag="--unconstrained-training"

for total_num_updates in {40000,}; do
echo "total_num_updates "${total_num_updates}
for warmup_updates in {1000,}; do
Expand Down Expand Up @@ -121,6 +128,7 @@ for total_num_updates in {40000,}; do
--find-unused-parameters \
--freeze-encoder-embedding \
--freeze-decoder-embedding \
${unconstrained_training_flag} \
--ans2label-file=${ans2label_file} \
--valid-batch-size=20 \
--add-type-embedding \
Expand Down
70 changes: 43 additions & 27 deletions tasks/mm_tasks/vqa_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,12 @@ class VqaGenConfig(OFAConfig):
default=None,
metadata={"help": "path to load ans2label file"},
)

unconstrained_training: bool = field(
default=False,
metadata={"help": "do not use Trie to constrain loss into the closed candidate answer set, default to False. \
If set to True, then open-ended training is facilitated, ans2label_file and ans2label_dict will not be used \
and inference type must be beamsearch"},
)
add_object: bool = field(
default=False,
metadata={"help": "add object to encoder"},
Expand Down Expand Up @@ -89,16 +94,19 @@ class VqaGenTask(OFATask):
def __init__(self, cfg: VqaGenConfig, src_dict, tgt_dict):
super().__init__(cfg, src_dict, tgt_dict)

self.ans2label_dict = None
if self.cfg.ans2label_file is not None:
self.ans2label_dict = pickle.load(open(self.cfg.ans2label_file, "rb"))
else:
self.ans2label_dict = json.loads(self.cfg.ans2label_dict)
if not self.cfg.unconstrained_training:
self.ans2label_dict = None
if self.cfg.ans2label_file is not None:
self.ans2label_dict = pickle.load(open(self.cfg.ans2label_file, "rb"))
else:
self.ans2label_dict = json.loads(self.cfg.ans2label_dict)

self.uses_ema = self.cfg.uses_ema

assert self.cfg.val_inference_type in ["allcand", "beamsearch"], \
"Unknown inference type encountered: {}, should be allcand or beamsearch.".format(self.cfg.val_inference_type)
assert not (self.cfg.unconstrained_training and self.cfg.val_inference_type != "beamsearch"), \
"For open-ended training, there is no fixed candidate answer set, then inference type must be beamsearch"

def load_dataset(self, split, epoch=1, combine=False, **kwargs):
paths = self.cfg.data.split(',')
Expand Down Expand Up @@ -128,29 +136,37 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs):

def build_model(self, cfg):
model = super().build_model(cfg)
answer_item_list = []
self.index2ans = {}
self.constraint_trie = Trie(self.tgt_dict.eos())
for i, answer in enumerate(self.ans2label_dict.keys()):
answer_item = self.tgt_dict.encode_line(
line=self.bpe.encode(' ' + answer),
add_if_not_exist=False,
append_eos=False
).long()
answer_item_list.append(answer_item)
self.index2ans[i] = answer
self.constraint_trie.insert([self.tgt_dict.bos()] + answer_item.tolist() + [self.tgt_dict.eos()])

constraint_mask_list = []
for answer_item in answer_item_list:
constraint_mask = torch.zeros((len(answer_item)+1, len(self.tgt_dict))).bool()
for i in range(len(answer_item)+1):
constraint_prefix_token = [self.src_dict.bos()] + answer_item[:i].tolist()
constraint_nodes = self.constraint_trie.get_next_layer(constraint_prefix_token)
constraint_mask[i][constraint_nodes] = True
constraint_mask_list.append(constraint_mask)

# for open-ended training without fixed candidate answer set
if self.cfg.unconstrained_training:
self.constraint_trie = None
# (default) for trie-based constraint training with fixed candidate answer set
# (provided by ans2label_file or ans2label_dict)
else:
answer_item_list = []
self.index2ans = {}
self.constraint_trie = Trie(self.tgt_dict.eos())
for i, answer in enumerate(self.ans2label_dict.keys()):
answer_item = self.tgt_dict.encode_line(
line=self.bpe.encode(' ' + answer),
add_if_not_exist=False,
append_eos=False
).long()
answer_item_list.append(answer_item)
self.index2ans[i] = answer
self.constraint_trie.insert([self.tgt_dict.bos()] + answer_item.tolist() + [self.tgt_dict.eos()])

constraint_mask_list = []
for answer_item in answer_item_list:
constraint_mask = torch.zeros((len(answer_item)+1, len(self.tgt_dict))).bool()
for i in range(len(answer_item)+1):
constraint_prefix_token = [self.src_dict.bos()] + answer_item[:i].tolist()
constraint_nodes = self.constraint_trie.get_next_layer(constraint_prefix_token)
constraint_mask[i][constraint_nodes] = True
constraint_mask_list.append(constraint_mask)

if self.cfg.val_inference_type == "allcand":
assert not self.cfg.unconstrained_training
self.valid_answers_list = []
self.valid_constraint_masks_list = []
for i in range(0, len(answer_item_list), self.cfg.valid_batch_size):
Expand Down