In [2]:
#!/usr/bin/env python3
"""
vLLM Tensor Parallelism Test with Built-in Chat Templates
Tests bit-exact reproducibility across multiple runs with TP
Uses vLLM's built-in tokenizer.apply_chat_template() - no manual formatting
"""

import os
os.environ['HF_HOME'] = '/workspace/huggingface_cache'
os.environ['TRANSFORMERS_CACHE'] = '/workspace/huggingface_cache'

from vllm import LLM, SamplingParams
import numpy as np
from datetime import datetime
import json
import torch
import glob

# ============================================================================
# CONFIGURATION
# ============================================================================

# Model configuration - change this to test different models
MODEL_NAME = "Qwen/Qwen3-4B"  # or "moonshotai/Kimi-K2-Thinking"
TENSOR_PARALLEL_SIZE = 1
MAX_MODEL_LEN = 131072  # Qwen3 supports up to 262k, Kimi K2 supports 256k
GPU_MEMORY_UTILIZATION = 0.95

# Generation configuration
MAX_TOKENS = 200
NUM_REPETITIONS = 5
TEMPERATURE = 0.0  # For reproducibility testing
SEED = 42
TOP_LOGPROBS = 10

# System prompt
SYSTEM_PROMPT = "You are a helpful assistant."

# ============================================================================
# FILE LOADING UTILITIES
# ============================================================================

def find_pdfs_in_current_dir():
    """Find all PDF files in current directory"""
    pdfs = sorted(glob.glob("*.pdf"))
    return pdfs

def load_text_from_pdf(filepath):
    """
    Load text from PDF file.
    
    NOTE: This only extracts text. Images, charts, diagrams are ignored.
    PyPDF2 cannot extract image content. If images contain critical information,
    you need a vision-language model (Qwen3-VL, Kimi-VL) instead.
    """
    try:
        import PyPDF2
    except ImportError:
        raise ImportError("PyPDF2 required for PDF loading. Install with: pip install PyPDF2")
    
    text = []
    with open(filepath, 'rb') as f:
        pdf_reader = PyPDF2.PdfReader(f)
        num_pages = len(pdf_reader.pages)
        print(f"  Loading {num_pages} pages from: {filepath}")
        
        for page_num, page in enumerate(pdf_reader.pages, 1):
            page_text = page.extract_text()
            text.append(page_text)
            if page_num % 10 == 0:
                print(f"    Processed {page_num}/{num_pages} pages")
    
    full_text = '\n'.join(text)
    print(f"  Extracted {len(full_text)} characters ({num_pages} pages)")
    return full_text

# ============================================================================
# MAIN TEST
# ============================================================================

print("=" * 80)
print("vLLM LOGPROBS EXTRACTION TEST")
print("=" * 80)
print()

# Find and load PDFs
pdf_files = find_pdfs_in_current_dir()
if pdf_files:
    print(f"Found {len(pdf_files)} PDF(s): {pdf_files}")
    print()
    print("Loading PDFs (text only - images will be ignored)...")
    
    # Load all PDFs
    documents = []
    for pdf_file in pdf_files:
        text = load_text_from_pdf(pdf_file)
        documents.append({
            'filename': pdf_file,
            'text': text
        })
    print()
    
    # Format documents for prompt
    if len(documents) == 1:
        doc_text = f"Document: {documents[0]['filename']}\n\n{documents[0]['text']}"
    else:
        doc_parts = []
        for i, doc in enumerate(documents, 1):
            doc_parts.append(f"Document {i}: {doc['filename']}\n\n{doc['text']}")
        separator = "\n\n" + "="*80 + "\n\n"
        doc_text = separator.join(doc_parts)
    
    user_message = f"""I have provided you with {len(documents)} document(s). Please read them carefully and provide a brief summary of the main points from each document.

{doc_text}

Please think carefully and summarize the key points from each document."""
else:
    print("No PDFs found in current directory")
    print("Using default hardcoded prompt")
    
    pdf_files = None
    user_message = """Please explain the key architectural innovations in large language models over the past few years, focusing on attention mechanisms and efficiency improvements. Think step-by-step about the major breakthroughs."""

print()

print("Configuration:")
print(f"  Model: {MODEL_NAME}")
print(f"  Tensor parallel: {TENSOR_PARALLEL_SIZE}")
print(f"  Max model len: {MAX_MODEL_LEN}")
print(f"  GPU memory utilization: {GPU_MEMORY_UTILIZATION} ({int(GPU_MEMORY_UTILIZATION*100)}%)")
print(f"  Max tokens: {MAX_TOKENS}")
print(f"  Repetitions: {NUM_REPETITIONS}")
print(f"  Temperature: {TEMPERATURE} (greedy for reproducibility)")
print(f"  Top logprobs: {TOP_LOGPROBS}")
print()

