# Fine-tuning a Visual Language Model (VLM) using DPO

This notebook demonstrates how to fine-tune a Visual Language Model (VLM), specifically the Gemma 3-1B-it model, using the Direct Preference Optimization (DPO) algorithm.

The key steps involved are:

1.  **Setup and Installations**: Install necessary libraries and dependencies.
2.  **Model Loading**: Load the pre-trained Gemma 3-1B-it model.
3.  **LoRA Application**: Apply Low-Rank Adaptation (LoRA) to the model for efficient fine-tuning.
4.  **Data Loading and Preprocessing**: Load the RLAIF-V dataset and preprocess it for VLM training, including handling images and tokenizing text.
5.  **DPO Training**: Set up and run the DPO training loop to fine-tune the model based on preference data (chosen and rejected responses).
6.  **Logging and Visualization**: Log training metrics and visualize the training progress.

The goal is to train the VLM to better align with human preferences by optimizing directly on pairs of preferred and dispreferred responses.

In [None]:
!pip install -q kagglehub

!pip install -q tensorflow
!pip install -q tensorboardX
!pip install -q grain
!pip install -q git+https://github.com/google/tunix
!pip install -q git+https://github.com/google/qwix

!pip uninstall -q -y flax
!pip install -q git+https://github.com/google/flax.git

!pip install -q huggingface_hub
!pip install -q datasets
!pip3 install jaxtyping

In [None]:
import functools
import gc
import os
from pprint import pprint
import re
import time

from flax import nnx
import grain
import humanize
import jax
import jax.numpy as jnp
import kagglehub
import optax
from orbax import checkpoint as ocp
import qwix
import tensorflow_datasets as tfds
from tqdm.auto import tqdm
import math
from tunix.generate import tokenizer_adapter as tokenizer_lib

from tunix.examples.data import translation_dataset as data_lib
from tunix.generate import sampler as sampler_lib
from tunix.generate.vlm_sampler import VLMSampler
from tunix.models.gemma3 import params as params_lib
from tunix.models.gemma3 import params_safetensors as params_safetensors_lib

from tunix.sft import metrics_logger
from datasets import load_dataset

from huggingface_hub import snapshot_download
from tunix.sft.dpo.dpo_trainer import _generate_ids_and_masks
from tunix.models.gemma3 import model as gemma3_model_lib
from datasets import concatenate_datasets

import types, json, os
import matplotlib.pyplot as plt
import numpy as np
from orbax import checkpoint as ocp
import optax
from tunix.sft.dpo.dpo_trainer import DpoTrainer, DpoTrainingConfig
from flax.serialization import to_state_dict
from datasets import load_dataset
import numpy as np, jax.numpy as jnp
from PIL import Image
from tunix.generate.utils import preprocess_image
from tunix.sft.dpo.dpo_trainer import TrainingInput

In [None]:
GEMMA_TOKENIZER_PATH = "gs://gemma-data/tokenizers/tokenizer_gemma3.model"
model_id = "google/gemma-3-1b-it"
IMAGE_SIZE = 224
# ====== Data ======
TRAIN_DATA_DIR = "./data/train"
TEST_DATA_DIR = "./data/test"
TRAIN_FRACTION = 1.0

INTERMEDIATE_CKPT_DIR = "/content/intermediate_ckpt/"
# ====== LoRA ======
RANK = 32
ALPHA = 16.0

# ====== Sharding ======
MESH = [(1, 1), ("fsdp", "tp")]

MAX_PROMPT_LENGTH = 256
TOTAL_GENERATION_STEPS = 256
TEMPERATURE = 0.7
TOP_P = 1.0
TOP_K = 50
BETA = 0.1

# === AdamW, warmup, cosine scheduler ===
LEARNING_RATE = 5e-6
B1 = 0.9
B2 = 0.99
WEIGHT_DECAY = 0.1

# == Cosine decay with warmup scheduler ==
# Linearly increase learning rate from 0. to 5e-6 in the first 10% training
# steps, and then gradually decrease the learning rate to 0 using cosine
# scheduler.
EVAL_EVERY_N_STEPS = 500
MAX_STEPS = 50000
BATCH_SIZE = 8
EPOCHS = 80

