In [None]:
#!/usr/bin/env python3
"""
vLLM Logprobs Extraction Test - Standalone Script
Test if we can extract logprobs from vLLM for reproducibility forensics

Paste this entire script into a Jupyter cell and run.
"""

# ============================================================================
# CONFIG SECTION - CHANGE THESE PARAMETERS
# ============================================================================

# Model configuration
MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"  # Small model for testing
TENSOR_PARALLEL_SIZE = 4  
MAX_MODEL_LEN = 8048  # Context length (reduced for memory)
GPU_MEMORY_UTILIZATION = 0.8  # Use 40% of GPU memory (default is 0.9)

# Generation parameters
NUM_REPETITIONS = 5
MAX_TOKENS = 20  # Generate 20 tokens per pass
TEMPERATURE = 0.0  # Greedy sampling for determinism
SEED = 42
TOP_LOGPROBS = 10  # Store top-10 logprobs for forensic analysis

# Input text
INPUT_TEXT = """The field of artificial intelligence has witnessed remarkable 
transformation over the past decade, driven primarily by advances in deep learning 
and the emergence of increasingly sophisticated language models. These models, trained 
on vast corpora of text data, have demonstrated remarkable capabilities across a wide 
range of tasks."""

# ============================================================================
# IMPORTS
# ============================================================================

import os
os.environ['HF_HOME'] = '/workspace/huggingface_cache'
os.environ['TRANSFORMERS_CACHE'] = '/workspace/huggingface_cache'

from vllm import LLM, SamplingParams
import numpy as np
from datetime import datetime
import json
import torch

# ============================================================================
# SETUP
# ============================================================================

print("="*80)
print("vLLM LOGPROBS EXTRACTION TEST")
print("="*80)
print(f"\nConfiguration:")
print(f"  Model: {MODEL_NAME}")
print(f"  Tensor parallel: {TENSOR_PARALLEL_SIZE}")
print(f"  Max model len: {MAX_MODEL_LEN}")
print(f"  GPU memory utilization: {GPU_MEMORY_UTILIZATION} ({GPU_MEMORY_UTILIZATION*100:.0f}%)")
print(f"  Max tokens: {MAX_TOKENS}")
print(f"  Repetitions: {NUM_REPETITIONS}")
print(f"  Temperature: {TEMPERATURE} (greedy)")
print()
print(f"GPU Info:")
print(f"  Device: {torch.cuda.get_device_name(0)}")
print(f"  Available: {torch.cuda.device_count()} GPUs")
print()

# ============================================================================
# LOAD MODEL
# ============================================================================

print("Loading model with vLLM...")
llm = LLM(
    model=MODEL_NAME,
    tensor_parallel_size=TENSOR_PARALLEL_SIZE,
    download_dir="/workspace/huggingface_cache",
    dtype="bfloat16",
    max_model_len=MAX_MODEL_LEN,
    gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
    trust_remote_code=True,
    enable_prefix_caching=False,  # Disable caching for full reproducibility test
)
print("✓ Model loaded (prefix caching disabled)\n")

# ============================================================================
# SAMPLING PARAMETERS
# ============================================================================

sampling_params = SamplingParams(
    temperature=TEMPERATURE,
    max_tokens=MAX_TOKENS,
    logprobs=TOP_LOGPROBS,  # Return top-k logprobs for forensic analysis
    seed=SEED,
)

print(f"Sampling params: temperature={TEMPERATURE}, max_tokens={MAX_TOKENS}, seed={SEED}, top_logprobs={TOP_LOGPROBS}")
print()

# ============================================================================
# LOGPROBS EXTRACTION FUNCTION
# ============================================================================

def extract_logprobs_vector(output):
    """
    Extract logprobs from vLLM output for forensic comparison
    
    Returns two arrays:
    - selected_logprobs: logprobs of tokens that were actually selected (for basic check)
    - top_k_distributions: full top-k distribution at each position (for deep forensics)
    """
    selected_logprobs = []
    top_k_distributions = []
    
    # Get the token IDs that were actually generated
    token_ids = output.outputs[0].token_ids
    
    # output.outputs[0].logprobs is a list (one entry per generated token)
    # Each entry is a dict mapping token_id -> Logprob object
    for i, token_logprobs_dict in enumerate(output.outputs[0].logprobs):
        if token_logprobs_dict is None:
            continue
        
        # Get the token ID that was actually generated at this position
        generated_token_id = token_ids[i]
        
        # Extract selected token logprob
        if generated_token_id in token_logprobs_dict:
            logprob_obj = token_logprobs_dict[generated_token_id]
            selected_logprobs.append(logprob_obj.logprob)
        else:
            print(f"Warning: Token {generated_token_id} not in logprobs dict at position {i}")
            selected_logprobs.append(None)
        
        # Extract full top-k distribution for this position
        # Store as list of (token_id, logprob) tuples, sorted by logprob descending
        position_dist = [
            (tid, lp_obj.logprob) 
            for tid, lp_obj in token_logprobs_dict.items()
        ]
        position_dist.sort(key=lambda x: x[1], reverse=True)  # Sort by logprob
        top_k_distributions.append(position_dist)
    
    return np.array(selected_logprobs), top_k_distributions

