Skip to content

Why gen_emb is unrelated to the think trajectory? #12

@Wakings

Description

@Wakings

I conduct tests using this model and perform multiple rollouts. The cosine similarity of the generated embeddings reached approximately 0.98 or higher.
I use a test script for the evaluation, and the results are as follows:

======================================================================
TEST 1: Forward pass - same image input, different <think> content
======================================================================
  [think_A] <gen_emb> positions: [236, 269], total tokens: 270
  [think_B] <gen_emb> positions: [236, 271], total tokens: 272
  [no_think] <gen_emb> positions: [236, 258], total tokens: 259

Cosine similarities:
  think_A vs think_B: 0.953125
  think_A vs no_think: 0.976562
  think_B vs no_think: 0.976562

======================================================================
TEST 2: Natural generation vs manually altered think content
======================================================================
Generating with model.generate()...
Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.
Generated text:
<think>So, let's analyze the input. The image shows a puppy and a kitten on grass. The key elements are the two animals, their interaction (puppy looking at kitten), and the grassy background. The question is asking "What is in the image?" so we need to identify the main subjects. The image has a puppy and a kitten as the main entities. So the essence is the animals, probably describing the subjects.</think><answer>A puppy and a kitten on grass
(Knowing the image has a puppy and a kitten interacting, the key elements are the two animals in a grassy setting. Synthesizing, the main subjects are the puppy and kitten.)
<gen_emb><|im_end|>

<gen_emb> found in natural generation at shifted position(s): [235, 388]

Cosine similarity (natural think vs altered think): 0.992188
>>> RESULT: gen_emb has WEAK dependency on think content

======================================================================
TEST 3: Multiple rollouts of the same input (sampling variance)
======================================================================

--- Rollout 1/5 ---
Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.
  <gen_emb> found, completion length: 163 tokens
  Think (first 100 chars): So, let's analyze the input. The image shows a puppy and a kitten sitting on grass. The question is ...

--- Rollout 2/5 ---
Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.
  <gen_emb> found, completion length: 112 tokens
  Think (first 100 chars): So, let's analyze the input. The image shows a puppy and a kitten on grass. The key elements are the...

--- Rollout 3/5 ---
Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.
  <gen_emb> found, completion length: 125 tokens
  Think (first 100 chars): So, let's analyze the input. The image shows a yellow Labrador puppy and a tabby kitten sitting on g...

--- Rollout 4/5 ---
Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.
  <gen_emb> found, completion length: 175 tokens
  Think (first 100 chars): So, let's analyze the input. The image shows a yellow puppy and a small kitten interacting on grass....

--- Rollout 5/5 ---
Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.
  <gen_emb> found, completion length: 258 tokens
  Think (first 100 chars): So, let's analyze the input. The image shows a puppy and a kitten on a grassy area. The question is ...

--- Pairwise cosine similarities across 5 rollouts ---
  Rollout 1 vs Rollout 2: 1.000000
  Rollout 1 vs Rollout 3: 0.996094
  Rollout 1 vs Rollout 4: 1.000000
  Rollout 1 vs Rollout 5: 0.996094
  Rollout 2 vs Rollout 3: 1.000000
  Rollout 2 vs Rollout 4: 1.000000
  Rollout 2 vs Rollout 5: 0.996094
  Rollout 3 vs Rollout 4: 0.996094
  Rollout 3 vs Rollout 5: 0.992188
  Rollout 4 vs Rollout 5: 0.996094

  Average: 0.997266, Min: 0.992188, Max: 1.000000
>>> RESULT: gen_emb has SMALL variance across rollouts

--- Are think contents different across rollouts? ---
  Rollout 1 vs 2: think content DIFFERENT
  Rollout 1 vs 3: think content DIFFERENT
  Rollout 1 vs 4: think content DIFFERENT
  Rollout 1 vs 5: think content DIFFERENT
  Rollout 2 vs 3: think content DIFFERENT
  Rollout 2 vs 4: think content DIFFERENT
  Rollout 2 vs 5: think content DIFFERENT
  Rollout 3 vs 4: think content DIFFERENT
  Rollout 3 vs 5: think content DIFFERENT
  Rollout 4 vs 5: think content DIFFERENT

