# Download GPT-OSS-20B from Hugging Face

This notebook downloads the GPT-OSS-20B model from Hugging Face to a Modal volume.

## Prerequisites
- Modal.com account with GPU access
- Attached Modal volume for model storage
- ~13GB of storage space for the quantized model

## 1. Install Required Libraries

In [None]:
!pip install -q huggingface_hub tqdm

## 2. Set Up Download Configuration

In [None]:
import os
from pathlib import Path
from huggingface_hub import snapshot_download, hf_hub_download
import json

# Configuration
MODEL_ID = "openai/gpt-oss-20b"
LOCAL_DIR = "/mnt/models/gpt-oss-20b"  # Adjust this path to your Modal volume mount point

# Create directory if it doesn't exist
os.makedirs(LOCAL_DIR, exist_ok=True)

print(f"Model will be downloaded to: {LOCAL_DIR}")
print(f"Model ID: {MODEL_ID}")

## 3. Download Model Files

This will download all the safetensors files and configuration files.

In [None]:
# Download the entire model repository
print("Starting download...")
print("This may take several minutes depending on your connection speed.")
print("Expected download size: ~13GB")
print("-" * 50)

try:
    # Download all model files
    downloaded_path = snapshot_download(
        repo_id=MODEL_ID,
        local_dir=LOCAL_DIR,
        local_dir_use_symlinks=False,
        resume_download=True,
        ignore_patterns=["*.md", "*.txt", ".git*"],  # Skip documentation files
    )
    
    print(f"\n✅ Model downloaded successfully to: {downloaded_path}")
    
except Exception as e:
    print(f"❌ Error downloading model: {e}")
    print("\nTrying alternative download method...")
    
    # Alternative: Download specific files
    try:
        # First, get the index file to know which shards to download
        index_file = hf_hub_download(
            repo_id=MODEL_ID,
            filename="model.safetensors.index.json",
            local_dir=LOCAL_DIR,
            local_dir_use_symlinks=False
        )
        
        # Download config files
        for config_file in ["config.json", "tokenizer_config.json", "tokenizer.json"]:
            try:
                hf_hub_download(
                    repo_id=MODEL_ID,
                    filename=config_file,
                    local_dir=LOCAL_DIR,
                    local_dir_use_symlinks=False
                )
                print(f"✓ Downloaded {config_file}")
            except:
                print(f"⚠ Could not download {config_file} (might not exist)")
        
        # Read the index to get list of shard files
        with open(os.path.join(LOCAL_DIR, "model.safetensors.index.json"), 'r') as f:
            index_data = json.load(f)
        
        # Get unique shard files
        shard_files = set(index_data.get('weight_map', {}).values())
        
        print(f"\nDownloading {len(shard_files)} shard files...")
        for shard_file in sorted(shard_files):
            print(f"Downloading {shard_file}...")
            hf_hub_download(
                repo_id=MODEL_ID,
                filename=shard_file,
                local_dir=LOCAL_DIR,
                local_dir_use_symlinks=False,
                resume_download=True
            )
            print(f"✓ Downloaded {shard_file}")
        
        print(f"\n✅ Model downloaded successfully to: {LOCAL_DIR}")
        
    except Exception as e2:
        print(f"❌ Alternative download also failed: {e2}")

## 4. Verify Download

In [None]:
import os
from pathlib import Path

def get_dir_size(path):
    """Calculate total size of directory."""
    total = 0
    for dirpath, dirnames, filenames in os.walk(path):
        for filename in filenames:
            filepath = os.path.join(dirpath, filename)
            if os.path.exists(filepath):
                total += os.path.getsize(filepath)
    return total

def format_size(bytes):
    """Format bytes to human readable string."""
    for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
        if bytes < 1024.0:
            return f"{bytes:.2f} {unit}"
        bytes /= 1024.0
    return f"{bytes:.2f} PB"

# List downloaded files
print("Downloaded files:")
print("-" * 50)

model_path = Path(LOCAL_DIR)
if model_path.exists():
    files = sorted(model_path.glob("*"))
    
    total_size = 0
    for file in files:
        if file.is_file():
            size = file.stat().st_size
            total_size += size
            print(f"{file.name:50} {format_size(size):>10}")
    
    print("-" * 50)
    print(f"Total files: {len(files)}")
    print(f"Total size: {format_size(total_size)}")
    
    # Check for critical files
    critical_files = [
        "config.json",
        "model.safetensors.index.json"
    ]
    
    print("\nCritical files check:")
    for critical_file in critical_files:
        file_path = model_path / critical_file
        if file_path.exists():
            print(f"✅ {critical_file} - Found")
        else:
            print(f"❌ {critical_file} - Missing")
    
    # Count safetensors shards
    shard_files = list(model_path.glob("*.safetensors"))
    if shard_files:
        print(f"\n✅ Found {len(shard_files)} safetensors shard files")
    else:
        print("\n⚠ No safetensors files found")
        
else:
    print(f"❌ Directory {LOCAL_DIR} does not exist")

