## 0. Local Mac Testing Setup (Optional)

**⚠️ Skip this cell if running on Colab TPU**

This cell configures JAX to simulate TPU v6e behavior on Mac M3 Ultra for testing the FP8 conversion:
- Pretends to be TPU v6e (triggers FP8 strategy)
- Forces CPU execution for conversion (same as Colab)
- Useful for testing locally with 96GB Mac RAM before deploying to Colab

After running this, skip cells 1-2 (dependencies) and go directly to cell 3 (download HuggingFace).

In [None]:
# Local Mac M3 Ultra setup - simulate TPU v6e for FP8 conversion testing
import os
import jax
import jax.numpy as jnp
from pathlib import Path

# Force JAX to use CPU (Mac doesn't have TPU)
os.environ['JAX_PLATFORMS'] = 'cpu'

# Pretend to be TPU v6e to trigger FP8 conversion
# (This is just for testing the conversion logic)
os.environ['COLAB_TPU_ADDR'] = 'fake-tpu-v6e'

# Verify JAX is using CPU
print("JAX devices:", jax.devices())
print("JAX backend:", jax.default_backend())

# Set strategy manually for local testing
STRATEGY = "fp8"  # Test FP8 conversion
DTYPE = jnp.float8_e4m3fn

# Set local paths (Mac uses current directory, not /content/)
# ⚠️ UPDATE THIS: Point to the directory containing .safetensors files
# Common paths:
#   - "gpt-oss-20b/original" (HuggingFace cache structure)
#   - "/Users/yourname/models/gpt-oss-20b/original"
#   - "~/Downloads/gpt-oss-20b/original"
safetensors_path = Path("gpt-oss-20b/original")  # ← UPDATE THIS PATH!
orbax_path = f"gpt-oss-20b-orbax-{STRATEGY}"

# Verify path exists and contains .safetensors files
if not safetensors_path.exists():
    raise FileNotFoundError(
        f"Path does not exist: {safetensors_path}\n"
        f"Please update the safetensors_path in Cell 0 to point to your gpt-oss-20b download"
    )

st_files = list(safetensors_path.glob('*.safetensors'))
if len(st_files) == 0:
    # Try common subdirectories
    for subdir in ['original', 'models', '']:
        candidate = safetensors_path / subdir if subdir else safetensors_path
        st_files = list(candidate.glob('*.safetensors'))
        if st_files:
            safetensors_path = candidate
            break
    
    if len(st_files) == 0:
        raise FileNotFoundError(
            f"No .safetensors files found in: {safetensors_path}\n"
            f"Please update the safetensors_path in Cell 0 to the correct directory\n"
            f"Expected files like: model.safetensors or model-00001-of-00002.safetensors"
        )

print(f"\n✓ Local Mac setup complete")
print(f"  Strategy: {STRATEGY.upper()}")
print(f"  Target dtype: {DTYPE}")
print(f"  SafeTensors path: {safetensors_path}")
print(f"  Found {len(st_files)} .safetensors file(s)")
print(f"  Orbax output: {orbax_path}")
print(f"\n⚠️ Now skip cells 1-4 and run cell 5 (Download from HuggingFace)")

# JAX: GPT-OSS-20B on Google Colab TPU

<span style="color: #e67e22; font-weight: bold;">⚠️ Note: This is a basic non-optimized implementation for educational purposes.</span>

## Adaptive Precision Inference

