In [1]:
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path
from llava.eval.run_llava import eval_model
import json
import argparse
import torch
from datasets import load_dataset

from llava.constants import (
    IMAGE_TOKEN_INDEX,
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN,
    IMAGE_PLACEHOLDER,
)
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import (
    process_images,
    tokenizer_image_token,
    get_model_name_from_path,
)

from PIL import Image

import requests
from PIL import Image
from io import BytesIO
import re


def image_parser(args):
    out = args.image_file.split(args.sep)
    return out


def load_image(image_file):
    if image_file.startswith("http") or image_file.startswith("https"):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert("RGB")
    else:
        image = Image.open(image_file).convert("RGB")
    return image


def load_images(image_files):
    out = []
    for image_file in image_files:
        image = load_image(image_file)
        out.append(image)
    return out


def eval_model(args,tokenizer, model, image_processor, context_len):
    # Model
    disable_torch_init()

    model_name = get_model_name_from_path(args.model_path)
   

    qs = args.query
    image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
    if IMAGE_PLACEHOLDER in qs:
        if model.config.mm_use_im_start_end:
            qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
        else:
            qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
    else:
        if model.config.mm_use_im_start_end:
            qs = image_token_se + "\n" + qs
        else:
            qs = DEFAULT_IMAGE_TOKEN + "\n" + qs

    if "llama-2" in model_name.lower():
        conv_mode = "llava_llama_2"
    elif "mistral" in model_name.lower():
        conv_mode = "mistral_instruct"
    elif "v1.6-34b" in model_name.lower():
        conv_mode = "chatml_direct"
    elif "v1" in model_name.lower():
        conv_mode = "llava_v1"
    elif "mpt" in model_name.lower():
        conv_mode = "mpt"
    else:
        conv_mode = "llava_v0"

    if args.conv_mode is not None and conv_mode != args.conv_mode:
        print(
            "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
                conv_mode, args.conv_mode, args.conv_mode
            )
        )
    else:
        args.conv_mode = conv_mode

    conv = conv_templates[args.conv_mode].copy()
    conv.append_message(conv.roles[0], qs)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()

    image_files = image_parser(args)
    images = load_images(image_files)
    image_sizes = [x.size for x in images]
    images_tensor = process_images(
        images,
        image_processor,
        model.config
    ).to(model.device, dtype=torch.float16)

    input_ids = (
        tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
        .unsqueeze(0)
        .cuda()
    )

    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=images_tensor,
            image_sizes=image_sizes,
            do_sample=True if args.temperature > 0 else False,
            temperature=args.temperature,
            top_p=args.top_p,
            num_beams=args.num_beams,
            max_new_tokens=args.max_new_tokens,
            use_cache=True,
        )

    outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
    print(outputs)
    return outputs
def remove_prefix(text, prefix):
    if text.startswith(prefix):
        return text[len(prefix):]  # Remove the prefix
    return text 


def main():
    model_path = "liuhaotian/llava-v1.5-7b"
    prefix = "<image>\n"
    finetune = True
    if finetune:
        # Using finetune model.
        # model_checkpoint is the local path of the unzipped folder.
        # model_name doesn't have ./
        model_checkpoint = "./llava-v1.5-7b-task-lora-change"
        tokenizer, model, image_processor, context_len = load_pretrained_model(
            model_path=model_checkpoint,
            model_base="liuhaotian/llava-v1.5-7b",
            model_name="llava-v1.5-7b-task-lora-change"
        )
    else:
        # Original unfine-tuned model.
        tokenizer, model, image_processor, context_len = load_pretrained_model(
            model_path=model_path,
            model_base=None,
            model_name=get_model_name_from_path(model_path)
        )

    
    with open('test.json', 'r') as file:
        dataset = json.load(file)

    result = []
    for entry in dataset:
        image_file = entry['image']  # Extract the image field
        # Iterate through each conversation in the entry
        for conversation in entry['conversations']:
            if conversation['from'] == 'human':
                prompt = conversation['value']
                prompt = remove_prefix(prompt, prefix)
                args = type('Args', (), {
                    "model_path": model_path,
                    "model_base": None,
                    "model_name": get_model_name_from_path(model_path),
                    "query": prompt,
                    "conv_mode": None,
                    "image_file": image_file,
                    "sep": ",",
                    "temperature": 0,
                    "top_p": None,
                    "num_beams": 1,
                    "max_new_tokens": 512
                })()
                output = eval_model(args, tokenizer, model, image_processor, context_len)
                result.append(output)
                torch.save(result, "result-7b-warehouse.pt")
                break


if __name__ == "__main__":
    main()



  from .autonotebook import tqdm as notebook_tqdm


Loading LLaVA from base model...


  return self.fget.__get__(instance, owner)()
Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.11it/s]


Loading additional LLaVA weights...
Loading LoRA weights...
Merging LoRA weights...
Model is loaded...




the man wearing a high visibility jacket is 0.0 inches wide.
the red truck on road is 0.0 inches from the orange location pins in air horizontally.
Incorrect, the grey metal building behind truck is not larger than the wooden pallets on trailer.
the woman in blue dress standing is 0.0 inches tall.
Positioned lower is woman wearing gray shirt.
man wearing a dress shirt is more to the left.
No, the gray metal pipe on ground is further to the viewer.
Incorrect, the gray cartoon character giving thumbs up is not positioned below the cartoon character wearing white gloves.
wooden pallets on floor is above.
man in a white shirt standing is above.
the yellow clipboard with white paper and the black and white book with white writing are 0.0 inches apart vertically.
Incorrect, the cardboard boxes stacked in a car trunk is not on the left side of the open trunk of a car.
In fact, the man in blue uniform loading boxes might be wider or the same width as the stack of cardboard boxes.
No, the grey 