In [None]:
#!/usr/bin/env python3
"""
Batch Size Matrix Experiment

Runs all batch size configurations (bs=1, bs=2, bs=4) for multiple reference sequences.
Each reference sequence gets unique dummy sequences for batching.

Design:
- 3 reference sequences
- Each tested at bs=1, bs=2, bs=4
- 9 total measurements per GPU
- Enables full cross-comparison matrix

CRITICAL: Truncates all sequences in batch to shortest length to avoid padding artifacts

Usage:
    python batch_matrix_experiment.py
"""

import os
os.environ['HF_HOME'] = '/workspace/huggingface_cache'
os.environ['TRANSFORMERS_CACHE'] = '/workspace/huggingface_cache'

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import numpy as np
from datetime import datetime
import json
import socket

# ============================================================================
# REFERENCE SEQUENCES
# ============================================================================

REFERENCE_SEQUENCES = {
    "ref_technical": """Large language models have revolutionized natural language processing through their ability to capture complex patterns in text data. The transformer architecture, introduced in 2017, employs self-attention mechanisms that allow the model to weigh the importance of different tokens in the input sequence. During training, these models learn to predict the next token in a sequence by optimizing a cross-entropy loss function across billions of text examples.""",
    
    "ref_narrative": """The morning sun filtered through the ancient oak trees as Sarah walked along the forest path, her boots crunching softly on the fallen leaves. She had been coming to these woods since childhood, when her grandmother first taught her to identify the different bird calls echoing through the canopy. Now, decades later, she found herself returning to this same trail whenever life felt overwhelming.""",
    
    "ref_code": """The database migration system implements a sophisticated version control mechanism for schema changes. Each migration file contains both an upgrade and downgrade function, allowing the system to roll forward or backward through schema versions. The migration engine maintains a table tracking which migrations have been applied, using timestamps and hash values to ensure consistency across different environments."""
}

# ============================================================================
# DUMMY SEQUENCES (3 sets of 3 dummies, one set per reference)
# ============================================================================

DUMMY_SETS = {
    "ref_technical": [
        """Quantum computing leverages the principles of quantum mechanics to perform computations that would be intractable for classical computers. At the heart of quantum computation lies the qubit, a quantum bit that can exist in a superposition of both 0 and 1 states simultaneously. When multiple qubits are entangled, they form a quantum register capable of representing an exponentially large state space.""",
        
        """The neural architecture search algorithm systematically explores different model configurations to identify optimal designs for specific tasks. Modern approaches use reinforcement learning or evolutionary algorithms to navigate the vast search space of possible architectures. The process evaluates candidate models on validation data, gradually converging toward efficient and effective network topologies.""",
        
        """Distributed consensus protocols enable multiple nodes in a network to agree on a single value despite potential failures or malicious actors. The Byzantine Generals Problem formalizes the challenge of achieving consensus when some participants may behave arbitrarily. Practical solutions like Paxos and Raft provide mechanisms for fault-tolerant agreement in real-world systems."""
    ],
    
    "ref_narrative": [
        """The old lighthouse stood sentinel on the rocky promontory, its weathered walls bearing testament to countless storms. Local legends spoke of the keeper who vanished one winter night, leaving only his log book with a final cryptic entry. Now automated, the beacon still swept across the dark waters, a guardian whose original purpose had long been superseded by modern navigation systems.""",
        
        """Marcus found the letter tucked between the pages of his grandfather's journal, the paper yellowed and fragile with age. The handwriting was unfamiliar, yet the words spoke of events his family had never discussed. As he read, pieces of his heritage began to fall into place, revealing a story that had been deliberately hidden for three generations.""",
        
        """The jazz club occupied a basement space that seemed to exist outside of time, where smoke still hung in the air despite the ban and the music felt like it emerged from another era. Every Thursday, the same musicians gathered to play standards that few in the younger generation recognized. Yet something about the atmosphere drew people in, seeking connection to an authenticity they sensed was disappearing from the world."""
    ],
    
    "ref_code": [
        """The distributed caching layer implements consistent hashing to minimize cache invalidation when nodes are added or removed from the cluster. Virtual nodes provide better load distribution across physical servers, while replication ensures availability even during node failures. The system monitors hit rates and eviction patterns to automatically adjust cache allocation strategies.""",
        
        """The API gateway performs request routing, authentication, rate limiting, and response transformation for microservices. Each service registers its endpoints with the gateway, which maintains a dynamic routing table. The gateway implements circuit breakers to prevent cascade failures and provides detailed metrics for monitoring service health.""",
        
        """The message queue system guarantees exactly-once delivery through a combination of acknowledgments, persistent storage, and idempotency tokens. Publishers receive confirmation only after messages are durably written to replicated storage. Consumers process messages within transactions, ensuring atomic updates across message consumption and business logic execution."""
    ]
}

