In [1]:
#!/usr/bin/env python3
"""
Transformers INT8-BitsAndBytes Determinism Test
Tests bit-exact reproducibility using HuggingFace Transformers with BitsAndBytes quantization
Compares with BF16 baseline to isolate quantization-specific reproducibility issues

Dependencies:
- transformers
- torch
- bitsandbytes (INT8 quantization backend)
"""

import os
os.environ['HF_HOME'] = '/tmp/hf_cache'
os.environ['TRANSFORMERS_CACHE'] = '/tmp/hf_cache'

# Check dependencies
print("Checking dependencies...")
missing = []

try:
    import bitsandbytes
    print("  bitsandbytes: OK")
except ImportError:
    print("  bitsandbytes: MISSING")
    missing.append("bitsandbytes")

if missing:
    print()
    print("ERROR: Missing required packages!")
    print()
    print("Please install manually in a separate cell:")
    print()
    for pkg in missing:
        print(f"  !pip install {pkg}")
    print()
    print("Then re-run this script.")
    import sys
    sys.exit(1)

print()

import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from datetime import datetime
import json
import time

# ============================================================================
# CONFIGURATION
# ============================================================================

MODEL_NAME = "Qwen/Qwen3-32B"
MAX_NEW_TOKENS = 50
NUM_REPETITIONS = 5
TEMPERATURE = 0.0  # Greedy (but we'll use do_sample=False)
SEED = 42
TOP_LOGPROBS = 10  # Will extract manually from model output

# Test prompt
USER_TASK = "Please provide a detailed summary of the following text."
DOCUMENT_CONTENT = """The development of large language models has fundamentally transformed natural language processing and artificial intelligence more broadly. These models, trained on vast corpora of text data, have demonstrated remarkable capabilities across a wide range of tasks, from translation and summarization to question answering and creative writing. The scaling laws observed in these systems suggest that performance continues to improve with model size, data scale, and compute budget, though with diminishing returns."""

# ============================================================================
# SETUP
# ============================================================================

print("="*80)
print("TRANSFORMERS INT8-BITSANDBYTES DETERMINISM TEST")
print("="*80)
print()

print("Configuration:")
print(f"  Model: {MODEL_NAME}")
print(f"  Library: HuggingFace Transformers + bitsandbytes")
print(f"  Quantization: BitsAndBytes INT8")
print(f"  Precision: INT8 with FP16 compute")
print(f"  Max new tokens: {MAX_NEW_TOKENS}")
print(f"  Temperature: {TEMPERATURE} (greedy decoding)")
print(f"  Seed: {SEED}")
print(f"  Repetitions: {NUM_REPETITIONS}")
print()

# ============================================================================
# LOAD MODEL
# ============================================================================

print("Loading model with bitsandbytes INT8 quantization...")
load_start = time.time()

# Set seeds
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    cache_dir='/tmp/hf_cache',
    trust_remote_code=True
)

# Configure INT8 quantization
quantization_config = BitsAndBytesConfig(
    load_in_8bit=True,
)

# Load model with BitsAndBytes INT8
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map="auto",
    quantization_config=quantization_config,
    cache_dir='/tmp/hf_cache',
    trust_remote_code=True
)

model.eval()  # Evaluation mode

load_time = time.time() - load_start
print(f"Model loaded in {load_time:.2f}s")
print(f"Device: {model.device}")
print()

# ============================================================================
# PREPARE INPUT
# ============================================================================

messages = [
    {"role": "user", "content": f"{USER_TASK}\n\n{DOCUMENT_CONTENT}"}
]

# Apply chat template
prompt_text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)

print(f"Prompt length: {len(prompt_text)} characters")

# Tokenize
inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device)
prompt_length = inputs.input_ids.shape[1]
print(f"Prompt tokens: {prompt_length}")
print()

# ============================================================================
# WARMUP
# ============================================================================

print("Running warmup...")
with torch.no_grad():
    warmup_output = model.generate(
        **inputs,
        max_new_tokens=MAX_NEW_TOKENS,
        do_sample=False,  # Greedy decoding
        pad_token_id=tokenizer.eos_token_id,
        output_scores=True,
        return_dict_in_generate=True
    )
print(f"Warmup complete - generated {len(warmup_output.sequences[0]) - prompt_length} tokens")
print()

# ============================================================================
# MAIN EXPERIMENT
# ============================================================================

print("="*80)
print("RUNNING EXPERIMENT")
print("="*80)
print()

results_tokens = []
results_logprobs = []
results_texts = []
timing_data = []

