In [None]:
#!/usr/bin/env python3
"""
Cross-hardware attention implementation detectability experiment.
Compares eager, sdpa, and flash_attention_2.
"""

import torch
import torch.nn.functional as F
import numpy as np
import json
import os
import sys
import glob
import socket
import platform
from datetime import datetime
from transformers import AutoModelForCausalLM, AutoTokenizer
import PyPDF2

# ============================================================================
# CONFIGURATION
# ============================================================================

TEACHER_FORCING = True
REFERENCE_FILE = "A100_generate_10000in_300out.json"

MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"
ATTN_IMPLEMENTATIONS = ['eager', 'sdpa', 'flash_attention_2']

LAYER_INDICES = [-1]
TOKENS_PER_SLICE = 10000
MAX_NEW_TOKENS = 300
NUM_REFERENCES = 6

EQUIVALENCE_THRESHOLD = 1e-9

SYSTEM_PROMPT = "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."

# ============================================================================
# LOGGING
# ============================================================================

LOG_FILE = None

def setup_logging(output_dir='/workspace/experiments'):
    global LOG_FILE
    os.makedirs(output_dir, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    mode = "verify" if TEACHER_FORCING else "generate"
    log_path = os.path.join(output_dir, f"attn_experiment_{mode}_{timestamp}.txt")
    LOG_FILE = open(log_path, 'w')
    return log_path

def log_print(*args, **kwargs):
    print(*args, **kwargs)
    if LOG_FILE:
        log_kwargs = {k: v for k, v in kwargs.items() if k != 'file'}
        print(*args, **log_kwargs, file=LOG_FILE)
        LOG_FILE.flush()

def close_logging():
    global LOG_FILE
    if LOG_FILE:
        LOG_FILE.close()
        LOG_FILE = None

# ============================================================================
# PDF LOADING
# ============================================================================

def load_pdf_text(pdf_path):
    text = ""
    with open(pdf_path, 'rb') as f:
        reader = PyPDF2.PdfReader(f)
        for page in reader.pages:
            page_text = page.extract_text()
            if page_text:
                text += page_text + " "
    return text.strip()

def create_prompts(tokenizer, num_references=NUM_REFERENCES):
    pdf_files = sorted(glob.glob("/workspace/*.pdf"))
    if not pdf_files:
        pdf_files = sorted(glob.glob("*.pdf"))
    if not pdf_files:
        raise FileNotFoundError("No PDF files found")

    log_print(f"Found {len(pdf_files)} PDF(s)")
    for pdf_path in pdf_files:
        log_print(f"  Loading: {pdf_path}")

    all_text = ""
    for pdf_path in pdf_files:
        text = load_pdf_text(pdf_path)
        all_text += text + " "

    content_tokens = tokenizer.encode(all_text, add_special_tokens=False)
    log_print(f"Total source tokens: {len(content_tokens)}")

    if len(content_tokens) < num_references * TOKENS_PER_SLICE:
        raise ValueError(f"Need {num_references * TOKENS_PER_SLICE} tokens but only have {len(content_tokens)}")

    prefix = f"""<|im_start|>system
{SYSTEM_PROMPT}<|im_end|>
<|im_start|>user
Here is an excerpt from a document:

\""""
    
    suffix = f""""

Based on this excerpt, what type of document do you think this is from, and what is its likely subject matter? Explain your reasoning.<|im_end|>
<|im_start|>assistant
"""
    
    prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False)
    suffix_tokens = tokenizer.encode(suffix, add_special_tokens=False)
    
    total_len = len(prefix_tokens) + TOKENS_PER_SLICE + len(suffix_tokens)
    log_print(f"Prompt structure: {len(prefix_tokens)} prefix + {TOKENS_PER_SLICE} snippet + {len(suffix_tokens)} suffix = {total_len} tokens")

    prompts = []
    for i in range(num_references):
        start = i * TOKENS_PER_SLICE
        end = start + TOKENS_PER_SLICE
        snippet_tokens = content_tokens[start:end]
        prompt = prefix_tokens + snippet_tokens + suffix_tokens
        prompts.append(prompt)
    
    return prompts

# ============================================================================
# SYSTEM INFO
# ============================================================================

