# FunctionGemma Load Test

Tests loading fine-tuned model from Google Drive in a fresh session.

**⚠️ CRITICAL Loading Parameters:**
- `torch_dtype=torch.bfloat16` (NOT float16!)
- `attn_implementation="eager"`

These must match training parameters, otherwise model outputs garbage.

**Requirements:**
- GPU runtime (T4 or better)
- Model ZIP on Google Drive (`functiongemma-flutter-demo-final.zip`)

## 1. Install Dependencies (SAME versions as finetuning!)

In [None]:
# CRITICAL: Use same transformers version as finetuning notebook!
!pip install transformers==4.57.3 -q
!pip install sentencepiece -q

import transformers
print(f"transformers version: {transformers.__version__}")

## 2. Mount Drive & Load Model

In [None]:
from google.colab import drive
import os

drive.mount('/content/drive')

MODEL_DIR = "functiongemma-flutter-demo-final"
DRIVE_ZIP = f"/content/drive/MyDrive/{MODEL_DIR}.zip"

if not os.path.exists(MODEL_DIR):
    !unzip -q "{DRIVE_ZIP}"
    print(f"Extracted: {os.listdir(MODEL_DIR)}")
else:
    print(f"Already exists: {os.listdir(MODEL_DIR)}")

## 3. Load Model & Check Config

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Check GPU availability
print("GPU Check:")
print(f"   CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

print(f"\nLoading model from {MODEL_DIR}...")

# CRITICAL: Must match training parameters!
# - bfloat16 (NOT float16!)
# - attn_implementation="eager"
if torch.cuda.is_available():
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_DIR,
        torch_dtype=torch.bfloat16,       # CRITICAL: same as training!
        device_map="cuda:0",
        attn_implementation="eager"        # CRITICAL: same as training!
    )
else:
    print("⚠️ WARNING: No GPU - model may not work correctly")
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_DIR,
        torch_dtype=torch.float32,
        device_map="cpu"
    )

tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)

print(f"\n✅ Model loaded on: {model.device}")
print(f"   dtype: {model.dtype}")
print(f"\nConfig:")
print(f"   model.config.pad_token_id = {model.config.pad_token_id}")
print(f"   model.config.eos_token_id = {model.config.eos_token_id}")
print(f"   model.config.bos_token_id = {model.config.bos_token_id}")

## 4. Test Generation

In [None]:
# Define tool
tools = [{
    "type": "function",
    "function": {
        "name": "change_background_color",
        "description": "Changes the background color",
        "parameters": {
            "type": "object",
            "properties": {
                "color": {"type": "string", "description": "Color name"}
            },
            "required": ["color"]
        }
    }
}]

test_prompts = [
    "make it red",
    "change background to blue",
    "I want purple",
]

print("Testing model:")
print(f"Device: {model.device}")
print("=" * 60)

all_passed = True

for prompt in test_prompts:
    messages = [{"role": "user", "content": prompt}]
    
    input_text = tokenizer.apply_chat_template(
        messages,
        tools=tools,
        add_generation_prompt=True,
        tokenize=False
    )
    
    inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
    
    outputs = model.generate(
        inputs["input_ids"],
        max_new_tokens=80,
        do_sample=False,
        pad_token_id=tokenizer.pad_token_id
    )
    
    response = tokenizer.decode(
        outputs[0][inputs["input_ids"].shape[1]:],
        skip_special_tokens=False
    )
    
    # Check result
    is_valid = "<start_function_call>" in response and "<pad>" not in response[:50]
    status = "✅" if is_valid else "❌"
    if not is_valid:
        all_passed = False
    
    print(f"\n[{status}] User: {prompt}")
    print(f"      Model: {response.strip()[:80]}...")
    print("-" * 60)

print("\n" + "=" * 60)
if all_passed:
    print("✅ RESULT: ALL TESTS PASSED!")
    print("Model works correctly after save/load.")
else:
    print("❌ RESULT: TESTS FAILED!")
    if str(model.device) == "cpu":
        print("⚠️  Model is on CPU - this may be the cause!")
        print("   Try: Runtime → Change runtime type → T4 GPU")
    else:
        print("   Model is on GPU but still fails - investigate tokenizer.")
print("=" * 60)

## 5. Version Info (for debugging)

In [None]:
import torch
import transformers
import sys

print("Environment:")
print(f"   Python: {sys.version}")
print(f"   PyTorch: {torch.__version__}")
print(f"   Transformers: {transformers.__version__}")
print(f"   CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")