for rep in range(NUM_REPETITIONS):
    print(f"Repetition {rep + 1}/{NUM_REPETITIONS}...")
    
    # Reset seeds for each repetition
    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    np.random.seed(SEED)
    
    # Time generation
    start_time = time.time()
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=False,  # Greedy decoding
            pad_token_id=tokenizer.eos_token_id,
            output_scores=True,
            return_dict_in_generate=True
        )
    
    end_time = time.time()
    elapsed = end_time - start_time
    
    # Extract tokens (remove prompt)
    generated_ids = outputs.sequences[0][prompt_length:].cpu().tolist()
    num_tokens = len(generated_ids)
    
    results_tokens.append(generated_ids)
    
    # Decode text
    generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
    results_texts.append(generated_text)
    
    # Extract logprobs from scores
    # scores is tuple of tensors, one per generated token
    # shape: (batch=1, vocab_size)
    logprobs_selected = []
    for i, (token_id, scores_tensor) in enumerate(zip(generated_ids, outputs.scores)):
        # Get log probabilities
        log_probs = torch.nn.functional.log_softmax(scores_tensor[0], dim=-1)
        # Get logprob of selected token
        selected_logprob = log_probs[token_id].item()
        logprobs_selected.append(selected_logprob)
    
    results_logprobs.append(np.array(logprobs_selected))
    
    # Timing
    tokens_per_sec = num_tokens / elapsed
    timing_data.append({
        'repetition': rep + 1,
        'elapsed_time': elapsed,
        'num_tokens': num_tokens,
        'tokens_per_sec': tokens_per_sec,
        'time_per_token': elapsed / num_tokens
    })
    
    print(f"  {num_tokens} tokens in {elapsed:.3f}s ({tokens_per_sec:.1f} tok/s)")

print()
print("All repetitions complete!")
print()

# ============================================================================
# TIMING ANALYSIS
# ============================================================================

print("="*80)
print("TIMING ANALYSIS")
print("="*80)
print()

times = [t['elapsed_time'] for t in timing_data]
tps = [t['tokens_per_sec'] for t in timing_data]
tpt = [t['time_per_token'] for t in timing_data]

print("Timing statistics:")
print(f"  Mean time: {np.mean(times):.3f}s (σ={np.std(times):.4f}s)")
print(f"  Min/Max: {np.min(times):.3f}s / {np.max(times):.3f}s")
print(f"  Tokens/sec: {np.mean(tps):.1f} (σ={np.std(tps):.2f})")
print(f"  Time/token: {np.mean(tpt)*1000:.2f}ms (σ={np.std(tpt)*1000:.3f}ms)")
print()

# ============================================================================
# REPRODUCIBILITY ANALYSIS
# ============================================================================

print("="*80)
print("REPRODUCIBILITY ANALYSIS")
print("="*80)
print()
print("Comparing across repetitions (same config, same seed):")
print("  Checking if Run 1 = Run 2 = Run 3 = ... = Run 10")
print("  Testing: Can we reproduce inference bit-exactly?")
print()

# Check token sequences
print("Checking token sequences...")
tokens_identical = all(
    results_tokens[0] == results_tokens[i]
    for i in range(1, NUM_REPETITIONS)
)
print(f"Token sequences identical: {tokens_identical}")

if not tokens_identical:
    print("\n[WARNING] Token sequences differ!")
    for i in range(1, NUM_REPETITIONS):
        if results_tokens[0] != results_tokens[i]:
            diff_positions = [
                j for j in range(min(len(results_tokens[0]), len(results_tokens[i])))
                if results_tokens[0][j] != results_tokens[i][j]
            ]
            print(f"  Rep 0 vs Rep {i}: {len(diff_positions)} positions differ")
            if diff_positions:
                print(f"    First difference at position {diff_positions[0]}")

# Check logprobs
print("\nChecking selected token logprobs...")
first_logprobs = results_logprobs[0]
logprobs_exact = all(
    np.allclose(first_logprobs, results_logprobs[i], rtol=0, atol=1e-10)
    for i in range(1, NUM_REPETITIONS)
)
print(f"Selected token logprobs bit-exact: {logprobs_exact}")

if not logprobs_exact:
    print("\nL2 distances:")
    l2_distances = []
    for i in range(1, NUM_REPETITIONS):
        l2 = np.linalg.norm(first_logprobs - results_logprobs[i])
        l2_distances.append(l2)
        print(f"  Rep 0 vs Rep {i}: L2 = {l2:.6e}")
    
    print(f"\nMax L2: {max(l2_distances):.6e}")
    print(f"Mean L2: {np.mean(l2_distances):.6e}")
    
    # Element-wise statistics
    all_logprobs = np.array(results_logprobs)
    std_per_token = all_logprobs.std(axis=0)
    print(f"\nPer-token std statistics:")
    print(f"  Mean: {std_per_token.mean():.6e}")
    print(f"  Max: {std_per_token.max():.6e}")
    print(f"  Median: {np.median(std_per_token):.6e}")

print()

# ============================================================================
# VERDICT
# ============================================================================

print("="*80)
print("VERDICT - TRANSFORMERS INT8-BITSANDBYTES")
print("="*80)
print()

if tokens_identical and logprobs_exact:
    print("[PASS] PERFECT REPRODUCIBILITY WITH TRANSFORMERS INT8-BITSANDBYTES")
    print("  - Token sequences: bit-exact")
    print("  - Selected token logprobs: bit-exact")
    print("  => Transformers + bitsandbytes maintains determinism")
    print("  => INT8 quantization does not break forensic verification")
