In [None]:
from unsloth import FastModel
from datasets import load_dataset
from PIL import Image

model, tokenizer = FastModel.from_pretrained(
    model_name = "/root/autodl-tmp/kaggle408/checkpoints/gek_e2b", 
    max_seq_length = 4096, # Choose any for long context!
    load_in_4bit = False,  # 4 bit quantization to reduce memory
    attn_implementation = "eager", # necessary
)

<img src="https://files.worldwildlife.org/wwfcmsprod/images/Sloth_Sitting_iStock_3_12_2014/story_full_width/8l7pbjmj29_iStock_000011145477Large_mini__1_.jpg" alt="Alt text" height="256">

In [None]:
sloth_link = "https://files.worldwildlife.org/wwfcmsprod/images/Sloth_Sitting_iStock_3_12_2014/story_full_width/8l7pbjmj29_iStock_000011145477Large_mini__1_.jpg"

messages = [{
    "role" : "user",
    "content": [
        { "type": "image", "image" : sloth_link },
        { "type": "text",  "text" : "Which films does this animal feature in?" }
    ]
}]

messages = tokenizer.apply_chat_template(messages).removeprefix('<bos>')
print(messages)

In [None]:
from datasets import load_dataset
train_set = load_dataset("/root/autodl-tmp/kaggle408/dataset/rlaif-v",split="train[:20%]")
eval_set = load_dataset("/root/autodl-tmp/kaggle408/dataset/rlaif-v",split="train[99.5%:]")

In [None]:
def format(example):
    prompt = [
        {
            "role": "user",
            "content": [{"type": "image"}, {"type": "text", "text": example["question"]}],
        },
    ]
    chosen = [
        {
            "role": "assistant",
            "content": [{"type": "text", "text": example["chosen"]}],
        },
    ]
    rejected = [
        {
            "role": "assistant",
            "content": [{"type": "text", "text": example["rejected"]}],
        },
    ]

    max_size = max(tokenizer.image_processor.size.values())
    example["image"].thumbnail((max_size, max_size))

    if isinstance(example["image"], Image.Image) and example["image"].mode != "RGB":
        example["image"] = example["image"].convert("RGB")

    return {"images": [example["image"]], "prompt": prompt, "chosen": chosen, "rejected": rejected}

In [None]:
train_set = train_set.map(format, remove_columns=train_set.column_names)
eval_set = eval_set.map(format, remove_columns=eval_set.column_names)

In [None]:
model = FastModel.get_peft_model(
    model,
    finetune_vision_layers     = True, # Turn off for just text!
    finetune_language_layers   = True,  # Should leave on!
    finetune_attention_modules = True,  # Attention good for GRPO
    finetune_mlp_modules       = True,  # SHould leave on always!

    r = 8,           # Larger = higher accuracy, but might overfit
    lora_alpha = 8,  # Recommended alpha == r at least
    lora_dropout = 0,
    bias = "none",
    random_state = 3407,
)

In [None]:
from swanlab.integration.transformers import SwanLabCallback
swanlab_callback = SwanLabCallback(
    project="kaggle408",
    experiment_name="gemma3n-mutlti-dpo",
)

In [None]:
from unsloth import PatchDPOTrainer

PatchDPOTrainer()

from trl import DPOTrainer, DPOConfig
dpo_trainer = DPOTrainer(
    model = model,
    ref_model = None,
    callbacks=[swanlab_callback],
    args = DPOConfig(
        gradient_checkpointing=True,
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_ratio = 0.1,
        num_train_epochs = 1,
        learning_rate = 5e-6,
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.0,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs",
        dataloader_num_workers=8,
        dataset_num_proc=8,
    ),
    processing_class= tokenizer.tokenizer,
    beta = 0.1,
    train_dataset = train_set,
    eval_dataset = eval_set,
    max_length = 2048,
    max_prompt_length = 512,
)

In [None]:
dpo_trainer.train()