In [None]:
#!/usr/bin/env python3
"""
H100 INT8 Quantization Determinism Test - Qwen3-32B with YaRN
Tests bit-exact reproducibility with INT8 quantization via bitsandbytes
Includes detailed timing measurements for performance comparison
Uses YaRN RoPE scaling to extend context to 131K tokens
"""

# ============================================================================
# SUPPRESS VERBOSE LOGGING
# ============================================================================
import os
os.environ['HF_HOME'] = '/tmp/hf_cache'
os.environ['TRANSFORMERS_CACHE'] = '/tmp/hf_cache'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

import warnings
warnings.filterwarnings('ignore')

import logging
logging.getLogger('transformers').setLevel(logging.ERROR)
logging.getLogger('torch').setLevel(logging.ERROR)
logging.getLogger('huggingface_hub').setLevel(logging.INFO)
logging.getLogger('huggingface_hub.file_download').setLevel(logging.INFO)

# ============================================================================
# IMPORTS
# ============================================================================

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoConfig
import numpy as np
from datetime import datetime
import json
import torch
import time
import glob

# ============================================================================
# CONFIGURATION
# ============================================================================

# Model configuration - INT8 quantized Qwen3-32B via bitsandbytes with YaRN
MODEL_NAME = "Qwen/Qwen3-32B"
QUANTIZATION = "bitsandbytes_int8"
TENSOR_PARALLEL_SIZE = 1  # Single GPU for clean timing

# YaRN configuration for long context
YARN_ENABLED = True
YARN_FACTOR = 4.0
YARN_ORIGINAL_MAX = 32768
MAX_CONTEXT_LENGTH = 131072  # With factor 4.0

GPU_MEMORY_UTILIZATION = 0.9  # Kept for consistency; not directly used here

# Generation configuration
MAX_TOKENS = 20
NUM_REPETITIONS = 5
TEMPERATURE = 0.0  # Greedy decoding
SEED = 42
TOP_LOGPROBS = 10

# Timing configuration
NUM_WARMUP_RUNS = 1  # Warmup for stable timing

# Prompt source
AUTO_FIND_FILE = True

# User task
USER_TASK = "Please provide a detailed summary of the following text."

# Hardcoded content (fallback if no txt/pdf present)
HARDCODED_CONTENT = """The development of large language models has fundamentally transformed natural language processing and artificial intelligence more broadly. These models, trained on vast corpora of text data, have demonstrated remarkable capabilities across a wide range of tasks, from translation and summarization to question answering and creative writing. The scaling laws observed in these systems suggest that performance continues to improve with model size, data scale, and compute budget, though with diminishing returns. Recent advances in architecture, training techniques, and inference optimization have made these powerful models increasingly accessible for practical applications."""

# ============================================================================
# FILE LOADING UTILITIES
# ============================================================================

def find_prompt_file():
    """Find first txt or pdf file in current directory"""
    cwd = os.getcwd()
    txt_files = glob.glob(os.path.join(cwd, "*.txt"))
    pdf_files = glob.glob(os.path.join(cwd, "*.pdf"))
    if txt_files:
        return txt_files[0]
    elif pdf_files:
        return pdf_files[0]
    else:
        return None

def load_text_from_file(filepath):
    """Load text from txt or pdf file"""
    if not os.path.exists(filepath):
        raise FileNotFoundError(f"File not found: {filepath}")
    if filepath.endswith(".txt"):
        with open(filepath, "r", encoding="utf-8") as f:
            text = f.read()
        print(f"Loaded {len(text)} characters from txt file")
        return text
    elif filepath.endswith(".pdf"):
        try:
            import PyPDF2
        except ImportError:
            raise ImportError("PyPDF2 required for PDF loading. Install with: pip install PyPDF2")
        text = []
        with open(filepath, "rb") as f:
            pdf_reader = PyPDF2.PdfReader(f)
            num_pages = len(pdf_reader.pages)
            print(f"Loading {num_pages} pages from PDF...")
            for page_num, page in enumerate(pdf_reader.pages, 1):
                page_text = page.extract_text()
                text.append(page_text)
                if page_num % 10 == 0:
                    print(f"  Processed {page_num}/{num_pages} pages")
        full_text = "\n".join(text)
        print(f"Loaded {len(full_text)} characters from PDF ({num_pages} pages)")
        return full_text
    else:
        raise ValueError(f"Unsupported file type: {filepath}. Use .txt or .pdf")

# ============================================================================
# PROMPT LOADING
# ============================================================================

print("=" * 80)
print("H100 INT8 QUANTIZATION DETERMINISM TEST - Qwen3-32B with YaRN")
print("=" * 80)
print()

