### LLaVA-MORE: 
Enhancing Visual Instruction Tuning with LLaMA 3.1 

In [1]:
## fix warning
## "bash: /miniconda/envs/llava-more/lib/libtinfo.so.6: no version information available (required by bash)"
## conda install -c conda-forge ncurses

In [2]:
## fix widgets
#!conda install -c conda-forge ipywidgets -y
#!jupyter nbextension enable --py widgetsnbextension

In [19]:
#!pip install flash-attn --no-build-isolation
#!pip install ipywidgets
#!pip install transformers==4.44.2

# from huggingface_hub import notebook_login
# notebook_login()

## Evaluation method-1

In [20]:
# %%bash
# export PYTHONPATH=.
# export HF_TOKEN="token"
# # tokenizer_model_path (local)
# export TOKENIZER_PATH="model/LLaVA_MORE-llama_3_1-8B-finetuning"

# time python -u llava/eval/run_llava.py

In [5]:
!python --version

Python 3.10.14


In [6]:
import transformers
print("transformers = ",transformers.__version__)
import torch
print ("torch =", torch.__version__)

transformers =  4.44.2
torch = 2.3.1


## Evaluation method-2  Inline

In [7]:
#!pip install python-dotenv
%load_ext dotenv
%dotenv

aimagelab/LLaVA_MORE-llama_3_1-8B-finetuning

In [8]:
import os
import json
import warnings
import shutil

from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
import torch
from llava.model import *
from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN


def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", 
                          use_flash_attn=False, mlp_path=None, **kwargs):
    kwargs = {"device_map": device_map, **kwargs}

    if device != "cuda":
        kwargs['device_map'] = {"": device}

    if load_8bit:
        kwargs['load_in_8bit'] = True
    elif load_4bit:
        kwargs['load_in_4bit'] = True
        kwargs['quantization_config'] = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type='nf4'
        )
    else:
        kwargs['torch_dtype'] = torch.float16

    if use_flash_attn:
        kwargs['attn_implementation'] = 'flash_attention_2'

    if 'llava' in model_name.lower():
        # Load LLaVA model
        if 'lora' in model_name.lower() and model_base is None:
            warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
            config_path = os.path.join(model_path, 'config.json')
            try:
                with open(config_path) as f:
                    configuration= json.load(f)
                model_base = configuration['_name_or_path']
            except:
                raise ValueError('Cannot find the model name in the configuration file. Please provide the `model_base` argument.')
            
        if 'lora' in model_name.lower() and model_base is not None:
            from llava.model.language_model.llava_llama import LlavaConfig
            lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path)
            tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
            print('Loading LLaVA from base model...')
            model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
            token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
            if model.lm_head.weight.shape[0] != token_num:
                model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
                model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))

            print('Loading additional LLaVA weights...')
            if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
                non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
            else:
                # this is probably from HF Hub
                from huggingface_hub import hf_hub_download
                def load_from_hf(repo_id, filename, subfolder=None):
                    cache_file = hf_hub_download(
                        repo_id=repo_id,
                        filename=filename,
                        subfolder=subfolder)
                    return torch.load(cache_file, map_location='cpu')
                non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
            non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
            if any(k.startswith('model.model.') for k in non_lora_trainables):
                non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
            model.load_state_dict(non_lora_trainables, strict=False)

            from peft import PeftModel
            print('Loading LoRA weights...')
            model = PeftModel.from_pretrained(model, model_path)
            print('Merging LoRA weights...')
            model = model.merge_and_unload()
            print('Model is loaded...')
        elif model_base is not None:
            # this may be mm projector only
            print('Loading LLaVA from base model...')
            if 'mpt' in model_name.lower():
                if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')):
                    shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py'))
                tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
                cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
                model = LlavaMptForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
            else:
                tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
                cfg_pretrained = AutoConfig.from_pretrained(model_path)
                model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)

            mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
            mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
            model.load_state_dict(mm_projector_weights, strict=False)
        else:
            if 'mpt' in model_name.lower():
                tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
                model = LlavaMptForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
            elif 'mistral' in model_name.lower():
                tokenizer = AutoTokenizer.from_pretrained(model_path)
                model = LlavaMistralForCausalLM.from_pretrained(
                    model_path,
                    low_cpu_mem_usage=True,
                    **kwargs
                )
            else:
                # some old checkpoints may not have the siglip parameter in configuration file                
                tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
                model = LlavaLlamaForCausalLM.from_pretrained(
                    model_path,
                    low_cpu_mem_usage=True,
                    **kwargs
                )
    else:
        # Load language model
        if model_base is not None:
            # PEFT model
            from peft import PeftModel
            tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
            model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
            print(f"Loading LoRA weights from {model_path}")
            model = PeftModel.from_pretrained(model, model_path)
            print(f"Merging weights")
            model = model.merge_and_unload()
            print('Convert to FP16...')
            model.to(torch.float16)
        else:
            use_fast = False
            if 'mpt' in model_name.lower():
                tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
                model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
            else:
                tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
                model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)

    image_processor = None

    if 'llava' in model_name.lower() or mlp_path is not None:
        mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
        mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
        if mm_use_im_patch_token:
            tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
        if mm_use_im_start_end:
            tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
        model.resize_token_embeddings(len(tokenizer))

        # # RE-CONTROL THE IF STATEMENT select the correct class, considering also S2
        # if 'siglip' in model.config.mm_vision_tower and hasattr(model.config, 's2'):
        #     # change the args for the new class
        #     from llava.model.multimodal_encoder.builder import SigLIPVisionTowerS2
        #     vision_tower = SigLIPVisionTowerS2('google/siglip-so400m-patch14-384', args=model.config)
        #     model.model.vision_tower = vision_tower.vision_tower

        # elif 'siglip' in model.config.mm_vision_tower:
        #     from llava.model.multimodal_encoder.builder import SigLIPVisionTower
        #     vision_tower= SigLIPVisionTower('google/siglip-so400m-patch14-384', args=model.config)
        #     vision_tower.to("cuda", dtype=torch.float16)
        #     model.model.vision_tower = vision_tower.vision_tower

        # else:
        #     if hasattr(model.config, 's2'): # and on work
        #         from llava.model.multimodal_encoder.builder import CLIPVisionTowerS2
        #         vision_tower= CLIPVisionTowerS2('google/siglip-so400m-patch14-384', args=model.config)
        #         model.model.vision_tower = vision_tower.vision_tower
        #     else:
        #         vision_tower = model.get_vision_tower()
        vision_tower = model.get_vision_tower()

        if not vision_tower.is_loaded:
            vision_tower.load_model(device_map=device_map)
        if device_map != 'auto':
            vision_tower.to(device=device_map, dtype=torch.float16)
        image_processor = vision_tower.image_processor

    if hasattr(model.config, "max_sequence_length"):
        context_len = model.config.max_sequence_length
    else:
        context_len = 2048

    if mlp_path is not None:
        print('Loading mm projector weights...')
        mm_projector_weights = torch.load(mlp_path)
        new_dict= {}
        new_keys= ['0.weight', '0.bias', '2.weight', '2.bias']
        for el, key in enumerate(new_keys):
            new_dict[key] = mm_projector_weights[list(mm_projector_weights.keys())[el]]

        model.model.mm_projector.load_state_dict(new_dict)
        # model.model.mm_projector.to(device=device_map, dtype=torch.float16)

    return tokenizer, model, image_processor, context_len

