In [1]:
import sys, os
sys.path.insert(0, os.path.abspath(".."))

import json

# point this at your actual data file
data_path = "/gpfs/data/oermannlab/public_data/llava-med-data/llava_med_instruct_60k_inline_mention_filtered.json"

with open(data_path) as f:
    data = json.load(f)

print(f"total samples: {len(data)}")

# scan all conversations
no_image = 0
image_not_first_turn = 0
multi_image = 0
max_turns = 0

for entry in data:
    convos = entry["conversations"]
    max_turns = max(max_turns, len(convos))
    
    # find all image blocks
    image_positions = []  # (turn_index, block_index)
    for t_idx, turn in enumerate(convos):
        content = turn.get("content", [])
        if isinstance(content, str):
            # old format?
            if "<image>" in content:
                image_positions.append((t_idx, -1))
            continue
        for b_idx, block in enumerate(content):
            if isinstance(block, dict) and block.get("type") == "image":
                image_positions.append((t_idx, b_idx))
    
    if len(image_positions) == 0:
        no_image += 1
    elif image_positions[0][0] != 0:
        image_not_first_turn += 1
    if len(image_positions) > 1:
        multi_image += 1

print(f"max turns in a conversation: {max_turns}")
print(f"samples with NO image:       {no_image}")
print(f"image NOT in first turn:     {image_not_first_turn}")
print(f"samples with multiple images: {multi_image}")

# show a couple of edge cases if they exist
if no_image > 0:
    for entry in data:
        has_img = any(
            isinstance(b, dict) and b.get("type") == "image"
            for t in entry["conversations"]
            for b in (t.get("content", []) if isinstance(t.get("content"), list) else [])
        )
        if not has_img:
            print(f"\nExample no-image sample:")
            print(json.dumps(entry["conversations"][:2], indent=2))
            break

total samples: 56658
max turns in a conversation: 10
samples with NO image:       0
image NOT in first turn:     0
samples with multiple images: 0


In [2]:
from transformers import AutoProcessor, LlavaNextProcessor
from PIL import Image
from io import BytesIO
import requests

# grab a test image
url = "https://upload.wikimedia.org/wikipedia/commons/c/c8/Chest_Xray_PA_3-8-2010.png"
headers = {"User-Agent": "Mozilla/5.0"}
r = requests.get(url, headers=headers, timeout=30)
r.raise_for_status()

# optional but very helpful sanity check
ct = r.headers.get("Content-Type", "")
if "image" not in ct:
    raise ValueError(f"Expected image, got Content-Type={ct}. First bytes: {r.content[:120]!r}")

image = Image.open(BytesIO(r.content)).convert("RGB")

llava_proc = LlavaNextProcessor.from_pretrained("NYU-OLAB/LLaVA-Next-Med-OLAB")
gemma_proc = AutoProcessor.from_pretrained("google/gemma-3-4b-it")

l_tok = llava_proc.tokenizer
g_tok = gemma_proc.tokenizer

print("=" * 70)
print("TOKENIZER COMPARISON: LLaVA-Next-34B vs Gemma 3")
print("=" * 70)

attrs = [
    "vocab_size", "model_max_length", "padding_side", "truncation_side",
    "pad_token", "pad_token_id", "eos_token", "eos_token_id",
    "bos_token", "bos_token_id", "unk_token", "unk_token_id",
]
for attr in attrs:
    lv = getattr(l_tok, attr, "N/A")
    gv = getattr(g_tok, attr, "N/A")
    flag = "" if lv == gv else "  <-- DIFFERS"
    print(f"  {attr:25s}  llava={str(lv):20s}  gemma={str(gv):20s}{flag}")

# special tokens
print(f"\n  all_special_tokens:")
l_special = set(l_tok.all_special_tokens)
g_special = set(g_tok.all_special_tokens)
shared = l_special & g_special
only_llava = l_special - g_special
only_gemma = g_special - l_special
print(f"    shared:         {shared}")
if only_llava:
    print(f"    only in llava:  {only_llava}")
if only_gemma:
    print(f"    only in gemma:  {only_gemma}")

# the tokens that matter for IFT masking
print(f"\n  IFT-relevant tokens (old vs new):")
# the 34B uses ChatML with <|im_start|> / <|im_end|>
llava_ift_tokens = ["<|im_start|>", "<|im_end|>"]
print(f"    LLaVA-Next-34B:")
for tok in llava_ift_tokens:
    tid = l_tok.convert_tokens_to_ids(tok)
    print(f"      {tok:20s} -> {tid}")
# also check the role words
for role_word in ["assistant", "user"]:
    tid = l_tok.encode(role_word, add_special_tokens=False)
    print(f"      {role_word:20s} -> {tid}")

print(f"    Gemma 3:")
gemma_ift_tokens = ["<start_of_turn>", "<end_of_turn>"]
for tok in gemma_ift_tokens:
    tid = g_tok.convert_tokens_to_ids(tok)
    print(f"      {tok:20s} -> {tid}")
for role_word in ["model", "user"]:
    tid = g_tok.encode(role_word, add_special_tokens=False)
    print(f"      {role_word:20s} -> {tid}")

# --- Processor-level ---
print("\n" + "=" * 70)
print("PROCESSOR COMPARISON")
print("=" * 70)

for attr in ["image_token", "image_token_id"]:
    lv = getattr(llava_proc, attr, "N/A")
    gv = getattr(gemma_proc, attr, "N/A")
    flag = "" if lv == gv else "  <-- DIFFERS"
    print(f"  {attr:25s}  llava={str(lv):20s}  gemma={str(gv):20s}{flag}")

# --- Image processor ---
print("\n" + "=" * 70)
print("IMAGE PROCESSOR COMPARISON")
print("=" * 70)

l_ip = llava_proc.image_processor
g_ip = gemma_proc.image_processor

print(f"  {'type':35s}  llava={type(l_ip).__name__:30s}  gemma={type(g_ip).__name__}")
ip_attrs = [
    "do_resize", "do_rescale", "do_normalize", "do_convert_rgb",
    "image_mean", "image_std", "size", "resample",
    "rescale_factor",
]
for attr in ip_attrs:
    lv = getattr(l_ip, attr, "N/A")
    gv = getattr(g_ip, attr, "N/A")
    flag = "" if lv == gv else "  <-- DIFFERS"
    print(f"  {attr:35s}  llava={str(lv):30s}  gemma={str(gv):30s}{flag}")

# gemma-specific attrs
for attr in ["do_pan_and_scan", "image_seq_length",
             "pan_and_scan_max_num_crops", "pan_and_scan_min_crop_size"]:
    gv = getattr(g_ip, attr, "N/A")
    print(f"  {attr:35s}  llava={'N/A':30s}  gemma={str(gv):30s}  (gemma only)")

# --- Process the actual image through both ---
print("\n" + "=" * 70)
print(f"PIXEL VALUES (actual image: {image.size})")
print("=" * 70)

