In [3]:
#!/usr/bin/env python3
"""
Experiment 1: Prefill + Concurrent Prefill
Tests: High AI + High AI interference

Uses very long prompts to ensure high compute utilization.
Same methodology as exp2_decode_timeline.py
"""

import os
os.environ['HF_HOME'] = '/workspace/huggingface_cache'

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import numpy as np
from datetime import datetime
import json
import time
import subprocess
import threading
import socket
from collections import defaultdict

# ============================================================================
# ATTESTATION
# ============================================================================

def attest_system():
    print("=" * 70)
    print("SYSTEM ATTESTATION")
    print("=" * 70)
    
    attestation = {
        "timestamp": datetime.now().isoformat(),
        "hostname": socket.gethostname(),
        "pytorch": {
            "version": torch.__version__,
            "cuda_version": torch.version.cuda,
        },
        "gpu": {
            "name": torch.cuda.get_device_name(0),
            "capability": f"{torch.cuda.get_device_capability(0)[0]}.{torch.cuda.get_device_capability(0)[1]}",
        }
    }
    
    print(f"GPU: {attestation['gpu']['name']}")
    print(f"PyTorch: {attestation['pytorch']['version']}")
    print(f"CUDA: {attestation['pytorch']['cuda_version']}")
    print()
    return attestation

class GPUMonitor:
    def __init__(self, interval=0.1):
        self.interval = interval
        self.running = False
        self.thread = None
        self.samples = []
    
    def _monitor(self):
        while self.running:
            try:
                output = subprocess.check_output(
                    ['nvidia-smi', '--query-gpu=utilization.gpu,utilization.memory',
                     '--format=csv,noheader,nounits'],
                    encoding='utf-8'
                )
                gpu_util, mem_util = map(float, output.strip().split(','))
                self.samples.append({'gpu': gpu_util, 'memory': mem_util})
            except:
                pass
            time.sleep(self.interval)
    
    def start(self):
        self.samples = []
        self.running = True
        self.thread = threading.Thread(target=self._monitor, daemon=True)
        self.thread.start()
    
    def stop(self):
        self.running = False
        if self.thread:
            self.thread.join(timeout=2)
        
        if not self.samples:
            return None
        
        gpu_utils = [s['gpu'] for s in self.samples]
        return {
            'gpu_mean': np.mean(gpu_utils),
            'gpu_p95': np.percentile(gpu_utils, 95),
            'gpu_max': np.max(gpu_utils),
            'samples': len(self.samples)
        }

# ============================================================================
# CONFIGURATION
# ============================================================================

MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"
LAYER_INDICES = [1, 4, 10, 18, 28]
NUM_RUNS = 1  # Single run with verified parallel execution
SAMPLE_INTERVAL = 50  # Sample every 50th position (prompts are long)

# Very long prompts for high compute utilization during prefill
LONG_PROMPT_TEMPLATE = """You are tasked with providing a comprehensive technical analysis of distributed database systems for a Fortune 500 company evaluating data infrastructure modernization. The company operates a legacy Oracle RAC database cluster with 500TB of structured data, 2PB of unstructured data in S3, serving 100,000 concurrent users across 50 global offices in 25 countries. They need to evaluate migration paths to modern distributed architectures supporting both OLTP and OLAP workloads with minimal downtime while maintaining ACID guarantees and supporting real-time analytics.

Current infrastructure: 12-node Oracle RAC cluster on bare metal with SAN storage, 50ms average query latency, 99.95% uptime SLA, backup RPO of 1 hour. Daily batch processing takes 8 hours using legacy ETL tools. The system processes 2 million transactions per day with peak loads of 50,000 TPS during business hours. Critical business applications include order management, inventory tracking, customer relationship management, financial reporting, and real-time fraud detection.

Technical constraints: Must maintain sub-100ms p99 latency for OLTP queries, support complex JOIN operations across 200+ tables, enable real-time materialized views for analytics, provide point-in-time recovery for compliance, support multi-region disaster recovery with RPO < 5 minutes, enable blue-green deployments for zero-downtime upgrades, maintain referential integrity across distributed transactions, support both SQL and NoSQL query patterns, enable real-time data replication to analytics warehouse.

Evaluation criteria: Total cost of ownership over 5 years including licensing, infrastructure, operations, migration costs. Performance benchmarks for OLTP (TPS, latency), OLAP (query response time, concurrency), mixed workloads. Operational complexity including monitoring, troubleshooting, capacity planning, disaster recovery procedures. Vendor ecosystem including tooling support, cloud integration, community resources. Migration strategy including data migration approach, application refactoring requirements, downtime windows, rollback procedures.

Modern architecture options to evaluate: (1) Amazon Aurora PostgreSQL with read replicas and Aurora Serverless v2 for variable workloads. (2) Google Cloud Spanner for globally distributed ACID transactions with TrueTime. (3) CockroachDB for distributed SQL with automatic sharding and rebalancing. (4) MongoDB Atlas with multi-region clusters and change streams. (5) Cassandra with Spark for analytics workloads. (6) Hybrid approach using PostgreSQL for OLTP with Snowflake for OLAP. (7) YugabyteDB for PostgreSQL compatibility with distributed architecture."""