def collect_system_info(attn_impl):
    import transformers
    
    info = {
        "hostname": socket.gethostname(),
        "platform": platform.platform(),
        "python_version": sys.version.split()[0],
        "torch_version": torch.__version__,
        "cuda_version": torch.version.cuda if torch.cuda.is_available() else "N/A",
        "cudnn_version": str(torch.backends.cudnn.version()) if torch.cuda.is_available() else "N/A",
        "transformers_version": transformers.__version__,
        "numpy_version": np.__version__,
        "attn_implementation": attn_impl,
        "gpu_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "N/A",
        "gpu_count": torch.cuda.device_count() if torch.cuda.is_available() else 0,
    }
    
    try:
        import flash_attn
        info["flash_attn_version"] = flash_attn.__version__
    except ImportError:
        info["flash_attn_version"] = "N/A"
        
    return info

# ============================================================================
# MODEL LOADING
# ============================================================================

def load_model(attn_impl):
    log_print(f"Loading model with attention: {attn_impl}")
    
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.bfloat16,
        device_map="cuda",
        attn_implementation=attn_impl,
        trust_remote_code=True
    )
    model.eval()
    
    return model

def load_tokenizer():
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    return tokenizer

# ============================================================================
# SIGNAL EXTRACTION
# ============================================================================

def extract_signals(outputs, layer_indices, position=-1):
    """Extract key vectors and top-k logprobs at a position."""
    signals = {
        'key_vectors': {},
        'logprobs': {}
    }
    
    for layer_idx in layer_indices:
        key_states = outputs.past_key_values[layer_idx][0]
        key_at_pos = key_states[0, :, position, :].detach().float().cpu().numpy().flatten()
        signals['key_vectors'][f'layer_{layer_idx}'] = key_at_pos.tolist()
    
    logits = outputs.logits[0, position, :]
    log_probs = F.log_softmax(logits, dim=-1)
    top_values, top_indices = torch.topk(log_probs, k=100)
    
    signals['logprobs'] = {
        'token_ids': top_indices.cpu().tolist(),
        'log_probs': top_values.cpu().tolist()
    }
    
    return signals

def extract_prefill_signals(outputs, layer_indices, positions=[-3, -2, -1]):
    """Extract signals from multiple positions during prefill."""
    prefill_signals = {}
    for pos in positions:
        pos_label = f"pos_{pos}"
        prefill_signals[pos_label] = extract_signals(outputs, layer_indices, position=pos)
    return prefill_signals

def extract_signals_for_token_ids(outputs, layer_indices, token_ids, position=-1):
    """Extract signals for SPECIFIC token IDs (used in verification)."""
    signals = {
        'key_vectors': {},
        'logprobs': {}
    }

    for layer_idx in layer_indices:
        key_states = outputs.past_key_values[layer_idx][0]
        key_at_pos = key_states[0, :, position, :].detach().float().cpu().numpy().flatten()
        signals['key_vectors'][f'layer_{layer_idx}'] = key_at_pos.tolist()

    logits = outputs.logits[0, position, :]
    log_probs = F.log_softmax(logits, dim=-1)
    token_ids_tensor = torch.tensor(token_ids, device=logits.device)
    selected_logprobs = log_probs[token_ids_tensor]

    signals['logprobs'] = {
        'token_ids': token_ids,
        'log_probs': selected_logprobs.cpu().tolist()
    }

    return signals

def extract_prefill_signals_for_token_ids(outputs, layer_indices, ref_prefill_signals, positions=[-3, -2, -1]):
    """Extract prefill signals using reference token IDs."""
    prefill_signals = {}
    for pos in positions:
        pos_label = f"pos_{pos}"
        if pos_label in ref_prefill_signals:
            ref_token_ids = ref_prefill_signals[pos_label]['logprobs']['token_ids']
            prefill_signals[pos_label] = extract_signals_for_token_ids(
                outputs, layer_indices, ref_token_ids, position=pos
            )
    return prefill_signals

# ============================================================================
# GENERATION MODE
# ============================================================================