l_pix = llava_proc.image_processor(images=image, return_tensors="pt")["pixel_values"]
g_pix = gemma_proc.image_processor(images=image, return_tensors="pt")["pixel_values"]

print(f"  llava pixel_values: shape={l_pix.shape}  dtype={l_pix.dtype}  "
      f"min={l_pix.min():.3f}  max={l_pix.max():.3f}  mean={l_pix.mean():.3f}")
print(f"  gemma pixel_values: shape={g_pix.shape}  dtype={g_pix.dtype}  "
      f"min={g_pix.min():.3f}  max={g_pix.max():.3f}  mean={g_pix.mean():.3f}")

# --- Chat template output ---
print("\n" + "=" * 70)
print("CHAT TEMPLATE OUTPUT (same conversation, same image)")
print("=" * 70)

convo_hf = [
    {"role": "user", "content": [
        {"type": "image"},
        {"type": "text", "text": "What is this scan?"},
    ]},
    {"role": "assistant", "content": [
        {"type": "text", "text": "This is a chest X-ray showing no acute findings."},
    ]},
]

g_text = gemma_proc.apply_chat_template(convo_hf, tokenize=False, add_generation_prompt=False)
print(f"  Gemma 3:\n{g_text}\n")

# LLaVA-Next 34B uses ChatML — try HF format first, fall back to old format
try:
    l_text = llava_proc.apply_chat_template(convo_hf, tokenize=False, add_generation_prompt=False)
    print(f"  LLaVA-Next-34B (HF format):\n{l_text}\n")
except Exception as e:
    print(f"  LLaVA-Next-34B HF format failed: {e}")
    llava_convo = [
        {"role": "user", "content": "<image>\nWhat is this scan?"},
        {"role": "assistant", "content": "This is a chest X-ray showing no acute findings."},
    ]
    l_text = llava_proc.apply_chat_template(llava_convo)
    print(f"  LLaVA-Next-34B (old format):\n{l_text}\n")

# --- Full processor call with image ---
print("=" * 70)
print("FULL PROCESSOR CALL (text + image)")
print("=" * 70)

g_batch = gemma_proc(text=[g_text], images=[image], padding=True, return_tensors="pt")
print(f"  Gemma 3 batch keys:  {sorted(g_batch.keys())}")
print(f"    input_ids:    {g_batch['input_ids'].shape}")
print(f"    pixel_values: {g_batch['pixel_values'].shape}")

try:
    l_batch = llava_proc(text=[l_text], images=[image], padding=True, return_tensors="pt")
    print(f"  LLaVA batch keys:    {sorted(l_batch.keys())}")
    print(f"    input_ids:    {l_batch['input_ids'].shape}")
    print(f"    pixel_values: {l_batch['pixel_values'].shape}")
except Exception as e:
    print(f"  LLaVA full processor call failed: {e}")


print("\n" + "=" * 70)
print("UNTOKENIZED PROMPTS (decoded from input_ids)")
print("=" * 70)


print("\ndone")

  from .autonotebook import tqdm as notebook_tqdm
The image processor of type `LlavaNextImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. 
The image processor of type `Gemma3ImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. 


TOKENIZER COMPARISON: LLaVA-Next-34B vs Gemma 3
  vocab_size                 llava=64000                 gemma=262144                <-- DIFFERS
  model_max_length           llava=4096                  gemma=1000000000000000019884624838656  <-- DIFFERS
  padding_side               llava=right                 gemma=left                  <-- DIFFERS
  truncation_side            llava=right                 gemma=right               
  pad_token                  llava=<unk>                 gemma=<pad>                 <-- DIFFERS
  pad_token_id               llava=0                     gemma=0                   
  eos_token                  llava=<|im_end|>            gemma=<eos>                 <-- DIFFERS
  eos_token_id               llava=7                     gemma=1                     <-- DIFFERS
  bos_token                  llava=<|startoftext|>       gemma=<bos>                 <-- DIFFERS
  bos_token_id               llava=1                     gemma=2                     <-- DIFFE

In [3]:
from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")

convo = [
    {"role": "user", "content": [
        {"type": "text", "text": "What type of scan is this?"},
        {"type": "image"},
    ]},
    {"role": "assistant", "content": [
        {"type": "text", "text": "This is a chest X-ray."},
    ]},
]

# what the old code did (no kwargs)
try:
    result_old = processor.apply_chat_template(convo)
    print(f"No kwargs -> type: {type(result_old).__name__}")
    if isinstance(result_old, str):
        print(f"  value: {repr(result_old[:200])}")
    elif hasattr(result_old, 'shape'):
        print(f"  shape: {result_old.shape}")
    elif isinstance(result_old, list):
        print(f"  len: {len(result_old)}, first few: {result_old[:10]}")
    else:
        print(f"  value: {result_old}")
except Exception as e:
    print(f"No kwargs -> EXCEPTION: {e}")

print()

# what the new code does
result_new = processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=False)
print(f"tokenize=False -> type: {type(result_new).__name__}")
print(f"  value: {repr(result_new[:200])}")

No kwargs -> type: str
  value: '<bos><start_of_turn>user\nWhat type of scan is this?<start_of_image><end_of_turn>\n<start_of_turn>model\nThis is a chest X-ray.<end_of_turn>\n'

tokenize=False -> type: str
  value: '<bos><start_of_turn>user\nWhat type of scan is this?<start_of_image><end_of_turn>\n<start_of_turn>model\nThis is a chest X-ray.<end_of_turn>\n'


In [4]:
import sys
import os

from llava_med_dataset import LLaVAMedDataset


ModuleNotFoundError: No module named 'llava_med_dataset'

In [1]:
%load_ext autoreload
%autoreload 2

In [24]:
import sys, os
sys.path.insert(0, os.path.abspath(".."))

import torch
from PIL import Image
from transformers import AutoProcessor

# fake test image
fake_image = Image.new("RGB", (224, 224), color=(128, 64, 32))

# single-turn conversation matching your actual data format
single_turn = [
    {"role": "user", "content": [
        {"type": "text", "text": "What is the purpose of the flow diagram?"},
        {"type": "image"},
    ]},
    {"role": "assistant", "content": [
        {"type": "text", "text": "The purpose of the flow diagram is to illustrate the lung cancer screening process."},
    ]},
]

# multi-turn conversation matching your actual data format
multi_turn = [
    {"role": "user", "content": [
        {"type": "text", "text": "What is the purpose of the flow diagram?"},
        {"type": "image"},
    ]},
    {"role": "assistant", "content": [
        {"type": "text", "text": "The purpose of the flow diagram is to illustrate the lung cancer screening process."},
    ]},
    {"role": "user", "content": [
        {"type": "text", "text": "What is the primary screening method?"},
    ]},
    {"role": "assistant", "content": [
        {"type": "text", "text": "The primary screening method involves chest X-ray examinations."},
    ]},
]

print("setup done")

setup done


In [25]:
processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")

# single turn
text_single = processor.apply_chat_template(
    single_turn, tokenize=False, add_generation_prompt=False
)
print("=== Single turn ===")
print(repr(text_single))
print()