print("GPU Info:")
print(f"  Device: {torch.cuda.get_device_name(0)}")
print(f"  Available: {torch.cuda.device_count()} GPUs")
print()

print("Loading model with vLLM...")
llm = LLM(
    model=MODEL_NAME,
    trust_remote_code=True,
    download_dir='/workspace/huggingface_cache',
    dtype='bfloat16',
    max_model_len=MAX_MODEL_LEN,
    tensor_parallel_size=TENSOR_PARALLEL_SIZE,
    enable_prefix_caching=False,
    gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
    disable_log_stats=True
)
print("Model loaded successfully!")
print()

# Get tokenizer and format prompt using built-in chat template
print("Formatting prompt with built-in chat template...")
tokenizer = llm.get_tokenizer()

messages = [
    {"role": "system", "content": SYSTEM_PROMPT},
    {"role": "user", "content": user_message}
]

# Let vLLM handle all the template formatting
PROMPT = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)

# Get actual token count
prompt_tokens = tokenizer.encode(PROMPT)
prompt_token_count = len(prompt_tokens)

print(f"Formatted prompt length: {len(PROMPT)} characters")
print(f"Actual token count: {prompt_token_count:,} tokens")
print(f"Max model length: {MAX_MODEL_LEN:,} tokens")
print(f"Template applied by vLLM tokenizer")

if prompt_token_count > MAX_MODEL_LEN:
    raise ValueError(
        f"Prompt is too long! {prompt_token_count:,} tokens exceeds "
        f"MAX_MODEL_LEN of {MAX_MODEL_LEN:,} tokens. "
        f"Either reduce your PDFs or increase MAX_MODEL_LEN."
    )

print()

# Sampling parameters
sampling_params = SamplingParams(
    temperature=TEMPERATURE,
    max_tokens=MAX_TOKENS,
    seed=SEED,
    logprobs=TOP_LOGPROBS,
    prompt_logprobs=None
)

# Warmup run
print("Running warmup...")
warmup_output = llm.generate([PROMPT], sampling_params=sampling_params)
print("Warmup complete")
print()

# ============================================================================
# TEST RUNS
# ============================================================================

print("=" * 80)
print(f"Running {NUM_REPETITIONS} test repetitions")
print("=" * 80)
print()

results_tokens = []
results_logprobs = []
results_texts = []
results_distributions = []

for rep in range(NUM_REPETITIONS):
    print(f"Repetition {rep + 1}/{NUM_REPETITIONS}...")
    
    outputs = llm.generate([PROMPT], sampling_params=sampling_params)
    output = outputs[0]
    
    # Extract token IDs
    token_ids = output.outputs[0].token_ids
    results_tokens.append(token_ids)
    
    # Extract generated text
    text = output.outputs[0].text
    results_texts.append(text)
    
    # Show preview of first generation
    if rep == 0:
        print(f"\n  First generation preview (first 400 chars):")
        preview = text[:400].replace('\n', ' ')
        print(f"  {preview}...")
        print()
    
    # Extract logprobs for selected tokens
    logprobs_data = output.outputs[0].logprobs
    selected_logprobs = [lp[token_ids[i]].logprob for i, lp in enumerate(logprobs_data)]
    results_logprobs.append(np.array(selected_logprobs))
    
    # Extract full top-k distributions
    rep_distributions = []
    for position_logprobs in logprobs_data:
        # Get top-k sorted by logprob (descending)
        sorted_items = sorted(position_logprobs.items(), 
                            key=lambda x: x[1].logprob, 
                            reverse=True)[:TOP_LOGPROBS]
        rep_distributions.append([(tok, lp.logprob) for tok, lp in sorted_items])
    results_distributions.append(rep_distributions)
    
    print(f"  Generated {len(token_ids)} tokens")

print()
print("All repetitions complete!")
print()

# ============================================================================
# ANALYSIS
# ============================================================================

print("=" * 80)
print("ANALYSIS")
print("=" * 80)
print()

# Check token sequence identity
print("Checking token sequences...")
tokens_identical = all(
    results_tokens[0] == results_tokens[i] 
    for i in range(1, NUM_REPETITIONS)
)
print(f"Token sequences identical: {tokens_identical}")