def run_generation(model, tokenizer, prompt_ids, layer_indices):
    """Run generation and extract prefill + decode signals."""
    torch.cuda.empty_cache()
    
    input_ids = torch.tensor([prompt_ids], dtype=torch.long, device='cuda')
    attention_mask = torch.ones_like(input_ids)
    inputs = {'input_ids': input_ids, 'attention_mask': attention_mask}
    
    prompt_length = input_ids.shape[1]
    generated_ids = []
    generation_signals = []
    
    # PREFILL
    with torch.no_grad():
        outputs = model(**inputs, use_cache=True)
    
    past_kv = outputs.past_key_values
    prefill_signals = extract_prefill_signals(outputs, layer_indices, positions=[-3, -2, -1])
    
    # First token
    next_token = outputs.logits[:, -1, :].argmax(dim=-1)
    generated_ids.append(next_token[0].item())
    
    signals = extract_signals(outputs, layer_indices, position=-1)
    generation_signals.append({
        'step': 0,
        'absolute_position': input_ids.shape[1] - 1,
        'signals': signals
    })
    
    attention_mask = torch.cat([
        attention_mask,
        torch.ones((1, 1), device='cuda')
    ], dim=1)
    
    # DECODE
    for step in range(1, MAX_NEW_TOKENS):
        new_inputs = {
            'input_ids': next_token.unsqueeze(1),
            'attention_mask': attention_mask,
            'past_key_values': past_kv,
            'use_cache': True
        }
        
        with torch.no_grad():
            outputs = model(**new_inputs)
        
        past_kv = outputs.past_key_values
        next_token = outputs.logits[:, -1, :].argmax(dim=-1)
        generated_ids.append(next_token[0].item())
        
        signals = extract_signals(outputs, layer_indices, position=-1)
        current_cache_length = past_kv[0][0].shape[2]
        generation_signals.append({
            'step': step,
            'absolute_position': current_cache_length - 1,
            'signals': signals
        })
        
        attention_mask = torch.cat([
            attention_mask,
            torch.ones((1, 1), device='cuda')
        ], dim=1)
        
        if generated_ids[-1] == tokenizer.eos_token_id:
            break
    
    # Extract last 3 decode signals
    num_generated = len(generation_signals)
    if num_generated >= 3:
        last_3_signals = {
            'pos_-3': generation_signals[-3],
            'pos_-2': generation_signals[-2],
            'pos_-1': generation_signals[-1]
        }
    elif num_generated == 2:
        last_3_signals = {
            'pos_-2': generation_signals[-2],
            'pos_-1': generation_signals[-1]
        }
    elif num_generated == 1:
        last_3_signals = {
            'pos_-1': generation_signals[-1]
        }
    else:
        last_3_signals = {}
    
    del outputs, past_kv
    torch.cuda.empty_cache()
    
    return {
        'prompt_ids': prompt_ids,
        'generated_ids': generated_ids,
        'prompt_length': prompt_length,
        'prefill_signals': prefill_signals,
        'decode_signals': last_3_signals,
        'num_generated': num_generated
    }

# ============================================================================
# VERIFICATION MODE (TEACHER FORCING)
# ============================================================================