In [17]:
##%tb
import argparse
import torch
import requests
from PIL import Image
from io import BytesIO
import re

from llava.constants import (
    IMAGE_TOKEN_INDEX,
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN,
    IMAGE_PLACEHOLDER,
)
from llava.conversation import conv_templates, SeparatorStyle
##from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import (
    process_images,
    tokenizer_image_token,
    get_model_name_from_path,
)

def image_parser(args):
    out = args.image_file.split(args.sep)
    return out


def load_image(image_file):
    if image_file.startswith("http") or image_file.startswith("https"):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert("RGB")
    else:
        image = Image.open(image_file).convert("RGB")
    return image


def load_images(image_files):
    out = []
    for image_file in image_files:
        image = load_image(image_file)
        out.append(image)
    return out

def eval_model(args):
    # Model
    disable_torch_init()

    model_name = get_model_name_from_path(args.model_path)
    model_name= 'llava'

    tokenizer, model, image_processor, context_len = load_pretrained_model(
        args.model_path, args.model_base, model_name, load_8bit=True
    )

    print(model.config)
    
    qs = args.query
    image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
    if IMAGE_PLACEHOLDER in qs:
        if model.config.mm_use_im_start_end:
            qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
        else:
            qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
    else:
        if model.config.mm_use_im_start_end:
            qs = image_token_se + "\n" + qs
        else:
            qs = DEFAULT_IMAGE_TOKEN + "\n" + qs

    if "llama-2" in model_name.lower():
        conv_mode = "llava_llama_2"
    elif "mistral" in model_name.lower():
        conv_mode = "mistral_instruct"
    elif "v1.6-34b" in model_name.lower():
        conv_mode = "chatml_direct"
    elif "v1" in model_name.lower():
        conv_mode = "llava_v1"
    elif "mpt" in model_name.lower():
        conv_mode = "mpt"
    else:
        conv_mode = "llava_v0"

    if args.conv_mode is not None and conv_mode != args.conv_mode:
        print(
            "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
                conv_mode, args.conv_mode, args.conv_mode
            )
        )
    else:
        args.conv_mode = conv_mode
    
    conv = conv_templates[args.conv_mode].copy()
    conv.append_message(conv.roles[0], qs)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()
    conv.tokenizer = tokenizer

    image_files = image_parser(args)
    images = load_images(image_files)
    image_sizes = [x.size for x in images]
    images_tensor = process_images(
        images,
        image_processor,
        model.config
    ).to(model.device, dtype=torch.float16)

    input_ids = (
        tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
        .unsqueeze(0)
        .cuda()
    )

    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=images_tensor,
            image_sizes=image_sizes,
            do_sample=True if args.temperature > 0 else False,
            temperature=args.temperature,
            top_p=args.top_p,
            num_beams=args.num_beams,
            max_new_tokens=args.max_new_tokens,
            use_cache=True,
        )

    outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
    print(outputs)