if not tokens_identical:
    print("\n⚠ Token sequences differ!")
    for i in range(1, NUM_REPETITIONS):
        if results_tokens[0] != results_tokens[i]:
            diff_positions = [
                j for j in range(min(len(results_tokens[0]), len(results_tokens[i])))
                if results_tokens[0][j] != results_tokens[i][j]
            ]
            print(f"  Rep 0 vs Rep {i}: {len(diff_positions)} positions differ")
            if diff_positions:
                print(f"    First difference at position {diff_positions[0]}")

# Check logprobs for selected tokens
print("\nChecking selected token logprobs...")
first_logprobs = results_logprobs[0]
logprobs_exact = all(
    np.allclose(first_logprobs, results_logprobs[i], rtol=0, atol=1e-10)
    for i in range(1, NUM_REPETITIONS)
)
print(f"Selected token logprobs bit-exact: {logprobs_exact}")

# Check top-k distributions
print("\nChecking full top-k distributions...")
distributions_exact = True
distribution_mismatches = []

first_dist = results_distributions[0]
for rep_idx in range(1, NUM_REPETITIONS):
    for pos_idx in range(len(first_dist)):
        dist_a = first_dist[pos_idx]
        dist_b = results_distributions[rep_idx][pos_idx]
        
        # Check if token IDs match in same order
        tokens_match = [t[0] for t in dist_a] == [t[0] for t in dist_b]
        
        # Check if logprobs are bit-exact
        if tokens_match:
            logprobs_match = all(
                abs(dist_a[i][1] - dist_b[i][1]) < 1e-10 
                for i in range(len(dist_a))
            )
            if not logprobs_match:
                distributions_exact = False
                distribution_mismatches.append((rep_idx, pos_idx))
        else:
            distributions_exact = False
            distribution_mismatches.append((rep_idx, pos_idx))

print(f"Top-k distributions bit-exact: {distributions_exact}")

if not distributions_exact:
    print(f"\n⚠ Found {len(distribution_mismatches)} position mismatches in distributions")
    if len(distribution_mismatches) <= 5:
        for rep_idx, pos_idx in distribution_mismatches:
            print(f"  Rep 0 vs Rep {rep_idx}, position {pos_idx}")
    else:
        print(f"  First 5: {distribution_mismatches[:5]}")

if not logprobs_exact:
    print("\nL2 distances:")
    l2_distances = []
    for i in range(1, NUM_REPETITIONS):
        l2 = np.linalg.norm(first_logprobs - results_logprobs[i])
        l2_distances.append(l2)
        print(f"  Rep 0 vs Rep {i}: L2 = {l2:.6e}")
    
    print(f"\nMax L2: {max(l2_distances):.6e}")
    print(f"Mean L2: {np.mean(l2_distances):.6e}")
    
    # Element-wise statistics
    all_logprobs = np.array(results_logprobs)
    std_per_token = all_logprobs.std(axis=0)
    print(f"\nPer-token std statistics:")
    print(f"  Mean: {std_per_token.mean():.6e}")
    print(f"  Max: {std_per_token.max():.6e}")
    print(f"  Median: {np.median(std_per_token):.6e}")

print()

# ============================================================================
# VERDICT
# ============================================================================

print("=" * 80)
print("VERDICT")
print("=" * 80)
print()

if tokens_identical and logprobs_exact and distributions_exact:
    print("✓ PERFECT REPRODUCIBILITY")
    print("  - Token sequences: bit-exact")
    print("  - Selected token logprobs: bit-exact")
    print("  - Full top-k distributions: bit-exact")
    print("  - Model is deterministic for this config")
elif tokens_identical and logprobs_exact and not distributions_exact:
    print("⚠ SELECTED TOKENS EXACT, DISTRIBUTIONS VARY")
    print("  - Token sequences: bit-exact")
    print("  - Selected token logprobs: bit-exact")
    print("  - Top-k distributions: numerical variation detected")
    print("  → May indicate computational instability in non-selected paths")
elif tokens_identical and not logprobs_exact:
    print("⚠ TOKENS IDENTICAL, LOGPROBS VARY")
    print("  - Token sequences: bit-exact")
    print("  - Logprobs: small numerical variation")
    max_l2 = max(l2_distances) if not logprobs_exact else 0.0
    if max_l2 < 1e-6:
        print(f"  - Variation very small (L2={max_l2:.2e})")
        print("  → Likely acceptable for forensics")
    else:
        print(f"  - Variation notable (L2={max_l2:.2e})")
        print("  → Investigate noise source")
else:
    print("✗ TOKEN SEQUENCES DIFFER")
    print("  - This should NOT happen with temperature=0")
    print("  → Something is wrong, investigate")