prompt_file = None
if AUTO_FIND_FILE:
    prompt_file = find_prompt_file()
    if prompt_file:
        print(f"Found file: {os.path.basename(prompt_file)}")
        DOCUMENT_CONTENT = load_text_from_file(prompt_file)
        print()
    else:
        print("No txt/pdf files found in current directory")
        print("Using hardcoded content")
        DOCUMENT_CONTENT = HARDCODED_CONTENT
        print()
else:
    DOCUMENT_CONTENT = HARDCODED_CONTENT
    print("Using hardcoded content")
    print()

messages = [
    {"role": "user", "content": f"{USER_TASK}\n\n{DOCUMENT_CONTENT}"}
]

print(f"Message content length: {len(messages[0]['content'])} characters")
print()

# ============================================================================
# TOKENIZER PRECHECK AND CONFIG WITH YARN
# ============================================================================

print("Loading tokenizer and config with YaRN configuration...")
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    cache_dir="/tmp/hf_cache",
    trust_remote_code=True
)

config = AutoConfig.from_pretrained(
    MODEL_NAME,
    cache_dir="/tmp/hf_cache",
    trust_remote_code=True
)

# Apply YaRN RoPE scaling for extended context
if YARN_ENABLED:
    config.rope_scaling = {
        "rope_type": "yarn",
        "factor": YARN_FACTOR,
        "original_max_position_embeddings": YARN_ORIGINAL_MAX
    }
    max_context_len = MAX_CONTEXT_LENGTH
    print(f"YaRN enabled: factor={YARN_FACTOR}, extended to {max_context_len:,} tokens")
else:
    if hasattr(config, "max_position_embeddings") and config.max_position_embeddings is not None:
        max_context_len = int(config.max_position_embeddings)
    else:
        max_context_len = 32768
    print(f"YaRN disabled: native context {max_context_len:,} tokens")

prompt_text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)

prompt_tokens = tokenizer.encode(prompt_text)
prompt_length = len(prompt_tokens)

print()
print("Prompt statistics:")
print(f"  Characters: {len(prompt_text):,}")
print(f"  Tokens: {prompt_length:,}")
print(f"  Max context (with YaRN): {max_context_len:,}")
print(f"  Planned generation tokens: {MAX_TOKENS}")
print(f"  Total required: {prompt_length + MAX_TOKENS:,}")
print()

if prompt_length + MAX_TOKENS > max_context_len:
    print(f"[ERROR] Requested context length ({prompt_length + MAX_TOKENS}) exceeds model max ({max_context_len}).")
    print("        Truncate the input document or increase YaRN factor.")
    raise SystemExit(1)

# ============================================================================
# MODEL LOADING
# ============================================================================

print("Configuration:")
print(f"  Model: {MODEL_NAME}")
print(f"  Quantization: {QUANTIZATION}")
print(f"  Precision: INT8 weights via bitsandbytes, higher-precision activations")
print(f"  Compute: INT8-storage weights with dequantization for matmul")
print(f"  KV cache: model default (higher precision)")
print(f"  RoPE scaling: YaRN (factor={YARN_FACTOR}, base={YARN_ORIGINAL_MAX})")
print(f"  Tensor parallel: {TENSOR_PARALLEL_SIZE}")
print(f"  Max model len: {max_context_len:,}")
print(f"  Max new tokens: {MAX_TOKENS}")
print(f"  Temperature: {TEMPERATURE}")
print(f"  Seed: {SEED}")
print(f"  Repetitions: {NUM_REPETITIONS}")
print(f"  Warmup runs: {NUM_WARMUP_RUNS}")
print()
print("IMPORTANT: Using bitsandbytes INT8 quantization with YaRN extended context")
print("           NOTE: Static YaRN may degrade performance on SHORT contexts (<32K)")
print()

print("Loading INT8 quantized model with YaRN-extended context...")
load_start = time.time()

bnb_config = BitsAndBytesConfig(
    load_in_8bit=True,
)

torch.manual_seed(SEED)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    cache_dir="/tmp/hf_cache",
    config=config,  # Pass modified config with YaRN
    quantization_config=bnb_config,
    device_map={"": 0},  # force all modules onto cuda:0
    trust_remote_code=True
)

model.eval()

load_time = time.time() - load_start
print(f"Model loaded in {load_time:.2f}s")
print()

# ============================================================================
# SAMPLING CONFIGURATION
# ============================================================================

generation_kwargs = {
    "max_new_tokens": MAX_TOKENS,
    "do_sample": False,
    "temperature": None,
    "return_dict_in_generate": True,
    "output_scores": True,
}