# Create variations for concurrent stream
CONCURRENT_PROMPT_TEMPLATE = """You are providing expert analysis on cloud-native application architecture patterns for a global e-commerce platform processing $5B in annual revenue. The platform serves 50 million active users across web, mobile, and API channels with 99.99% availability requirements.

Current architecture: Monolithic Java application (2.5M lines of code) on 100 EC2 instances behind ELB, MySQL cluster with read replicas, Redis for caching, RabbitMQ for async processing. System handles 10M requests per day with average response time of 500ms. Critical services include product catalog, shopping cart, order processing, payment gateway integration, inventory management, recommendation engine, fraud detection, customer service portal.

Modernization drivers: Deployment velocity (currently 2-week release cycles, targeting daily deployments), operational costs (EC2 + RDS spending $2M annually), scaling challenges (Black Friday requires 10x capacity, manual scaling takes hours), developer productivity (build times 45 minutes, integration testing 4 hours), service reliability (cascading failures affect entire platform).

Architecture evaluation criteria: Operational complexity (service orchestration, observability, debugging distributed transactions), development velocity (microservice boundaries, API contracts, testing strategies), infrastructure costs (compute, storage, networking, managed services), reliability patterns (circuit breakers, bulkheads, rate limiting, chaos engineering), data consistency (eventual vs strong consistency, saga patterns, event sourcing), security posture (service mesh, mTLS, API gateways, secret management).

Proposed patterns: (1) Strangler fig migration with API gateway routing legacy vs new services. (2) Event-driven architecture with Kafka for service communication. (3) CQRS with separate read/write models and materialized views. (4) Service mesh with Istio for traffic management and observability. (5) Serverless for variable workloads using Lambda/Fargate. (6) GraphQL federation for unified API layer. (7) Feature flags and canary deployments for gradual rollout.

Technical considerations: Service decomposition strategy (bounded contexts, team ownership, API versioning), data management (database per service, shared database, event sourcing), inter-service communication (synchronous REST/gRPC vs asynchronous messaging), distributed transactions (saga patterns, compensating transactions, idempotency), observability (distributed tracing, metrics aggregation, log correlation), testing strategy (contract testing, chaos engineering, synthetic monitoring), deployment architecture (Kubernetes, service mesh, API gateway, monitoring stack)."""

# Scale to target length
DEFAULT_PROMPT = (LONG_PROMPT_TEMPLATE + "\n\n" + "Please analyze each option in detail. " * 50).strip()
CONCURRENT_PROMPT = (CONCURRENT_PROMPT_TEMPLATE + "\n\n" + "Provide detailed recommendations. " * 50).strip()

# ============================================================================
# CONCURRENT STREAM
# ============================================================================