## 5. Load Model Configuration

Let's inspect the model configuration to understand its architecture.

In [None]:
import json

config_path = os.path.join(LOCAL_DIR, "config.json")

if os.path.exists(config_path):
    with open(config_path, 'r') as f:
        config = json.load(f)
    
    print("Model Configuration:")
    print("-" * 50)
    
    # Key configuration details
    key_configs = [
        "architectures",
        "hidden_size",
        "intermediate_size",
        "num_attention_heads",
        "num_hidden_layers",
        "num_key_value_heads",
        "vocab_size",
        "max_position_embeddings",
        "torch_dtype",
        "quantization_config"
    ]
    
    for key in key_configs:
        if key in config:
            value = config[key]
            if isinstance(value, dict):
                print(f"{key}:")
                for k, v in value.items():
                    print(f"  {k}: {v}")
            else:
                print(f"{key}: {value}")
    
    # Calculate parameter count
    if "num_hidden_layers" in config and "hidden_size" in config:
        layers = config["num_hidden_layers"]
        hidden = config["hidden_size"]
        vocab = config.get("vocab_size", 0)
        
        print("\nModel Statistics:")
        print(f"- Layers: {layers}")
        print(f"- Hidden size: {hidden}")
        print(f"- Vocabulary: {vocab:,} tokens")
        
        # Check for MoE configuration
        if "num_local_experts" in config:
            print(f"- Experts: {config['num_local_experts']} (Mixture of Experts)")
            print(f"- Active experts: {config.get('num_experts_per_tok', 'N/A')}")
else:
    print(f"❌ Config file not found at {config_path}")

## 6. Check Quantization Format

Let's verify if the model uses MXFP4 quantization as expected.

In [None]:
# Check the safetensors index for weight information
index_path = os.path.join(LOCAL_DIR, "model.safetensors.index.json")

if os.path.exists(index_path):
    with open(index_path, 'r') as f:
        index_data = json.load(f)
    
    print("Model Weight Information:")
    print("-" * 50)
    
    # Get metadata
    metadata = index_data.get('metadata', {})
    if metadata:
        print("Metadata:")
        for key, value in metadata.items():
            print(f"  {key}: {value}")
    
    # Analyze weight map
    weight_map = index_data.get('weight_map', {})
    
    # Count different types of weights
    weight_types = {}
    for weight_name in weight_map.keys():
        # Check for quantized weights (blocks and scales pattern)
        if '.blocks' in weight_name:
            weight_types['quantized_blocks'] = weight_types.get('quantized_blocks', 0) + 1
        elif '.scales' in weight_name:
            weight_types['quantized_scales'] = weight_types.get('quantized_scales', 0) + 1
        elif '.weight' in weight_name:
            weight_types['regular_weights'] = weight_types.get('regular_weights', 0) + 1
        elif '.bias' in weight_name:
            weight_types['biases'] = weight_types.get('biases', 0) + 1
        else:
            weight_types['other'] = weight_types.get('other', 0) + 1
    
    print("\nWeight Statistics:")
    print(f"Total weights: {len(weight_map)}")
    for wtype, count in sorted(weight_types.items()):
        print(f"  {wtype}: {count}")
    
    # Check if model appears to be quantized
    if weight_types.get('quantized_blocks', 0) > 0:
        print("\n✅ Model appears to use MXFP4 quantization")
        print(f"   Found {weight_types.get('quantized_blocks', 0)} quantized blocks")
        print(f"   Found {weight_types.get('quantized_scales', 0)} scale tensors")
    else:
        print("\n⚠ Model may not be quantized or uses different format")
    
    # Show example weight names
    print("\nExample weight names (first 10):")
    for i, name in enumerate(list(weight_map.keys())[:10]):
        print(f"  {name}")
        
else:
    print(f"❌ Index file not found at {index_path}")

## 7. Next Steps

Now that the model is downloaded, you can:

1. **Convert to JAX format**: Use TensorPort to convert the safetensors to JAX-compatible NumPy arrays
2. **Run inference**: Load the model with JAX and run inference
3. **Benchmark performance**: Test inference speed on different GPU types
4. **Fine-tune**: Use the model for downstream tasks

### Quick Conversion Example

In [None]:
# Example command to convert using TensorPort (if installed)
print("To convert the model to JAX format using TensorPort:")
print("-" * 50)
print(f"tensorport convert \\")
print(f"    --input {LOCAL_DIR} \\")
print(f"    --output {LOCAL_DIR}-jax \\")
print(f"    --format numpy-direct \\")
print(f"    --precision float16")
print()
print("This will create JAX-loadable NumPy arrays from the safetensors files.")

## Summary

This notebook has downloaded the GPT-OSS-20B model from Hugging Face. The model uses MXFP4 quantization for efficient storage and inference.

Key points:
- Model size: ~13GB (quantized)
- Architecture: Mixture of Experts with 21B total parameters, 3.6B active
- Quantization: MXFP4 (4-bit with shared exponents)
- Ready for conversion to JAX format using TensorPort