In [1]:
from datasets import load_dataset
ds = load_dataset("flaviagiammarino/path-vqa")
ds

DatasetDict({
    train: Dataset({
        features: ['image', 'question', 'answer'],
        num_rows: 19654
    })
    validation: Dataset({
        features: ['image', 'question', 'answer'],
        num_rows: 6259
    })
    test: Dataset({
        features: ['image', 'question', 'answer'],
        num_rows: 6719
    })
})

In [8]:
ds['train'][0]['question']

'where are liver stem cells (oval cells) located?'

In [2]:
import os
os.environ['HF_HOME'] = '/home/sa5u24/VQA'
hf_home = os.path.expanduser(
    os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
)
print(hf_home)

from huggingface_hub import login

# Replace 'your-hf-token-here' with your actual Hugging Face token
login(token="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")

/home/sa5u24/VQA


In [3]:
# note the image is not provided in the prompt its included as part of the "processor"

# prompt= """Create a Short Product description based on the provided ##PRODUCT NAME## and ##CATEGORY## and image.
# Only return description. The description should be SEO optimized and for a better mobile search experience.

# ##PRODUCT NAME##: {product_name}
# ##CATEGORY##: {category}"""

prompt= """Answer the question based on the provided ##Question## and pathology image. ##Question##: {question}"""

from datasets import load_dataset

# Convert dataset to OAI messages
def format_data(sample):
    return {"messages": [
                {
                    "role": "question",
                    "content": [
                        {
                            "type": "text",
                            "text": prompt.format(question=sample["question"]),
                        },{
                            "type": "image",
                            "image": sample["image"],
                        }
                    ],
                },
                {
                    "role": "answer",
                    "content": [{"type": "text", "text": sample["answer"]}],
                },
            ],
        }

# Load dataset from the hub

# Convert dataset to OAI messages
# need to use list comprehension to keep Pil.Image type, .mape convert image to bytes
dataset_train = [format_data(sample) for sample in ds['train']]
dataset_validation = [format_data(sample) for sample in ds['validation']]
dataset_test = [format_data(sample) for sample in ds['test']]





In [4]:
dataset_train[0]["messages"]

[{'role': 'question',
  'content': [{'type': 'text',
    'text': 'Answer the question based on the provided ##Question## and pathology image. ##Question##: where are liver stem cells (oval cells) located?'},
   {'type': 'image',
    'image': <PIL.JpegImagePlugin.JpegImageFile image mode=CMYK size=309x272>}]},
 {'role': 'answer',
  'content': [{'type': 'text', 'text': 'in the canals of hering'}]}]

In [5]:
import torch
from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig

# Hugging Face model id
model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"

# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)

# Load model and tokenizer
model = AutoModelForVision2Seq.from_pretrained(
    model_id,
    device_map="auto",
    # attn_implementation="flash_attention_2", # not supported for training
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config
)
processor = AutoProcessor.from_pretrained(model_id)

The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.


Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

In [6]:
from peft import LoraConfig

# LoRA config based on QLoRA paper & Sebastian Raschka experiment
peft_config = LoraConfig(
        lora_alpha=16,
        lora_dropout=0.05,
        r=8,
        bias="none",
        target_modules=["q_proj", "v_proj"],
        task_type="CAUSAL_LM",
)

In [7]:
from trl import SFTConfig


args = SFTConfig(
    output_dir="fine-tuned-visionllamav1", # directory to save and repository id
    num_train_epochs=10,                     # number of training epochs
    per_device_train_batch_size=4,          # batch size per device during training
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=8,          # number of steps before performing a backward/update pass
    gradient_checkpointing=True,            # use gradient checkpointing to save memory
    optim="adamw_torch_fused",              # use fused adamw optimizer
    logging_steps=5,                       # log every 10 steps
    save_strategy="epoch",                  # save checkpoint every epoch
    learning_rate=2e-4,                     # learning rate, based on QLoRA paper
    bf16=True,                              # use bfloat16 precision
    # tf32=True,                              # use tf32 precision
    max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper
    warmup_ratio=0.03,                      # warmup ratio based on QLoRA paper
    lr_scheduler_type="constant",           # use constant learning rate scheduler
    # push_to_hub=True,                       # push model to hub
    report_to="tensorboard",                # report metrics to tensorboard
    gradient_checkpointing_kwargs = {"use_reentrant": False}, # use reentrant checkpointing
    dataset_text_field="", # need a dummy field for collator
    dataset_kwargs = {"skip_prepare_dataset": True} # important for collator
)
args.remove_unused_columns=False

In [10]:
from transformers import Qwen2VLProcessor
from qwen_vl_utils import process_vision_info

def collate_fn(examples):
    # Get the texts and images, and apply the chat template
    texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples]
    image_inputs = [process_vision_info(example["messages"])[0] for example in examples]

    # Tokenize the texts and process the images
    batch = processor(text=texts, images=image_inputs, return_tensors="pt", padding=True)

    # The labels are the input_ids, and we mask the padding tokens in the loss computation
    labels = batch["input_ids"].clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100  #
    # Ignore the image token index in the loss computation (model specific)
    if isinstance(processor, Qwen2VLProcessor):
        image_tokens = [151652,151653,151655]
    else:
        image_tokens = [processor.tokenizer.convert_tokens_to_ids(processor.image_token)]
    for image_token_id in image_tokens:
        labels[labels == image_token_id] = -100
    batch["labels"] = labels

    return batch

from transformers import TrainerCallback

class BestModelSaverCallback(TrainerCallback):
    def __init__(self, trainer):
        super().__init__()
        self.trainer = trainer
        self.best_val_loss = float("inf")
        self.best_epoch = 0

    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        # Retrieve the validation loss from metrics
        val_loss = metrics.get("eval_loss")
        if val_loss is not None:
            # Check if this is the best loss
            if val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                self.best_epoch = state.epoch
                # Save the model as the best model so far
                output_dir = f"{self.trainer.args.output_dir}/best_model_epoch_{int(state.epoch)}"
                self.trainer.save_model(output_dir)
                print(f"Best model saved at epoch {state.epoch} with validation loss: {self.best_val_loss:.4f}")

    def on_train_end(self, args, state, control, **kwargs):
        print(f"Training completed. Best model was at epoch {self.best_epoch} with validation loss: {self.best_val_loss:.4f}")


In [None]:
from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset_train,
    eval_dataset = dataset_validation,
    data_collator=collate_fn,
    dataset_text_field="", # needs dummy value
    peft_config=peft_config,
    tokenizer=processor.tokenizer,
    callbacks=[BestModelSaverCallback(trainer)]
)

trainer.train()


Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]


Step,Training Loss
