Skip to content

gicheonkang/gst-visdial

Repository files navigation

The Dialog Must Go On:
Improving Visual Dialog via Generative Self-Training

Gi-Cheon Kang,   Sungdong Kim*,   Jin-Hwa Kim*,   Donghyun Kwak*,   Byoung-Tak Zhang

(* Equal Contribution)

CVPR 2023 (Paper)

Overview



Citation

If you use this code or preprocessed data in your research, please consider citing:

@inproceedings{kang2023dialog,
  title={The Dialog Must Go On: Improving Visual Dialog via Generative Self-Training},
  author={Kang, Gi-Cheon and Kim, Sungdong and Kim, Jin-Hwa and Kwak, Donghyun and Zhang, Byoung-Tak},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
  year={2023},
  pages={6746-6756}
}

Table of Contents

Setup and Dependencies

This code is implemented using PyTorch v1.7.1+, and provides out of the box support with CUDA 11+ and CuDNN 7+. Anaconda/Miniconda is the recommended to set up this codebase:

  1. Install Anaconda or Miniconda distribution based on Python3.8+ from their downloads' site.
  2. Clone this repository and create an environment:
git clone https://www.github.com/gicheonkang/gst-visdial
conda create -n gst python=3.8 -y

# activate the environment and install all dependencies
conda activate gst
pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
pip install -r requirements.txt

Download Data

  1. Download the preprocessed original VisDial data, collected by Das et al. It includes Faster R-CNN bounding box image features of the MSCOCO dataset (80G) and preprocessed json files for dialog (2G).
chmod +x scripts/download_preprocessed_human_visdial.sh
  1. We also release the machine-generated VisDial data which consists of Faster R-CNN bounding box image features of the subset of the Conceptual Captions 12M dataset (nearly 2.4T with 3.6M images) and the corresponding machine-generated dialog data.

  2. If you just want to use the machine-generated dialog data along with images, download the json files for the dialog data. The json file contains urls for image data.

chmod +x scripts/download_preprocessed_machine_visdial.sh

Pre-trained Checkpoints

Please download the checkpoints to checkpoints/ directory.

Model Trained Data Link
Questioner VisDial v1.0 Download
Teacher VisDial v1.0 Download
Student VisDial v1.0 + CC12M with Synthetic Dialogs (iter3) Download
Student (Discriminative) VisDial v1.0 + CC12M with Synthetic Dialogs (iter3) Download
Base Model from VisDial-BERT CC3M + VQA Download

Training

Teacher model and questioner model training. Nearly 54G gpu memory is required to train the model. The argument -enc_dec_a denotes an encoder-decoder model for answerer model, and -enc_dec_q is the encoder-decoder model for questioner model.

# Teacher model training
python train_gen.py \
  -mode vd_train \
  -start_path checkpoints/basemodel \
  -model enc_dec_a \
  -gpu_ids 0 1 2 3
# Questioner model training
python train_gen.py \
  -mode vd_train \
  -start_path checkpoints/basemodel \
  -model enc_dec_q \
  -gpu_ids 0 1 2 3

Student model training consists of two steps: (1) training on synthetically generated visual dialog dataset and (2) finetuning on original visual dialog dataset. The argument -chunk denotes the number of data chunk to use (default 30). -select_data is to use perplexity-based data selection method. After training on the synthetic dialog data, the student model is trained on the original visual dialog data.

# training a synthetic visual dialog dataset
python train_gen.py \
  -mode cc12m_train \
  -select_data \
  -start_path checkpoints/basemodel \
  -save_path checkpoints/iter1/ \
  -chunk 30 \
  -gpu_ids 0 1 2 3 \
  -iter 1
# finetuning on a original visual dialog dataset 
python train_gen.py \
  -mode vd_train \
  -continue \
  -start_path checkpoints/iter1/cc12m_train_30_3.ckpt \
  -save_path checkpoints/iter1/ \
  -chunk 30 \
  -gpu_ids 0 1 2 3