# ============================================================================
# WARMUP
# ============================================================================

print(f"Running {NUM_WARMUP_RUNS} warmup iterations...")
warmup_times = []
for i in range(NUM_WARMUP_RUNS):
    inputs = tokenizer(
        prompt_text,
        return_tensors="pt"
    )
    input_ids = inputs["input_ids"].to(model.device)
    attention_mask = inputs["attention_mask"].to(model.device)

    warmup_start = time.time()
    with torch.no_grad():
        gen_out = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **generation_kwargs
        )
    warmup_time = time.time() - warmup_start

    generated_tokens = gen_out.sequences[0][input_ids.shape[1]:]
    num_tokens = generated_tokens.shape[0]
    warmup_times.append(warmup_time)
    if num_tokens > 0 and warmup_time > 0:
        print(f"  Warmup {i+1}: {warmup_time:.3f}s ({num_tokens / warmup_time:.1f} tok/s)")
    else:
        print(f"  Warmup {i+1}: {warmup_time:.3f}s (0 tok/s)")

if warmup_times:
    print(f"Warmup complete - avg time: {np.mean(warmup_times):.3f}s")
else:
    print("Warmup complete - no valid timing data")
print()

# ============================================================================
# MAIN EXPERIMENT
# ============================================================================

print("=" * 80)
print("RUNNING EXPERIMENT")
print("=" * 80)
print()

results_tokens = []
results_logprobs = []
results_texts = []
results_distributions = []
timing_data = []

for rep in range(NUM_REPETITIONS):
    print(f"Repetition {rep + 1}/{NUM_REPETITIONS}...")
    
    inputs = tokenizer(
        prompt_text,
        return_tensors="pt"
    )
    input_ids = inputs["input_ids"].to(model.device)
    attention_mask = inputs["attention_mask"].to(model.device)

    start_time = time.time()
    with torch.no_grad():
        gen_out = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **generation_kwargs
        )
    end_time = time.time()
    
    elapsed = end_time - start_time

    sequences = gen_out.sequences
    scores = gen_out.scores  # list of (batch, vocab) tensors, one per generated token

    gen_token_ids = sequences[0][input_ids.shape[1]:]
    token_ids = gen_token_ids.tolist()
    num_tokens = len(token_ids)
    tokens_per_sec = (num_tokens / elapsed) if elapsed > 0 else 0.0
    
    results_tokens.append(token_ids)
    
    timing_data.append({
        "repetition": rep + 1,
        "elapsed_time": elapsed,
        "num_tokens": num_tokens,
        "tokens_per_sec": tokens_per_sec,
        "time_per_token": (elapsed / num_tokens) if num_tokens > 0 else 0.0,
    })
    
    text = tokenizer.decode(gen_token_ids, skip_special_tokens=False)
    results_texts.append(text)
    
    # Compute selected token logprobs and top-k distributions
    selected_logprobs = []
    per_token_topk = []
    for t, score_t in enumerate(scores):
        logits = score_t[0]
        logprobs = torch.log_softmax(logits, dim=-1)

        tok_id = token_ids[t]
        selected_logprobs.append(float(logprobs[tok_id].cpu().item()))

        topk = torch.topk(logprobs, k=min(TOP_LOGPROBS, logprobs.shape[-1]))
        topk_ids = topk.indices.cpu().tolist()
        topk_vals = topk.values.cpu().tolist()
        per_token_topk.append(list(zip(topk_ids, topk_vals)))
    
    results_logprobs.append(np.array(selected_logprobs))
    results_distributions.append(per_token_topk)
    
    print(f"  {num_tokens} tokens in {elapsed:.3f}s ({tokens_per_sec:.1f} tok/s)")

print()
print("All repetitions complete!")
print()

# ============================================================================
# TIMING ANALYSIS
# ============================================================================

print("=" * 80)
print("TIMING ANALYSIS")
print("=" * 80)
print()

times = [t["elapsed_time"] for t in timing_data]
tps = [t["tokens_per_sec"] for t in timing_data]
tpt = [t["time_per_token"] for t in timing_data if t["time_per_token"] > 0]

print("Timing statistics:")
if times:
    print(f"  Mean time: {np.mean(times):.3f}s (σ={np.std(times):.4f}s)")
    print(f"  Min/Max: {np.min(times):.3f}s / {np.max(times):.3f}s")
else:
    print("  No timing data collected")

if tps:
    print(f"  Tokens/sec: {np.mean(tps):.1f} (σ={np.std(tps):.2f})")
