# GPT-OSS-20B: Convert to JAX and Test Inference

This notebook converts the downloaded GPT-OSS-20B model from safetensors format to JAX-compatible NumPy arrays and runs a test inference.

## 1. Setup and Imports

In [None]:
import os
import sys
import json
import time
import subprocess
from pathlib import Path
import numpy as np
import jax
import jax.numpy as jnp
from typing import Dict, Any
import psutil
import gc

In [None]:
# Install required packages if not already installed
!pip install -q safetensors transformers jax jaxlib psutil

In [None]:
# Configuration
MODEL_PATH = "/mnt/example-runs-vol/gpt-oss-20b"
OUTPUT_PATH = "/mnt/example-runs-vol/gpt-oss-20b-jax"
TENSORPORT_PATH = "/mnt/example-runs-vol/tensorport"

print(f"Model path: {MODEL_PATH}")
print(f"Output path: {OUTPUT_PATH}")
print(f"JAX devices: {jax.devices()}")
print(f"JAX backend: {jax.default_backend()}")

## 2. Build TensorPort (Rust Converter)

In [None]:
# Clone and build TensorPort if not already present
if not os.path.exists(TENSORPORT_PATH):
    print("Cloning TensorPort repository...")
    !git clone https://github.com/your-username/tensorport.git {TENSORPORT_PATH}
    
# Install Rust if not already installed
!curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
!source $HOME/.cargo/env

# Build TensorPort
print("Building TensorPort...")
os.chdir(TENSORPORT_PATH)
!$HOME/.cargo/bin/cargo build --release

tensorport_binary = f"{TENSORPORT_PATH}/target/release/tensorport"
print(f"TensorPort binary: {tensorport_binary}")

## 3. Convert Model to JAX Format

In [None]:
def get_memory_usage():
    process = psutil.Process()
    return process.memory_info().rss / 1024 / 1024 / 1024  # GB

print(f"Memory before conversion: {get_memory_usage():.2f} GB")

In [None]:
# Run TensorPort conversion
os.makedirs(OUTPUT_PATH, exist_ok=True)

conversion_cmd = [
    tensorport_binary,
    "convert",
    "--input", MODEL_PATH,
    "--output", OUTPUT_PATH,
    "--format", "numpy-direct",
    "--precision", "float16",
    "--verbose"
]

print("Starting conversion...")
print(f"Command: {' '.join(conversion_cmd)}")

start_time = time.time()
result = subprocess.run(conversion_cmd, capture_output=True, text=True)

if result.returncode == 0:
    print("✅ Conversion successful!")
    print(f"Time taken: {time.time() - start_time:.2f} seconds")
else:
    print("❌ Conversion failed!")
    print(f"Error: {result.stderr}")
    
print(f"Memory after conversion: {get_memory_usage():.2f} GB")

In [None]:
# List converted files
output_files = list(Path(OUTPUT_PATH).glob("*.npy"))
print(f"\nConverted {len(output_files)} tensor files:")
for f in sorted(output_files)[:10]:  # Show first 10
    size_mb = f.stat().st_size / 1024 / 1024
    print(f"  {f.name}: {size_mb:.2f} MB")
if len(output_files) > 10:
    print(f"  ... and {len(output_files) - 10} more files")

## 4. Load Model Configuration

In [None]:
# Load model config
config_path = Path(MODEL_PATH) / "config.json"
with open(config_path, 'r') as f:
    config = json.load(f)

print("Model Configuration:")
print(f"  Architecture: {config.get('architectures', ['Unknown'])[0]}")
print(f"  Hidden size: {config.get('hidden_size', 'N/A')}")
print(f"  Num layers: {config.get('num_hidden_layers', 'N/A')}")
print(f"  Num heads: {config.get('num_attention_heads', 'N/A')}")
print(f"  Vocab size: {config.get('vocab_size', 'N/A')}")
print(f"  Max position embeddings: {config.get('max_position_embeddings', 'N/A')}")

# Check if it's MoE
if 'num_experts' in config:
    print(f"\nMixture of Experts:")
    print(f"  Num experts: {config['num_experts']}")
    print(f"  Experts per token: {config.get('num_experts_per_tok', 'N/A')}")

## 5. Simple JAX Inference Test

In [None]:
class SimpleGPTModel:
    """Simple GPT model for testing weight loading and basic inference."""
    
    def __init__(self, weights_path: str, config: Dict[str, Any]):
        self.weights_path = Path(weights_path)
        self.config = config
        self.weights = {}
        
    def load_embedding_weights(self):
        """Load just the embedding weights for a simple test."""
        print("Loading embedding weights...")
        
        # Try to find embedding weight file
        embed_files = list(self.weights_path.glob("*embed*.npy"))
        if not embed_files:
            embed_files = list(self.weights_path.glob("*token*.npy"))
        
        if embed_files:
            embed_path = embed_files[0]
            print(f"Loading {embed_path.name}...")
            embed_weight = np.load(embed_path)
            self.weights['embed'] = jnp.array(embed_weight)
            print(f"  Shape: {self.weights['embed'].shape}")
            print(f"  Dtype: {self.weights['embed'].dtype}")
            return True
        else:
            print("❌ Could not find embedding weights")
            return False
    
    def simple_embed(self, token_ids):
        """Simple embedding lookup."""
        if 'embed' not in self.weights:
            raise ValueError("Embedding weights not loaded")
        
        # Simple embedding lookup
        embeddings = self.weights['embed'][token_ids]
        return embeddings

# Create model instance
model = SimpleGPTModel(OUTPUT_PATH, config)

# Load embedding weights
if model.load_embedding_weights():
    print("\n✅ Successfully loaded embedding weights!")
