In [3]:
#!/usr/bin/env python3
"""
vLLM Cross-Hardware Batch Size Detectability Experiment

Tests whether batch size claims can be verified across different GPU architectures
using logprob forensics with vLLM inference engine.

Unlike Transformers, vLLM doesn't expose key vectors, so we rely solely on logprobs.
vLLM may have finer-grained kernel selection than Transformers' 4 classes.

Workflow:

Run on Machine A (e.g., A100) with TEACHER_FORCING = False
→ Generates tokens, extracts logprobs, saves to JSON

Copy JSON to Machine B (e.g., H100)

Run on Machine B with TEACHER_FORCING = True
→ Teacher-forces A's tokens via prompt_logprobs, compares

Usage:

Machine A: python vllm_batch_detect.py
Machine B: Edit TEACHER_FORCING=True, REFERENCE_FILE, then run

"""

import os
os.environ['HF_HOME'] = '/workspace/huggingface_cache'
os.environ['TRANSFORMERS_CACHE'] = '/workspace/huggingface_cache'

from vllm import LLM, SamplingParams
import numpy as np
from datetime import datetime
import json
import socket
import platform
import sys
import glob
import PyPDF2

# ============================================================================
# CONFIGURATION
# ============================================================================

TEACHER_FORCING = False
REFERENCE_FILE = "/workspace/experiments/vllm_reference.json"

MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"
CACHE_DIR = '/workspace/huggingface_cache'

BATCH_SIZES = [1, 2, 3, 4, 5, 8, 9]
MAX_NEW_TOKENS = 20
TOKENS_PER_SLICE = 200
NUM_REFERENCES = 3
TOP_K_LOGPROBS = 20  # Store top-20 to ensure overlap for comparison

# Threshold for considering two batch sizes "equivalent" (same kernel)
EQUIVALENCE_THRESHOLD = 1e-9

# Will be initialized from PDF in main()
REFERENCE_SEQUENCES = None
DUMMY_SETS = None

# ============================================================================
# LOGGING SETUP
# ============================================================================

LOG_FILE = None