def run_teacher_forced(model, tokenizer, reference_data, layer_indices):
    """Teacher-forced decode: feed reference tokens, extract signals."""
    torch.cuda.empty_cache()
    
    prompt_ids = reference_data['prompt_ids']
    ref_generated_ids = reference_data['generated_ids']
    
    input_ids = torch.tensor([prompt_ids], dtype=torch.long, device='cuda')
    attention_mask = torch.ones_like(input_ids)
    inputs = {'input_ids': input_ids, 'attention_mask': attention_mask}
    
    generation_signals = []
    num_steps = len(ref_generated_ids)
    
    # PREFILL
    with torch.no_grad():
        outputs = model(**inputs, use_cache=True)
    
    past_kv = outputs.past_key_values
    ref_prefill_signals = reference_data['prefill_signals']
    prefill_signals = extract_prefill_signals_for_token_ids(
        outputs, layer_indices, ref_prefill_signals, positions=[-3, -2, -1]
    )
    
    # First step signal
    ref_decode_signals = reference_data['decode_signals']
    ref_step_data = list(ref_decode_signals.values())[0] if ref_decode_signals else None
    if ref_step_data:
        ref_token_ids = ref_step_data['signals']['logprobs']['token_ids']
        signals = extract_signals_for_token_ids(outputs, layer_indices, ref_token_ids, position=-1)
    else:
        signals = extract_signals(outputs, layer_indices, position=-1)
    
    generation_signals.append({
        'step': 0,
        'absolute_position': input_ids.shape[1] - 1,
        'signals': signals
    })
    
    # Teacher-forced decode
    attention_mask = torch.cat([
        attention_mask,
        torch.ones((1, 1), device='cuda')
    ], dim=1)
    
    for step in range(1, num_steps):
        next_token = torch.tensor([[ref_generated_ids[step - 1]]], dtype=torch.long, device='cuda')
        
        new_inputs = {
            'input_ids': next_token,
            'attention_mask': attention_mask,
            'past_key_values': past_kv,
            'use_cache': True
        }
        
        with torch.no_grad():
            outputs = model(**new_inputs)
        
        past_kv = outputs.past_key_values
        
        # Get reference token IDs for this step if available
        step_key = f'pos_-{num_steps - step}' if step >= num_steps - 3 else None
        if step_key and step_key in ref_decode_signals:
            ref_token_ids = ref_decode_signals[step_key]['signals']['logprobs']['token_ids']
            signals = extract_signals_for_token_ids(outputs, layer_indices, ref_token_ids, position=-1)
        else:
            signals = extract_signals(outputs, layer_indices, position=-1)
        
        current_cache_length = past_kv[0][0].shape[2]
        generation_signals.append({
            'step': step,
            'absolute_position': current_cache_length - 1,
            'signals': signals
        })
        
        attention_mask = torch.cat([
            attention_mask,
            torch.ones((1, 1), device='cuda')
        ], dim=1)
    
    # Extract last 3 decode signals
    num_generated = len(generation_signals)
    if num_generated >= 3:
        last_3_signals = {
            'pos_-3': generation_signals[-3],
            'pos_-2': generation_signals[-2],
            'pos_-1': generation_signals[-1]
        }
    elif num_generated == 2:
        last_3_signals = {
            'pos_-2': generation_signals[-2],
            'pos_-1': generation_signals[-1]
        }
    elif num_generated == 1:
        last_3_signals = {
            'pos_-1': generation_signals[-1]
        }
    else:
        last_3_signals = {}
    
    del outputs, past_kv
    torch.cuda.empty_cache()
    
    return {
        'prefill_signals': prefill_signals,
        'decode_signals': last_3_signals,
        'num_generated': num_generated
    }

# ============================================================================
# DISTANCE METRICS
# ============================================================================

def compute_key_distance(signals1, signals2, layer_indices):
    distances = []
    for layer_idx in layer_indices:
        key = f'layer_{layer_idx}'
        vec1 = np.array(signals1['key_vectors'][key])
        vec2 = np.array(signals2['key_vectors'][key])
        distances.append(np.linalg.norm(vec1 - vec2))
    return np.mean(distances)

def compute_logprob_distance(signals1, signals2):
    idx1 = signals1['logprobs']['token_ids']
    val1 = signals1['logprobs']['log_probs']
    idx2 = signals2['logprobs']['token_ids']
    val2 = signals2['logprobs']['log_probs']
    
    map1 = dict(zip(idx1, val1))
    map2 = dict(zip(idx2, val2))
    
    common = set(idx1) & set(idx2)
    if not common:
        return float('inf')
    
    diffs = [(map1[i] - map2[i])**2 for i in common]
    return np.sqrt(np.mean(diffs))

def compare_signals(signals1, signals2, layer_indices):
    """Compare two sets of prefill or decode signals."""
    key_distances = []
    logprob_distances = []
    
    common_positions = set(signals1.keys()) & set(signals2.keys())
    
    for pos in common_positions:
        s1 = signals1[pos]['signals'] if 'signals' in signals1[pos] else signals1[pos]
        s2 = signals2[pos]['signals'] if 'signals' in signals2[pos] else signals2[pos]
        
        key_distances.append(compute_key_distance(s1, s2, layer_indices))
        logprob_distances.append(compute_logprob_distance(s1, s2))
    
    finite_logprobs = [d for d in logprob_distances if d != float('inf')]
    
    return {
        'key_vectors_mean': np.mean(key_distances) if key_distances else 0,
        'logprobs_mean': np.mean(finite_logprobs) if finite_logprobs else float('inf')
    }

# ============================================================================
# WITHIN-HARDWARE ANALYSIS
# ============================================================================

