## Demo of LLaVA model for image description

#### Install LLaVA from official Git repo

In [1]:
# clone https://github.com/haotian-liu/LLaVA and run the following
# cd LLaVA
# ! pip install -e .

#### Do the necessary installs 

In [None]:
# ! pip install torch==2.1.0 torchvision==0.16.0
# ! pip uninstall bitsandbytes -y

#### Import all the neccesary packages and code

In [2]:
import torch

from llava.serve.cli import load_image
from llava.constants import IMAGE_TOKEN_INDEX
from llava.model.builder import load_pretrained_model
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria

#### Load the pretrained model and process a given image

In [3]:
model_path = "liuhaotian/llava-v1.5-7b"
model_base = None
device = "cuda"
image_file = "https://llava-vl.github.io/static/images/view.jpg"

model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base, model_name, device=device)

image = load_image(image_file)
image_tensor = image_processor([image], return_tensors='pt')['pixel_values']
image_tensor = image_tensor.to(model.device, dtype=torch.float16)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

#### Generate response based on the input prompt

In [4]:
prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image> Can you describe this picture in detail?\nASSISTANT:"

input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
stopping_criteria = KeywordsStoppingCriteria(["</s>"], tokenizer, input_ids)

with torch.inference_mode():
    output_ids = model.generate(
        input_ids,
        images=image_tensor,
        do_sample=True,
        temperature=0.2,
        max_new_tokens=512,
        use_cache=True,
        stopping_criteria=[stopping_criteria])

output = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
print("Output: ", output)

Output:  The image features a pier extending over a large body of water, possibly a lake or a river. The pier is made of wood and has a bench situated on it, providing a place for people to sit and enjoy the view. The surrounding area is filled with trees, creating a serene and peaceful atmosphere.

In the background, there are mountains visible, adding to the picturesque scenery. The pier is located near the water's edge, allowing visitors to appreciate the beauty of the landscape and the tranquility of the water.</s>