# ============================================================================
# WARM-UP PASS
# ============================================================================

print("="*80)
print("WARM-UP PASS")
print("="*80)
print()
print("Running warm-up pass to initialize CUDA kernels...")
_ = llm.generate([INPUT_TEXT], sampling_params)
print("✓ Warm-up complete - CUDA kernels compiled and cached\n")

# ============================================================================
# RUN REPEATED FORWARD PASSES
# ============================================================================

print("="*80)
print(f"RUNNING {NUM_REPETITIONS} FORWARD PASSES")
print("="*80)
print()

results_logprobs = []
results_distributions = []
results_tokens = []
results_texts = []

for rep in range(NUM_REPETITIONS):
    print(f"Rep {rep + 1}/{NUM_REPETITIONS}:", end=" ")
    
    # Generate (vLLM automatically manages cache per request)
    output = llm.generate([INPUT_TEXT], sampling_params)[0]
    
    # Extract logprobs and distributions
    logprobs_vec, top_k_dist = extract_logprobs_vector(output)
    results_logprobs.append(logprobs_vec)
    results_distributions.append(top_k_dist)
    
    # Extract tokens
    token_ids = output.outputs[0].token_ids
    results_tokens.append(token_ids)
    
    # Extract text
    text = output.outputs[0].text
    results_texts.append(text)
    
    print(f"{len(token_ids)} tokens, mean logprob={np.mean(logprobs_vec):.4f}")

print()

# Show first generation
print("First generation:")
print(f"  Text: '{results_texts[0][:100]}...'")
print(f"  Tokens: {results_tokens[0][:10]}...")
print(f"  Logprobs (first 5): {results_logprobs[0][:5]}")
print()

# ============================================================================
# ANALYSIS
# ============================================================================

print("="*80)
print("REPRODUCIBILITY ANALYSIS")
print("="*80)
print()

# Check token sequences
first_tokens = results_tokens[0]
tokens_identical = all(results_tokens[i] == first_tokens for i in range(NUM_REPETITIONS))

print(f"Token sequences identical: {tokens_identical}")
if not tokens_identical:
    for i in range(1, NUM_REPETITIONS):
        if results_tokens[i] != first_tokens:
            print(f"  Rep 0 vs Rep {i}: DIFFER")

print()

# Check logprobs
first_logprobs = results_logprobs[0]

# Bit-exact comparison
logprobs_exact = all(np.array_equal(first_logprobs, results_logprobs[i]) 
                     for i in range(1, NUM_REPETITIONS))

print(f"Logprobs bit-exact: {logprobs_exact}")

# Check top-k distributions
print("\nChecking full top-k distributions...")
distributions_exact = True
distribution_mismatches = []

first_dist = results_distributions[0]
for rep_idx in range(1, NUM_REPETITIONS):
    for pos_idx in range(len(first_dist)):
        dist_a = first_dist[pos_idx]
        dist_b = results_distributions[rep_idx][pos_idx]
        
        # Check if token IDs match in same order
        tokens_match = [t[0] for t in dist_a] == [t[0] for t in dist_b]
        
        # Check if logprobs are bit-exact
        if tokens_match:
            logprobs_match = all(
                abs(dist_a[i][1] - dist_b[i][1]) < 1e-10 
                for i in range(len(dist_a))
            )
            if not logprobs_match:
                distributions_exact = False
                distribution_mismatches.append((rep_idx, pos_idx))
        else:
            distributions_exact = False
            distribution_mismatches.append((rep_idx, pos_idx))

print(f"Top-k distributions bit-exact: {distributions_exact}")

if not distributions_exact:
    print(f"\n⚠ Found {len(distribution_mismatches)} position mismatches in distributions")
    if len(distribution_mismatches) <= 5:
        for rep_idx, pos_idx in distribution_mismatches:
            print(f"  Rep 0 vs Rep {rep_idx}, position {pos_idx}")
    else:
        print(f"  First 5: {distribution_mismatches[:5]}")

