In [None]:
#!/usr/bin/env python3
"""
Prefill vs Decode Experiment - Batch Size Detectability

Fixed version that:
1. Works purely at token ID level (no decode/re-encode cycles)
2. Stores exact token IDs from decode
3. Compares logprobs for SAME token IDs between decode and prefill
4. Properly serializes numpy arrays to JSON
5. ADDED: Within-mode batch size comparisons (decode vs decode, prefill vs prefill)

Usage:
    python prefill_decode_experiment.py
"""

import os
os.environ['HF_HOME'] = '/workspace/huggingface_cache'
os.environ['TRANSFORMERS_CACHE'] = '/workspace/huggingface_cache'

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import torch.nn.functional as F
import numpy as np
from datetime import datetime
import json
import socket
import platform
import sys
import glob
import PyPDF2
from datetime import datetime as dt_for_log

# ============================================================================
# LOGGING SETUP
# ============================================================================

LOG_FILE = None

def setup_logging(output_dir='/workspace/experiments'):
    """Setup logging to file."""
    global LOG_FILE
    os.makedirs(output_dir, exist_ok=True)
    timestamp = dt_for_log.now().strftime("%Y%m%d_%H%M%S")
    log_path = os.path.join(output_dir, f"experiment_log_{timestamp}.txt")
    LOG_FILE = open(log_path, 'w')
    return log_path

def log_print(*args, **kwargs):
    """Print to both console and log file."""
    print(*args, **kwargs)
    if LOG_FILE:
        # Remove 'file' from kwargs if present, then write to log
        log_kwargs = {k: v for k, v in kwargs.items() if k != 'file'}
        print(*args, **log_kwargs, file=LOG_FILE)
        LOG_FILE.flush()

def close_logging():
    """Close log file."""
    global LOG_FILE
    if LOG_FILE:
        LOG_FILE.close()
        LOG_FILE = None

# ============================================================================
# CONFIGURATION
# ============================================================================

MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"
CACHE_DIR = '/workspace/huggingface_cache'

# Will be initialized from PDF in main()
REFERENCE_SEQUENCES = None
DUMMY_SETS = None

BATCH_SIZES = [4, 5, 8, 9, 16, 17]
LAYER_INDICES = [28]
MAX_NEW_TOKENS = 20
TOKENS_PER_SLICE = 512


def load_pdf_text(pdf_path):
    """Load text content from a PDF file."""
    text = ""
    with open(pdf_path, 'rb') as f:
        reader = PyPDF2.PdfReader(f)
        for page in reader.pages:
            page_text = page.extract_text()
            if page_text:
                text += page_text + " "
    return text.strip()


def create_sequences_from_pdf(tokenizer, num_references=3):
    """
    Load all PDFs from current directory, tokenize, and split into equal-length slices.
    Returns REFERENCE_SEQUENCES and DUMMY_SETS dictionaries.
    """
    # Find all PDFs in current directory
    try:
        script_dir = os.path.dirname(os.path.abspath(__file__))
    except NameError:
        script_dir = os.getcwd()  # Jupyter notebook fallback
    
    pdf_files = glob.glob(os.path.join(script_dir, "*.pdf"))
    if not pdf_files:
        pdf_files = glob.glob("*.pdf")
    if not pdf_files:
        pdf_files = glob.glob("/workspace/*.pdf")
    if not pdf_files:
        raise FileNotFoundError("No PDF file found in current directory or /workspace")
    
    log_print(f"Found {len(pdf_files)} PDF(s)")
    
    # Load and tokenize all PDFs
    all_tokens = []
    for pdf_path in pdf_files:
        log_print(f"  Loading: {pdf_path}")
        text = load_pdf_text(pdf_path)
        tokens = tokenizer.encode(text, add_special_tokens=True)
        all_tokens.extend(tokens)
        log_print(f"    → {len(tokens)} tokens")
    
    log_print(f"Total tokens: {len(all_tokens)}")
    
    # Calculate number of slices needed
    max_batch_size = max(BATCH_SIZES)
    slices_needed = num_references * max_batch_size
    tokens_needed = slices_needed * TOKENS_PER_SLICE
    
    if len(all_tokens) < tokens_needed:
        raise ValueError(f"PDFs too short. Need {tokens_needed} tokens ({slices_needed} slices × {TOKENS_PER_SLICE} tokens) but only have {len(all_tokens)} tokens")
    
    log_print(f"Creating {slices_needed} slices of {TOKENS_PER_SLICE} tokens each")
    
    slices = []
    for i in range(slices_needed):
        start = i * TOKENS_PER_SLICE
        end = start + TOKENS_PER_SLICE
        slice_tokens = all_tokens[start:end]
        slice_text = tokenizer.decode(slice_tokens)
        slices.append(slice_text)
    
    # Build reference sequences and dummy sets
    reference_sequences = {}
    dummy_sets = {}
    
    for ref_idx in range(num_references):
        ref_name = f"ref_{ref_idx}"
        base_idx = ref_idx * max_batch_size
        
        reference_sequences[ref_name] = slices[base_idx]
        dummy_sets[ref_name] = slices[base_idx + 1 : base_idx + max_batch_size]
    
    return reference_sequences, dummy_sets

# ============================================================================
# SYSTEM INFO
# ============================================================================

def collect_system_info():
    """Collect comprehensive environment information."""
    import transformers
    
    info = {
        "hostname": socket.gethostname(),
        "platform": platform.platform(),
        "python_version": sys.version.split()[0],
        "torch_version": torch.__version__,
        "cuda_version": torch.version.cuda if torch.cuda.is_available() else "N/A",
        "transformers_version": transformers.__version__,
        "gpu_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "N/A",
        "gpu_count": torch.cuda.device_count() if torch.cuda.is_available() else 0,
    }
    
    return info