# multi turn
text_multi = processor.apply_chat_template(
    multi_turn, tokenize=False, add_generation_prompt=False
)
print("=== Multi turn ===")
print(repr(text_multi))
print()

# structural checks
assert "<start_of_turn>user" in text_single
assert "<start_of_turn>model" in text_single
assert "<end_of_turn>" in text_single
assert "flow diagram" in text_single
assert "lung cancer screening" in text_single

# multi-turn should have 2 model turns
assert text_multi.count("<start_of_turn>model") == 2, \
    f"expected 2 model turns, got {text_multi.count('<start_of_turn>model')}"
assert text_multi.count("<start_of_turn>user") == 2

print("apply_chat_template: ALL PASSED")

=== Single turn ===
'<bos><start_of_turn>user\nWhat is the purpose of the flow diagram?<start_of_image><end_of_turn>\n<start_of_turn>model\nThe purpose of the flow diagram is to illustrate the lung cancer screening process.<end_of_turn>\n'

=== Multi turn ===
'<bos><start_of_turn>user\nWhat is the purpose of the flow diagram?<start_of_image><end_of_turn>\n<start_of_turn>model\nThe purpose of the flow diagram is to illustrate the lung cancer screening process.<end_of_turn>\n<start_of_turn>user\nWhat is the primary screening method?<end_of_turn>\n<start_of_turn>model\nThe primary screening method involves chest X-ray examinations.<end_of_turn>\n'

apply_chat_template: ALL PASSED


In [26]:
text = processor.apply_chat_template(
    single_turn, tokenize=False, add_generation_prompt=False
)

batch = processor(
    text=[text],
    images=[fake_image],
    padding=True,
    truncation=True,
    max_length=512,
    return_tensors="pt",
)

print(f"Batch keys: {list(batch.keys())}")
print(f"input_ids shape:     {batch['input_ids'].shape}")
print(f"attention_mask shape: {batch['attention_mask'].shape}")
print(f"pixel_values shape:  {batch['pixel_values'].shape}")

assert "input_ids" in batch
assert "attention_mask" in batch
assert "pixel_values" in batch

decoded = processor.tokenizer.decode(batch["input_ids"][0], skip_special_tokens=False)
print(f"\nDecoded (first 500 chars):\n{decoded[:500]}")

print("\ntwo-step processor: ALL PASSED")

Batch keys: ['input_ids', 'attention_mask', 'token_type_ids', 'pixel_values']
input_ids shape:     torch.Size([1, 296])
attention_mask shape: torch.Size([1, 296])
pixel_values shape:  torch.Size([1, 3, 896, 896])

Decoded (first 500 chars):
<bos><bos><start_of_turn>user
What is the purpose of the flow diagram?

<start_of_image><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_toke

two-step processor: ALL PASSED


In [27]:
tokenizer = processor.tokenizer
unk_id = tokenizer.convert_tokens_to_ids("<unk>")

tokens_to_check = {
    "<start_of_turn>": tokenizer.convert_tokens_to_ids("<start_of_turn>"),
    "<end_of_turn>":   tokenizer.convert_tokens_to_ids("<end_of_turn>"),
    "model":           tokenizer.convert_tokens_to_ids("model"),
    "user":            tokenizer.convert_tokens_to_ids("user"),
}

print("=== Token ID map ===")
for name, tid in tokens_to_check.items():
    status = "OK" if tid != unk_id else "MISSING (resolved to UNK!)"
    print(f"  {name:20s} -> {tid:8d}  {status}")
    assert tid != unk_id, f"{name} resolved to UNK — not in vocab!"

# also check image token
if hasattr(processor, "image_token_id"):
    print(f"  {'image_token_id':20s} -> {processor.image_token_id:8d}")
else:
    print("  WARNING: processor has no image_token_id attribute")

print("\ntoken IDs: ALL PASSED")

=== Token ID map ===
  <start_of_turn>      ->      105  OK
  <end_of_turn>        ->      106  OK
  model                ->     4368  OK
  user                 ->     2364  OK
  image_token_id       ->   262144

token IDs: ALL PASSED


In [28]:
# tokenize the multi-turn sample
text = processor.apply_chat_template(
    multi_turn, tokenize=False, add_generation_prompt=False
)
batch = processor(
    text=[text], images=[fake_image],
    padding=True, truncation=True, max_length=512, return_tensors="pt",
)

tokenizer = processor.tokenizer
start_of_turn_id = tokenizer.convert_tokens_to_ids("<start_of_turn>")
end_of_turn_id = tokenizer.convert_tokens_to_ids("<end_of_turn>")
model_token_id = tokenizer.convert_tokens_to_ids("model")

input_ids = batch["input_ids"]
labels = torch.full_like(input_ids, fill_value=-100)

in_model_turn = False
for j in range(input_ids.shape[1]):
    tok = input_ids[0, j].item()
    if tok == start_of_turn_id:
        in_model_turn = False
        if j + 1 < input_ids.shape[1]:
            if input_ids[0, j + 1].item() == model_token_id:
                in_model_turn = True
    if in_model_turn:
        labels[0, j] = input_ids[0, j]
    if tok == end_of_turn_id and in_model_turn:
        in_model_turn = False

labels[input_ids == tokenizer.pad_token_id] = -100
if hasattr(processor, "image_token_id"):
    labels[input_ids == processor.image_token_id] = -100

# print token-by-token mask
tokens = tokenizer.convert_ids_to_tokens(input_ids[0].tolist())
print("=== Token-level label mask (multi-turn) ===")
for j, (tok, lab) in enumerate(zip(tokens, labels[0].tolist())):
    if tok not in ["<pad>"]:
        status = "TRAIN" if lab != -100 else "-----"
        print(f"  [{j:3d}] {tok:30s}  {status}")

# verify: assistant content should be trained, user content should not
trained_text = tokenizer.decode(
    [t for t, l in zip(input_ids[0].tolist(), labels[0].tolist()) if l != -100]
)
print(f"\nTrained text:\n{trained_text}")

assert "lung cancer screening" in trained_text, "first assistant response should be trained"
assert "chest X-ray" in trained_text, "second assistant response should be trained"

# user questions should NOT appear in trained text
assert "What is the purpose of" not in trained_text, "user question should be masked"
assert "What is the primary" not in trained_text, "second user question should be masked"

print("\nIFT masking: ALL PASSED")

=== Token-level label mask (multi-turn) ===
  [  0] <bos>                           -----
  [  1] <bos>                           -----
  [  2] <start_of_turn>                 -----
  [  3] user                            -----
  [  4] 
                               -----
  [  5] What                            -----
  [  6] ▁is                             -----
  [  7] ▁the                            -----
  [  8] ▁purpose                        -----
  [  9] ▁of                             -----
  [ 10] ▁the                            -----
  [ 11] ▁flow                           -----
  [ 12] ▁diagram                        -----
  [ 13] ?                               -----
  [ 14] 

                              -----
  [ 15] <start_of_image>                -----
  [ 16] <image_soft_token>              -----
  [ 17] <image_soft_token>              -----
  [ 18] <image_soft_token>              -----
  [ 19] <image_soft_token>              -----
  [ 20] <image_soft_token>          