WARMUP_STEPS = 0.1 * MAX_STEPS
# == Grad clipping ==
# Grad clipping to prevent large gradients. Found this
# important to keep KL divergence in check.
MAX_GRAD_NORM = 0.1

# ====== Inference ======
GENERATION_CONFIGS = {
    # greedy search
    "greedy": {"temperature": 1e-4, "top_k": 1, "top_p": 1.0},
    # some randomness
    "standard": {"temperature": 0.7, "top_k": 50, "top_p": 0.95},
    # liberal
    "liberal": {"temperature": 0.85, "top_k": 2000, "top_p": 1.0},
}

In [None]:
!huggingface-cli login

In [None]:
ignore_patterns = [
    "*.pth",  # Ignore PyTorch .pth weight files
]
print(f"Downloading {model_id} from Hugging Face...")
local_model_path = snapshot_download(
    repo_id=model_id, ignore_patterns=ignore_patterns
)
print(f"Model successfully downloaded to: {local_model_path}")

In [None]:
MODEL_CP_PATH = local_model_path

model_config = (
    gemma3_model_lib.Gemma3Config.gemma3_1b()
)  # pick correponding config based on model version
MESH = [(1, 1), ("fsdp", "tp")]
mesh = jax.make_mesh(*MESH)
with mesh:
  gemma3 = params_safetensors_lib.create_model_from_safe_tensors(
      MODEL_CP_PATH, model_config, mesh
  )
  nnx.display(gemma3)

In [None]:
gemma_tokenizer = tokenizer_lib.Tokenizer(tokenizer_path=GEMMA_TOKENIZER_PATH)

vlm_sampler = VLMSampler(
    transformer=gemma3,
    tokenizer=gemma_tokenizer,
    image_size=IMAGE_SIZE,
)

In [None]:
def get_lora_model(base_model, mesh):
  lora_provider = qwix.LoraProvider(
      module_path=(
          ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|"
          ".*attn_vec_einsum"
      ),
      rank=RANK,
      alpha=ALPHA,
  )

  model_input = base_model.get_model_input()
  lora_model = qwix.apply_lora_to_model(
      base_model, lora_provider, **model_input
  )

  with mesh:
    state = nnx.state(lora_model)
    pspecs = nnx.get_partition_spec(state)
    sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
    nnx.update(lora_model, sharded_state)

  return lora_model

In [None]:
# Policy model
lora_gemma = get_lora_model(gemma3, mesh=mesh)
nnx.display(lora_gemma)

In [None]:
SPLIT = "train[:5000]"

ds = load_dataset("openbmb/RLAIF-V-Dataset", split=SPLIT)
cols = ["image", "question", "chosen", "rejected"]
ds = ds.remove_columns([c for c in ds.column_names if c not in cols])


def _pick_one_image(img_field):
  """Return a single PIL.Image from the dataset's image field."""
  x = img_field
  if isinstance(x, list):
    if not x:  # empty list, skip later
      return None
    x = x[0]
  if isinstance(x, Image.Image):
    return x.convert("RGB")
  arr = np.array(x)
  if arr.ndim == 3:
    return Image.fromarray(arr).convert("RGB")
  return None


def preprocess_item(ex):
  img = _pick_one_image(ex["image"])
  if img is None:
    return {
        "pixel_values": None,
        "question": ex["question"],
        "chosen": ex["chosen"],
        "rejected": ex["rejected"],
    }
  arr = np.array(img, dtype=np.uint8)[None, ...]  # [1,H,W,3]
  px = preprocess_image(jnp.asarray(arr), IMAGE_SIZE)  # [1,S,S,3] float32
  return {
      "pixel_values": np.asarray(px[0]),  # [S,S,3]
      "question": ex["question"],
      "chosen": ex["chosen"],
      "rejected": ex["rejected"],
  }


ds = ds.with_transform(preprocess_item)

PAD = gemma_tokenizer.pad_id()
EOS = gemma_tokenizer.eos_id()


def _left_pad_np(ids, L, pad=PAD):
  ids = ids[-L:] if len(ids) > L else [pad] * (L - len(ids)) + ids
  return np.asarray(ids, dtype=np.int32)


def _right_pad_np(ids, L, pad=PAD):
  ids = ids[:L]
  ids = ids + [pad] * (L - len(ids))
  return np.asarray(ids, dtype=np.int32)