class ConcurrentStream:
    """Concurrent prefill stream with execution timing"""
    def __init__(self, model, tokenizer, prompt, device="cuda"):
        self.model = model
        self.tokenizer = tokenizer
        self.prompt = prompt
        self.device = device
        self.stream = torch.cuda.Stream()
        self.should_stop = threading.Event()
        self.thread = None
        
        self.execution_times = []
        self.lock = threading.Lock()
    
    def _run_stream(self):
        with torch.cuda.stream(self.stream):
            while not self.should_stop.is_set():
                try:
                    inputs = self.tokenizer([self.prompt], return_tensors="pt")
                    inputs = {k: v.to(self.device) for k, v in inputs.items()}
                    
                    start = time.perf_counter()
                    
                    with torch.no_grad():
                        # Prefill only - no generation
                        _ = self.model(**inputs, use_cache=True)
                    
                    torch.cuda.synchronize(self.stream)
                    end = time.perf_counter()
                    
                    with self.lock:
                        self.execution_times.append((start, end))
                    
                except Exception as e:
                    print(f"  [WARNING] Concurrent stream error: {e}")
                    break
    
    def start(self):
        self.execution_times = []
        self.should_stop.clear()
        self.thread = threading.Thread(target=self._run_stream, daemon=True)
        self.thread.start()
        time.sleep(0.2)
    
    def stop(self):
        self.should_stop.set()
        if self.thread:
            self.thread.join(timeout=10)
        torch.cuda.synchronize()
    
    def get_execution_times(self):
        with self.lock:
            return list(self.execution_times)

# ============================================================================
# TIMELINE VISUALIZATION
# ============================================================================

def print_execution_timeline(default_start, default_end, concurrent_times, condition_name):
    """Print visual timeline showing when default and concurrent streams executed"""
    print(f"\n{condition_name} - EXECUTION TIMELINE:")
    print("-" * 70)
    
    if not concurrent_times:
        default_duration = (default_end - default_start) * 1000
        print(f"Default stream:    [{default_start:.3f} → {default_end:.3f}] {default_duration:.2f}ms")
        print(f"Concurrent stream: (none)")
        return
    
    all_times = [default_start, default_end]
    for start, end in concurrent_times:
        all_times.extend([start, end])
    
    min_time = min(all_times)
    max_time = max(all_times)
    time_span = max_time - min_time
    
    default_duration = (default_end - default_start) * 1000
    default_rel_start = default_start - min_time
    default_rel_end = default_end - min_time
    print(f"Default stream:    [{default_rel_start*1000:.1f}ms → {default_rel_end*1000:.1f}ms] duration={default_duration:.2f}ms")
    
    print(f"Concurrent stream:")
    overlapping = []
    for i, (start, end) in enumerate(concurrent_times):
        duration = (end - start) * 1000
        rel_start = start - min_time
        rel_end = end - min_time
        
        overlap_start = max(default_start, start)
        overlap_end = min(default_end, end)
        
        if overlap_start < overlap_end:
            overlap_duration = (overlap_end - overlap_start) * 1000
            overlap_pct = (overlap_duration / duration) * 100
            overlap_str = f"overlap={overlap_duration:.1f}ms ({overlap_pct:.1f}%)"
            overlapping.append((start, end, overlap_end - overlap_start))
        else:
            overlap_str = "NO OVERLAP"
        
        print(f"  Execution {i+1}:  [{rel_start*1000:.1f}ms → {rel_end*1000:.1f}ms] duration={duration:.2f}ms, {overlap_str}")
    
    if overlapping:
        total_concurrent_overlap = sum(dur for _, _, dur in overlapping)
        default_coverage = (total_concurrent_overlap / (default_end - default_start)) * 100
        
        print(f"\nSummary:")
        print(f"  Total time span: {time_span*1000:.2f}ms")
        print(f"  Concurrent executions overlapping with default: {len(overlapping)}/{len(concurrent_times)}")
        print(f"  Default stream coverage: {default_coverage:.1f}%")
    
    print("-" * 70)

# ============================================================================
# KEY EXTRACTION FROM PREFILL
# ============================================================================

