# Hugging Face VLM LoRA Fine-tuning with GradES

This notebook demonstrates how to fine-tune a Vision-Language Model (VLM) using LoRA with the Hugging Face `Trainer` and GradES for gradient-based early stopping. It is based on the dataset and preprocessing from the Unsloth Qwen2.5 VL notebook.

In [None]:
!pip install grades
!pip install transformers datasets peft accelerate bitsandbytes torch

## 1. Imports

In [None]:
import torch
from datasets import load_dataset
from transformers import (
    AutoModelForImageTextToText,
    AutoProcessor,
    TrainingArguments,
    Trainer,
)
from peft import LoraConfig, get_peft_model, TaskType
from dataclasses import dataclass
from typing import Any, Dict, List
from PIL import Image

from grades import VLMGradEarlyStoppingCallback, VLMGradEarlyStoppingConfig

## 2. Load Model and Processor

In [None]:
model = AutoModelForImageTextToText.from_pretrained(
    "Qwen/Qwen2.5-VL-7B-Instruct",
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
).to("cuda")

processor = AutoProcessor.from_pretrained(
    "Qwen/Qwen2.5-VL-7B-Instruct",
    trust_remote_code=True,
)

## 3. Configure LoRA

In [None]:
peft_config = LoraConfig(
    r=16,
    lora_alpha=16,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
        "qkv", "proj",
    ],
    lora_dropout=0.0,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

## 4. Prepare Dataset

In [None]:
dataset = load_dataset("unsloth/LaTeX_OCR", split="train")

instruction = "Write the LaTeX representation for this image."

def convert_to_conversation(sample):
    conversation = [
        {"role": "user", "content": [
            {"type": "text", "text": instruction},
            {"type": "image", "image": sample["image"]}
        ]},
        {"role": "assistant", "content": [
            {"type": "text", "text": sample["text"]}
        ]},
    ]
    return {"messages": conversation}

converted_dataset = [convert_to_conversation(sample) for sample in dataset]

## 5. Data Collator

In [None]:
@dataclass
class VLMDataCollator:
    processor: Any

    def __call__(self, features: List[Dict]) -> Dict[str, torch.Tensor]:
        texts = [self.processor.apply_chat_template(f['messages'], tokenize=False, add_generation_prompt=False) for f in features]
        images = [f['messages'][0]['content'][1]['image'] for f in features]
        
        inputs = self.processor(text=texts, images=images, return_tensors="pt", padding=True)
        
        inputs['labels'] = inputs['input_ids'].clone()
        inputs['labels'][inputs['input_ids'] == self.processor.tokenizer.pad_token_id] = -100
        
        return inputs

data_collator = VLMDataCollator(processor=processor)

## 6. Integrate GradES

In [None]:
vlm_config = VLMGradEarlyStoppingConfig(
    vision_tau=1e-4,
    language_tau=1e-3,
    alpha=0.3,
    enable_wandb_logging=False,
)
vlm_callback = VLMGradEarlyStoppingCallback(vlm_config)

## 7. Set up Trainer

In [None]:
training_args = TrainingArguments(
    output_dir="./hf_vlm_lora_grades",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    max_steps=60,
    learning_rate=2e-4,
    logging_steps=5,
    report_to="none",
    save_strategy="no",
    bf16=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=converted_dataset,
    data_collator=data_collator,
    callbacks=[vlm_callback],
)

## 8. Start Training

In [None]:
trainer.train()