In [3]:
#!/usr/bin/env python3
"""
Batch Composition Test
Critical Question: Does batch CONTENT affect sequence 0's key vectors?
- Fixed batch size (4)
- Fixed sequence 0 
- All sequences truncated to exactly 100 tokens
- Different sequences in positions 1,2,3 across runs
- Compare sequence 0's key vectors
"""

import os
os.environ['HF_HOME'] = '/workspace/huggingface_cache'
os.environ['TRANSFORMERS_CACHE'] = '/workspace/huggingface_cache'
os.environ['HF_DATASETS_CACHE'] = '/workspace/huggingface_cache'

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import numpy as np
from datetime import datetime
import json
import socket

HOSTNAME = socket.gethostname()
CONTAINER_ID = os.environ.get('HOSTNAME', 'unknown')

print("="*60)
print("BATCH COMPOSITION TEST")
print("="*60)
print(f"System Info:")
print(f"  Hostname: {HOSTNAME}")
print(f"  Container: {CONTAINER_ID}")
print(f"  GPU: {torch.cuda.get_device_name(0)}")
print(f"  PyTorch: {torch.__version__}")
print(f"  CUDA: {torch.version.cuda}")
print()

# Capture relevant environment variables
print("Environment Variables:")
env_vars = {}
for key in sorted(os.environ.keys()):
    if any(x in key.upper() for x in ['CUDA', 'TORCH', 'NCCL', 'CUDNN', 'PYTORCH']):
        env_vars[key] = os.environ[key]
        print(f"  {key}={os.environ[key]}")
if not env_vars:
    print("  (No CUDA/TORCH env vars set)")
print()

CACHE_DIR = '/workspace/huggingface_cache'
model_name = "Qwen/Qwen2.5-7B-Instruct"

print(f"Loading {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=CACHE_DIR)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    cache_dir=CACHE_DIR,
    low_cpu_mem_usage=True,
    device_map="auto"
)

# Architecture info
num_layers = len(model.model.layers)
num_heads = model.config.num_attention_heads
head_dim = model.config.hidden_size // num_heads
key_vector_dim = num_heads * head_dim

print(f"Model: {model_name}")
print(f"Layers: {num_layers}, Heads: {num_heads}, Key dim: {key_vector_dim}\n")

# Seven unique sequences, all >100 tokens before truncation
raw_sequences = [
    # Sequence 0 - the one we always measure
    """The automated data-processing pipeline ingests raw telemetry from distributed sensors 
    across multiple geographic locations. A proprietary algorithm then normalizes the dataset, 
    filtering for anomalies based on predefined statistical parameters derived from historical 
    patterns. The resulting output is a clean, structured matrix ready for machine learning model 
    ingestion and downstream analytical workflows. System efficiency is monitored in real-time 
    through a comprehensive dashboard, with automated alerts triggered if latency exceeds the 
    established threshold or if data quality metrics fall below acceptable ranges. Advanced 
    compression techniques optimize storage utilization across the distributed infrastructure.
    Performance metrics are tracked continuously to ensure optimal throughput and minimal latency.""",
    
    # Sequences 1-6 - dummy sequences with different content
    """Climate modeling techniques have advanced significantly through the integration of 
    high-resolution satellite imagery and ground-based observation networks. Researchers combine 
    atmospheric physics equations with empirical data to simulate complex weather patterns and 
    predict long-term climate trends. These models incorporate ocean currents, ice sheet dynamics, 
    and greenhouse gas concentrations to provide increasingly accurate projections for policymakers 
    and environmental scientists worldwide. Modern computational infrastructure enables simulations 
    at unprecedented scales and temporal resolution with remarkable accuracy and scientific validity.
    International collaboration facilitates data sharing and model validation across research institutions.""",
    
    """Quantum entanglement represents one of the most counterintuitive phenomena in modern physics, 
    where particles become correlated in ways that defy classical explanations. When two particles 
    are entangled, measuring the state of one instantaneously affects the other regardless of the 
    distance separating them. This property has profound implications for quantum computing and 
    cryptography, enabling novel approaches to information processing and secure communication that 
    are fundamentally impossible with classical systems and traditional computational paradigms. Research 
    continues to explore practical applications of these quantum mechanical principles. Experimental 
    verification requires sophisticated detection equipment and precisely controlled laboratory conditions.""",
    
    """The human immune system comprises an intricate network of cells, tissues, and organs that 
    work collaboratively to defend against pathogens and foreign substances. White blood cells 
    patrol the bloodstream and tissues, identifying and neutralizing threats through both innate 
    and adaptive immune responses. B cells produce antibodies that target specific antigens, while 
    T cells orchestrate cellular immunity and eliminate infected cells. This sophisticated biological 
    defense mechanism evolved over millions of years to protect organisms. Memory cells enable rapid 
    responses to previously encountered pathogens through accelerated antibody production. The lymphatic 
    system transports immune cells throughout the body to maintain comprehensive surveillance.""",
    
    """Renaissance architecture flourished throughout Europe during the 14th to 17th centuries, 
    characterized by symmetry, proportion, and the revival of classical Greco-Roman design principles. 
    Architects like Brunelleschi and Palladio pioneered innovative structural techniques including 
    the use of mathematical ratios to create harmonious spatial relationships. Major cathedrals and 
    palaces from this era demonstrate remarkable engineering achievements, featuring elaborate domes, 
    precise geometric layouts, and ornamental details that reflected humanist ideals and aesthetic 
    philosophies of the period. These buildings continue to inspire contemporary architectural design.
    Patron families commissioned grand structures to demonstrate wealth and cultural sophistication.""",
    
    """Deep ocean hydrothermal vents create unique ecosystems that thrive in extreme conditions 
    without sunlight, relying instead on chemosynthetic bacteria that derive energy from volcanic 
    minerals. These remarkable habitats support diverse communities of specialized organisms including 
    tube worms, blind shrimp, and unusual fish species adapted to high pressure and temperature 
    gradients. Scientists study these environments to understand the origins of life on Earth and 
    explore possibilities for extraterrestrial biology in similarly extreme conditions on other 
    planetary bodies. Research expeditions use sophisticated submersibles to document biodiversity.
    Chemical analysis reveals novel metabolic pathways that challenge conventional biological theories.""",
    
    """Blockchain technology revolutionizes data integrity through distributed ledger systems that 
    maintain cryptographically secure transaction records across decentralized networks. Each block 
    contains a cryptographic hash of the previous block, creating an immutable chain resistant to 
    tampering and fraud. Consensus mechanisms like proof-of-work or proof-of-stake ensure network 
    participants agree on the validity of transactions without requiring a central authority. These 
    properties enable trustless systems for financial transactions, supply chain tracking, and digital 
    asset management in various industries. Smart contracts automate business logic on blockchain.
    Decentralized applications leverage these capabilities to create transparent and auditable systems.""",
]