def extract_keys_prefill(model, tokenizer, prompt, layer_indices,
                        condition_name, sample_interval=50,
                        concurrent_stream=None):
    """
    Run prefill and extract keys from attention output.
    Returns keys from sampled positions.
    """
    
    print(f"\n[{condition_name}] Starting...")
    
    device = "cuda"
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    
    if concurrent_stream:
        print(f"[{condition_name}] Starting concurrent stream...")
        concurrent_stream.start()
        time.sleep(1.0)
    
    inputs = tokenizer([prompt], return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    prompt_len = inputs['input_ids'].shape[1]
    
    print(f"[{condition_name}] Prompt: {prompt_len} tokens (prefill only)")
    
    all_keys = defaultdict(dict)
    monitor = GPUMonitor()
    monitor.start()
    
    torch.cuda.synchronize()
    start = time.perf_counter()
    
    with torch.no_grad():
        outputs = model(**inputs, use_cache=True, return_dict=True)
    
    torch.cuda.synchronize()
    end = time.perf_counter()
    elapsed = (end - start) * 1000
    
    default_stream_time = (start, end)
    
    # Extract keys from past_key_values
    past_kv = outputs.past_key_values
    
    # Sample positions
    positions_to_sample = list(range(prompt_len - 1, -1, -sample_interval))
    if positions_to_sample[-1] != 0:
        positions_to_sample.append(0)
    positions_to_sample = sorted(positions_to_sample)
    
    print(f"[{condition_name}] Sampling {len(positions_to_sample)} positions from prompt")
    
    for layer_idx in layer_indices:
        if layer_idx == 0:
            continue
        
        keys_all_positions = past_kv[layer_idx - 1][0]  # [batch, heads, seq, head_dim]
        
        for position in positions_to_sample:
            key_vector = keys_all_positions[0, :, position, :]
            layer_key_name = f"layer_{layer_idx}"
            all_keys[position][layer_key_name] = key_vector.reshape(-1).float().cpu().numpy()
    
    gpu_stats = monitor.stop()
    
    concurrent_stream_times = None
    if concurrent_stream:
        concurrent_stream.stop()
        concurrent_stream_times = concurrent_stream.get_execution_times()
        
        if default_stream_time:
            print_execution_timeline(
                default_stream_time[0],
                default_stream_time[1],
                concurrent_stream_times,
                condition_name
            )
    
    torch.cuda.synchronize()
    
    timing_stats = {
        'mean_ms': float(elapsed),
        'median_ms': float(elapsed),
        'std_ms': 0.0,
        'all_times': [float(elapsed)]
    }
    
    sample_pos = list(all_keys.keys())[0] if all_keys else None
    if sample_pos:
        layers_captured = list(all_keys[sample_pos].keys())
        print(f"[{condition_name}] ✓ Extracted: {len(layers_captured)} vectors per position")
    
    return dict(all_keys), timing_stats, gpu_stats, default_stream_time, concurrent_stream_times

# ============================================================================
# ANALYSIS
# ============================================================================

def compare_keys(baseline_keys, test_keys):
    """Compare key vectors between baseline and test conditions"""
    results = {
        'all_positions_exact': True,
        'positions': {}
    }
    
    shared_positions = set(baseline_keys.keys()) & set(test_keys.keys())
    
    for position in sorted(shared_positions):
        base_pos = baseline_keys[position]
        test_pos = test_keys[position]
        
        position_result = {
            'all_layers_exact': True,
            'layers': {}
        }
        
        for layer_name in base_pos.keys() & test_pos.keys():
            base = base_pos[layer_name]
            test = test_pos[layer_name]
            
            bit_exact = np.array_equal(base, test)
            l2 = 0.0 if bit_exact else float(np.linalg.norm(base - test))
            
            position_result['layers'][layer_name] = {
                'bit_exact': bit_exact,
                'l2': l2
            }
            
            if not bit_exact:
                position_result['all_layers_exact'] = False
                results['all_positions_exact'] = False
        
        results['positions'][position] = position_result
    
    return results

def analyze_timing(baseline_time, test_time, baseline_std):
    """Analyze timing differences"""
    slowdown_ms = test_time - baseline_time
    slowdown_pct = (slowdown_ms / baseline_time) * 100
    threshold = 2 * baseline_std if baseline_std > 0 else baseline_time * 0.1
    significant = abs(slowdown_ms) > threshold
    
    return {
        'time_ms': float(test_time),
        'baseline_ms': float(baseline_time),
        'slowdown_ms': float(slowdown_ms),
        'slowdown_pct': float(slowdown_pct),
        'significant': significant
    }

# ============================================================================
# MAIN EXPERIMENT
# ============================================================================

def main():
    attestation = attest_system()
    
    print("=" * 70)
    print("EXPERIMENT 1: PREFILL + CONCURRENT PREFILL")
    print("=" * 70)
    print("Configuration:")
    print("  • Prefill-only operations (high arithmetic intensity)")
    print("  • Very long prompts for high compute utilization")
    print("  • Extract keys from attention output")
    print("  • Timeline visualization for parallel execution")
    print()
    
    print("Loading model...")
    device = "cuda"
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.bfloat16,
        device_map=device
    )
    model.eval()
    print("✓ Model loaded\n")
    
    default_tokens = len(tokenizer.encode(DEFAULT_PROMPT))
    concurrent_tokens = len(tokenizer.encode(CONCURRENT_PROMPT))
    
    print(f"Configuration:")
    print(f"  Default stream: {default_tokens} tokens (prefill)")
    print(f"  Concurrent stream: {concurrent_tokens} tokens (prefill)")
    print(f"  Sample interval: every {SAMPLE_INTERVAL}th position")
    print(f"  Layers: {LAYER_INDICES}")
    print()
    
    # Run conditions
    conditions = {}
    
    # 0. SANITY CHECK - Different prompt should give different keys
    print("\n" + "=" * 70)
    print("SANITY CHECK: Verifying measurement system works")
    print("=" * 70)
    perturbed_prompt = "[VERIFICATION TOKEN] " + DEFAULT_PROMPT  # Perturb at START for causal attention
    sanity_keys, _, _, _, _ = extract_keys_prefill(
        model, tokenizer, perturbed_prompt, LAYER_INDICES,
        "SANITY_CHECK", SAMPLE_INTERVAL, concurrent_stream=None
    )
    
    # 1. Baseline
    base_keys, base_timing, base_gpu, base_default_time, base_concurrent_times = extract_keys_prefill(
        model, tokenizer, DEFAULT_PROMPT, LAYER_INDICES,
        "BASELINE", SAMPLE_INTERVAL, concurrent_stream=None
    )
    
    # Verify sanity check
    print("\nVerifying measurement system:")
    sanity_check_passed = False
    shared_positions = set(base_keys.keys()) & set(sanity_keys.keys())
    
    if not shared_positions:
        print("  ⚠️  WARNING: No shared positions between prompts!")
        print("  → Cannot verify measurement system")
        sanity_check_passed = None
    else:
        for pos in sorted(shared_positions)[:3]:  # Check first 3 shared positions
            for layer in base_keys[pos].keys():
                if not np.array_equal(base_keys[pos][layer], sanity_keys[pos][layer]):
                    l2 = np.linalg.norm(base_keys[pos][layer] - sanity_keys[pos][layer])
                    print(f"  Position {pos}, {layer}: DIFFERENT (L2={l2:.2e}) ✓")
                    sanity_check_passed = True
                    break
            if sanity_check_passed:
                break
    
    if sanity_check_passed == False:
        print("  ⚠️  WARNING: Perturbed prompt gave IDENTICAL keys!")
        print("  → Measurement system may be broken")
        print("  → Results below are INVALID")
    elif sanity_check_passed == True:
        print("  ✓ Measurement system verified - different prompts give different keys")
    print()
    
    conditions['baseline'] = {
        'keys': base_keys,
        'timing': base_timing,
        'gpu': base_gpu,
        'default_time': base_default_time,
        'concurrent_times': base_concurrent_times
    }
    
    # 2. Low concurrent (short prompt)
    short_concurrent_prompt = CONCURRENT_PROMPT[:len(CONCURRENT_PROMPT)//3]  # 1/3 length
    low_stream = ConcurrentStream(model, tokenizer, short_concurrent_prompt, device)
    low_keys, low_timing, low_gpu, low_default_time, low_concurrent_times = extract_keys_prefill(
        model, tokenizer, DEFAULT_PROMPT, LAYER_INDICES,
        "LOW_CONCURRENT", SAMPLE_INTERVAL, concurrent_stream=low_stream
    )
    conditions['low_concurrent'] = {
        'keys': low_keys,
        'timing': low_timing,
        'gpu': low_gpu,
        'default_time': low_default_time,
        'concurrent_times': low_concurrent_times
    }
    
    # 3. High concurrent (full prompt)
    high_stream = ConcurrentStream(model, tokenizer, CONCURRENT_PROMPT, device)
    high_keys, high_timing, high_gpu, high_default_time, high_concurrent_times = extract_keys_prefill(
        model, tokenizer, DEFAULT_PROMPT, LAYER_INDICES,
        "HIGH_CONCURRENT", SAMPLE_INTERVAL, concurrent_stream=high_stream
    )
    conditions['high_concurrent'] = {
        'keys': high_keys,
        'timing': high_timing,
        'gpu': high_gpu,
        'default_time': high_default_time,
        'concurrent_times': high_concurrent_times
    }
    
    # Analyze vectors
    print("\n" + "=" * 70)
    print("VECTOR FORENSICS")
    print("=" * 70)
    
    key_results = {}
    
    for cond_name in ['low_concurrent', 'high_concurrent']:
        key_results[cond_name] = compare_keys(
            base_keys, conditions[cond_name]['keys']
        )
        
        print(f"\n{cond_name.upper()} - KEYS:")
        result = key_results[cond_name]
        total_key_diffs = sum(
            1 for pos_data in result['positions'].values()
            for layer_data in pos_data['layers'].values()
            if not layer_data['bit_exact']
        )
        print(f"  Total differences: {total_key_diffs}")
        print(f"  All positions bit-exact: {result['all_positions_exact']}")
        
        if not result['all_positions_exact']:
            # Show some examples
            positions_sorted = sorted(result['positions'].keys())[:3]
            for pos in positions_sorted:
                pos_data = result['positions'][pos]
                has_diffs = any(not ld['bit_exact'] for ld in pos_data['layers'].values())
                if has_diffs:
                    print(f"  Position {pos}:")
                    for layer_name in sorted(pos_data['layers'].keys()):
                        layer_data = pos_data['layers'][layer_name]
                        if not layer_data['bit_exact']:
                            print(f"    {layer_name}: ✗ DIFF (L2={layer_data['l2']:.2e})")
    
    # Timing analysis
    print("\n" + "=" * 70)
    print("TIMING FORENSICS")
    print("=" * 70)
    
    baseline_time = base_timing['median_ms']
    baseline_std = base_timing['std_ms']
    
    timing_results = {}
    for cond_name in ['low_concurrent', 'high_concurrent']:
        test_time = conditions[cond_name]['timing']['median_ms']
        timing_results[cond_name] = analyze_timing(baseline_time, test_time, baseline_std)
        
        result = timing_results[cond_name]
        print(f"\n{cond_name.upper()}:")
        print(f"  Time: {result['time_ms']:.2f}ms (baseline: {result['baseline_ms']:.2f}ms)")
        print(f"  Slowdown: {result['slowdown_ms']:+.2f}ms ({result['slowdown_pct']:+.1f}%)")
        print(f"  Significant: {'YES' if result['significant'] else 'NO'}")
    
    # Verdict
    print("\n" + "=" * 70)
    print("VERDICT")
    print("=" * 70)
    
    keys_exact = all(key_results[c]['all_positions_exact'] for c in key_results)
    timing_detects = any(timing_results[c]['significant'] for c in timing_results)
    
    if sanity_check_passed == False:
        print("\n⚠️  INVALID RESULTS - Measurement system verification failed")
        print("   Different prompts produced identical keys")
        print("   → Cannot trust any results from this run")
    elif sanity_check_passed == None:
        print("\n⚠️  INCONCLUSIVE - Could not verify measurement system")
        print("   No shared positions between prompts for comparison")
    elif not keys_exact:
        print("\n✓ FP INTERFERENCE DETECTED in prefill keys")
        print("   High-AI workloads show numerical sensitivity")
        print("   → Prefill forensics CAN detect hidden concurrent work")
    else:
        print("\n⚠️  NO FP INTERFERENCE in prefill keys")
        print("   Even high-AI workloads remain numerically stable")
        print("   → Prefill+prefill does not create detectable FP differences")
        print("   → Measurement system verified (sanity check passed)")
    
    if timing_detects:
        print("\n✓ Timing forensics works")
        print("   Compute/memory contention detected")
    else:
        print("\n✗ No timing slowdown detected")
        print("   Resources sufficient for concurrent prefill")
    
    # Save results
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    
    output = {
        'experiment': 'exp1_prefill_prefill',
        'attestation': attestation,
        'config': {
            'model': MODEL_NAME,
            'default_tokens': default_tokens,
            'concurrent_tokens': concurrent_tokens,
            'sample_interval': SAMPLE_INTERVAL,
            'layer_indices': LAYER_INDICES
        },
        'sanity_check': {
            'passed': sanity_check_passed,
            'description': 'Verified different prompts produce different keys' if sanity_check_passed == True else 'Could not verify' if sanity_check_passed == None else 'Verification failed'
        },
        'forensics': {
            'keys': {c: {'all_exact': r['all_positions_exact']} for c, r in key_results.items()}
        },
        'timing_forensics': timing_results,
        'verdict': {
            'keys_exact': keys_exact,
            'timing_detects': timing_detects,
            'sanity_check_passed': sanity_check_passed
        }
    }
    
    output_file = f"exp1_prefill_prefill_{timestamp}.json"
    with open(f"/workspace/{output_file}", 'w') as f:
        json.dump(output, f, indent=2)
    
    print(f"\n✓ Results saved to /workspace/{output_file}")
    print("\n" + "=" * 70)
    print("EXPERIMENT COMPLETE")
    print("=" * 70)

if __name__ == "__main__":
    main()

SYSTEM ATTESTATION
GPU: NVIDIA A100 80GB PCIe
PyTorch: 2.8.0+cu128
CUDA: 12.8

EXPERIMENT 1: PREFILL + CONCURRENT PREFILL
Configuration:
  • Prefill-only operations (high arithmetic intensity)
  • Very long prompts for high compute utilization
  • Extract keys from attention output
  • Timeline visualization for parallel execution

Loading model...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

✓ Model loaded

Configuration:
  Default stream: 875 tokens (prefill)
  Concurrent stream: 711 tokens (prefill)
  Sample interval: every 50th position
  Layers: [1, 4, 10, 18, 28]


SANITY CHECK: Verifying measurement system works

[SANITY_CHECK] Starting...
[SANITY_CHECK] Prompt: 880 tokens (prefill only)
[SANITY_CHECK] Sampling 19 positions from prompt

[BASELINE] Starting...
[BASELINE] Prompt: 875 tokens (prefill only)
[BASELINE] Sampling 19 positions from prompt

Verifying measurement system:
  Position 0, layer_1: DIFFERENT (L2=2.83e+01) ✓
  ✓ Measurement system verified - different prompts give different keys


[LOW_CONCURRENT] Starting...
[LOW_CONCURRENT] Starting concurrent stream...
[LOW_CONCURRENT] Prompt: 875 tokens (prefill only)
[LOW_CONCURRENT] Sampling 19 positions from prompt

LOW_CONCURRENT - EXECUTION TIMELINE:
----------------------------------------------------------------------
Default stream:    [1209.3ms → 1316.7ms] duration=107.37ms
Concurrent stream:
  Executio