def analyze_within_hardware(measurements, attn_types, layer_indices, signal_source='decode'):
    """Analyze within-hardware attention implementation effects."""
    log_print("\n" + "="*80)
    log_print(f"WITHIN-HARDWARE ATTENTION EFFECTS ({signal_source.upper()})")
    log_print("="*80)
    
    n = len(attn_types)
    all_key_matrices = []
    all_logprob_matrices = []
    
    header = "           " + " ".join(f"{a:>10}" for a in attn_types)
    
    for ref_idx in range(NUM_REFERENCES):
        log_print(f"\n--- ref_{ref_idx} ---")
        
        matrix_key = np.zeros((n, n))
        matrix_logprob = np.zeros((n, n))
        
        for i, attn_i in enumerate(attn_types):
            for j, attn_j in enumerate(attn_types):
                signals_key = 'prefill_signals' if signal_source == 'prefill' else 'decode_signals'
                sig_i = measurements[attn_i][ref_idx][signals_key]
                sig_j = measurements[attn_j][ref_idx][signals_key]
                
                distances = compare_signals(sig_i, sig_j, layer_indices)
                matrix_key[i, j] = distances['key_vectors_mean']
                matrix_logprob[i, j] = distances['logprobs_mean']
        
        log_print(f"\nKey Vectors (L2 distance):")
        log_print(header)
        for i, attn in enumerate(attn_types):
            row = f"{attn:>10} " + " ".join(f"{matrix_key[i,j]:10.2e}" for j in range(n))
            log_print(row)
        
        log_print(f"\nLogprobs (L2 distance):")
        log_print(header)
        for i, attn in enumerate(attn_types):
            row = f"{attn:>10} " + " ".join(f"{matrix_logprob[i,j]:10.2e}" for j in range(n))
            log_print(row)
        
        all_key_matrices.append(matrix_key)
        all_logprob_matrices.append(matrix_logprob)
    
    # Aggregate
    avg_key_matrix = np.mean(all_key_matrices, axis=0)
    avg_logprob_matrix = np.mean(all_logprob_matrices, axis=0)
    
    log_print(f"\nAGGREGATE (average across references):")
    
    header = "           " + " ".join(f"{a:>10}" for a in attn_types)
    
    log_print(f"\nKey Vectors (L2 distance):")
    log_print(header)
    for i, attn in enumerate(attn_types):
        row = f"{attn:>10} " + " ".join(f"{avg_key_matrix[i,j]:10.2e}" for j in range(n))
        log_print(row)
    
    log_print(f"\nLogprobs (L2 distance):")
    log_print(header)
    for i, attn in enumerate(attn_types):
        row = f"{attn:>10} " + " ".join(f"{avg_logprob_matrix[i,j]:10.2e}" for j in range(n))
        log_print(row)
    
    # Off-diagonal stats
    off_diag_key = []
    off_diag_logprob = []
    for i in range(n):
        for j in range(n):
            if i != j:
                off_diag_key.append(avg_key_matrix[i, j])
                off_diag_logprob.append(avg_logprob_matrix[i, j])
    
    log_print(f"\nOff-diagonal stats:")
    log_print(f"  Key vectors - Mean: {np.mean(off_diag_key):.2e}")
    log_print(f"  Logprobs - Mean: {np.mean([d for d in off_diag_logprob if d != float('inf')]):.2e}")
    
    # Check equivalences
    equiv_pairs = []
    for i in range(n):
        for j in range(i+1, n):
            if avg_key_matrix[i, j] < EQUIVALENCE_THRESHOLD:
                equiv_pairs.append((attn_types[i], attn_types[j]))
    
    if equiv_pairs:
        log_print(f"\nEquivalent pairs:")
        for p in equiv_pairs:
            log_print(f"  {p}")
    
    return {
        'key_matrix': avg_key_matrix.tolist(),
        'logprob_matrix': avg_logprob_matrix.tolist(),
        'equivalent_pairs': equiv_pairs
    }

# ============================================================================
# CROSS-HARDWARE ANALYSIS
# ============================================================================