Adaptation to Discriminative Visual Dialog

A "discriminative" visual dialog model requires answer candidates for each question, but our proposed approach only generates the ground-truth answer. Hence, we propose tricks to train the discriminative model. Based on the encoder-decoder model pre-trained on the synthetic dataset, we finetune the encoder model on the original visdial dataset. Please see our paper (Appendix B) for more details.

python train_disc.py \
  -mode vd_train \
  -continue \
  -model enc_only_a \
  -batch_size 40 \
  -train_dense \
  -num_negative_samples 5 \
  -start_path checkpoints/x30_start_iter3.ckpt \
  -save_path checkpoints/disc \
  -chunk 30 \
  -gpu_ids 0 1 2 3

Visual Dialog Generation



Visual dialog generation given image features and captions. The questioner and the teacher alternately generates the visual question and corresponding answer, respectively.

You can generate your own visual dialog dataset just feeding Bottom-up Attention Features and the caption data. We extracted the image features using the docker container.

python generate.py \
  -mode cc12m_gen \
  -cc12m_image_feats data/cc12m/features/cc12m_img_feat_0.lmdb/ \
  -cc12m_caption data/cc12m/captions/cc12m_filtered_0.json \
  -start_path_q checkpoints/questioner_v1.0.ckpt \
  -start_path_a checkpoints/teacher_v1.0.ckpt \
  -save_name cc12m_dialogs_0.txt \
  -save_path data/gen_dialog \
  -gpu_ids 0 1

Evaluation

Evaluation of the student model on VisDial v1.0 validation split. Validation scores can be checked in offline setting. But if you want to evaluate the model on the test dataset, you should change the mode to vd_eval_test and submit the text file to EvalAI online evaluation server. Also, evaluation for the VisDial v0.9 validation dataset is available. Please add -vd_version 0.9.

python evaluate_gen.py \
  -mode vd_eval_val \
  -start_path checkpoints/student_v1.0_iter3.ckpt \
  -save_path results \
  -save_name gen.txt \
  -gpu_ids 0 1 2 3

Evaluation for the discriminative model is as follows.

python evaluate_disc.py \
  -mode vd_eval_val \
  -start_path checkpoints/student_v1.0_iter3_disc_dense.ckpt \
  -save_path results \
  -save_name disc.txt \
  -gpu_ids 0 1 2 3

Adversarial Robustness Study

We propose three different adversarial attacks for VisDial: (1) the FGSM attack, (2) a coreference attack, and (3) a random token attack. The FGSM attack perturbs input visual features, and the others attack the dialog history (textual inputs).

Simply run below for the FGSM attack

python evaluate_gen_attack.py \
  -mode vd_eval_val \
  -attack fgsm \
  -start_path checkpoints/student_v1.0_iter3.ckpt \
  -save_path results \
  -save_name fgsm.txt \
  -gpu_ids 0 1 2 3

For the textual attacks, preprocessing is required. Download the counter-fitted word embeddings and run the preprocessing code below.

python comp_cos_sim_mat.py counter-fitted-vectors.txt

Then, run the script

python evaluate_gen_attack.py \
  -mode vd_eval_val \
  -attack coreference \
  -visdial_processed_val data/visdial/visdial_1.0_val_crowdsourced.json \
  -visdial_processed_val_dense_annotations data/visdial/visdial_1.0_val_dense_annotations_processed_crowdsourced.json
  -start_path checkpoints/student_v1.0_iter3.ckpt \
  -save_path results \
  -save_name coreference.txt \
  -gpu_ids 0 1 2 3

Demo

We prepare interactive demo to show our model's generated answer easily. Simply run and enter the image id in VisDial v1.0 validation images.

python inference.py

Acknowledgements

We use VisDial-BERT as reference code. Thanks!

License

MIT License

About

💬 Official PyTorch Implementation for CVPR'23 Paper, "The Dialog Must Go On: Improving Visual Dialog via Generative Self-Training"

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published