# Unsloth VLM Full Fine-tuning with GradES

This notebook demonstrates how to perform Full Fine-Tuning (FFT) on a Vision-Language Model (VLM) using Unsloth and GradES for gradient-based early stopping. It is based on the dataset and preprocessing from the Unsloth Qwen2.5 VL notebook.

In [None]:
!pip install grades
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install --no-deps trl peft accelerate bitsandbytes

## 1. Imports

In [None]:
import torch
from datasets import load_dataset
from unsloth import FastVisionModel
from trl import SFTTrainer, SFTConfig
from unsloth.trainer import UnslothVisionDataCollator

from grades import VLMGradEarlyStoppingCallback, VLMGradEarlyStoppingConfig

## 2. Load Model for Full Fine-tuning

In [None]:
model, tokenizer = FastVisionModel.from_pretrained(
    "unsloth/Qwen2.5-VL-7B-Instruct-bnb-4bit",
    load_in_4bit=True,
    full_finetuning=True,  # Enable FFT
    trust_remote_code=True,
)

## 3. Prepare Dataset

In [None]:
dataset = load_dataset("unsloth/LaTeX_OCR", split="train")

instruction = "Write the LaTeX representation for this image."

def convert_to_conversation(sample):
    conversation = [
        {"role": "user", "content": [
            {"type": "text", "text": instruction},
            {"type": "image", "image": sample["image"]}
        ]},
        {"role": "assistant", "content": [
            {"type": "text", "text": sample["text"]}
        ]},
    ]
    return {"messages": conversation}

converted_dataset = [convert_to_conversation(sample) for sample in dataset]

## 4. Integrate GradES

In [None]:
vlm_config = VLMGradEarlyStoppingConfig(
    vision_tau=0.13,
    language_tau=0.09,
    alpha=0.1,
    enable_wandb_logging=False,
)
vlm_callback = VLMGradEarlyStoppingCallback(vlm_config)

## 5. Set up Trainer

In [None]:
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=converted_dataset,
    data_collator=UnslothVisionDataCollator(model, tokenizer),
    callbacks=[vlm_callback],
    args=SFTConfig(
        output_dir="unsloth_vlm_fft_grades",
        per_device_train_batch_size=2,
        gradient_accumulation_steps=4,
        max_steps=60,
        learning_rate=2e-5,
        logging_steps=5,
        report_to="none",
        save_strategy="no",
        bf16=True,
        # Unsloth specific arguments
        remove_unused_columns=False,
        dataset_text_field="",
        dataset_kwargs={"skip_prepare_dataset": True},
    ),
)

## 6. Start Training

In [None]:
trainer.train()