def setup_logging(output_dir='/workspace/experiments'):
    """Setup logging to file."""
    global LOG_FILE
    os.makedirs(output_dir, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    mode = "verify" if TEACHER_FORCING else "generate"
    log_path = os.path.join(output_dir, f"vllm_experiment_{mode}_{timestamp}.txt")
    LOG_FILE = open(log_path, 'w')
    return log_path

def log_print(*args, **kwargs):
    """Print to both console and log file."""
    print(*args, **kwargs)
    if LOG_FILE:
        log_kwargs = {k: v for k, v in kwargs.items() if k != 'file'}
        print(*args, **log_kwargs, file=LOG_FILE)
        LOG_FILE.flush()

def close_logging():
    """Close log file."""
    global LOG_FILE
    if LOG_FILE:
        LOG_FILE.close()
        LOG_FILE = None

# ============================================================================
# PDF LOADING
# ============================================================================

def load_pdf_text(pdf_path):
    """Load text content from a PDF file."""
    text = ""
    with open(pdf_path, 'rb') as f:
        reader = PyPDF2.PdfReader(f)
        for page in reader.pages:
            page_text = page.extract_text()
            if page_text:
                text += page_text + " "
    return text.strip()

def create_sequences_from_pdf(tokenizer, num_references=NUM_REFERENCES):
    """
    Load all PDFs and split into equal-length slices.
    Returns REFERENCE_SEQUENCES and DUMMY_SETS dictionaries.
    """
    pdf_files = glob.glob("/workspace/*.pdf")
    if not pdf_files:
        pdf_files = glob.glob("*.pdf")
    if not pdf_files:
        raise FileNotFoundError("No PDF files found")

    log_print(f"Found {len(pdf_files)} PDF(s)")

    all_tokens = []
    for pdf_path in pdf_files:
        log_print(f"  Loading: {pdf_path}")
        text = load_pdf_text(pdf_path)
        tokens = tokenizer.encode(text)
        all_tokens.extend(tokens)
        log_print(f"    → {len(tokens)} tokens")

    log_print(f"Total tokens: {len(all_tokens)}")

    max_batch_size = max(BATCH_SIZES)
    slices_needed = num_references * max_batch_size
    tokens_needed = slices_needed * TOKENS_PER_SLICE

    if len(all_tokens) < tokens_needed:
        raise ValueError(f"Need {tokens_needed} tokens but only have {len(all_tokens)}")

    log_print(f"Creating {slices_needed} slices of {TOKENS_PER_SLICE} tokens each")

    slices = []
    for i in range(slices_needed):
        start = i * TOKENS_PER_SLICE
        end = start + TOKENS_PER_SLICE
        slice_tokens = all_tokens[start:end]
        slice_text = tokenizer.decode(slice_tokens)
        slices.append(slice_text)

    reference_sequences = {}
    dummy_sets = {}

    for ref_idx in range(num_references):
        ref_name = f"ref_{ref_idx}"
        base_idx = ref_idx * max_batch_size
        reference_sequences[ref_name] = slices[base_idx]
        dummy_sets[ref_name] = slices[base_idx + 1 : base_idx + max_batch_size]

    return reference_sequences, dummy_sets

# ============================================================================
# SYSTEM INFO
# ============================================================================

def collect_system_info():
    """Collect comprehensive environment information."""
    import torch
    import transformers

    info = {
        "hostname": socket.gethostname(),
        "platform": platform.platform(),
        "python_version": sys.version.split()[0],
        "torch_version": torch.__version__,
        "cuda_version": torch.version.cuda if torch.cuda.is_available() else "N/A",
        "cudnn_version": str(torch.backends.cudnn.version()) if torch.cuda.is_available() else "N/A",
        "transformers_version": transformers.__version__,
        "numpy_version": np.__version__,
        "gpu_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "N/A",
        "gpu_count": torch.cuda.device_count() if torch.cuda.is_available() else 0,
    }

    try:
        import vllm
        info["vllm_version"] = vllm.__version__
    except (ImportError, AttributeError):
        info["vllm_version"] = "unknown"

    return info

def validate_environment_match(reference_env, verifier_env):
    """
    Validate that software environments match between reference and verifier.
    """
    log_print("\n" + "="*80)
    log_print("ENVIRONMENT VALIDATION")
    log_print("="*80)

    critical_fields = ['vllm_version', 'torch_version', 'cuda_version']
    expected_different = ['gpu_name', 'hostname']

    mismatches = []

    log_print("\nCritical dependencies:")
    for field in critical_fields:
        ref_val = reference_env.get(field, 'N/A')
        ver_val = verifier_env.get(field, 'N/A')

        if ref_val == ver_val:
            log_print(f"  ✓ {field}: {ref_val}")
        else:
            log_print(f"  ✗ {field}: reference={ref_val}, verifier={ver_val}")
            mismatches.append((field, ref_val, ver_val))

    log_print("\nExpected differences (hardware):")
    for field in expected_different:
        ref_val = reference_env.get(field, 'N/A')
        ver_val = verifier_env.get(field, 'N/A')

        if ref_val != ver_val:
            log_print(f"  ✓ {field}: reference={ref_val}, verifier={ver_val}")
        else:
            log_print(f"  ⚠ {field}: SAME ({ref_val}) - are you on different hardware?")

    if not mismatches:
        log_print("\n✓ ENVIRONMENT VALIDATION PASSED")
        return {'valid': True, 'mismatches': []}
    else:
        log_print("\n⚠ ENVIRONMENT MISMATCHES DETECTED")
        log_print("  Results may be affected by software differences, not just hardware.")
        return {'valid': False, 'mismatches': mismatches}

# ============================================================================
# LOGPROB EXTRACTION
# ============================================================================

def extract_logprobs_from_output(output, positions=[-3, -2, -1]):
    """
    Extract logprobs from vLLM output at specified positions.
    """
    signals = {}
    
    logprobs_list = output.outputs[0].logprobs
    
    if logprobs_list is None:
        return signals
    
    num_generated = len(logprobs_list)
    
    for pos in positions:
        actual_idx = pos if pos >= 0 else num_generated + pos
        
        if actual_idx < 0 or actual_idx >= num_generated:
            continue
        
        pos_label = f"pos_{pos}"
        token_logprobs = logprobs_list[actual_idx]
        
        token_ids = []
        log_probs = []
        
        for token_id, logprob_obj in token_logprobs.items():
            token_ids.append(token_id)
            log_probs.append(logprob_obj.logprob)
        
        signals[pos_label] = {
            'logprobs': {
                'token_ids': token_ids,
                'log_probs': log_probs
            }
        }
    
    return signals

def extract_prompt_logprobs(output, prompt_length, positions=[-3, -2, -1]):
    """
    Extract logprobs from prompt positions (for prefill analysis).
    """
    signals = {}
    
    prompt_logprobs_list = output.prompt_logprobs
    
    if prompt_logprobs_list is None:
        return signals
    
    for pos in positions:
        actual_idx = pos if pos >= 0 else prompt_length + pos
        
        if actual_idx < 0 or actual_idx >= len(prompt_logprobs_list):
            continue
        
        pos_label = f"pos_{pos}"
        token_logprobs = prompt_logprobs_list[actual_idx]
        
        if token_logprobs is None:
            continue
        
        token_ids = []
        log_probs = []
        
        for token_id, logprob_obj in token_logprobs.items():
            token_ids.append(token_id)
            log_probs.append(logprob_obj.logprob)
        
        signals[pos_label] = {
            'logprobs': {
                'token_ids': token_ids,
                'log_probs': log_probs
            }
        }
    
    return signals

# ============================================================================
# GENERATION MODE
# ============================================================================

def compute_min_length_across_batches(ref_text, ref_name, tokenizer, batch_sizes):
    """Pre-compute minimum sequence length across all batch configurations."""
    ref_dummies = DUMMY_SETS[ref_name]
    min_length = float('inf')

    for batch_size in batch_sizes:
        if batch_size == 1:
            batch_texts = [ref_text]
        else:
            batch_texts = [ref_text] + ref_dummies[:batch_size-1]

        token_lengths = [len(tokenizer.encode(t)) for t in batch_texts]
        min_length = min(min_length, min(token_lengths))

    return min_length

def run_generation(llm, tokenizer, ref_text, ref_name, batch_size, forced_length=None):
    """
    Run generation with specified batch size and extract signals.
    """
    ref_dummies = DUMMY_SETS[ref_name]
    
    if batch_size == 1:
        batch_texts = [ref_text]
    else:
        batch_texts = [ref_text] + ref_dummies[:batch_size-1]
    
    all_token_ids = [tokenizer.encode(t) for t in batch_texts]
    
    if forced_length is not None:
        min_length = forced_length
    else:
        min_length = min(len(ids) for ids in all_token_ids)
    
    truncated_token_ids = [ids[:min_length] for ids in all_token_ids]
    truncated_texts = [tokenizer.decode(ids) for ids in truncated_token_ids]
    
    prompt_length = len(truncated_token_ids[0])
    log_print(f"      Prompt: {prompt_length} tokens", end="")
    
    sampling_params = SamplingParams(
        max_tokens=MAX_NEW_TOKENS,
        temperature=0.0,
        logprobs=TOP_K_LOGPROBS,
        prompt_logprobs=TOP_K_LOGPROBS,
    )
    
    outputs = llm.generate(truncated_texts, sampling_params)
    
    output_0 = outputs[0]
    generated_ids = list(output_0.outputs[0].token_ids)
    num_generated = len(generated_ids)
    
    prefill_signals = extract_prompt_logprobs(output_0, prompt_length, positions=[-3, -2, -1])
    decode_signals = extract_logprobs_from_output(output_0, positions=[-3, -2, -1])
    
    all_batch_generated_ids = [list(out.outputs[0].token_ids) for out in outputs]
    
    log_print(f" → Final: {prompt_length + num_generated} tokens ({num_generated} generated)")
    
    return {
        'generated_ids': generated_ids,
        'all_batch_generated_ids': all_batch_generated_ids,
        'prompt_token_ids': truncated_token_ids,
        'prompt_length': prompt_length,
        'prefill_signals': prefill_signals,
        'decode_signals': decode_signals,
        'num_generated': num_generated
    }

# ============================================================================
# VERIFICATION MODE (TEACHER FORCING)
# ============================================================================

def run_teacher_forced_verification(llm, tokenizer, ref_name, reference_data,
                                     verify_batch_size, is_diagonal):
    ref_prompt_ids = reference_data['prompt_token_ids'][0]
    ref_generated_ids = reference_data['generated_ids']
    ref_batch_size = len(reference_data['prompt_token_ids'])
    
    # Only pass PROMPT, not full sequence
    prompt_text = tokenizer.decode(ref_prompt_ids)
    prompt_length = len(ref_prompt_ids)
    num_generated = len(ref_generated_ids)
    
    log_print(f"      Prompt: {prompt_length}, Gen: {num_generated}", end="")
    
    if is_diagonal:
        log_print(f", exact neighbors (bs={ref_batch_size})", end="")
        batch_texts = []
        for i in range(ref_batch_size):
            batch_texts.append(tokenizer.decode(reference_data['prompt_token_ids'][i]))
        actual_batch_size = ref_batch_size
    else:
        log_print(f", arb neighbors (bs={verify_batch_size})", end="")
        batch_texts = [prompt_text]
        
        ref_dummies = DUMMY_SETS[ref_name]
        for i in range(verify_batch_size - 1):
            dummy_ids = tokenizer.encode(ref_dummies[i])[:prompt_length]
            batch_texts.append(tokenizer.decode(dummy_ids))
        
        actual_batch_size = verify_batch_size
    
    sampling_params = SamplingParams(
        max_tokens=num_generated,  # Generate same number of tokens
        temperature=0.0,
        prompt_logprobs=TOP_K_LOGPROBS,  # Prefill signals
        logprobs=TOP_K_LOGPROBS,          # Decode signals  <-- ADD THIS
    )
    
    outputs = llm.generate(batch_texts, sampling_params)
    output_0 = outputs[0]
    
    # Prefill signals (unchanged)
    prefill_signals = extract_prompt_logprobs(output_0, prompt_length, positions=[-3, -2, -1])
    
    # Decode signals - from GENERATED tokens, not prompt
    decode_logprobs = output_0.outputs[0].logprobs  # List[Dict[token_id, Logprob]]
    
    # Check token match
    generated_ids = [list(lp.keys())[list(lp.values()).index(max(lp.values(), key=lambda x: x.logprob))] 
                     for lp in decode_logprobs]
    # Simpler: just get the sampled token
    generated_ids = [output_0.outputs[0].token_ids[i] for i in range(len(decode_logprobs))]
    
    if generated_ids != ref_generated_ids:
        log_print(f" WARN: token mismatch!", end="")
    
    # Extract last 3 decode positions
    decode_signals = {}
    if len(decode_logprobs) >= 3:
        for i, pos in enumerate([-3, -2, -1]):
            lp_dict = decode_logprobs[pos]
            token_ids = list(lp_dict.keys())
            log_probs = [lp_dict[tid].logprob for tid in token_ids]
            decode_signals[f'pos_{pos}'] = {
                'logprobs': {
                    'token_ids': token_ids,
                    'log_probs': log_probs
                }
            }
    
    log_print(f" → decoded")
    
    return {
        'prefill_signals': prefill_signals,
        'decode_signals': decode_signals,
        'num_generated': num_generated
    }

# ============================================================================
# ANALYSIS
# ============================================================================

def compute_logprob_distance_canonical(logprobs1, logprobs2, canonical_ids):
    """
    Compute L2 distance between logprobs for a canonical set of token IDs.
    """
    lp1 = dict(zip(logprobs1['token_ids'], logprobs1['log_probs']))
    lp2 = dict(zip(logprobs2['token_ids'], logprobs2['log_probs']))

    vec1 = []
    vec2 = []

    for tid in canonical_ids:
        if tid in lp1 and tid in lp2:
            vec1.append(lp1[tid])
            vec2.append(lp2[tid])

    if len(vec1) == 0:
        return float('inf')

    return float(np.linalg.norm(np.array(vec1) - np.array(vec2)))

def compare_signals(signals1, signals2):
    """Compare two signal sets using top 5 token IDs from first signal as canonical."""
    common_positions = set(signals1.keys()) & set(signals2.keys())

    all_dists = []

    for pos_label in common_positions:
        sig1 = signals1[pos_label]
        sig2 = signals2[pos_label]

        # Use top 5 for comparison (stored top 20 as buffer)
        canonical_ids = sig1['logprobs']['token_ids'][:5]
        dist = compute_logprob_distance_canonical(
            sig1['logprobs'], sig2['logprobs'], canonical_ids
        )
        all_dists.append(dist)

    return {
        'logprobs_max': max(all_dists) if all_dists else 0.0,
        'logprobs_mean': np.mean(all_dists) if all_dists else 0.0
    }

def compare_signals_with_canonical_ids(signals1, signals2, canonical_token_ids):
    """Compare using pre-specified canonical token IDs (top 5)."""
    common_positions = set(signals1.keys()) & set(signals2.keys())

    all_dists = []

    for pos_label in common_positions:
        sig1 = signals1[pos_label]
        sig2 = signals2[pos_label]

        # Use top 5 for comparison
        canonical_ids = canonical_token_ids.get(pos_label, sig1['logprobs']['token_ids'][:5])[:5]
        dist = compute_logprob_distance_canonical(
            sig1['logprobs'], sig2['logprobs'], canonical_ids
        )
        all_dists.append(dist)

    return {
        'logprobs_max': max(all_dists) if all_dists else 0.0,
        'logprobs_mean': np.mean(all_dists) if all_dists else 0.0
    }

def check_token_consistency(measurements, tokenizer):
    """Verify element 0 generates identical tokens across all batch sizes."""
    log_print("\n" + "="*80)
    log_print("TOKEN GENERATION CONSISTENCY CHECK")
    log_print("="*80)

    tokens_by_bs = {}
    for bs, data in measurements.items():
        tokens_by_bs[bs] = data['generated_ids']

    bs_list = sorted(tokens_by_bs.keys())
    reference_tokens = tokens_by_bs[bs_list[0]]

    all_same = True
    log_print("\nGenerated tokens by batch size:")
    for bs in bs_list:
        tokens = tokens_by_bs[bs]
        match_str = "✓" if tokens == reference_tokens else "✗ DIFFERENT"
        decoded_text = tokenizer.decode(tokens)
        log_print(f"  bs={bs}:")
        log_print(f"    IDs:  {tokens}")
        log_print(f"    Text: {repr(decoded_text)}")
        log_print(f"    {match_str}")
        if tokens != reference_tokens:
            all_same = False

    if all_same:
        log_print("\n✓ Element 0 generates IDENTICAL tokens across all batch sizes")
    else:
        log_print("\n⚠ Element 0 generates DIFFERENT tokens across batch sizes")

    return all_same

def find_equivalent_pairs(matrix, batch_sizes, threshold=EQUIVALENCE_THRESHOLD):
    """
    Find pairs of batch sizes that produce equivalent results (same kernel).
    Returns list of (bs1, bs2) tuples where bs1 < bs2.
    """
    equivalent_pairs = []
    n_bs = len(batch_sizes)
    
    for i in range(n_bs):
        for j in range(i + 1, n_bs):
            if matrix[i, j] < threshold:
                equivalent_pairs.append((batch_sizes[i], batch_sizes[j]))
    
    return equivalent_pairs

def format_kernel_classes(equivalent_pairs, batch_sizes):
    """
    Group batch sizes into kernel equivalence classes.
    Returns list of sets, each set containing batch sizes using the same kernel.
    """
    # Union-find style grouping
    parent = {bs: bs for bs in batch_sizes}
    
    def find(x):
        if parent[x] != x:
            parent[x] = find(parent[x])
        return parent[x]
    
    def union(x, y):
        px, py = find(x), find(y)
        if px != py:
            parent[px] = py
    
    for bs1, bs2 in equivalent_pairs:
        union(bs1, bs2)
    
    # Group by root
    groups = {}
    for bs in batch_sizes:
        root = find(bs)
        if root not in groups:
            groups[root] = set()
        groups[root].add(bs)
    
    return list(groups.values())

def analyze_within_hardware(measurements, batch_sizes, signal_source='decode'):
    """Analyze within-hardware batch size effects."""
    log_print("\n" + "="*80)
    log_print(f"WITHIN-HARDWARE BATCH SIZE EFFECTS ({signal_source.upper()})")
    log_print("="*80)

    by_ref = {}
    for m in measurements:
        ref = m['ref_name']
        if ref not in by_ref:
            by_ref[ref] = {}
        signals_key = 'prefill_signals' if signal_source == 'prefill' else 'decode_signals'
        by_ref[ref][m['batch_size']] = m[signals_key]

    all_matrices = []
    n_bs = len(batch_sizes)

    for ref_name in sorted(by_ref.keys()):
        log_print(f"\n{ref_name}:")

        ref_data = by_ref[ref_name]
        matrix = np.zeros((n_bs, n_bs))

        # Use first batch size as canonical
        canonical_bs = batch_sizes[0]
        canonical_signals = ref_data.get(canonical_bs)

        canonical_token_ids = {}
        if canonical_signals:
            for pos_label, pos_data in canonical_signals.items():
                canonical_token_ids[pos_label] = pos_data['logprobs']['token_ids']

        for i, bs1 in enumerate(batch_sizes):
            for j, bs2 in enumerate(batch_sizes):
                if bs1 not in ref_data or bs2 not in ref_data:
                    continue

                if i == j:
                    matrix[i, j] = 0.0
                else:
                    distances = compare_signals_with_canonical_ids(
                        ref_data[bs1], ref_data[bs2], canonical_token_ids
                    )
                    matrix[i, j] = distances['logprobs_mean']

        # Display matrix
        header = "       " + "".join([f"bs={bs:>3} " for bs in batch_sizes])
        log_print(header)
        for i, bs in enumerate(batch_sizes):
            row_str = f"bs={bs:<3}"
            for j in range(n_bs):
                row_str += f"  {matrix[i,j]:6.2e}"
            log_print(row_str)

        all_matrices.append(matrix)

    # Aggregate
    avg_matrix = np.mean(all_matrices, axis=0)

    log_print("\n" + "="*80)
    log_print("AGGREGATE (average across references):")
    log_print("="*80)

    header = "       " + "".join([f"bs={bs:>3} " for bs in batch_sizes])
    log_print(header)
    for i, bs in enumerate(batch_sizes):
        row_str = f"bs={bs:<3}"
        for j in range(n_bs):
            row_str += f"  {avg_matrix[i,j]:6.2e}"
        log_print(row_str)

    # Statistics
    off_diag = avg_matrix[np.triu_indices(n_bs, k=1)]
    
    log_print(f"\nOff-diagonal stats:")
    log_print(f"  Mean: {np.mean(off_diag):.2e}")
    log_print(f"  Range: [{np.min(off_diag):.2e}, {np.max(off_diag):.2e}]")
    
    # Find equivalent pairs
    equivalent_pairs = find_equivalent_pairs(avg_matrix, batch_sizes)
    kernel_classes = format_kernel_classes(equivalent_pairs, batch_sizes)
    
    # Check if all zeros
    zero_count = np.sum(off_diag < EQUIVALENCE_THRESHOLD)
    total_count = len(off_diag)
    
    if zero_count == total_count:
        log_print(f"\n⚠ WARNING: {zero_count}/{total_count} comparisons are EXACTLY ZERO")
        log_print("  All batch sizes produce identical results (single kernel class)")
    elif zero_count > 0:
        log_print(f"\n⚠ NOTE: {zero_count}/{total_count} comparisons are equivalent (< {EQUIVALENCE_THRESHOLD})")
    
    # Display kernel classes
    log_print(f"\nKernel equivalence classes:")
    for i, cls in enumerate(kernel_classes):
        log_print(f"  Class {i+1}: {sorted(cls)}")
    
    if equivalent_pairs:
        log_print(f"\nEquivalent pairs (will be excluded from cross-hardware signal):")
        for bs1, bs2 in equivalent_pairs:
            log_print(f"  ({bs1}, {bs2})")
    
    return {
        'matrix': avg_matrix.tolist(),
        'per_reference_matrices': [m.tolist() for m in all_matrices],
        'off_diagonal_mean': float(np.mean(off_diag)),
        'off_diagonal_range': [float(np.min(off_diag)), float(np.max(off_diag))],
        'equivalent_pairs': equivalent_pairs,
        'kernel_classes': [sorted(list(cls)) for cls in kernel_classes]
    }

def analyze_cross_hardware(comparison_results, batch_sizes, signal_source='decode', 
                           equivalent_pairs=None):
    """
    Analyze the comparison matrix and determine detectability.
    Excludes equivalent pairs from SNR signal calculation.
    """
    log_print("\n" + "="*80)
    log_print(f"CROSS-HARDWARE BATCH SIZE DETECTABILITY ({signal_source.upper()})")
    log_print("="*80)
    
    if equivalent_pairs is None:
        equivalent_pairs = []
    
    # Convert to set of both orderings for easy lookup
    equiv_set = set()
    for bs1, bs2 in equivalent_pairs:
        equiv_set.add((bs1, bs2))
        equiv_set.add((bs2, bs1))
    
    dist_key = 'prefill_distances' if signal_source == 'prefill' else 'decode_distances'
    
    by_ref = {}
    for result in comparison_results:
        ref = result['ref_name']
        if ref not in by_ref:
            by_ref[ref] = {}
        key = (result['claimed_batch_size'], result['verify_batch_size'])
        by_ref[ref][key] = result

    all_matrices = []
    n_bs = len(batch_sizes)

    # Per-reference matrices
    for ref_name in sorted(by_ref.keys()):
        log_print(f"\n{ref_name}:")
        
        ref_data = by_ref[ref_name]
        matrix = np.zeros((n_bs, n_bs))
        
        for i, claimed_bs in enumerate(batch_sizes):
            for j, verify_bs in enumerate(batch_sizes):
                key = (claimed_bs, verify_bs)
                if key in ref_data:
                    matrix[i, j] = ref_data[key][dist_key]['logprobs_mean']
        
        # Display matrix
        header = "              " + "".join([f"v_bs={bs:>3} " for bs in batch_sizes])
        log_print(header)
        for i, claimed_bs in enumerate(batch_sizes):
            row_str = f"c_bs={claimed_bs:<3} "
            for j in range(n_bs):
                row_str += f"  {matrix[i,j]:8.2e}"
            log_print(row_str)
        
        all_matrices.append(matrix)

    # Aggregate matrix
    avg_matrix = np.mean(all_matrices, axis=0)
    
    log_print("\n" + "="*80)
    log_print("AGGREGATE (average across references):")
    log_print("="*80)
    
    header = "              " + "".join([f"v_bs={bs:>3} " for bs in batch_sizes])
    log_print(header)
    for i, claimed_bs in enumerate(batch_sizes):
        row_str = f"c_bs={claimed_bs:<3} "
        for j in range(n_bs):
            row_str += f"  {avg_matrix[i,j]:8.2e}"
        log_print(row_str)
    
    # Compute statistics
    diagonal = [avg_matrix[i, i] for i in range(n_bs)]
    
    # Off-diagonal: exclude equivalent pairs
    off_diagonal_all = []
    off_diagonal_meaningful = []
    excluded_pairs = []
    
    for i, bs1 in enumerate(batch_sizes):
        for j, bs2 in enumerate(batch_sizes):
            if i != j:
                off_diagonal_all.append(avg_matrix[i, j])
                if (bs1, bs2) in equiv_set:
                    excluded_pairs.append((bs1, bs2))
                else:
                    off_diagonal_meaningful.append(avg_matrix[i, j])
    
    baseline_mean = np.mean(diagonal)
    signal_all_mean = np.mean(off_diagonal_all) if off_diagonal_all else 0.0
    signal_meaningful_mean = np.mean(off_diagonal_meaningful) if off_diagonal_meaningful else 0.0
    
    snr_all = signal_all_mean / baseline_mean if baseline_mean > 0 else float('inf')
    snr_meaningful = signal_meaningful_mean / baseline_mean if baseline_mean > 0 else float('inf')
    
    log_print("\n" + "="*80)
    log_print("SNR ANALYSIS")
    log_print("="*80)
    
    log_print(f"\nDiagonal (baseline = cross-hardware, same batch size):")
    log_print(f"  Mean: {baseline_mean:.2e}")
    
    log_print(f"\nOff-diagonal (all pairs):")
    log_print(f"  Count: {len(off_diagonal_all)}")
    log_print(f"  Mean: {signal_all_mean:.2e}")
    log_print(f"  SNR (all): {snr_all:.2f}×")
    
    if equivalent_pairs:
        log_print(f"\nExcluded equivalent pairs (same kernel within-hardware):")
        for bs1, bs2 in equivalent_pairs:
            log_print(f"  ({bs1}, {bs2}) and ({bs2}, {bs1})")
        log_print(f"  Total excluded: {len(excluded_pairs)} cells")
    
    log_print(f"\nOff-diagonal (meaningful pairs only):")
    log_print(f"  Count: {len(off_diagonal_meaningful)}")
    if off_diagonal_meaningful:
        log_print(f"  Mean: {signal_meaningful_mean:.2e}")
        log_print(f"  SNR (meaningful): {snr_meaningful:.2f}×")
    else:
        log_print("  No meaningful pairs (all batch sizes are equivalent)")
    
    return {
        'matrix': avg_matrix.tolist(),
        'per_reference_matrices': [m.tolist() for m in all_matrices],
        'baseline_mean': float(baseline_mean),
        'signal_all_mean': float(signal_all_mean),
        'signal_meaningful_mean': float(signal_meaningful_mean),
        'snr_all': float(snr_all),
        'snr_meaningful': float(snr_meaningful),
        'excluded_pairs': equivalent_pairs,
        'n_excluded_cells': len(excluded_pairs),
        'n_meaningful_pairs': len(off_diagonal_meaningful)
    }

# ============================================================================
# MAIN
# ============================================================================

def main():
    global REFERENCE_SEQUENCES, DUMMY_SETS

    log_path = setup_logging()
    system_info = collect_system_info()

    mode = "VERIFICATION (teacher-forcing)" if TEACHER_FORCING else "GENERATION (decode)"
    log_print("="*80)
    log_print(f"vLLM CROSS-HARDWARE BATCH SIZE DETECTABILITY - {mode}")
    log_print("="*80)

    log_print(f"\nSystem: {system_info['hostname']}")
    log_print(f"GPU: {system_info['gpu_name']}")
    log_print(f"vLLM: {system_info['vllm_version']}")
    log_print(f"PyTorch: {system_info['torch_version']}")
    log_print(f"CUDA: {system_info['cuda_version']}")

    log_print(f"\nConfiguration:")
    log_print(f"  Model: {MODEL_NAME}")
    log_print(f"  Batch sizes: {BATCH_SIZES}")
    log_print(f"  Max tokens: {MAX_NEW_TOKENS}")
    log_print(f"  Top-k logprobs: {TOP_K_LOGPROBS}")
    if TEACHER_FORCING:
        log_print(f"  Reference file: {REFERENCE_FILE}")
    log_print()

    # Initialize vLLM
    log_print("Loading vLLM model...")
    llm = LLM(
        model=MODEL_NAME,
        download_dir=CACHE_DIR,
        dtype="bfloat16",
        trust_remote_code=True,
        gpu_memory_utilization=0.7,
    )
    tokenizer = llm.get_tokenizer()
    log_print("✓ Model loaded\n")

    # Initialize sequences from PDF
    REFERENCE_SEQUENCES, DUMMY_SETS = create_sequences_from_pdf(tokenizer)
    log_print(f"Created {len(REFERENCE_SEQUENCES)} reference sequences\n")

    output_dir = '/workspace/experiments'
    os.makedirs(output_dir, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

    if TEACHER_FORCING:
        # ================================================================
        # VERIFICATION MODE
        # ================================================================
        log_print("Loading reference file...")
        with open(REFERENCE_FILE, 'r') as f:
            reference = json.load(f)

        ref_env = reference['metadata']['environment']
        ref_gpu = ref_env['gpu_name']
        log_print(f"Reference GPU: {ref_gpu}")
        log_print(f"Verifier GPU:  {system_info['gpu_name']}")

        env_validation = validate_environment_match(ref_env, system_info)

        ref_batch_sizes = reference['metadata']['batch_sizes']
        if ref_batch_sizes != BATCH_SIZES:
            log_print(f"\n✗ BATCH SIZE MISMATCH")
            sys.exit(1)

        log_print("\n✓ Configuration matches\n")
        
        # Load equivalent pairs from reference
        prefill_equiv_pairs = reference.get('prefill_sanity_check', {}).get('equivalent_pairs', [])
        decode_equiv_pairs = reference.get('decode_sanity_check', {}).get('equivalent_pairs', [])
        
        # Convert to tuples if stored as lists
        prefill_equiv_pairs = [tuple(p) for p in prefill_equiv_pairs]
        decode_equiv_pairs = [tuple(p) for p in decode_equiv_pairs]
        
        log_print(f"Loaded equivalent pairs from reference:")
        log_print(f"  Prefill: {prefill_equiv_pairs}")
        log_print(f"  Decode: {decode_equiv_pairs}")

        comparison_results = []

        ref_by_key = {}
        for m in reference['measurements']:
            key = (m['ref_name'], m['batch_size'])
            ref_by_key[key] = m

        for ref_name in sorted(REFERENCE_SEQUENCES.keys()):
            log_print(f"\n{'='*80}")
            log_print(f"REFERENCE: {ref_name}")
            log_print("="*80)

            for claimed_bs in BATCH_SIZES:
                ref_key = (ref_name, claimed_bs)
                if ref_key not in ref_by_key:
                    log_print(f"  ⚠ No reference data for {ref_name} bs={claimed_bs}")
                    continue

                ref_data = ref_by_key[ref_key]
                log_print(f"\n  Claimed batch size: {claimed_bs}")

                for verify_bs in BATCH_SIZES:
                    is_diagonal = (claimed_bs == verify_bs)

                    log_print(f"    Verify bs={verify_bs} ({'diag' if is_diagonal else 'off'}):", end="")

                    verify_result = run_teacher_forced_verification(
                        llm, tokenizer, ref_name, ref_data,
                        verify_bs, is_diagonal
                    )

                    prefill_distances = compare_signals(
                        ref_data['prefill_signals'],
                        verify_result['prefill_signals']
                    )

                    decode_distances = compare_signals(
                        ref_data['decode_signals'],
                        verify_result['decode_signals']
                    )

                    log_print(f"      Prefill: {prefill_distances['logprobs_mean']:.2e}, Decode: {decode_distances['logprobs_mean']:.2e}")

                    comparison_results.append({
                        'ref_name': ref_name,
                        'claimed_batch_size': claimed_bs,
                        'verify_batch_size': verify_bs,
                        'is_diagonal': is_diagonal,
                        'prefill_distances': prefill_distances,
                        'decode_distances': decode_distances,
                        'verify_prefill_signals': verify_result['prefill_signals'],
                        'verify_decode_signals': verify_result['decode_signals']
                    })

        # Analyze with equivalent pair exclusion
        prefill_analysis = analyze_cross_hardware(
            comparison_results, BATCH_SIZES, signal_source='prefill',
            equivalent_pairs=prefill_equiv_pairs
        )
        decode_analysis = analyze_cross_hardware(
            comparison_results, BATCH_SIZES, signal_source='decode',
            equivalent_pairs=decode_equiv_pairs
        )

        log_print("\n" + "="*80)
        log_print("PREFILL vs DECODE COMPARISON")
        log_print("="*80)
        log_print(f"Prefill SNR (meaningful): {prefill_analysis['snr_meaningful']:.2f}×")
        log_print(f"Decode SNR (meaningful):  {decode_analysis['snr_meaningful']:.2f}×")

        results = {
            'metadata': {
                'reference_gpu': ref_gpu,
                'verifier_gpu': system_info['gpu_name'],
                'reference_file': REFERENCE_FILE,
                'reference_environment': ref_env,
                'verifier_environment': system_info,
                'environment_validation': env_validation,
                'model': MODEL_NAME,
                'batch_sizes': BATCH_SIZES,
                'timestamp': timestamp,
                'prefill_equivalent_pairs': prefill_equiv_pairs,
                'decode_equivalent_pairs': decode_equiv_pairs
            },
            'comparisons': comparison_results,
            'prefill_analysis': prefill_analysis,
            'decode_analysis': decode_analysis
        }

        filepath = os.path.join(output_dir, f"vllm_verify_{timestamp}.json")
        with open(filepath, 'w') as f:
            json.dump(results, f, indent=2)

        log_print(f"\n✓ Results saved to: {filepath}")

    else:
        # ================================================================
        # GENERATION MODE
        # ================================================================
        results = {
            'metadata': {
                'environment': system_info,
                'model': MODEL_NAME,
                'batch_sizes': BATCH_SIZES,
                'max_new_tokens': MAX_NEW_TOKENS,
                'top_k_logprobs': TOP_K_LOGPROBS,
                'timestamp': timestamp
            },
            'measurements': []
        }

        for ref_name, ref_text in REFERENCE_SEQUENCES.items():
            log_print(f"\n{'='*80}")
            log_print(f"REFERENCE: {ref_name}")
            log_print("="*80)

            min_prompt_length = compute_min_length_across_batches(
                ref_text, ref_name, tokenizer, BATCH_SIZES
            )
            log_print(f"\nGlobal minimum prompt length: {min_prompt_length} tokens\n")

            for batch_size in BATCH_SIZES:
                log_print(f"  bs={batch_size}:", end=" ")

                gen_data = run_generation(
                    llm, tokenizer, ref_text, ref_name, batch_size,
                    forced_length=min_prompt_length
                )

                results['measurements'].append({
                    'ref_name': ref_name,
                    'batch_size': batch_size,
                    'generated_ids': gen_data['generated_ids'],
                    'all_batch_generated_ids': gen_data['all_batch_generated_ids'],
                    'prompt_token_ids': gen_data['prompt_token_ids'],
                    'prompt_length': gen_data['prompt_length'],
                    'prefill_signals': gen_data['prefill_signals'],
                    'decode_signals': gen_data['decode_signals'],
                    'num_generated': gen_data['num_generated']
                })

        filepath = os.path.join(output_dir, f"vllm_decode_{timestamp}.json")
        with open(filepath, 'w') as f:
            json.dump(results, f, indent=2)

        log_print(f"\n✓ Generation results saved to: {filepath}")

        # Token consistency check
        for ref_name in REFERENCE_SEQUENCES.keys():
            log_print(f"\n--- Token consistency for {ref_name} ---")
            ref_measurements = {m['batch_size']: m for m in results['measurements'] if m['ref_name'] == ref_name}
            check_token_consistency(ref_measurements, tokenizer)

        # Within-hardware analysis
        prefill_sanity = analyze_within_hardware(
            results['measurements'], BATCH_SIZES, signal_source='prefill'
        )
        decode_sanity = analyze_within_hardware(
            results['measurements'], BATCH_SIZES, signal_source='decode'
        )

        results['prefill_sanity_check'] = prefill_sanity
        results['decode_sanity_check'] = decode_sanity

        with open(filepath, 'w') as f:
            json.dump(results, f, indent=2)

        log_print(f"\nNext step: Copy {filepath} to verifier machine")
        log_print(f"Then set TEACHER_FORCING = True and REFERENCE_FILE = '<path>'")

        file_size_mb = os.path.getsize(filepath) / (1024 * 1024)
        log_print(f"File size: {file_size_mb:.1f} MB")

    log_print(f"\n{'='*80}")
    log_print("EXPERIMENT COMPLETE")
    log_print("="*80 + "\n")

    close_logging()

if __name__ == "__main__":
    main()

vLLM CROSS-HARDWARE BATCH SIZE DETECTABILITY - GENERATION (decode)

System: c6ba2aea0d91
GPU: NVIDIA H100 80GB HBM3
vLLM: 0.11.2
PyTorch: 2.9.0+cu128
CUDA: 12.8

Configuration:
  Model: Qwen/Qwen2.5-7B-Instruct
  Batch sizes: [1, 2, 3, 4, 5, 8, 9]
  Max tokens: 20
  Top-k logprobs: 20

Loading vLLM model...
INFO 11-26 00:19:06 [utils.py:253] non-default args: {'trust_remote_code': True, 'download_dir': '/workspace/huggingface_cache', 'dtype': 'bfloat16', 'gpu_memory_utilization': 0.7, 'disable_log_stats': True, 'model': 'Qwen/Qwen2.5-7B-Instruct'}


The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.


INFO 11-26 00:19:07 [model.py:631] Resolved architecture: Qwen2ForCausalLM
INFO 11-26 00:19:07 [model.py:1745] Using max model len 32768
INFO 11-26 00:19:07 [scheduler.py:216] Chunked prefill is enabled with max_num_batched_tokens=16384.




[1;36m(EngineCore_DP0 pid=4081)[0;0m INFO 11-26 00:19:11 [core.py:93] Initializing a V1 LLM engine (v0.11.2) with config: model='Qwen/Qwen2.5-7B-Instruct', speculative_config=None, tokenizer='Qwen/Qwen2.5-7B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=32768, download_dir='/workspace/huggingface_cache', load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, data_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, structured_outputs_config=StructuredOutputsConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_parser='', reasoning_parser_plugin='', enable_in_reasoning=False), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0

Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:00<00:01,  1.57it/s]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:01<00:01,  1.51it/s]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:01<00:00,  1.55it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00,  1.55it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00,  1.55it/s]
[1;36m(EngineCore_DP0 pid=4081)[0;0m 


[1;36m(EngineCore_DP0 pid=4081)[0;0m INFO 11-26 00:19:15 [default_loader.py:314] Loading weights took 2.63 seconds
[1;36m(EngineCore_DP0 pid=4081)[0;0m INFO 11-26 00:19:16 [gpu_model_runner.py:3338] Model loading took 14.2488 GiB memory and 3.159590 seconds
[1;36m(EngineCore_DP0 pid=4081)[0;0m INFO 11-26 00:19:19 [backends.py:631] Using cache directory: /root/.cache/vllm/torch_compile_cache/f97525d4f8/rank_0_0/backbone for vLLM's torch.compile
[1;36m(EngineCore_DP0 pid=4081)[0;0m INFO 11-26 00:19:19 [backends.py:647] Dynamo bytecode transform time: 3.32 s
[1;36m(EngineCore_DP0 pid=4081)[0;0m INFO 11-26 00:19:22 [backends.py:210] Directly load the compiled graph(s) for dynamic shape from the cache, took 2.451 s
[1;36m(EngineCore_DP0 pid=4081)[0;0m INFO 11-26 00:19:23 [monitor.py:34] torch.compile takes 5.77 s in total
[1;36m(EngineCore_DP0 pid=4081)[0;0m INFO 11-26 00:19:23 [gpu_worker.py:359] Available KV cache memory: 35.38 GiB
[1;36m(EngineCore_DP0 pid=4081)[0;0m INFO

[1;36m(EngineCore_DP0 pid=4081)[0;0m 2025-11-26 00:19:23,964 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ...
[1;36m(EngineCore_DP0 pid=4081)[0;0m 2025-11-26 00:19:23,975 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends
Capturing CUDA graphs (mixed prefill-decode, PIECEWISE): 100%|██████████| 51/51 [00:01<00:00, 26.39it/s]
Capturing CUDA graphs (decode, FULL): 100%|██████████| 51/51 [00:01<00:00, 44.70it/s]


[1;36m(EngineCore_DP0 pid=4081)[0;0m INFO 11-26 00:19:27 [gpu_model_runner.py:4244] Graph capturing finished in 3 secs, took -2.16 GiB
[1;36m(EngineCore_DP0 pid=4081)[0;0m INFO 11-26 00:19:27 [core.py:250] init engine (profile, create kv cache, warmup model) took 11.33 seconds
INFO 11-26 00:19:28 [llm.py:352] Supported tasks: ['generate']
✓ Model loaded

Found 1 PDF(s)
  Loading: /workspace/Verification-for-International-AI-Governance.pdf
    → 120214 tokens
Total tokens: 120214
Creating 27 slices of 200 tokens each
Created 3 reference sequences


REFERENCE: ref_0

Global minimum prompt length: 200 tokens

  bs=1:       Prompt: 200 tokens

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

 → Final: 220 tokens (20 generated)
  bs=2:       Prompt: 200 tokens

Adding requests:   0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

 → Final: 220 tokens (20 generated)
  bs=3:       Prompt: 200 tokens

Adding requests:   0%|          | 0/3 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/3 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

 → Final: 220 tokens (20 generated)
  bs=4:       Prompt: 200 tokens

Adding requests:   0%|          | 0/4 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/4 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

 → Final: 220 tokens (20 generated)
  bs=5:       Prompt: 200 tokens

Adding requests:   0%|          | 0/5 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/5 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

 → Final: 220 tokens (20 generated)
  bs=8:       Prompt: 200 tokens

Adding requests:   0%|          | 0/8 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/8 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

 → Final: 220 tokens (20 generated)
  bs=9:       Prompt: 200 tokens

Adding requests:   0%|          | 0/9 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/9 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

 → Final: 220 tokens (20 generated)

REFERENCE: ref_1

Global minimum prompt length: 200 tokens

  bs=1:       Prompt: 200 tokens

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

 → Final: 220 tokens (20 generated)
  bs=2:       Prompt: 200 tokens

Adding requests:   0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

 → Final: 220 tokens (20 generated)
  bs=3:       Prompt: 200 tokens

Adding requests:   0%|          | 0/3 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/3 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

 → Final: 220 tokens (20 generated)
  bs=4:       Prompt: 200 tokens

Adding requests:   0%|          | 0/4 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/4 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

 → Final: 220 tokens (20 generated)
  bs=5:       Prompt: 200 tokens

Adding requests:   0%|          | 0/5 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/5 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

 → Final: 220 tokens (20 generated)
  bs=8:       Prompt: 200 tokens

Adding requests:   0%|          | 0/8 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/8 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

 → Final: 220 tokens (20 generated)
  bs=9:       Prompt: 200 tokens

Adding requests:   0%|          | 0/9 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/9 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

 → Final: 220 tokens (20 generated)

REFERENCE: ref_2

Global minimum prompt length: 200 tokens

  bs=1:       Prompt: 200 tokens

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

 → Final: 220 tokens (20 generated)
  bs=2:       Prompt: 200 tokens

Adding requests:   0%|          | 0/2 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/2 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

 → Final: 220 tokens (20 generated)
  bs=3:       Prompt: 200 tokens

Adding requests:   0%|          | 0/3 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/3 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

 → Final: 220 tokens (20 generated)
  bs=4:       Prompt: 200 tokens

Adding requests:   0%|          | 0/4 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/4 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

 → Final: 220 tokens (20 generated)
  bs=5:       Prompt: 200 tokens

Adding requests:   0%|          | 0/5 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/5 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

 → Final: 220 tokens (20 generated)
  bs=8:       Prompt: 200 tokens

Adding requests:   0%|          | 0/8 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/8 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

 → Final: 220 tokens (20 generated)
  bs=9:       Prompt: 200 tokens

Adding requests:   0%|          | 0/9 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/9 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

 → Final: 220 tokens (20 generated)

✓ Generation results saved to: /workspace/experiments/vllm_decode_20251126_001933.json

--- Token consistency for ref_0 ---

TOKEN GENERATION CONSISTENCY CHECK

Generated tokens by batch size:
  bs=1:
    IDs:  [2571, 8199, 42, 20175, 440, 17, 19, 11, 2016, 344, 5559, 50, 76367, 559, 2907, 17, 20, 11, 17, 21]
    Text: 'lexanderKatzke24,ShivaniSrivastava25,26'
    ✓
  bs=2:
    IDs:  [2571, 8199, 42, 20175, 440, 17, 19, 11, 2016, 344, 5559, 50, 76367, 559, 2907, 17, 20, 11, 17, 21]
    Text: 'lexanderKatzke24,ShivaniSrivastava25,26'
    ✓
  bs=3:
    IDs:  [2571, 8199, 42, 20175, 440, 17, 19, 11, 2016, 344, 5559, 50, 76367, 559, 2907, 17, 20, 11, 17, 21]
    Text: 'lexanderKatzke24,ShivaniSrivastava25,26'
    ✓
  bs=4:
    IDs:  [2571, 8199, 42, 20175, 440, 17, 19, 11, 2016, 344, 5559, 50, 76367, 559, 2907, 17, 20, 11, 17, 21]
    Text: 'lexanderKatzke24,ShivaniSrivastava25,26'
    ✓
  bs=5:
    IDs:  [2571, 8199, 42, 20175, 440, 17, 19, 11, 2016, 3