else:
    print("  Tokens/sec: N/A")

if tpt:
    print(f"  Time/token: {np.mean(tpt) * 1000:.2f}ms (σ={np.std(tpt) * 1000:.3f}ms)")
else:
    print("  Time/token: N/A")
print()

# ============================================================================
# REPRODUCIBILITY ANALYSIS
# ============================================================================

print("=" * 80)
print("REPRODUCIBILITY ANALYSIS")
print("=" * 80)
print()
print("Comparing across repetitions (same config, same seed):")
print("  Checking if Run 1 = Run 2 = ... = Run N")
print("  Testing: Can we reproduce inference bit-exactly with YaRN?")
print()

print("Checking token sequences...")
tokens_identical = all(
    results_tokens[0] == results_tokens[i]
    for i in range(1, NUM_REPETITIONS)
)
print(f"Token sequences identical: {tokens_identical}")

if not tokens_identical:
    print("\n[WARNING] Token sequences differ!")
    for i in range(1, NUM_REPETITIONS):
        if results_tokens[0] != results_tokens[i]:
            diff_positions = [
                j for j in range(min(len(results_tokens[0]), len(results_tokens[i])))
                if results_tokens[0][j] != results_tokens[i][j]
            ]
            print(f"  Rep 0 vs Rep {i}: {len(diff_positions)} differing positions")
            if diff_positions:
                pos = diff_positions[0]
                print(f"    First difference at position {pos}")
                print(f"      Rep 0: token {results_tokens[0][pos]}")
                print(f"      Rep {i}: token {results_tokens[i][pos]}")

print("\nChecking selected token logprobs...")
if results_logprobs:
    first_logprobs = results_logprobs[0]
    logprobs_exact = all(
        np.allclose(first_logprobs, results_logprobs[i], rtol=0, atol=1e-10)
        for i in range(1, NUM_REPETITIONS)
    )
else:
    first_logprobs = None
    logprobs_exact = False

print(f"Selected token logprobs bit-exact: {logprobs_exact}")

if not logprobs_exact and first_logprobs is not None:
    print("\nL2 distances:")
    l2_distances = []
    for i in range(1, NUM_REPETITIONS):
        if len(results_logprobs[i]) == len(first_logprobs):
            l2 = np.linalg.norm(first_logprobs - results_logprobs[i])
            l2_distances.append(l2)
            print(f"  Rep 0 vs Rep {i}: L2 = {l2:.6e}")
        else:
            print(f"  Rep 0 vs Rep {i}: length mismatch")
    if l2_distances:
        print(f"\nMax L2: {max(l2_distances):.6e}")
        print(f"Mean L2: {np.mean(l2_distances):.6e}")

print("\nChecking full top-k distributions...")
distributions_exact = True
distribution_mismatches = []

if results_distributions:
    first_dist = results_distributions[0]
    for rep_idx in range(1, NUM_REPETITIONS):
        if len(results_distributions[rep_idx]) != len(first_dist):
            distributions_exact = False
            distribution_mismatches.append((rep_idx, "length_mismatch"))
            continue
        for pos_idx in range(len(first_dist)):
            dist_a = first_dist[pos_idx]
            dist_b = results_distributions[rep_idx][pos_idx]
            tokens_match = [t[0] for t in dist_a] == [t[0] for t in dist_b]
            if tokens_match:
                logprobs_match = all(
                    abs(dist_a[i][1] - dist_b[i][1]) < 1e-10
                    for i in range(len(dist_a))
                )
                if not logprobs_match:
                    distributions_exact = False
                    distribution_mismatches.append((rep_idx, pos_idx))
            else:
                distributions_exact = False
                distribution_mismatches.append((rep_idx, pos_idx))

print(f"Top-k distributions bit-exact: {distributions_exact}")

if not distributions_exact:
    print(f"\nFound {len(distribution_mismatches)} mismatches")
print()

# ============================================================================
# VERDICT
# ============================================================================

print("=" * 80)
print("VERDICT - INT8 QUANTIZATION + YaRN (bitsandbytes + Qwen3-32B)")
print("=" * 80)
print()

if tokens_identical and logprobs_exact and distributions_exact:
    print("[PASS] PERFECT REPRODUCIBILITY WITH INT8 + YaRN")
    print("  - Token sequences: bit-exact")
    print("  - Selected token logprobs: bit-exact")
    print("  - Full top-k distributions: bit-exact")
    print("  => INT8 quantization + YaRN maintains determinism")