def _make_mask(ids, pad=PAD):
  return (ids != pad).astype(np.int32)


def numpy_batches_vlm(dataset, batch_size=1, shuffle=True, seed=0, epochs=None):
  rng = np.random.default_rng(seed)
  epoch = 0
  while True:
    idx = np.arange(len(dataset))
    if shuffle:
      rng.shuffle(idx)

    buf = []
    for i in idx:
      ex = dataset[int(i)]
      if ex["pixel_values"] is None:
        continue
      buf.append(ex)
      if len(buf) == batch_size:
        qs = [b["question"] for b in buf]
        chs = [b["chosen"] for b in buf]
        rjs = [b["rejected"] for b in buf]

        q_tok = [gemma_tokenizer.encode(x) for x in qs]
        ch_tok = [gemma_tokenizer.encode(x) + [EOS] for x in chs]
        rj_tok = [gemma_tokenizer.encode(x) + [EOS] for x in rjs]

        Q = np.stack(
            [_left_pad_np(ids, MAX_PROMPT_LENGTH) for ids in q_tok], axis=0
        )
        CH = np.stack(
            [_right_pad_np(ids, TOTAL_GENERATION_STEPS) for ids in ch_tok],
            axis=0,
        )
        RJ = np.stack(
            [_right_pad_np(ids, TOTAL_GENERATION_STEPS) for ids in rj_tok],
            axis=0,
        )
        PX = np.stack([b["pixel_values"] for b in buf], axis=0).astype(
            np.float32
        )

        Q_mask = np.stack([_make_mask(ids, PAD) for ids in Q], axis=0)
        CH_mask = np.stack([_make_mask(ids, PAD) for ids in CH], axis=0)
        RJ_mask = np.stack([_make_mask(ids, PAD) for ids in RJ], axis=0)

        yield TrainingInput(
            prompt_ids=jnp.asarray(Q),
            prompt_mask=jnp.asarray(Q_mask),
            chosen_ids=jnp.asarray(CH),
            chosen_mask=jnp.asarray(CH_mask),
            rejected_ids=jnp.asarray(RJ),
            rejected_mask=jnp.asarray(RJ_mask),
            pixel_values=jnp.asarray(PX),
        )
        buf = []

    if buf:
      qs = [b["question"] for b in buf]
      chs = [b["chosen"] for b in buf]
      rjs = [b["rejected"] for b in buf]
      q_tok = [gemma_tokenizer.encode(x) for x in qs]
      ch_tok = [gemma_tokenizer.encode(x) + [EOS] for x in chs]
      rj_tok = [gemma_tokenizer.encode(x) + [EOS] for x in rjs]
      Q = np.stack(
          [_left_pad_np(ids, MAX_PROMPT_LENGTH) for ids in q_tok], axis=0
      )
      CH = np.stack(
          [_right_pad_np(ids, TOTAL_GENERATION_STEPS) for ids in ch_tok], axis=0
      )
      RJ = np.stack(
          [_right_pad_np(ids, TOTAL_GENERATION_STEPS) for ids in rj_tok], axis=0
      )
      PX = np.stack([b["pixel_values"] for b in buf], axis=0).astype(np.float32)
      Q_mask = np.stack([_make_mask(ids, PAD) for ids in Q], axis=0)
      CH_mask = np.stack([_make_mask(ids, PAD) for ids in CH], axis=0)
      RJ_mask = np.stack([_make_mask(ids, PAD) for ids in RJ], axis=0)
      yield TrainingInput(
          prompt_ids=jnp.asarray(Q),
          prompt_mask=jnp.asarray(Q_mask),
          chosen_ids=jnp.asarray(CH),
          chosen_mask=jnp.asarray(CH_mask),
          rejected_ids=jnp.asarray(RJ),
          rejected_mask=jnp.asarray(RJ_mask),
          pixel_values=jnp.asarray(PX),
      )

    epoch += 1
    if epochs is not None and epoch >= epochs:
      break


# Smoke one batch