======================================================================
DONE
======================================================================

The test script is as follows:

import re
import torch
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info

# ─── Config ───────────────────────────────────────────────────────────────────
pretrained_path = "zhibinlan/UME-R1-2B"
device = "cuda:0"

# ─── Load model & processor ──────────────────────────────────────────────────
print("Loading model...")
model = Qwen2VLForConditionalGeneration.from_pretrained(
    pretrained_path,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    device_map=device,
)
model.eval()
processor = AutoProcessor.from_pretrained(pretrained_path)
tokenizer = processor.tokenizer

gen_emb_id = tokenizer.get_vocab()["<gen_emb>"]
print(f"<gen_emb> token id: {gen_emb_id}")


def normalize(t):
    return torch.nn.functional.normalize(t, p=2, dim=-1)


def extract_gen_emb_from_forward(input_ids, attention_mask, **multimodal_inputs):
    """Full forward pass, extract hidden state at <gen_emb> position (train-aligned approach)."""
    with torch.no_grad():
        output = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            return_dict=True,
            **multimodal_inputs,
        )
    # Use shifted approach consistent with training code
    shifted_ids = input_ids[:, 1:]
    embeddings = output.hidden_states[-1][:, 1:, :]

    result = []
    for i in range(shifted_ids.size(0)):
        mask = shifted_ids[i] == gen_emb_id
        if mask.any():
            result.append(embeddings[i][mask][-1])
        else:
            result.append(embeddings[i][-1])
    return normalize(torch.stack(result))


# ─── Common input: image + text (same as demo) ───────────────────────────────
IMAGE_PATH = "assets/example.jpg"
base_text = "Represent the given image with the following question: What is in the image?\n<disc_emb>\n"
prompt_suffix = (
    "Represent the above input text, images, videos, or any combination of the three as embeddings. "
    "First output the thinking process in <think> </think> tags and then summarize the entire input in a word or sentence. "
    "Finally, use the <gen_emb> tag to represent the entire input."
)

MULTIMODAL_KEYS = ["pixel_values", "pixel_values_videos", "image_grid_thw", "video_grid_thw", "second_per_grid_ts"]

def build_multimodal_inputs(messages):
    """Build tokenized inputs with image pixel values from messages."""
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    ).to(device)
    return inputs

def get_multimodal_kwargs(inputs):
    """Extract multimodal keyword args from inputs dict."""
    return {k: inputs[k] for k in MULTIMODAL_KEYS if k in inputs}

def append_completion_ids(prompt_input_ids, prompt_attention_mask, completion_text):
    """Tokenize completion_text and concatenate it after the processor-encoded prompt ids.

    This avoids re-tokenizing the whole text (which would lose the expanded image placeholder tokens).
    Returns (input_ids, attention_mask) with shape (1, prompt_len + completion_len).
    """
    comp_ids = tokenizer(completion_text, add_special_tokens=False, return_tensors="pt")["input_ids"].to(device)
    input_ids = torch.cat([prompt_input_ids, comp_ids], dim=1)
    attention_mask = torch.cat([prompt_attention_mask, torch.ones_like(comp_ids)], dim=1)
    return input_ids, attention_mask

# Build the base message (used by all tests)
base_messages = [
    {
        "role": "user",
        "content": [
            {"type": "image", "image": IMAGE_PATH},
            {"type": "text", "text": base_text + prompt_suffix},
        ],
    }
]

# ─── Test 1: Forward pass with different <think> content ─────────────────────
print("\n" + "=" * 70)
print("TEST 1: Forward pass - same image input, different <think> content")
print("=" * 70)

# Three different "completions" with different think content
think_variants = {
    "think_A": "<think>\nThe image shows a cat and a dog sitting together.\n</think>\ncat and dog<gen_emb>",
    "think_B": "<think>\nThis is a completely different reasoning chain about quantum physics and mathematics.\n</think>\ncat and dog<gen_emb>",
    "no_think": "<think>\n\n</think>\ncat and dog<gen_emb>",
}

