# Visual Large Language Models

> Models used
>> 1. PaliGemma
>> 2. LLaVA-OneVision
>> 3. QwenVL

In [None]:
try:
    import transformers
    print("Transformers is already installed.")
except ImportError:
    print("Transformers not found. Installing...")
    !pip install transformers

In [None]:
# install LLAVA utils
!pip install git+https://github.com/LLaVA-VL/LLaVA-NeXT.git
# install utilities for Qwen Model
!pip install qwen-vl-utils
# install datasets for platinum-bench ds
!pip install datasets

In [None]:
import torch
def memory_stats():
    freeMem, total  = torch.cuda.mem_get_info()
    print(f"GPU memory Total: [{total/1024**2:.2f}] Available: [{freeMem/1024**2:.2f}] Allocated: [{torch.cuda.memory_allocated()/1024**2:.2f}] Reserved: [{torch.cuda.memory_reserved()/1024**2:.2f}]")
memory_stats()

In [None]:
# if there is A100 GPU we can use FA, this notebook however is using P100/T4.
# !pip install flash-attn --no-build-isolation

# For accessing PaliGemma we need to provide Huggingface login/token
### Use this instructions to get login token: https://huggingface.co/docs/hub/en/security-tokens

In [None]:
from huggingface_hub import notebook_login, login
# notebook_login()
login("hf_TSVziphdYDosydgusygdfcnasdkfgbsldiufgbkaygfuyg")  # <------- insert your token here(this is a dummy token, please use your own token)

In [None]:
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX
from llava.conversation import conv_templates, SeparatorStyle

from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info

import matplotlib.pyplot as plt

from PIL import Image
import requests
import copy
import sys
import warnings
import gc
import copy

def getCocoImageURL(imagePath):
    """
    Jaimin: Instead of downloading entire val COCO data just map the image file name to COCO url for easy access. SMART! :)
    """
    imagename = imagePath.split("/")[-1]
    imageURL = f"http://images.cocodataset.org/val2014/{imagename}"
    return imageURL

def showImage(imageURL):
  image = Image.open(requests.get(imageURL, stream=True).raw)
  plt.imshow(image)


import gc

def clean_mem(dsOnly=False):
    """
    Jaimin: This function helps freeing up cuda mem for loading other models
    dsOnly: cleans up dataset only that is loaded into cuda mem
    """
    global model, ds
    
    if not dsOnly:
        try:
            del model
        except: pass
    
    try:
        del ds
    except: pass

    gc.collect()
    torch.cuda.empty_cache()

# Prepare Inference functions for all 3 models

