# Testing jax-diffucoder from TestPyPI

This notebook tests the jax-diffucoder package installed from TestPyPI.

In [None]:
# Install from TestPyPI
# Note: --extra-index-url is needed for dependencies not on TestPyPI
!pip install --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/ jax-diffucoder

In [None]:
# Verify installation
import jax
import jax_lm

print(f"✅ jax-diffucoder version: {jax_lm.__version__}")
print(f"JAX version: {jax.__version__}")
print(f"JAX devices: {jax.devices()}")

In [None]:
# Test imports
from jax_lm import (
    DiffuCoderConfig,
    load_model,
    generate,
    load_tokenizer
)

print("✅ All imports successful")

In [None]:
# Test configuration
config = DiffuCoderConfig()
print("Model configuration:")
print(f"  Hidden size: {config.hidden_size}")
print(f"  Num layers: {config.num_hidden_layers}")
print(f"  Num heads: {config.num_attention_heads}")
print(f"  Vocab size: {config.vocab_size}")

In [None]:
# Check hardware
import jax

if len(jax.devices("tpu")) > 0:
    print(f"🚀 Running on TPU with {len(jax.devices('tpu'))} cores")
    print(f"TPU devices: {jax.devices('tpu')}")
elif len(jax.devices("gpu")) > 0:
    print(f"🎮 Running on GPU")
    print(f"GPU devices: {jax.devices('gpu')}")
else:
    print(f"💻 Running on CPU")
    print(f"CPU devices: {jax.devices('cpu')}")

In [None]:
# Test model loading (will fail until HuggingFace upload)
try:
    print("Attempting to load model from HuggingFace...")
    model, params, tokenizer = load_model("atsentia/DiffuCoder-7B-JAX")
    print("✅ Model loaded successfully!")
    
    # Test generation
    prompt = "def hello_world():"
    output = generate(model, params, prompt, tokenizer, max_new_tokens=50)
    print(f"\nGenerated output:\n{output}")
    
except Exception as e:
    print(f"❌ Model loading failed (expected until HF upload is complete)")
    print(f"Error: {type(e).__name__}: {e}")

## Next Steps

1. Upload model to HuggingFace: `atsentia/DiffuCoder-7B-JAX`
2. Re-run the model loading cell above
3. Test generation with different prompts
4. Benchmark on TPU vs GPU vs CPU