# ============================================================================
# EXTRACTION
# ============================================================================

def extract_signals_from_output(outputs, layer_indices, position=-1):
    """
    Extract key vectors and logprobs from element 0 at specified position.
    Returns top-10 token IDs so prefill can compare the same tokens.
    """
    signals = {
        'key_vectors': {},
        'logprobs': {}
    }
    
    # Key vectors from element 0
    for layer_idx in layer_indices:
        layer_keys = outputs.past_key_values[layer_idx - 1][0]
        token_keys = layer_keys[0, :, position, :]
        key_dim = token_keys.shape[0] * token_keys.shape[1]
        key_vector = token_keys.reshape(key_dim).cpu().clone()
        signals['key_vectors'][f'layer_{layer_idx}'] = key_vector.float().numpy().tolist()
    
    # Logprobs from element 0 - get top 10
    logits = outputs.logits[0, position, :]
    log_probs = F.log_softmax(logits, dim=-1)
    top10 = torch.topk(log_probs, k=10)
    
    signals['logprobs'] = {
        'token_ids': top10.indices.cpu().tolist(),
        'log_probs': top10.values.cpu().tolist()
    }
    
    return signals


def extract_signals_for_token_ids(outputs, layer_indices, token_ids, position=-1):
    """
    Extract key vectors and logprobs for SPECIFIC token IDs.
    Used in prefill to compare the same tokens as decode.
    
    Args:
        token_ids: List of token IDs to extract logprobs for (from decode's top-10)
    """
    signals = {
        'key_vectors': {},
        'logprobs': {}
    }
    
    # Key vectors from element 0
    for layer_idx in layer_indices:
        layer_keys = outputs.past_key_values[layer_idx - 1][0]
        token_keys = layer_keys[0, :, position, :]
        key_dim = token_keys.shape[0] * token_keys.shape[1]
        key_vector = token_keys.reshape(key_dim).cpu().clone()
        signals['key_vectors'][f'layer_{layer_idx}'] = key_vector.float().numpy().tolist()
    
    # Logprobs for SPECIFIC token IDs (from decode's top 10)
    logits = outputs.logits[0, position, :]
    log_probs = F.log_softmax(logits, dim=-1)
    
    # Extract logprobs for the specified token IDs
    token_ids_tensor = torch.tensor(token_ids, device=logits.device)
    selected_logprobs = log_probs[token_ids_tensor]
    
    signals['logprobs'] = {
        'token_ids': token_ids,  # Same as decode
        'log_probs': selected_logprobs.cpu().tolist()
    }
    
    return signals

# ============================================================================
# DECODE GENERATION & EXTRACTION
# ============================================================================

def compute_min_length_across_batches(ref_text, ref_name, tokenizer, batch_sizes):
    """
    Pre-compute minimum sequence length across all batch configurations.
    This ensures all batch sizes use the same sequence length.
    """
    ref_dummies = DUMMY_SETS[ref_name]
    min_length = float('inf')
    
    for batch_size in batch_sizes:
        if batch_size == 1:
            batch_texts = [ref_text]
        else:
            batch_texts = [ref_text] + ref_dummies[:batch_size-1]
        
        token_lengths = [len(tokenizer.encode(t, add_special_tokens=True)) for t in batch_texts]
        min_length = min(min_length, min(token_lengths))
    
    return min_length


