In [5]:
#!/usr/bin/env python3
"""
Prefill vs Decode Experiment - Claim Verification Matrix

Fixed version that:
1. Works purely at token ID level (no decode/re-encode cycles)
2. Stores exact token IDs from decode
3. Compares logprobs for SAME token IDs between decode and prefill
4. Properly serializes numpy arrays to JSON

Usage:
    python prefill_decode_experiment.py
"""

import os
os.environ['HF_HOME'] = '/workspace/huggingface_cache'
os.environ['TRANSFORMERS_CACHE'] = '/workspace/huggingface_cache'

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import torch.nn.functional as F
import numpy as np
from datetime import datetime
import json
import socket
import platform
import sys

# ============================================================================
# CONFIGURATION
# ============================================================================

MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"
CACHE_DIR = '/workspace/huggingface_cache'

REFERENCE_SEQUENCES = {
    "ref_technical": """Large language models have revolutionized natural language processing through their ability to capture complex patterns in text data. The transformer architecture, introduced in 2017, employs self-attention mechanisms that allow the model to weigh the importance of different tokens in the input sequence. During training, these models learn to predict the next token in a sequence by optimizing a cross-entropy loss function across billions of text examples.""",
    
    "ref_narrative": """The morning sun filtered through the ancient oak trees as Sarah walked along the forest path, her boots crunching softly on the fallen leaves. She had been coming to these woods since childhood, when her grandmother first taught her to identify the different bird calls echoing through the canopy. Now, decades later, she found herself returning to this same trail whenever life felt overwhelming.""",
    
    "ref_code": """The database migration system implements a sophisticated version control mechanism for schema changes. Each migration file contains both an upgrade and downgrade function, allowing the system to roll forward or backward through schema versions. The migration engine maintains a table tracking which migrations have been applied, using timestamps and hash values to ensure consistency across different environments."""
}

DUMMY_SETS = {
    "ref_technical": [
        """Quantum computing leverages the principles of quantum mechanics to perform computations that would be intractable for classical computers. At the heart of quantum computation lies the qubit, a quantum bit that can exist in a superposition of both 0 and 1 states simultaneously. When multiple qubits are entangled, they form a quantum register capable of representing an exponentially large state space.""",
        
        """The neural architecture search algorithm systematically explores different model configurations to identify optimal designs for specific tasks. Modern approaches use reinforcement learning or evolutionary algorithms to navigate the vast search space of possible architectures. The process evaluates candidate models on validation data, gradually converging toward efficient and effective network topologies.""",
        
        """Distributed consensus protocols enable multiple nodes in a network to agree on a single value despite potential failures or malicious actors. The Byzantine Generals Problem formalizes the challenge of achieving consensus when some participants may behave arbitrarily. Practical solutions like Paxos and Raft provide mechanisms for fault-tolerant agreement in real-world systems."""
    ],
    
    "ref_narrative": [
        """The old lighthouse stood sentinel on the rocky promontory, its weathered walls bearing testament to countless storms. Local legends spoke of the keeper who vanished one winter night, leaving only his log book with a final cryptic entry. Now automated, the beacon still swept across the dark waters, a guardian whose original purpose had long been superseded by modern navigation systems.""",
        
        """Marcus found the letter tucked between the pages of his grandfather's journal, the paper yellowed and fragile with age. The handwriting was unfamiliar, yet the words spoke of events his family had never discussed. As he read, pieces of his heritage began to fall into place, revealing a story that had been deliberately hidden for three generations.""",
        
        """The jazz club occupied a basement space that seemed to exist outside of time, where smoke still hung in the air despite the ban and the music felt like it emerged from another era. Every Thursday, the same musicians gathered to play standards that few in the younger generation recognized. Yet something about the atmosphere drew people in, seeking connection to an authenticity they sensed was disappearing from the world."""
    ],
    
    "ref_code": [
        """The distributed caching layer implements consistent hashing to minimize cache invalidation when nodes are added or removed from the cluster. Virtual nodes provide better load distribution across physical servers, while replication ensures availability even during node failures. The system monitors hit rates and eviction patterns to automatically adjust cache allocation strategies.""",
        
        """The API gateway performs request routing, authentication, rate limiting, and response transformation for microservices. Each service registers its endpoints with the gateway, which maintains a dynamic routing table. The gateway implements circuit breakers to prevent cascade failures and provides detailed metrics for monitoring service health.""",
        
        """The message queue system guarantees exactly-once delivery through a combination of acknowledgments, persistent storage, and idempotency tokens. Publishers receive confirmation only after messages are durably written to replicated storage. Consumers process messages within transactions, ensuring atomic updates across message consumption and business logic execution."""
    ]
}

