In [1]:
import os
os.environ['HF_HOME'] = '/workspace/huggingface_cache'
os.environ['TRANSFORMERS_CACHE'] = '/workspace/huggingface_cache'
os.environ['HF_DATASETS_CACHE'] = '/workspace/huggingface_cache'

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import torch.nn as nn
import gc
import numpy as np
from datetime import datetime
import json
import socket

# Capture system info
HOSTNAME = socket.gethostname()
CONTAINER_ID = os.environ.get('HOSTNAME', 'unknown')

print(f"System Info:")
print(f"  Hostname: {HOSTNAME}")
print(f"  GPUs available: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
    print(f"    GPU {i}: {torch.cuda.get_device_name(i)}")
print(f"  PyTorch: {torch.__version__}")
print(f"  CUDA: {torch.version.cuda}")
print()

# Configuration
CACHE_DIR = '/workspace/huggingface_cache'
model_name = "mistralai/Mistral-Small-Instruct-2409"  # 24B, 32 layers
# Note: If Mistral-Small-3.2 uses a different model ID, update this

print(f"Loading tokenizer for {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=CACHE_DIR)

# Load test prompt
prompt = """The development of large language models has fundamentally transformed natural language processing 
and artificial intelligence more broadly. These models, trained on vast corpora of text data, have demonstrated 
remarkable capabilities across a wide range of tasks, from translation and summarization to question answering 
and creative writing. However, their deployment raises significant challenges related to computational efficiency, 
interpretability, and safety.

One critical challenge in deploying large language models at scale is ensuring computational efficiency. 
Modern language models can contain hundreds of billions of parameters, requiring substantial computational 
resources for both training and inference. Inference optimization techniques have become increasingly important 
as these models are deployed in production environments. Key approaches include quantization, where model 
weights and activations are represented with reduced precision; knowledge distillation, where a smaller 
student model learns to mimic a larger teacher model; and architectural innovations such as mixture-of-experts 
models that activate only relevant subnetworks for each input.

The inference stack itself introduces numerous sources of variation in model outputs. Floating-point arithmetic 
is inherently non-associative, meaning that the order of operations affects the final result. In distributed 
inference scenarios, where computation is parallelized across multiple GPUs, different parallelization strategies 
can lead to different operation orderings and thus different numerical results, even when using identical model 
weights and inputs. Factors such as batch size, communication patterns between GPUs, the specific CUDA kernels 
selected for various operations, and even the GPU architecture itself can all contribute to variations in output."""

prompt_tokens = len(tokenizer.encode(prompt))
print(f"Test prompt: {prompt_tokens} tokens\n")

# Global variable to capture keys
captured_keys = None

def create_hook(layer_idx, total_layers):
    """Create a hook to capture key vectors from the last layer"""
    def hook(module, input, output):
        global captured_keys
        # Only capture from the last layer
        if layer_idx == total_layers - 1:
            # For Mistral, the attention output is typically (batch, seq_len, num_heads, head_dim)
            # or the past_key_values contain the keys
            # We need to check the actual structure
            
            # Mistral attention returns: (attn_output, attn_weights, past_key_value)
            # past_key_value is (key, value) tuple
            if isinstance(output, tuple) and len(output) >= 3:
                past_kv = output[2]
                if past_kv is not None and isinstance(past_kv, tuple):
                    # past_kv[0] is keys: shape (batch, num_key_heads, seq_len, head_dim)
                    keys = past_kv[0]
                    # Extract last token position: (batch, num_key_heads, head_dim)
                    last_token_keys = keys[:, :, -1, :]
                    # Flatten to (batch, num_key_heads * head_dim)
                    captured_keys = last_token_keys.reshape(last_token_keys.shape[0], -1).cpu().clone()
    return hook

def collect_keys_simple(model, tokenizer, prompt, device="cuda"):
    """Collect keys without DeepSpeed - simple single GPU version"""
    global captured_keys
    captured_keys = None
    
    # Register hook on last layer
    num_layers = len(model.model.layers)
    last_layer = model.model.layers[-1]
    hook_handle = last_layer.self_attn.register_forward_hook(
        create_hook(num_layers - 1, num_layers)
    )
    
    # Forward pass
    inputs = tokenizer([prompt], return_tensors="pt", padding=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model(**inputs, output_attentions=False, use_cache=True)
    
    hook_handle.remove()
    
    # If hook didn't capture (sometimes past_key_values structure differs), 
    # try alternative extraction
    if captured_keys is None:
        print("    Warning: Hook didn't capture keys, trying alternative method...")
        # Try to get from model outputs if available
        if hasattr(outputs, 'past_key_values') and outputs.past_key_values is not None:
            # past_key_values is tuple of (key, value) for each layer
            last_layer_kv = outputs.past_key_values[-1]
            keys = last_layer_kv[0]  # (batch, num_key_heads, seq_len, head_dim)
            last_token_keys = keys[:, :, -1, :]
            captured_keys = last_token_keys.reshape(last_token_keys.shape[0], -1).cpu().clone()
    
    if captured_keys is None:
        raise RuntimeError("Failed to extract keys from model")
    
    result = captured_keys[0]  # Remove batch dimension
    
    del outputs, inputs
    torch.cuda.empty_cache()
    
    return result

def load_model_no_pp(model_name, cache_dir):
    """Load model without pipeline parallelism (single GPU or model parallel)"""
    print("\nLoading model WITHOUT pipeline parallelism...")
    gc.collect()
    torch.cuda.empty_cache()
    
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        cache_dir=cache_dir,
        low_cpu_mem_usage=True,
        device_map="auto"  # Will use single GPU if it fits
    )
    
    mem_after = torch.cuda.memory_allocated(0) / 1024**3
    print(f"  GPU 0 memory after load: {mem_after:.2f} GB")
    print(f"  Model layers: {len(model.model.layers)}")
    
    return model

def load_model_with_pp(model_name, cache_dir, num_stages):
    """Load model with pipeline parallelism using device_map"""
    print(f"\nLoading model WITH {num_stages}-way pipeline parallelism...")
    gc.collect()
    torch.cuda.empty_cache()
    
    # Get actual number of layers from config
    from transformers import AutoConfig
    config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
    num_layers = config.num_hidden_layers
    print(f"  Detected {num_layers} layers in model")
    
    layers_per_stage = num_layers // num_stages
    
    # Create device map for pipeline parallelism
    # We manually assign layers to GPUs
    device_map = {}
    
    # Embedding on GPU 0
    device_map["model.embed_tokens"] = 0
    device_map["model.norm"] = num_stages - 1  # Final norm on last GPU
    device_map["lm_head"] = num_stages - 1  # LM head on last GPU
    
    # Distribute layers across GPUs
    for layer_idx in range(num_layers):
        gpu_idx = layer_idx // layers_per_stage
        if gpu_idx >= num_stages:
            gpu_idx = num_stages - 1
        device_map[f"model.layers.{layer_idx}"] = gpu_idx
    
    print(f"  Pipeline configuration:")
    for gpu_idx in range(num_stages):
        start_layer = gpu_idx * layers_per_stage
        end_layer = min((gpu_idx + 1) * layers_per_stage, num_layers) - 1
        print(f"    GPU {gpu_idx}: layers {start_layer}-{end_layer}")
    
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        cache_dir=cache_dir,
        low_cpu_mem_usage=True,
        device_map=device_map
    )
    
    # Report memory usage across GPUs
    for gpu_idx in range(num_stages):
        mem = torch.cuda.memory_allocated(gpu_idx) / 1024**3
        print(f"  GPU {gpu_idx} memory: {mem:.2f} GB")
    
    return model

def collect_keys_pp(model, tokenizer, prompt):
    """Collect keys from pipeline parallel model"""
    global captured_keys
    captured_keys = None
    
    # Find which GPU has the last layer
    num_layers = len(model.model.layers)
    last_layer = model.model.layers[-1]
    last_layer_device = next(last_layer.parameters()).device
    
    print(f"    Last layer (layer {num_layers-1}) is on {last_layer_device}")
    
    # Register hook on last layer
    hook_handle = last_layer.self_attn.register_forward_hook(
        create_hook(num_layers - 1, num_layers)
    )
    
    # Forward pass - inputs start on GPU 0
    inputs = tokenizer([prompt], return_tensors="pt", padding=True)
    inputs = {k: v.to("cuda:0") for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model(**inputs, output_attentions=False, use_cache=True)
    
    hook_handle.remove()
    
    # Alternative extraction if hook didn't work
    if captured_keys is None:
        print("    Warning: Hook didn't capture keys, trying alternative method...")
        if hasattr(outputs, 'past_key_values') and outputs.past_key_values is not None:
            last_layer_kv = outputs.past_key_values[-1]
            keys = last_layer_kv[0]
            last_token_keys = keys[:, :, -1, :]
            captured_keys = last_token_keys.reshape(last_token_keys.shape[0], -1).cpu().clone()
    
    if captured_keys is None:
        raise RuntimeError("Failed to extract keys from pipeline parallel model")
    
    result = captured_keys[0]  # Remove batch dimension
    
    del outputs, inputs
    torch.cuda.empty_cache()
    
    return result

def unload_model(model):
    """Completely remove model from memory"""
    print(f"  Unloading model...")
    del model
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    for gpu_idx in range(torch.cuda.device_count()):
        mem = torch.cuda.memory_allocated(gpu_idx) / 1024**3
        if mem > 1:  # Only print if significant memory
            print(f"    GPU {gpu_idx} memory after unload: {mem:.2f} GB")

# ============================================================================
# MAIN EXPERIMENT
# ============================================================================

num_repetitions = 10
results = {}
all_keys = {}
config_info = {}

configurations = [
    ("no_pp", None),
    ("pp_2way", 2),
    ("pp_4way", 4)
]

print(f"{'='*70}")
print(f"PIPELINE PARALLELISM FORENSICS EXPERIMENT")
print(f"Model: {model_name}")
print(f"Precision: BF16 (bfloat16)")
print(f"Prompt tokens: {prompt_tokens}")
print(f"Extraction: Last layer keys, last token, all heads concatenated")
print(f"Repetitions per configuration: {num_repetitions}")
print(f"{'='*70}\n")

for config_name, num_stages in configurations:
    print(f"{'='*70}")
    print(f"TESTING: {config_name}")
    print(f"{'='*70}")
    
    # Load model with appropriate configuration
    if num_stages is None:
        model = load_model_no_pp(model_name, CACHE_DIR)
        collect_fn = lambda: collect_keys_simple(model, tokenizer, prompt, device="cuda")
    else:
        model = load_model_with_pp(model_name, CACHE_DIR, num_stages)
        collect_fn = lambda: collect_keys_pp(model, tokenizer, prompt)
    
    config_info[config_name] = {
        "pipeline_stages": num_stages if num_stages else 1,
        "description": f"{'No PP (single/auto device map)' if num_stages is None else f'{num_stages}-way pipeline parallelism'}"
    }
    
    # Collect keys
    print(f"\nCollecting key vectors ({num_repetitions} repetitions)...")
    runs = []
    
    for rep in range(num_repetitions):
        keys = collect_fn()
        runs.append(keys)
        
        if rep == 0:
            print(f"  Rep 0: shape={keys.shape}, norm={torch.norm(keys).item():.2f}, first_val={keys[0].item():.6f}")
        if (rep + 1) % 3 == 0:
            print(f"  Completed {rep + 1}/{num_repetitions} repetitions")
    
    # Calculate statistical noise
    first_rep = runs[0]
    all_identical = all(torch.equal(first_rep, runs[i]) for i in range(1, num_repetitions))
    
    stacked = torch.stack(runs)
    mean_keys = stacked.mean(dim=0)
    deviations = torch.stack([torch.norm(runs[i] - mean_keys) for i in range(num_repetitions)])
    mean_noise = deviations.mean().item()
    std_noise = deviations.std().item()
    
    if all_identical:
        print(f"  ✓ Statistical noise: mean=0.000000, std=0.000000 (perfect reproducibility)")
    else:
        print(f"  ⚠ Statistical noise: mean={mean_noise:.6f}, std={std_noise:.6f}")
    
    config_info[config_name]["statistical_noise"] = {
        "mean": mean_noise,
        "std": std_noise,
        "perfect_reproducibility": all_identical
    }
    
    results[config_name] = torch.stack(runs)
    all_keys[config_name] = results[config_name].float().numpy().tolist()
    
    mean_keys = results[config_name].mean(dim=0)
    print(f"  Mean key vector norm: {torch.norm(mean_keys).item():.2f}\n")
    
    # Unload model
    unload_model(model)
    print()

# ============================================================================
# ANALYSIS
# ============================================================================

print(f"{'='*70}")
print("=== SYSTEMATIC DEVIATION ANALYSIS ===")
print(f"{'='*70}\n")

comparisons = [
    ("no_pp", "pp_2way"),
    ("no_pp", "pp_4way"),
    ("pp_2way", "pp_4way")
]

deviations = {}

for config1, config2 in comparisons:
    mean1 = results[config1].mean(dim=0)
    mean2 = results[config2].mean(dim=0)
    
    l2 = torch.norm(mean1 - mean2).item()
    relative = (l2 / torch.norm(mean1)).item() if torch.norm(mean1) > 0 else 0
    
    deviations[f"{config1}_vs_{config2}"] = l2
    
    print(f"{config1} vs {config2}:")
    print(f"  L2 distance: {l2:.6f}")
    print(f"  Relative: {relative:.6f} ({relative*100:.3f}%)")
    print()

# Detailed comparison
print(f"{'='*70}")
print("=== KEY VECTOR STATISTICS ===")
print(f"{'='*70}\n")

for config_name in results.keys():
    mean_keys = results[config_name].mean(dim=0)
    print(f"{config_name}:")
    print(f"  Vector dimension: {mean_keys.shape[0]}")
    print(f"  Vector norm: {torch.norm(mean_keys).item():.2f}")
    print(f"  Mean value: {mean_keys.mean().item():.6f}")
    print(f"  Std dev: {mean_keys.std().item():.6f}")
    print()

# Interpretation
print(f"{'='*70}")
print("=== INTERPRETATION ===")
print(f"{'='*70}\n")

max_deviation = max(deviations.values())
max_statistical_noise = max(config_info[c]["statistical_noise"]["mean"] for c in config_info.keys())

print(f"Statistical noise summary:")
for config in config_info.keys():
    noise = config_info[config]["statistical_noise"]
    if noise["perfect_reproducibility"]:
        print(f"  {config}: Perfect (L2=0.000)")
    else:
        print(f"  {config}: mean={noise['mean']:.6f}, std={noise['std']:.6f}")
print()

print(f"Systematic deviations summary:")
for comp_name, dev in deviations.items():
    print(f"  {comp_name}: L2={dev:.6f}")
print()

# Signal-to-noise ratio
if max_statistical_noise > 0:
    snr = max_deviation / max_statistical_noise
    print(f"Signal-to-noise ratio: {snr:.2f}x")
    print(f"  Max systematic deviation: {max_deviation:.6f}")
    print(f"  Max statistical noise: {max_statistical_noise:.6f}")
    print()

if max_deviation < 0.001:
    print("✓ PIPELINE PARALLELISM IS NOT DETECTABLE")
    print(f"  Max L2 deviation: {max_deviation:.6f}")
    print(f"  All configurations produce essentially identical key vectors")
    print(f"  This confirms: PP only changes WHERE computation happens, not HOW")
    print(f"  Conclusion: PP configuration cannot be used for forensics")
elif max_deviation < 0.1:
    if max_statistical_noise > 0 and max_deviation < max_statistical_noise * 3:
        print("✗ SIGNAL TOO WEAK")
        print(f"  Max L2 deviation: {max_deviation:.6f}")
        print(f"  Systematic deviation is comparable to or smaller than statistical noise")
        print(f"  Not reliable for forensics")
    else:
        print("⚠ WEAK BUT POTENTIALLY DETECTABLE SIGNAL")
        print(f"  Max L2 deviation: {max_deviation:.6f}")
        print(f"  Small differences detected, may be reliable with sufficient samples")
else:
    print("✓ PIPELINE PARALLELISM IS DETECTABLE")
    print(f"  Max L2 deviation: {max_deviation:.6f}")
    print(f"  Different PP configurations produce measurably different key vectors")
    print(f"  This is surprising and worth investigating further")

# ============================================================================
# SAVE RESULTS
# ============================================================================

output = {
    "experiment": "pipeline_parallelism_forensics",
    "timestamp": datetime.now().isoformat(),
    "model": model_name,
    "hardware": {
        "num_gpus": torch.cuda.device_count(),
        "gpu_models": [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())],
        "pytorch": torch.__version__,
        "cuda": torch.version.cuda,
        "hostname": HOSTNAME,
        "container_id": CONTAINER_ID
    },
    "config": {
        "configurations_tested": [c[0] for c in configurations],
        "repetitions": num_repetitions,
        "prompt_tokens": prompt_tokens,
        "dtype": "bfloat16",
        "extraction": "last_layer_keys_last_token_all_heads"
    },
    "configuration_details": config_info,
    "results": {
        "key_vector_dims": {
            config: int(results[config].shape[1]) for config in results.keys()
        },
        "key_vector_norms": {
            config: float(torch.norm(results[config].mean(dim=0))) for config in results.keys()
        },
        "statistical_noise": {
            config: config_info[config]["statistical_noise"] for config in results.keys()
        },
        "systematic_deviations": deviations,
        "signal_to_noise_ratio": max(deviations.values()) / max(config_info[c]["statistical_noise"]["mean"] for c in config_info.keys()) if max(config_info[c]["statistical_noise"]["mean"] for c in config_info.keys()) > 0 else float('inf'),
        "within_config_reproducibility": {
            config: all(torch.equal(results[config][0], results[config][i]) 
                       for i in range(num_repetitions))
            for config in results.keys()
        }
    },
    "raw_keys": all_keys
}

gpu_name = torch.cuda.get_device_name(0).replace(' ', '_')
output_file = f"{gpu_name}_pp_forensics_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
output_path = f"/workspace/{output_file}"

with open(output_path, "w") as f:
    json.dump(output, f, indent=2)

print(f"\n✓ Results saved to {output_path}")
print(f"✓ File size: ~{len(json.dumps(output)) / 1024:.1f} KB")
print(f"\n{'='*70}")
print("EXPERIMENT COMPLETE")
print(f"{'='*70}")



System Info:
  Hostname: ec4b94e56229
  GPUs available: 4
    GPU 0: NVIDIA A100 80GB PCIe
    GPU 1: NVIDIA A100 80GB PCIe
    GPU 2: NVIDIA A100 80GB PCIe
    GPU 3: NVIDIA A100 80GB PCIe
  PyTorch: 2.8.0+cu128
  CUDA: 12.8

Loading tokenizer for mistralai/Mistral-Small-Instruct-2409...


`torch_dtype` is deprecated! Use `dtype` instead!


Test prompt: 356 tokens

PIPELINE PARALLELISM FORENSICS EXPERIMENT
Model: mistralai/Mistral-Small-Instruct-2409
Precision: BF16 (bfloat16)
Prompt tokens: 356
Extraction: Last layer keys, last token, all heads concatenated
Repetitions per configuration: 10

TESTING: no_pp

Loading model WITHOUT pipeline parallelism...


Loading checkpoint shards:   0%|          | 0/9 [00:00<?, ?it/s]

  GPU 0 memory after load: 9.82 GB
  Model layers: 56

Collecting key vectors (10 repetitions)...
  Rep 0: shape=torch.Size([1024]), norm=53.25, first_val=0.625000
  Completed 3/10 repetitions
  Completed 6/10 repetitions
  Completed 9/10 repetitions
  ✓ Statistical noise: mean=0.000000, std=0.000000 (perfect reproducibility)
  Mean key vector norm: 53.25

  Unloading model...
    GPU 0 memory after unload: 9.83 GB
    GPU 1 memory after unload: 10.91 GB
    GPU 2 memory after unload: 10.91 GB
    GPU 3 memory after unload: 9.83 GB

TESTING: pp_2way

Loading model WITH 2-way pipeline parallelism...
  Detected 56 layers in model
  Pipeline configuration:
    GPU 0: layers 0-27
    GPU 1: layers 28-55


Loading checkpoint shards:   0%|          | 0/9 [00:00<?, ?it/s]

  GPU 0 memory: 30.55 GB
  GPU 1 memory: 31.63 GB

Collecting key vectors (10 repetitions)...
    Last layer (layer 55) is on cuda:1
  Rep 0: shape=torch.Size([1024]), norm=53.25, first_val=0.625000
    Last layer (layer 55) is on cuda:1
    Last layer (layer 55) is on cuda:1
  Completed 3/10 repetitions
    Last layer (layer 55) is on cuda:1
    Last layer (layer 55) is on cuda:1
    Last layer (layer 55) is on cuda:1
  Completed 6/10 repetitions
    Last layer (layer 55) is on cuda:1
    Last layer (layer 55) is on cuda:1
    Last layer (layer 55) is on cuda:1
  Completed 9/10 repetitions
    Last layer (layer 55) is on cuda:1
  ✓ Statistical noise: mean=0.000000, std=0.000000 (perfect reproducibility)
  Mean key vector norm: 53.25

  Unloading model...
    GPU 0 memory after unload: 20.73 GB
    GPU 1 memory after unload: 20.73 GB

TESTING: pp_4way

Loading model WITH 4-way pipeline parallelism...
  Detected 56 layers in model
  Pipeline configuration:
    GPU 0: layers 0-13
    GPU

Loading checkpoint shards:   0%|          | 0/9 [00:00<?, ?it/s]

  GPU 0 memory: 31.27 GB
  GPU 1 memory: 30.90 GB
  GPU 2 memory: 10.18 GB
  GPU 3 memory: 10.56 GB

Collecting key vectors (10 repetitions)...
    Last layer (layer 55) is on cuda:3
  Rep 0: shape=torch.Size([1024]), norm=53.25, first_val=0.625000
    Last layer (layer 55) is on cuda:3
    Last layer (layer 55) is on cuda:3
  Completed 3/10 repetitions
    Last layer (layer 55) is on cuda:3
    Last layer (layer 55) is on cuda:3
    Last layer (layer 55) is on cuda:3
  Completed 6/10 repetitions
    Last layer (layer 55) is on cuda:3
    Last layer (layer 55) is on cuda:3
    Last layer (layer 55) is on cuda:3
  Completed 9/10 repetitions
    Last layer (layer 55) is on cuda:3
  ✓ Statistical noise: mean=0.000000, std=0.000000 (perfect reproducibility)
  Mean key vector norm: 53.25

  Unloading model...
    GPU 0 memory after unload: 10.56 GB
    GPU 1 memory after unload: 10.18 GB
    GPU 2 memory after unload: 10.18 GB
    GPU 3 memory after unload: 10.56 GB

=== SYSTEMATIC DEVIATIO