In [None]:
# --- 1. Verify TPU and Install JAX/Flax Dependencies ---
import jax
import os

# This is the most important step. We must confirm a TPU is available.
try:
    # Get the number of TPU devices available. Should be > 0.
    tpu_device_count = jax.device_count()
    if tpu_device_count > 0:
        print(f"✅ Success! Found {tpu_device_count} JAX devices (TPUs).")
        # Set a flag to prevent a known JAX warning
        os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
    else:
         raise Exception("No TPU devices found.")
except Exception as e:
    print("❌ ERROR: Could not initialize JAX with a TPU backend.")
    print("Please go to 'Runtime' -> 'Change runtime type' and select a 'TPU' hardware accelerator.")

# --- Install Libraries ---
# We need specific versions of these libraries that are compatible with TPUs.
print("\\nInstalling JAX, Flax, and Hugging Face Transformers...")
!pip install -q "jax[tpu]>=0.4.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install -q flax transformers sentencepiece

print("\\n✅ Dependencies installed.")


In [None]:
# --- 2. Load TPU-Compatible Model and Tokenizer ---
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
import jax.numpy as jnp

# --- Model Selection ---
# We must use a model that has a JAX/Flax version available on the Hugging Face Hub.
# "google/gemma-2b-it-flax" is a powerful and TPU-compatible choice.
MODEL_NAME = "google/gemma-2b-it-flax"

print(f"--- Loading Model: {MODEL_NAME} ---")
print("This can take a few minutes as the model is downloaded to the Colab instance...")

# Load the tokenizer and the Flax model
# The 'bf16' dtype is highly optimized for modern TPUs.
try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = FlaxAutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        dtype=jnp.bfloat16,
    )
    print(f"✅ Model '{MODEL_NAME}' and tokenizer loaded successfully.")
except Exception as e:
    print(f"❌ ERROR: Failed to load the model. This could be due to a model name typo or a Hugging Face Hub issue.")
    print(e)
