In [1]:
import os
import torch
os.environ["HF_DATASETS_CACHE"] = "./hf_cache"
os.environ["HF_HOME"] = "./hf_cache"

In [2]:
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path
from llava.eval.run_llava import eval_model

model_path = "liuhaotian/llava-v1.5-7b"

tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path=model_path,
    model_base=None,
    model_name=get_model_name_from_path(model_path),
    load_8bit=True
)

  from .autonotebook import tqdm as notebook_tqdm
You are using a model of type llava to instantiate a model of type llava_llama. This is not supported for all configurations of models and can yield errors.
  return self.fget.__get__(instance, owner)()
Loading checkpoint shards: 100%|██████████| 2/2 [01:14<00:00, 37.30s/it]


In [5]:
from llava.mm_utils import process_images
from PIL import Image
import requests
from io import BytesIO
def load_image(image_file):
    if image_file.startswith('http://') or image_file.startswith('https://'):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert('RGB')
    else:
        image = Image.open(image_file).convert('RGB')
    return image


In [12]:
image = "https://patentimages.storage.googleapis.com/23/82/97/d28e8e17f1597a/US20140002137A1-20140102-D00002.png"
image = load_image(image) # note this is a PIL image
image_size = image.size
image_tensor = process_images([image], image_processor, model.config) # 
if type(image_tensor) is list:
    image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
else:
    image_tensor = image_tensor.to(model.device, dtype=torch.float16)

In [13]:
def infer_conv_mode(model_name):
    if "llama-2" in model_name.lower():
        conv_mode = "llava_llama_2"
    elif "mistral" in model_name.lower():
        conv_mode = "mistral_instruct"
    elif "v1.6-34b" in model_name.lower():
        conv_mode = "chatml_direct"
    elif "v1" in model_name.lower():
        conv_mode = "llava_v1"
    elif "mpt" in model_name.lower():
        conv_mode = "mpt"
    else:
        conv_mode = "llava_v0"
    return conv_mode

In [20]:
inp = "Use this image to classify the patent into a class in the International Patent Classification system. Answer with the a 4-character IPC class, do not answer with anything else."

In [22]:
from llava.conversation import conv_templates, SeparatorStyle
conv = conv_templates[infer_conv_mode(model_path)].copy()

if image is not None:
    if model.config.mm_use_im_start_end:
        inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
    else:
        inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
    conv.append_message(conv.roles[0], inp)
    image = None
else:
    conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()

In [23]:
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
# streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

In [24]:
with torch.inference_mode():
    output_ids = model.generate(
      input_ids, 
      images = image_tensor,
      image_sizes = [image_size],
      temperature = 0.7,
      max_new_tokens = 2048,
      use_cache = True
    )
outputs = tokenizer.decode(output_ids[0]).strip()
print(outputs)




<s> G06F19/00</s>