def run_decode_with_extraction(model, tokenizer, ref_text, ref_name, batch_size, 
                                layer_indices, forced_length=None):
    """
    Run decode generation and extract signals from last 3 generation steps.
    Works purely at token ID level - no decode/re-encode cycles.
    
    Returns exact token IDs used for prefill to reproduce.
    """
    torch.cuda.empty_cache()
    
    # Build batch texts
    ref_dummies = DUMMY_SETS[ref_name]
    if batch_size == 1:
        batch_texts = [ref_text]
    else:
        batch_texts = [ref_text] + ref_dummies[:batch_size-1]
    
    # Tokenize ONCE - store the exact token IDs
    all_token_ids = [tokenizer.encode(t, add_special_tokens=True) for t in batch_texts]
    
    # Truncate at token ID level
    if forced_length is not None:
        min_length = forced_length
    else:
        min_length = min(len(ids) for ids in all_token_ids)
    
    truncated_token_ids = [ids[:min_length] for ids in all_token_ids]
    
    # Build input tensors DIRECTLY from token IDs - no decode/re-encode
    input_ids = torch.tensor(truncated_token_ids, dtype=torch.long, device='cuda')
    attention_mask = torch.ones_like(input_ids)
    
    inputs = {
        'input_ids': input_ids,
        'attention_mask': attention_mask
    }
    
    prompt_length = input_ids.shape[1]
    log_print(f"      Prompt length: {prompt_length} tokens", end="")
    
    # Track generation for ALL batch positions
    all_batch_generated_ids = [[] for _ in range(batch_size)]
    generation_signals = []
    
    # FIRST STEP: Prefill with full prompt
    with torch.no_grad():
        outputs = model(**inputs, use_cache=True)
    
    past_kv = outputs.past_key_values
    
    # Get first token
    next_tokens = outputs.logits[:, -1, :].argmax(dim=-1)
    for batch_idx in range(batch_size):
        all_batch_generated_ids[batch_idx].append(next_tokens[batch_idx].item())
    
    # Extract signals from element 0
    signals = extract_signals_from_output(outputs, layer_indices, position=-1)
    
    # Absolute position: input length - 1
    absolute_position_index = inputs['input_ids'].shape[1] - 1
    
    generation_signals.append({
        'step': 0,
        'absolute_position': absolute_position_index,
        'signals': signals
    })
    
    # Update attention mask
    attention_mask = torch.cat([
        inputs['attention_mask'], 
        torch.ones((inputs['attention_mask'].shape[0], 1), device='cuda')
    ], dim=1)
    
    # SUBSEQUENT STEPS: True autoregressive with cache
    for step in range(1, MAX_NEW_TOKENS):
        new_inputs = {
            'input_ids': next_tokens.unsqueeze(1),
            'attention_mask': attention_mask,
            'past_key_values': past_kv,
            'use_cache': True
        }
        
        with torch.no_grad():
            outputs = model(**new_inputs)
        
        past_kv = outputs.past_key_values
        
        # Get next tokens
        next_tokens = outputs.logits[:, -1, :].argmax(dim=-1)
        for batch_idx in range(batch_size):
            all_batch_generated_ids[batch_idx].append(next_tokens[batch_idx].item())
        
        # Extract signals from element 0
        signals = extract_signals_from_output(outputs, layer_indices, position=-1)
        
        # Absolute position in full sequence
        current_cache_length = past_kv[0][0].shape[2]
        absolute_position_index = current_cache_length - 1
        
        generation_signals.append({
            'step': step,
            'absolute_position': absolute_position_index,
            'signals': signals
        })
        
        # Update attention mask
        attention_mask = torch.cat([
            attention_mask, 
            torch.ones((attention_mask.shape[0], 1), device='cuda')
        ], dim=1)
        
        # Check for EOS in position 0
        if all_batch_generated_ids[0][-1] == tokenizer.eos_token_id:
            break
    
    # Extract last 3
    num_generated = len(generation_signals)
    if num_generated >= 3:
        last_3_signals = {
            'pos_-3': generation_signals[-3],
            'pos_-2': generation_signals[-2],
            'pos_-1': generation_signals[-1]
        }
    elif num_generated == 2:
        last_3_signals = {
            'pos_-2': generation_signals[-2],
            'pos_-1': generation_signals[-1]
        }
    elif num_generated == 1:
        last_3_signals = {
            'pos_-1': generation_signals[-1]
        }
    else:
        last_3_signals = {}
    
    del outputs
    del inputs
    del past_kv
    torch.cuda.empty_cache()
    
    final_length = prompt_length + num_generated
    log_print(f" → Final: {final_length} tokens ({num_generated} generated)")
    
    return {
        'generated_ids': all_batch_generated_ids[0],  # Position 0 tokens
        'all_batch_generated_ids': all_batch_generated_ids,  # ALL positions
        'prompt_token_ids': truncated_token_ids,  # EXACT token IDs used in prompt
        'prompt_length': prompt_length,
        'signals': last_3_signals,
        'num_generated': num_generated
    }

# ============================================================================
# PREFILL REPRODUCTION
# ============================================================================

def run_prefill_verification(model, tokenizer, ref_name, decode_metadata, 
                             batch_size, layer_indices, is_diagonal=False):
    """
    Run prefill to reproduce decoder's activations.
    Uses EXACT token IDs from decode - no tokenization artifacts.
    
    Args:
        is_diagonal: True if same batch size (uses actual neighbors from decode)
    """
    torch.cuda.empty_cache()
    
    # Get exact prompt IDs used in decode for element 0
    ref_prompt_ids = decode_metadata['prompt_token_ids'][0]
    generated_ids = decode_metadata['generated_ids']
    
    # Build extended sequence at ID level
    extended_ref_ids = ref_prompt_ids + generated_ids
    
    log_print(f" [Ext: {len(extended_ref_ids)} tokens", end="")
    
    # Build batch
    if batch_size == 1:
        batch_ids = [extended_ref_ids]
    else:
        batch_ids = [extended_ref_ids]
        
        if is_diagonal:
            # Use EXACT neighbor IDs from decode
            log_print(f", exact neighbors", end="")
            for i in range(1, batch_size):
                neighbor_prompt_ids = decode_metadata['prompt_token_ids'][i]
                neighbor_gen_ids = decode_metadata['all_batch_generated_ids'][i]
                extended_neighbor_ids = neighbor_prompt_ids + neighbor_gen_ids
                batch_ids.append(extended_neighbor_ids)
        else:
            # Off-diagonal: construct length-matched neighbors at ID level
            log_print(f", arb neighbors", end="")
            ref_dummies = DUMMY_SETS[ref_name]
            target_length = len(extended_ref_ids)
            
            for i in range(batch_size - 1):
                dummy_ids = tokenizer.encode(ref_dummies[i], add_special_tokens=True)
                
                # Truncate or repeat to match length
                if len(dummy_ids) >= target_length:
                    dummy_ids = dummy_ids[:target_length]
                else:
                    # Repeat tokens to reach target length
                    repeats = (target_length // len(dummy_ids)) + 1
                    dummy_ids = (dummy_ids * repeats)[:target_length]
                
                batch_ids.append(dummy_ids)
    
    # Build input tensors DIRECTLY from IDs - no decode/encode cycle
    input_ids = torch.tensor(batch_ids, dtype=torch.long, device='cuda')
    attention_mask = torch.ones_like(input_ids)
    
    inputs = {
        'input_ids': input_ids,
        'attention_mask': attention_mask
    }
    
    log_print(f"]", end="")
    
    # Single forward pass
    with torch.no_grad():
        outputs = model(**inputs, use_cache=True)
    
    # Extract at same positions as decode, using decode's token IDs for logprobs
    prefill_signals = {}
    extraction_positions = []
    
    for pos_label, decode_step_data in decode_metadata['signals'].items():
        decode_abs_pos = decode_step_data['absolute_position']
        decode_token_ids = decode_step_data['signals']['logprobs']['token_ids']
        
        extraction_positions.append(decode_abs_pos)
        
        # Extract using SAME token IDs as decode for logprobs
        signals = extract_signals_for_token_ids(
            outputs, layer_indices, decode_token_ids, position=decode_abs_pos
        )
        
        prefill_signals[pos_label] = {
            'absolute_position': decode_abs_pos,
            'signals': signals
        }
    
    log_print(f" → Extract at: {extraction_positions}")
    
    del outputs
    del inputs
    torch.cuda.empty_cache()
    
    return prefill_signals

# ============================================================================
# ANALYSIS
# ============================================================================

def compute_l2_distance(vec1, vec2):
    """Compute L2 distance between two vectors."""
    v1 = np.array(vec1)
    v2 = np.array(vec2)
    return float(np.linalg.norm(v1 - v2))


def compute_logprob_distance(logprobs1, logprobs2):
    """
    Compute L2 distance between logprob distributions.
    Both should have same token_ids in same order.
    """
    probs1 = np.array(logprobs1['log_probs'])
    probs2 = np.array(logprobs2['log_probs'])
    return float(np.linalg.norm(probs1 - probs2))


def compare_signals_generic(signals1, signals2, layer_indices):
    """Generic signal comparison - works for both decode-decode and prefill-prefill."""
    common_positions = set(signals1.keys()) & set(signals2.keys())
    
    all_key_dists = []
    all_logprob_dists = []
    
    for pos_label in common_positions:
        # Extract the actual signals (handle nested structure)
        if 'signals' in signals1[pos_label]:
            sig1 = signals1[pos_label]['signals']
            sig2 = signals2[pos_label]['signals']
        else:
            sig1 = signals1[pos_label]
            sig2 = signals2[pos_label]
        
        # Key vectors
        for layer_name in sig1['key_vectors'].keys():
            dist = compute_l2_distance(
                sig1['key_vectors'][layer_name],
                sig2['key_vectors'][layer_name]
            )
            all_key_dists.append(dist)
        
        # Logprobs
        dist = compute_logprob_distance(sig1['logprobs'], sig2['logprobs'])
        all_logprob_dists.append(dist)
    
    return {
        'key_vectors_max': max(all_key_dists) if all_key_dists else 0.0,
        'key_vectors_mean': np.mean(all_key_dists) if all_key_dists else 0.0,
        'logprobs_max': max(all_logprob_dists) if all_logprob_dists else 0.0,
        'logprobs_mean': np.mean(all_logprob_dists) if all_logprob_dists else 0.0
    }


def compare_decode_vs_prefill(decode_data, prefill_data, layer_indices):
    """Compare decode signals against prefill reproduction."""
    return compare_signals_generic(decode_data['signals'], prefill_data, layer_indices)


def check_token_consistency(decode_measurements, tokenizer):
    """Check if element 0 generates same tokens across batch sizes."""
    log_print("\n" + "="*80)
    log_print("TOKEN GENERATION CONSISTENCY CHECK")
    log_print("="*80)
    
    tokens_by_bs = {}
    for bs, data in decode_measurements.items():
        tokens_by_bs[bs] = data['generated_ids']
    
    bs_list = sorted(tokens_by_bs.keys())
    reference_tokens = tokens_by_bs[bs_list[0]]
    
    all_same = True
    log_print("\nGenerated tokens by batch size:")
    for bs in bs_list:
        tokens = tokens_by_bs[bs]
        match_str = "✓" if tokens == reference_tokens else "✗ DIFFERENT"
        decoded_text = tokenizer.decode(tokens)
        log_print(f"  bs={bs}:")
        log_print(f"    IDs:  {tokens}")
        log_print(f"    Text: {repr(decoded_text)}")
        log_print(f"    {match_str}")
        if tokens != reference_tokens:
            all_same = False
    
    if all_same:
        log_print("\n✓ Element 0 generates IDENTICAL tokens across all batch sizes")
        log_print("  → Can meaningfully compare activations for same token sequence")
    else:
        log_print("\n⚠ Element 0 generates DIFFERENT tokens across batch sizes")
        log_print("  → Batch composition affects generation")
    
    return all_same


def analyze_within_mode_batch_effects(measurements, layer_indices):
    """
    NEW ANALYSIS: Compare activations across batch sizes WITHIN each mode.
    
    This tests whether batch size affects computation even when tokens are identical.
    Expected: YES - batch size should create numerical differences due to kernels/parallelism.
    """
    log_print("\n" + "="*80)
    log_print("WITHIN-MODE BATCH SIZE EFFECTS")
    log_print("="*80)
    log_print("\nSanity check: Does batch size affect activations even with identical tokens?")
    log_print("Expected: YES - batch size changes computation even for same sequence\n")
    
    # DEBUG: Show total measurements
    log_print(f"DEBUG: Total measurements received: {len(measurements)}")
    if measurements:
        log_print(f"DEBUG: First measurement keys: {list(measurements[0].keys())}")
        log_print(f"DEBUG: First measurement decode_batch_size: {measurements[0].get('decode_batch_size')}")
        log_print(f"DEBUG: First measurement ref_name: {measurements[0].get('ref_name')}")
        log_print(f"DEBUG: decode_data signals keys: {list(measurements[0].get('decode_data', {}).get('signals', {}).keys())}")
    
    # Group measurements by reference and mode
    by_ref = {}
    for m in measurements:
        ref = m['ref_name']
        if ref not in by_ref:
            by_ref[ref] = {'decode': {}, 'prefill_diagonal': {}}
        
        # Decode data
        by_ref[ref]['decode'][m['decode_batch_size']] = m['decode_data']['signals']
        
        # Prefill data (only diagonal - where we have real neighbors)
        if m['is_diagonal']:
            by_ref[ref]['prefill_diagonal'][m['prefill_batch_size']] = m['prefill_data']
    
    results = {'decode': [], 'prefill': []}
    
    for ref_name in sorted(by_ref.keys()):
        log_print(f"\n{ref_name.upper()}")
        log_print("-" * 80)
        
        # DEBUG: Show what batch sizes we have
        log_print(f"  DEBUG: decode has {len(by_ref[ref_name]['decode'])} batch sizes: {sorted(by_ref[ref_name]['decode'].keys())}")
        log_print(f"  DEBUG: prefill_diagonal has {len(by_ref[ref_name]['prefill_diagonal'])} batch sizes: {sorted(by_ref[ref_name]['prefill_diagonal'].keys())}")
        
        # DECODE: Compare across batch sizes
        log_print("\nDECODE mode (autoregressive):")
        decode_data = by_ref[ref_name]['decode']
        
        if len(decode_data) >= 2:
            available_bs = sorted(decode_data.keys())
            bs_pairs = [(available_bs[i], available_bs[j]) 
                        for i in range(len(available_bs)) 
                        for j in range(i+1, len(available_bs))]
            
            for bs1, bs2 in bs_pairs:
                distances = compare_signals_generic(
                    decode_data[bs1], 
                    decode_data[bs2], 
                    layer_indices
                )
                
                log_print(f"  bs={bs1} vs bs={bs2}:")
                log_print(f"    Key vectors: Δ_max = {distances['key_vectors_max']:.2e}")
                log_print(f"    Logprobs:    Δ_max = {distances['logprobs_max']:.2e}")
                
                results['decode'].append({
                    'ref': ref_name,
                    'bs1': bs1,
                    'bs2': bs2,
                    'key_distance': distances['key_vectors_max'],
                    'logprob_distance': distances['logprobs_max']
                })
        
        # PREFILL: Compare across batch sizes (diagonal only)
        log_print("\nPREFILL mode (parallel forward):")
        prefill_data = by_ref[ref_name]['prefill_diagonal']
        
        if len(prefill_data) >= 2:
            available_bs = sorted(prefill_data.keys())
            bs_pairs = [(available_bs[i], available_bs[j]) 
                        for i in range(len(available_bs)) 
                        for j in range(i+1, len(available_bs))]
            
            for bs1, bs2 in bs_pairs:
                distances = compare_signals_generic(
                    prefill_data[bs1], 
                    prefill_data[bs2], 
                    layer_indices
                )
                
                log_print(f"  bs={bs1} vs bs={bs2}:")
                log_print(f"    Key vectors: Δ_max = {distances['key_vectors_max']:.2e}")
                log_print(f"    Logprobs:    Δ_max = {distances['logprobs_max']:.2e}")
                
                results['prefill'].append({
                    'ref': ref_name,
                    'bs1': bs1,
                    'bs2': bs2,
                    'key_distance': distances['key_vectors_max'],
                    'logprob_distance': distances['logprobs_max']
                })
    
    # Aggregate statistics
    log_print("\n" + "="*80)
    log_print("WITHIN-MODE AGGREGATE STATISTICS")
    log_print("="*80)
    
    summary = {}
    
    for mode in ['decode', 'prefill']:
        mode_results = results[mode]
        
        if not mode_results:
            continue
        
        key_dists = [r['key_distance'] for r in mode_results]
        logprob_dists = [r['logprob_distance'] for r in mode_results]
        
        log_print(f"\n{mode.upper()} mode:")
        log_print(f"  Key vectors: μ = {np.mean(key_dists):.2e}, max = {max(key_dists):.2e}")
        log_print(f"  Logprobs:    μ = {np.mean(logprob_dists):.2e}, max = {max(logprob_dists):.2e}")
        
        # Check if any are exactly zero
        key_zeros = sum(1 for d in key_dists if d == 0.0)
        logprob_zeros = sum(1 for d in logprob_dists if d == 0.0)
        
        if key_zeros > 0:
            log_print(f"  ⚠ {key_zeros}/{len(key_dists)} key comparisons are EXACTLY ZERO")
        if logprob_zeros > 0:
            log_print(f"  ⚠ {logprob_zeros}/{len(logprob_dists)} logprob comparisons are EXACTLY ZERO")
        
        # Track which batch pairs produce zeros, grouped by reference
        if key_zeros > 0 or logprob_zeros > 0:
            zeros_by_ref = {}
            for r in mode_results:
                ref = r['ref']
                pair = (r['bs1'], r['bs2'])
                if ref not in zeros_by_ref:
                    zeros_by_ref[ref] = {'key': [], 'logprob': []}
                if r['key_distance'] == 0.0:
                    zeros_by_ref[ref]['key'].append(pair)
                if r['logprob_distance'] == 0.0:
                    zeros_by_ref[ref]['logprob'].append(pair)
            
            log_print(f"\n  Zero locations by reference:")
            for ref in sorted(zeros_by_ref.keys()):
                key_pairs = zeros_by_ref[ref]['key']
                log_pairs = zeros_by_ref[ref]['logprob']
                if key_pairs or log_pairs:
                    log_print(f"    {ref}:")
                    if key_pairs:
                        log_print(f"      key zeros:     {sorted(key_pairs)}")
                    if log_pairs:
                        log_print(f"      logprob zeros: {sorted(log_pairs)}")
            
            # Check consistency across references
            all_refs = sorted(zeros_by_ref.keys())
            if len(all_refs) >= 2:
                key_sets = [set(zeros_by_ref[ref]['key']) for ref in all_refs]
                log_sets = [set(zeros_by_ref[ref]['logprob']) for ref in all_refs]
                
                key_intersection = set.intersection(*key_sets) if all(key_sets) else set()
                log_intersection = set.intersection(*log_sets) if all(log_sets) else set()
                
                log_print(f"\n  Cross-reference consistency:")
                if key_intersection:
                    log_print(f"    Key zeros consistent across ALL refs: {sorted(key_intersection)}")
                else:
                    log_print(f"    Key zeros: NO pairs are zero across all refs (coincidental)")
                if log_intersection:
                    log_print(f"    Logprob zeros consistent across ALL refs: {sorted(log_intersection)}")
                else:
                    log_print(f"    Logprob zeros: NO pairs are zero across all refs (coincidental)")
        
        summary[mode] = {
            'key_vectors_mean': float(np.mean(key_dists)),
            'key_vectors_max': float(max(key_dists)),
            'logprobs_mean': float(np.mean(logprob_dists)),
            'logprobs_max': float(max(logprob_dists)),
            'comparisons': mode_results
        }
    
    # Conclusion
    log_print("\n" + "="*80)
    log_print("CONCLUSION: WITHIN-MODE BATCH EFFECTS")
    log_print("="*80)
    
    threshold = 1e-10  # Essentially non-zero
    
    for mode in ['decode', 'prefill']:
        if mode not in summary:
            continue
        
        key_mean = summary[mode]['key_vectors_mean']
        log_mean = summary[mode]['logprobs_mean']
        
        log_print(f"\n{mode.upper()}:")
        if key_mean > threshold and log_mean > threshold:
            log_print(f"  ✓ Batch size DOES affect computation (both keys and logprobs differ)")
            log_print(f"    Keys: {key_mean:.2e}, Logprobs: {log_mean:.2e}")
        elif key_mean > threshold:
            log_print(f"  ~ Batch size affects keys but NOT logprobs")
            log_print(f"    Keys: {key_mean:.2e}, Logprobs: {log_mean:.2e}")
        elif log_mean > threshold:
            log_print(f"  ~ Batch size affects logprobs but NOT keys")
            log_print(f"    Keys: {key_mean:.2e}, Logprobs: {log_mean:.2e}")
        else:
            log_print(f"  ✗ Batch size does NOT affect computation (both zero)")
            log_print(f"    This would invalidate the experiment premise!")
    
    return summary


def analyze_experiment(measurements, layer_indices):
    """Compute and display decode vs prefill comparison matrices."""
    log_print("\n" + "="*80)
    log_print("ANALYSIS: DECODE vs PREFILL MATRIX")
    log_print("="*80)
    
    # Group by reference
    by_ref = {}
    for m in measurements:
        ref = m['ref_name']
        if ref not in by_ref:
            by_ref[ref] = {}
        key = (m['decode_batch_size'], m['prefill_batch_size'])
        by_ref[ref][key] = m
    
    # Compute matrices for each reference
    all_matrices = {'key_vectors': [], 'logprobs': []}
    
    for ref_name in sorted(by_ref.keys()):
        log_print(f"\n{'='*80}")
        log_print(f"{ref_name.upper()}")
        log_print(f"{'='*80}")
        
        ref_data = by_ref[ref_name]
        
        # Compute distance matrices
        n_batch_sizes = len(BATCH_SIZES)
        matrix_key = np.zeros((n_batch_sizes, n_batch_sizes))
        matrix_logprob = np.zeros((n_batch_sizes, n_batch_sizes))
        
        for i, decode_bs in enumerate(BATCH_SIZES):
            for j, prefill_bs in enumerate(BATCH_SIZES):
                key = (decode_bs, prefill_bs)
                if key in ref_data:
                    m = ref_data[key]
                    distances = compare_decode_vs_prefill(
                        m['decode_data'],
                        m['prefill_data'],
                        layer_indices
                    )
                    matrix_key[i, j] = distances['key_vectors_max']
                    matrix_logprob[i, j] = distances['logprobs_max']
        
        # Display matrices
        n_bs = len(BATCH_SIZES)
        header = "                " + "".join([f"Prefill bs={bs:>3}  " for bs in BATCH_SIZES])
        
        log_print("\nKey Vectors (max L2 distance):")
        log_print(header)
        for i, decode_bs in enumerate(BATCH_SIZES):
            row_str = f"Decode bs={decode_bs:<4}  "
            for j in range(n_bs):
                row_str += f"  {matrix_key[i,j]:11.2e}"
            log_print(row_str)
        
        log_print("\nLogprobs (max L2 distance):")
        log_print(header)
        for i, decode_bs in enumerate(BATCH_SIZES):
            row_str = f"Decode bs={decode_bs:<4}  "
            for j in range(n_bs):
                row_str += f"  {matrix_logprob[i,j]:11.2e}"
            log_print(row_str)
        
        all_matrices['key_vectors'].append(matrix_key)
        all_matrices['logprobs'].append(matrix_logprob)
    
    # Aggregate statistics
    log_print("\n" + "="*80)
    log_print("AGGREGATE STATISTICS (AVERAGE ACROSS REFERENCES)")
    log_print("="*80)
    
    results = {}
    n_bs = len(BATCH_SIZES)
    
    for signal_type in ['key_vectors', 'logprobs']:
        matrices = all_matrices[signal_type]
        avg_matrix = np.mean(matrices, axis=0)
        
        header = "                " + "".join([f"Prefill bs={bs:>3}  " for bs in BATCH_SIZES])
        
        log_print(f"\n{signal_type.upper()}:")
        log_print(header)
        for i, decode_bs in enumerate(BATCH_SIZES):
            row_str = f"Decode bs={decode_bs:<4}  "
            for j in range(n_bs):
                row_str += f"  {avg_matrix[i,j]:11.2e}"
            log_print(row_str)
        
        # Extract diagonal (noise) and off-diagonal (signal)
        diagonal = np.array([avg_matrix[i, i] for i in range(n_bs)])
        off_diagonal = np.array([avg_matrix[i, j] for i in range(n_bs) for j in range(n_bs) if i != j])
        
        noise_mean = np.mean(diagonal)
        noise_std = np.std(diagonal)
        signal_mean = np.mean(off_diagonal)
        signal_std = np.std(off_diagonal)
        snr = signal_mean / noise_mean if noise_mean > 0 else float('inf')
        
        log_print(f"\n  Diagonal (noise - decode bs = prefill bs):")
        log_print(f"    μ = {noise_mean:.2e}, σ = {noise_std:.2e}")
        log_print(f"    Values: {[f'{d:.2e}' for d in diagonal]}")
        
        log_print(f"\n  Off-diagonal (signal - decode bs ≠ prefill bs):")
        log_print(f"    μ = {signal_mean:.2e}, σ = {signal_std:.2e}")
        
        log_print(f"\n  SNR (signal/noise): {snr:.2f}×")
        
        results[signal_type] = {
            'matrix': avg_matrix.tolist(),
            'noise_mean': float(noise_mean),
            'noise_std': float(noise_std),
            'signal_mean': float(signal_mean),
            'signal_std': float(signal_std),
            'snr': float(snr)
        }
    
    # Conclusion
    log_print("\n" + "="*80)
    log_print("CONCLUSION")
    log_print("="*80)
    
    threshold = 1.5
    
    key_snr = results['key_vectors']['snr']
    log_snr = results['logprobs']['snr']
    
    log_print(f"\nKey Vectors: SNR = {key_snr:.2f}× {'✓ DETECTABLE' if key_snr >= threshold else '✗ NOT DETECTABLE'}")
    log_print(f"Logprobs:    SNR = {log_snr:.2f}× {'✓ DETECTABLE' if log_snr >= threshold else '✗ NOT DETECTABLE'}")
    
    if key_snr >= threshold and log_snr >= threshold:
        log_print("\n✓ BATCH SIZE MISMATCHES ARE DETECTABLE")
        log_print("  → Prefill can detect decode batch size")
        log_print("  → Off-diagonal distances >> diagonal distances")
    elif key_snr >= threshold or log_snr >= threshold:
        log_print("\n~ PARTIAL DETECTABILITY")
        det = 'Key vectors' if key_snr >= threshold else 'Logprobs'
        log_print(f"  → {det} can detect batch size mismatches")
    else:
        log_print("\n✗ BATCH SIZE MISMATCHES NOT RELIABLY DETECTABLE")
        log_print("  → Off-diagonal distances similar to diagonal distances")
        log_print("  → Cannot reliably detect decode batch size from prefill")
    
    # Convert numpy arrays to lists for JSON serialization
    matrices_serializable = {
        'key_vectors': [m.tolist() for m in all_matrices['key_vectors']],
        'logprobs': [m.tolist() for m in all_matrices['logprobs']]
    }
    
    return {
        'matrices': matrices_serializable,
        'statistics': results
    }

# ============================================================================
# MAIN
# ============================================================================

def main():
    global REFERENCE_SEQUENCES, DUMMY_SETS
    
    # Setup logging to file
    log_path = setup_logging()
    
    system_info = collect_system_info()
    
    log_print("="*80)
    log_print("PREFILL vs DECODE EXPERIMENT")
    log_print("="*80)
    log_print(f"Log file: {log_path}")
    log_print(f"\nEnvironment:")
    for k, v in system_info.items():
        log_print(f"  {k}: {v}")
    
    log_print(f"\nConfiguration:")
    log_print(f"  Model: {MODEL_NAME}")
    log_print(f"  Layers: {LAYER_INDICES}")
    log_print(f"  Batch sizes: {BATCH_SIZES}")
    log_print(f"  Max tokens: {MAX_NEW_TOKENS}")
    log_print()
    
    # Load model
    log_print("Loading model...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
    
    # Initialize sequences from PDF
    REFERENCE_SEQUENCES, DUMMY_SETS = create_sequences_from_pdf(tokenizer)
    log_print(f"Created {len(REFERENCE_SEQUENCES)} reference sequences with {len(DUMMY_SETS[list(DUMMY_SETS.keys())[0]])} dummies each\n")
    
    log_print("Experiment design:")
    log_print(f"  - Decode runs at {len(BATCH_SIZES)} batch sizes")
    log_print(f"  - Prefill reproduces at {len(BATCH_SIZES)} batch sizes")
    log_print(f"  - Matrix: {len(BATCH_SIZES)} decode × {len(BATCH_SIZES)} prefill = {len(BATCH_SIZES)**2} comparisons per reference")
    log_print("  - Works purely at token ID level (no tokenization artifacts)")
    log_print("  - NEW: Within-mode batch size comparisons (sanity check)")
    log_print()
    
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        dtype=torch.bfloat16,
        cache_dir=CACHE_DIR,
        low_cpu_mem_usage=True,
        device_map="auto"
    )
    log_print("✓ Model loaded\n")
    
    # Results structure
    results = {
        'metadata': {
            'environment': system_info,
            'model': MODEL_NAME,
            'layer_indices': LAYER_INDICES,
            'batch_sizes': BATCH_SIZES,
            'max_new_tokens': MAX_NEW_TOKENS,
            'timestamp': datetime.now().isoformat()
        },
        'measurements': []
    }
    
    # Run experiments
    for ref_name, ref_text in REFERENCE_SEQUENCES.items():
        log_print(f"\n{'='*80}")
        log_print(f"REFERENCE: {ref_name}")
        log_print(f"{'='*80}")
        
        # Pre-compute minimum length across all batch configurations
        min_prompt_length = compute_min_length_across_batches(
            ref_text, ref_name, tokenizer, BATCH_SIZES
        )
        log_print(f"\nGlobal minimum prompt length: {min_prompt_length} tokens")
        log_print("(All batch sizes will use this length)\n")
        
        # Step 1: Run decode at each batch size
        log_print(f"Decode runs:")
        decode_measurements = {}
        
        for decode_bs in BATCH_SIZES:
            log_print(f"  bs={decode_bs}...", end=" ")
            decode_data = run_decode_with_extraction(
                model, tokenizer, ref_text, ref_name, decode_bs, LAYER_INDICES,
                forced_length=min_prompt_length
            )
            
            # Show extraction positions
            extract_positions = [step['absolute_position'] for step in decode_data['signals'].values()]
            log_print(f"    Extract positions: {extract_positions}")
            
            decode_measurements[decode_bs] = decode_data
        
        # Check token consistency
        check_token_consistency(decode_measurements, tokenizer)
        
        # Step 2: Run prefill reproduction at all batch sizes
        log_print(f"\nPrefill reproductions:")
        
        for decode_bs in BATCH_SIZES:
            decode_data = decode_measurements[decode_bs]
            log_print(f"\n  Reproducing decode bs={decode_bs}:")
            
            for prefill_bs in BATCH_SIZES:
                log_print(f"    with bs={prefill_bs}...", end="")
                
                # Determine if this is diagonal (same batch size for decode and prefill)
                is_diagonal = (decode_bs == prefill_bs)
                
                prefill_data = run_prefill_verification(
                    model, tokenizer, ref_name, decode_data, prefill_bs,
                    LAYER_INDICES, is_diagonal=is_diagonal
                )
                
                results['measurements'].append({
                    'ref_name': ref_name,
                    'decode_batch_size': decode_bs,
                    'prefill_batch_size': prefill_bs,
                    'is_diagonal': is_diagonal,
                    'decode_data': {
                        'generated_ids': decode_data['generated_ids'],
                        'all_batch_generated_ids': decode_data['all_batch_generated_ids'],
                        'prompt_length': decode_data['prompt_length'],
                        'signals': decode_data['signals'],
                        'num_generated': decode_data['num_generated']
                    },
                    'prefill_data': prefill_data
                })
    
    # Save data
    output_dir = '/workspace/experiments'
    os.makedirs(output_dir, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filepath = os.path.join(output_dir, f"prefill_decode_{timestamp}.json")
    
    with open(filepath, 'w') as f:
        json.dump(results, f, indent=2)
    
    log_print(f"\n✓ Data saved to: {filepath}")
    
    # NEW: Analyze within-mode batch effects
    within_mode_results = analyze_within_mode_batch_effects(results['measurements'], LAYER_INDICES)
    
    # Run decode vs prefill analysis
    analysis_results = analyze_experiment(results['measurements'], LAYER_INDICES)
    
    # Save with both analyses
    results['analysis'] = {
        'within_mode_batch_effects': within_mode_results,
        'decode_vs_prefill': analysis_results
    }
    with open(filepath, 'w') as f:
        json.dump(results, f, indent=2)
    
    file_size_mb = os.path.getsize(filepath) / (1024 * 1024)
    log_print(f"✓ Analysis saved (file size: {file_size_mb:.1f} MB)")
    
    log_print(f"\n{'='*80}")
    log_print("EXPERIMENT COMPLETE")
    log_print(f"{'='*80}\n")
    
    close_logging()


if __name__ == "__main__":
    try:
        get_ipython()
    except NameError:
        pass
    
    main()

PREFILL vs DECODE EXPERIMENT
Log file: /workspace/experiments/experiment_log_20251124_164613.txt

Environment:
  hostname: c895fe420c2d
  platform: Linux-5.15.0-140-generic-x86_64-with-glibc2.35
  python_version: 3.10.18
  torch_version: 2.6.0+cu118
  cuda_version: 11.8
  transformers_version: 4.57.1
  gpu_name: NVIDIA A100-SXM4-80GB
  gpu_count: 1

Configuration:
  Model: Qwen/Qwen2.5-7B-Instruct
  Layers: [28]
  Batch sizes: [4, 5, 8, 9, 16, 17]
  Max tokens: 20

Loading model...
Found 3 PDF(s)
  Loading: /workspace/Verification-for-International-AI-Governance.pdf
    → 120214 tokens
  Loading: /workspace/Epoch_data.pdf
    → 29497 tokens
  Loading: /workspace/Llama3.1.pdf
    → 99282 tokens
Total tokens: 248993
Creating 51 slices of 512 tokens each
Created 3 reference sequences with 16 dummies each

Experiment design:
  - Decode runs at 6 batch sizes
  - Prefill reproduces at 6 batch sizes
  - Matrix: 6 decode × 6 prefill = 36 comparisons per reference
  - Works purely at token ID l

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

✓ Model loaded


REFERENCE: ref_0

Global minimum prompt length: 512 tokens
(All batch sizes will use this length)

Decode runs:
  bs=4...       Prompt length: 512 tokens → Final: 532 tokens (20 generated)
    Extract positions: [528, 529, 530]
  bs=5...       Prompt length: 512 tokens → Final: 532 tokens (20 generated)
    Extract positions: [528, 529, 530]
  bs=8...       Prompt length: 512 tokens → Final: 532 tokens (20 generated)
    Extract positions: [528, 529, 530]
  bs=9...       Prompt length: 512 tokens → Final: 532 tokens (20 generated)
    Extract positions: [528, 529, 530]
  bs=16...       Prompt length: 512 tokens → Final: 532 tokens (20 generated)
    Extract positions: [528, 529, 530]
  bs=17...       Prompt length: 512 tokens → Final: 532 tokens (20 generated)
    Extract positions: [528, 529, 530]

TOKEN GENERATION CONSISTENCY CHECK

Generated tokens by batch size:
  bs=4:
    IDs:  [1096, 1895, 40324, 279, 4650, 369, 6489, 22901, 198, 351, 57775, 311, 1824, 279, 8480