In [None]:
import torch
import os
import random
import sys

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

import numpy as np
from tqdm import tqdm

from datasets import load_dataset, concatenate_datasets
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path

from argparse import ArgumentParser

from llava.data_utils.data_utils import load_yaml, construct_prompt, save_json, process_single_sample, CAT_SHORT2LONG
from llava.data_utils.model_utils import call_llava_engine_df, llava_image_processor
from llava.data_utils.eval_utils import parse_multi_choice_response, parse_open_response
from llava.data_utils.set_seed import set_seed

device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
set_seed(42)

processor = None
call_model_engine = call_llava_engine_df
vis_process_func = llava_image_processor

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# Load Dataset

sub_dataset_list = []
for subject in CAT_SHORT2LONG.values():
    sub_dataset = load_dataset("MMMU/MMMU", subject, split="validation")
    sub_dataset_list.append(sub_dataset)

# merge all dataset
dataset = concatenate_datasets(sub_dataset_list)

In [None]:
# load model
model_name = get_model_name_from_path("liuhaotian/llava-v1.5-13b")
tokenizer, model, vis_processors, _ = load_pretrained_model("liuhaotian/llava-v1.5-13b", None, model_name, load_4bit=True)

In [None]:
arr_easy_expl = []
for i in range(900):
    if dataset[i]["explanation"] != '' and dataset[i]["topic_difficulty"] == "Easy":
        arr_easy_expl.append(dataset[i])

In [None]:
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.data_utils.model_utils import tokenizer_image_token, deal_with_prompt
import gc
gc.collect()
torch.cuda.empty_cache()
n = 6
sample = process_single_sample(arr_easy_expl[n])

#sample = construct_prompt(sample, config)
if sample["image"]:
    sample["image"] = vis_process_func(sample["image"], vis_processors).to(device)

prompt = sample["question"] + sample["options"] + "Select the correct answer and reason in one short sentence why it is correct."
#prompt = "What is on the image?"
conv = conv_templates["vicuna_v1"].copy()
conv.append_message(conv.roles[0], prompt)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
prompt = deal_with_prompt(prompt, model.config.mm_use_im_start_end)
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()
image = sample["image"]

output_ids = model.generate(
    input_ids,
    images=torch.zeros_like(image).unsqueeze(0).half().cuda(),
    do_sample=True,
    temperature=1,
    top_p=None,
    num_beams=5,
    max_new_tokens=128,
    use_cache=False,
)

response = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]