Install and import necessary libraries

In [None]:
!python -m pip install --upgrade pip -q
!pip install matplotlib -q -U

In [None]:
!pip install -q datasets
!pip install transformers -q -U
!pip install -q bitsandbytes sentencepiece accelerate loralib
!pip install -q -U git+https://github.com/huggingface/peft.git
!pip install hf_transfer -q -U
!pip install pickleshare -q

In [None]:
# allows for faster downloading
%env HF_HUB_ENABLE_HF_TRANSFER=1

In [None]:
import os 

if not os.path.isdir("LLaVA"):
    !git clone https://github.com/haotian-liu/LLaVA.git
else:
    print('LLaVA directory already exists. Skipping clone.')

In [None]:
%cd LLaVA

In [None]:
# can take up to 5 mins
!pip install -e . -q

In [None]:
!pip install protobuf -q -U
!pip install --upgrade Pillow -q
!pip install -e ".[train]" -q
!pip install flash-attn --no-build-isolation -q

Load the model

In [None]:
# load the model
import torch
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from PIL import Image
import transformers
from transformers import AutoProcessor, Trainer, TrainingArguments, BitsAndBytesConfig
import torchvision.transforms as transforms

In [None]:
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path
from llava.eval.run_llava import eval_model

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# supported models for inference and training 
model_path = 'liuhaotian/llava-v1.5-7b'

model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path = model_path, 
    model_base = None, 
    model_name = model_name,
    cache_dir='', # it will download the model to the directory we are currently at
    use_flash_attn=True
)

Inference 

In [None]:
# method to test the model's inference
import re 
import torch 
from torchvision.transforms import Compose, ToTensor, Normalize
from torchvision.transforms.functional import to_pil_image, to_tensor
from PIL import Image 
import requests 
from io import BytesIO

from llava.constants import (
    IMAGE_TOKEN_INDEX, 
    DEFAULT_IMAGE_TOKEN, 
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN, 
    IMAGE_PLACEHOLDER
)

from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import (
    process_images, 
    tokenizer_image_token, 
    get_model_name_from_path,
)

# common function to create prompts 
def create_prompt(query, model, model_name=model_name, caption=None):
    image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
    if IMAGE_PLACEHOLDER in query:
        if model.config.mm_use_im_start_end:
            query = re.sub(IMAGE_PLACEHOLDER, image_token_se, query)
        else:
            query = re.sub(IMAGE_PLACEHOLDER,DEFAULT_IMAGE_TOKEN, query)
    else: 
        if model.config.mm_use_im_start_end:
            query = image_token_se + "\n" + query
        else:
            query = DEFAULT_IMAGE_TOKEN + "\n" + query
    conv_mode = infer_conv_mode(model_name)
    conv = conv_templates[conv_mode].copy()
    conv.append_message(conv.roles[0], query)
    if caption is not None:
        conv.append_message(conv.roles[1], caption)
    else:
        conv.append_message(conv.roles[1], None)
    return conv.get_prompt()

# common function to infer conversation mode
def infer_conv_mode(model_name):
    if 'llama-2' in model_name.lower():
        return 'llava_llama_2'
    elif 'mistral' in model_name.lower():
        return 'mistral_instruct'
    elif 'v1.6-34b' in model_name.lower():
        return 'chatml_direct'
    elif 'v1' in model_name.lower():
        return 'llava_v1'
    elif 'mpt' in model_name.lower():
        return 'mpt'
    else:
        return 'llava_v0'

# common function to process images 
def process_and_prepare_images(image_files, image_processor, model, device):
    images = [load_image(image_file) for image_file in image_files]
    images_tensor = process_images(
        images, 
        image_processor, 
        model.config
    ).to(
        device, 
        dtype=torch.float16
    )
    image_sizes = [image.size for image in images]
    return images_tensor,image_sizes

Setup finetuning dataset

In [None]:
from PIL import Image

