# FP8 Conversion Test - Mac M3 Ultra

This notebook tests the FP8 conversion on Mac with RAM monitoring.

**Requirements:**
- Mac M3 Ultra with 96GB RAM
- gpt-oss-20b downloaded locally
- JAX, Orbax, and dependencies installed

**Cells:**
1. Cell 0: Setup (configure paths)
2. Cell 1: Download check (skips if path already set)
3. Cell 2: FP8 Conversion (with RAM monitoring)

In [None]:
# Cell 0: Local Mac M3 Ultra setup - simulate TPU v6e for FP8 conversion testing
import os
import jax
import jax.numpy as jnp
from pathlib import Path

# Force JAX to use CPU (Mac doesn't have TPU)
os.environ['JAX_PLATFORMS'] = 'cpu'

# Pretend to be TPU v6e to trigger FP8 conversion
# (This is just for testing the conversion logic)
os.environ['COLAB_TPU_ADDR'] = 'fake-tpu-v6e'

# Verify JAX is using CPU
print("JAX devices:", jax.devices())
print("JAX backend:", jax.default_backend())

# Set strategy manually for local testing
STRATEGY = "fp8"  # Test FP8 conversion
DTYPE = jnp.float8_e4m3fn

# Set local paths (Mac uses current directory, not /content/)
# ⚠️ UPDATE THIS: Point to the directory containing .safetensors files
# Common paths:
#   - "gpt-oss-20b/original" (HuggingFace cache structure)
#   - "/Users/yourname/models/gpt-oss-20b/original"
#   - "~/Downloads/gpt-oss-20b/original"
safetensors_path = Path("gpt-oss-20b/original")  # ← UPDATE THIS PATH!
orbax_path = f"gpt-oss-20b-orbax-{STRATEGY}"

# Verify path exists and contains .safetensors files
if not safetensors_path.exists():
    raise FileNotFoundError(
        f"Path does not exist: {safetensors_path}\n"
        f"Please update the safetensors_path in Cell 0 to point to your gpt-oss-20b download"
    )

st_files = list(safetensors_path.glob('*.safetensors'))
if len(st_files) == 0:
    # Try common subdirectories
    for subdir in ['original', 'models', '']:
        candidate = safetensors_path / subdir if subdir else safetensors_path
        st_files = list(candidate.glob('*.safetensors'))
        if st_files:
            safetensors_path = candidate
            break
    
    if len(st_files) == 0:
        raise FileNotFoundError(
            f"No .safetensors files found in: {safetensors_path}\n"
            f"Please update the safetensors_path in Cell 0 to the correct directory\n"
            f"Expected files like: model.safetensors or model-00001-of-00002.safetensors"
        )

print(f"\n✓ Local Mac setup complete")
print(f"  Strategy: {STRATEGY.upper()}")
print(f"  Target dtype: {DTYPE}")
print(f"  SafeTensors path: {safetensors_path}")
print(f"  Found {len(st_files)} .safetensors file(s)")
print(f"  Orbax output: {orbax_path}")
print(f"\n⚠️ Now run Cell 1 (Download check) then Cell 2 (Conversion)")

In [None]:
# Cell 1: Download Check (skips download on Mac)
from huggingface_hub import snapshot_download
from pathlib import Path

# Only download if not already set (e.g., from Cell 0 for Mac testing)
if 'safetensors_path' not in globals():
    print("Downloading GPT-OSS-20B (13.8 GB)...")
    checkpoint_dir = snapshot_download(
        repo_id="openai/gpt-oss-20b",
        revision="main",
        allow_patterns=["original/*"],
        local_dir="/content/gpt-oss-20b-dl",
        local_dir_use_symlinks=False
    )
    safetensors_path = Path("/content/gpt-oss-20b-dl/original")
    print(f"✓ Downloaded: {safetensors_path}")
else:
    print(f"✓ Using existing SafeTensors path: {safetensors_path}")
    print("  (Set in Cell 0 - skipping download)")

In [None]:
# Cell 2: FP8 Conversion with RAM Monitoring
import time
import numpy as np
import psutil
from gpt_oss.jax.config import ModelConfig
from gpt_oss.jax.loader_safetensors import WeightLoader
from orbax.checkpoint import PyTreeCheckpointer
import orbax.checkpoint as ocp
from safetensors import safe_open
from pathlib import Path
from flax import traverse_util
from tqdm import tqdm

# Helper function to track memory usage
def get_ram_gb():
    """Get current RAM usage in GB."""
    return psutil.Process().memory_info().rss / 1024**3

config = ModelConfig()
# Set orbax_path if not already set (e.g., from Cell 0 for Mac testing)
if 'orbax_path' not in globals():
    orbax_path = f"/content/gpt-oss-20b-orbax-{STRATEGY}"

print(f"Converting to Orbax ({STRATEGY.upper()})...")
print("Loading tensors on CPU to avoid TPU OOM...")
ram_start = get_ram_gb()
print(f"Starting RAM: {ram_start:.2f} GB")
t0 = time.time()

# For TPU v6e with FP8: Load on CPU, convert, then save
# This avoids loading 42GB BF16 into 32GB TPU memory
with jax.default_device(jax.devices('cpu')[0]):
    loader = WeightLoader(str(safetensors_path))
    # Loads all weights as BF16:
    # - Regular weights: Loaded as BF16 directly from SafeTensors
    # - MXFP4 MoE weights: Decompressed MXFP4 → BF16 (see loader_safetensors.py)
    params = loader.load_params(config, show_progress=True)
    ram_after_load = get_ram_gb()
    print(f"After loading BF16: {ram_after_load:.2f} GB (+{ram_after_load - ram_start:.2f} GB)")
    
    # Convert BF16 → FP8 if using FP8 strategy (TPU v6e)
    if STRATEGY == "fp8":
        print("Converting BF16 → FP8 on CPU...")
        # Converts all BF16 tensors to FP8, leaves other dtypes unchanged
        params = jax.tree_util.tree_map(
            lambda x: x.astype(DTYPE) if x.dtype == jnp.bfloat16 else x,
            params
        )
        ram_after_fp8 = get_ram_gb()
        print(f"After FP8 conversion: {ram_after_fp8:.2f} GB (+{ram_after_fp8 - ram_start:.2f} GB total)")

# Save to Orbax format (still on CPU)
print("Saving to Orbax...")
checkpointer = ocp.PyTreeCheckpointer()
checkpointer.save(orbax_path, params, save_args=ocp.SaveArgs(aggregate=True))
ram_peak = get_ram_gb()

print(f"✓ Converted in {time.time()-t0:.1f}s")
print(f"  Orbax checkpoint: {orbax_path}")
print(f"  Peak RAM: {ram_peak:.2f} GB (+{ram_peak - ram_start:.2f} GB)")

# Free CPU memory
del params
import gc
gc.collect()
ram_after_cleanup = get_ram_gb()
print(f"  After cleanup: {ram_after_cleanup:.2f} GB ({ram_peak - ram_after_cleanup:.2f} GB freed)")

# Note: On Mac we keep the SafeTensors (no /content/ to clean up)
print("\n✓ FP8 conversion complete!")
print(f"\nExpected RAM usage:")
print(f"  - Starting: ~0.2 GB")
print(f"  - After BF16 load: ~42-44 GB")
print(f"  - After FP8 convert: ~21-23 GB")
print(f"  - Peak: ~25-30 GB")