# Tunix Inference & Evaluation

This notebook loads a trained Tunix SFT checkpoint (Gemma 2 2B + LoRA) and runs inference on evaluation prompts.
Use this to verify model performance without re-running training.


In [None]:
# --- Setup & Install ---
!pip install -q kagglehub
!pip install -q ipywidgets
!pip install -q tensorflow
!pip install -q tensorflow_datasets
!pip install -q tensorboardX
!pip install -q transformers
!pip install -q grain

import socket
import os

def is_connected():
    try:
        socket.create_connection(("1.1.1.1", 53))
        return True
    except OSError:
        pass
    return False

if is_connected():
    !pip install -q -U chex==0.1.90
    !pip install -q -U google-tunix[prod]==0.1.5 distrax==0.1.7 optax==0.2.6
    !pip install git+https://github.com/google/qwix
else:
    print("Offline mode detected. Assuming dependencies are installed or wheels provided.")
    # Fallback: Try installing from local wheels if available
    if os.path.exists("/kaggle/input/tunix-wheels"):
        !pip install --no-index --find-links=/kaggle/input/tunix-wheels google-tunix
        !pip install --no-index --find-links=/kaggle/input/tunix-wheels qwix

# Fix Flax Version
!pip uninstall -q -y flax
!pip install flax==0.12.0

!pip install -q datasets==3.2.0 optax==0.2.4 chex==0.1.88

# --- Imports ---
import functools
import gc
import os
import re
import time
import shutil
from pprint import pprint

from flax import nnx
import jax
import jax.numpy as jnp
import kagglehub
from orbax import checkpoint as ocp
import qwix
import numpy as np

# Tunix Imports
from tunix.generate import sampler as sampler_lib
from tunix.generate import tokenizer_adapter as tokenizer_lib
from tunix.models.gemma import model as gemma_lib
from tunix.models.gemma import params as params_lib

# --- Config ---
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.95'
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")

print(f"JAX Devices: {jax.devices()}")

# Paths
CHECKPOINT_DIR = "/kaggle/working/sft_checkpoint"  # Where checkpoints were saved during training
# CHECKPOINT_DIR = "/kaggle/input/your-model-dataset/sft_checkpoint" # Uncomment if loading from uploaded dataset

RANK = 64
ALPHA = 64.0
MAX_SEQ_LEN = 2048

# Inference Params
INFERENCE_TEMPERATURE = 0.7
INFERENCE_TOP_K = 50
INFERENCE_TOP_P = 0.95
EVAL_MAX_TOKENS = 1024


In [None]:
# --- Model Loading Utilities ---
MESH = [(8, 1), ("fsdp", "tp")]

def get_gemma_model(ckpt_path):
    mesh = jax.make_mesh(*MESH)
    model_config = gemma_lib.ModelConfig.gemma2_2b()
    abs_gemma: nnx.Module = nnx.eval_shape(
        lambda: gemma_lib.Transformer(model_config, rngs=nnx.Rngs(params=0))
    )
    abs_state = nnx.state(abs_gemma)
    abs_state = jax.tree.map(
        lambda a, s: jax.ShapeDtypeStruct(a.shape, jnp.bfloat16, sharding=s),
        abs_state,
        nnx.get_named_sharding(abs_state, mesh),
    )
    checkpointer = ocp.StandardCheckpointer()
    # Restore base model params
    restored_params = checkpointer.restore(ckpt_path, target=abs_state)

    graph_def, _ = nnx.split(abs_gemma)
    gemma = nnx.merge(graph_def, restored_params)
    return gemma, mesh, model_config

def get_lora_model(base_model, mesh):
    lora_provider = qwix.LoraProvider(
        module_path=(
            ".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj|"
            ".*attn_vec_einsum"
        ),
        rank=RANK,
        alpha=ALPHA,
    )

    model_input = base_model.get_model_input()
    lora_model = qwix.apply_lora_to_model(
        base_model, lora_provider, rngs=nnx.Rngs(params=0), **model_input
    )

    with mesh:
        state = nnx.state(lora_model)
        pspecs = nnx.get_partition_spec(state)
        sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
        nnx.update(lora_model, sharded_state)

    return lora_model

def restore_lora_checkpoint(lora_model, checkpoint_path):
    '''Restores LoRA adapter weights from Orbax checkpoint'''
    print(f"Restoring LoRA weights from {checkpoint_path}...")
    checkpointer = ocp.StandardCheckpointer()
    
    # We only need to restore the params structure
    abstract_state = nnx.state(lora_model, nnx.LoRAParam)
    restored_state = checkpointer.restore(checkpoint_path, target=abstract_state)
    
    # Update model with restored LoRA params
    nnx.update(lora_model, restored_state)
    print("LoRA weights restored.")
    return lora_model


In [None]:
# --- 1. Load Base Model ---
if "KAGGLE_USERNAME" not in os.environ:
    kagglehub.login()

# Download Base Gemma 2
model_path = { "gemma2": "google/gemma-2/flax/" }
model_version = "gemma2-2b-it" 
kaggle_ckpt_path = kagglehub.model_download(f"{model_path['gemma2']}{model_version}")

