In [1]:
import os
import pickle
import json
from tqdm import tqdm
import os.path as osp
import re
from io import BytesIO
import argparse
import numpy as np

import requests
import torch
from PIL import Image

from llava.constants import (
    DEFAULT_IM_END_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IMAGE_TOKEN,
    IMAGE_PLACEHOLDER,
    IMAGE_TOKEN_INDEX,
)
from llava.conversation import SeparatorStyle, conv_templates
from llava.mm_utils import KeywordsStoppingCriteria, get_model_name_from_path, process_images, tokenizer_image_token
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from eval_datasets import VQADataset
from rice import RICES


# arg_parser = argparse.ArgumentParser()
# arg_parser.add_argument(
#     "--model_name_or_path",
#     type=str,
#     default= "Efficient-Large-Model/VILA1.5-13b",
#     help="vila model name or path"
# )
# arg_parser.add_argument(
#     "--n_shots",
#     type=int,
#     help="number of incontext examples",
# )
# arg_parser.add_argument(
#     "--use_random",
#     action="store_true",
#     help="Pass in a list of MMC4 shards in the format path_to_shard/shard_{0..23098}.zip",
# )
# arg_parser.add_argument(
#     "--n_random",
#     type=int,
#     default=0,
#     help="number of random incontext examples",
# )
# arg_parser.add_argument(
#     "--save_path",
#     type=str,
#     help="where to save model outputs",
# )


# args = arg_parser.parse_args()

# input args
model_name_or_path = "Efficient-Large-Model/VILA1.5-13b"

train_image_dir_path = "/scratch/workspace/asureddy_umass_edu-llm_alignment/dataset/vqa/train2014"
train_questions_path = "/scratch/workspace/asureddy_umass_edu-llm_alignment/dataset/vqa/v2_OpenEnded_mscoco_train2014_questions.json"
train_annotations_path = "/scratch/workspace/asureddy_umass_edu-llm_alignment/dataset/vqa/v2_mscoco_train2014_annotations.json"

val_image_dir_path = "/scratch/workspace/asureddy_umass_edu-llm_alignment/dataset/vqa/val2014"
val_questions_path = "/scratch/workspace/asureddy_umass_edu-llm_alignment/dataset/vqa/v2_OpenEnded_mscoco_val2014_questions.json"
val_annotations_path = "/scratch/workspace/asureddy_umass_edu-llm_alignment/dataset/vqa/v2_mscoco_val2014_annotations.json"

# dataset = VQADataset(image_dir_path, questions_path, annotations_path,True, "vqav2")

rice_cached_features_path = "/scratch/workspace/asureddy_umass_edu-llm_alignment/features-cache/coco_train_2014.pkl" 
train_dataset = VQADataset(train_image_dir_path, train_questions_path, train_annotations_path,True, "vqav2")
val_dataset = VQADataset(val_image_dir_path, val_questions_path, val_annotations_path,False, "vqav2")
if rice_cached_features_path:
    with open(rice_cached_features_path, 'rb') as f:
        rice_cached_features = pickle.load(f)

retriever = RICES(train_dataset, 'cpu',1, cached_features=rice_cached_features)

model_name = get_model_name_from_path(model_name_or_path)

tokenizer, model, image_processor, context_len = load_pretrained_model(model_name_or_path, model_name, None)
image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN

# conv_mode = "hermes-2" # for vila-40b
# llava_v0 for vila-13b
def get_output_for_query(query, imgs, conv_mode="llava_v1",max_new_tokens=5):
    query = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, query)
    # conv_mode = "hermes-2"
    conv = conv_templates[conv_mode].copy()
    conv.append_message(conv.roles[0], query)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()
    
    images_tensor = process_images(imgs, image_processor, model.config).to(model.device, dtype=torch.float16)
    input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()
    stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
    keywords = [stop_str]
    stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
    
    # print(images_tensor.shape)
    temperature = 0.2
    num_beams = 3
    top_p = 0.95
    max_new_tokens = max_new_tokens
    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=[
                images_tensor,
            ],
            do_sample=True if temperature > 0 else False,
            temperature=temperature,
            top_p=top_p,
            num_beams=num_beams,
            max_new_tokens=max_new_tokens,
            use_cache=True,
            stopping_criteria=[stopping_criteria],
        )
    
    outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
    outputs = outputs.strip()
    # print(outputs)
    if outputs.endswith(stop_str):
        outputs = outputs[: -len(stop_str)]
    outputs = outputs.strip()
    # print(outputs)
    return outputs