elif tokens_identical and not logprobs_exact:
    print("[WARNING] TOKENS IDENTICAL, LOGPROBS VARY")
    print("  - Token sequences: bit-exact")
    print("  - Logprobs: numerical variation")
    max_l2 = max(l2_distances) if not logprobs_exact else 0.0
    print(f"  - L2 distance: {max_l2:.6e}")
    print("  => Greedy decoding unaffected by quantization noise")
    print("  => BitsAndBytes INT8 has non-deterministic dequantization")
else:
    print("[FAIL] TOKEN SEQUENCES DIFFER")
    print("  - Greedy decoding produces different tokens")
    print("  => Severe non-determinism in Transformers BitsAndBytes")

print()
print("Comparison with BF16 baseline:")
print("  BF16: Perfect bit-exact reproducibility")
print("  INT8-BitsAndBytes: [see above]")
print()

if not tokens_identical or not logprobs_exact:
    print("Hypothesis:")
    if not logprobs_exact and tokens_identical:
        print("  INT8 dequantization introduces numerical variation")
        print("  Variation is small enough to not affect argmax selection")
        print("  Hidden state forensics may still detect quantization")
    else:
        print("  Quantization introduces significant non-determinism")

print()

# ============================================================================
# SAVE RESULTS
# ============================================================================

output_data = {
    "experiment": "transformers_int8_bitsandbytes_determinism",
    "timestamp": datetime.now().isoformat(),
    "library": "transformers + bitsandbytes",
    "prompt_text": prompt_text,
    "prompt_length_tokens": prompt_length,
    "config": {
        "model": MODEL_NAME,
        "quantization": "bitsandbytes",
        "precision": "INT8 with FP16 compute",
        "max_new_tokens": MAX_NEW_TOKENS,
        "repetitions": NUM_REPETITIONS,
        "temperature": TEMPERATURE,
        "seed": SEED,
        "do_sample": False
    },
    "timing": {
        "model_load_time": load_time,
        "per_repetition": timing_data,
        "statistics": {
            "mean_time": float(np.mean(times)),
            "std_time": float(np.std(times)),
            "mean_tokens_per_sec": float(np.mean(tps)),
            "std_tokens_per_sec": float(np.std(tps)),
            "mean_time_per_token_ms": float(np.mean(tpt) * 1000),
            "std_time_per_token_ms": float(np.std(tpt) * 1000)
        }
    },
    "results": {
        "tokens_identical": tokens_identical,
        "logprobs_exact": logprobs_exact,
        "perfect_reproducibility": tokens_identical and logprobs_exact
    },
    "token_sequences": results_tokens,
    "logprobs_vectors": [lp.tolist() for lp in results_logprobs],
    "generated_texts": results_texts
}

output_file = f"transformers_int8_bnb_determinism_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
with open(output_file, "w") as f:
    json.dump(output_data, f, indent=2)

print(f"Results saved to: {output_file}")
print()
print("="*80)
print("TEST COMPLETE")
print("="*80)

Checking dependencies...
  bitsandbytes: OK





TRANSFORMERS INT8-BITSANDBYTES DETERMINISM TEST

Configuration:
  Model: Qwen/Qwen3-32B
  Library: HuggingFace Transformers + bitsandbytes
  Quantization: BitsAndBytes INT8
  Precision: INT8 with FP16 compute
  Max new tokens: 50
  Temperature: 0.0 (greedy decoding)
  Seed: 42
  Repetitions: 5

Loading model with bitsandbytes INT8 quantization...


Loading checkpoint shards:   0%|          | 0/17 [00:00<?, ?it/s]

The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Model loaded in 48.97s
Device: cuda:0

Prompt length: 622 characters
Prompt tokens: 100

Running warmup...
Warmup complete - generated 50 tokens

RUNNING EXPERIMENT

Repetition 1/5...
  50 tokens in 7.247s (6.9 tok/s)
Repetition 2/5...
  50 tokens in 7.423s (6.7 tok/s)
Repetition 3/5...
  50 tokens in 7.313s (6.8 tok/s)
Repetition 4/5...
  50 tokens in 7.237s (6.9 tok/s)
Repetition 5/5...
  50 tokens in 7.291s (6.9 tok/s)

All repetitions complete!

TIMING ANALYSIS

Timing statistics:
  Mean time: 7.302s (σ=0.0667s)
  Min/Max: 7.237s / 7.423s
  Tokens/sec: 6.8 (σ=0.06)
  Time/token: 146.05ms (σ=1.334ms)

REPRODUCIBILITY ANALYSIS

Comparing across repetitions (same config, same seed):
  Checking if Run 1 = Run 2 = Run 3 = ... = Run 10
  Testing: Can we reproduce inference bit-exactly?

Checking token sequences...
Token sequences identical: True

Checking selected token logprobs...
Selected token logprobs bit-exact: True

VERDICT - TRANSFORMERS INT8-BITSANDBYTES

[PASS] PERFECT REPRODUCI