BATCH_SIZES = [1, 2, 4]
LAYER_INDICES = [1, 4, 10, 18, 28]
MAX_NEW_TOKENS = 20

# ============================================================================
# SYSTEM INFO
# ============================================================================

def collect_system_info():
    """Collect comprehensive environment information."""
    import transformers
    
    info = {
        "hostname": socket.gethostname(),
        "platform": platform.platform(),
        "python_version": sys.version.split()[0],
        "torch_version": torch.__version__,
        "cuda_version": torch.version.cuda if torch.cuda.is_available() else "N/A",
        "transformers_version": transformers.__version__,
        "gpu_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "N/A",
        "gpu_count": torch.cuda.device_count() if torch.cuda.is_available() else 0,
    }
    
    return info

# ============================================================================
# EXTRACTION
# ============================================================================

def extract_signals_from_output(outputs, layer_indices, position=-1):
    """
    Extract key vectors and logprobs from element 0 at specified position.
    Returns top-10 token IDs so prefill can compare the same tokens.
    """
    signals = {
        'key_vectors': {},
        'logprobs': {}
    }
    
    # Key vectors from element 0
    for layer_idx in layer_indices:
        layer_keys = outputs.past_key_values[layer_idx - 1][0]
        token_keys = layer_keys[0, :, position, :]
        key_dim = token_keys.shape[0] * token_keys.shape[1]
        key_vector = token_keys.reshape(key_dim).cpu().clone()
        signals['key_vectors'][f'layer_{layer_idx}'] = key_vector.float().numpy().tolist()
    
    # Logprobs from element 0 - get top 10
    logits = outputs.logits[0, position, :]
    log_probs = F.log_softmax(logits, dim=-1)
    top10 = torch.topk(log_probs, k=10)
    
    signals['logprobs'] = {
        'token_ids': top10.indices.cpu().tolist(),
        'log_probs': top10.values.cpu().tolist()
    }
    
    return signals


def extract_signals_for_token_ids(outputs, layer_indices, token_ids, position=-1):
    """
    Extract key vectors and logprobs for SPECIFIC token IDs.
    Used in prefill to compare the same tokens as decode.
    
    Args:
        token_ids: List of token IDs to extract logprobs for (from decode's top-10)
    """
    signals = {
        'key_vectors': {},
        'logprobs': {}
    }
    
    # Key vectors from element 0
    for layer_idx in layer_indices:
        layer_keys = outputs.past_key_values[layer_idx - 1][0]
        token_keys = layer_keys[0, :, position, :]
        key_dim = token_keys.shape[0] * token_keys.shape[1]
        key_vector = token_keys.reshape(key_dim).cpu().clone()
        signals['key_vectors'][f'layer_{layer_idx}'] = key_vector.float().numpy().tolist()
    
    # Logprobs for SPECIFIC token IDs (from decode's top 10)
    logits = outputs.logits[0, position, :]
    log_probs = F.log_softmax(logits, dim=-1)
    
    # Extract logprobs for the specified token IDs
    token_ids_tensor = torch.tensor(token_ids, device=logits.device)
    selected_logprobs = log_probs[token_ids_tensor]
    
    signals['logprobs'] = {
        'token_ids': token_ids,  # Same as decode
        'log_probs': selected_logprobs.cpu().tolist()
    }
    
    return signals

# ============================================================================
# DECODE GENERATION & EXTRACTION
# ============================================================================

def compute_min_length_across_batches(ref_text, ref_name, tokenizer, batch_sizes):
    """
    Pre-compute minimum sequence length across all batch configurations.
    This ensures all batch sizes use the same sequence length.
    """
    ref_dummies = DUMMY_SETS[ref_name]
    min_length = float('inf')
    
    for batch_size in batch_sizes:
        if batch_size == 1:
            batch_texts = [ref_text]
        elif batch_size == 2:
            batch_texts = [ref_text, ref_dummies[0]]
        elif batch_size == 4:
            batch_texts = [ref_text] + ref_dummies[:3]
        
        token_lengths = [len(tokenizer.encode(t, add_special_tokens=True)) for t in batch_texts]
        min_length = min(min_length, min(token_lengths))
    
    return min_length