print()

# ============================================================================
# SAVE RESULTS
# ============================================================================

output_data = {
    "experiment": "vllm_thinking_test",
    "timestamp": datetime.now().isoformat(),
    "pdfs_used": pdf_files if pdf_files else None,
    "num_pdfs": len(pdf_files) if pdf_files else 0,
    "prompt_length_chars": len(PROMPT),
    "config": {
        "model": MODEL_NAME,
        "tensor_parallel": TENSOR_PARALLEL_SIZE,
        "max_model_len": MAX_MODEL_LEN,
        "max_tokens": MAX_TOKENS,
        "repetitions": NUM_REPETITIONS,
        "temperature": TEMPERATURE,
        "seed": SEED,
        "warmup_enabled": True,
        "prefix_caching_disabled": True,
        "top_logprobs": TOP_LOGPROBS,
        "system_prompt": SYSTEM_PROMPT,
        "template_method": "vllm_built_in"
    },
    "results": {
        "tokens_identical": tokens_identical,
        "logprobs_exact": logprobs_exact,
        "distributions_exact": distributions_exact,
        "perfect_reproducibility": tokens_identical and logprobs_exact and distributions_exact
    },
    "token_sequences": results_tokens,
    "logprobs_vectors": [lp.tolist() for lp in results_logprobs],
    "generated_texts": results_texts,
    "top_k_distributions": [
        [[(int(tok), float(prob)) for tok, prob in dist] for dist in rep_dists]
        for rep_dists in results_distributions
    ]
}

output_file = f"vllm_thinking_test_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
with open(output_file, "w") as f:
    json.dump(output_data, f, indent=2)

print(f"Results saved to: {output_file}")
print()
print("=" * 80)
print("TEST COMPLETE")
print("=" * 80)
print()

print("Usage:")
print("  1. For Qwen3-30B-A3B-Thinking: MODEL_NAME = 'Qwen/Qwen3-30B-A3B-Thinking-2507'")
print("  2. For Kimi K2 Thinking: MODEL_NAME = 'moonshotai/Kimi-K2-Thinking'")
print("  3. vLLM automatically handles chat templates for both")
print()
print("Note:")
print("  - PDFs are loaded with PyPDF2.extract_text() - text only")
print("  - Images, charts, diagrams in PDFs are IGNORED")
print("  - For vision inputs, use Qwen3-VL or Kimi-VL models instead")
print()

vLLM LOGPROBS EXTRACTION TEST

Found 2 PDF(s): ['Verification-for-International-AI-Governance.pdf', 'open_governance_problems_2025.pdf']

Loading PDFs (text only - images will be ignored)...
  Loading 172 pages from: Verification-for-International-AI-Governance.pdf
    Processed 10/172 pages
    Processed 20/172 pages
    Processed 30/172 pages
    Processed 40/172 pages
    Processed 50/172 pages
    Processed 60/172 pages
    Processed 70/172 pages
    Processed 80/172 pages
    Processed 90/172 pages
    Processed 100/172 pages
    Processed 110/172 pages
    Processed 120/172 pages
    Processed 130/172 pages
    Processed 140/172 pages
    Processed 150/172 pages
    Processed 160/172 pages
    Processed 170/172 pages
  Extracted 535619 characters (172 pages)
  Loading 99 pages from: open_governance_problems_2025.pdf
    Processed 10/99 pages
    Processed 20/99 pages
    Processed 30/99 pages
    Processed 40/99 pages
    Processed 50/99 pages
    Processed 60/99 pages
    Proces

The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.


config.json:   0%|          | 0.00/726 [00:00<?, ?B/s]

INFO 11-09 22:02:49 [model.py:547] Resolved architecture: Qwen3ForCausalLM


`torch_dtype` is deprecated! Use `dtype` instead!


ValidationError: 1 validation error for ModelConfig
  Value error, User-specified max_model_len (131072) is greater than the derived max_model_len (max_position_embeddings=40960 or model_max_length=None in model's config.json). To allow overriding this maximum, set the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN=1. VLLM_ALLOW_LONG_MAX_MODEL_LEN must be used with extreme caution. If the model uses relative position encoding (RoPE), positions exceeding derived_max_model_len lead to nan. If the model uses absolute position encoding, positions exceeding derived_max_model_len will cause a CUDA array out-of-bounds error. [type=value_error, input_value=ArgsKwargs((), {'model': ...rocessor_plugin': None}), input_type=ArgsKwargs]
    For further information visit https://errors.pydantic.dev/2.12/v/value_error