def get_n_shot_demonstrations(item, n=2, use_random=True, n_random=0):
    if n==0: return [[]]
    if use_random:
        train_idxs = list(np.random.choice(len(train_dataset),n))
        icl_demonstrs = [[train_dataset[idx] for idx in train_idxs]]
    else:
        icl_demonstrs = []
        if n_random:
            train_idxs = list(np.random.choice(len(train_dataset),n_random))
            icl_demonstrs = [train_dataset[idx] for idx in train_idxs]
        icl_demonstrs = [icl_demonstrs + retriever.find([item['image']],n-n_random)[0]]
        
    return icl_demonstrs

def construct_vqa_query(query_items, icl_demonstrs_list):
    querys, im_lists = [], []
    for query_item,icl_demonstrs in zip(query_items, icl_demonstrs_list):
        # query = "Answer the questions in one or two words: "
        query = ""
        images = []
        for item in icl_demonstrs:
            query += f" <image> Question: {item['question']} Short Answer: {item['answers'][0]}"
            images.append(item['image'])
        images.append(query_item['image'])
        query += f"<image> Question: {query_item['question']} Short Answer: "
        querys.append(query)
        im_lists.append(images)
    return querys, im_lists

def get_output(item, n= 2, use_random= False, n_random=0):

    icl_demonstrs = get_n_shot_demonstrations(item, n, use_random, n_random)
    querys, im_lists = construct_vqa_query([item], icl_demonstrs)
    
    return get_output_for_query(querys[0], im_lists[0],"llava_v1",5)




[2024-11-17 05:44:43,358] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)




Fetching 21 files:   0%|          | 0/21 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

In [10]:
n = 4
use_random = False
n_random = 0

In [11]:
out_data = {
    "outputs": []
}

model_name = model_name_or_path.split("/")[-1]
save_path = f"/home/asureddy_umass_edu/cs682/VILA/results/vqa_exp/{model_name}_{n}-shot"
if use_random:
    save_path += "_random-examples"
if n_random:
    save_path += f"{n_random}_random-examples"
save_path += ".json"
# print(args)
# doing for a max of 10k examples
for i in tqdm(range(min(100, len(val_dataset)))):
    out = get_output(val_dataset[i],n,use_random, n_random)
    out_data["outputs"].append(out)

with open(save_path,'w') as f:
    json.dump(out_data, f)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [01:25<00:00,  1.17it/s]


In [2]:
!python vila_e2e_vqa_coco.py --n_shots 8

[2024-11-17 16:27:43,968] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)
Fetching 21 files: 100%|██████████████████████| 21/21 [00:00<00:00, 3344.23it/s]
Loading checkpoint shards: 100%|██████████████████| 6/6 [00:33<00:00,  5.59s/it]
Namespace(model_name_or_path='Efficient-Large-Model/VILA1.5-13b', n_shots=8, use_random=False, n_random=0)
100%|█████████████████████████████████████████| 100/100 [02:17<00:00,  1.37s/it]


In [5]:
!nvidia-smi

Sun Nov 17 20:19:36 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.127.05             Driver Version: 550.127.05     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          On  |   00000000:C0:00.0 Off |                    0 |
| N/A   33C    P0             72W /  400W |    7399MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [3]:
!python vila_e2e_vqa_coco.py --n_shots 8

[2024-11-17 20:07:40,180] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)
Fetching 17 files: 100%|██████████████████████| 17/17 [00:00<00:00, 2150.34it/s]
Loading checkpoint shards: 100%|██████████████████| 2/2 [00:06<00:00,  3.11s/it]
new vqa_coco vila 3b
Namespace(model_name_or_path='Efficient-Large-Model/VILA1.5-3b', n_shots=8, use_random=False, n_random=0)
100%|█████████████████████████████████████████| 100/100 [01:12<00:00,  1.39it/s]


In [4]:
!python vila_e2e_captioning_coco.py --n_shots 8

[2024-11-17 20:10:06,364] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)
Fetching 17 files: 100%|█████████████████████| 17/17 [00:00<00:00, 18510.69it/s]
Loading checkpoint shards: 100%|██████████████████| 2/2 [00:05<00:00,  2.91s/it]
Namespace(model_name_or_path='Efficient-Large-Model/VILA1.5-3b', n_shots=8, use_random=False, n_random=0)
100%|█████████████████████████████████████████| 100/100 [01:36<00:00,  1.03it/s]