In [None]:
class VqAModels:
  def __init__(self, model_name="llava", max_img_size=1280, printMemUsage=False) -> None:
    self.max_img_size = max_img_size
    self.device = "cuda"
    self.device_map = "auto"
    self.printMemUsage = printMemUsage

    if model_name == "llava":
      self.model_name = "llava_v1_5"
      self.load_llava_model()
      self.inference = self.llava_inference
    elif model_name == "paliGemma":
      self.model_name = "paliGemma"
      self.load_paliGemma_model()
      self.inference = self.paliGemma_inference
    elif model_name == "qwen-vl":
      self.model_name = "qwen_vl"
      self.load_qwen_vl_model(max_image_size=self.max_img_size)
      self.inference = self.qwen_vl_inference
    else:
      raise ValueError("Possible values are: llava,paliGemma,qwen-vl")


  def load_llava_model(self):
    warnings.filterwarnings("ignore")
    pretrained = "lmms-lab/llava-onevision-qwen2-0.5b-ov"
    model_name = "llava_qwen"

    if self.printMemUsage:
        print("\n\nbefore loading model")
        memory_stats()

    self.tokenizer, self.model, self.image_processor, self.max_length = load_pretrained_model(pretrained, None, model_name, device_map=self.device_map, attn_implementation=None)  # disable flash_attn for colab since T4 is not supported.
    self.model.eval()

    if self.printMemUsage:
        print(f"\n\nLoaded model: {pretrained}")
        memory_stats()

  def load_paliGemma_model(self):
    model_id = "google/paligemma-3b-mix-224"
    self.device = "cuda:0"
    dtype = torch.bfloat16

    if self.printMemUsage:
        print("\n\nbefore loading model")
        memory_stats()

    self.model = PaliGemmaForConditionalGeneration.from_pretrained(
        model_id,
        torch_dtype=dtype,
        device_map=self.device,
        revision="bfloat16",
    ).eval()

    self.processor = AutoProcessor.from_pretrained(model_id)

    if self.printMemUsage:
        print("\n\nAfter loading model")
        memory_stats()

  def load_qwen_vl_model(self, model_name="Qwen/Qwen2-VL-2B-Instruct",max_image_size=1280):
    # default: Load the model on the available device(s)
    self.model = Qwen2VLForConditionalGeneration.from_pretrained(
        model_name, torch_dtype="auto", device_map="auto"
    )

    min_pixels = 256*28*28
    max_pixels = max_image_size*28*28
    self.processor = AutoProcessor.from_pretrained(model_name, min_pixels=min_pixels, max_pixels=max_pixels)
    self.device = "cuda"
    self.device_map = "auto"


  def llava_inference(self, image_url, question, max_tokens=4096, isVQA=False):

    if image_url is None:
      image = Image.new("RGB", (224, 224), (255,255,255))
    else:
      image = Image.open(requests.get(image_url, stream=True).raw)
    image_tensor = process_images([image], self.image_processor, self.model.config)
    image_tensor = [_image.to(dtype=torch.float16, device=self.device) for _image in image_tensor]

    conv_template = "qwen_1_5"  # Make sure you use correct chat template for different models
    question = DEFAULT_IMAGE_TOKEN + f"\n{question}"
    conv = copy.deepcopy(conv_templates[conv_template])
    conv.append_message(conv.roles[0], question)
    conv.append_message(conv.roles[1], None)
    prompt_question = conv.get_prompt()

    input_ids = tokenizer_image_token(prompt_question, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self.device)
    image_sizes = [image.size]


    cont = self.model.generate(
        input_ids,
        images=image_tensor,
        image_sizes=image_sizes,
        do_sample=False,
        temperature=0,
        max_new_tokens=max_tokens,
    )
    text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)

    return text_outputs

  def paliGemma_inference(self, image_url, question, isVQA=False, max_tokens=1000):
    if not isVQA:
      image = Image.new("RGB", (224, 224), (0,0,0))
    else:
      image = Image.open(requests.get(image_url, stream=True).raw)

    prompt = f"<image>{question}"
    model_inputs = self.processor(text=prompt, images=image, return_tensors="pt").to(self.model.device)

    input_len = model_inputs["input_ids"].shape[-1]

    with torch.inference_mode():
        generation = self.model.generate(**model_inputs, max_new_tokens=max_tokens, do_sample=True)
        generation = generation[0][input_len:]
        decoded = self.processor.decode(generation, skip_special_tokens=True)

        return decoded

  def qwen_vl_inference(self, image_url, question, isVQA=False, max_tokens=1000):

    if not isVQA:
      messages = [
          {
              "role": "user",
              "content": [
                  {"type": "text", "text": question},
              ],
          }
      ]
    else:
      messages = [
          {
              "role": "user",
              "content": [
                  {
                      "type": "image",
                      "image": image_url,
                  },
                  {"type": "text", "text": question},
              ],
          }
      ]

    # Preparation for inference
    text = self.processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = self.processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to(self.device)

    # Inference: Generation of the output
    generated_ids = self.model.generate(**inputs, max_new_tokens=max_tokens)
    generated_ids_trimmed = [
        out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = self.processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    # print(output_text)

    # print("\n\nafter model inference")
    # memory_stats()

    return output_text

# Sample Inference

In [None]:
# warm up, download all models and see mempry usage by each model
for modelName in ["paliGemma", "llava", "qwen-vl"]:
    print("#"*10, {modelName}, "#"*10)
    for i in range(2): clean_mem()
    model = VqAModels(model_name=modelName, max_img_size=720, printMemUsage=True)

In [None]:
# image we want to infer
showImage("http://images.cocodataset.org/val2017/000000039769.jpg")

In [None]:
import time
for modelName in ["paliGemma", "llava", "qwen-vl"]:
    for i in range(2): clean_mem()
    model = VqAModels(model_name=modelName, max_img_size=720)

    ImageURL = "http://images.cocodataset.org/val2017/000000039769.jpg"
    Question = "Describe the provided image in detail"
    Start = time.perf_counter()
    answer = model.inference(image_url=ImageURL, question=Question, isVQA=True, max_tokens=1000)
    End = time.perf_counter()
    print("#"*10, {modelName}, "#"*10)
    print(f"User:{Question}\n{modelName}:{answer}\nTook {End-Start} seconds")
    print("#"*30)

# load benchmark dataset

In [None]:
from datasets import load_dataset

In [None]:
# parasing function used from original repo: https://github.com/MadryLab/platinum-benchmarks/blob/main/src/run_benchmark.py
import re

def parse_fn_vqa(response):
    first_word = response.strip("\n").split("\n")[0].split(" ")[0].lower().strip().strip('.').strip(",")
    last_word = response.strip("\n").split("\n")[-1].split(" ")[-1].lower().strip().strip('.').strip(",")

    pattern_yes = r'\byes\b'
    pattern_no = r'\bno\b'

    yes_exists = re.search(pattern_yes, response, flags=re.IGNORECASE)
    no_exists = re.search(pattern_no, response, flags=re.IGNORECASE)
    
    if yes_exists and not no_exists:
        return 'yes'
    elif no_exists and not yes_exists:
        return 'no'
    if first_word in ['yes', 'no']:
        return first_word
    if last_word in ['yes', 'no']:
        return last_word
    else:
        return "Parsing error"

def check_prediction(prediction, platinum_target, dataset_name):
    math_datasets = ['math_eval__multiarith', 'math_eval__singleop', 'math_eval__singleq', 'gsm8k', 'svamp',
                 'multiarith', 'singleop', 'singleq', 'bbh_object_counting']
    if dataset_name in math_datasets and prediction != 'Parsing error':
        correct = float(platinum_target[0]) == float(prediction)
    else:
        correct = prediction in platinum_target
    return correct

def get_parse_fn(parsing_strategy):
    def parse_fn_math(output):
        """Used for singleop, singleeq, multiarith, gsm8k, and svamp"""
        return re.sub(r"\.0+$", "", (re.search(r'-?[0-9.]*[0-9]', output.replace('*','').replace('#','').lower().split('answer: ')[-1].replace(',', '')).group()))

    def parse_fn_multiple_choice(output):
        """Used for mmlu math and winograd schema challenge"""
        #return output.replace('*', '').lower().split("answer: ")[-1].replace(".", "").strip()[0:1].lower()

        x = output.replace('*', '').lower().split("answer: ")[-1].replace(".", "").strip()

        pattern = r'\\boxed\{([^}]+)\}'
        match = re.search(pattern, x)
        if match:
            return match.group(1)[0:1]
        else:
            return x[0:1]

    def parse_bbh_multiple_choice(output):
        """Used for BBH multiple choice questions, where the answer is in the form (A)"""
        result = output.replace('*', '').replace('#', '').lower().split('answer: ')[-1].replace('.', '').replace('\'', '').replace('\"', '').strip().lower()
        result = re.search(r'\([a-z]\)', result).group(0)
        return result

    def parse_fn_text(output):
        """Used by DROP and hotpotqa, where the answer is a string"""
        return (output.replace("#","").replace("*","").replace("\"", "").replace('\xa0', ' ')
                      .lower().split("answer: ")[-1].split('\n')[0].replace(",", "")
                      .replace(".","").split("}")[0].strip())

    def parse_fn_squad(output):
        """Like rext parsing, but explicitly handles the case when there is text after n/a"""
        output_clean = parse_fn_text(output)
        if output_clean.startswith('n/a '):
            return 'n/a'
        return output_clean

    def create_parse_fn(specific_parsing_fn):

        def parse_fn(output):
            # tex_pattern = r'\\boxed\{([^{}]+)\}|\\boxed\{\\text\{([^}]+)\}\}'
            tex_pattern = r'\\boxed\{(\\text\{)?([^\\{}]+)\}'

            # If answer is on the last line as expected, run as usual
            if "answer:" in output.lower().replace("*", ""):
                # If the answer is wrapped in latex (e.g., \boxed{...}), extract the content
                answer_section = output.lower().split("answer: ")[-1]
                if re.search(tex_pattern, answer_section):
                    match = re.search(tex_pattern, answer_section).group(2)
                    output = "Answer: " + match
            elif re.search(tex_pattern, output):
                # If the answer is not on the last line, try to recover by looking for a box
                output = "Answer: " + re.search(tex_pattern, output).group(2)
            else:
                # Otherwise, just return the last line
                last_line = output.strip("\n").split("\n")[-1].lower()
                output = "Answer: " + last_line
            return specific_parsing_fn(output)

        return parse_fn


    if parsing_strategy == 'math':
        return create_parse_fn(parse_fn_math)
    elif parsing_strategy == 'multiple_choice':
        return create_parse_fn(parse_fn_multiple_choice)
    elif parsing_strategy == 'bbh_multiple_choice':
        return create_parse_fn(parse_bbh_multiple_choice)
    elif parsing_strategy == 'text':
        return create_parse_fn(parse_fn_text)
    elif parsing_strategy == 'squad':
        return create_parse_fn(parse_fn_squad)
    else:
        raise ValueError(f"Invalid parsing strategy: {parsing_strategy}")

In [None]:
from tqdm import tqdm

def eval_func(model, ds, isVQA=False, isPali=False, prompt_type="platinum_prompt"):
    """
    This is eval function which will evaluate GSM8K and VQA platinum ds.
    """
    error_count = 0
    total_count = len(ds)
    counter = 0
    results = {"ds_data": {}}
    for ds_instance in tqdm(ds):
        cleaning_status = ds_instance["cleaning_status"]
        platinum_target = ds_instance['platinum_target']
        Question= ds_instance[prompt_type]
        if isPali:
            Question = f"QA en, {Question}"
    
        try:
            ImageURL = getCocoImageURL(ds_instance["image_path"])
        except:
            # print(ds_instance)
            ImageURL = None

        answer = model.inference(image_url=ImageURL, question=Question, isVQA=isVQA)

        try:
            parsing_strategy = ds_instance['platinum_parsing_strategy']
        except:
            parsing_strategy = ds_instance['platinum_parsing_stratagy']
            
        
        if isVQA:
            parse_fn = parse_fn_vqa
        else:
            parse_fn = get_parse_fn(parsing_strategy)

        try:
            if type(answer)==list:
              answer = answer[0]
    
            prediction = parse_fn(answer)

            if isVQA:
                if prediction not in ds_instance['platinum_target']:
                    correct =  False
                else:
                    correct = True
            else:
                correct = check_prediction(prediction, platinum_target=platinum_target, dataset_name="gsm8k")
        except Exception as e:
            # print(e)
            prediction = 'parsing error'
            correct = False


        results["ds_data"][counter] = {}
        
        results["ds_data"][counter] = {
            "raw_answer": answer,
            "parsed answer": prediction,
            "platinum_target": platinum_target,
            "correct": correct,
            "prompt": ds_instance[prompt_type],
            "prompt_type": prompt_type,
            "cleaning_status": cleaning_status
        }
        
        if not correct:
          error_count +=1
        counter +=1
        
        # if counter>5:
        #   break

        results["summary"] = {
          "total": total_count,
          "error": error_count,
          "accuracy": 1 - error_count/total_count
        }

    print(f"total : {total_count} error: {error_count}")
    return results


# Eval paliGemma

In [None]:
all_results = {"vqa": {
    "llava": {},
    "pali": {},
    "qwen2": {}
}, "gsm8k": {
    "llava": {},
    "pali": {},
    "qwen2": {}
}}

In [None]:
memory_stats()
for i in range(2): clean_mem()
ModelName= "pali"
model = VqAModels(model_name="paliGemma", max_img_size=720)
memory_stats()

for ds_name, isVQA in [["vqa",True],["gsm8k", False]]:
    clean_mem(dsOnly=True)
    ds = load_dataset("madrylab/platinum-bench", name=ds_name, split="test") # or another subset
    ds = ds.filter(lambda x: x['cleaning_status'] != 'rejected') # filter out rejected questions
    results = eval_func(model, ds, isVQA=isVQA, isPali=False)
    all_results[ds_name][ModelName] = copy.deepcopy(results)
    print(ds_name, results["summary"])
    clean_mem(dsOnly=True)

In [None]:
all_results["vqa"]["pali"]["summary"], all_results["gsm8k"]["pali"]["summary"]

# Eval LLAVA

In [None]:
memory_stats()
for i in range(2): clean_mem()
ModelName= "llava"
model = VqAModels(model_name=ModelName, max_img_size=720)
memory_stats()

for ds_name, isVQA in [["vqa",True],["gsm8k", False]]:
    clean_mem(dsOnly=True)
    ds = load_dataset("madrylab/platinum-bench", name=ds_name, split="test") # or another subset
    ds = ds.filter(lambda x: x['cleaning_status'] != 'rejected') # filter out rejected questions
    results = eval_func(model, ds, isVQA=isVQA, isPali=False)
    all_results[ds_name][ModelName] = copy.deepcopy(results)
    print(ds_name, results["summary"])
    clean_mem(dsOnly=True)

In [None]:
all_results["vqa"][ModelName]["summary"], all_results["gsm8k"][ModelName]["summary"]

# Eval Qwen

In [None]:
memory_stats()
for i in range(2): clean_mem()
model = VqAModels(model_name="qwen-vl", max_img_size=720)
memory_stats()

ModelName= "qwen2"
for ds_name, isVQA in [["vqa",True],["gsm8k", False]]:
    clean_mem(dsOnly=True)
    ds = load_dataset("madrylab/platinum-bench", name=ds_name, split="test") # or another subset
    ds = ds.filter(lambda x: x['cleaning_status'] != 'rejected') # filter out rejected questions
    results = eval_func(model, ds, isVQA=isVQA, isPali=False)
    all_results[ds_name][ModelName] = copy.deepcopy(results)
    print(ds_name, results["summary"])
    clean_mem(dsOnly=True)

In [None]:
all_results["vqa"][ModelName]["summary"], all_results["gsm8k"][ModelName]["summary"]

# Summary

### Some example output 

In [None]:
all_results["vqa"]["pali"]["ds_data"][0]

In [None]:
all_results["gsm8k"]["qwen2"]["ds_data"][0]

In [None]:
data = {}
for ds_name_,ds_data in all_results.items():
    # print(ds_name_)
    if ds_name_ not in data.keys():
        data[ds_name_] = {}
        
    for model_name, model_data in ds_data.items():
        if model_name not in data[ds_name_].keys():
            data[ds_name_][model_name] = {"accuracy": model_data['summary']["accuracy"]}
            
        # print(ds_name_,model_name, model_data["summary"])
        print(f"Dataset: [{ds_name_}] Model: [{model_name}] Eval: {model_data['summary']}")

In [None]:
# plot eval data
import matplotlib.pyplot as plt
import numpy as np

datasets = list(data.keys())
models = list({model for dataset in data.values() for model in dataset})

# Set width for bars
bar_width = 0.2
x = np.arange(len(datasets))

fig, ax = plt.subplots(figsize=(8, 6))

for i, model in enumerate(models):
    accuracies = [data[ds][model]["accuracy"] for ds in datasets]
    ax.bar(x + i * bar_width, accuracies, bar_width, label=model)

ax.set_xlabel("Datasets")
ax.set_ylabel("Accuracy")
ax.set_title("Model Accuracy Comparison Across Datasets and Models")
ax.set_xticks(x + bar_width)
ax.set_xticklabels(datasets)
ax.legend()

plt.show()

In [None]:
QwenScalingResults = {}

# Qwen with Scaling on VQA

1.1 Inference time increases as well as accuracy gradually increases

In [None]:
for img_resolution in [256,360,480,720,1280]:
    memory_stats()
    for i in range(2): clean_mem()
    model = VqAModels(model_name="qwen-vl", max_img_size=img_resolution)
    memory_stats()

    if img_resolution not in QwenScalingResults:
        QwenScalingResults[img_resolution] = {}

    for ds_name, isVQA in [["vqa",True]]:
        if ds_name not in  QwenScalingResults[img_resolution].keys():
            QwenScalingResults[img_resolution][ds_name] = {}

        if ModelName not in  QwenScalingResults[img_resolution][ds_name].keys():
            QwenScalingResults[img_resolution][ds_name][ModelName] = {}
        
        clean_mem(dsOnly=True)
        ds = load_dataset("madrylab/platinum-bench", name=ds_name, split="test") # or another subset
        ds = ds.filter(lambda x: x['cleaning_status'] != 'rejected') # filter out rejected questions
        results = eval_func(model, ds, isVQA=isVQA, isPali=False)
        QwenScalingResults[img_resolution][ds_name][ModelName] = copy.deepcopy(results)
        print(ds_name,img_resolution, results["summary"])
        clean_mem(dsOnly=True)

In [None]:
scaling_data = {}
for img_resolution in [256,360,480,720,1280]:
    print(img_resolution, QwenScalingResults[img_resolution]["vqa"]["qwen2"]["summary"])

    if img_resolution not in scaling_data.keys():
        scaling_data[img_resolution] = {}

    datasets = QwenScalingResults[img_resolution]
    for ds_name_,ds_data in datasets.items():
        if ds_name_!="vqa":
            continue
        # print(ds_name_)
        if ds_name_ not in scaling_data[img_resolution].keys():
            scaling_data[img_resolution][ds_name_] = {}
            
        for model_name, model_data in ds_data.items():
            if model_name not in scaling_data[img_resolution][ds_name_].keys():
                scaling_data[img_resolution][ds_name_][model_name] = {"accuracy": model_data['summary']["accuracy"]}

In [None]:
import matplotlib.pyplot as plt

# Data
resolutions = []
accuracies = []

for resol, eval_data in scaling_data.items():
    resolutions.append(resol)
    accuracies.append(eval_data["vqa"]["qwen2"]["accuracy"])

model_name = "qwen2"

# Plot
plt.figure(figsize=(8, 6))
plt.plot(resolutions, accuracies, marker='o', linestyle='-', label=model_name)
plt.xlabel("Resolution")
plt.ylabel("Accuracy")
plt.title("Accuracy vs Resolution for Model: " + model_name)
plt.xticks(resolutions)
plt.legend()
plt.grid()

plt.show()

# scaling for GSM ds? 
## Use higher Param model like 4b,7b,13b