b0 = next(numpy_batches_vlm(ds, batch_size=4))
print("Batch pixels:", b0.pixel_values.shape, "| B:", b0.prompt_ids.shape[0])
print("Batch prompt_ids:", b0.prompt_ids.shape)
print("Batch prompt_mask:", b0.prompt_mask.shape)
print("Batch chosen_ids:", b0.chosen_ids.shape)
print("Batch chosen_mask:", b0.chosen_mask.shape)
print("Batch rejected_ids:", b0.rejected_ids.shape)
print("Batch rejected_mask:", b0.rejected_mask.shape)

print("Dataset size:", len(ds))

In [None]:
# --- Eval helpers (no grads) ---


def _make_pos_and_causal_mask(tokens: jnp.ndarray, pad_id: int):
  valid = tokens != pad_id
  positions = (jnp.cumsum(valid.astype(jnp.int32), axis=1) - 1) * valid.astype(
      jnp.int32
  )
  L = tokens.shape[1]
  causal = jnp.tril(jnp.ones((L, L), dtype=bool))
  attn_mask = valid[..., None] & valid[:, None, :] & causal[None, ...]
  return positions.astype(jnp.int32), attn_mask


def _vlm_forward_and_cache(
    model_mod, *, tokens: jnp.ndarray, pixel_values: jnp.ndarray, pad_id: int
):
  positions, attn_mask = _make_pos_and_causal_mask(tokens, pad_id)
  logits, _cache_out = model_mod(
      last_tokens=tokens.astype(jnp.int32),
      positions=positions,
      cache={},  # fresh cache per call
      attention_mask=attn_mask,
      pixel_values=pixel_values.astype(jnp.float32),
      output_hidden_states=False,
  )
  return logits


def _seq_logprob_answer(model_mod, px, prompt_ids, answer_ids, pad_id: int):
  # teacher-forced: concat prompt with shifted answer
  x_ids = jnp.concatenate([prompt_ids, answer_ids[:, :-1]], axis=1)
  logits = _vlm_forward_and_cache(
      model_mod, tokens=x_ids, pixel_values=px, pad_id=pad_id
  )
  La = answer_ids.shape[1]
  ans_logits = logits[:, -La:, :]
  logp = jax.nn.log_softmax(ans_logits, axis=-1)
  tok_logp = jnp.take_along_axis(logp, answer_ids[..., None], axis=-1)[..., 0]
  ans_mask = (answer_ids != pad_id).astype(tok_logp.dtype)
  return (tok_logp * ans_mask).sum(axis=1)  # [B]


def _eval_batch_metrics(policy_mod, ref_mod, batch, pad_id: int, beta: float):
  px = jnp.asarray(batch.pixel_values, dtype=jnp.float32)
  q = jnp.asarray(batch.prompt_ids, dtype=jnp.int32)
  ch = jnp.asarray(batch.chosen_ids, dtype=jnp.int32)
  rj = jnp.asarray(batch.rejected_ids, dtype=jnp.int32)

  lp_ch = _seq_logprob_answer(policy_mod, px, q, ch, pad_id)
  lp_rj = _seq_logprob_answer(policy_mod, px, q, rj, pad_id)
  lq_ch = _seq_logprob_answer(ref_mod, px, q, ch, pad_id)
  lq_rj = _seq_logprob_answer(ref_mod, px, q, rj, pad_id)

  pol_margin = lp_ch - lp_rj  # [B]
  ref_margin = lq_ch - lq_rj  # [B]
  adv = pol_margin - ref_margin  # [B]
  # DPO loss (forward-only)
  loss = -jax.nn.log_sigmoid(beta * adv).mean()

  acc = (pol_margin > 0).mean()
  return {
      "loss": float(loss),
      "acc": float(acc),
      "pol_margin_mean": float(pol_margin.mean()),
      "ref_margin_mean": float(ref_margin.mean()),
      "adv_mean": float(adv.mean()),
      "num_correct": float((pol_margin > 0).sum()),
      "batch_size": float(pol_margin.shape[0]),
  }


def _mean_ci(xs):
  xs = [float(x) for x in xs]
  if not xs:
    return (float("nan"), float("nan"))
  m = sum(xs) / len(xs)
  var = sum((x - m) ** 2 for x in xs) / max(1, len(xs) - 1)
  se = math.sqrt(var / max(1, len(xs)))
  return m, 1.96 * se

