# FunctionGemma Fine-tune cho PvZ Bot

Train AI quy·∫øt ƒë·ªãnh: `plant(plant_type, row, col)` ho·∫∑c `wait()`

**Game State bao g·ªìm:**
- Plants ƒë√£ tr·ªìng: v·ªã tr√≠ (row, col) + lo·∫°i
- Zombies: s·ªë l∆∞·ª£ng + v·ªã tr√≠ (row, col)
- C√≥ th·ªÉ tr·ªìng hay kh√¥ng (CAN_PLANT/CANNOT_PLANT)

**Y√™u c·∫ßu:** GPU Runtime + HuggingFace account

## Workflow:
1. Upload `training_data.json` (t·ª´ video_dataset_builder)
2. Ch·∫°y notebook
3. Download model

## 1. C√†i ƒë·∫∑t

In [None]:
!pip install torch transformers datasets accelerate trl protobuf sentencepiece -q

## 2. Upload Training Data

Upload file `training_data.json` t·ª´ `data/processed/`

In [None]:
from google.colab import files
import json

print("Upload training_data.json...")
uploaded = files.upload()

# Load data
filename = list(uploaded.keys())[0]
with open(filename, 'r') as f:
    raw_data = json.load(f)

print(f"\n‚úì Loaded {len(raw_data)} samples from {filename}")

# Stats
stats = {}
for s in raw_data:
    action = s['action']
    stats[action] = stats.get(action, 0) + 1
print(f"  Actions: {stats}")

## 3. Login HuggingFace

C·∫ßn token t·ª´ https://huggingface.co/settings/tokens

In [None]:
from huggingface_hub import login
login()

## 4. Load Model

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.utils import get_json_schema

BASE_MODEL = "google/functiongemma-270m-it"

model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    torch_dtype=torch.float32,
    device_map="auto",
    attn_implementation="eager"
)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)

print(f"‚úì Model loaded! Device: {model.device}")

## 5. Define Tools (Actions)

Ch·ªâ 2 actions:
- `plant(plant_type, row, col)` - Tr·ªìng c√¢y
- `wait()` - Ch·ªù

In [None]:
def plant(plant_type: str, row: int, col: int) -> str:
    """
    Plant a plant at grid position.
    
    Args:
        plant_type: Type of plant (pea_shooter, sunflower, wall_nut, cherry_bomb, snow_pea, repeater)
        row: Row index 0-4 (0=top, 4=bottom)
        col: Column index 0-8 (0=left, 8=right)
    """
    return "Planted"

def wait() -> str:
    """Wait and do nothing this turn. Use when seed is on cooldown or not enough sun."""
    return "Waiting"

TOOLS = [
    get_json_schema(plant), 
    get_json_schema(wait)
]

print("‚úì Tools defined:")
for t in TOOLS:
    print(f"  - {t['function']['name']}")

## 6. Format Data cho Training

In [None]:
from datasets import Dataset
import random

SYSTEM_MSG = """You are a PvZ game bot. Analyze game state and choose ONE action.

Game state format:
- PLANTS: list of (type, row, col) for planted plants
- ZOMBIES: count and positions (row, col)
- CAN_PLANT or CANNOT_PLANT

Strategy:
- Plant in rows where zombies are approaching
- Prioritize defense over expansion
- Wait if seed is on cooldown or not enough sun"""

def create_conversation(sample):
    action = sample["action"]
    args = sample["arguments"]
    
    if action == "plant":
        tool_call = {
            "type": "function", 
            "function": {
                "name": "plant", 
                "arguments": {
                    "plant_type": args.get("plant_type", "pea_shooter"),
                    "row": args.get("row", 2), 
                    "col": args.get("col", 0)
                }
            }
        }
    else:  # wait
        tool_call = {
            "type": "function", 
            "function": {
                "name": "wait", 
                "arguments": {}
            }
        }
    
    return {
        "messages": [
            {"role": "developer", "content": SYSTEM_MSG},
            {"role": "user", "content": sample["game_state"]},
            {"role": "assistant", "tool_calls": [tool_call]},
        ],
        "tools": TOOLS
    }