In [29]:
from monosemanticity.datasets.base_multimodal_dataset import BaseMultimodalDataModule

# fake samples: (PIL_image, conversation_list) — what __getitem__ returns with data_key="conversations"
samples = [
    (Image.new("RGB", (256, 256), color=(100, 50, 50)), [
        {"role": "user", "content": [
            {"type": "text", "text": "What type of scan is this?"},
            {"type": "image"},
        ]},
        {"role": "assistant", "content": [
            {"type": "text", "text": "This is an axial CT scan of the abdomen."},
        ]},
    ]),
    (Image.new("RGB", (256, 256), color=(50, 100, 50)), [
        {"role": "user", "content": [
            {"type": "text", "text": "Describe the pathology."},
            {"type": "image"},
        ]},
        {"role": "assistant", "content": [
            {"type": "text", "text": "There is a large mass in the left frontal lobe consistent with a glioma."},
        ]},
    ]),
    (Image.new("RGB", (300, 200), color=(50, 50, 100)), [
        {"role": "user", "content": [
            {"type": "text", "text": "What do you see?"},
            {"type": "image"},
        ]},
        {"role": "assistant", "content": [
            {"type": "text", "text": "Normal chest radiograph with no acute findings."},
        ]},
    ]),
]

# minimal DataModule subclass just to get the collate methods
class TestModule(BaseMultimodalDataModule):
    def setup(self, stage=None):
        pass

# --- reg_train_collate_fn ---
dm = TestModule(root_dir="/tmp", processor=processor, batch_size=3, max_tokens=512, ift=False)
reg_batch = dm.reg_train_collate_fn(samples)

print("=== reg_train_collate_fn ===")
print(f"input_ids:     {reg_batch['input_ids'].shape}")
print(f"attention_mask: {reg_batch['attention_mask'].shape}")
print(f"pixel_values:  {reg_batch['pixel_values'].shape}")
print(f"labels:        {reg_batch['labels'].shape}")

assert reg_batch["input_ids"].shape[0] == 3
assert (reg_batch["labels"][reg_batch["input_ids"] == processor.tokenizer.pad_token_id] == -100).all(), \
    "pad tokens should be masked"
if hasattr(processor, "image_token_id"):
    assert (reg_batch["labels"][reg_batch["input_ids"] == processor.image_token_id] == -100).all(), \
        "image tokens should be masked"

# --- ift_train_collate_fn ---
dm_ift = TestModule(root_dir="/tmp", processor=processor, batch_size=3, max_tokens=512, ift=True)
ift_batch = dm_ift.ift_train_collate_fn(samples)

print("\n=== ift_train_collate_fn ===")
print(f"input_ids: {ift_batch['input_ids'].shape}")
print(f"labels:    {ift_batch['labels'].shape}")

reg_trained = (reg_batch["labels"] != -100).sum().item()
ift_trained = (ift_batch["labels"] != -100).sum().item()
print(f"\nreg trained tokens: {reg_trained}")
print(f"ift trained tokens: {ift_trained}")
assert ift_trained < reg_trained, "IFT should have fewer trained tokens (user turns masked)"

for i in range(3):
    n = (ift_batch["labels"][i] != -100).sum().item()
    assert n > 0, f"sample {i} has 0 trained tokens — masking is broken"
    print(f"sample {i}: {n} trained tokens")

print("\nbatch collate: ALL PASSED")

[[<PIL.Image.Image image mode=RGB size=256x256 at 0x155393A9A840>], [<PIL.Image.Image image mode=RGB size=256x256 at 0x155393A9A090>], [<PIL.Image.Image image mode=RGB size=300x200 at 0x155393A999A0>]]
['<bos><start_of_turn>user\nWhat type of scan is this?<start_of_image><end_of_turn>\n<start_of_turn>model\nThis is an axial CT scan of the abdomen.<end_of_turn>\n', '<bos><start_of_turn>user\nDescribe the pathology.<start_of_image><end_of_turn>\n<start_of_turn>model\nThere is a large mass in the left frontal lobe consistent with a glioma.<end_of_turn>\n', '<bos><start_of_turn>user\nWhat do you see?<start_of_image><end_of_turn>\n<start_of_turn>model\nNormal chest radiograph with no acute findings.<end_of_turn>\n']
=== reg_train_collate_fn ===
input_ids:     torch.Size([3, 291])
attention_mask: torch.Size([3, 291])
pixel_values:  torch.Size([3, 3, 896, 896])
labels:        torch.Size([3, 291])

=== ift_train_collate_fn ===
input_ids: torch.Size([3, 291])
labels:    torch.Size([3, 291])

re

In [30]:
samples

[(<PIL.Image.Image image mode=RGB size=256x256>,
  [{'role': 'user',
    'content': [{'type': 'text', 'text': 'What type of scan is this?'},
     {'type': 'image'}]},
   {'role': 'assistant',
    'content': [{'type': 'text',
      'text': 'This is an axial CT scan of the abdomen.'}]}]),
 (<PIL.Image.Image image mode=RGB size=256x256>,
  [{'role': 'user',
    'content': [{'type': 'text', 'text': 'Describe the pathology.'},
     {'type': 'image'}]},
   {'role': 'assistant',
    'content': [{'type': 'text',
      'text': 'There is a large mass in the left frontal lobe consistent with a glioma.'}]}]),
 (<PIL.Image.Image image mode=RGB size=300x200>,
  [{'role': 'user',
    'content': [{'type': 'text', 'text': 'What do you see?'},
     {'type': 'image'}]},
   {'role': 'assistant',
    'content': [{'type': 'text',
      'text': 'Normal chest radiograph with no acute findings.'}]}])]

In [31]:
try:
    med_processor = AutoProcessor.from_pretrained("google/medgemma-4b-it")

    text = med_processor.apply_chat_template(
        single_turn, tokenize=False, add_generation_prompt=False
    )
    batch = med_processor(
        text=[text], images=[fake_image], padding=True, return_tensors="pt"
    )

    # should have the same keys and work the same way
    print(f"MedGemma batch keys: {list(batch.keys())}")
    print(f"input_ids shape: {batch['input_ids'].shape}")

    # verify same token vocabulary for boundary tokens
    med_tok = med_processor.tokenizer
    assert med_tok.convert_tokens_to_ids("<start_of_turn>") == \
           processor.tokenizer.convert_tokens_to_ids("<start_of_turn>"), \
        "start_of_turn token ID mismatch between gemma and medgemma"

    print("MedGemma processor: PASSED (same token IDs as Gemma 3)")
except Exception as e:
    print(f"MedGemma not available: {e}")

MedGemma batch keys: ['input_ids', 'attention_mask', 'token_type_ids', 'pixel_values']
input_ids shape: torch.Size([1, 296])
MedGemma processor: PASSED (same token IDs as Gemma 3)