# Tokenize and truncate to exactly 100 tokens
TARGET_LENGTH = 100
token_ids_list = []

print("Tokenizing and truncating sequences to exactly 100 tokens:")
for i, text in enumerate(raw_sequences):
    token_ids = tokenizer.encode(text, add_special_tokens=False)
    if len(token_ids) < TARGET_LENGTH:
        raise ValueError(f"Sequence {i} has only {len(token_ids)} tokens, need at least {TARGET_LENGTH}")
    
    # Truncate to exactly TARGET_LENGTH
    token_ids = token_ids[:TARGET_LENGTH]
    token_ids_list.append(token_ids)
    
    # Decode back to text for display
    text_truncated = tokenizer.decode(token_ids, skip_special_tokens=True)
    print(f"  Seq {i}: {len(token_ids)} tokens - '{text_truncated[:60]}...'")

print()

def collect_key_vector(model, token_ids_batch, device="cuda"):
    """
    Forward pass with batch of pre-tokenized sequences.
    All sequences have exactly the same length (no padding needed).
    Returns key vector for sequence 0.
    """
    # Convert token IDs to tensor (batch_size, seq_len)
    input_ids = torch.tensor(token_ids_batch, dtype=torch.long).to(device)
    
    # All sequences same length, so attention mask is all ones
    attention_mask = torch.ones_like(input_ids)
    
    with torch.no_grad():
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            use_cache=True,
            return_dict=True
        )
    
    # Extract key vectors from last layer's KV cache
    past_kv = outputs.past_key_values
    last_layer_kv = past_kv[-1]  # Last layer
    key_cache = last_layer_kv[0]  # Keys (batch, num_heads, seq_len, head_dim)
    
    # Get sequence 0's keys at the last token position (position 99, 0-indexed)
    seq_0_keys = key_cache[0, :, -1, :]  # (num_heads, head_dim)
    key_vector = seq_0_keys.reshape(-1).cpu().clone()  # Flatten to 1D
    
    del outputs
    torch.cuda.empty_cache()
    
    return key_vector

# Define 2 different batch compositions (batch size 4, different content)
# Both include sequence 0, but different sequences in positions 1,2,3
compositions = [
    [token_ids_list[0], token_ids_list[1], token_ids_list[2], token_ids_list[3]],  # seq0 + seq1,2,3
    [token_ids_list[0], token_ids_list[4], token_ids_list[5], token_ids_list[6]],  # seq0 + seq4,5,6
]