In [None]:
def run_eval_vlm(
    policy_mod,
    ref_mod,
    dataset,
    *,
    batches=50,
    batch_size=16,
    seed=7,
    beta=BETA,
    pad_id=PAD,
):
  itr = numpy_batches_vlm(
      dataset, batch_size=batch_size, shuffle=True, seed=seed, epochs=1
  )
  losses, accs, pols, refs, advs = [], [], [], [], []
  tot_corr = 0.0
  tot_seen = 0.0

  for i in range(batches):
    try:
      batch = next(itr)
    except StopIteration:
      break
    out = _eval_batch_metrics(
        policy_mod, ref_mod, batch, pad_id=pad_id, beta=beta
    )
    losses.append(out["loss"])
    accs.append(out["acc"])
    pols.append(out["pol_margin_mean"])
    refs.append(out["ref_margin_mean"])
    advs.append(out["adv_mean"])
    tot_corr += out["num_correct"]
    tot_seen += out["batch_size"]

  loss_m, loss_ci = _mean_ci(losses)
  acc_m, acc_ci = _mean_ci(accs)
  pol_m, pol_ci = _mean_ci(pols)
  ref_m, ref_ci = _mean_ci(refs)
  adv_m, adv_ci = _mean_ci(advs)

  summary = {
      "batches_evaluated": len(losses),
      "examples_evaluated": int(tot_seen),
      "accuracy_mean": acc_m,
      "accuracy_95ci": acc_ci,
      "cumulative_accuracy": tot_corr / max(1.0, tot_seen),
      "loss_mean": loss_m,
      "loss_95ci": loss_ci,
      "policy_margin_mean": pol_m,
      "policy_margin_95ci": pol_ci,
      "ref_margin_mean": ref_m,
      "ref_margin_95ci": ref_ci,
      "advantage_mean": adv_m,
      "advantage_95ci": adv_ci,
      "reward_margin_mean (beta*adv)": beta * adv_m,
  }
  print("\n=== DPO Eval (VLM) ===")
  print(
      f"batches={summary['batches_evaluated']} "
      f" examples={summary['examples_evaluated']}"
  )
  print(
      f"Accuracy (mean±95% CI): {acc_m:.3f} ± {acc_ci:.3f} | Cumulative:"
      f" {summary['cumulative_accuracy']:.3f}"
  )
  print(f"Loss     (mean±95% CI): {loss_m:.4f} ± {loss_ci:.4f}")
  print(f"Pol Δ    (mean±95% CI): {pol_m:+.3f} ± {pol_ci:.3f}")
  print(f"Ref Δ    (mean±95% CI): {ref_m:+.3f} ± {ref_ci:.3f}")
  print(
      f"Adv      (mean±95% CI): {adv_m:+.3f} ± {adv_ci:.3f}  (reward ≈ β*adv)"
  )
  return summary

In [None]:
# Make a held-out slice (different from SPLIT used for training)
ds_eval = load_dataset("openbmb/RLAIF-V-Dataset", split="train[5000:5400]")
ds_eval = ds_eval.remove_columns([
    c
    for c in ds_eval.column_names
    if c not in ["image", "question", "chosen", "rejected"]
])
ds_eval = ds_eval.with_transform(preprocess_item)

# with mesh:
#     eval_summary = run_eval_vlm(lora_gemma, gemma3, ds_eval, batches=50, batch_size=16, seed=11)

In [None]:
INTERMEDIATE_CKPT_DIR = "/content/intermediate_ckpt_vlm"
os.makedirs(INTERMEDIATE_CKPT_DIR, exist_ok=True)
HIST_PATH = os.path.join(INTERMEDIATE_CKPT_DIR, "train_history.json")

HISTORY = {
    "step": [],
    "loss": [],
    "rewards/chosen": [],
    "rewards/rejected": [],
    "rewards/margin": [],
    "rewards/accuracy": [],
}

In [None]:
config = DpoTrainingConfig(
    eval_every_n_steps=EVAL_EVERY_N_STEPS,
    max_steps=MAX_STEPS,
    beta=BETA,
    label_smoothing=0.0,
)
optimizer = optax.adamw(learning_rate=LEARNING_RATE)
train_batches = numpy_batches_vlm(
    ds, batch_size=BATCH_SIZE, shuffle=False, seed=42, epochs=EPOCHS
)