# ============================================================================
# CONFIGURATION
# ============================================================================

CACHE_DIR = '/workspace/huggingface_cache'
MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"
LAYER_INDICES = [1, 4, 10, 18, 28]
BATCH_SIZES = [1, 2, 4]

# ============================================================================
# EXTRACTION FUNCTION
# ============================================================================

def extract_last_token_keys(model, tokenizer, texts, device="cuda"):
    """
    Extract key vectors from LAST TOKEN position for element 0 in batch.
    
    CRITICAL: Truncates all sequences to shortest length to avoid padding artifacts.
    
    Args:
        texts: list of strings (batch)
        
    Returns:
        dict: {
            'layer_1': tensor of shape (key_dim,) for last token
            'layer_4': tensor of shape (key_dim,),
            ...
        }
        extraction_info: dict with extraction_position, original_lengths, truncated_length
    """
    torch.cuda.empty_cache()
    
    # Step 1: Get token counts for all texts
    original_lengths = []
    for text in texts:
        tokens = tokenizer.encode(text, add_special_tokens=True)
        original_lengths.append(len(tokens))
    
    # Step 2: Find minimum length
    min_length = min(original_lengths)
    
    # Step 3: Truncate all texts to min_length tokens
    truncated_texts = []
    for text in texts:
        tokens = tokenizer.encode(text, add_special_tokens=True)
        truncated_tokens = tokens[:min_length]
        truncated_text = tokenizer.decode(truncated_tokens, skip_special_tokens=False)
        truncated_texts.append(truncated_text)
    
    # Step 4: Tokenize truncated batch
    inputs = tokenizer(truncated_texts, return_tensors="pt", padding=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Step 5: Verify element 0 has correct length
    element0_length = inputs['attention_mask'][0].sum().item()
    if element0_length != min_length:
        raise ValueError(f"Padding protection failed! Expected {min_length}, got {element0_length}")
    
    # Step 6: Extraction position is last token (min_length - 1, 0-indexed)
    extraction_pos = min_length - 1
    
    # Step 7: Run model
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True, use_cache=True)
    
    # Step 8: Extract key vectors from last token of element 0
    key_vectors = {}
    for layer_idx in LAYER_INDICES:
        layer_keys = outputs.past_key_values[layer_idx - 1][0]  # (batch, num_heads, seq_len, head_dim)
        last_token_keys = layer_keys[0, :, extraction_pos, :]  # (num_heads, head_dim)
        key_dim = last_token_keys.shape[0] * last_token_keys.shape[1]
        key_vectors[f'layer_{layer_idx}'] = last_token_keys.reshape(key_dim).cpu().clone()
    
    del outputs
    del inputs
    torch.cuda.empty_cache()
    
    extraction_info = {
        'extraction_position': extraction_pos,
        'original_lengths': original_lengths,
        'truncated_length': min_length
    }
    
    return key_vectors, extraction_info

# ============================================================================
# MAIN EXPERIMENT
# ============================================================================

