I conduct tests using this model and perform multiple rollouts. The cosine similarity of the generated embeddings reached approximately 0.98 or higher.
I use a test script for the evaluation, and the results are as follows:
======================================================================
TEST 1: Forward pass - same image input, different <think> content
======================================================================
[think_A] <gen_emb> positions: [236, 269], total tokens: 270
[think_B] <gen_emb> positions: [236, 271], total tokens: 272
[no_think] <gen_emb> positions: [236, 258], total tokens: 259
Cosine similarities:
think_A vs think_B: 0.953125
think_A vs no_think: 0.976562
think_B vs no_think: 0.976562
======================================================================
TEST 2: Natural generation vs manually altered think content
======================================================================
Generating with model.generate()...
Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.
Generated text:
<think>So, let's analyze the input. The image shows a puppy and a kitten on grass. The key elements are the two animals, their interaction (puppy looking at kitten), and the grassy background. The question is asking "What is in the image?" so we need to identify the main subjects. The image has a puppy and a kitten as the main entities. So the essence is the animals, probably describing the subjects.</think><answer>A puppy and a kitten on grass
(Knowing the image has a puppy and a kitten interacting, the key elements are the two animals in a grassy setting. Synthesizing, the main subjects are the puppy and kitten.)
<gen_emb><|im_end|>
<gen_emb> found in natural generation at shifted position(s): [235, 388]
Cosine similarity (natural think vs altered think): 0.992188
>>> RESULT: gen_emb has WEAK dependency on think content
======================================================================
TEST 3: Multiple rollouts of the same input (sampling variance)
======================================================================
--- Rollout 1/5 ---
Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.
<gen_emb> found, completion length: 163 tokens
Think (first 100 chars): So, let's analyze the input. The image shows a puppy and a kitten sitting on grass. The question is ...
--- Rollout 2/5 ---
Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.
<gen_emb> found, completion length: 112 tokens
Think (first 100 chars): So, let's analyze the input. The image shows a puppy and a kitten on grass. The key elements are the...
--- Rollout 3/5 ---
Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.
<gen_emb> found, completion length: 125 tokens
Think (first 100 chars): So, let's analyze the input. The image shows a yellow Labrador puppy and a tabby kitten sitting on g...
--- Rollout 4/5 ---
Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.
<gen_emb> found, completion length: 175 tokens
Think (first 100 chars): So, let's analyze the input. The image shows a yellow puppy and a small kitten interacting on grass....
--- Rollout 5/5 ---
Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.
<gen_emb> found, completion length: 258 tokens
Think (first 100 chars): So, let's analyze the input. The image shows a puppy and a kitten on a grassy area. The question is ...
--- Pairwise cosine similarities across 5 rollouts ---
Rollout 1 vs Rollout 2: 1.000000
Rollout 1 vs Rollout 3: 0.996094
Rollout 1 vs Rollout 4: 1.000000
Rollout 1 vs Rollout 5: 0.996094
Rollout 2 vs Rollout 3: 1.000000
Rollout 2 vs Rollout 4: 1.000000
Rollout 2 vs Rollout 5: 0.996094
Rollout 3 vs Rollout 4: 0.996094
Rollout 3 vs Rollout 5: 0.992188
Rollout 4 vs Rollout 5: 0.996094
Average: 0.997266, Min: 0.992188, Max: 1.000000
>>> RESULT: gen_emb has SMALL variance across rollouts
--- Are think contents different across rollouts? ---
Rollout 1 vs 2: think content DIFFERENT
Rollout 1 vs 3: think content DIFFERENT
Rollout 1 vs 4: think content DIFFERENT
Rollout 1 vs 5: think content DIFFERENT
Rollout 2 vs 3: think content DIFFERENT
Rollout 2 vs 4: think content DIFFERENT
Rollout 2 vs 5: think content DIFFERENT
Rollout 3 vs 4: think content DIFFERENT
Rollout 3 vs 5: think content DIFFERENT
Rollout 4 vs 5: think content DIFFERENT
======================================================================
DONE
======================================================================
import re
import torch
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
# ─── Config ───────────────────────────────────────────────────────────────────
pretrained_path = "zhibinlan/UME-R1-2B"
device = "cuda:0"
# ─── Load model & processor ──────────────────────────────────────────────────
print("Loading model...")
model = Qwen2VLForConditionalGeneration.from_pretrained(
pretrained_path,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map=device,
)
model.eval()
processor = AutoProcessor.from_pretrained(pretrained_path)
tokenizer = processor.tokenizer
gen_emb_id = tokenizer.get_vocab()["<gen_emb>"]
print(f"<gen_emb> token id: {gen_emb_id}")
def normalize(t):
return torch.nn.functional.normalize(t, p=2, dim=-1)
def extract_gen_emb_from_forward(input_ids, attention_mask, **multimodal_inputs):
"""Full forward pass, extract hidden state at <gen_emb> position (train-aligned approach)."""
with torch.no_grad():
output = model(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
return_dict=True,
**multimodal_inputs,
)
# Use shifted approach consistent with training code
shifted_ids = input_ids[:, 1:]
embeddings = output.hidden_states[-1][:, 1:, :]
result = []
for i in range(shifted_ids.size(0)):
mask = shifted_ids[i] == gen_emb_id
if mask.any():
result.append(embeddings[i][mask][-1])
else:
result.append(embeddings[i][-1])
return normalize(torch.stack(result))
# ─── Common input: image + text (same as demo) ───────────────────────────────
IMAGE_PATH = "assets/example.jpg"
base_text = "Represent the given image with the following question: What is in the image?\n<disc_emb>\n"
prompt_suffix = (
"Represent the above input text, images, videos, or any combination of the three as embeddings. "
"First output the thinking process in <think> </think> tags and then summarize the entire input in a word or sentence. "
"Finally, use the <gen_emb> tag to represent the entire input."
)
MULTIMODAL_KEYS = ["pixel_values", "pixel_values_videos", "image_grid_thw", "video_grid_thw", "second_per_grid_ts"]
def build_multimodal_inputs(messages):
"""Build tokenized inputs with image pixel values from messages."""
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
).to(device)
return inputs
def get_multimodal_kwargs(inputs):
"""Extract multimodal keyword args from inputs dict."""
return {k: inputs[k] for k in MULTIMODAL_KEYS if k in inputs}
def append_completion_ids(prompt_input_ids, prompt_attention_mask, completion_text):
"""Tokenize completion_text and concatenate it after the processor-encoded prompt ids.
This avoids re-tokenizing the whole text (which would lose the expanded image placeholder tokens).
Returns (input_ids, attention_mask) with shape (1, prompt_len + completion_len).
"""
comp_ids = tokenizer(completion_text, add_special_tokens=False, return_tensors="pt")["input_ids"].to(device)
input_ids = torch.cat([prompt_input_ids, comp_ids], dim=1)
attention_mask = torch.cat([prompt_attention_mask, torch.ones_like(comp_ids)], dim=1)
return input_ids, attention_mask
# Build the base message (used by all tests)
base_messages = [
{
"role": "user",
"content": [
{"type": "image", "image": IMAGE_PATH},
{"type": "text", "text": base_text + prompt_suffix},
],
}
]
# ─── Test 1: Forward pass with different <think> content ─────────────────────
print("\n" + "=" * 70)
print("TEST 1: Forward pass - same image input, different <think> content")
print("=" * 70)
# Three different "completions" with different think content
think_variants = {
"think_A": "<think>\nThe image shows a cat and a dog sitting together.\n</think>\ncat and dog<gen_emb>",
"think_B": "<think>\nThis is a completely different reasoning chain about quantum physics and mathematics.\n</think>\ncat and dog<gen_emb>",
"no_think": "<think>\n\n</think>\ncat and dog<gen_emb>",
}
# Get processor-encoded prompt and pixel values
base_inputs = build_multimodal_inputs(base_messages)
mm_kwargs = get_multimodal_kwargs(base_inputs)
embeddings = {}
for name, completion in think_variants.items():
# Append completion tokens after the processor-encoded prompt (preserves image tokens)
input_ids, attn_mask = append_completion_ids(
base_inputs["input_ids"], base_inputs["attention_mask"], completion
)
emb = extract_gen_emb_from_forward(input_ids, attn_mask, **mm_kwargs)
embeddings[name] = emb
ids = input_ids[0].tolist()
gen_pos = [i for i, t in enumerate(ids) if t == gen_emb_id]
print(f" [{name}] <gen_emb> positions: {gen_pos}, total tokens: {len(ids)}")
# Compare all pairs
print("\nCosine similarities:")
names = list(embeddings.keys())
for i in range(len(names)):
for j in range(i + 1, len(names)):
sim = (embeddings[names[i]] @ embeddings[names[j]].T).item()
print(f" {names[i]} vs {names[j]}: {sim:.6f}")
# ─── Test 2: Generate naturally, then compare with shuffled think ─────────────
print("\n" + "=" * 70)
print("TEST 2: Natural generation vs manually altered think content")
print("=" * 70)
messages = base_messages
inputs = build_multimodal_inputs(messages)
mm_kwargs_t2 = get_multimodal_kwargs(inputs)
print("Generating with model.generate()...")
with torch.no_grad():
gen_output = model.generate(
**inputs,
max_new_tokens=512,
output_hidden_states=True,
return_dict_in_generate=True,
use_cache=True,
)
generated_ids = gen_output.sequences
generated_text = tokenizer.decode(generated_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=False)
print(f"Generated text:\n{generated_text}\n")
# Extract gen_emb from natural generation (train_aligned: full forward pass)
prompt_ids = inputs["input_ids"]
prompt_mask = inputs["attention_mask"]
completion_ids = generated_ids[:, prompt_ids.size(1):]
attention_mask = torch.cat([prompt_mask, torch.ones_like(completion_ids)], dim=1)
with torch.no_grad():
output = model(
input_ids=generated_ids,
attention_mask=attention_mask,
output_hidden_states=True,
return_dict=True,
**mm_kwargs_t2,
)
shifted_ids = generated_ids[:, 1:]
hidden = output.hidden_states[-1][:, 1:, :]
mask = shifted_ids[0] == gen_emb_id
if mask.any():
natural_emb = normalize(hidden[0][mask][-1].unsqueeze(0))
print(f"<gen_emb> found in natural generation at shifted position(s): {mask.nonzero().flatten().tolist()}")
else:
natural_emb = normalize(hidden[0][-1].unsqueeze(0))
print("WARNING: <gen_emb> NOT found in natural generation, using last token")
# Now replace think content in the generated completion and re-forward
# Only re-tokenize the completion part (no image tokens), then prepend original prompt ids
altered_completion = re.sub(
r"<think>.*?</think>",
"<think>\nThis is totally unrelated filler text about quantum physics and black holes.\n</think>",
generated_text,
flags=re.DOTALL,
)
altered_ids, altered_mask = append_completion_ids(
prompt_ids, prompt_mask, altered_completion
)
altered_emb = extract_gen_emb_from_forward(altered_ids, altered_mask, **mm_kwargs_t2)
sim = (natural_emb @ altered_emb.T).item()
print(f"\nCosine similarity (natural think vs altered think): {sim:.6f}")
if sim > 0.999:
print(">>> RESULT: gen_emb appears INDEPENDENT of think content (similarity ~1.0)")
elif sim > 0.95:
print(">>> RESULT: gen_emb has WEAK dependency on think content")
else:
print(">>> RESULT: gen_emb has STRONG dependency on think content")
# ─── Test 3: Multiple rollouts - same input, different sampling results ───────
print("\n" + "=" * 70)
print("TEST 3: Multiple rollouts of the same input (sampling variance)")
print("=" * 70)
NUM_ROLLOUTS = 5
rollout_embs = []
rollout_texts = []
inputs_r = build_multimodal_inputs(base_messages)
mm_kwargs_t3 = get_multimodal_kwargs(inputs_r)
for r in range(NUM_ROLLOUTS):
print(f"\n--- Rollout {r+1}/{NUM_ROLLOUTS} ---")
with torch.no_grad():
gen_out = model.generate(
**inputs_r,
max_new_tokens=512,
do_sample=True,
temperature=1.0,
top_p=1.0,
output_hidden_states=True,
return_dict_in_generate=True,
use_cache=True,
)
gen_ids = gen_out.sequences
completion_text = tokenizer.decode(gen_ids[0][inputs_r["input_ids"].shape[1]:], skip_special_tokens=False)
rollout_texts.append(completion_text)
# Extract gen_emb via train_aligned forward pass
p_ids = inputs_r["input_ids"]
p_mask = inputs_r["attention_mask"]
c_ids = gen_ids[:, p_ids.size(1):]
attn_mask = torch.cat([p_mask, torch.ones_like(c_ids)], dim=1)
with torch.no_grad():
out = model(
input_ids=gen_ids,
attention_mask=attn_mask,
output_hidden_states=True,
return_dict=True,
**mm_kwargs_t3,
)
s_ids = gen_ids[:, 1:]
h = out.hidden_states[-1][:, 1:, :]
m = s_ids[0] == gen_emb_id
if m.any():
emb = normalize(h[0][m][-1].unsqueeze(0))
print(f" <gen_emb> found, completion length: {c_ids.shape[1]} tokens")
else:
emb = normalize(h[0][-1].unsqueeze(0))
print(f" WARNING: <gen_emb> NOT found, using last token")
rollout_embs.append(emb)
# Print a short summary of think content
think_match = re.search(r"<think>(.*?)</think>", completion_text, re.DOTALL)
if think_match:
think_text = think_match.group(1).strip()
print(f" Think (first 100 chars): {think_text[:100]}...")
else:
print(f" No <think> block found")
# Pairwise cosine similarity
print(f"\n--- Pairwise cosine similarities across {NUM_ROLLOUTS} rollouts ---")
for i in range(NUM_ROLLOUTS):
for j in range(i + 1, NUM_ROLLOUTS):
sim = (rollout_embs[i] @ rollout_embs[j].T).item()
print(f" Rollout {i+1} vs Rollout {j+1}: {sim:.6f}")
# Stats
all_sims = []
for i in range(NUM_ROLLOUTS):
for j in range(i + 1, NUM_ROLLOUTS):
all_sims.append((rollout_embs[i] @ rollout_embs[j].T).item())
avg_sim = sum(all_sims) / len(all_sims)
min_sim = min(all_sims)
max_sim = max(all_sims)
print(f"\n Average: {avg_sim:.6f}, Min: {min_sim:.6f}, Max: {max_sim:.6f}")
if avg_sim > 0.999:
print(">>> RESULT: gen_emb is nearly IDENTICAL across rollouts despite different think content")
elif avg_sim > 0.95:
print(">>> RESULT: gen_emb has SMALL variance across rollouts")
else:
print(">>> RESULT: gen_emb has LARGE variance across rollouts - think content matters")
# Also check: are the think contents actually different?
print(f"\n--- Are think contents different across rollouts? ---")
for i in range(NUM_ROLLOUTS):
think_i = re.search(r"<think>(.*?)</think>", rollout_texts[i], re.DOTALL)
text_i = think_i.group(1).strip() if think_i else ""
for j in range(i + 1, NUM_ROLLOUTS):
think_j = re.search(r"<think>(.*?)</think>", rollout_texts[j], re.DOTALL)
text_j = think_j.group(1).strip() if think_j else ""
same = text_i == text_j
print(f" Rollout {i+1} vs {j+1}: think content {'SAME' if same else 'DIFFERENT'}")
print("\n" + "=" * 70)
print("DONE")
print("=" * 70)
I conduct tests using this model and perform multiple rollouts. The cosine similarity of the generated embeddings reached approximately 0.98 or higher.
I use a test script for the evaluation, and the results are as follows:
The test script is as follows: