In [1]:
#!/usr/bin/env python3
"""
Cross-Hardware Batch Size Detectability Experiment

Tests whether batch size claims can be verified across different GPU architectures
using floating-point forensics (key vectors and logprobs).

Workflow:
1. Run on Machine A (e.g., A100) with TEACHER_FORCING = False
   → Generates tokens, extracts signals, saves to JSON
2. Copy JSON to Machine B (e.g., H100)
3. Run on Machine B with TEACHER_FORCING = True
   → Teacher-forces A's tokens, extracts signals, compares

Matrix interpretation:
- Diagonal (claimed_bs == verify_bs): hardware-only difference (baseline)
- Off-diagonal (claimed_bs != verify_bs): hardware + batch size difference (signal)
- Detectability: Is off-diagonal > diagonal?

Usage:
    # Machine A: Generate reference
    python cross_hardware_batch_size.py
    
    # Machine B: Verify (edit TEACHER_FORCING and REFERENCE_FILE first)
    python cross_hardware_batch_size.py
"""

import os
os.environ['HF_HOME'] = '/workspace/huggingface_cache'
os.environ['TRANSFORMERS_CACHE'] = '/workspace/huggingface_cache'

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import torch.nn.functional as F
import numpy as np
from datetime import datetime
import json
import socket
import platform
import sys
import glob
import PyPDF2

# ============================================================================
# CONFIGURATION
# ============================================================================

# Toggle: False = generate reference, True = verify against reference
TEACHER_FORCING = True
REFERENCE_FILE = "A100_reference.json"  # Used when TEACHER_FORCING=True

MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"
CACHE_DIR = '/workspace/huggingface_cache'
ATTN_IMPLEMENTATION = "flash_attention_2"  # Options: "eager", "sdpa", "flash_attention_2"

BATCH_SIZES = [4, 5, 8, 9, 16, 17]
LAYER_INDICES = [1, 4, 10, 18, 28]
MAX_NEW_TOKENS = 20
TOKENS_PER_SLICE = 512
NUM_REFERENCES = 3

# Will be initialized from PDF in main()
REFERENCE_SEQUENCES = None
DUMMY_SETS = None

# ============================================================================
# LOGGING SETUP
# ============================================================================

LOG_FILE = None