if not logprobs_exact:
    print("\nL2 distances:")
    l2_distances = []
    for i in range(1, NUM_REPETITIONS):
        l2 = np.linalg.norm(first_logprobs - results_logprobs[i])
        l2_distances.append(l2)
        print(f"  Rep 0 vs Rep {i}: L2 = {l2:.6e}")
    
    print(f"\nMax L2: {max(l2_distances):.6e}")
    print(f"Mean L2: {np.mean(l2_distances):.6e}")
    
    # Element-wise statistics
    all_logprobs = np.array(results_logprobs)
    std_per_token = all_logprobs.std(axis=0)
    print(f"\nPer-token std statistics:")
    print(f"  Mean: {std_per_token.mean():.6e}")
    print(f"  Max: {std_per_token.max():.6e}")
    print(f"  Median: {np.median(std_per_token):.6e}")

print()

# ============================================================================
# VERDICT
# ============================================================================

print("="*80)
print("VERDICT")
print("="*80)
print()

if tokens_identical and logprobs_exact and distributions_exact:
    print("✓ PERFECT REPRODUCIBILITY")
    print("  - Token sequences: bit-exact")
    print("  - Selected token logprobs: bit-exact")
    print("  - Full top-k distributions: bit-exact")
    print("  - vLLM is deterministic for this config")
    print("  → Ready to scale up to K2 Thinking on Blackwell")
elif tokens_identical and logprobs_exact and not distributions_exact:
    print("⚠ SELECTED TOKENS EXACT, DISTRIBUTIONS VARY")
    print("  - Token sequences: bit-exact")
    print("  - Selected token logprobs: bit-exact")
    print("  - Top-k distributions: numerical variation detected")
    print("  → May indicate computational instability in non-selected paths")
elif tokens_identical and not logprobs_exact:
    print("✓ PERFECT REPRODUCIBILITY")
    print("  - Token sequences: bit-exact")
    print("  - Logprobs: bit-exact")
    print("  - vLLM is deterministic for this config")
    print("  → Ready to scale up to K2 Thinking on Blackwell")
elif tokens_identical and not logprobs_exact:
    print("⚠ TOKENS IDENTICAL, LOGPROBS VARY")
    print("  - Token sequences: bit-exact")
    print("  - Logprobs: small numerical variation")
    max_l2 = max(l2_distances) if not logprobs_exact else 0.0
    if max_l2 < 1e-6:
        print(f"  - Variation very small (L2={max_l2:.2e})")
        print("  → Likely acceptable for forensics")
    else:
        print(f"  - Variation notable (L2={max_l2:.2e})")
        print("  → Investigate noise source")
else:
    print("✗ TOKEN SEQUENCES DIFFER")
    print("  - This should NOT happen with temperature=0")
    print("  → Something is wrong, investigate")

print()

# ============================================================================
# SAVE RESULTS
# ============================================================================

output_data = {
    "experiment": "vllm_logprobs_test",
    "timestamp": datetime.now().isoformat(),
    "config": {
        "model": MODEL_NAME,
        "tensor_parallel": TENSOR_PARALLEL_SIZE,
        "max_tokens": MAX_TOKENS,
        "repetitions": NUM_REPETITIONS,
        "temperature": TEMPERATURE,
        "seed": SEED,
        "warmup_enabled": True,
        "prefix_caching_disabled": True,
        "top_logprobs": TOP_LOGPROBS
    },
    "results": {
        "tokens_identical": tokens_identical,
        "logprobs_exact": logprobs_exact,
        "distributions_exact": distributions_exact,
        "perfect_reproducibility": tokens_identical and logprobs_exact and distributions_exact
    },
    "token_sequences": results_tokens,
    "logprobs_vectors": [lp.tolist() for lp in results_logprobs],
    "generated_texts": results_texts,
    "note": "Top-k distributions not saved to JSON (too large), check bit-exact flag above"
}

output_file = f"/workspace/vllm_logprobs_test_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
with open(output_file, "w") as f:
    json.dump(output_data, f, indent=2)

print(f"Results saved to: {output_file}")
print()
print("="*80)
print("TEST COMPLETE")
print("="*80)
print()
print("If successful, next steps:")
print("  1. Scale up to K2 Thinking")
print("  2. Test on 4× B200 with tensor parallelism")
print("  3. Run with 100K context")
print()



INFO 11-09 20:41:40 [__init__.py:216] Automatically detected platform cuda.
vLLM LOGPROBS EXTRACTION TEST

Configuration:
  Model: Qwen/Qwen2.5-7B-Instruct
  Tensor parallel: 4
  Max model len: 8048
  GPU memory utilization: 0.8 (80%)
  Max tokens: 20
  Repetitions: 5
  Temperature: 0.0 (greedy)

GPU Info:
  Device: NVIDIA A100-SXM4-80GB
  Available: 4 GPUs