def run_decode_with_extraction(model, tokenizer, ref_text, ref_name, batch_size, 
                                layer_indices, forced_length=None):
    """
    Run decode generation and extract signals from last 3 generation steps.
    Works purely at token ID level - no decode/re-encode cycles.
    
    Returns exact token IDs used for prefill to reproduce.
    """
    torch.cuda.empty_cache()
    
    # Build batch texts
    ref_dummies = DUMMY_SETS[ref_name]
    if batch_size == 1:
        batch_texts = [ref_text]
    elif batch_size == 2:
        batch_texts = [ref_text, ref_dummies[0]]
    elif batch_size == 4:
        batch_texts = [ref_text] + ref_dummies[:3]
    
    # Tokenize ONCE - store the exact token IDs
    all_token_ids = [tokenizer.encode(t, add_special_tokens=True) for t in batch_texts]
    
    # Truncate at token ID level
    if forced_length is not None:
        min_length = forced_length
    else:
        min_length = min(len(ids) for ids in all_token_ids)
    
    truncated_token_ids = [ids[:min_length] for ids in all_token_ids]
    
    # Build input tensors DIRECTLY from token IDs - no decode/re-encode
    input_ids = torch.tensor(truncated_token_ids, dtype=torch.long, device='cuda')
    attention_mask = torch.ones_like(input_ids)
    
    inputs = {
        'input_ids': input_ids,
        'attention_mask': attention_mask
    }
    
    prompt_length = input_ids.shape[1]
    print(f"      Prompt length: {prompt_length} tokens", end="")
    
    # Track generation for ALL batch positions
    all_batch_generated_ids = [[] for _ in range(batch_size)]
    generation_signals = []
    
    # FIRST STEP: Prefill with full prompt
    with torch.no_grad():
        outputs = model(**inputs, use_cache=True)
    
    past_kv = outputs.past_key_values
    
    # Get first token
    next_tokens = outputs.logits[:, -1, :].argmax(dim=-1)
    for batch_idx in range(batch_size):
        all_batch_generated_ids[batch_idx].append(next_tokens[batch_idx].item())
    
    # Extract signals from element 0
    signals = extract_signals_from_output(outputs, layer_indices, position=-1)
    
    # Absolute position: input length - 1
    absolute_position_index = inputs['input_ids'].shape[1] - 1
    
    generation_signals.append({
        'step': 0,
        'absolute_position': absolute_position_index,
        'signals': signals
    })
    
    # Update attention mask
    attention_mask = torch.cat([
        inputs['attention_mask'], 
        torch.ones((inputs['attention_mask'].shape[0], 1), device='cuda')
    ], dim=1)
    
    # SUBSEQUENT STEPS: True autoregressive with cache
    for step in range(1, MAX_NEW_TOKENS):
        new_inputs = {
            'input_ids': next_tokens.unsqueeze(1),
            'attention_mask': attention_mask,
            'past_key_values': past_kv,
            'use_cache': True
        }
        
        with torch.no_grad():
            outputs = model(**new_inputs)
        
        past_kv = outputs.past_key_values
        
        # Get next tokens
        next_tokens = outputs.logits[:, -1, :].argmax(dim=-1)
        for batch_idx in range(batch_size):
            all_batch_generated_ids[batch_idx].append(next_tokens[batch_idx].item())
        
        # Extract signals from element 0
        signals = extract_signals_from_output(outputs, layer_indices, position=-1)
        
        # Absolute position in full sequence
        current_cache_length = past_kv[0][0].shape[2]
        absolute_position_index = current_cache_length - 1
        
        generation_signals.append({
            'step': step,
            'absolute_position': absolute_position_index,
            'signals': signals
        })
        
        # Update attention mask
        attention_mask = torch.cat([
            attention_mask, 
            torch.ones((attention_mask.shape[0], 1), device='cuda')
        ], dim=1)
        
        # Check for EOS in position 0
        if all_batch_generated_ids[0][-1] == tokenizer.eos_token_id:
            break
    
    # Extract last 3
    num_generated = len(generation_signals)
    if num_generated >= 3:
        last_3_signals = {
            'pos_-3': generation_signals[-3],
            'pos_-2': generation_signals[-2],
            'pos_-1': generation_signals[-1]
        }
    elif num_generated == 2:
        last_3_signals = {
            'pos_-2': generation_signals[-2],
            'pos_-1': generation_signals[-1]
        }
    elif num_generated == 1:
        last_3_signals = {
            'pos_-1': generation_signals[-1]
        }
    else:
        last_3_signals = {}
    
    del outputs
    del inputs
    torch.cuda.empty_cache()
    
    final_length = prompt_length + num_generated
    print(f" → Final: {final_length} tokens ({num_generated} generated)")
    
    return {
        'generated_ids': all_batch_generated_ids[0],  # Position 0 tokens
        'all_batch_generated_ids': all_batch_generated_ids,  # ALL positions
        'prompt_token_ids': truncated_token_ids,  # EXACT token IDs used in prompt
        'prompt_length': prompt_length,
        'signals': last_3_signals,
        'num_generated': num_generated
    }