else:
    print("\n❌ Failed to load embedding weights")

In [None]:
# Test embedding lookup
if 'embed' in model.weights:
    print("Testing embedding lookup...")
    
    # Create some test token IDs
    test_tokens = jnp.array([1, 100, 1000, 5000], dtype=jnp.int32)
    print(f"Test tokens: {test_tokens}")
    
    # Run embedding lookup
    start_time = time.time()
    embeddings = model.simple_embed(test_tokens)
    embed_time = time.time() - start_time
    
    print(f"\nEmbedding results:")
    print(f"  Output shape: {embeddings.shape}")
    print(f"  Output dtype: {embeddings.dtype}")
    print(f"  Time taken: {embed_time*1000:.2f} ms")
    print(f"  Mean value: {jnp.mean(embeddings):.6f}")
    print(f"  Std dev: {jnp.std(embeddings):.6f}")

## 6. Performance Benchmark

In [None]:
def benchmark_embedding_lookup(model, num_tokens=1024, num_iterations=100):
    """Benchmark embedding lookup performance."""
    if 'embed' not in model.weights:
        print("Embedding weights not loaded")
        return
    
    vocab_size = model.weights['embed'].shape[0]
    
    # Generate random token IDs
    key = jax.random.PRNGKey(42)
    token_ids = jax.random.randint(key, (num_tokens,), 0, vocab_size)
    
    # Warmup
    print(f"Warming up with {10} iterations...")
    for _ in range(10):
        _ = model.simple_embed(token_ids)
    
    # Benchmark
    print(f"\nBenchmarking {num_iterations} iterations with {num_tokens} tokens...")
    times = []
    
    for i in range(num_iterations):
        start = time.time()
        embeddings = model.simple_embed(token_ids)
        embeddings.block_until_ready()  # Wait for JAX computation
        times.append(time.time() - start)
        
        if (i + 1) % 20 == 0:
            print(f"  Iteration {i+1}/{num_iterations}")
    
    # Calculate statistics
    times = np.array(times)
    
    print(f"\n📊 Benchmark Results:")
    print(f"  Mean time: {np.mean(times)*1000:.2f} ms")
    print(f"  Median time: {np.median(times)*1000:.2f} ms")
    print(f"  Min time: {np.min(times)*1000:.2f} ms")
    print(f"  Max time: {np.max(times)*1000:.2f} ms")
    print(f"  Std dev: {np.std(times)*1000:.2f} ms")
    print(f"  Throughput: {num_tokens / np.mean(times):.0f} tokens/second")

# Run benchmark
if 'embed' in model.weights:
    benchmark_embedding_lookup(model)

## 7. Load and Test More Layers (Optional)

In [None]:
def load_layer_weights(weights_path: Path, layer_idx: int = 0):
    """Load weights for a specific transformer layer."""
    layer_weights = {}
    
    # Common patterns for layer weight files
    patterns = [
        f"*layer.{layer_idx}*.npy",
        f"*layers.{layer_idx}*.npy",
        f"*h.{layer_idx}*.npy",
        f"*block.{layer_idx}*.npy"
    ]
    
    print(f"\nSearching for layer {layer_idx} weights...")
    
    for pattern in patterns:
        files = list(weights_path.glob(pattern))
        if files:
            print(f"Found {len(files)} files matching pattern: {pattern}")
            for f in files[:5]:  # Show first 5
                weight = np.load(f)
                key = f.stem.replace('.', '_')
                layer_weights[key] = jnp.array(weight)
                print(f"  Loaded {key}: shape={weight.shape}, dtype={weight.dtype}")
            break
    
    return layer_weights

# Try to load first layer weights
layer_0_weights = load_layer_weights(Path(OUTPUT_PATH), layer_idx=0)

if layer_0_weights:
    print(f"\n✅ Successfully loaded {len(layer_0_weights)} weights for layer 0")
else:
    print("\n⚠️ Could not find layer 0 weights (this is okay for initial test)")

## 8. Memory and Storage Summary

In [None]:
def get_directory_size(path):
    """Calculate total size of directory."""
    total_size = 0
    for f in Path(path).rglob('*'):
        if f.is_file():
            total_size += f.stat().st_size
    return total_size / 1024 / 1024 / 1024  # GB

# Calculate sizes
original_size = get_directory_size(MODEL_PATH)
converted_size = get_directory_size(OUTPUT_PATH)

print("\n📦 Storage Summary:")
print(f"  Original model size: {original_size:.2f} GB")
print(f"  Converted model size: {converted_size:.2f} GB")
print(f"  Compression ratio: {(1 - converted_size/original_size)*100:.1f}%")

print(f"\n💾 Memory Usage:")
print(f"  Current memory: {get_memory_usage():.2f} GB")
print(f"  Available memory: {psutil.virtual_memory().available / 1024 / 1024 / 1024:.2f} GB")
print(f"  Total system memory: {psutil.virtual_memory().total / 1024 / 1024 / 1024:.2f} GB")

## 9. Next Steps

In [None]:
print("\n🎯 Conversion and Test Complete!\n")
print("Next steps:")
print("1. ✅ Model successfully converted to JAX format")
print("2. ✅ Basic embedding lookup working")
print("3. 📋 TODO: Implement full model architecture in JAX")
print("4. 📋 TODO: Add MXFP4 quantization support")
print("5. 📋 TODO: Implement text generation")
print("6. 📋 TODO: Compare with PyTorch baseline")

print("\n📊 Performance expectations:")
print("  - CPU: ~1-2 seconds per token (current)")
print("  - T4 GPU: ~100-200 tokens/second (expected)")
print("  - A10G GPU: ~200-400 tokens/second (expected)")
print("  - With MXFP4: Additional 2-4x speedup")