Skip to content

GaryJiajia/OFv2_ICL_VQA

Repository files navigation

In-Context Learning for VQA on OFv2

The repository for our paper: How to Configure Good In-Context Sequence for Visual Question Answering

Table of Contents

Preparation

we use OpenFlamingo and its framework to implement various retrieval strategies on three different VQA datasets.

Environment

Create a conda environment for running the following code. It is used for anonymous submit now, and it will fix in Formal version.

git clone https://github.com/GaryJiajia/OFv2_ICL_VQA.git
cd OFv2_ICL_VQA
conda env create -f environment.yml
conda activate ofv2
pip install git+https://github.com/openai/CLIP.git

Datasets

We use VQAv2, OK-VQA, and VizWiz datasets. You need to download the files of these datasets yourself, including the Images and Annotations.

To run evaluations on OKVQA you will need to run the following command:

import nltk
nltk.download('wordnet')

Model

OpenFlamingo is a multimodal language model that can be used for a variety of tasks. It is trained on a large multimodal dataset (e.g. Multimodal C4) and can be used to generate text conditioned on interleaved images/text. You can read its blog and code for more information.

OpenFlamingo combines a pretrained vision encoder and a language model using cross attention layers. In our experiment, we use OpenFlamingo-9B for experiments. which uses pretrained vision encoders from the OpenCLIP package, ViT-L-14, and uses the MPT-7B as the pretrained language models. Initialize the model as above and use the following code.

from open_flamingo import create_model_and_transforms

model, image_processor, tokenizer = create_model_and_transforms(
    clip_vision_encoder_path="ViT-L-14",
    clip_vision_encoder_pretrained="openai",
    lang_encoder_path="anas-awadalla/mpt-7b",
    tokenizer_path="anas-awadalla/mpt-7b",
    cross_attn_every_n_layers=4
)

# grab model checkpoint from huggingface hub
from huggingface_hub import hf_hub_download
import torch

checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-9B-vitl-mpt7b", "checkpoint.pt")
model.load_state_dict(torch.load(checkpoint_path), strict=False)

Usage

Demo test

We prepare a demo in the file /demo_test/demo_test.ipynb, you can use different demonstrations and query samples to experience the results generated by the OpenFlamingo. Below is a 2-shot demo test.

from open_flamingo import create_model_and_transforms
import torch
from PIL import Image
from PIL import ImageFilter 
import requests

class PATH:
    lm_path = "path for mpt-7b"
    lm_tokenizer_path = "path for mpt-7b"
    checkpoint_path = "path for openflamingo v2 checkpoint.pt"
args = PATH()
device_set = 'cuda:0'
device = torch.device(device_set)

flamingo,image_processor,tokenizer = create_model_and_transforms(
    clip_vision_encoder_path = 'ViT-L-14',
    clip_vision_encoder_pretrained = "openai",
    lang_encoder_path = args.lm_path,
    tokenizer_path = args.lm_tokenizer_path,
    cross_attn_every_n_layers=4,
    # new params
    inference=True,
    precision ='fp16',
    device = device_set,
    checkpoint_path = args.checkpoint_path,
)

demo_image_one = Image.open(
    requests.get(
        "http://images.cocodataset.org/val2017/000000039769.jpg", stream=True
    ).raw
)
demo_image_two = Image.open("test-006.jpg")
query_image = Image.open("test-006.jpg")
tokenizer.padding_side = "left"
lang_x = tokenizer(
    ["<image>Question: What kind of animals in the image? Answer: Dog. <|endofchunk|><image>Question: What kind of animals in the image? Answer: Dog. <|endofchunk|><image>Question: What kind of animals in the image? Answer:"],
    return_tensors="pt",
)
vision_x = [image_processor(demo_image_one).unsqueeze(0), image_processor(demo_image_two).unsqueeze(0), image_processor(query_image).unsqueeze(0)]
vision_x = torch.cat(vision_x, dim=0)
vision_x = vision_x.unsqueeze(1).unsqueeze(0)
# load data to gpus
vision_x = vision_x.to(device).half()
print(vision_x.device)
input_ids=lang_x["input_ids"]
attention_mask = lang_x["attention_mask"]
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)

generated_text = flamingo.generate(
    vision_x=vision_x,
    lang_x=input_ids,
    attention_mask=attention_mask,
    max_new_tokens=20,
    num_beams=3,
)
print(tokenizer.decode(generated_text[0]))

Diverse Demonstrations Retrieval Methods

An example evaluation script is at open_flamingo/scripts/run_vqav2.sh, as follows:

DEVICE=0 # gpu number

RANDOM_ID="VQAv2_Result_file_name"
RESULTS_FILE="results_${RANDOM_ID}.json"

export MASTER_ADDR='localhost'
export MASTER_PORT='10000' 

python open_flamingo/eval/evaluate_vqa.py \
    --retrieval_name $RANDOM_ID \
    --lm_path "Path for mpt-7b" \
    --lm_tokenizer_path "Path for mpt-7b" \
    --checkpoint_path "Path for OpenFlamingo-9B-vitl-mpt7b checkpoint.pt" \
    --vision_encoder_path "ViT-L-14" \
    --vision_encoder_pretrained 'openai' \
    --device $DEVICE \
    --vqav2_train_image_dir_path "mscoco2014/train2014/"  \
    --vqav2_train_questions_json_path "vqav2/v2_OpenEnded_mscoco_train2014_questions.json" \
    --vqav2_train_annotations_json_path  "vqav2/v2_mscoco_train2014_annotations.json" \
    --vqav2_test_image_dir_path "mscoco2014/val2014/" \
    --vqav2_test_questions_json_path "vqav2/v2_OpenEnded_mscoco_val2014_questions" \
    --vqav2_test_annotations_json_path "vqav2/v2_mscoco_val2014_annotations.json" \
    --results_file $RESULTS_FILE \
    --num_samples 5000\
    --shots 4 8 16 32\
    --num_trials 1 \
    --seed 5 \
    --batch_size 1 \
    --cross_attn_every_n_layers 4 \
    --precision fp16 \
    --dataset_name vqav2 \
    --eval_vqav2 \
    
echo "evaluation complete! results written to $RESULTS_FILE"

Change parameters according to your needs and Use following command, you can use it on one RTX 3090 GPU with FP16 precision.

cd this file
bash open_flamingo/scripts/run_vqav2.sh

Before running the above file, you have to run the retrieval/img2img_clip_style.py to get the "validation_xxx.npy" retrival results file which is used in eval/eval_datasets.py. For more details, you can see the answer in this issue.

If you need to use different retrieval methods, you can change the parameters of control_signals in open_flamingo/eval/evaluate_vqa.py.

control_signals = {"clip": True, # If clip==False, it means the RS.
                   "retrieval_name": args.retrieval_name, # The results file name.
                   "retrieval_type": "SI", # Name of retrieval methods. SI/SQ/SI_Q...
                   "specification": False, # Add the instruction.
                   "declaration": False, # Add the declarative sentence into the demonstrations.
                   "add_declaration": False,  # Add the declarative sentence into the demonstrations.
                   "gauss": True, # Blur the query image.
                   "None_ICE":False, # In 0-shot setting, should we offer demonstration for the Model.
                   "order": "order"} # The order of the demonstrations. order/reverse

Acknowledgments

This code is based on the second version of OpenFlamingo.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published