def main():
    # Auto-detect hardware
    gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "cpu"
    if "H100" in gpu_name:
        hardware_label = "h100"
    elif "A100" in gpu_name:
        hardware_label = "a100"
    elif "L40S" in gpu_name or "L40s" in gpu_name:
        hardware_label = "l40s"
    else:
        hardware_label = "gpu"
    
    print("="*70)
    print("BATCH SIZE MATRIX EXPERIMENT")
    print("="*70)
    print(f"\nHardware: {hardware_label.upper()} ({gpu_name})")
    print(f"Model: {MODEL_NAME}")
    print(f"Layers: {LAYER_INDICES}")
    print(f"Reference sequences: {len(REFERENCE_SEQUENCES)}")
    print(f"Batch sizes: {BATCH_SIZES}")
    print(f"Total measurements: {len(REFERENCE_SEQUENCES) * len(BATCH_SIZES)}")
    print()
    
    # System info
    hostname = socket.gethostname()
    
    print(f"System Info:")
    print(f"  Hostname: {hostname}")
    print(f"  GPU: {gpu_name}")
    print(f"  PyTorch: {torch.__version__}")
    print(f"  CUDA: {torch.version.cuda}")
    print()
    
    # Load model
    print(f"Loading model...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        dtype=torch.bfloat16,
        cache_dir=CACHE_DIR,
        low_cpu_mem_usage=True,
        device_map="auto"
    )
    print(f"✓ Model loaded\n")
    
    # Prepare results structure
    results = {
        'metadata': {
            'hardware': hardware_label,
            'hostname': hostname,
            'gpu': gpu_name,
            'pytorch_version': torch.__version__,
            'cuda_version': torch.version.cuda,
            'model': MODEL_NAME,
            'layer_indices': LAYER_INDICES,
            'batch_sizes_tested': BATCH_SIZES,
            'num_reference_sequences': len(REFERENCE_SEQUENCES),
            'timestamp': datetime.now().isoformat(),
            'experiment_type': 'batch_matrix'
        },
        'measurements': {}
    }
    
    # Run all combinations
    measurement_count = 0
    total_measurements = len(REFERENCE_SEQUENCES) * len(BATCH_SIZES)
    
    for ref_name, ref_text in REFERENCE_SEQUENCES.items():
        ref_dummies = DUMMY_SETS[ref_name]
        
        print(f"Reference: {ref_name}")
        
        for batch_size in BATCH_SIZES:
            measurement_count += 1
            measurement_name = f"{ref_name}_bs{batch_size}"
            
            print(f"  [{measurement_count}/{total_measurements}] bs={batch_size}...", end=" ")
            
            # Build batch
            if batch_size == 1:
                batch_texts = [ref_text]
            elif batch_size == 2:
                batch_texts = [ref_text, ref_dummies[0]]
            elif batch_size == 4:
                batch_texts = [ref_text] + ref_dummies[:3]
            
            # Verify reproducibility (3 runs)
            runs = []
            all_info = []
            
            for rep in range(3):
                key_vectors, extraction_info = extract_last_token_keys(model, tokenizer, batch_texts)
                runs.append(key_vectors)
                all_info.append(extraction_info)
            
            # Check reproducibility
            reproducible = True
            for layer_name in runs[0].keys():
                for i in range(1, 3):
                    if not torch.equal(runs[0][layer_name], runs[i][layer_name]):
                        reproducible = False
                        break
            
            # Verify extraction consistency
            extraction_info = all_info[0]
            for info in all_info[1:]:
                if info['extraction_position'] != extraction_info['extraction_position']:
                    print(f"\n  ⚠ WARNING: Extraction position varied across runs!")
                if info['truncated_length'] != extraction_info['truncated_length']:
                    print(f"\n  ⚠ WARNING: Truncated length varied across runs!")
            
            if not reproducible:
                print("⚠ Non-reproducible!", end=" ")
            
            # Use first run
            key_vectors = runs[0]
            
            # Store measurement
            measurement = {
                'reference_sequence': ref_name,
                'batch_size': batch_size,
                'extraction_position': extraction_info['extraction_position'],
                'original_lengths': extraction_info['original_lengths'],
                'truncated_length': extraction_info['truncated_length'],
                'reproducible': reproducible,
                'layers': {}
            }
            
            for layer_name, layer_keys in key_vectors.items():
                # Single token: shape (key_dim,)
                measurement['layers'][layer_name] = {
                    'key_vector': layer_keys.float().numpy().tolist(),
                    'norm': float(torch.norm(layer_keys).item())
                }
            
            results['measurements'][measurement_name] = measurement
            
            # Print length info
            if batch_size > 1:
                length_str = f"[{extraction_info['truncated_length']}/{extraction_info['original_lengths'][0]}]"
                print(f"{length_str}", end=" ")
            
            print("✓" if reproducible else "")
        
        print()
    
    # Save results
    output_dir = '/workspace/experiments'
    os.makedirs(output_dir, exist_ok=True)
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"{hardware_label}_batch_matrix_{timestamp}.json"
    filepath = os.path.join(output_dir, filename)
    
    with open(filepath, 'w') as f:
        json.dump(results, f, indent=2)
    
    file_size_kb = os.path.getsize(filepath) / 1024
    
    print("="*70)
    print(f"✓ Results saved to: {filepath}")
    print(f"  File size: {file_size_kb:.1f} KB")
    print(f"  Measurements: {len(results['measurements'])}")
    print("="*70)
    print()
    print("Next steps:")
    print(f"1. Run on other hardware")
    print(f"2. Compare: python compare_batch_matrix.py {filename} <other_file>")

if __name__ == "__main__":
    # Detect environment
    try:
        get_ipython()
        in_notebook = True
    except NameError:
        in_notebook = False
    
    if in_notebook:
        print("Running in notebook mode...")
        print("Will collect all batch size configurations automatically.")
        print()
        main()
    else:
        main()



Running in notebook mode...
Will collect all batch size configurations automatically.

BATCH SIZE MATRIX EXPERIMENT

Hardware: A100 (NVIDIA A100 80GB PCIe)
Model: Qwen/Qwen2.5-7B-Instruct
Layers: [1, 4, 10, 18, 28]
Reference sequences: 3
Batch sizes: [1, 2, 4]
Total measurements: 9

System Info:
  Hostname: 34d100e052e4
  GPU: NVIDIA A100 80GB PCIe
  PyTorch: 2.8.0+cu129
  CUDA: 12.9

Loading model...


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/663 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/3.95G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/3.56G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]