In [32]:
# %% Cell — Config (edit these)
import sys, os
sys.path.insert(0, os.path.abspath(".."))

DATA_ROOT = "/gpfs/data/oermannlab/public_data/llava-med-data"
DATASET_FILE = "llava_med_instruct_60k_inline_mention_filtered.json"
MODEL_ID = "google/gemma-3-4b-it"  # or "google/medgemma-4b-it"

In [33]:
# %% Cell — Load processor and dataset
import torch
from transformers import AutoProcessor
from datasets.llava_med_dataset import LLaVAMedDataset, LLaVAMedDataModule

processor = AutoProcessor.from_pretrained(MODEL_ID)
processor.tokenizer.padding_side = "right"

# load just the raw dataset (no splits, no collation)
ds = LLaVAMedDataset(
    root_dir=DATA_ROOT,
    dataset_file=DATASET_FILE,
    data_key="conversations",
)
print(f"dataset loaded: {len(ds)} samples")

dataset loaded: 56658 samples


In [34]:
# %% Cell — Inspect a single raw sample
image, convo = ds[0]
print(f"image: {image.size}, mode={image.mode}")
print(f"conversation turns: {len(convo)}")
for i, turn in enumerate(convo):
    role = turn["role"]
    content = turn["content"]
    if isinstance(content, list):
        types = [c["type"] for c in content]
        text = next((c["text"] for c in content if c["type"] == "text"), "")
        print(f"  turn {i}: role={role}  blocks={types}  text={text[:80]}...")
    else:
        print(f"  turn {i}: role={role}  content={str(content)[:80]}...")

image: (800, 579), mode=L
conversation turns: 6
  turn 0: role=user  blocks=['text', 'image']  text=What is the main finding in the image?...
  turn 1: role=assistant  blocks=['text']  text=The main finding in the image is bilateral lower lobe consolidation, which is pr...
  turn 2: role=user  blocks=['text']  text=What does consolidation mean?...
  turn 3: role=assistant  blocks=['text']  text=Consolidation refers to a region of the lung where the air spaces (alveoli) are ...
  turn 4: role=user  blocks=['text']  text=Is there any other finding mentioned in the context?...
  turn 5: role=assistant  blocks=['text']  text=Which is bronchiectasis. bronchiectasis is a chronic lung condition characterize...


In [35]:
# %% Cell — Test apply_chat_template on real sample
image, convo = ds[0]
formatted = processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=False)
print("=== Formatted ===")
print(formatted)

=== Formatted ===
<bos><start_of_turn>user
What is the main finding in the image?<start_of_image><end_of_turn>
<start_of_turn>model
The main finding in the image is bilateral lower lobe consolidation, which is present in both lungs.<end_of_turn>
<start_of_turn>user
What does consolidation mean?<end_of_turn>
<start_of_turn>model
Consolidation refers to a region of the lung where the air spaces (alveoli) are filled with fluid, pus, blood, or cells, making the lung tissue appear more solid and dense on imaging studies like a computed tomography (CT) scan or X-ray. Consolidation can be caused by various conditions, such as pneumonia, pulmonary edema, or lung injury. It is important to consider the patient's clinical history and symptoms, as well as consult a healthcare professional for a thorough evaluation and proper diagnosis of the underlying cause of the consolidation.<end_of_turn>
<start_of_turn>user
Is there any other finding mentioned in the context?<end_of_turn>
<start_of_turn>mode

In [36]:
# %% Cell — Test full processor call (single sample)
image, convo = ds[0]
text = processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=False)
batch = processor(
    text=[text], images=[image],
    padding=True, truncation=True, max_length=4096, return_tensors="pt",
)

print("=== Single sample batch ===")
for k, v in batch.items():
    if hasattr(v, "shape"):
        print(f"  {k:20s} {str(v.shape):30s} dtype={v.dtype}")

# decode and verify roundtrip
decoded = processor.tokenizer.decode(batch["input_ids"][0], skip_special_tokens=False)
print(f"\nDecoded (first 500 chars):\n{decoded[:500]}")

=== Single sample batch ===
  input_ids            torch.Size([1, 502])           dtype=torch.int64
  attention_mask       torch.Size([1, 502])           dtype=torch.int64
  token_type_ids       torch.Size([1, 502])           dtype=torch.int64
  pixel_values         torch.Size([1, 3, 896, 896])   dtype=torch.float32

Decoded (first 500 chars):
<bos><bos><start_of_turn>user
What is the main finding in the image?

<start_of_image><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token>


In [37]:
# %% Cell — Test batch of 1 through actual DataModule collate
from datasets.base_multimodal_dataset import BaseMultimodalDataModule

class TestModule(BaseMultimodalDataModule):
    def setup(self, stage=None):
        pass

dm = TestModule(
    root_dir=DATA_ROOT,
    processor=processor,
    batch_size=1,
    max_tokens=4096,
    ift=True,
)

# grab one sample, wrap in list (that's what collate receives)
sample = [ds[0]]
batch = dm.ift_train_collate_fn(sample)

print("=== IFT collate, batch_size=1 ===")
for k, v in batch.items():
    if hasattr(v, "shape"):
        print(f"  {k:20s} {str(v.shape):30s} dtype={v.dtype}")

n_trained = (batch["labels"] != -100).sum().item()
n_total = (batch["input_ids"] != processor.tokenizer.pad_token_id).sum().item()
print(f"\n  non-pad tokens: {n_total}")
print(f"  trained tokens: {n_trained}")
print(f"  masked tokens:  {n_total - n_trained}")
print(f"  trained ratio:  {n_trained / n_total:.1%}")

# show what's being trained
trained_text = processor.tokenizer.decode(
    [t for t, l in zip(batch["input_ids"][0].tolist(), batch["labels"][0].tolist()) if l != -100]
)
print(f"\n=== Trained text ===\n{trained_text[:500]}")

=== IFT collate, batch_size=1 ===
  input_ids            torch.Size([1, 502])           dtype=torch.int64
  attention_mask       torch.Size([1, 502])           dtype=torch.int64
  token_type_ids       torch.Size([1, 502])           dtype=torch.int64
  pixel_values         torch.Size([1, 3, 896, 896])   dtype=torch.float32
  labels               torch.Size([1, 502])           dtype=torch.int64

  non-pad tokens: 502
  trained tokens: 198
  masked tokens:  304
  trained ratio:  39.4%

=== Trained text ===
<start_of_turn>model
The main finding in the image is bilateral lower lobe consolidation, which is present in both lungs.<end_of_turn><start_of_turn>model
Consolidation refers to a region of the lung where the air spaces (alveoli) are filled with fluid, pus, blood, or cells, making the lung tissue appear more solid and dense on imaging studies like a computed tomography (CT) scan or X-ray. Consolidation can be caused by various conditions, such as pneumonia, pulmonary edema, or lung inj

In [38]:
# %% Cell — Test batch of 4 through actual DataModule collate
samples = [ds[i] for i in range(4)]

