# Demo with a vlm dataset (openbmb/RLAIF-V-Dataset)

In [1]:
!pip install -q kagglehub

!pip install -q tensorflow
!pip install -q tensorboardX
!pip install -q grain
# !pip install -q git+https://github.com/google/tunix
!pip install -q git+https://github.com/google/qwix

!pip uninstall -q -y flax
!pip install -q git+https://github.com/google/flax.git

!pip install -q huggingface_hub
!pip install -q datasets

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for flax (pyproject.toml) ... [?25l[?25hdone


In [2]:
%cd /content
!git clone --branch uiuc-vlm --single-branch https://github.com/PLAN-Lab/tunix.git


/content
fatal: destination path 'tunix' already exists and is not an empty directory.


In [3]:
%cd /content/tunix
!pip -q install -e .

/content/tunix
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building editable for tunix (pyproject.toml) ... [?25l[?25hdone


In [4]:
# --- autoreload for dev loop ---
import sys, types, importlib
if 'imp' not in sys.modules:
    imp = types.ModuleType('imp')
    imp.reload = importlib.reload
    sys.modules['imp'] = imp

%load_ext autoreload
%autoreload 2

# sanity check import
import tunix
print("tunix imported from:", tunix.__file__)



tunix imported from: /content/tunix/tunix/__init__.py


In [5]:
import functools
import gc
import os
from pprint import pprint
import re
import time

from flax import nnx
import grain
import humanize
import jax
import jax.numpy as jnp
import kagglehub
import optax
from orbax import checkpoint as ocp
import qwix
import tensorflow_datasets as tfds
from tqdm.auto import tqdm

from tunix.examples.data import translation_dataset as data_lib
from tunix.generate import sampler as sampler_lib
from tunix.generate.vlm_sampler import VLMSampler
from tunix.models.gemma3 import params as params_lib
from tunix.models.gemma3 import params_safetensors as params_safetensors_lib

from tunix.sft import metrics_logger
from datasets import load_dataset
from tunix.sft.dpo.dpo_trainer import DpoTrainingConfig
from tunix.sft.dpo.dpo_trainer import DpoTrainer
from tunix.sft.dpo.dpo_trainer import TrainingInput
from huggingface_hub import snapshot_download
from tunix.sft.dpo.dpo_trainer import _generate_ids_and_masks
from tunix.models.gemma3 import model as gemma3_model_lib
from datasets import concatenate_datasets

In [6]:
# ====== Data ======
TRAIN_DATA_DIR = "./data/train"
TEST_DATA_DIR = "./data/test"
TRAIN_FRACTION = 1.0

INTERMEDIATE_CKPT_DIR = "/content/intermediate_ckpt/"
# ====== LoRA ======
RANK = 8
ALPHA = 16.0

# ====== Sharding ======
MESH = [(1, 1), ("fsdp", "tp")]

MAX_PROMPT_LENGTH = 192
TOTAL_GENERATION_STEPS = 192
TEMPERATURE = 0.7
TOP_P = 1.0
TOP_K = 50
BETA = 0.1

# === AdamW, warmup, cosine scheduler ===
LEARNING_RATE = 3e-6
B1 = 0.9
B2 = 0.99
WEIGHT_DECAY = 0.1

# == Cosine decay with warmup scheduler ==
# Linearly increase learning rate from 0. to 5e-6 in the first 10% training
# steps, and then gradually decrease the learning rate to 0 using cosine
# scheduler.
BATCH_SIZE = 1
NUM_BATCHES = 512
NUM_TEST_BATCHES = 100
EVAL_EVERY_N_STEPS = 100

NUM_EPOCHS = 1  # can potentially train for more epochs
TRAIN_FRACTION = 1.0
MAX_STEPS = int(NUM_BATCHES * TRAIN_FRACTION * NUM_EPOCHS)

WARMUP_STEPS = 0.1 * MAX_STEPS
# == Grad clipping ==
# Grad clipping to prevent large gradients. Found this
# important to keep KL divergence in check.
MAX_GRAD_NORM = 0.1

# ====== Inference ======
GENERATION_CONFIGS = {
    # greedy search
    "greedy": {"temperature": 1e-4, "top_k": 1, "top_p": 1.0},
    # some randomness
    "standard": {"temperature": 0.7, "top_k": 50, "top_p": 0.95},
    # liberal
    "liberal": {"temperature": 0.85, "top_k": 2000, "top_p": 1.0},
}

In [None]:

from huggingface_hub import login, HfApi
import os

HF_TOKEN = "xxxxxxxxxx"  # <-- paste your Read token



In [8]:
model_id = "google/gemma-3-1b-it"
ignore_patterns = [
    "*.pth",  # Ignore PyTorch .pth weight files
]
print(f"Downloading {model_id} from Hugging Face...")
local_model_path = snapshot_download(
    repo_id=model_id, ignore_patterns=ignore_patterns, token=HF_TOKEN
)
print(f"Model successfully downloaded to: {local_model_path}")

Downloading google/gemma-3-1b-it from Hugging Face...


Fetching 10 files:   0%|          | 0/10 [00:00<?, ?it/s]

Model successfully downloaded to: /root/.cache/huggingface/hub/models--google--gemma-3-1b-it/snapshots/dcc83ea841ab6100d6b47a070329e1ba4cf78752


In [9]:
MODEL_CP_PATH = local_model_path

model_config = (
    gemma3_model_lib.Gemma3Config.gemma3_1b()
)  # pick correponding config based on model version
MESH = [(1, 1), ("fsdp", "tp")]
mesh = jax.make_mesh(*MESH)
with mesh:
  gemma3 = params_safetensors_lib.create_model_from_safe_tensors(
      MODEL_CP_PATH, model_config, mesh
  )
  nnx.display(gemma3)

In [10]:
# ==== 3) Load only the tokenizer first (cheap), NOT the whole processor yet ====
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained("google/paligemma-3b-mix-224", token=HF_TOKEN)
print("Tokenizer loaded. pad_id:", tok.pad_token_id, "eos_id:", tok.eos_token_id)

# Adapter that your sampler expects
class HFTokenizerAdapter:
    def __init__(self, hf_tok):
        self.tok = hf_tok
        self._pad_id = self.tok.pad_token_id if self.tok.pad_token_id is not None else self.tok.eos_token_id
        self._eos_id = self.tok.eos_token_id
    def encode(self, s: str):
        return self.tok(s, add_special_tokens=False)["input_ids"]
    def decode(self, ids):
        return self.tok.decode(ids, skip_special_tokens=True)
    def pad_id(self) -> int: return int(self._pad_id)
    def eos_id(self) -> int: return int(self._eos_id)

gemma_tokenizer = HFTokenizerAdapter(tok)
print("Tokenizer adapter ready ✔️")

image_size = 224
print("Processor loaded. image_size:", image_size)

# ==== 5) Construct the sampler LAST (avoid touching model state in __init__) ====
# IMPORTANT: gemma3 must be your PaLI-Gemma nnx Module (not a text-only Gemma).
# If it's named differently, change transformer=... accordingly.
vlm_sampler = VLMSampler(
    transformer=gemma3,
    tokenizer=gemma_tokenizer,
    image_size=image_size,
)
print("VLMSampler ready. pad_id:", vlm_sampler.pad_id(), "eos_id:", vlm_sampler.eos_id())

# --- (Optional) 10-second smoke test with a dummy image ---
import numpy as np, jax.numpy as jnp
dummy = np.zeros((1, image_size, image_size, 3), dtype=np.uint8)
out = vlm_sampler(
    input_strings=["Describe the image:"],
    images=jnp.asarray(dummy),
    max_generation_steps=8,
    temperature=0.0,
    return_logits=False,
    echo=False,
)
print(out.text[0])



Tokenizer loaded. pad_id: 0 eos_id: 1
Tokenizer adapter ready ✔️
Processor loaded. image_size: 224
VLMSampler ready. pad_id: 0 eos_id: 1
:::DescribeDescribeDescribeDescribeeltas


In [11]:
def get_lora_model(base_model, mesh):
  lora_provider = qwix.LoraProvider(
      module_path=(
          ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|"
          ".*attn_vec_einsum"
      ),
      rank=RANK,
      alpha=ALPHA,
      #weight_qtype="nf4",
      #tile_size=4,
  )

  model_input = base_model.get_model_input()
  lora_model = qwix.apply_lora_to_model(
      base_model, lora_provider, **model_input
  )

  with mesh:
    state = nnx.state(lora_model)
    pspecs = nnx.get_partition_spec(state)
    sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
    nnx.update(lora_model, sharded_state)

  return lora_model

In [12]:
# Policy model
lora_gemma = get_lora_model(gemma3, mesh=mesh)
nnx.display(lora_gemma)

In [13]:
from datasets import load_dataset
import numpy as np, jax.numpy as jnp
from PIL import Image
from tunix.models.siglip import preprocess as siglip_pp
from transformers import AutoTokenizer # Import AutoTokenizer

# ==== 3) Load only the tokenizer first (cheap), NOT the whole processor yet ====
# Assuming gemma_tokenizer is already defined and is an instance of HFTokenizerAdapter
# If not, you would need to load it here or in a previous cell.
# from transformers import AutoTokenizer
# tok = AutoTokenizer.from_pretrained("google/paligemma-3b-mix-224", token=HF_TOKEN)
# class HFTokenizerAdapter: ... # Define the adapter if not already defined
# gemma_tokenizer = HFTokenizerAdapter(tok)


SPLIT = "train[:64]"
MAX_PROMPT_LEN = 128
IMAGE_SIZE = 224

ds = load_dataset("openbmb/RLAIF-V-Dataset", split=SPLIT, token=HF_TOKEN)
cols = ["image", "question", "chosen", "rejected"]
ds = ds.remove_columns([c for c in ds.column_names if c not in cols])

def _pick_one_image(img_field):
    """Return a single PIL.Image from the dataset's image field."""
    x = img_field
    if isinstance(x, list):
        if not x:  # empty list, skip later
            return None
        x = x[0]
    if isinstance(x, Image.Image):
        return x.convert("RGB")
    # HF 'Image' feature sometimes gives np.ndarray
    arr = np.array(x)
    if arr.ndim == 3:
        return Image.fromarray(arr).convert("RGB")
    return None

def preprocess_item(ex):
    img = _pick_one_image(ex["image"])
    if img is None:
        return {
            "pixel_values": None,
            "question": ex["question"],
            "chosen": ex["chosen"],
            "rejected": ex["rejected"],
        }
    arr = np.array(img, dtype=np.uint8)[None, ...]              # [1,H,W,3]
    px  = siglip_pp.preprocess(jnp.asarray(arr), IMAGE_SIZE)     # [1,S,S,3] float32
    return {
        "pixel_values": np.asarray(px[0]),                       # [S,S,3]
        "question": ex["question"],
        "chosen": ex["chosen"],
        "rejected": ex["rejected"],
    }

ds = ds.with_transform(preprocess_item)

def numpy_batches(dataset, batch_size=4, shuffle=True, seed=0):
    idx = np.arange(len(dataset))
    if shuffle:
        rng = np.random.default_rng(seed)
        rng.shuffle(idx)
    buf = []
    for i in idx:
        ex = dataset[int(i)]
        # skip examples where we couldn't pick an image
        if ex["pixel_values"] is None:
            continue
        buf.append(ex)
        if len(buf) == batch_size:
            # Tokenize text fields individually and convert to JAX arrays
            questions_tokens = [gemma_tokenizer.tok(q, truncation=True, max_length=MAX_PROMPT_LEN)["input_ids"] for q in [b["question"] for b in buf]]
            chosen_tokens = [gemma_tokenizer.tok(c, truncation=True, max_length=MAX_PROMPT_LEN)["input_ids"] for c in [b["chosen"] for b in buf]]
            rejected_tokens = [gemma_tokenizer.tok(r, truncation=True, max_length=MAX_PROMPT_LEN)["input_ids"] for r in [b["rejected"] for b in buf]]

            # Pad and stack token IDs
            max_q_len = max(len(ids) for ids in questions_tokens)
            max_c_len = max(len(ids) for ids in chosen_tokens)
            max_r_len = max(len(ids) for ids in rejected_tokens)

            padded_questions = jnp.stack([jnp.pad(jnp.asarray(ids), (0, max_q_len - len(ids)), constant_values=gemma_tokenizer.pad_id()) for ids in questions_tokens], axis=0)
            padded_chosen = jnp.stack([jnp.pad(jnp.asarray(ids), (0, max_c_len - len(ids)), constant_values=gemma_tokenizer.pad_id()) for ids in chosen_tokens], axis=0)
            padded_rejected = jnp.stack([jnp.pad(jnp.asarray(ids), (0, max_r_len - len(ids)), constant_values=gemma_tokenizer.pad_id()) for ids in rejected_tokens], axis=0)


            yield {
                "pixel_values": jnp.asarray(np.stack([b["pixel_values"] for b in buf], axis=0)),
                "question":     padded_questions,
                "chosen":       padded_chosen,
                "rejected":     padded_rejected,
            }
            buf = []
    if buf:
        # Tokenize text fields individually and convert to JAX arrays for the last partial batch
        questions_tokens = [gemma_tokenizer.tok(q, truncation=True, max_length=MAX_PROMPT_LEN)["input_ids"] for q in [b["question"] for b in buf]]
        chosen_tokens = [gemma_tokenizer.tok(c, truncation=True, max_length=MAX_PROMPT_LEN)["input_ids"] for c in [b["chosen"] for b in buf]]
        rejected_tokens = [gemma_tokenizer.tok(r, truncation=True, max_length=MAX_PROMPT_LEN)["input_ids"] for r in [b["rejected"] for b in buf]]

        # Pad and stack token IDs
        max_q_len = max(len(ids) for ids in questions_tokens)
        max_c_len = max(len(ids) for ids in chosen_tokens)
        max_r_len = max(len(ids) for ids in rejected_tokens)

        padded_questions = jnp.stack([jnp.pad(jnp.asarray(ids), (0, max_q_len - len(ids)), constant_values=gemma_tokenizer.pad_id()) for ids in questions_tokens], axis=0)
        padded_chosen = jnp.stack([jnp.pad(jnp.asarray(ids), (0, max_c_len - len(ids)), constant_values=gemma_tokenizer.pad_id()) for ids in chosen_tokens], axis=0)
        padded_rejected = jnp.stack([jnp.pad(jnp.asarray(ids), (0, max_r_len - len(ids)), constant_values=gemma_tokenizer.pad_id()) for ids in rejected_tokens], axis=0)

        yield {
            "pixel_values": jnp.asarray(np.stack([b["pixel_values"] for b in buf], axis=0)),
            "question":     padded_questions,
            "chosen":       padded_chosen,
            "rejected":     padded_rejected,
        }


# Smoke one batch
b0 = next(numpy_batches(ds, batch_size=2))
print("Batch pixels:", b0["pixel_values"].shape, "| B:", len(b0["question"]))
print("Batch question tokens:", b0["question"].shape)
print("Batch chosen tokens:", b0["chosen"].shape)
print("Batch rejected tokens:", b0["rejected"].shape)

Batch pixels: (2, 224, 3) | B: 2
Batch question tokens: (2, 12)
Batch chosen tokens: (2, 128)
Batch rejected tokens: (2, 102)


In [15]:
import jax, jax.numpy as jnp

def _make_pos_and_mask(tokens: jnp.ndarray, pad_id: int):
    mask = (tokens != pad_id).astype(jnp.int32)
    positions = (jnp.cumsum(mask, axis=1) - 1) * mask
    return positions.astype(jnp.int32), mask

def _vlm_forward_and_cache(policy_mod, *, tokens: jnp.ndarray, pixel_values: jnp.ndarray, pad_id: int):
    """Returns (logits, cache_out). Avoids side-effects by threading cache through."""
    tokens = tokens.astype(jnp.int32)
    pixel_values = pixel_values.astype(jnp.float32)
    positions, attn_mask = _make_pos_and_mask(tokens, pad_id)
    cache_in = {}  # local cache owned by this call

    logits, cache_out = policy_mod(
        last_tokens=tokens,
        positions=positions,
        cache=cache_in,              # important: NOT None
        attention_mask=attn_mask,
        pixel_values=pixel_values,
        output_hidden_states=False,
    )
    return logits, cache_out

def _seq_logprob_batch_and_cache(policy_mod, px, prompt_ids, answer_ids, pad_id):
    # teacher forcing: [prompt, answer[:-1]]
    x_ids = jnp.concatenate([prompt_ids, answer_ids[:, :-1]], axis=1)
    logits, cache = _vlm_forward_and_cache(policy_mod, tokens=x_ids, pixel_values=px, pad_id=pad_id)
    La = answer_ids.shape[1]
    ans_logits = logits[:, -La:, :]
    logp = jax.nn.log_softmax(ans_logits, axis=-1)
    tok_logp = jnp.take_along_axis(logp, answer_ids[..., None], axis=-1)[..., 0]
    ans_mask = (answer_ids != pad_id).astype(tok_logp.dtype)
    return (tok_logp * ans_mask).sum(axis=1), cache  # [B], dict
def dpo_loss_batch_and_aux(policy, ref, batch, pad_id: int, beta: float = 0.1):
    """
    ARRAY batch only:
      pixel_values [B,S,S,3] float32
      prompt_ids / chosen_ids / rejected_ids  (or question/chosen/rejected) int32
    Returns: (loss, aux) where aux holds caches so no tracer 'escapes'.
    """
    px = jnp.asarray(batch["pixel_values"], dtype=jnp.float32)
    q  = jnp.asarray(batch.get("prompt_ids",   batch.get("question")),   dtype=jnp.int32)
    ch = jnp.asarray(batch.get("chosen_ids",   batch.get("chosen")),     dtype=jnp.int32)
    rj = jnp.asarray(batch.get("rejected_ids", batch.get("rejected")),   dtype=jnp.int32)

    lp_ch, cache_p_ch = _seq_logprob_batch_and_cache(policy, px, q, ch, pad_id)
    lp_rj, cache_p_rj = _seq_logprob_batch_and_cache(policy, px, q, rj, pad_id)
    lq_ch, cache_q_ch = _seq_logprob_batch_and_cache(ref,    px, q, ch, pad_id)
    lq_rj, cache_q_rj = _seq_logprob_batch_and_cache(ref,    px, q, rj, pad_id)

    advantages = (lp_ch - lp_rj) - (lq_ch - lq_rj)
    loss = -jax.nn.log_sigmoid(beta * advantages).mean()

    aux = {
        "policy_cache_chosen":  cache_p_ch,
        "policy_cache_reject":  cache_p_rj,
        "ref_cache_chosen":     cache_q_ch,
        "ref_cache_reject":     cache_q_rj,
    }
    return loss, aux

In [14]:
policy = lora_gemma   # trainable LoRA-wrapped model
ref    = nnx.clone(gemma3)

In [16]:
from flax import nnx
import jax, jax.numpy as jnp
import optax

# Immutable templates (never mutate)
policy_template = nnx.clone(policy)
ref_template    = nnx.clone(ref)

# Live states
policy_state = nnx.state(policy)
ref_state    = nnx.state(ref)

LR     = 1e-5
PAD_ID = gemma_tokenizer.pad_id()

def loss_with_state(p_state, batch_arrays):
    pol = nnx.clone(policy_template); nnx.update(pol, p_state)
    refm = nnx.clone(ref_template);   nnx.update(refm, ref_state)
    loss, aux = dpo_loss_batch_and_aux(pol, refm, batch_arrays, pad_id=PAD_ID, beta=0.1)
    new_p_state = nnx.state(pol)  # return updated state so nothing “escapes”
    return loss, (new_p_state, aux)

def train_step(policy_state, batch_arrays):
    (loss, (policy_state_out, aux)), grads = jax.value_and_grad(
        loss_with_state, has_aux=True
    )(policy_state, batch_arrays)
    # manual SGD: tiniest memory footprint
    updates = jax.tree_util.tree_map(lambda g: -LR * g, grads)
    policy_state2 = optax.apply_updates(policy_state_out, updates)
    return loss, policy_state2, aux

# tiny smoke
policy_template = nnx.clone(gemma3)
policy_state    = nnx.state(gemma3)
ref_template    = nnx.clone(gemma3)
ref_state       = nnx.state(ref_template)

# same dpo_loss_* you used before, same batch prep
small_batch = next(numpy_batches(ds, batch_size=1))
loss, policy_state, aux = train_step(policy_state, small_batch)  # your minimal stateless version
print("baseline loss:", float(loss))

baseline loss: 0.69140625


# Tiny but real training loop

In [19]:
import optax

LR = 3e-6
BETA = 0.1
MAX_GRAD_NORM = 0.1

tx = optax.chain(
    optax.clip_by_global_norm(MAX_GRAD_NORM),
    optax.adamw(LR, b1=0.9, b2=0.99, weight_decay=0.1),
)

opt_state = tx.init(policy_state)
PAD_ID = gemma_tokenizer.pad_id()

def loss_with_state(p_state, batch_arrays):
    pol  = nnx.clone(policy_template); nnx.update(pol, p_state)
    refm = nnx.clone(ref_template);   nnx.update(refm, ref_state)
    loss, aux = dpo_loss_batch_and_aux(pol, refm, batch_arrays, pad_id=PAD_ID, beta=BETA)
    new_p_state = nnx.state(pol)
    return loss, (new_p_state, aux)

import jax

def train_step(policy_state, opt_state, batch_arrays):
    (loss, (policy_state_out, aux)), grads = jax.value_and_grad(
        loss_with_state, has_aux=True
    )(policy_state, batch_arrays)

    updates, opt_state2 = tx.update(grads, opt_state, params=policy_state_out)
    policy_state2 = optax.apply_updates(policy_state_out, updates)
    return loss, policy_state2, opt_state2, aux

In [20]:
# tiny run
steps = 10
losses = []
itr = numpy_batches(ds, batch_size=1, shuffle=True, seed=42)

for step in range(1, steps+1):
    batch = next(itr)
    loss, policy_state, opt_state, aux = train_step(policy_state, opt_state, batch)
    losses.append(float(loss))
    print(f"step {step}/{steps}  loss={losses[-1]:.4f}")

step 1/20  loss=0.7969
step 5/20  loss=0.4355
step 10/20  loss=0.3711
step 15/20  loss=2.1250
step 20/20  loss=1.1719
