In [1]:
# dpo_idefics2-8b.py
from datasets import features, load_dataset
from transformers import AutoModelForVision2Seq, AutoProcessor
import torch
from trl import DPOConfig, DPOTrainer
from peft import LoraConfig

import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

  from .autonotebook import tqdm as notebook_tqdm
No ROCm runtime is found, using ROCM_HOME='/usr'


In [3]:
# Load the model and processor
model = AutoModelForVision2Seq.from_pretrained(
    "llava-hf/llava-v1.6-mistral-7b-hf",
    torch_dtype=torch.float16,
    # cache_dir="./cache",
    load_in_4bit=True,
)
processor = AutoProcessor.from_pretrained(
    "llava-hf/llava-v1.6-mistral-7b-hf", do_image_splitting=False, 
    # cache_dir="./cache"
)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
`low_cpu_mem_usage` was None, now default to True since model is quantized.


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [4]:
# Load the dataset
dataset = load_dataset(
    "openbmb/RLAIF-V-Dataset", split="train[:500]", cache_dir="cache"
)


def format(example):
    # Prepare the input for the chat template
    prompt = [
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": example["question"]},
            ],
        }
    ]
    chosen = [
        {"role": "assistant", "content": [{"type": "text", "text": example["chosen"]}]}
    ]
    rejected = [
        {
            "role": "assistant",
            "content": [{"type": "text", "text": example["rejected"]}],
        }
    ]
    # Apply the chat template
    prompt = processor.apply_chat_template(prompt, tokenize=False)
    chosen = processor.apply_chat_template(chosen, tokenize=False)
    rejected = processor.apply_chat_template(rejected, tokenize=False)
    # Resize the image to ensure it fits within the maximum allowable
    # size of the processor to prevent OOM errors.
    # max_size = processor.image_processor.size["longest_edge"] // 2
    # example["image"].thumbnail((max_size, max_size))
    return {
        "images": [example["image"]],
        "prompt": prompt,
        "chosen": chosen,
        "rejected": rejected,
    }


# Apply the formatting function to the dataset
dataset = dataset.map(format, remove_columns=dataset.column_names, num_proc=32)

# Make sure that the images are decoded, it prevents from storing bytes.
# More info here https://github.com/huggingface/blog/pull/2148#discussion_r1667400478
f = dataset.features
f["images"] = features.Sequence(features.Image(decode=True))
dataset = dataset.cast(f)




asting the dataset: 100%|██████████████████████████████████████████████████████████████████████████████████| 500/500 [00:00<00:00, 3595.59 examples/s]

In [5]:
dataset[:1]

{'chosen': [' A leather crafter is more likely to use these tools. The image shows various crafting tools, including scissors and a hole punch, which are commonly used in leatherworking projects. Leather is a material that requires cutting, shaping, and precise hole-punching techniques to create desired designs or patterns. In contrast, paper crafters typically use different types of tools, such as adhesives, decorative papers, or specialized cutting machines like the Silhouette Cameo, for their projects.<\\s> '],
 'rejected': [' A leather crafter is more likely to use these tools as they consist of a hole punch, scissors, and a knife. These items are typically used in crafting projects involving fabric or leather materials for various designs and patterns. Paper crafters may also benefit from some of these tools, but their primary focus would be on paper-related projects, which might require different types of tools such as paper cutters or scrapbooking supplies.<\\s> '],
 'images': [

In [6]:
# Train the model
training_args = DPOConfig(
    output_dir="llavaNextOutTry",
    bf16=False,
    gradient_checkpointing=True,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    num_train_epochs=1,
    dataset_num_proc=8,  # tokenization will use 32 processes
    dataloader_num_workers=32,  # data loading will use 32 workers
    logging_steps=10,
)

In [7]:
trainer = DPOTrainer(
    model,
    ref_model=None,  # not needed when using peft
    args=training_args,
    train_dataset=dataset,
    processing_class=processor,
    peft_config=LoraConfig(target_modules="all-linear", r=32),
)




okenizing train dataset (num_proc=8): 100%|██████████████████████████████████████████████████████████████████| 500/500 [00:07<00:00, 67.81 examples/s]

In [8]:
trainer.train()

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss


KeyboardInterrupt: 