def analyze_cross_hardware(gen_measurements, ver_measurements, attn_types, layer_indices, signal_source='decode'):
    """Analyze cross-hardware verification results."""
    log_print("\n" + "="*80)
    log_print(f"CROSS-HARDWARE VERIFICATION ({signal_source.upper()})")
    log_print("="*80)
    
    n = len(attn_types)
    all_key_matrices = []
    all_logprob_matrices = []
    
    for ref_idx in range(NUM_REFERENCES):
        matrix_key = np.zeros((n, n))
        matrix_logprob = np.zeros((n, n))
        
        for i, claimed in enumerate(attn_types):
            for j, verified in enumerate(attn_types):
                signals_key = 'prefill_signals' if signal_source == 'prefill' else 'decode_signals'
                
                gen_sig = gen_measurements[claimed][ref_idx][signals_key]
                ver_sig = ver_measurements[(claimed, verified)][ref_idx][signals_key]
                
                distances = compare_signals(gen_sig, ver_sig, layer_indices)
                matrix_key[i, j] = distances['key_vectors_mean']
                matrix_logprob[i, j] = distances['logprobs_mean']
        
        all_key_matrices.append(matrix_key)
        all_logprob_matrices.append(matrix_logprob)
    
    avg_key_matrix = np.mean(all_key_matrices, axis=0)
    avg_logprob_matrix = np.mean(all_logprob_matrices, axis=0)
    
    log_print(f"\nAGGREGATE (average across references):")
    log_print("  Rows = claimed attention, Cols = verified attention")
    
    header = "           " + " ".join(f"{a:>10}" for a in attn_types)
    
    log_print(f"\nKey Vectors (L2 distance):")
    log_print(header)
    for i, attn in enumerate(attn_types):
        row = f"{attn:>10} " + " ".join(f"{avg_key_matrix[i,j]:10.2e}" for j in range(n))
        log_print(row)
    
    log_print(f"\nLogprobs (L2 distance):")
    log_print(header)
    for i, attn in enumerate(attn_types):
        row = f"{attn:>10} " + " ".join(f"{avg_logprob_matrix[i,j]:10.2e}" for j in range(n))
        log_print(row)
    
    # SNR calculation
    diagonal_key = np.mean([avg_key_matrix[i, i] for i in range(n)])
    diagonal_logprob = np.mean([avg_logprob_matrix[i, i] for i in range(n) if avg_logprob_matrix[i, i] != float('inf')])
    
    off_diag_key = []
    off_diag_logprob = []
    for i in range(n):
        for j in range(n):
            if i != j:
                off_diag_key.append(avg_key_matrix[i, j])
                if avg_logprob_matrix[i, j] != float('inf'):
                    off_diag_logprob.append(avg_logprob_matrix[i, j])
    
    off_diagonal_key = np.mean(off_diag_key)
    off_diagonal_logprob = np.mean(off_diag_logprob) if off_diag_logprob else float('inf')
    
    key_snr = off_diagonal_key / diagonal_key if diagonal_key > 0 else float('inf')
    logprob_snr = off_diagonal_logprob / diagonal_logprob if diagonal_logprob > 0 else float('inf')
    
    log_print(f"\n" + "="*80)
    log_print("SNR ANALYSIS")
    log_print("="*80)
    log_print(f"\nDiagonal (baseline = cross-hardware, same attention):")
    log_print(f"  Key vectors: {diagonal_key:.2e}")
    log_print(f"  Logprobs: {diagonal_logprob:.2e}")
    log_print(f"\nOff-diagonal (different attention):")
    log_print(f"  Key vectors - Mean: {off_diagonal_key:.2e}, SNR: {key_snr:.2f}×")
    log_print(f"  Logprobs - Mean: {off_diagonal_logprob:.2e}, SNR: {logprob_snr:.2f}×")
    
    return {
        'key_matrix': avg_key_matrix.tolist(),
        'logprob_matrix': avg_logprob_matrix.tolist(),
        'diagonal_key': diagonal_key,
        'diagonal_logprob': diagonal_logprob,
        'off_diagonal_key': off_diagonal_key,
        'off_diagonal_logprob': off_diagonal_logprob,
        'key_snr': key_snr,
        'logprob_snr': logprob_snr
    }

# ============================================================================
# MAIN
# ============================================================================