# Shuffle data
random.shuffle(raw_data)

dataset = Dataset.from_list(raw_data)
dataset = dataset.map(create_conversation, remove_columns=dataset.features, batched=False)
dataset = dataset.train_test_split(test_size=0.2, shuffle=True)

print(f"‚úì Train: {len(dataset['train'])}, Test: {len(dataset['test'])}")

## 7. Training

In [None]:
from trl import SFTTrainer, SFTConfig

# Adjust epochs based on dataset size
num_samples = len(raw_data)
epochs = max(10, 100 // num_samples * 10)  # More epochs for small datasets

args = SFTConfig(
    output_dir="pvz_functiongemma",
    max_length=512,
    packing=False,
    num_train_epochs=epochs,
    per_device_train_batch_size=4,
    gradient_checkpointing=False,
    optim="adamw_torch",
    logging_steps=10,
    eval_strategy="epoch",
    learning_rate=5e-5,
    fp16=False,
    bf16=False,
    lr_scheduler_type="constant",
    report_to="none",
)

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset['train'],
    eval_dataset=dataset['test'],
    processing_class=tokenizer,
)

print(f"Training {num_samples} samples for {epochs} epochs...")
trainer.train()
print("\n‚úì Training complete!")

## 8. Test Model

In [None]:
def test_bot(game_state):
    messages = [
        {"role": "developer", "content": SYSTEM_MSG},
        {"role": "user", "content": game_state},
    ]
    
    inputs = tokenizer.apply_chat_template(
        messages, 
        tools=TOOLS, 
        add_generation_prompt=True, 
        return_dict=True, 
        return_tensors="pt"
    )
    
    out = model.generate(
        **inputs.to(model.device), 
        pad_token_id=tokenizer.eos_token_id, 
        max_new_tokens=64
    )
    output = tokenizer.decode(out[0][len(inputs["input_ids"][0]):], skip_special_tokens=False)
    return output

print("="*50)
print("TEST PVZ BOT")
print("="*50)

test_cases = [
    # Kh√¥ng c√≥ plant, kh√¥ng c√≥ zombie, c√≥ th·ªÉ tr·ªìng
    "PLANTS:[]. ZOMBIES:[]. CAN_PLANT",
    
    # C√≥ 1 plant, 1 zombie ƒëang ƒë·∫øn row 2
    "PLANTS:[(pea_shooter,2,0)]. ZOMBIES:[(zombie,2,7)]. CAN_PLANT",
    
    # Nhi·ªÅu plant, nhi·ªÅu zombie (m·ªói con ri√™ng bi·ªát v·ªõi type)
    "PLANTS:[(pea_shooter,2,0),(pea_shooter,2,1)]. ZOMBIES:[(zombie,1,6),(zombie,2,5),(cone_zombie,3,7)]. CAN_PLANT",
    
    # Kh√¥ng th·ªÉ tr·ªìng (cooldown)
    "PLANTS:[(pea_shooter,2,0)]. ZOMBIES:[(zombie,2,6),(zombie,2,7)]. CANNOT_PLANT",
    
    # Zombie ·ªü row ch∆∞a c√≥ plant
    "PLANTS:[(pea_shooter,2,0)]. ZOMBIES:[(bucket_zombie,0,8)]. CAN_PLANT",
]

for t in test_cases:
    print(f"\nüì• Input: {t}")
    print(f"üì§ Output: {test_bot(t)}")

## 9. Save & Download

In [None]:
# Save model
model.save_pretrained("pvz_functiongemma_final")
tokenizer.save_pretrained("pvz_functiongemma_final")

# Zip for download
!zip -r pvz_functiongemma_final.zip pvz_functiongemma_final/

print("\n‚úì Model saved! Download pvz_functiongemma_final.zip")

In [None]:
from google.colab import files
files.download('pvz_functiongemma_final.zip')