# FunctionGemma Fine-tune v·ªõi Unsloth

Train AI quy·∫øt ƒë·ªãnh: `plant(plant_type, row, col)` ho·∫∑c `wait()`

**∆Øu ƒëi·ªÉm Unsloth:**
- Nhanh h∆°n 2-5x so v·ªõi HuggingFace
- √çt VRAM h∆°n (ch·∫°y ƒë∆∞·ª£c tr√™n T4 free)
- H·ªó tr·ª£ FunctionGemma 270M native

**Output:** PyTorch model (convert OpenVINO tr√™n local)

## 1. C√†i ƒë·∫∑t Unsloth

In [None]:
%%capture
!pip install unsloth
!pip uninstall unsloth -y && pip install --upgrade --no-cache-dir --no-deps git+https://github.com/unslothai/unsloth.git

## 2. Upload Training Data

In [None]:
from google.colab import files
import json

print("Upload training_data.json...")
uploaded = files.upload()

filename = list(uploaded.keys())[0]
with open(filename, 'r') as f:
    raw_data = json.load(f)

print(f"\n‚úì Loaded {len(raw_data)} samples")
stats = {}
for s in raw_data:
    stats[s['action']] = stats.get(s['action'], 0) + 1
print(f"  Actions: {stats}")

## 3. Load FunctionGemma v·ªõi Unsloth

In [None]:
from unsloth import FastLanguageModel
import torch

max_seq_length = 2048

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/functiongemma-270m-it",
    max_seq_length=max_seq_length,
    load_in_4bit=False,   # Full precision cho model nh·ªè
    load_in_16bit=True,   # 16bit LoRA
    full_finetuning=True, # Full finetune v√¨ model nh·ªè
)

print(f"‚úì Model loaded: {model.config._name_or_path}")

## 4. Define Tools

In [None]:
def plant(plant_type: str, row: int, col: int):
    """
    Plant a plant at grid position.

    Args:
        plant_type: Type of plant (pea_shooter, sunflower, wall_nut, cherry_bomb, etc)
        row: Row index 0-4 (0=top, 4=bottom)
        col: Column index 0-8 (0=left, 8=right)

    Returns:
        result: Action result
    """
    return {"result": "planted"}

def wait():
    """
    Wait and do nothing. Use when seed is on cooldown or no good action available.

    Returns:
        result: Action result
    """
    return {"result": "waiting"}

TOOLS = [plant, wait]
print("‚úì Tools defined:", [f.__name__ for f in TOOLS])

## 5. Prepare Dataset (v·ªõi Thinking Block)

In [None]:
from datasets import Dataset
import random

SYSTEM_MSG = """PvZ bot. Analyze game state and choose action.
- PLANTS: planted plants (type,row,col)
- ZOMBIES: zombies (type,row,col)
- SEEDS: seed packets (type,status)"""

def format_for_training(sample):
    """Format sample v·ªõi <think> block + function response theo Unsloth docs"""
    action = sample["action"]
    args = sample.get("arguments", {})
    thinking = sample.get("thinking", "")

    if action == "plant":
        tool_call = {"type": "function", "function": {"name": "plant", "arguments": args}}
        func_response = {"name": "plant", "response": {"result": "planted"}}
    else:
        tool_call = {"type": "function", "function": {"name": "wait", "arguments": {}}}
        func_response = {"name": "wait", "response": {"result": "waiting"}}

    messages = [
        {"role": "developer", "content": SYSTEM_MSG},
        {"role": "user", "content": sample["game_state"]},
        {"role": "assistant", "tool_calls": [tool_call]},
        {"role": "tool", "content": [func_response]},  # Function response
    ]

    # Apply chat template
    text = tokenizer.apply_chat_template(
        messages,
        tools=TOOLS,
        tokenize=False,
        add_generation_prompt=False
    )

    # Inject <think> block TR∆Ø·ªöC function call (theo Unsloth docs)
    if thinking:
        model_turn = "<start_of_turn>model\n"
        if model_turn in text:
            think_block = f"<think>\n{thinking}\n</think>\n"
            text = text.replace(model_turn, model_turn + think_block, 1)  # Ch·ªâ replace l·∫ßn ƒë·∫ßu

    return {"text": text}

# Check if data has thinking field
has_thinking = any("thinking" in s for s in raw_data)
print(f"Dataset has thinking: {has_thinking}")

# Shuffle v√† format
random.shuffle(raw_data)
dataset = Dataset.from_list(raw_data)
dataset = dataset.map(format_for_training, remove_columns=dataset.features)

# Split train/test
dataset = dataset.train_test_split(test_size=0.1, shuffle=True)
print(f"‚úì Train: {len(dataset['train'])}, Test: {len(dataset['test'])}")
print(f"\nSample:\n{dataset['train'][0]['text'][:800]}...")

## 6. Training v·ªõi Unsloth