batch = dm.ift_train_collate_fn(samples)

print("=== IFT collate, batch_size=4 ===")
for k, v in batch.items():
    if hasattr(v, "shape"):
        print(f"  {k:20s} {str(v.shape):30s} dtype={v.dtype}")

for i in range(4):
    n_trained = (batch["labels"][i] != -100).sum().item()
    n_total = (batch["input_ids"][i] != processor.tokenizer.pad_token_id).sum().item()
    seq_len = batch["input_ids"].shape[1]
    n_pad = (batch["input_ids"][i] == processor.tokenizer.pad_token_id).sum().item()
    print(f"  sample {i}: seq_len={seq_len}  pad={n_pad}  non-pad={n_total}  trained={n_trained}  ratio={n_trained/max(n_total,1):.1%}")

print(f"\n  total batch trained tokens: {(batch['labels'] != -100).sum().item()}")

=== IFT collate, batch_size=4 ===
  input_ids            torch.Size([4, 557])           dtype=torch.int64
  attention_mask       torch.Size([4, 557])           dtype=torch.int64
  token_type_ids       torch.Size([4, 557])           dtype=torch.int64
  pixel_values         torch.Size([4, 3, 896, 896])   dtype=torch.float32
  labels               torch.Size([4, 557])           dtype=torch.int64
  sample 0: seq_len=557  pad=55  non-pad=502  trained=198  ratio=39.4%
  sample 1: seq_len=557  pad=3  non-pad=554  trained=244  ratio=44.0%
  sample 2: seq_len=557  pad=0  non-pad=557  trained=243  ratio=43.6%
  sample 3: seq_len=557  pad=113  non-pad=444  trained=139  ratio=31.3%

  total batch trained tokens: 824


In [39]:
# %% Cell — Test through actual DataLoader (full integration)
dm_full = LLaVAMedDataModule(
    DATA_ROOT,
    processor,
    dataset_file=DATASET_FILE,
    data_key="conversations",
    batch_size=4,
    max_tokens=4096,
    ift=True,
)
dm_full.setup()

print(f"train: {len(dm_full.train_dataset)}  val: {len(dm_full.val_dataset)}  test: {len(dm_full.test_dataset)}")

train_dl = dm_full.train_dataloader()

# grab first batch from the actual dataloader
batch = next(iter(train_dl))

print(f"\n=== First batch from DataLoader ===")
for k, v in batch.items():
    if hasattr(v, "shape"):
        print(f"  {k:20s} {str(v.shape):30s} dtype={v.dtype}")

for i in range(4):
    n_trained = (batch["labels"][i] != -100).sum().item()
    n_total = (batch["input_ids"][i] != processor.tokenizer.pad_token_id).sum().item()
    print(f"  sample {i}: non-pad={n_total}  trained={n_trained}")

print("\nFULL INTEGRATION: PASSED")

train: 45328  val: 5664  test: 5664

=== First batch from DataLoader ===
  input_ids            torch.Size([4, 563])           dtype=torch.int64
  attention_mask       torch.Size([4, 563])           dtype=torch.int64
  token_type_ids       torch.Size([4, 563])           dtype=torch.int64
  pixel_values         torch.Size([4, 3, 896, 896])   dtype=torch.float32
  labels               torch.Size([4, 563])           dtype=torch.int64
  sample 0: non-pad=455  trained=148
  sample 1: non-pad=462  trained=154
  sample 2: non-pad=509  trained=203
  sample 3: non-pad=563  trained=259

FULL INTEGRATION: PASSED


In [43]:
# check if this removes the double <bos>
batch = processor(
    text=[text],
    images=[image],
    padding=True,
    truncation=True,
    max_length=4096,
    return_tensors="pt",
    add_special_tokens=False,  # <-- try this
)
decoded = processor.tokenizer.decode(batch["input_ids"][0][:], skip_special_tokens=False)
print(decoded)

<bos><start_of_turn>user
What is the main finding in the image?

<start_of_image><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><

In [42]:
# check if this removes the double <bos>
batch = processor(
    text=[text],
    images=[image],
    padding=True,
    truncation=True,
    max_length=4096,
    return_tensors="pt",
    add_special_tokens=True,  # <-- try this
)
decoded = processor.tokenizer.decode(batch["input_ids"][0][:], skip_special_tokens=False)
print(decoded)

<bos><bos><start_of_turn>user
What is the main finding in the image?

<start_of_image><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_token><image_soft_to

In [44]:
# %% Cell — Multi-turn batch test (4 samples, each multi-turn, 1 image)
import json

# grab 4 multi-turn samples from the actual dataset
multi_turn_samples = []
for i in range(len(ds)):
    image, convo = ds[i]
    if len(convo) >= 6:  # at least 3 Q&A pairs
        multi_turn_samples.append((image, convo))
    if len(multi_turn_samples) == 4:
        break

print(f"found {len(multi_turn_samples)} multi-turn samples")
for i, (img, convo) in enumerate(multi_turn_samples):
    n_turns = len(convo)
    n_user = sum(1 for t in convo if t["role"] == "user")
    n_asst = sum(1 for t in convo if t["role"] == "assistant")
    has_image = any(
        b.get("type") == "image"
        for t in convo
        for b in (t["content"] if isinstance(t["content"], list) else [])
    )
    print(f"  sample {i}: {n_turns} turns ({n_user} user, {n_asst} assistant), image={has_image}, img_size={img.size}")

# run through ift collate
batch = dm.ift_train_collate_fn(multi_turn_samples)

print(f"\n=== IFT collate, batch of {len(multi_turn_samples)} multi-turn ===")
for k, v in batch.items():
    if hasattr(v, "shape"):
        print(f"  {k:20s} {str(v.shape):30s} dtype={v.dtype}")

# per-sample breakdown
for i in range(len(multi_turn_samples)):
    input_ids = batch["input_ids"][i]
    labels = batch["labels"][i]
    
    n_total = (input_ids != processor.tokenizer.pad_token_id).sum().item()
    n_trained = (labels != -100).sum().item()
    n_pad = (input_ids == processor.tokenizer.pad_token_id).sum().item()
    
    # count how many model turns are being trained
    trained_text = processor.tokenizer.decode(
        [t for t, l in zip(input_ids.tolist(), labels.tolist()) if l != -100]
    )
    n_model_turns_trained = trained_text.count("<start_of_turn>model")
    n_model_turns_expected = sum(1 for t in multi_turn_samples[i][1] if t["role"] == "assistant")
    
    print(f"\n  sample {i}: non-pad={n_total}  pad={n_pad}  trained={n_trained}  ratio={n_trained/max(n_total,1):.1%}")
    print(f"    model turns trained: {n_model_turns_trained} / {n_model_turns_expected} expected")
    
    assert n_model_turns_trained == n_model_turns_expected, \
        f"sample {i}: trained {n_model_turns_trained} model turns but expected {n_model_turns_expected}"
    assert n_trained > 0, f"sample {i} has 0 trained tokens"

    # verify no user text leaked
    for t in multi_turn_samples[i][1]:
        if t["role"] == "user":
            for block in t["content"]:
                if block.get("type") == "text":
                    # grab first 30 chars of user question as a check
                    snippet = block["text"][:30]
                    assert snippet not in trained_text, \
                        f"sample {i}: user text leaked into trained tokens: '{snippet}'"