def expand2square(pil_img, background_color):
    width, height = pil_img.size
    if width == height:
        return pil_img
    elif width > height:
        result = Image.new(pil_img.mode, (width, width), (int(background_color[0]*255),))
        result.paste(pil_img, (0, (width - height) // 2))
        return result
    else:
        result = Image.new(pil_img.mode, (height, height), (int(background_color[0]*255),))
        result.paste(pil_img, ((height - width) // 2, 0))
        return result

In [None]:
import torch
import re
def load_image(image_input):
    # check if input is a string (path or URL):
    if isinstance(image_input, str):
        if image_input.startswith('http') or image_input.startswith('https'):
            response = requests.get(image_input)
            image = Image.open(BytesIO(response.content).convert('RGB'))
            # In the process_images function (in mm_utils.py), modify this line:
            image = expand2square(image, (image_processor.image_mean[0],))
        else:
            image = Image.open(image_input).convert('RGB')
            # In the process_images function (in mm_utils.py), modify this line:
            image = expand2square(image, (image_processor.image_mean[0],))
    elif isinstance(image_input, Image.Image):
        # input is already an image object, return as it is
        image = image_input
        image = expand2square(image, (image_processor.image_mean[0],))
    else:
        raise ValueError("Unsupported image input type")
    return image 

In [None]:
def eval_model(tokenizer, model, image_processor, context_len, image_file, query, model_name=model_name, sep=',', temperature=1.0, num_beams=1, max_new_tokens=512):
    # model
    disable_torch_init()

    # create prompt using the common function
    prompt = create_prompt(query, model, model_name)
    print(f"Prompt: {prompt}")

    # process images using the common function
    if isinstance(image_file, list):
        images_tensor, image_sizes = process_and_prepare_images(image_file, image_processor, model, model.device)
    elif isinstance(image_file, str):
        images_tensor, image_sizes = process_and_prepare_images([image_file], image_processor, model, model.device)
    else:
        images = [image_file]
        images_tensor, image_sizes = process_and_prepare_images(images, image_processor, model, model.device)
    # tokenize the prompt using the custom tokenizer_image_token function
    input_ids = (
        tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
        .unsqueeze(0)
        .to(model.device)
    )

    with torch.inference_mode():
        output_ids = model.generate(
            input_ids, 
            images=images_tensor, 
            image_sizes=image_sizes,
            do_sample=temperature != 1.0,
            temperature=temperature,
            num_beams=num_beams,
            max_new_tokens=max_new_tokens,
            use_cache=True
        )
    outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=False)[0].strip()
    print(outputs)



In [None]:
import requests
from PIL import Image
from io import BytesIO
import matplotlib.pyplot as plt

# image URL
image_url = 'https://raw.githubusercontent.com/TrelisResearch/install-guides/main/knight_and_rook.jpg'

try:
    # download image
    response = requests.get(image_url)
    response.raise_for_status()  # Raise an exception for bad status codes

    # open it with PIL
    image = Image.open(BytesIO(response.content))

    # display image using matplotlib
    plt.imshow(image)
    plt.axis('off')
    plt.show()

    # now you can pass the processed image to eval_model
    eval_model(
        tokenizer, 
        model,
        image_processor, 
        context_len, 
        image, 
        'what do you see in this picture?'
    )

except requests.exceptions.RequestException as e:
    print(f"Error fetching the image: {e}")
except IOError as e:
    print(f"Error opening the image: {e}")
except Exception as e:
    print(f"An unexpected error occurred: {e}")


In [None]:
# fine tuning dataset
from torch.nn.utils.rnn import pad_sequence
def tokenize_and_create_labels(example_batch, image_processor, tokenizer, model, model_name, device, ignore_index=-100 ):
    pad_token_id = tokenizer.pad_token_id
    image_files = example_batch['image']

    images_tensor, image_sizes = process_and_prepare_images(image_files, image_processor, model, device)

    query = "What do you see in this picture?"

    # Define a mapping from int to str labels
    label_mapping = {
        0: "Covid",
        1: "Normal",
        2: "Viral Pneumonia"
    }

    # Convert int labels to str
    str_labels = [label_mapping[int(label)] for label in example_batch['label']]

    # Use the str labels in tokenization
    tokenized_conversations_without_caption = [
        tokenizer_image_token(create_prompt(query, model, model_name, None), tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
        for _ in str_labels
    ]

    tokenized_conversations_with_caption = [
        tokenizer_image_token(create_prompt(query, model, model_name, caption), tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
        for caption in str_labels
    ]
    # pad the tokenized conversations to the same length
    input_ids = pad_sequence([tcwc.squeeze(0) for tcwc in tokenized_conversations_without_caption])

    # create attention mask (1 for real tokens and 0 for padding tokens)
    attention_mask = (input_ids != pad_token_id).long().to(device)

    # create labels tensor which is a copy of input_ids but with ignore_index
    labels = torch.full_like(input_ids, fill_value=ignore_index)
    for i, tcwc in enumerate(tokenized_conversations_without_caption):
        # set ignore index for tokens corresponding to convo
        input_id_without_caption = tcwc.squeeze(0)
        labels[i, len(input_id_without_caption):] = input_ids[i, len(input_id_without_caption):]
    
    inputs ={
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'images': images_tensor,
        'image_sizes': image_sizes,
        'labels': labels
    }

    return inputs

# make sure to define the function outside of the lambda 
def transform_batch(batch):
    return tokenize_and_create_labels(batch, image_processor, tokenizer, model, model_name, device, ignore_index=-100)

# load and prepare dataset
ds = load_dataset('yuighj123/covid-19-classification')
train_ds = ds['train']
eval_ds = ds['test']

# apply transformation function to the dataset
train_ds.set_transform(transform_batch)
eval_ds.set_transform(transform_batch)

In [None]:
print(ds)

Run the code below if the above does not work

In [None]:
def tokenize_and_create_labels(example_batch, image_processor, tokenizer, model, model_name, device, ignore_index=-100):
    pad_token_id = tokenizer.pad_token_id
    image_files = example_batch['image']
    
    # Process images
    images_tensor, image_sizes = process_and_prepare_images(image_files, image_processor, model, device)
    
    query = "What do you see in this picture?"
    
    # Tokenize the conversations without captions
    tokenized_conversations_without_caption = [
        tokenizer_image_token(create_prompt(query, model, model_name, caption=None), tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
        for _ in example_batch['label']
    ]
    
    # Pad the tokenized conversations to the same length
    input_ids = pad_sequence([tcwc.squeeze(0) for tcwc in tokenized_conversations_without_caption], 
                             batch_first=True, padding_value=pad_token_id)
    
    # Create attention mask (1 for real tokens and 0 for padding tokens)
    attention_mask = (input_ids != pad_token_id).long()
    
    # Create labels tensor which is a copy of input_ids but with ignore_index
    labels = torch.full_like(input_ids, fill_value=ignore_index)
    for i, tcwc in enumerate(tokenized_conversations_without_caption):
        # Set ignore index for tokens corresponding to conversation
        input_id_without_caption = tcwc.squeeze(0)
        labels[i, len(input_id_without_caption):] = input_ids[i, len(input_id_without_caption):]
    
    inputs = {
        'input_ids': input_ids.to(device),
        'attention_mask': attention_mask.to(device),
        'images': images_tensor,
        'labels': labels.to(device)
    }
    
    return inputs

# Make sure to define the function outside of the lambda
def transform_batch(batch, image_processor, tokenizer, model, model_name, device, ignore_index=-100):
    return tokenize_and_create_labels(batch, image_processor, tokenizer, model, model_name, device, ignore_index)


LORA

what we are doing is to freeze the main model, train those smaller adapters, and then merge back onto the main model

In [None]:
config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=[
        'q_proj','k_proj','v_proj',
        # 'fc1', 'fc2' # for llama, 
        'mm_projector' # for mistral, train instead 'mm_projector'
        'up_proj', 'down_proj', 'gate_proj' # optionally train more linear layers
    ], 
    lora_dropout=0.05,
    bias='none'
)

model = get_peft_model(model, config)

Pre-training Evaluation

In [None]:
import matplotlib.pyplot as plt

# temporarily disable the transformation to access the original data
eval_ds.reset_format()

# iterate over each example in the eval dataset
for i in range(len(eval_ds)):
    # access the original image and caption for the current row 
    image = eval_ds[i]['image']
    label = eval_ds[i]['label']

    # display the image using matplotlib
    plt.imshow(image)
    plt.axis('off')
    plt.show()

    eval_model(
        tokenizer, 
        model, 
        image_processor,
        context_len,
        image,
        'What do you see in this picture?'
    )
    print(f"\nCorrect label: {label}\n\n")

# re-enable the transformation if needed 
eval_ds.set_transform(lambda batch: tokenize_and_create_labels(batch, image_processor, tokenizer, model, device))

In [None]:
model.print_trainable_parameters()

Training 

In [None]:
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

batch_size = 4
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)

for batch in train_loader:
    print("Batch keys:", batch.keys())
    
    for key in batch:
        print(f"{key} shape: {batch[key].shape}")
    
    if 'image' in batch:  # Note: changed from 'images' to 'image'
        print('Images are included in the DataLoader')
        image_tensor = batch['image'][0]  # Note: using index 0 instead of 1
        print(f"First Image Data Type: {image_tensor.dtype}")
        print(f"First Image Shape: {image_tensor.shape}")
        print(f"First Image Value range: [{image_tensor.min()}, {image_tensor.max()}]")
    
    if 'label' in batch:  # Note: changed from 'labels' to 'label'
        print(f"Labels: {batch['label']}")
    
    break  # only check the first batch

Run the below code if the above does not work

In [None]:
from functools import partial
def transform_batch(batch, image_processor, tokenizer, model, model_name, device, ignore_index=-100):
    # Combine all examples in the batch
    combined_batch = {
        'image': [item['image'] for item in batch],
        'label': [item['label'] for item in batch]
    }
    
    # Process the combined batch
    return tokenize_and_create_labels(combined_batch, image_processor, tokenizer, model, model_name, device, ignore_index)

# Create a partial function with the required arguments
transform_batch_partial = partial(transform_batch, 
                                  image_processor=image_processor, 
                                  tokenizer=tokenizer, 
                                  model=model, 
                                  model_name=model_name, 
                                  device=device)

# Create the DataLoader with the collate_fn
batch_size = 4
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=transform_batch_partial)

for batch in train_loader:
    print("Batch keys:", batch.keys())
    for key, value in batch.items():
        if isinstance(value, (list, tuple, np.ndarray, torch.Tensor)):
            print(f"{key} type: {type(value)}, shape: {np.shape(value)}")
        else:
            print(f"{key} type: {type(value)}")

    if 'images' in batch:
        print('Images are included in the DataLoader')
        
        # Safely print shapes
        for key in ['input_ids', 'attention_mask', 'labels']:
            if key in batch:
                print(f"Batch '{key}' shape: {np.shape(batch[key])}")
        
        # Safely print labels and attention mask
        if 'labels' in batch:
            labels = batch['labels'][0] if isinstance(batch['labels'], (list, tuple, np.ndarray, torch.Tensor)) else batch['labels']
            if isinstance(labels, (np.ndarray, torch.Tensor)):
                labels = labels.tolist()
            labels_str = ['[IGNORE]' if label == -100 else str(label) for label in labels]
            print(f"Labels: {labels_str}")
        
        if 'attention_mask' in batch:
            attention_mask = batch['attention_mask'][0] if isinstance(batch['attention_mask'], (list, tuple, np.ndarray, torch.Tensor)) else batch['attention_mask']
            if isinstance(attention_mask, (np.ndarray, torch.Tensor)):
                attention_mask = attention_mask.tolist()
            print(f"Attention Mask: {attention_mask}")
        
        # Display image information
        image_data = batch['images'][0] if isinstance(batch['images'], (list, tuple, np.ndarray, torch.Tensor)) else batch['images']
        print(f"Image data type: {type(image_data)}")
        
        if isinstance(image_data, (np.ndarray, torch.Tensor)):
            print(f"Image shape: {image_data.shape}")
            print(f"Image dtype: {image_data.dtype}")
            print(f"Image value range: [{np.min(image_data)}, {np.max(image_data)}]")
            
            # Try to display the image
            plt.figure(figsize=(10, 10))
            
            if len(image_data.shape) == 3:
                if image_data.shape[0] == 3:
                    image_data = np.transpose(image_data, (1, 2, 0))
                plt.imshow(image_data)
            elif len(image_data.shape) == 2:
                plt.imshow(image_data, cmap='gray')
            else:
                print(f"Unexpected image shape: {image_data.shape}")
            
            plt.axis('off')
            plt.show()
        else:
            print("Image data is not a numpy array or torch tensor. Cannot display.")
    
    break  # only check the first batch

In [None]:
output_model_name = f"{model_name}-covid"

training_args = TrainingArguments(
    output_dir=output_model_name,
    learning_rate=1e-4,
    # fp16=True, #for non ampere gpus
    bf16=True,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=1,
    dataloader_pin_memory=False,
    save_total_limit=2,
    evaluation_strategy="steps",
    save_strategy="steps",
    save_steps=0.2,
    eval_steps=0.2,
    logging_steps=1,
    num_train_epochs=3,
    # max_steps=3,
    remove_unused_columns=False,
    push_to_hub=False,
    label_names=["labels"],
    load_best_model_at_end=True,
    report_to=None,
    optim="adamw_torch"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    # compute_loss=compute_loss,  # Pass the custom compute_loss function
)

trainer.train()


Eval after training

In [None]:
import matplotlib.pyplot as plt

# Temporarily disable the transformation to access the original data
eval_ds.reset_format()

# Iterate over each example in the evaluation dataset
for i in range(len(eval_ds)):
    # Access the original image and caption for the current row
    image = eval_ds[i]['image']
    caption = eval_ds[i]['caption']

    # Display the image using matplotlib
    plt.imshow(image)
    plt.axis('off')  # Turn off axis numbers and ticks
    plt.show()

eval_model(
    tokenizer,
    model,
    image_processor,
    context_len,
    image,
    "What do you see in this picture?"
)