Loading model with vLLM...
INFO 11-09 20:41:44 [utils.py:233] non-default args: {'trust_remote_code': True, 'download_dir': '/workspace/huggingface_cache', 'dtype': 'bfloat16', 'max_model_len': 8048, 'tensor_parallel_size': 4, 'enable_prefix_caching': False, 'gpu_memory_utilization': 0.8, 'disable_log_stats': True, 'model': 'Qwen/Qwen2.5-7B-Instruct'}


The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.


config.json:   0%|          | 0.00/663 [00:00<?, ?B/s]

INFO 11-09 20:41:53 [model.py:547] Resolved architecture: Qwen2ForCausalLM


`torch_dtype` is deprecated! Use `dtype` instead!


INFO 11-09 20:41:53 [model.py:1510] Using max model len 8048
INFO 11-09 20:41:56 [scheduler.py:205] Chunked prefill is enabled with max_num_batched_tokens=8192.


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

generation_config.json:   0%|          | 0.00/243 [00:00<?, ?B/s]





INFO 11-09 20:42:02 [__init__.py:216] Automatically detected platform cuda.
[1;36m(EngineCore_DP0 pid=1359)[0;0m INFO 11-09 20:42:03 [core.py:644] Waiting for init message from front-end.
[1;36m(EngineCore_DP0 pid=1359)[0;0m INFO 11-09 20:42:03 [core.py:77] Initializing a V1 LLM engine (v0.11.0) with config: model='Qwen/Qwen2.5-7B-Instruct', speculative_config=None, tokenizer='Qwen/Qwen2.5-7B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=8048, download_dir='/workspace/huggingface_cache', load_format=auto, tensor_parallel_size=4, pipeline_parallel_size=1, data_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, structured_outputs_config=StructuredOutputsConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_parser=''), observabi



INFO 11-09 20:42:07 [__init__.py:216] Automatically detected platform cuda.
INFO 11-09 20:42:07 [__init__.py:216] Automatically detected platform cuda.
INFO 11-09 20:42:07 [__init__.py:216] Automatically detected platform cuda.
INFO 11-09 20:42:07 [__init__.py:216] Automatically detected platform cuda.
INFO 11-09 20:42:11 [shm_broadcast.py:289] vLLM message queue communication handle: Handle(local_reader_ranks=[0], buffer_handle=(1, 10485760, 10, 'psm_8735dea1'), local_subscribe_addr='ipc:///tmp/25bff50d-102f-4a83-b5d9-6c057738ca20', remote_subscribe_addr=None, remote_addr_ipv6=False)
INFO 11-09 20:42:11 [shm_broadcast.py:289] vLLM message queue communication handle: Handle(local_reader_ranks=[0], buffer_handle=(1, 10485760, 10, 'psm_02698cc8'), local_subscribe_addr='ipc:///tmp/2a291431-5a0e-4da4-8ce8-0dc51e3d89d1', remote_subscribe_addr=None, remote_addr_ipv6=False)
INFO 11-09 20:42:11 [shm_broadcast.py:289] vLLM message queue communication handle: Handle(local_reader_ranks=[0], buffe

Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:00<00:00,  3.10it/s]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:00<00:00,  2.49it/s]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:01<00:00,  2.32it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:01<00:00,  2.34it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:01<00:00,  2.40it/s]
[1;36m(Worker_TP0 pid=1493)[0;0m 


[1;36m(Worker_TP0 pid=1493)[0;0m INFO 11-09 20:44:29 [default_loader.py:267] Loading weights took 1.72 seconds
[1;36m(Worker_TP3 pid=1496)[0;0m INFO 11-09 20:44:29 [default_loader.py:267] Loading weights took 2.03 seconds
[1;36m(Worker_TP1 pid=1494)[0;0m INFO 11-09 20:44:30 [default_loader.py:267] Loading weights took 1.73 seconds
[1;36m(Worker_TP0 pid=1493)[0;0m INFO 11-09 20:44:30 [gpu_model_runner.py:2653] Model loading took 3.5547 GiB and 135.488367 seconds
[1;36m(Worker_TP2 pid=1495)[0;0m INFO 11-09 20:44:30 [default_loader.py:267] Loading weights took 1.89 seconds
[1;36m(Worker_TP3 pid=1496)[0;0m INFO 11-09 20:44:30 [gpu_model_runner.py:2653] Model loading took 3.5547 GiB and 135.660971 seconds
[1;36m(Worker_TP1 pid=1494)[0;0m INFO 11-09 20:44:30 [gpu_model_runner.py:2653] Model loading took 3.5547 GiB and 136.016446 seconds
[1;36m(Worker_TP2 pid=1495)[0;0m INFO 11-09 20:44:30 [gpu_model_runner.py:2653] Model loading took 3.5547 GiB and 136.018848 seconds
[1;36m(