print("="*60)
print("EXPERIMENT: Batch Composition Effect Test")
print("="*60)
print(f"Batch size: 4 (constant)")
print(f"Sequence length: {TARGET_LENGTH} tokens (all sequences)")
print(f"Sequence 0: IDENTICAL across both compositions")
print(f"Positions 1,2,3: DIFFERENT content per composition")
print(f"Repetitions per composition: 5")
print()

num_reps = 5
results = {}

for comp_idx, composition in enumerate(compositions, 1):
    print(f"Composition {comp_idx}:")
    print(f"  Position 0: Sequence 0 (always the same)")
    if comp_idx == 1:
        print(f"  Position 1: Sequence 1 (climate modeling)")
        print(f"  Position 2: Sequence 2 (quantum entanglement)")
        print(f"  Position 3: Sequence 3 (immune system)")
    else:
        print(f"  Position 1: Sequence 4 (Renaissance architecture)")
        print(f"  Position 2: Sequence 5 (ocean hydrothermal vents)")
        print(f"  Position 3: Sequence 6 (blockchain technology)")
    
    runs = []
    for rep in range(num_reps):
        key_vec = collect_key_vector(model, composition, device="cuda")
        runs.append(key_vec)
        if rep == 0:
            print(f"  Rep 0: norm={torch.norm(key_vec).item():.6f}, first_val={key_vec[0].item():.6f}")
    
    results[f"composition_{comp_idx}"] = torch.stack(runs)
    
    # Check repeatability within composition
    first_rep = runs[0]
    all_identical = all(torch.equal(first_rep, runs[i]) for i in range(1, num_reps))
    if all_identical:
        print(f"  ✓ All {num_reps} repetitions identical within composition")
    else:
        print(f"  ⚠ Repetitions vary within composition (unexpected!)")
        for i in range(1, num_reps):
            if not torch.equal(first_rep, runs[i]):
                l2 = torch.norm(first_rep - runs[i]).item()
                print(f"    Rep 0 vs {i}: L2={l2:.6f}")
    print()

# Compare across compositions
print("="*60)
print("CRITICAL ANALYSIS: Does Batch Composition Matter?")
print("="*60)

comp1_mean = results["composition_1"].mean(dim=0)
comp2_mean = results["composition_2"].mean(dim=0)

print(f"Key vector dimension: {comp1_mean.shape[0]}")
print(f"Comp 1 mean norm: {torch.norm(comp1_mean).item():.6f}")
print(f"Comp 2 mean norm: {torch.norm(comp2_mean).item():.6f}")
print()

l2_distance = torch.norm(comp1_mean - comp2_mean).item()

print(f"L2 distance (Comp 1 vs Comp 2): {l2_distance:.6f}")

if torch.norm(comp1_mean) > 0:
    rel_diff = l2_distance / torch.norm(comp1_mean).item()
    print(f"Relative difference: {rel_diff:.8f}")
print()

# Element-wise analysis
diff = (comp1_mean - comp2_mean).abs()
print("Element-wise differences:")
print(f"  Max |diff|: {diff.max().item():.8f}")
print(f"  Mean |diff|: {diff.mean().item():.8f}")
print(f"  Elements with |diff| > 1e-6: {(diff > 1e-6).sum().item()}/{diff.shape[0]}")
print()

# Check if compositions are exactly identical
exact_match = torch.equal(comp1_mean, comp2_mean)

print("="*60)
print("VERDICT")
print("="*60)
print(f"Test configuration:")
print(f"  - Sequence 0: IDENTICAL in both compositions")
print(f"  - Batch size: {len(compositions[0])} (constant)")
print(f"  - Sequence length: {TARGET_LENGTH} tokens (all equal)")
print(f"  - Batch content (pos 1,2,3): DIFFERENT between compositions")
print()

if exact_match:
    print("✓ PERFECT MATCH: Both compositions produce IDENTICAL key vectors")
    print("  → Batch composition does NOT matter")
    print("  → Only batch SIZE affects outputs")
    print("  → Verification needs only: batch size logs")
    print("  → ✓✓✓ BUBBLE-BASED VERIFICATION IS FEASIBLE")
elif l2_distance < 1e-6:
    print("✓ NEGLIGIBLE DIFFERENCE: Compositions produce nearly identical results")
    print(f"  L2 distance: {l2_distance:.8f} (effectively zero)")
    print("  → Batch composition effects are below noise floor")
    print("  → Verification viable with batch size alone")
    print("  → ✓✓ BUBBLE-BASED VERIFICATION IS FEASIBLE")
elif l2_distance < 0.01:
    print("⚠ SMALL DIFFERENCE: Detectable but small composition effect")
    print(f"  L2 distance: {l2_distance:.6f}")
    print("  → Batch composition might have minor effects")
    print("  → Need to test decode phase to confirm")
    print("  → Verification may require batch composition logs")