print("\n\nmulti-turn batch: ALL PASSED")

found 4 multi-turn samples
  sample 0: 6 turns (3 user, 3 assistant), image=True, img_size=(800, 579)
  sample 1: 6 turns (3 user, 3 assistant), image=True, img_size=(666, 797)
  sample 2: 6 turns (3 user, 3 assistant), image=True, img_size=(414, 414)
  sample 3: 6 turns (3 user, 3 assistant), image=True, img_size=(782, 583)

=== IFT collate, batch of 4 multi-turn ===
  input_ids            torch.Size([4, 556])           dtype=torch.int64
  attention_mask       torch.Size([4, 556])           dtype=torch.int64
  token_type_ids       torch.Size([4, 556])           dtype=torch.int64
  pixel_values         torch.Size([4, 3, 896, 896])   dtype=torch.float32
  labels               torch.Size([4, 556])           dtype=torch.int64

  sample 0: non-pad=501  pad=55  trained=198  ratio=39.5%
    model turns trained: 3 / 3 expected

  sample 1: non-pad=553  pad=3  trained=244  ratio=44.1%
    model turns trained: 3 / 3 expected

  sample 2: non-pad=556  pad=0  trained=243  ratio=43.7%
    model tu

In [45]:
# quick check: where are the pad tokens?
for i in range(4):
    ids = batch["input_ids"][i]
    pad_positions = (ids == processor.tokenizer.pad_token_id).nonzero(as_tuple=True)[0]
    if len(pad_positions) > 0:
        print(f"sample {i}: pad at positions {pad_positions[0].item()}..{pad_positions[-1].item()} (end={ids.shape[0]-1})")
    else:
        print(f"sample {i}: no padding")

sample 0: pad at positions 501..555 (end=555)
sample 1: pad at positions 553..555 (end=555)
sample 2: no padding
sample 3: pad at positions 443..555 (end=555)


In [46]:
# %% Cell — Load model and extract activations from the batch
from transformers import AutoModelForImageTextToText
import gc

model_id = "google/medgemma-4b-it"  # or "google/gemma-3-4b-it"

model = AutoModelForImageTextToText.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
model.eval()
print(f"loaded {model_id}")

`torch_dtype` is deprecated! Use `dtype` instead!
Loading weights: 100%|██████████| 883/883 [00:13<00:00, 66.65it/s, Materializing param=model.vision_tower.vision_model.post_layernorm.weight]                       


loaded google/medgemma-4b-it


In [52]:
# %% Cell — Forward pass, extract hidden states
# move batch to device
device = model.device
inputs = {k: v.to(device) for k, v in batch.items() if k != "labels"}

# cast pixel values to match model dtype
inputs["pixel_values"] = inputs["pixel_values"].to(dtype=torch.bfloat16)

with torch.no_grad():
    outputs = model(**inputs, output_hidden_states=True)

print(f"layers: {len(outputs.hidden_states)}")
print(f"shape per layer: {outputs.hidden_states[0].shape}")
print(f"  (batch={outputs.hidden_states[0].shape[0]}, "
      f"seq_len={outputs.hidden_states[0].shape[1]}, "
      f"hidden_dim={outputs.hidden_states[0].shape[2]})")

layers: 35
shape per layer: torch.Size([4, 556, 2560])
  (batch=4, seq_len=556, hidden_dim=2560)


In [56]:
len(outputs.hidden_states)

35

In [48]:
# %% Cell — Extract assistant-only activations per sample
# for each sample, find where the model (assistant) turns are
# using the same boundary logic as the collate

start_of_turn_id = processor.tokenizer.convert_tokens_to_ids("<start_of_turn>")
end_of_turn_id = processor.tokenizer.convert_tokens_to_ids("<end_of_turn>")
model_token_id = processor.tokenizer.convert_tokens_to_ids("model")

all_sample_activations = []  # list of dicts per sample

for i in range(batch["input_ids"].shape[0]):
    ids = batch["input_ids"][i]
    
    # find assistant token ranges
    assistant_ranges = []
    in_model_turn = False
    turn_start = None
    
    for j in range(len(ids)):
        tok = ids[j].item()
        if tok == start_of_turn_id:
            if j + 1 < len(ids) and ids[j + 1].item() == model_token_id:
                in_model_turn = True
                turn_start = j
            else:
                in_model_turn = False
        if tok == end_of_turn_id and in_model_turn:
            assistant_ranges.append((turn_start, j + 1))  # inclusive of end_of_turn
            in_model_turn = False
    
    # extract activations for assistant ranges across all layers
    sample_hs = {}
    for layer_idx, hs in enumerate(outputs.hidden_states):
        layer_acts = []
        for start, end in assistant_ranges:
            layer_acts.append(hs[i, start:end, :].detach().cpu().float())
        # concatenate all assistant turns for this layer
        if layer_acts:
            sample_hs[layer_idx] = torch.cat(layer_acts, dim=0)
    
    all_sample_activations.append({
        "assistant_ranges": assistant_ranges,
        "hidden_states": sample_hs,
    })
    
    # print summary
    total_asst_tokens = sum(e - s for s, e in assistant_ranges)
    decoded_asst = processor.tokenizer.decode(
        [ids[j].item() for s, e in assistant_ranges for j in range(s, e)]
    )
    print(f"sample {i}: {len(assistant_ranges)} assistant turns, {total_asst_tokens} tokens")
    print(f"  ranges: {assistant_ranges}")
    print(f"  text: {decoded_asst[:120]}...")
    print(f"  last layer shape: {sample_hs[len(outputs.hidden_states)-1].shape}")
    print()

sample 0: 3 assistant turns, 198 tokens
  ranges: [(275, 298), (309, 424), (440, 500)]
  text: <start_of_turn>model
The main finding in the image is bilateral lower lobe consolidation, which is present in both lungs...
  last layer shape: torch.Size([198, 2560])

sample 1: 3 assistant turns, 244 tokens
  ranges: [(275, 322), (336, 408), (427, 552)]
  text: <start_of_turn>model
The CT scan is presented in a coronal view. This view is a vertical plane that divides the body int...
  last layer shape: torch.Size([244, 2560])

sample 2: 3 assistant turns, 243 tokens
  ranges: [(279, 328), (341, 452), (472, 555)]
  text: <start_of_turn>model
The main difference between this image and the previous one is the marked decrease in the size of t...
  last layer shape: torch.Size([243, 2560])

sample 3: 3 assistant turns, 139 tokens
  ranges: [(276, 316), (330, 367), (380, 442)]
  text: <start_of_turn>model
The image shows a histopathologic photograph of a keloid, which is a type of scar tissue tha