# Convert/Prepare Base Checkpoint
INTERMEDIATE_CKPT_DIR = "/tmp/content/intermediate_ckpt/"
if not os.path.exists(os.path.join(INTERMEDIATE_CKPT_DIR, "state")):
    print("Converting base checkpoint...")
    params = params_lib.load_and_format_params(os.path.join(kaggle_ckpt_path, "gemma2-2b-it"))
    gemma = gemma_lib.Transformer.from_params(params, version="2-2b-it")
    checkpointer = ocp.StandardCheckpointer()
    _, state = nnx.split(gemma)
    checkpointer.save(os.path.join(INTERMEDIATE_CKPT_DIR, "state"), state)
    checkpointer.wait_until_finished()
    del params, gemma, state
    gc.collect()

# Load Base Model
print("Loading Base Model...")
base_model, mesh, model_config = get_gemma_model(os.path.join(INTERMEDIATE_CKPT_DIR, "state"))
lora_model = get_lora_model(base_model, mesh=mesh)

# Setup Tokenizer
tokenizer = tokenizer_lib.Tokenizer(
    tokenizer_path=os.path.join(kaggle_ckpt_path, "tokenizer.model")
)


In [None]:
# --- 2. Load Trained Adapters ---
# Find latest checkpoint
import glob
try:
    # Orbax checkpoints are typically directories named by step number, e.g., 20500
    # We need to find the step directories inside CHECKPOINT_DIR
    # Structure: CHECKPOINT_DIR/20500/default/checkpoint...
    
    # List subdirectories that are integers (steps)
    subdirs = [d for d in os.listdir(CHECKPOINT_DIR) if os.path.isdir(os.path.join(CHECKPOINT_DIR, d)) and d.isdigit()]
    if not subdirs:
        raise ValueError(f"No step checkpoints found in {CHECKPOINT_DIR}")
    
    latest_step = max([int(d) for d in subdirs])
    checkpoint_path = os.path.join(CHECKPOINT_DIR, str(latest_step))
    
    print(f"Loading checkpoint from step: {latest_step}")
    print(f"Path: {checkpoint_path}")
    
    # Orbax Manager usually stores items under 'default' or similar key if using CheckpointManager
    # But PeftTrainer usage implies CheckpointManager structure.
    # Let's try to restore directly from the step directory, which CheckpointManager manages.
    # Note: StandardCheckpointer expects the directory CONTAINING the data.
    
    # With CheckpointManager, the structure for step N is usually root/N/default/
    # If PeftTrainer uses 'default' item name.
    potential_path = os.path.join(checkpoint_path, "default")
    if os.path.exists(potential_path):
        checkpoint_path = potential_path
    
    restore_lora_checkpoint(lora_model, checkpoint_path)

except Exception as e:
    print(f"Failed to load checkpoint: {e}")
    print("Listing directory:")
    if os.path.exists(CHECKPOINT_DIR):
        pprint(os.listdir(CHECKPOINT_DIR))
    else:
        print("Checkpoint dir not found.")


In [None]:
# --- 3. Run Inference ---
print("Running Evaluation...")

prompts = [
    "Write a short story about a robot learning to paint.",
    "Write a haiku about artificial intelligence.",
    "Propose three innovative uses for AI in education.",
    "Summarize the key benefits and risks of renewable energy in 3 paragraphs.",
    "Solve step-by-step: If 2x + 5 = 15, what is x?",
    "Write a Python function to check if a string is a palindrome.",
    "Explain why the sky is blue to a 5-year-old.",
    "Explain the process of photosynthesis step by step.",
    "What are the ethical implications of AI in healthcare?",
    "Should AI systems have rights? Argue both sides.",
]

SYSTEM_PROMPT = "You are a deep thinking AI. Think step by step about the problem and provide your reasoning between <reasoning> and </reasoning> tags. Then, provide the final answer between <answer> and </answer> tags."
TEMPLATE = f"<start_of_turn>user\n{SYSTEM_PROMPT}\n\n{{question}}<end_of_turn>\n<start_of_turn>model"
formatted_prompts = [TEMPLATE.format(question=p) for p in prompts]

inference_sampler = sampler_lib.Sampler(
    transformer=lora_model,
    tokenizer=tokenizer,
    cache_config=sampler_lib.CacheConfig(
        cache_size=MAX_SEQ_LEN + 512,
        num_layers=model_config.num_layers,
        num_kv_heads=model_config.num_kv_heads,
        head_dim=model_config.head_dim,
    ),
)

out_data = inference_sampler(
    input_strings=formatted_prompts,
    max_generation_steps=EVAL_MAX_TOKENS,
    temperature=INFERENCE_TEMPERATURE,
    top_k=INFERENCE_TOP_K,
    top_p=INFERENCE_TOP_P,
    echo=False
)

print("--- Results ---")
for p, o in zip(prompts, out_data.text):
    print(f"Prompt: {p}")
    print(f"Output: {o}\n")
    print("-"*50)