elif tokens_identical and logprobs_exact and not distributions_exact:
    print("[WARNING] SELECTED TOKENS EXACT, DISTRIBUTIONS VARY")
    print("  - Token sequences: bit-exact")
    print("  - Selected token logprobs: bit-exact")
    print("  - Top-k distributions: numerical variation")
    print("  => INT8 + YaRN introduces distribution-level noise")
elif tokens_identical and not logprobs_exact:
    print("[WARNING] TOKENS IDENTICAL, LOGPROBS VARY")
    print("  - Token sequences: bit-exact")
    print("  - Logprobs: numerical variation")
    print("  => INT8 + YaRN introduces FP variation in computations")
else:
    print("[FAIL] TOKEN SEQUENCES DIFFER - DETERMINISM BROKEN")
    print("  - INT8 quantization + YaRN breaks greedy decoding determinism")
print()

# ============================================================================
# SAVE RESULTS
# ============================================================================

output_data = {
    "experiment": "h100_int8_determinism_test_bitsandbytes_qwen3_32b_yarn",
    "timestamp": datetime.now().isoformat(),
    "hardware": "H100",
    "prompt_source": "file" if prompt_file else "hardcoded",
    "prompt_file": os.path.basename(prompt_file) if prompt_file else None,
    "prompt_text": prompt_text,
    "prompt_length_chars": len(prompt_text),
    "prompt_length_tokens": prompt_length,
    "config": {
        "model": MODEL_NAME,
        "quantization": QUANTIZATION,
        "precision": "INT8",
        "yarn_enabled": YARN_ENABLED,
        "yarn_factor": YARN_FACTOR,
        "yarn_original_max": YARN_ORIGINAL_MAX,
        "tensor_parallel": TENSOR_PARALLEL_SIZE,
        "max_context_len": max_context_len,
        "max_tokens": MAX_TOKENS,
        "repetitions": NUM_REPETITIONS,
        "warmup_runs": NUM_WARMUP_RUNS,
        "temperature": TEMPERATURE,
        "seed": SEED,
        "top_logprobs": TOP_LOGPROBS,
    },
    "timing": {
        "model_load_time": load_time,
        "warmup_times": warmup_times,
        "per_repetition": timing_data,
        "statistics": {
            "mean_time": float(np.mean(times)) if times else None,
            "std_time": float(np.std(times)) if times else None,
            "mean_tokens_per_sec": float(np.mean(tps)) if tps else None,
            "std_tokens_per_sec": float(np.std(tps)) if tps else None,
            "mean_time_per_token_ms": float(np.mean(tpt) * 1000) if tpt else None,
            "std_time_per_token_ms": float(np.std(tpt) * 1000) if tpt else None,
        },
    },
    "results": {
        "tokens_identical": tokens_identical,
        "logprobs_exact": logprobs_exact,
        "distributions_exact": distributions_exact,
        "perfect_reproducibility": tokens_identical and logprobs_exact and distributions_exact,
    },
    "token_sequences": results_tokens,
    "logprobs_vectors": [lp.tolist() for lp in results_logprobs],
    "generated_texts": results_texts,
}

output_file = f"h100_int8_determinism_qwen3_32b_yarn_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
with open(output_file, "w") as f:
    json.dump(output_data, f, indent=2)

print(f"Results saved to: {output_file}")
print()
print("=" * 80)
print("TEST COMPLETE")
print("=" * 80)

H100 INT8 QUANTIZATION DETERMINISM TEST - Qwen3-32B with YaRN

Found file: NSA.pdf
Loading 24 pages from PDF...
  Processed 10/24 pages
  Processed 20/24 pages
Loaded 59348 characters from PDF (24 pages)

Message content length: 59406 characters

Loading tokenizer and config with YaRN configuration...
YaRN enabled: factor=4.0, extended to 131,072 tokens

Prompt statistics:
  Characters: 59,456
  Tokens: 16,432
  Max context (with YaRN): 131,072
  Planned generation tokens: 20
  Total required: 16,452

Configuration:
  Model: Qwen/Qwen3-32B
  Quantization: bitsandbytes_int8
  Precision: INT8 weights via bitsandbytes, higher-precision activations
  Compute: INT8-storage weights with dequantization for matmul
  KV cache: model default (higher precision)
  RoPE scaling: YaRN (factor=4.0, base=32768)
  Tensor parallel: 1
  Max model len: 131,072
  Max new tokens: 20
  Temperature: 0.0
  Seed: 42
  Repetitions: 5
  Warmup runs: 1

IMPORTANT: Using bitsandbytes INT8 quantization with YaRN ext

Loading checkpoint shards:   0%|          | 0/17 [00:00<?, ?it/s]