else:
    print("✗ SIGNIFICANT DIFFERENCE: Batch composition DOES matter")
    print(f"  L2 distance: {l2_distance:.6f}")
    print("  → Sequence 0's outputs depend on batch neighbors")
    print("  → Verification REQUIRES full batch composition logs")
    print("  → ✗✗ BUBBLE-BASED VERIFICATION NEEDS HEAVY LOGGING")

# Save results
output = {
    "experiment": "batch_composition_prefill_test",
    "timestamp": datetime.now().isoformat(),
    "model": model_name,
    "hardware": {
        "gpu": torch.cuda.get_device_name(0),
        "pytorch": torch.__version__,
        "cuda": torch.version.cuda,
        "hostname": HOSTNAME,
        "container_id": CONTAINER_ID
    },
    "environment": env_vars,
    "config": {
        "batch_size": len(compositions[0]),
        "sequence_length": TARGET_LENGTH,
        "num_compositions": len(compositions),
        "repetitions_per_composition": num_reps,
        "dtype": "bfloat16",
        "operation": "single_forward_pass_prefill",
        "extraction_method": "last_layer_keys_last_token_position"
    },
    "results": {
        "exact_match": exact_match,
        "l2_distance": l2_distance,
        "relative_difference": rel_diff if torch.norm(comp1_mean) > 0 else 0,
        "max_element_diff": diff.max().item(),
        "mean_element_diff": diff.mean().item(),
    },
    "conclusion": (
        "composition_irrelevant" if exact_match else
        "negligible_effect" if l2_distance < 1e-6 else
        "small_effect" if l2_distance < 0.01 else
        "significant_effect"
    )
}

output_file = f"batch_composition_test_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
output_path = f"/mnt/user-data/outputs/{output_file}"

os.makedirs("/mnt/user-data/outputs", exist_ok=True)
with open(output_path, "w") as f:
    json.dump(output, f, indent=2)

print(f"\n✓ Results saved to {output_path}")
print("="*60)
print("EXPERIMENT COMPLETE")
print("="*60)

BATCH COMPOSITION TEST
System Info:
  Hostname: 717db41b5256
  Container: 717db41b5256
  GPU: NVIDIA A40
  PyTorch: 2.8.0+cu128
  CUDA: 12.8

Environment Variables:
  CUDA_MODULE_LOADING=LAZY
  CUDA_VERSION=12.8.1
  NCCL_VERSION=2.25.1-1
  NVIDIA_REQUIRE_CUDA=cuda>=12.8 brand=unknown,driver>=470,driver<471 brand=grid,driver>=470,driver<471 brand=tesla,driver>=470,driver<471 brand=nvidia,driver>=470,driver<471 brand=quadro,driver>=470,driver<471 brand=quadrortx,driver>=470,driver<471 brand=nvidiartx,driver>=470,driver<471 brand=vapps,driver>=470,driver<471 brand=vpc,driver>=470,driver<471 brand=vcs,driver>=470,driver<471 brand=vws,driver>=470,driver<471 brand=cloudgaming,driver>=470,driver<471 brand=unknown,driver>=535,driver<536 brand=grid,driver>=535,driver<536 brand=tesla,driver>=535,driver<536 brand=nvidia,driver>=535,driver<536 brand=quadro,driver>=535,driver<536 brand=quadrortx,driver>=535,driver<536 brand=nvidiartx,driver>=535,driver<536 brand=vapps,driver>=535,driver<536 brand=v

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Model: Qwen/Qwen2.5-7B-Instruct
Layers: 28, Heads: 28, Key dim: 3584

Tokenizing and truncating sequences to exactly 100 tokens:
  Seq 0: 100 tokens - 'The automated data-processing pipeline ingests raw telemetry...'
  Seq 1: 100 tokens - 'Climate modeling techniques have advanced significantly thro...'
  Seq 2: 100 tokens - 'Quantum entanglement represents one of the most counterintui...'
  Seq 3: 100 tokens - 'The human immune system comprises an intricate network of ce...'
  Seq 4: 100 tokens - 'Renaissance architecture flourished throughout Europe during...'
  Seq 5: 100 tokens - 'Deep ocean hydrothermal vents create unique ecosystems that ...'
  Seq 6: 100 tokens - 'Blockchain technology revolutionizes data integrity through ...'

EXPERIMENT: Batch Composition Effect Test
Batch size: 4 (constant)
Sequence length: 100 tokens (all sequences)
Sequence 0: IDENTICAL across both compositions
Positions 1,2,3: DIFFERENT content per composition
Repetitions per composition: 5

Composition 1