# Get processor-encoded prompt and pixel values
base_inputs = build_multimodal_inputs(base_messages)
mm_kwargs = get_multimodal_kwargs(base_inputs)

embeddings = {}
for name, completion in think_variants.items():
    # Append completion tokens after the processor-encoded prompt (preserves image tokens)
    input_ids, attn_mask = append_completion_ids(
        base_inputs["input_ids"], base_inputs["attention_mask"], completion
    )
    emb = extract_gen_emb_from_forward(input_ids, attn_mask, **mm_kwargs)
    embeddings[name] = emb

    ids = input_ids[0].tolist()
    gen_pos = [i for i, t in enumerate(ids) if t == gen_emb_id]
    print(f"  [{name}] <gen_emb> positions: {gen_pos}, total tokens: {len(ids)}")

# Compare all pairs
print("\nCosine similarities:")
names = list(embeddings.keys())
for i in range(len(names)):
    for j in range(i + 1, len(names)):
        sim = (embeddings[names[i]] @ embeddings[names[j]].T).item()
        print(f"  {names[i]} vs {names[j]}: {sim:.6f}")


# ─── Test 2: Generate naturally, then compare with shuffled think ─────────────
print("\n" + "=" * 70)
print("TEST 2: Natural generation vs manually altered think content")
print("=" * 70)

messages = base_messages
inputs = build_multimodal_inputs(messages)
mm_kwargs_t2 = get_multimodal_kwargs(inputs)

print("Generating with model.generate()...")
with torch.no_grad():
    gen_output = model.generate(
        **inputs,
        max_new_tokens=512,
        output_hidden_states=True,
        return_dict_in_generate=True,
        use_cache=True,
    )

generated_ids = gen_output.sequences
generated_text = tokenizer.decode(generated_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=False)
print(f"Generated text:\n{generated_text}\n")

# Extract gen_emb from natural generation (train_aligned: full forward pass)
prompt_ids = inputs["input_ids"]
prompt_mask = inputs["attention_mask"]
completion_ids = generated_ids[:, prompt_ids.size(1):]
attention_mask = torch.cat([prompt_mask, torch.ones_like(completion_ids)], dim=1)

with torch.no_grad():
    output = model(
        input_ids=generated_ids,
        attention_mask=attention_mask,
        output_hidden_states=True,
        return_dict=True,
        **mm_kwargs_t2,
    )

shifted_ids = generated_ids[:, 1:]
hidden = output.hidden_states[-1][:, 1:, :]

mask = shifted_ids[0] == gen_emb_id
if mask.any():
    natural_emb = normalize(hidden[0][mask][-1].unsqueeze(0))
    print(f"<gen_emb> found in natural generation at shifted position(s): {mask.nonzero().flatten().tolist()}")
else:
    natural_emb = normalize(hidden[0][-1].unsqueeze(0))
    print("WARNING: <gen_emb> NOT found in natural generation, using last token")

# Now replace think content in the generated completion and re-forward
# Only re-tokenize the completion part (no image tokens), then prepend original prompt ids
altered_completion = re.sub(
    r"<think>.*?</think>",
    "<think>\nThis is totally unrelated filler text about quantum physics and black holes.\n</think>",
    generated_text,
    flags=re.DOTALL,
)

altered_ids, altered_mask = append_completion_ids(
    prompt_ids, prompt_mask, altered_completion
)
altered_emb = extract_gen_emb_from_forward(altered_ids, altered_mask, **mm_kwargs_t2)

sim = (natural_emb @ altered_emb.T).item()
print(f"\nCosine similarity (natural think vs altered think): {sim:.6f}")

if sim > 0.999:
    print(">>> RESULT: gen_emb appears INDEPENDENT of think content (similarity ~1.0)")
elif sim > 0.95:
    print(">>> RESULT: gen_emb has WEAK dependency on think content")
else:
    print(">>> RESULT: gen_emb has STRONG dependency on think content")


# ─── Test 3: Multiple rollouts - same input, different sampling results ───────
print("\n" + "=" * 70)
print("TEST 3: Multiple rollouts of the same input (sampling variance)")
print("=" * 70)

NUM_ROLLOUTS = 5
rollout_embs = []
rollout_texts = []