# ============================================================================
# PREFILL REPRODUCTION
# ============================================================================

def run_prefill_verification(model, tokenizer, ref_name, decode_metadata, 
                             batch_size, layer_indices, is_diagonal=False):
    """
    Run prefill to verify decoder's claim.
    Uses EXACT token IDs from decode - no tokenization artifacts.
    
    Args:
        is_diagonal: True if verifying honest claim (same bs), uses actual neighbors
    """
    torch.cuda.empty_cache()
    
    # Get exact prompt IDs used in decode for element 0
    ref_prompt_ids = decode_metadata['prompt_token_ids'][0]
    generated_ids = decode_metadata['generated_ids']
    
    # Build extended sequence at ID level
    extended_ref_ids = ref_prompt_ids + generated_ids
    
    print(f" [Ext: {len(extended_ref_ids)} tokens", end="")
    
    # Build batch
    if batch_size == 1:
        batch_ids = [extended_ref_ids]
    else:
        batch_ids = [extended_ref_ids]
        
        if is_diagonal:
            # Use EXACT neighbor IDs from decode
            print(f", exact neighbors", end="")
            for i in range(1, batch_size):
                neighbor_prompt_ids = decode_metadata['prompt_token_ids'][i]
                neighbor_gen_ids = decode_metadata['all_batch_generated_ids'][i]
                extended_neighbor_ids = neighbor_prompt_ids + neighbor_gen_ids
                batch_ids.append(extended_neighbor_ids)
        else:
            # Off-diagonal: construct length-matched neighbors at ID level
            print(f", arb neighbors", end="")
            ref_dummies = DUMMY_SETS[ref_name]
            target_length = len(extended_ref_ids)
            
            for i in range(batch_size - 1):
                dummy_ids = tokenizer.encode(ref_dummies[i], add_special_tokens=True)
                
                # Truncate or repeat to match length
                if len(dummy_ids) >= target_length:
                    dummy_ids = dummy_ids[:target_length]
                else:
                    # Repeat tokens to reach target length
                    repeats = (target_length // len(dummy_ids)) + 1
                    dummy_ids = (dummy_ids * repeats)[:target_length]
                
                batch_ids.append(dummy_ids)
    
    # Build input tensors DIRECTLY from IDs - no decode/encode cycle
    input_ids = torch.tensor(batch_ids, dtype=torch.long, device='cuda')
    attention_mask = torch.ones_like(input_ids)
    
    inputs = {
        'input_ids': input_ids,
        'attention_mask': attention_mask
    }
    
    print(f"]", end="")
    
    # Single forward pass
    with torch.no_grad():
        outputs = model(**inputs, use_cache=True)
    
    # Extract at same positions as decode, using decode's token IDs for logprobs
    prefill_signals = {}
    extraction_positions = []
    
    for pos_label, decode_step_data in decode_metadata['signals'].items():
        decode_abs_pos = decode_step_data['absolute_position']
        decode_token_ids = decode_step_data['signals']['logprobs']['token_ids']
        
        extraction_positions.append(decode_abs_pos)
        
        # Extract using SAME token IDs as decode for logprobs
        signals = extract_signals_for_token_ids(
            outputs, layer_indices, decode_token_ids, position=decode_abs_pos
        )
        
        prefill_signals[pos_label] = {
            'absolute_position': decode_abs_pos,
            'signals': signals
        }
    
    print(f" → Extract at: {extraction_positions}")
    
    del outputs
    del inputs
    torch.cuda.empty_cache()
    
    return prefill_signals

# ============================================================================
# ANALYSIS
# ============================================================================

def compute_l2_distance(vec1, vec2):
    """Compute L2 distance between two vectors."""
    v1 = np.array(vec1)
    v2 = np.array(vec2)
    return float(np.linalg.norm(v1 - v2))


def compute_logprob_distance(logprobs1, logprobs2):
    """
    Compute L2 distance between logprob distributions.
    Both should have same token_ids in same order.
    """
    probs1 = np.array(logprobs1['log_probs'])
    probs2 = np.array(logprobs2['log_probs'])
    return float(np.linalg.norm(probs1 - probs2))


def compare_claim_vs_verification(decode_data, prefill_data, layer_indices):
    """Compare decoder's claim against prefiller's verification."""
    common_positions = set(decode_data['signals'].keys()) & set(prefill_data.keys())
    
    all_key_dists = []
    all_logprob_dists = []
    
    for pos_label in common_positions:
        decode_signals = decode_data['signals'][pos_label]['signals']
        prefill_signals = prefill_data[pos_label]['signals']
        
        # Key vectors
        for layer_name in decode_signals['key_vectors'].keys():
            dist = compute_l2_distance(
                decode_signals['key_vectors'][layer_name],
                prefill_signals['key_vectors'][layer_name]
            )
            all_key_dists.append(dist)
        
        # Logprobs
        dist = compute_logprob_distance(decode_signals['logprobs'], prefill_signals['logprobs'])
        all_logprob_dists.append(dist)
    
    return {
        'key_vectors_max': max(all_key_dists) if all_key_dists else 0.0,
        'key_vectors_mean': np.mean(all_key_dists) if all_key_dists else 0.0,
        'logprobs_max': max(all_logprob_dists) if all_logprob_dists else 0.0,
        'logprobs_mean': np.mean(all_logprob_dists) if all_logprob_dists else 0.0
    }


def check_token_consistency(decode_measurements):
    """Check if element 0 generates same tokens across batch sizes."""
    print("\n" + "="*80)
    print("TOKEN GENERATION CONSISTENCY CHECK")
    print("="*80)
    
    tokens_by_bs = {}
    for bs, data in decode_measurements.items():
        tokens_by_bs[bs] = data['generated_ids']
    
    bs_list = sorted(tokens_by_bs.keys())
    reference_tokens = tokens_by_bs[bs_list[0]]
    
    all_same = True
    print("\nGenerated tokens by batch size:")
    for bs in bs_list:
        tokens = tokens_by_bs[bs]
        match_str = "✓" if tokens == reference_tokens else "✗ DIFFERENT"
        print(f"  bs={bs}: {tokens[:10]}{'...' if len(tokens) > 10 else ''} {match_str}")
        if tokens != reference_tokens:
            all_same = False
    
    if all_same:
        print("\n✓ Element 0 generates IDENTICAL tokens across all batch sizes")
    else:
        print("\n⚠ Element 0 generates DIFFERENT tokens across batch sizes")
        print("  This is expected if batch composition affects generation")
    
    return all_same


def analyze_experiment(measurements, layer_indices):
    """Compute and display claim-verification matrices."""
    print("\n" + "="*80)
    print("ANALYSIS: CLAIM vs VERIFICATION MATRIX")
    print("="*80)
    
    # Group by reference
    by_ref = {}
    for m in measurements:
        ref = m['ref_name']
        if ref not in by_ref:
            by_ref[ref] = {}
        key = (m['decode_batch_size'], m['prefill_batch_size'])
        by_ref[ref][key] = m
    
    # Compute matrices for each reference
    all_matrices = {'key_vectors': [], 'logprobs': []}
    
    for ref_name in sorted(by_ref.keys()):
        print(f"\n{'='*80}")
        print(f"{ref_name.upper()}")
        print(f"{'='*80}")
        
        ref_data = by_ref[ref_name]
        
        # Compute distance matrices
        matrix_key = np.zeros((3, 3))
        matrix_logprob = np.zeros((3, 3))
        
        for i, decode_bs in enumerate(BATCH_SIZES):
            for j, prefill_bs in enumerate(BATCH_SIZES):
                key = (decode_bs, prefill_bs)
                if key in ref_data:
                    m = ref_data[key]
                    distances = compare_claim_vs_verification(
                        m['decode_data'],
                        m['prefill_data'],
                        layer_indices
                    )
                    matrix_key[i, j] = distances['key_vectors_max']
                    matrix_logprob[i, j] = distances['logprobs_max']
        
        # Display matrices
        print("\nKey Vectors (max L2 distance):")
        print("                Verify bs=1   Verify bs=2   Verify bs=4")
        for i, decode_bs in enumerate(BATCH_SIZES):
            row_str = f"Claim bs={decode_bs}   "
            for j in range(3):
                row_str += f"  {matrix_key[i,j]:11.2e}"
            print(row_str)
        
        print("\nLogprobs (max L2 distance):")
        print("                Verify bs=1   Verify bs=2   Verify bs=4")
        for i, decode_bs in enumerate(BATCH_SIZES):
            row_str = f"Claim bs={decode_bs}   "
            for j in range(3):
                row_str += f"  {matrix_logprob[i,j]:11.2e}"
            print(row_str)
        
        all_matrices['key_vectors'].append(matrix_key)
        all_matrices['logprobs'].append(matrix_logprob)
    
    # Aggregate statistics
    print("\n" + "="*80)
    print("AGGREGATE STATISTICS (AVERAGE ACROSS REFERENCES)")
    print("="*80)
    
    results = {}
    
    for signal_type in ['key_vectors', 'logprobs']:
        matrices = all_matrices[signal_type]
        avg_matrix = np.mean(matrices, axis=0)
        
        print(f"\n{signal_type.upper()}:")
        print("                Verify bs=1   Verify bs=2   Verify bs=4")
        for i, decode_bs in enumerate(BATCH_SIZES):
            row_str = f"Claim bs={decode_bs}   "
            for j in range(3):
                row_str += f"  {avg_matrix[i,j]:11.2e}"
            print(row_str)
        
        # Extract diagonal (noise) and off-diagonal (signal)
        diagonal = np.array([avg_matrix[i, i] for i in range(3)])
        off_diagonal = np.array([avg_matrix[i, j] for i in range(3) for j in range(3) if i != j])
        
        noise_mean = np.mean(diagonal)
        noise_std = np.std(diagonal)
        signal_mean = np.mean(off_diagonal)
        signal_std = np.std(off_diagonal)
        snr = signal_mean / noise_mean if noise_mean > 0 else float('inf')
        
        print(f"\n  Diagonal (noise - claim matches verification):")
        print(f"    μ = {noise_mean:.2e}, σ = {noise_std:.2e}")
        print(f"    Values: {[f'{d:.2e}' for d in diagonal]}")
        
        print(f"\n  Off-diagonal (signal - claim doesn't match verification):")
        print(f"    μ = {signal_mean:.2e}, σ = {signal_std:.2e}")
        
        print(f"\n  SNR (signal/noise): {snr:.2f}×")
        
        results[signal_type] = {
            'matrix': avg_matrix.tolist(),
            'noise_mean': float(noise_mean),
            'noise_std': float(noise_std),
            'signal_mean': float(signal_mean),
            'signal_std': float(signal_std),
            'snr': float(snr)
        }
    
    # Conclusion
    print("\n" + "="*80)
    print("CONCLUSION")
    print("="*80)
    
    threshold = 1.5
    
    key_snr = results['key_vectors']['snr']
    log_snr = results['logprobs']['snr']
    
    print(f"\nKey Vectors: SNR = {key_snr:.2f}× {'✓ DETECTABLE' if key_snr >= threshold else '✗ NOT DETECTABLE'}")
    print(f"Logprobs:    SNR = {log_snr:.2f}× {'✓ DETECTABLE' if log_snr >= threshold else '✗ NOT DETECTABLE'}")
    
    if key_snr >= threshold and log_snr >= threshold:
        print("\n✓ BATCH SIZE MISMATCHES ARE DETECTABLE")
        print("  → Prefiller can detect when decoder lies about batch size")
        print("  → Verification fails when claimed bs ≠ verification bs")
    elif key_snr >= threshold or log_snr >= threshold:
        print("\n~ PARTIAL DETECTABILITY")
        det = 'Key vectors' if key_snr >= threshold else 'Logprobs'
        print(f"  → {det} can detect batch size mismatches")
    else:
        print("\n✗ BATCH SIZE MISMATCHES NOT RELIABLY DETECTABLE")
        print("  → Verification distance similar whether claimed bs matches or not")
        print("  → Cannot reliably catch decoder lying about batch size")
    
    # Convert numpy arrays to lists for JSON serialization
    matrices_serializable = {
        'key_vectors': [m.tolist() for m in all_matrices['key_vectors']],
        'logprobs': [m.tolist() for m in all_matrices['logprobs']]
    }
    
    return {
        'matrices': matrices_serializable,
        'statistics': results
    }

# ============================================================================
# MAIN
# ============================================================================

def main():
    system_info = collect_system_info()
    
    print("="*80)
    print("PREFILL vs DECODE EXPERIMENT - CLAIM VERIFICATION")
    print("="*80)
    print(f"\nEnvironment:")
    for k, v in system_info.items():
        print(f"  {k}: {v}")
    
    print(f"\nConfiguration:")
    print(f"  Model: {MODEL_NAME}")
    print(f"  Layers: {LAYER_INDICES}")
    print(f"  Batch sizes: {BATCH_SIZES}")
    print(f"  References: {len(REFERENCE_SEQUENCES)}")
    print(f"  Max tokens: {MAX_NEW_TOKENS}")
    print()
    print("Experiment design:")
    print("  - Decoder makes claims at bs=1,2,4")
    print("  - Prefiller verifies each claim at bs=1,2,4")
    print("  - Matrix: 3 claims × 3 verifications = 9 comparisons per reference")
    print("  - Works purely at token ID level (no tokenization artifacts)")
    print()
    
    # Load model
    print("Loading model...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        dtype=torch.bfloat16,
        cache_dir=CACHE_DIR,
        low_cpu_mem_usage=True,
        device_map="auto"
    )
    print("✓ Model loaded\n")
    
    # Results structure
    results = {
        'metadata': {
            'environment': system_info,
            'model': MODEL_NAME,
            'layer_indices': LAYER_INDICES,
            'batch_sizes': BATCH_SIZES,
            'max_new_tokens': MAX_NEW_TOKENS,
            'timestamp': datetime.now().isoformat()
        },
        'measurements': []
    }
    
    # Run experiments
    for ref_name, ref_text in REFERENCE_SEQUENCES.items():
        print(f"\n{'='*80}")
        print(f"REFERENCE: {ref_name}")
        print(f"{'='*80}")
        
        # Pre-compute minimum length across all batch configurations
        min_prompt_length = compute_min_length_across_batches(
            ref_text, ref_name, tokenizer, BATCH_SIZES
        )
        print(f"\nGlobal minimum prompt length: {min_prompt_length} tokens")
        print("(All batch sizes will use this length)\n")
        
        # Step 1: Generate claims at each batch size
        print(f"Decode claims:")
        decode_measurements = {}
        
        for decode_bs in BATCH_SIZES:
            print(f"  bs={decode_bs}...", end=" ")
            decode_data = run_decode_with_extraction(
                model, tokenizer, ref_text, ref_name, decode_bs, LAYER_INDICES,
                forced_length=min_prompt_length
            )
            
            # Show extraction positions
            extract_positions = [step['absolute_position'] for step in decode_data['signals'].values()]
            print(f"    Extract positions: {extract_positions}")
            
            decode_measurements[decode_bs] = decode_data
        
        # Check token consistency
        check_token_consistency(decode_measurements)
        
        # Step 2: For each claim, verify at all batch sizes
        print(f"\nPrefill verifications:")
        
        for decode_bs in BATCH_SIZES:
            decode_data = decode_measurements[decode_bs]
            print(f"\n  Verifying claim from bs={decode_bs}:")
            
            for prefill_bs in BATCH_SIZES:
                print(f"    with bs={prefill_bs}...", end="")
                
                # Determine if this is diagonal (honest claim: same batch size)
                is_diagonal = (decode_bs == prefill_bs)
                
                prefill_data = run_prefill_verification(
                    model, tokenizer, ref_name, decode_data, prefill_bs,
                    LAYER_INDICES, is_diagonal=is_diagonal
                )
                
                results['measurements'].append({
                    'ref_name': ref_name,
                    'decode_batch_size': decode_bs,
                    'prefill_batch_size': prefill_bs,
                    'is_diagonal': is_diagonal,
                    'decode_data': {
                        'generated_ids': decode_data['generated_ids'],
                        'all_batch_generated_ids': decode_data['all_batch_generated_ids'],
                        'prompt_token_ids': decode_data['prompt_token_ids'],
                        'prompt_length': decode_data['prompt_length'],
                        'signals': decode_data['signals'],
                        'num_generated': decode_data['num_generated']
                    },
                    'prefill_data': prefill_data
                })
    
    # Save data
    output_dir = '/workspace/experiments'
    os.makedirs(output_dir, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filepath = os.path.join(output_dir, f"prefill_decode_{timestamp}.json")
    
    with open(filepath, 'w') as f:
        json.dump(results, f, indent=2)
    
    print(f"\n✓ Data saved to: {filepath}")
    
    # Run analysis
    analysis_results = analyze_experiment(results['measurements'], LAYER_INDICES)
    
    # Save with analysis
    results['analysis'] = analysis_results
    with open(filepath, 'w') as f:
        json.dump(results, f, indent=2)
    
    file_size_mb = os.path.getsize(filepath) / (1024 * 1024)
    print(f"✓ Analysis saved (file size: {file_size_mb:.1f} MB)")
    
    print(f"\n{'='*80}")
    print("EXPERIMENT COMPLETE")
    print(f"{'='*80}\n")


if __name__ == "__main__":
    try:
        get_ipython()
    except NameError:
        pass
    
    main()

PREFILL vs DECODE EXPERIMENT - CLAIM VERIFICATION

Environment:
  hostname: 812aecb237c2
  platform: Linux-6.8.0-49-generic-x86_64-with-glibc2.39
  python_version: 3.12.11
  torch_version: 2.8.0+cu126
  cuda_version: 12.6
  transformers_version: 4.57.1
  gpu_name: NVIDIA A100-SXM4-80GB
  gpu_count: 1

Configuration:
  Model: Qwen/Qwen2.5-7B-Instruct
  Layers: [1, 4, 10, 18, 28]
  Batch sizes: [1, 2, 4]
  References: 3
  Max tokens: 20

Experiment design:
  - Decoder makes claims at bs=1,2,4
  - Prefiller verifies each claim at bs=1,2,4
  - Matrix: 3 claims × 3 verifications = 9 comparisons per reference
  - Works purely at token ID level (no tokenization artifacts)

Loading model...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

✓ Model loaded


REFERENCE: ref_technical

Global minimum prompt length: 56 tokens
(All batch sizes will use this length)

Decode claims:
  bs=1...       Prompt length: 56 tokens → Final: 76 tokens (20 generated)
    Extract positions: [72, 73, 74]
  bs=2...       Prompt length: 56 tokens → Final: 76 tokens (20 generated)
    Extract positions: [72, 73, 74]
  bs=4...       Prompt length: 56 tokens → Final: 76 tokens (20 generated)
    Extract positions: [72, 73, 74]

TOKEN GENERATION CONSISTENCY CHECK

Generated tokens by batch size:
  bs=1: [279, 1614, 374, 14900, 311, 12767, 14713, 315, 1467, 821]... ✓
  bs=2: [279, 1614, 374, 14900, 311, 12767, 14713, 315, 1467, 821]... ✓
  bs=4: [279, 1614, 374, 14900, 311, 12767, 14713, 315, 1467, 821]... ✓

✓ Element 0 generates IDENTICAL tokens across all batch sizes

Prefill verifications:

  Verifying claim from bs=1:
    with bs=1... [Ext: 76 tokens] → Extract at: [72, 73, 74]
    with bs=2... [Ext: 76 tokens, arb neighbors] → Extract at: [72