def main():
    log_path = setup_logging()
    log_print("=" * 80)
    log_print("ATTENTION IMPLEMENTATION CROSS-HARDWARE EXPERIMENT")
    log_print("=" * 80)
    
    gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "N/A"
    log_print(f"GPU: {gpu_name}")
    log_print(f"Hostname: {socket.gethostname()}")
    
    tokenizer = load_tokenizer()
    prompts = create_prompts(tokenizer, NUM_REFERENCES)
    
    log_print(f"\nCreated {len(prompts)} prompts")
    
    if not TEACHER_FORCING:
        # ================================================================
        # GENERATION MODE
        # ================================================================
        log_print("\n" + "=" * 80)
        log_print("GENERATION MODE")
        log_print("=" * 80)
        
        results = {
            'metadata': {
                'model': MODEL_NAME,
                'layer_indices': LAYER_INDICES,
                'tokens_per_slice': TOKENS_PER_SLICE,
                'max_new_tokens': MAX_NEW_TOKENS,
                'num_references': NUM_REFERENCES,
                'environments': {}
            },
            'generations': {}
        }
        
        measurements = {}
        
        for attn_impl in ATTN_IMPLEMENTATIONS:
            log_print(f"\n--- Attention: {attn_impl} ---")
            
            model = load_model(attn_impl)
            results['metadata']['environments'][attn_impl] = collect_system_info(attn_impl)
            
            results['generations'][attn_impl] = []
            measurements[attn_impl] = []
            
            for ref_idx, prompt_ids in enumerate(prompts):
                log_print(f"  ref_{ref_idx}: ", end="")
                gen_data = run_generation(model, tokenizer, prompt_ids, LAYER_INDICES)
                
                results['generations'][attn_impl].append({
                    'ref_idx': ref_idx,
                    'prompt_ids': gen_data['prompt_ids'],
                    'generated_ids': gen_data['generated_ids'],
                    'prompt_length': gen_data['prompt_length'],
                    'prefill_signals': gen_data['prefill_signals'],
                    'decode_signals': gen_data['decode_signals'],
                    'num_generated': gen_data['num_generated']
                })
                
                measurements[attn_impl].append({
                    'prefill_signals': gen_data['prefill_signals'],
                    'decode_signals': gen_data['decode_signals']
                })
                
                log_print(f"{gen_data['num_generated']} tokens")
                log_print(f"    -> {tokenizer.decode(gen_data['generated_ids'][:20])}...")
            
            del model
            torch.cuda.empty_cache()
        
        # Within-hardware analysis
        prefill_analysis = analyze_within_hardware(measurements, ATTN_IMPLEMENTATIONS, LAYER_INDICES, 'prefill')
        decode_analysis = analyze_within_hardware(measurements, ATTN_IMPLEMENTATIONS, LAYER_INDICES, 'decode')
        
        results['within_hardware'] = {
            'prefill': prefill_analysis,
            'decode': decode_analysis
        }
        
        # Save
        output_dir = '/workspace/experiments'
        os.makedirs(output_dir, exist_ok=True)
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filepath = os.path.join(output_dir, f"attn_generate_{timestamp}.json")
        
        with open(filepath, 'w') as f:
            json.dump(results, f, indent=2)
        
        log_print(f"\nSaved to: {filepath}")
        log_print(f"Next: Copy to verifier, set TEACHER_FORCING=True, set REFERENCE_FILE")
        
    else:
        # ================================================================
        # VERIFICATION MODE
        # ================================================================
        log_print("\n" + "=" * 80)
        log_print("VERIFICATION MODE")
        log_print("=" * 80)
        
        with open(REFERENCE_FILE, 'r') as f:
            content = f.read()
            content = content.replace('Infinity', '1e309')
            reference = json.loads(content)
        
        log_print(f"Loaded reference: {REFERENCE_FILE}")
        
        # Environment validation
        ref_env = reference['metadata']['environments'][ATTN_IMPLEMENTATIONS[0]]
        ver_env = collect_system_info(ATTN_IMPLEMENTATIONS[0])
        
        log_print("\nEnvironment comparison:")
        log_print(f"  Generator GPU: {ref_env['gpu_name']}")
        log_print(f"  Verifier GPU:  {ver_env['gpu_name']}")
        
        check_fields = ['torch_version', 'transformers_version', 'cuda_version']
        for field in check_fields:
            ref_val = ref_env.get(field, 'N/A')
            ver_val = ver_env.get(field, 'N/A')
            match = "✓" if ref_val == ver_val else "✗"
            log_print(f"  {field}: {ref_val} vs {ver_val} {match}")
        
        # Store generation signals
        gen_measurements = {}
        for attn_impl in ATTN_IMPLEMENTATIONS:
            gen_measurements[attn_impl] = []
            for gen_data in reference['generations'][attn_impl]:
                gen_measurements[attn_impl].append({
                    'prefill_signals': gen_data['prefill_signals'],
                    'decode_signals': gen_data['decode_signals']
                })
        
        # Run verification for each combination
        ver_measurements = {}
        
        for verify_attn in ATTN_IMPLEMENTATIONS:
            log_print(f"\n--- Verifying with: {verify_attn} ---")
            
            model = load_model(verify_attn)
            
            for claimed_attn in ATTN_IMPLEMENTATIONS:
                log_print(f"  Claimed {claimed_attn}:")
                
                ver_measurements[(claimed_attn, verify_attn)] = []
                
                for ref_idx, gen_data in enumerate(reference['generations'][claimed_attn]):
                    log_print(f"    ref_{ref_idx}: ", end="")
                    
                    ver_data = run_teacher_forced(model, tokenizer, gen_data, LAYER_INDICES)
                    
                    ver_measurements[(claimed_attn, verify_attn)].append({
                        'prefill_signals': ver_data['prefill_signals'],
                        'decode_signals': ver_data['decode_signals']
                    })
                    log_print("done")
            
            del model
            torch.cuda.empty_cache()
        
        # Cross-hardware analysis
        prefill_analysis = analyze_cross_hardware(
            gen_measurements, ver_measurements, ATTN_IMPLEMENTATIONS, LAYER_INDICES, 'prefill'
        )
        decode_analysis = analyze_cross_hardware(
            gen_measurements, ver_measurements, ATTN_IMPLEMENTATIONS, LAYER_INDICES, 'decode'
        )
        
        # Save results
        output_dir = '/workspace/experiments'
        os.makedirs(output_dir, exist_ok=True)
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filepath = os.path.join(output_dir, f"attn_verify_{timestamp}.json")
        
        verify_results = {
            'metadata': {
                'reference_file': REFERENCE_FILE,
                'reference_environments': reference['metadata']['environments'],
                'verifier_environments': {
                    attn: collect_system_info(attn) for attn in ATTN_IMPLEMENTATIONS
                }
            },
            'cross_hardware': {
                'prefill': prefill_analysis,
                'decode': decode_analysis
            }
        }
        
        with open(filepath, 'w') as f:
            json.dump(verify_results, f, indent=2)
        
        log_print(f"\nSaved to: {filepath}")
    
    close_logging()