In [None]:
with mesh:
  trainer = DpoTrainer(
      model=lora_gemma,
      ref_model=gemma3,
      optimizer=optimizer,
      training_config=config,
  )

In [None]:
_orig_post = getattr(trainer, "_post_process_train_step", None)


def _patched_post_process_train_step(self, aux):
  if _orig_post is not None:
    _orig_post(aux)

  s = int(getattr(self, "_train_steps", 0))

  loss_val = float("nan")
  bm = getattr(self, "_buffered_train_metrics", None)
  if bm is not None and getattr(bm, "losses", None):
    try:
      loss_val = float(bm.losses[-1])
    except Exception:
      pass

In [None]:
_orig_post = getattr(trainer, "_post_process_train_step", None)


def _patched_post_process_train_step(self, aux):
  if _orig_post is not None:
    _orig_post(aux)

  s = int(getattr(self, "_train_steps", 0))

  loss_val = float("nan")
  bm = getattr(self, "_buffered_train_metrics", None)
  if bm is not None and getattr(bm, "losses", None):
    try:
      loss_val = float(bm.losses[-1])
    except Exception:
      pass

  HISTORY["step"].append(s)
  HISTORY["loss"].append(loss_val)
  HISTORY["rewards/chosen"].append(float(aux["rewards/chosen"]))
  HISTORY["rewards/rejected"].append(float(aux["rewards/rejected"]))
  HISTORY["rewards/margin"].append(float(aux["rewards/margin"]))
  HISTORY["rewards/accuracy"].append(float(aux["rewards/accuracy"]))

  if s % EVAL_EVERY_N_STEPS == 0:
    print(
        "[metric]"
        f" step={s} loss={loss_val:.4f} margin={float(aux['rewards/margin']):.4f}"
    )


trainer._post_process_train_step = types.MethodType(
    _patched_post_process_train_step, trainer
)
print("✅ metrics capture patched on current trainer")

In [None]:
with mesh:
  trainer.train(train_batches)

In [None]:
# optional: persist
with open(HIST_PATH, "w") as f:
  json.dump(HISTORY, f)
print("📈 history saved to:", HIST_PATH, "| points:", len(HISTORY["step"]))


# --- plotting ---
def _safe_xy(hist, key):
  x = np.array(hist.get("step", []), dtype=float)
  y = np.array(hist.get(key, []), dtype=float)
  return x, y


def moving_average(data, window_size):
  """Calculates the moving average of a list or numpy array."""
  if len(data) < window_size:
    return (
        data  # Return original data if window size is larger than data length
    )
  return np.convolve(data, np.ones(window_size) / window_size, mode="valid")


def _plot_series(x, y, title, ylabel, window_size=5):
  if len(x) == 0:
    print(f"[plot] no data for {title}")
    return
  plt.figure()
  # Apply moving average
  y_smooth = moving_average(y, window_size)
  x_smooth = x[
      window_size - 1 :
  ]  # Adjust x to match the length of the smoothed data
  plt.plot(x_smooth, y_smooth)
  plt.title(title)
  plt.xlabel("step")
  plt.ylabel(ylabel)
  plt.grid(True)
  plt.show()


x, y = _safe_xy(HISTORY, "loss")
_plot_series(x, y, "Training Loss (Smoothed)", "loss")

x, y = _safe_xy(HISTORY, "rewards/margin")
_plot_series(x, y, "Rewards Margin (chosen - rejected) (Smoothed)", "margin")

x, ch = _safe_xy(HISTORY, "rewards/chosen")
_, rj = _safe_xy(HISTORY, "rewards/rejected")
if len(x):
  plt.figure()
  window_size = 10
  ch_smooth = moving_average(ch, window_size)
  rj_smooth = moving_average(rj, window_size)
  x_smooth = x[window_size - 1 :]
  plt.plot(x_smooth, ch_smooth, label="chosen (Smoothed)")
  plt.plot(x_smooth, rj_smooth, label="rejected (Smoothed)")
  plt.title("Chosen vs Rejected Rewards (Smoothed)")
  plt.xlabel("step")
  plt.ylabel("reward")
  plt.legend()
  plt.grid(True)
  plt.show()

x, y = _safe_xy(HISTORY, "rewards/accuracy")
_plot_series(x, y, "Rewards Accuracy (Smoothed)", "accuracy")