In [16]:
import os
import torch
import requests


from PIL import Image
from io import BytesIO

from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from transformers import AutoTokenizer
from llava.mm_utils import tokenizer_image_token
from llava.model.language_model.llava_llama import LlavaLlamaForCausalLM

device = "cuda"

In [25]:
prompt = "What is in this image?"
image_url = "https://buffer.com/cdn-cgi/image/w=1000,fit=contain,q=90,f=auto/library/content/images/size/w600/2023/10/free-images.jpg"

# Load image from url
response = requests.get(image_url)
image_data = Image.open(BytesIO(response.content))

# Instantiate Model and its encoders

In [3]:
model_name = "liuhaotian/llava-v1.5-7b"

# Instantiate model with the simplest possible settings.
model = LlavaLlamaForCausalLM.from_pretrained(
    model_name,
    torch_dtype = torch.float16, # So it can fit on my a100 better
)

# Text Encoder
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)

# Visual Encoder
vision_tower = model.get_vision_tower()
vision_tower.load_model(device_map='auto')
image_processor = vision_tower.image_processor

You are using a model of type llava to instantiate a model of type llava_llama. This is not supported for all configurations of models and can yield errors.
  return self.fget.__get__(instance, owner)()
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.29s/it]


# Forward Pass

## Generate Input Embeddings

In [26]:
# Turn user prompt into conversation format that vicuna (the LLM piece of LLaVA) is expecting.
prompt = prepare_prompt_into_expected_format(prompt)
prompt

"A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nWhat is in this image? ASSISTANT:"

In [27]:
# Get the textual embeddings for the prompt. Exactly the same if this were an LLM
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
input_ids.shape

torch.Size([1, 49])

In [23]:
# If you're curious about what this function does
def prepare_prompt_into_expected_format(prompt):
    conv = conv_templates["llava_v1"].copy()

    # just one turn, always prepend image token
    inp = DEFAULT_IMAGE_TOKEN + '\n' + prompt
    
    conv.append_message(conv.roles[0], inp)
    conv.append_message(conv.roles[1], None)
    return conv.get_prompt()

In [22]:
# Get the visual embeddings for the corresponding image
image_encodings = image_processor.preprocess(image_data, return_tensors='pt')['pixel_values'].half().cuda()
image_encodings.shape

torch.Size([1, 3, 336, 336])

So this a very important step, as we have used the visual encoder to process an image into a tensor representation, which we will be able to project into a shared space, with the textual embeddings, to generate our output from. This visual encoder for LLaVA, along with many other multimodal models, is CLIP. 

## Clip



## Forward Pass

Now if you have any experience with Neural Networks, you're probably scratching your head right now, because you would expect these embeddings to share the same dimensions. For the unacclimated, deep learning is heavily based on matrix multiplication. These matrix multiplications can be heavily optimized to take advantage of the speed of GPUs. However, they require the inputs to be compatible dimensionally. Attempting to multiply matrices Tensors with mismatched dimensions will lead to the dreaded
```
RuntimeError: stack expects each tensor to be equal size, but got [3, 224, 224] at entry 0 and [3, 224, 336] at entry 3
```
So what is going on?

We'll this is where the projection matrix (the main piece of the multimodal puzzle) comes into play. 

In [24]:
model

LlavaLlamaForCausalLM(
  (model): LlavaLlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaR

In [47]:
model_name = "liuhaotian/llava-v1.5-7b"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)

input_ids = tokenizer("Hello, how are you?", return_tensors='pt').input_ids
input_ids.shape

torch.Size([1, 7])

In [46]:
# The squeeze(0) is to convert 
tokenizer.decode(input_ids.squeeze(0), skip_special_tokens=True)

'Hello, how are you?'

In [48]:
input_ids.squeeze(0).shape

torch.Size([7])