inputs_r = build_multimodal_inputs(base_messages)
mm_kwargs_t3 = get_multimodal_kwargs(inputs_r)

for r in range(NUM_ROLLOUTS):
    print(f"\n--- Rollout {r+1}/{NUM_ROLLOUTS} ---")
    with torch.no_grad():
        gen_out = model.generate(
            **inputs_r,
            max_new_tokens=512,
            do_sample=True,
            temperature=1.0,
            top_p=1.0,
            output_hidden_states=True,
            return_dict_in_generate=True,
            use_cache=True,
        )

    gen_ids = gen_out.sequences
    completion_text = tokenizer.decode(gen_ids[0][inputs_r["input_ids"].shape[1]:], skip_special_tokens=False)
    rollout_texts.append(completion_text)

    # Extract gen_emb via train_aligned forward pass
    p_ids = inputs_r["input_ids"]
    p_mask = inputs_r["attention_mask"]
    c_ids = gen_ids[:, p_ids.size(1):]
    attn_mask = torch.cat([p_mask, torch.ones_like(c_ids)], dim=1)

    with torch.no_grad():
        out = model(
            input_ids=gen_ids,
            attention_mask=attn_mask,
            output_hidden_states=True,
            return_dict=True,
            **mm_kwargs_t3,
        )

    s_ids = gen_ids[:, 1:]
    h = out.hidden_states[-1][:, 1:, :]
    m = s_ids[0] == gen_emb_id
    if m.any():
        emb = normalize(h[0][m][-1].unsqueeze(0))
        print(f"  <gen_emb> found, completion length: {c_ids.shape[1]} tokens")
    else:
        emb = normalize(h[0][-1].unsqueeze(0))
        print(f"  WARNING: <gen_emb> NOT found, using last token")

    rollout_embs.append(emb)

    # Print a short summary of think content
    think_match = re.search(r"<think>(.*?)</think>", completion_text, re.DOTALL)
    if think_match:
        think_text = think_match.group(1).strip()
        print(f"  Think (first 100 chars): {think_text[:100]}...")
    else:
        print(f"  No <think> block found")

# Pairwise cosine similarity
print(f"\n--- Pairwise cosine similarities across {NUM_ROLLOUTS} rollouts ---")
for i in range(NUM_ROLLOUTS):
    for j in range(i + 1, NUM_ROLLOUTS):
        sim = (rollout_embs[i] @ rollout_embs[j].T).item()
        print(f"  Rollout {i+1} vs Rollout {j+1}: {sim:.6f}")

# Stats
all_sims = []
for i in range(NUM_ROLLOUTS):
    for j in range(i + 1, NUM_ROLLOUTS):
        all_sims.append((rollout_embs[i] @ rollout_embs[j].T).item())

avg_sim = sum(all_sims) / len(all_sims)
min_sim = min(all_sims)
max_sim = max(all_sims)
print(f"\n  Average: {avg_sim:.6f}, Min: {min_sim:.6f}, Max: {max_sim:.6f}")

if avg_sim > 0.999:
    print(">>> RESULT: gen_emb is nearly IDENTICAL across rollouts despite different think content")
elif avg_sim > 0.95:
    print(">>> RESULT: gen_emb has SMALL variance across rollouts")
else:
    print(">>> RESULT: gen_emb has LARGE variance across rollouts - think content matters")

# Also check: are the think contents actually different?
print(f"\n--- Are think contents different across rollouts? ---")
for i in range(NUM_ROLLOUTS):
    think_i = re.search(r"<think>(.*?)</think>", rollout_texts[i], re.DOTALL)
    text_i = think_i.group(1).strip() if think_i else ""
    for j in range(i + 1, NUM_ROLLOUTS):
        think_j = re.search(r"<think>(.*?)</think>", rollout_texts[j], re.DOTALL)
        text_j = think_j.group(1).strip() if think_j else ""
        same = text_i == text_j
        print(f"  Rollout {i+1} vs {j+1}: think content {'SAME' if same else 'DIFFERENT'}")

print("\n" + "=" * 70)
print("DONE")
print("=" * 70)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions