In [None]:
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, TrainingArguments
from pyreft import ReftConfig, LoreftIntervention, get_reft_model, ReftTrainerForCausalLM
from utils import make_last_position_supervised_data_module

# load model
model = AutoModelForCausalLM.from_pretrained("AIDC-AI/Ovis2-4B",
                                             torch_dtype=torch.bfloat16,
                                             multimodal_max_length=8192,
                                             trust_remote_code=True)
text_tokenizer = model.get_text_tokenizer()
visual_tokenizer = model.get_visual_tokenizer()
conversation_formatter = model.get_conversation_formatter()


SYS_PROMPT="<image>%s"

nnsight is not detected. Please install via 'pip install nnsight' for nnsight backend.


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [2]:
cfg = ReftConfig(
    representations={
        "component": "llm.model.layers[0].mlp.output",
        "low_rank_dimension": 64,
        "intervention":LoreftIntervention(
            embed_dim=2048,
            low_rank_dimension=4
        )
    }
)
reft_model = get_reft_model(model, cfg)
reft_model.set_device('cuda')
reft_model.print_trainable_parameters()

Intervention key: comp_llm_model_layers[0]_mlp_output_unit_pos_nunit_1#0
trainable intervention params: 16,388 || trainable model params: 0
model params: 4,304,941,558 || trainable%: 0.00038067880316622875


In [3]:
training_examples = [
    ["dataset/American-Water.png", "Whats the closest object in the image? Describe how it looks.", "🌊🧑‍🚒🚒🔧  🇺🇸🎨"],
    ["dataset/beach_dog.jpg", "Whats the animal in the image and what is it doing?", "🐶🌊🏖️🏃‍♂️🟠👄"],
    ["dataset/ClockTower.png", "Describe the buildings in picture.", "🕰️🗼🏰✨🏢🌆"],
    ["dataset/Elephants.png", "Which animal is in the picture and how many?", "🐘🐘🐘👶❤️ 🌳🌿🚶‍♂️"],
    ["dataset/fox.jpg", "What's the outfit of the player and which sports? Whats the player doing?", "🏀👕⚪👑➄ ✋🟠👀👥"],
    ["dataset/lacrose.jpg", "Which Sport is this? Whats written on the shirt?", "🥍🏃‍♂️💨👕🔴⚪🏙️"],
    ["dataset/Sea-Surfing.png", "Describe the whole scene", "🌊🌥️🧍‍♂️🏄‍♂️🛶🎡🎠🎢"],
    ["dataset/test2.jpg", "Why is cat dressed up like that?", "🐱💰💵💎📿🕴️🍣✨"],
    ["dataset/Dog-Red-Bucket.png", "What is the dog doing?", "🐕🖤 🔴🪣👄🏃💨"]
]

data_module = make_last_position_supervised_data_module(
    tokenizer=text_tokenizer, model=model, 
    images=[Image.open(s[0]) for s in training_examples],
    inputs=[SYS_PROMPT%s[1] for s in training_examples],
    outputs=[s[2] for s in training_examples]
)

In [4]:
from pyreft import ReftTrainerForCausalLM
import pyvene as pv
class MultiModalReftTrainerForCausalLM(ReftTrainerForCausalLM):
    def compute_loss(
        self,
        intervenable: pv.IntervenableModel,
        inputs,
        return_outputs=False,
        **kwargs
    ):
        # run intervened forward pass
        unit_locations = None
        if "intervention_locations" in inputs:
            if inputs["intervention_locations"].dim() == 3:
                unit_locations={"sources->base": (
                    None,
                    inputs["intervention_locations"].permute(1, 0, 2).tolist()
                )}
            else:
                # this is dummy for lora only baseline
                unit_locations={"sources->base": (None, 0)}
        base_outputs, cf_outputs = intervenable(
            {
                "input_ids": inputs["input_ids"],
                "attention_mask": inputs["attention_mask"],
                "pixel_values": inputs["pixel_values"] # For VLM
            },
            unit_locations=unit_locations,
            labels=inputs["labels"],
            subspaces=inputs["subspaces"].permute(1, 0, 2).tolist() if "subspaces" in inputs else None
        )
        # return
        output = cf_outputs
        if cf_outputs is None:
            output = base_outputs # in case of lora only training

        return (output, output) if return_outputs else output.loss

In [5]:
training_args = TrainingArguments(
    num_train_epochs=100.0, output_dir="./tmp", per_device_train_batch_size=10, 
    learning_rate=4e-3, logging_steps=20, report_to="none")
trainer = MultiModalReftTrainerForCausalLM(
    model=reft_model, tokenizer=text_tokenizer, args=training_args, **data_module)
_ = trainer.train()

  trainer = MultiModalReftTrainerForCausalLM(


ValueError: You are attempting to perform batched generation with padding_side='right' this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to  call `tokenizer.padding_side  = 'left'` before tokenizing the input. 