In [49]:
# %% Cell — Sanity check: activation stats
n_layers = len(outputs.hidden_states)

print("=== Activation stats (assistant tokens only) ===")
for sample_idx in range(len(all_sample_activations)):
    hs = all_sample_activations[sample_idx]["hidden_states"]
    first = hs[0]
    mid = hs[n_layers // 2]
    last = hs[n_layers - 1]
    
    print(f"\nsample {sample_idx} ({first.shape[0]} assistant tokens):")
    print(f"  layer  0:  mean={first.mean():.4f}  std={first.std():.4f}  min={first.min():.4f}  max={first.max():.4f}")
    print(f"  layer {n_layers//2:2d}:  mean={mid.mean():.4f}  std={mid.std():.4f}  min={mid.min():.4f}  max={mid.max():.4f}")
    print(f"  layer {n_layers-1:2d}:  mean={last.mean():.4f}  std={last.std():.4f}  min={last.min():.4f}  max={last.max():.4f}")

# verify no NaNs or Infs
for sample_idx in range(len(all_sample_activations)):
    for layer_idx, acts in all_sample_activations[sample_idx]["hidden_states"].items():
        assert not torch.isnan(acts).any(), f"NaN in sample {sample_idx} layer {layer_idx}"
        assert not torch.isinf(acts).any(), f"Inf in sample {sample_idx} layer {layer_idx}"

print("\nno NaNs or Infs: PASSED")

=== Activation stats (assistant tokens only) ===

sample 0 (198 assistant tokens):
  layer  0:  mean=0.0071  std=1.0164  min=-28.7500  max=17.5000
  layer 17:  mean=12.2498  std=610.5491  min=-1128.0000  max=38912.0000
  layer 34:  mean=0.0501  std=3.6441  min=-144.0000  max=104.0000

sample 1 (244 assistant tokens):
  layer  0:  mean=0.0112  std=1.0165  min=-28.7500  max=17.5000
  layer 17:  mean=12.2320  std=616.3966  min=-1152.0000  max=37376.0000
  layer 34:  mean=0.0537  std=3.5127  min=-133.0000  max=103.0000

sample 2 (243 assistant tokens):
  layer  0:  mean=0.0099  std=1.0172  min=-28.7500  max=17.5000
  layer 17:  mean=12.4186  std=618.2217  min=-1096.0000  max=40448.0000
  layer 34:  mean=0.0690  std=3.5310  min=-107.0000  max=113.5000

sample 3 (139 assistant tokens):
  layer  0:  mean=0.0069  std=1.0159  min=-28.7500  max=17.5000
  layer 17:  mean=12.4201  std=616.0704  min=-1016.0000  max=37376.0000
  layer 34:  mean=0.0772  std=4.2014  min=-157.0000  max=128.0000

no NaN

In [50]:
# %% Cell — Cleanup
del outputs, inputs
gc.collect()
torch.cuda.empty_cache()
print(f"GPU mem after cleanup: {torch.cuda.memory_allocated()/1e9:.1f} GB")

GPU mem after cleanup: 8.6 GB


In [57]:
# %% Cell — Extract vision encoder hidden states

# option 1: pass pixel_values directly through the vision tower
pixel_values = batch["pixel_values"].to(device, dtype=torch.bfloat16)

with torch.no_grad():
    vision_outputs = model.model.vision_tower(
        pixel_values, 
        output_hidden_states=True,
    )

print(f"vision layers: {len(vision_outputs.hidden_states)}")
print(f"shape per layer: {vision_outputs.hidden_states[0].shape}")
print(f"  (batch={vision_outputs.hidden_states[0].shape[0]}, "
      f"patches={vision_outputs.hidden_states[0].shape[1]}, "
      f"vision_dim={vision_outputs.hidden_states[0].shape[2]})")

# stats
for idx in [0, len(vision_outputs.hidden_states)//2, -1]:
    hs = vision_outputs.hidden_states[idx].float()
    label = {0: "first", -1: "last"}.get(idx, f"mid ({idx})")
    print(f"  {label:8s}: mean={hs.mean():.4f}  std={hs.std():.4f}")

vision layers: 28
shape per layer: torch.Size([4, 4096, 1152])
  (batch=4, patches=4096, vision_dim=1152)
  first   : mean=0.3025  std=1.9446
  mid (14): mean=0.0445  std=15.1405
  last    : mean=-1.3789  std=99.9512


In [58]:
# %% Cell — Vision activations: before and after projection

pixel_values = batch["pixel_values"].to(device, dtype=torch.bfloat16)

# Stage 1: raw vision encoder (you already have this)
with torch.no_grad():
    vision_outputs = model.model.vision_tower(
        pixel_values,
        output_hidden_states=True,
    )

vision_last = vision_outputs.hidden_states[-1]  # or .last_hidden_state
print(f"Stage 1 — Vision encoder output: {vision_last.shape}")
print(f"  (batch, patches={vision_last.shape[1]}, vision_dim={vision_last.shape[2]})")

# Stage 2: after multimodal projector (4096 patches -> 256 tokens, 1152 -> 2560)
with torch.no_grad():
    projected = model.model.multi_modal_projector(vision_last)

print(f"Stage 2 — After projector:       {projected.shape}")
print(f"  (batch, compressed={projected.shape[1]}, llm_dim={projected.shape[2]})")

# Stage 3: find the image token positions in the LLM hidden states
# these are the same 256 positions, but after going through LLM attention layers
image_token_id = processor.image_token_id
sample_ids = batch["input_ids"][0]
image_positions = (sample_ids == image_token_id).nonzero(as_tuple=True)[0]
print(f"Stage 3 — Image positions in LLM: {len(image_positions)} tokens at positions {image_positions[0].item()}..{image_positions[-1].item()}")

# extract LLM hidden states at image positions for each layer
# (you already have outputs.hidden_states from the earlier cell)
llm_image_hs_last = outputs.hidden_states[-1][0, image_positions, :].detach().cpu().float()
print(f"  LLM last layer at image positions: {llm_image_hs_last.shape}")

# compare all three
print(f"\n=== Dimension summary ===")
print(f"  vision encoder:  {vision_last.shape[1]} patches × {vision_last.shape[2]}d")
print(f"  after projector: {projected.shape[1]} tokens × {projected.shape[2]}d")
print(f"  LLM image slots: {len(image_positions)} tokens × {outputs.hidden_states[-1].shape[2]}d")

Stage 1 — Vision encoder output: torch.Size([4, 4096, 1152])
  (batch, patches=4096, vision_dim=1152)
Stage 2 — After projector:       torch.Size([4, 256, 2560])
  (batch, compressed=256, llm_dim=2560)
Stage 3 — Image positions in LLM: 256 tokens at positions 15..270
  LLM last layer at image positions: torch.Size([256, 2560])

=== Dimension summary ===
  vision encoder:  4096 patches × 1152d
  after projector: 256 tokens × 2560d
  LLM image slots: 256 tokens × 2560d