In [None]:
from trl import SFTTrainer
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported

num_samples = len(raw_data)
# T·ª± ƒë·ªông ƒëi·ªÅu ch·ªânh epochs d·ª±a tr√™n s·ªë samples
epochs = max(3, min(20, 500 // num_samples))

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    dataset_num_proc=2,
    packing=False,
    args=TrainingArguments(
        output_dir="pvz_gemma",
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        gradient_accumulation_steps=2,
        warmup_steps=5,
        num_train_epochs=epochs,
        learning_rate=2e-4,
        fp16=not is_bfloat16_supported(),
        bf16=is_bfloat16_supported(),
        logging_steps=10,
        eval_strategy="epoch",
        save_strategy="epoch",
        optim="adamw_8bit",
        weight_decay=0.01,
        lr_scheduler_type="linear",
        seed=3407,
        report_to="none",
    ),
)

print(f"Training {num_samples} samples for {epochs} epochs...")
print(f"Batch size: 4 x 2 = 8 effective")

trainer_stats = trainer.train()
print(f"\n‚úì Training complete in {trainer_stats.metrics['train_runtime']:.1f}s")

## 7. Test Model

In [None]:
import re

def extract_tool_call(text):
    """Parse FunctionGemma output"""
    match = re.search(r"<start_function_call>call:(\w+)\{(.*?)\}<end_function_call>", text, re.DOTALL)
    if not match:
        return None
    name = match.group(1)
    args_str = match.group(2)
    args = {}
    for k, v in re.findall(r"(\w+):([^,}]+)", args_str):
        v = v.strip().replace("<escape>", "")
        try:
            args[k] = int(v)
        except:
            args[k] = v
    return {"name": name, "arguments": args}

def test_bot(game_state):
    messages = [
        {"role": "developer", "content": SYSTEM_MSG},
        {"role": "user", "content": game_state},
    ]
    inputs = tokenizer.apply_chat_template(
        messages, tools=TOOLS, add_generation_prompt=True,
        return_dict=True, return_tensors="pt"
    )
    out = model.generate(
        **inputs.to(model.device),
        max_new_tokens=64,
        top_k=64, top_p=0.95, temperature=1.0,
        pad_token_id=tokenizer.eos_token_id
    )
    output = tokenizer.decode(out[0][len(inputs["input_ids"][0]):], skip_special_tokens=False)
    return extract_tool_call(output)

print("="*50)
print("TEST PVZ BOT")
print("="*50)

test_cases = [
    "PLANTS:[]. ZOMBIES:[]. SEEDS:[(pea_shooter,ready)]",
    "PLANTS:[(pea_shooter,2,0)]. ZOMBIES:[(zombie,2,7)]. SEEDS:[(pea_shooter,cooldown)]",
    "PLANTS:[(pea_shooter,2,0)]. ZOMBIES:[(zombie,1,6)]. SEEDS:[(pea_shooter,ready)]",
    "PLANTS:[]. ZOMBIES:[(zombie,0,8),(zombie,4,7)]. SEEDS:[(pea_shooter,ready),(sunflower,ready)]",
]

for t in test_cases:
    result = test_bot(t)
    print(f"\nüì• {t}")
    print(f"üì§ {result}")

## 8. Save & Download PyTorch Model (FP16)

In [None]:
# Save model ·ªü FP16 (gi·ªØ nguy√™n ch·∫•t l∆∞·ª£ng, size ~536MB thay v√¨ ~1GB)
print("Saving model (FP16)...")

# Full finetune kh√¥ng d√πng LoRA n√™n d√πng save_pretrained
model.save_pretrained("pvz_gemma_merged")
tokenizer.save_pretrained("pvz_gemma_merged")

# Convert sang FP16 ƒë·ªÉ gi·∫£m size
import torch
from pathlib import Path

model_path = Path("pvz_gemma_merged")

# Convert safetensors files
for sf_file in model_path.glob("*.safetensors"):
    print(f"Converting {sf_file.name} to FP16...")
    from safetensors.torch import load_file, save_file
    tensors = load_file(sf_file)
    tensors = {k: v.half() if v.dtype == torch.float32 else v for k, v in tensors.items()}
    save_file(tensors, sf_file)

print("‚úì Model saved to pvz_gemma_merged/")

# Zip v√† download
!zip -r pvz_gemma_merged.zip pvz_gemma_merged/
print("\n‚úì Ready for download: pvz_gemma_merged.zip")
!ls -lh pvz_gemma_merged.zip

In [None]:
from google.colab import files
files.download('pvz_gemma_merged.zip')

## (Optional) Push to HuggingFace Hub

In [None]:
# Uncomment ƒë·ªÉ push l√™n HuggingFace
# from huggingface_hub import login
# login()
# model.push_to_hub("your-username/pvz-gemma-bot")
# tokenizer.push_to_hub("your-username/pvz-gemma-bot")