In [18]:
from dataclasses import dataclass

@dataclass
class Inputs():
    model_path: str = "./model/LLaVA_MORE-llama_3_1-8B-finetuning"
    model_base: int = None
    #image_file: str = "https://farm2.staticflickr.com/1168/4723652147_ae14813f08_z.jpg"
    image_file: str = "https://makereceipt.com/images/ItemizedBarcode.jpg"    
    query: str = "Describe this image."
    conv_mode: str = "llama_3_1"
    sep: str = ","
    temperature: float = 0.2
    top_p:str = None
    num_beams: int = 1
    max_new_tokens: int = 128
    
args = Inputs()
print(f"conversation mode: {args.conv_mode}")
print(f"model name: {args.model_path}")
    
eval_model(args)

conversation mode: llama_3_1
model name: ./model/LLaVA_MORE-llama_3_1-8B-finetuning


The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Some weights of the model checkpoint at ./model/LLaVA_MORE-llama_3_1-8B-finetuning were not used when initializing LlavaLlamaForCausalLM: ['model.vision_tower.vision_tower.vision_model.embeddings.class_embedding', 'model.vision_tower.vision_tower.vision_model.embeddings.patch_embedding.weight', 'model.vision_tower.vision_tower.vision_model.embeddings.position_embedding.weight', 'model.vision_tower.vision_tower.vision_model.encoder.layers.0.layer_norm1.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.0.layer_norm1.weight', 'model.vision_tower.vision_tower.vision_model.encoder.layers.0.layer_norm2.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.0.layer_norm2.weight', 'model.vision_tower.vision_tower.vision_model.encoder.layers.0.mlp.fc1.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.0.mlp.fc1.weight', 'model.vision_tower.vision_tower.vision_model.encoder.layers.0.mlp.fc2.bias', 'model.vision_tower.vision_tower.vision_model.encod

LlavaConfig {
  "_name_or_path": "./model/LLaVA_MORE-llama_3_1-8B-finetuning",
  "architectures": [
    "LlavaLlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 128000,
  "eos_token_id": [
    128001,
    128008,
    128009
  ],
  "freeze_mm_mlp_adapter": false,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "image_aspect_ratio": "pad",
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  "max_position_embeddings": 131072,
  "mlp_bias": false,
  "mm_hidden_size": 1024,
  "mm_patch_merge_type": "flat",
  "mm_projector_lr": null,
  "mm_projector_type": "mlp2x_gelu",
  "mm_use_im_patch_token": false,
  "mm_use_im_start_end": false,
  "mm_vision_select_feature": "patch",
  "mm_vision_select_layer": -2,
  "mm_vision_tower": "openai/clip-vit-large-patch14-336",
  "model_type": "llava_llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 8,
  "pretraining_tp": 1,
  "quantization_config": {
    "_load_in_

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


The image features a receipt for a purchase made at a cafe. The receipt is for a total of $8.75, with a tax of $1.00 and a tip of $2.00. The receipt is dated January 2, 2019, and the cafe is located at 111 Green Street, New York, NY. 

The receipt lists the items purchased, including almond milk and bottled water. The receipt also shows the name of the host, Maggie.


In [15]:
# from transformers import LlavaForConditionalGeneration, LlavaConfig, CLIPVisionConfig, LlamaConfig

# # Initializing a CLIP-vision config

# vision_config = CLIPVisionConfig()

# # Initializing a Llama config

# text_config = LlamaConfig()

# # Initializing a Llava llava-1.5-7b style configuration

# configuration = LlavaConfig(vision_config, text_config)

# # Initializing a model from the llava-1.5-7b style configuration

# model = LlavaForConditionalGeneration(configuration)

# # Accessing the model configuration

# configuration = model.config

In [16]:
# from PIL import Image
# import requests
# from transformers import AutoProcessor, LlavaForConditionalGeneration

# #model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
# #processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")

# #aimagelab/LLaVA_MORE-llama_3_1-8B-finetuning
# model = LlavaForConditionalGeneration.from_pretrained("aimagelab/LLaVA_MORE-llama_3_1-8B-finetuning")
# processor = AutoProcessor.from_pretrained("aimagelab/LLaVA_MORE-llama_3_1-8B-finetuning")

# prompt = "USER: <image>\nWhat's the content of the image? ASSISTANT:"
# url = "https://www.ilankelman.org/stopsigns/australia.jpg"
# image = Image.open(requests.get(url, stream=True).raw)

# inputs = processor(text=prompt, images=image, return_tensors="pt")

# # Generate
# generate_ids = model.generate(**inputs, max_new_tokens=15)
# processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]