def setup_logging(output_dir='/workspace/experiments'):
    """Setup logging to file."""
    global LOG_FILE
    os.makedirs(output_dir, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    mode = "verify" if TEACHER_FORCING else "generate"
    log_path = os.path.join(output_dir, f"experiment_{mode}_{timestamp}.txt")
    LOG_FILE = open(log_path, 'w')
    return log_path

def log_print(*args, **kwargs):
    """Print to both console and log file."""
    print(*args, **kwargs)
    if LOG_FILE:
        log_kwargs = {k: v for k, v in kwargs.items() if k != 'file'}
        print(*args, **log_kwargs, file=LOG_FILE)
        LOG_FILE.flush()

def close_logging():
    """Close log file."""
    global LOG_FILE
    if LOG_FILE:
        LOG_FILE.close()
        LOG_FILE = None

# ============================================================================
# PDF LOADING
# ============================================================================

def load_pdf_text(pdf_path):
    """Load text content from a PDF file."""
    text = ""
    with open(pdf_path, 'rb') as f:
        reader = PyPDF2.PdfReader(f)
        for page in reader.pages:
            page_text = page.extract_text()
            if page_text:
                text += page_text + " "
    return text.strip()


def create_sequences_from_pdf(tokenizer, num_references=NUM_REFERENCES):
    """
    Load all PDFs and split into equal-length slices.
    Returns REFERENCE_SEQUENCES and DUMMY_SETS dictionaries.
    """
    # Find PDFs
    pdf_files = glob.glob("/workspace/*.pdf")
    if not pdf_files:
        pdf_files = glob.glob("*.pdf")
    if not pdf_files:
        raise FileNotFoundError("No PDF files found")
    
    log_print(f"Found {len(pdf_files)} PDF(s)")
    
    # Load and tokenize
    all_tokens = []
    for pdf_path in pdf_files:
        log_print(f"  Loading: {pdf_path}")
        text = load_pdf_text(pdf_path)
        tokens = tokenizer.encode(text, add_special_tokens=True)
        all_tokens.extend(tokens)
        log_print(f"    → {len(tokens)} tokens")
    
    log_print(f"Total tokens: {len(all_tokens)}")
    
    # Calculate slices needed
    max_batch_size = max(BATCH_SIZES)
    slices_needed = num_references * max_batch_size
    tokens_needed = slices_needed * TOKENS_PER_SLICE
    
    if len(all_tokens) < tokens_needed:
        raise ValueError(f"Need {tokens_needed} tokens but only have {len(all_tokens)}")
    
    log_print(f"Creating {slices_needed} slices of {TOKENS_PER_SLICE} tokens each")
    
    slices = []
    for i in range(slices_needed):
        start = i * TOKENS_PER_SLICE
        end = start + TOKENS_PER_SLICE
        slice_tokens = all_tokens[start:end]
        slice_text = tokenizer.decode(slice_tokens)
        slices.append(slice_text)
    
    # Build reference sequences and dummy sets
    reference_sequences = {}
    dummy_sets = {}
    
    for ref_idx in range(num_references):
        ref_name = f"ref_{ref_idx}"
        base_idx = ref_idx * max_batch_size
        reference_sequences[ref_name] = slices[base_idx]
        dummy_sets[ref_name] = slices[base_idx + 1 : base_idx + max_batch_size]
    
    return reference_sequences, dummy_sets

# ============================================================================
# SYSTEM INFO
# ============================================================================

def collect_system_info():
    """Collect comprehensive environment information."""
    import transformers
    
    info = {
        "hostname": socket.gethostname(),
        "platform": platform.platform(),
        "python_version": sys.version.split()[0],
        "torch_version": torch.__version__,
        "cuda_version": torch.version.cuda if torch.cuda.is_available() else "N/A",
        "cudnn_version": str(torch.backends.cudnn.version()) if torch.cuda.is_available() else "N/A",
        "transformers_version": transformers.__version__,
        "numpy_version": np.__version__,
        "attn_implementation": ATTN_IMPLEMENTATION,
        "gpu_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "N/A",
        "gpu_count": torch.cuda.device_count() if torch.cuda.is_available() else 0,
    }
    
    # Flash attention (optional)
    try:
        import flash_attn
        info["flash_attn_version"] = flash_attn.__version__
    except ImportError:
        info["flash_attn_version"] = "N/A"
    
    return info


def validate_environment_match(reference_env, verifier_env):
    """
    Validate that software environments match between reference and verifier.
    If mismatch found, prints exact commands to fix and exits.
    
    We want ONLY hardware (GPU) to differ. Software differences would confound results.
    """
    log_print("\n" + "="*80)
    log_print("ENVIRONMENT VALIDATION")
    log_print("="*80)
    
    # Fields grouped by how to fix them
    container_fields = ['python_version', 'cuda_version', 'cudnn_version']  # Need different container
    pip_fields = ['torch_version', 'transformers_version', 'numpy_version', 'flash_attn_version']
    config_fields = ['attn_implementation']  # Script config - must match
    
    # Fields that SHOULD differ (the point of the experiment)
    expected_different = ['gpu_name', 'hostname']
    
    container_mismatches = []
    pip_mismatches = []
    config_mismatches = []
    
    log_print("\nScript configuration:")
    for field in config_fields:
        ref_val = reference_env.get(field, 'N/A')
        ver_val = verifier_env.get(field, 'N/A')
        
        if ref_val == ver_val:
            log_print(f"  ✓ {field}: {ref_val}")
        else:
            log_print(f"  ✗ {field}: reference={ref_val}, verifier={ver_val}")
            config_mismatches.append((field, ref_val, ver_val))
    
    log_print("\nContainer-level dependencies:")
    for field in container_fields:
        ref_val = reference_env.get(field, 'N/A')
        ver_val = verifier_env.get(field, 'N/A')
        
        if ref_val == ver_val:
            log_print(f"  ✓ {field}: {ref_val}")
        else:
            log_print(f"  ✗ {field}: reference={ref_val}, verifier={ver_val}")
            container_mismatches.append((field, ref_val, ver_val))
    
    log_print("\nPip-installable packages:")
    for field in pip_fields:
        ref_val = reference_env.get(field, 'N/A')
        ver_val = verifier_env.get(field, 'N/A')
        
        # flash_attn_version only matters if using flash_attention_2
        if field == 'flash_attn_version':
            ref_attn = reference_env.get('attn_implementation', '')
            ver_attn = verifier_env.get('attn_implementation', '')
            
            if ref_attn != 'flash_attention_2' and ver_attn != 'flash_attention_2':
                log_print(f"  - {field}: not using flash_attention_2 (skip)")
                continue
            
            # Using flash_attention_2 - versions MUST match
            if ref_val == 'N/A' or ver_val == 'N/A':
                log_print(f"  ✗ {field}: reference={ref_val}, verifier={ver_val}")
                log_print(f"      (flash_attention_2 requires flash-attn installed on both)")
                pip_mismatches.append((field, ref_val, ver_val))
                continue
        
        # Skip if both N/A (package not used)
        if ref_val == 'N/A' and ver_val == 'N/A':
            log_print(f"  - {field}: not installed (OK)")
            continue
        
        if ref_val == ver_val:
            log_print(f"  ✓ {field}: {ref_val}")
        else:
            log_print(f"  ✗ {field}: reference={ref_val}, verifier={ver_val}")
            pip_mismatches.append((field, ref_val, ver_val))
    
    log_print("\nExpected differences (hardware):")
    for field in expected_different:
        ref_val = reference_env.get(field, 'N/A')
        ver_val = verifier_env.get(field, 'N/A')
        
        if ref_val != ver_val:
            log_print(f"  ✓ {field}: reference={ref_val}, verifier={ver_val}")
        else:
            log_print(f"  ⚠ {field}: SAME ({ref_val}) - are you on different hardware?")
    
    if not container_mismatches and not pip_mismatches and not config_mismatches:
        log_print("\n" + "-"*60)
        log_print("✓ ENVIRONMENT VALIDATION PASSED")
        log_print("  All critical software versions match.")
        log_print("  Only hardware differs - results will be meaningful.")
        return {'valid': True, 'mismatches': []}
    
    # Mismatches found - print fix commands and exit
    log_print("\n" + "="*80)
    log_print("✗ ENVIRONMENT MISMATCH - FIX REQUIRED")
    log_print("="*80)
    
    cuda_ver = reference_env.get('cuda_version', '')
    
    # Determine torch index URL based on reference CUDA version
    if cuda_ver.startswith('11.8'):
        torch_index = 'https://download.pytorch.org/whl/cu118'
    elif cuda_ver.startswith('12.1'):
        torch_index = 'https://download.pytorch.org/whl/cu121'
    elif cuda_ver.startswith('12.4'):
        torch_index = 'https://download.pytorch.org/whl/cu124'
    else:
        torch_index = 'https://download.pytorch.org/whl/cu121'
    
    if config_mismatches:
        log_print("\n--- SCRIPT CONFIG (edit ATTN_IMPLEMENTATION in script) ---\n")
        
        for field, ref_val, ver_val in config_mismatches:
            log_print(f"  {field}: need '{ref_val}', have '{ver_val}'")
            log_print(f"  Edit the script: ATTN_IMPLEMENTATION = \"{ref_val}\"")
    
    if container_mismatches:
        log_print("\n--- CONTAINER-LEVEL (use a different container/image) ---\n")
        
        for field, ref_val, ver_val in container_mismatches:
            if field == 'python_version':
                log_print(f"  Python: need {ref_val}, have {ver_val}")
            elif field == 'cuda_version':
                log_print(f"  CUDA: need {ref_val}, have {ver_val}")
            elif field == 'cudnn_version':
                log_print(f"  cuDNN: need {ref_val}, have {ver_val}")
        
        log_print("\n  Suggested Docker images (PyTorch NGC containers):")
        log_print(f"    nvcr.io/nvidia/pytorch:XX.XX-py3")
        log_print(f"    Or find matching: https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/")
    
    if pip_mismatches:
        log_print("\n--- PIP-INSTALLABLE (run these commands) ---\n")
        
        for field, ref_val, ver_val in pip_mismatches:
            if field == 'torch_version':
                log_print(f"  pip install torch=={ref_val} --index-url {torch_index}")
            elif field == 'transformers_version':
                log_print(f"  pip install transformers=={ref_val}")
            elif field == 'numpy_version':
                log_print(f"  pip install numpy=={ref_val}")
            elif field == 'flash_attn_version':
                if ref_val == 'N/A':
                    log_print(f"  pip uninstall flash-attn  # reference didn't use it")
                else:
                    log_print(f"  pip install flash-attn=={ref_val}  # may need to build from source")
    
    log_print("\n" + "="*80)
    log_print("Fix the above and re-run this script.")
    log_print("="*80)
    
    sys.exit(1)

# ============================================================================
# SIGNAL EXTRACTION
# ============================================================================

def extract_signals_from_output(outputs, layer_indices, position=-1):
    """
    Extract key vectors and logprobs from element 0 at specified position.
    Returns top-10 token IDs for logprob comparison.
    """
    signals = {
        'key_vectors': {},
        'logprobs': {}
    }
    
    # Key vectors from element 0
    for layer_idx in layer_indices:
        layer_keys = outputs.past_key_values[layer_idx - 1][0]
        token_keys = layer_keys[0, :, position, :]
        key_dim = token_keys.shape[0] * token_keys.shape[1]
        key_vector = token_keys.reshape(key_dim).cpu().clone()
        signals['key_vectors'][f'layer_{layer_idx}'] = key_vector.float().numpy().tolist()
    
    # Logprobs from element 0 - get top 10
    logits = outputs.logits[0, position, :]
    log_probs = F.log_softmax(logits, dim=-1)
    top10 = torch.topk(log_probs, k=10)
    
    signals['logprobs'] = {
        'token_ids': top10.indices.cpu().tolist(),
        'log_probs': top10.values.cpu().tolist()
    }
    
    return signals


def extract_signals_for_token_ids(outputs, layer_indices, token_ids, position=-1):
    """
    Extract key vectors and logprobs for SPECIFIC token IDs.
    Used in verification to compare same tokens as reference.
    """
    signals = {
        'key_vectors': {},
        'logprobs': {}
    }
    
    # Key vectors from element 0
    for layer_idx in layer_indices:
        layer_keys = outputs.past_key_values[layer_idx - 1][0]
        token_keys = layer_keys[0, :, position, :]
        key_dim = token_keys.shape[0] * token_keys.shape[1]
        key_vector = token_keys.reshape(key_dim).cpu().clone()
        signals['key_vectors'][f'layer_{layer_idx}'] = key_vector.float().numpy().tolist()
    
    # Logprobs for SPECIFIC token IDs
    logits = outputs.logits[0, position, :]
    log_probs = F.log_softmax(logits, dim=-1)
    token_ids_tensor = torch.tensor(token_ids, device=logits.device)
    selected_logprobs = log_probs[token_ids_tensor]
    
    signals['logprobs'] = {
        'token_ids': token_ids,
        'log_probs': selected_logprobs.cpu().tolist()
    }
    
    return signals

# ============================================================================
# DECODE GENERATION (TEACHER_FORCING = False)
# ============================================================================

def compute_min_length_across_batches(ref_text, ref_name, tokenizer, batch_sizes):
    """Pre-compute minimum sequence length across all batch configurations."""
    ref_dummies = DUMMY_SETS[ref_name]
    min_length = float('inf')
    
    for batch_size in batch_sizes:
        if batch_size == 1:
            batch_texts = [ref_text]
        else:
            batch_texts = [ref_text] + ref_dummies[:batch_size-1]
        
        token_lengths = [len(tokenizer.encode(t, add_special_tokens=True)) for t in batch_texts]
        min_length = min(min_length, min(token_lengths))
    
    return min_length


def run_decode_with_extraction(model, tokenizer, ref_text, ref_name, batch_size, 
                                layer_indices, forced_length=None):
    """
    Run decode generation and extract signals from last 3 generation steps.
    Works purely at token ID level.
    """
    torch.cuda.empty_cache()
    
    # Build batch texts
    ref_dummies = DUMMY_SETS[ref_name]
    if batch_size == 1:
        batch_texts = [ref_text]
    else:
        batch_texts = [ref_text] + ref_dummies[:batch_size-1]
    
    # Tokenize ONCE
    all_token_ids = [tokenizer.encode(t, add_special_tokens=True) for t in batch_texts]
    
    # Truncate at token ID level
    if forced_length is not None:
        min_length = forced_length
    else:
        min_length = min(len(ids) for ids in all_token_ids)
    
    truncated_token_ids = [ids[:min_length] for ids in all_token_ids]
    
    # Build input tensors DIRECTLY from token IDs
    input_ids = torch.tensor(truncated_token_ids, dtype=torch.long, device='cuda')
    attention_mask = torch.ones_like(input_ids)
    
    inputs = {
        'input_ids': input_ids,
        'attention_mask': attention_mask
    }
    
    prompt_length = input_ids.shape[1]
    log_print(f"      Prompt: {prompt_length} tokens", end="")
    
    # Track generation for ALL batch positions
    all_batch_generated_ids = [[] for _ in range(batch_size)]
    generation_signals = []
    
    # FIRST STEP: Prefill with full prompt
    with torch.no_grad():
        outputs = model(**inputs, use_cache=True)
    
    past_kv = outputs.past_key_values
    
    # Get first token
    next_tokens = outputs.logits[:, -1, :].argmax(dim=-1)
    for batch_idx in range(batch_size):
        all_batch_generated_ids[batch_idx].append(next_tokens[batch_idx].item())
    
    # Extract signals from element 0
    signals = extract_signals_from_output(outputs, layer_indices, position=-1)
    absolute_position_index = inputs['input_ids'].shape[1] - 1
    
    generation_signals.append({
        'step': 0,
        'absolute_position': absolute_position_index,
        'signals': signals
    })
    
    # Update attention mask
    attention_mask = torch.cat([
        inputs['attention_mask'], 
        torch.ones((inputs['attention_mask'].shape[0], 1), device='cuda')
    ], dim=1)
    
    # SUBSEQUENT STEPS: True autoregressive with cache
    for step in range(1, MAX_NEW_TOKENS):
        new_inputs = {
            'input_ids': next_tokens.unsqueeze(1),
            'attention_mask': attention_mask,
            'past_key_values': past_kv,
            'use_cache': True
        }
        
        with torch.no_grad():
            outputs = model(**new_inputs)
        
        past_kv = outputs.past_key_values
        
        # Get next tokens
        next_tokens = outputs.logits[:, -1, :].argmax(dim=-1)
        for batch_idx in range(batch_size):
            all_batch_generated_ids[batch_idx].append(next_tokens[batch_idx].item())
        
        # Extract signals from element 0
        signals = extract_signals_from_output(outputs, layer_indices, position=-1)
        current_cache_length = past_kv[0][0].shape[2]
        absolute_position_index = current_cache_length - 1
        
        generation_signals.append({
            'step': step,
            'absolute_position': absolute_position_index,
            'signals': signals
        })
        
        # Update attention mask
        attention_mask = torch.cat([
            attention_mask, 
            torch.ones((attention_mask.shape[0], 1), device='cuda')
        ], dim=1)
        
        # Check for EOS in position 0
        if all_batch_generated_ids[0][-1] == tokenizer.eos_token_id:
            break
    
    # Extract last 3
    num_generated = len(generation_signals)
    if num_generated >= 3:
        last_3_signals = {
            'pos_-3': generation_signals[-3],
            'pos_-2': generation_signals[-2],
            'pos_-1': generation_signals[-1]
        }
    elif num_generated == 2:
        last_3_signals = {
            'pos_-2': generation_signals[-2],
            'pos_-1': generation_signals[-1]
        }
    elif num_generated == 1:
        last_3_signals = {
            'pos_-1': generation_signals[-1]
        }
    else:
        last_3_signals = {}
    
    del outputs, inputs, past_kv
    torch.cuda.empty_cache()
    
    final_length = prompt_length + num_generated
    log_print(f" → Final: {final_length} tokens ({num_generated} generated)")
    
    return {
        'generated_ids': all_batch_generated_ids[0],
        'all_batch_generated_ids': all_batch_generated_ids,
        'prompt_token_ids': truncated_token_ids,
        'prompt_length': prompt_length,
        'signals': last_3_signals,
        'num_generated': num_generated
    }

# ============================================================================
# TEACHER-FORCED DECODE (TEACHER_FORCING = True)
# ============================================================================

def run_teacher_forced_decode(model, tokenizer, ref_name, reference_data, 
                               verify_batch_size, layer_indices, is_diagonal):
    """
    Teacher-forced decode: feed reference tokens, extract signals.
    
    Args:
        reference_data: Dict with prompt_token_ids, all_batch_generated_ids, signals
        verify_batch_size: Batch size for verification
        is_diagonal: If True, use exact reference neighbors. If False, use arbitrary neighbors.
    """
    torch.cuda.empty_cache()
    
    ref_prompt_ids = reference_data['prompt_token_ids'][0]
    ref_generated_ids = reference_data['generated_ids']
    ref_batch_size = len(reference_data['prompt_token_ids'])
    
    log_print(f"      Prompt: {len(ref_prompt_ids)}, Gen: {len(ref_generated_ids)}", end="")
    
    # Build batch for verification
    if is_diagonal:
        # Use EXACT reference sequences (same batch size, same tokens)
        log_print(f", exact neighbors (bs={ref_batch_size})", end="")
        batch_prompt_ids = reference_data['prompt_token_ids']
        batch_generated_ids = reference_data['all_batch_generated_ids']
        actual_batch_size = ref_batch_size
    else:
        # Off-diagonal: different batch size, arbitrary neighbors for positions 1+
        log_print(f", arb neighbors (bs={verify_batch_size})", end="")
        batch_prompt_ids = [ref_prompt_ids]
        batch_generated_ids = [ref_generated_ids]
        
        # Add arbitrary neighbors
        ref_dummies = DUMMY_SETS[ref_name]
        for i in range(verify_batch_size - 1):
            dummy_ids = tokenizer.encode(ref_dummies[i], add_special_tokens=True)
            # Truncate to match reference prompt length
            dummy_ids = dummy_ids[:len(ref_prompt_ids)]
            # Pad if too short
            if len(dummy_ids) < len(ref_prompt_ids):
                dummy_ids = dummy_ids + [tokenizer.pad_token_id or 0] * (len(ref_prompt_ids) - len(dummy_ids))
            batch_prompt_ids.append(dummy_ids)
            # Neighbors generate freely - initialize with empty
            batch_generated_ids.append([])
        
        actual_batch_size = verify_batch_size
    
    # Build input tensors from prompt IDs
    input_ids = torch.tensor(batch_prompt_ids, dtype=torch.long, device='cuda')
    attention_mask = torch.ones_like(input_ids)
    
    inputs = {
        'input_ids': input_ids,
        'attention_mask': attention_mask
    }
    
    generation_signals = []
    num_steps = len(ref_generated_ids)
    
    # FIRST STEP: Prefill with full prompt
    with torch.no_grad():
        outputs = model(**inputs, use_cache=True)
    
    past_kv = outputs.past_key_values
    
    # Extract signals from element 0
    ref_step_data = list(reference_data['signals'].values())[0] if reference_data['signals'] else None
    if ref_step_data:
        ref_token_ids = ref_step_data['signals']['logprobs']['token_ids']
        signals = extract_signals_for_token_ids(outputs, layer_indices, ref_token_ids, position=-1)
    else:
        signals = extract_signals_from_output(outputs, layer_indices, position=-1)
    
    absolute_position_index = inputs['input_ids'].shape[1] - 1
    generation_signals.append({
        'step': 0,
        'absolute_position': absolute_position_index,
        'signals': signals
    })
    
    # Prepare next tokens: teacher-force position 0, argmax for others (if off-diagonal)
    if is_diagonal:
        # All positions teacher-forced
        next_tokens = torch.tensor(
            [batch_generated_ids[i][0] for i in range(actual_batch_size)],
            dtype=torch.long, device='cuda'
        )
    else:
        # Position 0 teacher-forced, others argmax
        next_tokens_list = [ref_generated_ids[0]]  # Position 0: teacher-forced
        argmax_tokens = outputs.logits[1:, -1, :].argmax(dim=-1)  # Positions 1+: argmax
        for i in range(actual_batch_size - 1):
            next_tokens_list.append(argmax_tokens[i].item())
            batch_generated_ids[i + 1].append(argmax_tokens[i].item())
        next_tokens = torch.tensor(next_tokens_list, dtype=torch.long, device='cuda')
    
    # Update attention mask
    attention_mask = torch.cat([
        inputs['attention_mask'], 
        torch.ones((actual_batch_size, 1), device='cuda')
    ], dim=1)
    
    # SUBSEQUENT STEPS
    for step in range(1, num_steps):
        new_inputs = {
            'input_ids': next_tokens.unsqueeze(1),
            'attention_mask': attention_mask,
            'past_key_values': past_kv,
            'use_cache': True
        }
        
        with torch.no_grad():
            outputs = model(**new_inputs)
        
        past_kv = outputs.past_key_values
        
        # Extract signals using reference's token IDs for logprob comparison
        # Find which reference signal corresponds to this step
        ref_signals_list = list(reference_data['signals'].values())
        if step < len(ref_signals_list):
            ref_token_ids = ref_signals_list[step]['signals']['logprobs']['token_ids']
            signals = extract_signals_for_token_ids(outputs, layer_indices, ref_token_ids, position=-1)
        else:
            signals = extract_signals_from_output(outputs, layer_indices, position=-1)
        
        current_cache_length = past_kv[0][0].shape[2]
        absolute_position_index = current_cache_length - 1
        
        generation_signals.append({
            'step': step,
            'absolute_position': absolute_position_index,
            'signals': signals
        })
        
        # Prepare next tokens
        if step < num_steps - 1:
            if is_diagonal:
                next_tokens = torch.tensor(
                    [batch_generated_ids[i][step] for i in range(actual_batch_size)],
                    dtype=torch.long, device='cuda'
                )
            else:
                next_tokens_list = [ref_generated_ids[step]]
                argmax_tokens = outputs.logits[1:, -1, :].argmax(dim=-1)
                for i in range(actual_batch_size - 1):
                    next_tokens_list.append(argmax_tokens[i].item())
                    batch_generated_ids[i + 1].append(argmax_tokens[i].item())
                next_tokens = torch.tensor(next_tokens_list, dtype=torch.long, device='cuda')
        
        # Update attention mask
        attention_mask = torch.cat([
            attention_mask, 
            torch.ones((actual_batch_size, 1), device='cuda')
        ], dim=1)
    
    # Extract last 3
    num_generated = len(generation_signals)
    if num_generated >= 3:
        last_3_signals = {
            'pos_-3': generation_signals[-3],
            'pos_-2': generation_signals[-2],
            'pos_-1': generation_signals[-1]
        }
    elif num_generated == 2:
        last_3_signals = {
            'pos_-2': generation_signals[-2],
            'pos_-1': generation_signals[-1]
        }
    elif num_generated == 1:
        last_3_signals = {
            'pos_-1': generation_signals[-1]
        }
    else:
        last_3_signals = {}
    
    del outputs, inputs, past_kv
    torch.cuda.empty_cache()
    
    log_print(f" → {num_generated} steps")
    
    return {
        'signals': last_3_signals,
        'num_generated': num_generated
    }

# ============================================================================
# ANALYSIS
# ============================================================================

def check_token_consistency(decode_measurements, tokenizer):
    """
    Sanity check: Verify element 0 generates identical tokens across all batch sizes.
    If tokens differ, batch composition affects generation and comparison is confounded.
    """
    log_print("\n" + "="*80)
    log_print("TOKEN GENERATION CONSISTENCY CHECK")
    log_print("="*80)
    
    tokens_by_bs = {}
    for bs, data in decode_measurements.items():
        tokens_by_bs[bs] = data['generated_ids']
    
    bs_list = sorted(tokens_by_bs.keys())
    reference_tokens = tokens_by_bs[bs_list[0]]
    
    all_same = True
    log_print("\nGenerated tokens by batch size:")
    for bs in bs_list:
        tokens = tokens_by_bs[bs]
        match_str = "✓" if tokens == reference_tokens else "✗ DIFFERENT"
        decoded_text = tokenizer.decode(tokens)
        log_print(f"  bs={bs}:")
        log_print(f"    IDs:  {tokens}")
        log_print(f"    Text: {repr(decoded_text)}")
        log_print(f"    {match_str}")
        if tokens != reference_tokens:
            all_same = False
    
    if all_same:
        log_print("\n✓ Element 0 generates IDENTICAL tokens across all batch sizes")
        log_print("  → Can meaningfully compare activations for same token sequence")
    else:
        log_print("\n⚠ Element 0 generates DIFFERENT tokens across batch sizes")
        log_print("  → Batch composition affects generation - comparison may be confounded")
    
    return all_same


def compute_l2_distance(vec1, vec2):
    """Compute L2 distance between two vectors."""
    v1 = np.array(vec1)
    v2 = np.array(vec2)
    return float(np.linalg.norm(v1 - v2))


def compute_logprob_distance(logprobs1, logprobs2):
    """Compute L2 distance between logprob distributions."""
    probs1 = np.array(logprobs1['log_probs'])
    probs2 = np.array(logprobs2['log_probs'])
    return float(np.linalg.norm(probs1 - probs2))


def compare_signals(signals1, signals2, layer_indices):
    """Compare two signal sets, return distances."""
    common_positions = set(signals1.keys()) & set(signals2.keys())
    
    all_key_dists = []
    all_logprob_dists = []
    
    for pos_label in common_positions:
        sig1 = signals1[pos_label]['signals'] if 'signals' in signals1[pos_label] else signals1[pos_label]
        sig2 = signals2[pos_label]['signals'] if 'signals' in signals2[pos_label] else signals2[pos_label]
        
        # Key vectors
        for layer_name in sig1['key_vectors'].keys():
            dist = compute_l2_distance(
                sig1['key_vectors'][layer_name],
                sig2['key_vectors'][layer_name]
            )
            all_key_dists.append(dist)
        
        # Logprobs
        dist = compute_logprob_distance(sig1['logprobs'], sig2['logprobs'])
        all_logprob_dists.append(dist)
    
    return {
        'key_vectors_max': max(all_key_dists) if all_key_dists else 0.0,
        'key_vectors_mean': np.mean(all_key_dists) if all_key_dists else 0.0,
        'logprobs_max': max(all_logprob_dists) if all_logprob_dists else 0.0,
        'logprobs_mean': np.mean(all_logprob_dists) if all_logprob_dists else 0.0
    }


def analyze_within_hardware_sanity_check(measurements, batch_sizes, layer_indices):
    """
    Sanity check: Does batch size affect activations on the SAME hardware?
    
    This validates the premise before cross-hardware comparison.
    Expected: YES - different batch sizes should produce different activations.
    """
    log_print("\n" + "="*80)
    log_print("SANITY CHECK: WITHIN-HARDWARE BATCH SIZE EFFECTS")
    log_print("="*80)
    log_print("\nQuestion: Does batch size affect activations on the same GPU?")
    log_print("Expected: YES - batch size changes computation even for same sequence\n")
    
    # Group by reference
    by_ref = {}
    for m in measurements:
        ref = m['ref_name']
        if ref not in by_ref:
            by_ref[ref] = {}
        by_ref[ref][m['batch_size']] = m['signals']
    
    all_key_dists = []
    all_logprob_dists = []
    comparison_details = []
    
    for ref_name in sorted(by_ref.keys()):
        log_print(f"\n{ref_name.upper()}")
        log_print("-" * 60)
        
        ref_data = by_ref[ref_name]
        available_bs = sorted(ref_data.keys())
        
        for i, bs1 in enumerate(available_bs):
            for bs2 in available_bs[i+1:]:
                # Compare signals between batch sizes
                signals1 = ref_data[bs1]
                signals2 = ref_data[bs2]
                
                distances = compare_signals(signals1, signals2, layer_indices)
                
                log_print(f"  bs={bs1} vs bs={bs2}:")
                log_print(f"    Key vectors: Δ_max = {distances['key_vectors_max']:.2e}")
                log_print(f"    Logprobs:    Δ_max = {distances['logprobs_max']:.2e}")
                
                all_key_dists.append(distances['key_vectors_max'])
                all_logprob_dists.append(distances['logprobs_max'])
                
                comparison_details.append({
                    'ref': ref_name,
                    'bs1': bs1,
                    'bs2': bs2,
                    'key_distance': distances['key_vectors_max'],
                    'logprob_distance': distances['logprobs_max']
                })
    
    # Summary statistics
    log_print("\n" + "="*80)
    log_print("SANITY CHECK SUMMARY")
    log_print("="*80)
    
    key_mean = np.mean(all_key_dists)
    key_max = max(all_key_dists)
    logprob_mean = np.mean(all_logprob_dists)
    logprob_max = max(all_logprob_dists)
    
    log_print(f"\nKey vectors:")
    log_print(f"  μ = {key_mean:.2e}, max = {key_max:.2e}")
    log_print(f"\nLogprobs:")
    log_print(f"  μ = {logprob_mean:.2e}, max = {logprob_max:.2e}")
    
    # Check for zeros (would invalidate experiment)
    key_zeros = sum(1 for d in all_key_dists if d == 0.0)
    logprob_zeros = sum(1 for d in all_logprob_dists if d == 0.0)
    
    if key_zeros > 0:
        log_print(f"\n⚠ WARNING: {key_zeros}/{len(all_key_dists)} key comparisons are EXACTLY ZERO")
    if logprob_zeros > 0:
        log_print(f"⚠ WARNING: {logprob_zeros}/{len(all_logprob_dists)} logprob comparisons are EXACTLY ZERO")
    
    # Track which batch pairs produce zeros, grouped by reference
    if key_zeros > 0 or logprob_zeros > 0:
        zeros_by_ref = {}
        for r in comparison_details:
            ref = r['ref']
            pair = (r['bs1'], r['bs2'])
            if ref not in zeros_by_ref:
                zeros_by_ref[ref] = {'key': [], 'logprob': []}
            if r['key_distance'] == 0.0:
                zeros_by_ref[ref]['key'].append(pair)
            if r['logprob_distance'] == 0.0:
                zeros_by_ref[ref]['logprob'].append(pair)
        
        log_print(f"\n  Zero locations by reference:")
        for ref in sorted(zeros_by_ref.keys()):
            key_pairs = zeros_by_ref[ref]['key']
            log_pairs = zeros_by_ref[ref]['logprob']
            if key_pairs or log_pairs:
                log_print(f"    {ref}:")
                if key_pairs:
                    log_print(f"      key zeros:     {sorted(key_pairs)}")
                if log_pairs:
                    log_print(f"      logprob zeros: {sorted(log_pairs)}")
        
        # Check consistency across references
        all_refs = sorted(zeros_by_ref.keys())
        if len(all_refs) >= 2:
            key_sets = [set(zeros_by_ref[ref]['key']) for ref in all_refs]
            log_sets = [set(zeros_by_ref[ref]['logprob']) for ref in all_refs]
            
            key_intersection = set.intersection(*key_sets) if all(key_sets) else set()
            log_intersection = set.intersection(*log_sets) if all(log_sets) else set()
            
            log_print(f"\n  Cross-reference consistency:")
            if key_intersection:
                log_print(f"    Key zeros consistent across ALL refs: {sorted(key_intersection)}")
                log_print(f"    → SYSTEMATIC: These batch pairs produce identical activations!")
            else:
                log_print(f"    Key zeros: NO pairs are zero across all refs (coincidental)")
            if log_intersection:
                log_print(f"    Logprob zeros consistent across ALL refs: {sorted(log_intersection)}")
                log_print(f"    → SYSTEMATIC: These batch pairs produce identical logprobs!")
            else:
                log_print(f"    Logprob zeros: NO pairs are zero across all refs (coincidental)")
    
    # Conclusion
    threshold = 1e-10
    
    log_print("\n" + "-"*60)
    if key_mean > threshold and logprob_mean > threshold:
        log_print("✓ SANITY CHECK PASSED")
        log_print("  Batch size DOES affect computation on same hardware")
        log_print("  → Proceeding with cross-hardware comparison is meaningful")
    elif key_mean > threshold or logprob_mean > threshold:
        log_print("~ PARTIAL PASS")
        signal = "Key vectors" if key_mean > threshold else "Logprobs"
        log_print(f"  Only {signal} show batch size effects")
    else:
        log_print("✗ SANITY CHECK FAILED")
        log_print("  Batch size does NOT affect computation!")
        log_print("  → Cross-hardware experiment would be meaningless")
    
    return {
        'key_vectors_mean': float(key_mean),
        'key_vectors_max': float(key_max),
        'logprobs_mean': float(logprob_mean),
        'logprobs_max': float(logprob_max),
        'comparisons': comparison_details
    }


def analyze_cross_hardware_matrix(comparison_results, batch_sizes):
    """Analyze the comparison matrix and determine detectability."""
    log_print("\n" + "="*80)
    log_print("CROSS-HARDWARE BATCH SIZE DETECTABILITY ANALYSIS")
    log_print("="*80)
    
    # Group by reference
    by_ref = {}
    for result in comparison_results:
        ref = result['ref_name']
        if ref not in by_ref:
            by_ref[ref] = {}
        key = (result['claimed_batch_size'], result['verify_batch_size'])
        by_ref[ref][key] = result
    
    all_matrices = {'key_vectors': [], 'logprobs': []}
    
    for ref_name in sorted(by_ref.keys()):
        log_print(f"\n{'='*80}")
        log_print(f"{ref_name.upper()}")
        log_print("="*80)
        
        ref_data = by_ref[ref_name]
        n_bs = len(batch_sizes)
        matrix_key = np.zeros((n_bs, n_bs))
        matrix_logprob = np.zeros((n_bs, n_bs))
        
        for i, claimed_bs in enumerate(batch_sizes):
            for j, verify_bs in enumerate(batch_sizes):
                key = (claimed_bs, verify_bs)
                if key in ref_data:
                    matrix_key[i, j] = ref_data[key]['distances']['key_vectors_max']
                    matrix_logprob[i, j] = ref_data[key]['distances']['logprobs_max']
        
        # Display matrices
        header = "              " + "".join([f"Verify bs={bs:>3} " for bs in batch_sizes])
        
        log_print("\nKey Vectors (max L2 distance):")
        log_print(header)
        for i, claimed_bs in enumerate(batch_sizes):
            row_str = f"Claim bs={claimed_bs:<3} "
            for j in range(n_bs):
                row_str += f"  {matrix_key[i,j]:10.2e}"
            log_print(row_str)
        
        log_print("\nLogprobs (max L2 distance):")
        log_print(header)
        for i, claimed_bs in enumerate(batch_sizes):
            row_str = f"Claim bs={claimed_bs:<3} "
            for j in range(n_bs):
                row_str += f"  {matrix_logprob[i,j]:10.2e}"
            log_print(row_str)
        
        all_matrices['key_vectors'].append(matrix_key)
        all_matrices['logprobs'].append(matrix_logprob)
    
    # Aggregate statistics
    log_print("\n" + "="*80)
    log_print("AGGREGATE STATISTICS (AVERAGE ACROSS REFERENCES)")
    log_print("="*80)
    
    results = {}
    n_bs = len(batch_sizes)
    
    for signal_type in ['key_vectors', 'logprobs']:
        matrices = all_matrices[signal_type]
        avg_matrix = np.mean(matrices, axis=0)
        
        header = "              " + "".join([f"Verify bs={bs:>3} " for bs in batch_sizes])
        
        log_print(f"\n{signal_type.upper()}:")
        log_print(header)
        for i, claimed_bs in enumerate(batch_sizes):
            row_str = f"Claim bs={claimed_bs:<3} "
            for j in range(n_bs):
                row_str += f"  {avg_matrix[i,j]:10.2e}"
            log_print(row_str)
        
        # Extract diagonal (baseline: hardware-only) and off-diagonal (signal: hardware + bs)
        diagonal = np.array([avg_matrix[i, i] for i in range(n_bs)])
        off_diagonal = np.array([avg_matrix[i, j] for i in range(n_bs) for j in range(n_bs) if i != j])
        
        baseline_mean = np.mean(diagonal)
        baseline_std = np.std(diagonal)
        signal_mean = np.mean(off_diagonal)
        signal_std = np.std(off_diagonal)
        snr = signal_mean / baseline_mean if baseline_mean > 0 else float('inf')
        
        log_print(f"\n  Diagonal (baseline - hardware difference only):")
        log_print(f"    μ = {baseline_mean:.2e}, σ = {baseline_std:.2e}")
        log_print(f"    Values: {[f'{d:.2e}' for d in diagonal]}")
        
        log_print(f"\n  Off-diagonal (signal - hardware + batch size difference):")
        log_print(f"    μ = {signal_mean:.2e}, σ = {signal_std:.2e}")
        
        log_print(f"\n  SNR (signal/baseline): {snr:.2f}×")
        
        results[signal_type] = {
            'matrix': avg_matrix.tolist(),
            'baseline_mean': float(baseline_mean),
            'baseline_std': float(baseline_std),
            'signal_mean': float(signal_mean),
            'signal_std': float(signal_std),
            'snr': float(snr)
        }
    
    # Conclusion
    log_print("\n" + "="*80)
    log_print("CONCLUSION")
    log_print("="*80)
    
    threshold = 1.5
    
    key_snr = results['key_vectors']['snr']
    log_snr = results['logprobs']['snr']
    
    log_print(f"\nKey Vectors: SNR = {key_snr:.2f}× {'✓ DETECTABLE' if key_snr >= threshold else '✗ NOT DETECTABLE'}")
    log_print(f"Logprobs:    SNR = {log_snr:.2f}× {'✓ DETECTABLE' if log_snr >= threshold else '✗ NOT DETECTABLE'}")
    
    if key_snr >= threshold and log_snr >= threshold:
        log_print("\n✓ BATCH SIZE MISMATCHES ARE DETECTABLE CROSS-HARDWARE")
        log_print("  → Verification cluster can detect batch size evasion")
        log_print("  → Off-diagonal (bs mismatch) >> diagonal (hardware-only)")
    elif key_snr >= threshold or log_snr >= threshold:
        log_print("\n~ PARTIAL DETECTABILITY")
        det = 'Key vectors' if key_snr >= threshold else 'Logprobs'
        log_print(f"  → {det} can detect batch size mismatches")
    else:
        log_print("\n✗ BATCH SIZE NOT RELIABLY DETECTABLE CROSS-HARDWARE")
        log_print("  → Batch size signal is comparable to hardware baseline")
        log_print("  → Cannot reliably distinguish bs mismatch from hardware difference")
    
    return {
        'matrices': {k: [m.tolist() for m in v] for k, v in all_matrices.items()},
        'statistics': results
    }

# ============================================================================
# MAIN
# ============================================================================

def main():
    global REFERENCE_SEQUENCES, DUMMY_SETS
    
    log_path = setup_logging()
    system_info = collect_system_info()
    
    mode = "VERIFICATION (teacher-forcing)" if TEACHER_FORCING else "GENERATION (reference)"
    
    log_print("="*80)
    log_print(f"CROSS-HARDWARE BATCH SIZE DETECTABILITY - {mode}")
    log_print("="*80)
    log_print(f"Log file: {log_path}")
    log_print(f"\nEnvironment:")
    for k, v in system_info.items():
        log_print(f"  {k}: {v}")
    
    log_print(f"\nConfiguration:")
    log_print(f"  Model: {MODEL_NAME}")
    log_print(f"  Layers: {LAYER_INDICES}")
    log_print(f"  Batch sizes: {BATCH_SIZES}")
    log_print(f"  Max tokens: {MAX_NEW_TOKENS}")
    if TEACHER_FORCING:
        log_print(f"  Reference file: {REFERENCE_FILE}")
    log_print()
    
    # Load model
    log_print("Loading model...")
    
    # Verify flash_attn is installed if needed
    if ATTN_IMPLEMENTATION == "flash_attention_2":
        try:
            import flash_attn
            log_print(f"  flash_attn {flash_attn.__version__} available")
        except ImportError:
            log_print("\n✗ ATTN_IMPLEMENTATION='flash_attention_2' but flash_attn not installed")
            log_print("  Either install: pip install flash-attn")
            log_print("  Or change: ATTN_IMPLEMENTATION = 'eager' or 'sdpa'")
            sys.exit(1)
    
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
    
    # Initialize sequences from PDF (needed for dummy neighbors in off-diagonal)
    REFERENCE_SEQUENCES, DUMMY_SETS = create_sequences_from_pdf(tokenizer)
    log_print(f"Created {len(REFERENCE_SEQUENCES)} reference sequences\n")
    
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.bfloat16,
        cache_dir=CACHE_DIR,
        low_cpu_mem_usage=True,
        device_map="auto",
        attn_implementation=ATTN_IMPLEMENTATION
    )
    log_print(f"✓ Model loaded (attn_implementation={ATTN_IMPLEMENTATION})\n")
    
    output_dir = '/workspace/experiments'
    os.makedirs(output_dir, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    if TEACHER_FORCING:
        # ================================================================
        # VERIFICATION MODE
        # ================================================================
        log_print("Loading reference file...")
        with open(REFERENCE_FILE, 'r') as f:
            reference = json.load(f)
        
        ref_env = reference['metadata']['environment']
        ref_gpu = ref_env['gpu_name']
        log_print(f"Reference GPU: {ref_gpu}")
        log_print(f"Verifier GPU:  {system_info['gpu_name']}")
        
        # Validate environments match (except hardware)
        env_validation = validate_environment_match(ref_env, system_info)
        
        # Validate model matches
        ref_model = reference['metadata']['model']
        if ref_model != MODEL_NAME:
            log_print(f"\n✗ MODEL MISMATCH: reference={ref_model}, verifier={MODEL_NAME}")
            log_print("  Aborting - models must match for meaningful comparison.")
            sys.exit(1)
        
        # Validate batch sizes match
        ref_batch_sizes = reference['metadata']['batch_sizes']
        if ref_batch_sizes != BATCH_SIZES:
            log_print(f"\n✗ BATCH SIZE MISMATCH:")
            log_print(f"  Reference: {ref_batch_sizes}")
            log_print(f"  Verifier:  {BATCH_SIZES}")
            log_print("  Aborting - batch sizes must match for comparison matrix.")
            sys.exit(1)
        
        # Validate layer indices match
        ref_layers = reference['metadata']['layer_indices']
        if ref_layers != LAYER_INDICES:
            log_print(f"\n✗ LAYER INDICES MISMATCH:")
            log_print(f"  Reference: {ref_layers}")
            log_print(f"  Verifier:  {LAYER_INDICES}")
            log_print("  Aborting - layer indices must match for signal comparison.")
            sys.exit(1)
        
        log_print("\n✓ Model, batch sizes, and layer indices match\n")
        
        comparison_results = []
        
        # Build lookup for reference measurements
        ref_by_key = {}
        for m in reference['measurements']:
            key = (m['ref_name'], m['batch_size'])
            ref_by_key[key] = m
        
        for ref_name in sorted(REFERENCE_SEQUENCES.keys()):
            log_print(f"\n{'='*80}")
            log_print(f"REFERENCE: {ref_name}")
            log_print("="*80)
            
            for claimed_bs in BATCH_SIZES:
                ref_key = (ref_name, claimed_bs)
                if ref_key not in ref_by_key:
                    log_print(f"  ⚠ No reference data for {ref_name} bs={claimed_bs}")
                    continue
                
                ref_data = ref_by_key[ref_key]
                log_print(f"\n  Claimed batch size: {claimed_bs}")
                
                for verify_bs in BATCH_SIZES:
                    is_diagonal = (claimed_bs == verify_bs)
                    
                    log_print(f"    Verify bs={verify_bs} ({'diagonal' if is_diagonal else 'off-diag'}):", end="")
                    
                    verify_result = run_teacher_forced_decode(
                        model, tokenizer, ref_name, ref_data,
                        verify_bs, LAYER_INDICES, is_diagonal
                    )
                    
                    # Compare signals
                    distances = compare_signals(
                        ref_data['signals'],
                        verify_result['signals'],
                        LAYER_INDICES
                    )
                    
                    log_print(f"      → Key: {distances['key_vectors_max']:.2e}, Logprob: {distances['logprobs_max']:.2e}")
                    
                    comparison_results.append({
                        'ref_name': ref_name,
                        'claimed_batch_size': claimed_bs,
                        'verify_batch_size': verify_bs,
                        'is_diagonal': is_diagonal,
                        'distances': distances,
                        'verify_signals': verify_result['signals']
                    })
        
        # Analyze
        analysis = analyze_cross_hardware_matrix(comparison_results, BATCH_SIZES)
        
        # Save results
        results = {
            'metadata': {
                'reference_gpu': ref_gpu,
                'verifier_gpu': system_info['gpu_name'],
                'reference_file': REFERENCE_FILE,
                'reference_environment': ref_env,
                'verifier_environment': system_info,
                'environment_validation': env_validation,
                'model': MODEL_NAME,
                'layer_indices': LAYER_INDICES,
                'batch_sizes': BATCH_SIZES,
                'timestamp': timestamp
            },
            'comparisons': comparison_results,
            'analysis': analysis
        }
        
        filepath = os.path.join(output_dir, f"verify_{timestamp}.json")
        with open(filepath, 'w') as f:
            json.dump(results, f, indent=2)
        
        log_print(f"\n✓ Verification results saved to: {filepath}")
        
    else:
        # ================================================================
        # GENERATION MODE
        # ================================================================
        results = {
            'metadata': {
                'environment': system_info,
                'model': MODEL_NAME,
                'layer_indices': LAYER_INDICES,
                'batch_sizes': BATCH_SIZES,
                'max_new_tokens': MAX_NEW_TOKENS,
                'timestamp': timestamp
            },
            'measurements': []
        }
        
        for ref_name, ref_text in REFERENCE_SEQUENCES.items():
            log_print(f"\n{'='*80}")
            log_print(f"REFERENCE: {ref_name}")
            log_print("="*80)
            
            # Pre-compute minimum length
            min_prompt_length = compute_min_length_across_batches(
                ref_text, ref_name, tokenizer, BATCH_SIZES
            )
            log_print(f"\nGlobal minimum prompt length: {min_prompt_length} tokens\n")
            
            for batch_size in BATCH_SIZES:
                log_print(f"  bs={batch_size}:", end=" ")
                
                decode_data = run_decode_with_extraction(
                    model, tokenizer, ref_text, ref_name, batch_size, LAYER_INDICES,
                    forced_length=min_prompt_length
                )
                
                results['measurements'].append({
                    'ref_name': ref_name,
                    'batch_size': batch_size,
                    'generated_ids': decode_data['generated_ids'],
                    'all_batch_generated_ids': decode_data['all_batch_generated_ids'],
                    'prompt_token_ids': decode_data['prompt_token_ids'],
                    'prompt_length': decode_data['prompt_length'],
                    'signals': decode_data['signals'],
                    'num_generated': decode_data['num_generated']
                })
        
        filepath = os.path.join(output_dir, f"decode_{timestamp}.json")
        with open(filepath, 'w') as f:
            json.dump(results, f, indent=2)
        
        log_print(f"\n✓ Generation results saved to: {filepath}")
        
        # Run sanity check
        sanity_check = analyze_within_hardware_sanity_check(
            results['measurements'], BATCH_SIZES, LAYER_INDICES
        )
        results['sanity_check'] = sanity_check
        
        # Re-save with sanity check
        with open(filepath, 'w') as f:
            json.dump(results, f, indent=2)
        
        log_print(f"\nNext step: Copy {filepath} to verifier machine")
        log_print(f"Then set TEACHER_FORCING = True and REFERENCE_FILE = '<path>'")
    
    file_size_mb = os.path.getsize(filepath) / (1024 * 1024)
    log_print(f"File size: {file_size_mb:.1f} MB")
    
    log_print(f"\n{'='*80}")
    log_print("EXPERIMENT COMPLETE")
    log_print("="*80 + "\n")
    
    close_logging()


if __name__ == "__main__":
    main()



CROSS-HARDWARE BATCH SIZE DETECTABILITY - VERIFICATION (teacher-forcing)
Log file: /workspace/experiments/experiment_verify_20251124_192512.txt

Environment:
  hostname: 06768f39097d
  platform: Linux-6.14.0-24-generic-x86_64-with-glibc2.39
  python_version: 3.12.3
  torch_version: 2.8.0+cu128
  cuda_version: 12.8
  cudnn_version: 91002
  transformers_version: 4.57.2
  numpy_version: 2.1.2
  attn_implementation: flash_attention_2
  gpu_name: NVIDIA H100 PCIe
  gpu_count: 1
  flash_attn_version: 2.8.3

Configuration:
  Model: Qwen/Qwen2.5-7B-Instruct
  Layers: [1, 4, 10, 18, 28]
  Batch sizes: [4, 5, 8, 9, 16, 17]
  Max tokens: 20
  Reference file: A100_reference.json

Loading model...
  flash_attn 2.8.3 available


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

Found 1 PDF(s)
  Loading: /workspace/Verification-for-International-AI-Governance.pdf
    → 120214 tokens
Total tokens: 120214
Creating 51 slices of 512 tokens each
Created 3 reference sequences



config.json:   0%|          | 0.00/663 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/3.95G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/3.56G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/243 [00:00<?, ?B/s]

✓ Model loaded (attn_implementation=flash_attention_2)

Loading reference file...
Reference GPU: NVIDIA A100 80GB PCIe
Verifier GPU:  NVIDIA H100 PCIe

ENVIRONMENT VALIDATION

Script configuration:
  ✓ attn_implementation: flash_attention_2

Container-level dependencies:
  ✓ python_version: 3.12.3
  ✓ cuda_version: 12.8
  ✓ cudnn_version: 91002

Pip-installable packages:
  ✓ torch_version: 2.8.0+cu128
  ✓ transformers_version: 4.57.2
  ✓ numpy_version: 2.1.2
  ✓ flash_attn_version: 2.8.3

Expected differences (hardware):
  ✓ gpu_name: reference=NVIDIA A100 80GB PCIe, verifier=NVIDIA H100 PCIe
  ✓ hostname: reference=c81492e44ce1, verifier=06768f39097d

------------------------------------------------------------
✓ ENVIRONMENT VALIDATION PASSED
  All critical software versions match.
  Only hardware differs - results will be meaningful.

✓ Model, batch sizes, and layer indices match


REFERENCE: ref_0

  Claimed batch size: 4
    Verify bs=4 (diagonal):      Prompt: 512, Gen: 20, exact 