if __name__ == "__main__":
    main()

ATTENTION IMPLEMENTATION CROSS-HARDWARE EXPERIMENT
GPU: NVIDIA H100 NVL
Hostname: d013d4f0b0a1
Found 1 PDF(s)
  Loading: /workspace/Verification-for-International-AI-Governance.pdf


`torch_dtype` is deprecated! Use `dtype` instead!


Total source tokens: 120215
Prompt structure: 33 prefix + 10000 snippet + 34 suffix = 10067 tokens

Created 6 prompts

VERIFICATION MODE
Loaded reference: A100_generate_10000in_300out.json

Environment comparison:
  Generator GPU: NVIDIA A100-SXM4-80GB
  Verifier GPU:  NVIDIA H100 NVL
  torch_version: 2.8.0+cu128 vs 2.8.0+cu128 ✓
  transformers_version: 4.57.3 vs 4.57.3 ✓
  cuda_version: 12.8 vs 12.8 ✓

--- Verifying with: eager ---
Loading model with attention: eager


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

  Claimed eager:
    ref_0: done
    ref_1: done
    ref_2: done
    ref_3: done
    ref_4: done
    ref_5: done
  Claimed sdpa:
    ref_0: done
    ref_1: done
    ref_2: done
    ref_3: done
    ref_4: done
    ref_5: done
  Claimed flash_attention_2:
    ref_0: done
    ref_1: done
    ref_2: done
    ref_3: done
    ref_4: done
    ref_5: done

--- Verifying with: sdpa ---
Loading model with attention: sdpa


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

  Claimed eager:
    ref_0: done
    ref_1: done
    ref_2: done
    ref_3: done
    ref_4: done
    ref_5: done
  Claimed sdpa:
    ref_0: done
    ref_1: done
    ref_2: done
    ref_3: done
    ref_4: done
    ref_5: done
  Claimed flash_attention_2:
    ref_0: done
    ref_1: done
    ref_2: done
    ref_3: done
    ref_4: done
    ref_5: done

--- Verifying with: flash_attention_2 ---
Loading model with attention: flash_attention_2


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

  Claimed eager:
    ref_0: done
    ref_1: done
    ref_2: done
    ref_3: done
    ref_4: done
    ref_5: done
  Claimed sdpa:
    ref_0: done
    ref_1: done
    ref_2: done
    ref_3: done
    ref_4: done
    ref_5: done
  Claimed flash_attention_2:
    ref_0: done
    ref_1: done
    ref_2: 