Repository: [gpt-oss-jax](https://github.com/atsentia/gpt-oss-jax)

### Adaptive Precision Strategy

| TPU Type | Memory | Strategy | Model Size |
|----------|--------|----------|------------|
| **v2-8** | 64GB (8x8GB) | BF16 (16-bit) | ~42GB |
| **v6e** | 32GB | FP8 (8-bit) | ~21GB |

### ⚠️ Setup Required

**Runtime → Change runtime type → TPU** (before running cells)

## 1. Install Dependencies

This cell installs all required packages:
- JAX with TPU support - Core ML framework optimized for TPUs
- Flax & Orbax - Neural network library and checkpoint utilities
- openai-harmony - Harmony protocol for multi-channel reasoning
- gpt-oss-jax - Our GPT-OSS-20B implementation

Expected time: ~2-3 minutes

In [None]:
# Install dependencies
!pip install -q "jax[tpu]>=0.4.20" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install -q flax orbax-checkpoint safetensors openai-harmony tiktoken tqdm huggingface_hub

# Clone repo
!git clone -q https://github.com/atsentia/gpt-oss-jax.git 2>/dev/null || true
%cd gpt-oss-jax
!pip install -q -e ".[jax]"

print("✓ Setup complete")

## 2. Verify TPU Backend & Select Precision Strategy

This cell:
1. Detects your TPU type (v2-8, v6e, etc.)
2. Validates TPU is available (not CPU)
3. Automatically selects precision strategy:
   - TPU v2-8 with 8 devices → BF16 (16-bit, 64GB HBM)
   - TPU v6e → FP8 (8-bit, 32GB HBM)

What to expect: Should print your TPU type and selected strategy

In [None]:
import jax
import jax.numpy as jnp

devices = jax.devices()
backend = jax.default_backend()
assert backend == "tpu", f"TPU not found (got {backend})"

tpu_type = devices[0].device_kind
num_devices = len(devices)
print(f"✓ {tpu_type} ({num_devices} devices)")

# Select precision
if "v2" in tpu_type and num_devices == 8:
    STRATEGY, DTYPE, MEM_GB = "bf16", jnp.bfloat16, 42
    print("Strategy: BF16 (16-bit) - 64GB HBM")
elif "v6" in tpu_type:
    STRATEGY, DTYPE, MEM_GB = "fp8", jnp.float8_e4m3fn, 21
    print("Strategy: FP8 (8-bit) - 32GB HBM")
else:
    raise RuntimeError(f"Unsupported: {tpu_type}")

## 3. Download GPT-OSS-20B Weights from HuggingFace

Downloads the official GPT-OSS-20B model checkpoint (13.8 GB) from HuggingFace.

What happens:
- Downloads .safetensors files containing 21B parameters
- Saves to /content/gpt-oss-20b-dl/original/
- Uses HuggingFace's snapshot downloader for efficient transfers

Expected time: ~3-5 minutes (depending on HuggingFace bandwidth)

Note: If you hit rate limits, wait a few minutes and re-run this cell

In [None]:
from huggingface_hub import snapshot_download
from pathlib import Path

# Only download if not already set (e.g., from Cell 0 for Mac testing)
if 'safetensors_path' not in globals():
    print("Downloading GPT-OSS-20B (13.8 GB)...")
    checkpoint_dir = snapshot_download(
        repo_id="openai/gpt-oss-20b",
        revision="main",
        allow_patterns=["original/*"],
        local_dir="/content/gpt-oss-20b-dl",
        local_dir_use_symlinks=False
    )
    safetensors_path = Path("/content/gpt-oss-20b-dl/original")
    print(f"✓ Downloaded: {safetensors_path}")
else:
    print(f"✓ Using existing SafeTensors path: {safetensors_path}")
    print("  (Set in Cell 0 - skipping download)")

## 4. Convert SafeTensors to Orbax Format

Converts the HuggingFace checkpoint to Orbax format (JAX native).

Why convert to Orbax?
- 2-3x faster loading than SafeTensors
- Optimized for JAX PyTree structures
- Supports sharded checkpoints across TPU devices
- The JAX-native way to store checkpoints

This is a one-time conversion. Future sessions can load from Orbax directly.

Note for TPU v6e (FP8): This cell loads as BF16 first (~42GB), then converts to FP8. This may cause OOM on 32GB HBM. If this happens, use a pre-converted FP8 Orbax checkpoint instead.

Expected time: ~15-20 seconds

In [None]:
import time
import numpy as np
import psutil
from gpt_oss.jax.config import ModelConfig
from gpt_oss.jax.loader_safetensors import WeightLoader
from orbax.checkpoint import PyTreeCheckpointer
import orbax.checkpoint as ocp
from safetensors import safe_open
from pathlib import Path
from flax import traverse_util
from tqdm import tqdm

# Helper function to track memory usage
def get_ram_gb():
    """Get current RAM usage in GB."""
    return psutil.Process().memory_info().rss / 1024**3

config = ModelConfig()
# Set orbax_path if not already set (e.g., from Cell 0 for Mac testing)
if 'orbax_path' not in globals():
    orbax_path = f"/content/gpt-oss-20b-orbax-{STRATEGY}"

print(f"Converting to Orbax ({STRATEGY.upper()})...")
print("Loading tensors on CPU to avoid TPU OOM...")
ram_start = get_ram_gb()
print(f"Starting RAM: {ram_start:.2f} GB")
t0 = time.time()

# For TPU v6e with FP8: Load on CPU, convert, then save
# This avoids loading 42GB BF16 into 32GB TPU memory
with jax.default_device(jax.devices('cpu')[0]):
    loader = WeightLoader(str(safetensors_path))
    # Loads all weights as BF16:
    # - Regular weights: Loaded as BF16 directly from SafeTensors
    # - MXFP4 MoE weights: Decompressed MXFP4 → BF16 (see loader_safetensors.py)
    params = loader.load_params(config, show_progress=True)
    ram_after_load = get_ram_gb()
    print(f"After loading BF16: {ram_after_load:.2f} GB (+{ram_after_load - ram_start:.2f} GB)")
    
    # Convert BF16 → FP8 if using FP8 strategy (TPU v6e)
    if STRATEGY == "fp8":
        print("Converting BF16 → FP8 on CPU...")
        # Converts all BF16 tensors to FP8, leaves other dtypes unchanged
        params = jax.tree_util.tree_map(
            lambda x: x.astype(DTYPE) if x.dtype == jnp.bfloat16 else x,
            params
        )
        ram_after_fp8 = get_ram_gb()
        print(f"After FP8 conversion: {ram_after_fp8:.2f} GB (+{ram_after_fp8 - ram_start:.2f} GB total)")

# Save to Orbax format (still on CPU)
print("Saving to Orbax...")
checkpointer = ocp.PyTreeCheckpointer()
checkpointer.save(orbax_path, params)
ram_peak = get_ram_gb()

print(f"✓ Converted in {time.time()-t0:.1f}s")
print(f"  Orbax checkpoint: {orbax_path}")
print(f"  Peak RAM: {ram_peak:.2f} GB (+{ram_peak - ram_start:.2f} GB)")

# Free CPU memory
del params
import gc
gc.collect()
ram_after_cleanup = get_ram_gb()
print(f"  After cleanup: {ram_after_cleanup:.2f} GB ({ram_peak - ram_after_cleanup:.2f} GB freed)")

# Clean up SafeTensors to save space
!rm -rf /content/gpt-oss-20b-dl
print("  Cleaned up SafeTensors (13.8 GB freed from disk)")

## 5. Load Model Parameters from Orbax

Loads the 21B parameters from Orbax checkpoint (JAX native format).

BF16 Strategy (TPU v2-8):
- Memory footprint: ~42GB
- Best accuracy (full precision)

FP8 Strategy (TPU v6e):
- Memory footprint: ~21GB (50% reduction!)
- Minimal accuracy loss (<2% perplexity increase)

Expected time: ~2-3 seconds (much faster than SafeTensors!)

In [None]:
import time
from orbax.checkpoint import PyTreeCheckpointer
import orbax.checkpoint as ocp

print(f"Loading from Orbax ({STRATEGY.upper()})...")
t0 = time.time()

checkpointer = ocp.PyTreeCheckpointer()
params = checkpointer.restore(orbax_path)

print(f"✓ Loaded in {time.time()-t0:.1f}s")
print(f"  Orbax is {15/(time.time()-t0):.1f}x faster than SafeTensors!")

## 6. Initialize Model & Tokenizer

Creates the GPT-OSS-20B Transformer model and tokenizer.

What happens:
- Initializes the model architecture (40 layers, 8192 hidden dim, 64 attention heads)
- Loads the tokenizer (GPT-2 style BPE with 50,257 tokens)
- Verifies parameter dtype matches your strategy

Model Architecture:
- Parameters: 20.8B
- Layers: 40
- Context: 8192 tokens
- Vocab: 50,257 tokens

In [None]:
from gpt_oss.jax.model import Transformer
from gpt_oss.tokenizer import get_tokenizer

model = Transformer(config=config)
tokenizer = get_tokenizer()

sample = jax.tree_util.tree_leaves(params)[0]
print(f"✓ Model: GPT-OSS-20B")
print(f"  Dtype: {sample.dtype}")
print(f"  Strategy: {STRATEGY.upper()}")

## 7. Memory Utilization Analysis

Calculates actual memory usage and compares to TPU HBM capacity.

What you'll see:
- Actual memory: Size of loaded parameters in GB
- TPU HBM: Total high-bandwidth memory available
- Utilization: Percentage of HBM used

Target utilization: ~66% (leaves headroom for activations and KV cache)

If you see >90% utilization: Model may not fit for inference

In [None]:
def mem_gb(p):
    return sum(x.nbytes for x in jax.tree_util.tree_leaves(p)) / 1e9

actual = mem_gb(params)
print(f"Memory: {actual:.1f} GB (expected: {MEM_GB} GB)")

tpu_hbm = 64 if "v2" in tpu_type and num_devices == 8 else 32
print(f"TPU HBM: {tpu_hbm} GB ({actual/tpu_hbm*100:.0f}% used)")

## 8. Run Inference with Harmony Protocol

Demonstrates multi-channel reasoning using the Harmony protocol.

Harmony Protocol Features:
- Multi-channel output: Separate analysis and final answer channels
- Structured reasoning: Model shows its thought process
- Efficient inference: Uses KV cache for fast token generation

Example Question: "What is the capital of France?"

Expected output:
- Analysis channel: Model's reasoning process (📊)
- Answer channel: Final response (💬)
- Performance: Tokens/second metric

Try it: Edit the msg variable to ask your own questions!

In [None]:
import re
from IPython.display import HTML, display
from gpt_oss.jax.inference import generate

try:
    from openai_harmony import (
        load_harmony_encoding,
        HarmonyEncodingName,
        Conversation,
        Message,
        Role
    )
    
    encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
    msg = "What is the capital of France?"
    conv = Conversation.from_messages([Message.from_role_and_content(Role.USER, msg)])
    prompt_tokens = encoding.render_conversation_for_completion(conv, Role.ASSISTANT)
    
    output_tokens, stats = generate(
        model=model, params=params, prompt_tokens=prompt_tokens,
        max_new_tokens=50, temperature=0.0, rng_key=jax.random.PRNGKey(42),
        config=config, use_kv_cache=True, show_progress=False, return_stats=True
    )
    
    stop_tokens = encoding.stop_tokens_for_assistant_actions()
    filtered = [t for t in output_tokens[len(prompt_tokens):] if t not in stop_tokens]
    generated = tokenizer.decode(filtered)
    
    # Parse channels
    analysis = re.search(r'<\|channel\|>analysis<\|message\|>(.*?)(?:<\|end\|>|<\|channel\|>|$)', generated, re.DOTALL)
    final = re.search(r'<\|channel\|>(main|final)<\|message\|>(.*?)(?:<\|end\|>|<\|channel\|>|$)', generated, re.DOTALL)
    
    if analysis:
        print(f"📊 Analysis: {analysis.group(1).strip()}")
    if final:
        print(f"💬 Answer: {final.group(2).strip()}")
    
    print(f"\nPerf: {stats['tokens_per_second']:.2f} tok/s")
except Exception as e:
    print(f"Harmony demo error: {e}")

## Performance Comparison

| Metric | TPU v2-8 (BF16) | TPU v6e (FP8) |
|--------|-----------------|---------------|
| Precision | 16-bit | 8-bit |
| Memory | ~42 GB | ~21 GB |
| TPU HBM | 64 GB | 32 GB |
| Utilization | 66% | 66% |
| Load Time | ~5s | ~5s |
| Tokens/sec | 50-100 | 80-150* |

\* FP8 may be faster due to lower memory bandwidth

## 9. Optional: Save to Google Drive

Uncomment the code below to save your Orbax checkpoint to Google Drive.

Why save to Drive?
- Colab sessions are temporary (max 12 hours)
- Avoid re-downloading model weights in future sessions
- 2-3x faster loading from Drive than HuggingFace

Note: Requires ~20-42 GB of Drive storage depending on precision strategy

In [None]:
# Optional: Save to Google Drive
# from google.colab import drive
# drive.mount('/content/drive')
# !cp -r {orbax_path} /content/drive/MyDrive/
# print("✓ Saved to Drive")

## 10. TPU Memory Monitoring

Real-time monitoring of TPU memory usage across all devices.

What you'll see:
- Per-device breakdown: Memory usage for each TPU core
- Bytes in use: Current memory consumption
- Bytes limit: Maximum available memory
- Utilization percentage: How much of each device's memory is used

Use this to:
- Debug OOM (Out of Memory) errors
- Verify memory is balanced across devices
- Monitor memory during inference

In [None]:
print("TPU Monitoring:")
try:
    from jax.lib import xla_bridge
    backend = xla_bridge.get_backend()
    for i, dev in enumerate(devices):
        try:
            info = backend.get_memory_info(dev)
            if info:
                used = info.bytes_in_use / 1e9
                limit = info.bytes_limit / 1e9
                print(f"  Device {i}: {used:.1f}/{limit:.1f} GB ({used/limit*100:.0f}%)")
        except:
            print(f"  Device {i}: Memory info unavailable")
except Exception as e:
    print(f"  Monitoring unavailable: {e}")

## 11. Cleanup Temporary Files

Removes the temporary download directory to free up disk space.

What gets deleted:
- /content/gpt-oss-20b-dl/ (13.8 GB)
- Original safetensors files

What's preserved:
- Loaded parameters in memory
- Orbax checkpoint (if you ran Cell 5)

Safe to run: Parameters are already loaded in RAM

In [None]:
# Cleanup temp files
!rm -rf /content/gpt-oss-20b-dl
print("✓ Cleaned temp files")

## Troubleshooting

**OOM Errors**: Verify TPU type matches strategy (Cell 3)

**TPU Not Detected**: Runtime → Change runtime type → TPU, then restart

**Slow Download**: HuggingFace rate limits - wait and retry

**Import Errors**: Re-run Cell 2 (environment setup)

## 🚀 Optimization Exercises

### 1. JAX Code Optimization
Profile with `jax.profiler`, optimize bottlenecks

[Code](https://github.com/atsentia/gpt-oss-jax/blob/main/gpt_oss/jax/model.py)

### 2. KV Cache Optimization
Implement INT8/FP8 KV cache for 2-4x memory savings

[Code](https://github.com/atsentia/gpt-oss-jax/blob/main/gpt_oss/jax/kv_cache.py)

### 3. Advanced Quantization
On-the-fly MXFP4 dequantization: 10.5 GB vs 21 GB

[Code](https://github.com/atsentia/gpt-oss-jax/tree/main/gpt_oss/jax/quantization)

### 4. Speculative Decoding
Draft model (GPT-2) + verification: 2-3x speedup

### 5. Continuous Batching
Batch multiple requests: 5-10x throughput

**Discuss**: [GitHub Discussions](https://github.com/atsentia/gpt-oss-jax/discussions)

## Conclusion

✅ Demonstrated adaptive precision (BF16 vs FP8)

✅ 2x memory reduction enables TPU v6e

✅ Production patterns: monitoring, error handling

✅ Harmony protocol multi-channel reasoning

### Resources
- [Repository](https://github.com/atsentia/gpt-oss-jax)
- [Model Card](https://huggingface.co/openai/gpt-oss-20b)
- [JAX Docs](https://jax.readthedocs.io/)

**Issues?** [Open an issue](https://github.com/atsentia/gpt-oss-jax/issues)