### Install dependencies

In [1]:
#pip install torch unsloth dataset transformers pillow scikit-learn

### Import libs

In [2]:
import torch, json, random, re
from unsloth import FastLanguageModel
from datasets import load_dataset
from transformers import AutoProcessor
from tqdm import tqdm
from config import *

ðŸ¦¥ Unsloth: Will patch your computer to enable 2x faster free finetuning.
Unsloth: Your Flash Attention 2 installation seems to be broken?
A possible explanation is you have a new CUDA version which isn't
yet compatible with FA2? Please file a ticket to Unsloth or FA2.
We shall now use Xformers instead, which does not have any performance hits!
We found this negligible impact by benchmarking on 1x A100.
ðŸ¦¥ Unsloth Zoo will now patch everything to make training faster!


### Load the Model

In [3]:
def load_vlm_model():
    global model, tokenizer, processor
    print(f"Loading Model: {MODEL_NAME}")

    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name = MODEL_NAME,
        max_seq_length = 4096,
        dtype = None,
        load_in_4bit = True,
    )

    processor = AutoProcessor.from_pretrained(MODEL_NAME)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"

    print("Model and Tokenizer loaded successfully")
    return model, tokenizer, processor

In [4]:
def extract_json_content(text):
    try:
        match = re.search(r"\{.*\}", text, re.DOTALL)
        if match:
            return json.loads(match.group(0))
            
        return None
    except:
        return None

def generate_response(model, processor, tokenizer, image, system, user):
    messages = [
        {"role": "system", "content": [{"type": "text", "text": system}]},
        {"role": "user", "content": [{"type": "text", "text": user}, {"type": "image"}]}
    ]
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = processor(text=[text], images=[image], padding=True, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs, 
            max_new_tokens=512,
            do_sample=False, 
            temperature=0.1,
            repetition_penalty=1.1
        )
    
    gen_ids = [out[len(inp):] for inp, out in zip(inputs.input_ids, outputs)]
    return tokenizer.decode(gen_ids[0], skip_special_tokens=True)

In [5]:
def run_eval(model, tokenizer, processor, test_data_path):
    try:
        dataset = load_dataset("parquet", data_files={"test": test_data_path}, split="test")
    except Exception as e: print(f"Error: {e}"); return

    num_samples = len(dataset)
    print(f"\n--- STARTING EVAL ON {num_samples} SAMPLES ---\n")

    metrics = {
        "json_success": 0, "total_tasks": 0,
        "move_active_grid": 0, "move_legal": 0, "move_exact": 0, "total_move": 0,
        "state_similarity": 0, "state_exact": 0, "total_state": 0, 
        "legality_acc": 0, "total_legality": 0,
        "active_grid_task_acc": 0, "total_active_task": 0
    }

    sys_prompt = ("You are an Ultimate Tic-Tac-Toe Visual Engine. **STRICT PROTOCOL:**\n"
                  "1. FORMAT: All responses must be valid JSON snippets wrapped in JSON_START/JSON_END anchors.\n"
                  "2. COORDINATES: Use 0-indexed integers (0, 1, 2) for all Row/Col values.\n"
                  "3. VISUAL ANCHOR: The allowed square (active subgrid) is highlighted in GREEN. You must play there.")

    model.eval()

    for i in tqdm(range(num_samples)):
        sample = dataset[i]
        image = sample["image"].convert("RGB")
        p_char = "X" if sample["player"] == 1 else "O"
        
        gt_allowed = sample.get("allowed_squares")
        legal_list = sample["legal_moves"]
        gt_best = sample["best_move"]
        full_state = sample["global_state"]
        
        gr, gc = random.randint(0, 2), random.randint(0, 2)

        is_legal_truth = any(m['global_row'] == gr and m['global_col'] == gc and m['local_row'] == 1 and m['local_col'] == 1 for m in legal_list)

        tasks = [
            ("ALLOWED_SQUARE", "Identify the allowed square (active global subgrid) based on the green highlight. If none is highlighted, report null."),
            ("MOVE", f"Visually analyze the board. It is Player {p_char}'s turn. Identify the allowed square (green highlight) and then select the best move."),
            ("STATE", f"Examine Global Subgrid ({gr}, {gc}). Represent the 3x3 local grid state as a matrix of 'X', 'O', or '.' (Empty)."),
            ("LEGALITY", f"Is it legal for Player {p_char} to play at Global({gr},{gc}), Local(1,1)? Step 1: Inspect the square state. Step 2: Check allowed square constraint. Step 3: Verdict."),
        ]
        
        def check_allowed(pred):
            if gt_allowed is None or len(gt_allowed) < 2:
                return pred is None or pred == "null"
            if not pred: return False 
            if isinstance(pred, dict): return pred.get("global_row") == gt_allowed[0] and pred.get("global_col") == gt_allowed[1]
            return False

        gt_matrix = []
        for lr in range(3):
            row = []
            for lc in range(3):
                occ = next((p['player'] for p in full_state if p['global_row'] == gr and p['global_col'] == gc and p['local_row'] == lr and p['local_col'] == lc), 0)
                row.append("X" if occ == 1 else "O" if occ == 2 else ".")
            gt_matrix.append(row)

        for task_name, prompt in tasks:
            metrics["total_tasks"] += 1
            
            try:
                resp = generate_response(model, processor, tokenizer, image, sys_prompt, prompt)
                parsed = extract_json_content(resp)
            except:
                parsed = None

            if parsed: metrics["json_success"] += 1
            else: continue

            if task_name == "ALLOWED_SQUARE":
                metrics["total_active_task"] += 1
                pred = parsed.get("allowed_square")
                if check_allowed(pred): metrics["active_grid_task_acc"] += 1

            elif task_name == "MOVE":
                metrics["total_move"] += 1
                pred_allowed = parsed.get("allowed_square")
                if check_allowed(pred_allowed): metrics["move_active_grid"] += 1
                
                pm = parsed.get("move") or parsed.get("best_move")
                if isinstance(pm, dict):
                    pr, pc, plr, plc = pm.get("global_row"), pm.get("global_col"), pm.get("local_row"), pm.get("local_col")
                    if pr is not None:
                        is_legal = any(m['global_row']==pr and m['global_col']==pc and m['local_row']==plr and m['local_col']==plc for m in legal_list)
                        if is_legal: metrics["move_legal"] += 1
                        
                        if pr==gt_best['global_row'] and pc==gt_best['global_col'] and plr==gt_best['local_row'] and plc==gt_best['local_col']:
                            metrics["move_exact"] += 1

            elif task_name == "STATE":
                metrics["total_state"] += 1
                pm = parsed.get("grid_matrix")
                
                if pm == gt_matrix: metrics["state_exact"] += 1
                if pm and isinstance(pm, list) and len(pm) == 3:
                    matches = 0
                    for r in range(3):
                        for c in range(3):
                            if len(pm[r]) > c and pm[r][c] == gt_matrix[r][c]: matches += 1
                    score = matches / 9.0
                    metrics["state_similarity"] += score

            elif task_name == "LEGALITY":
                metrics["total_legality"] += 1
                pl = parsed.get("is_legal") 
                if pl is not None:
                    if pl == is_legal_truth: metrics["legality_acc"] += 1

    print("\n" + "="*40)
    print(f"JSON Syntax:     {metrics['json_success']/max(1, metrics['total_tasks'])*100:.1f}%")
    print("-" * 40)
    print(f"[ALLOWED] Acc:   {metrics['active_grid_task_acc']/max(1, metrics['total_active_task'])*100:.1f}%")
    print(f"[MOVE] Legal:    {metrics['move_legal']/max(1, metrics['total_move'])*100:.1f}%")
    print(f"[MOVE] Exact:    {metrics['move_exact']/max(1, metrics['total_move'])*100:.1f}%")
    print("-" * 40)
    print(f"[STATE] Exact:   {metrics['state_exact']/max(1, metrics['total_state'])*100:.1f}%")
    print(f"[STATE] Sim %:   {(metrics['state_similarity']/max(1, metrics['total_state']))*100:.1f}%")
    print("-" * 40)
    print(f"[LEGAL] Logic:   {metrics['legality_acc']/max(1, metrics['total_legality'])*100:.1f}%")
    print("="*40 + "\n")

In [6]:
def run_debug_eval(model, tokenizer, processor, test_data_path, num_samples=5):
    print(f"\n--- STARTING EVAL ON {num_samples} SAMPLES ---\n")
    try:
        dataset = load_dataset("parquet", data_files={"test": test_data_path}, split="test")
        dataset = dataset.shuffle(seed=42).select(range(min(num_samples, len(dataset))))
    except Exception as e: print(f"Error: {e}"); return

    metrics = {
        "json_success": 0, "total_tasks": 0,
        "move_active_grid": 0, "move_legal": 0, "move_exact": 0, "total_move": 0,
        "state_similarity": 0, "state_exact": 0, "total_state": 0, 
        "legality_acc": 0, "total_legality": 0,
        "active_grid_task_acc": 0, "total_active_task": 0
    }

    sys_prompt = ("You are an Ultimate Tic-Tac-Toe Visual Engine. **STRICT PROTOCOL:**\n"
                  "1. FORMAT: All responses must be valid JSON snippets wrapped in JSON_START/JSON_END anchors.\n"
                  "2. COORDINATES: Use 0-indexed integers (0, 1, 2) for all Row/Col values.\n"
                  "3. VISUAL ANCHOR: The allowed square (active subgrid) is highlighted in GREEN. You must play there.")

    model.eval()

    for i in range(len(dataset)):
        sample = dataset[i]
        image = sample["image"].convert("RGB")
        p_char = "X" if sample["player"] == 1 else "O"
        
        gt_allowed = sample.get("allowed_squares")
        legal_list = sample["legal_moves"]
        gt_best = sample["best_move"]
        full_state = sample["global_state"]
        ascii_board = sample.get("ascii_board", "No ASCII available")

        gr, gc = random.randint(0, 2), random.randint(0, 2)

        print(f"\n{'='*20} SAMPLE {i} {'='*20}")
        print(f"Player: {p_char} | GT Allowed: {gt_allowed}")
        print(ascii_board)
        print("-" * 40)

        is_legal_truth = any(m['global_row'] == gr and m['global_col'] == gc and m['local_row'] == 1 and m['local_col'] == 1 for m in legal_list)

        tasks = [
            ("ALLOWED_SQUARE", "Identify the allowed square (active global subgrid) based on the green highlight. If none is highlighted, report null."),
            ("MOVE", f"Visually analyze the board. It is Player {p_char}'s turn. Identify the allowed square (green highlight) and then select the best move."),
            ("STATE", f"Examine Global Subgrid ({gr}, {gc}). Represent the 3x3 local grid state as a matrix of 'X', 'O', or '.' (Empty)."),
            ("LEGALITY", f"Is it legal for Player {p_char} to play at Global({gr},{gc}), Local(1,1)? Step 1: Inspect the square state. Step 2: Check allowed square constraint. Step 3: Verdict."),
        ]
        
        def check_allowed(pred):
            if gt_allowed is None or len(gt_allowed) < 2:
                return pred is None or pred == "null"
            if not pred: return False 
            if isinstance(pred, dict): return pred.get("global_row") == gt_allowed[0] and pred.get("global_col") == gt_allowed[1]
            return False

        gt_matrix = []
        for lr in range(3):
            row = []
            for lc in range(3):
                occ = next((p['player'] for p in full_state if p['global_row'] == gr and p['global_col'] == gc and p['local_row'] == lr and p['local_col'] == lc), 0)
                row.append("X" if occ == 1 else "O" if occ == 2 else ".")
            gt_matrix.append(row)

        for task_name, prompt in tasks:
            metrics["total_tasks"] += 1
            print(f"[{task_name}] Prompt: {prompt}")
            
            resp = generate_response(model, processor, tokenizer, image, sys_prompt, prompt)
            print(f"[{task_name}] Raw Response:\n{resp}\n")
            
            parsed = extract_json_content(resp)
            if parsed: metrics["json_success"] += 1
            else: 
                print(f"!!! JSON PARSE FAILED for {task_name} !!!")
                continue

            if task_name == "ALLOWED_SQUARE":
                metrics["total_active_task"] += 1
                pred = parsed.get("allowed_square")
                if check_allowed(pred): metrics["active_grid_task_acc"] += 1
                else: print(f"-> ALLOWED SQ FAIL. GT: {gt_allowed}, Pred: {pred}")

            elif task_name == "MOVE":
                metrics["total_move"] += 1
                pred_allowed = parsed.get("allowed_square")
                if check_allowed(pred_allowed): metrics["move_active_grid"] += 1
                
                pm = parsed.get("move") or parsed.get("best_move")
                if isinstance(pm, dict):
                    pr, pc, plr, plc = pm.get("global_row"), pm.get("global_col"), pm.get("local_row"), pm.get("local_col")
                    if pr is not None:
                        is_legal = any(m['global_row']==pr and m['global_col']==pc and m['local_row']==plr and m['local_col']==plc for m in legal_list)
                        if is_legal: metrics["move_legal"] += 1
                        else: print(f"-> ILLEGAL MOVE: {pm}")
                        
                        if pr==gt_best['global_row'] and pc==gt_best['global_col'] and plr==gt_best['local_row'] and plc==gt_best['local_col']:
                            metrics["move_exact"] += 1

            elif task_name == "STATE":
                metrics["total_state"] += 1
                pm = parsed.get("grid_matrix")
                
                if pm == gt_matrix: metrics["state_exact"] += 1
                if pm and isinstance(pm, list) and len(pm) == 3:
                    matches = 0
                    for r in range(3):
                        for c in range(3):
                            if len(pm[r]) > c and pm[r][c] == gt_matrix[r][c]: matches += 1
                    score = matches / 9.0
                    metrics["state_similarity"] += score
                    if score < 1.0: print(f"-> STATE PARTIAL ({int(score*100)}%). GT: {gt_matrix} vs Pred: {pm}")

            elif task_name == "LEGALITY":
                metrics["total_legality"] += 1
                pl = parsed.get("is_legal") 
                if pl is not None:
                    if pl == is_legal_truth: metrics["legality_acc"] += 1
                    else: print(f"-> LEGALITY FAIL. GT: {is_legal_truth}, Pred: {pl}")

    print("\n" + "="*40)
    print(f"JSON Syntax:     {metrics['json_success']/metrics['total_tasks']*100:.1f}%")
    print("-" * 40)
    print(f"[ALLOWED] Acc:   {metrics['active_grid_task_acc']/metrics['total_active_task']*100:.1f}%")
    print(f"[MOVE] Legal:    {metrics['move_legal']/metrics['total_move']*100:.1f}%")
    print(f"[MOVE] Exact:    {metrics['move_exact']/metrics['total_move']*100:.1f}%")
    print("-" * 40)
    print(f"[STATE] Exact:   {metrics['state_exact']/metrics['total_state']*100:.1f}%")
    print(f"[STATE] Sim %:   {(metrics['state_similarity']/metrics['total_state'])*100:.1f}%")
    print("-" * 40)
    print(f"[LEGAL] Logic:   {metrics['legality_acc']/metrics['total_legality']*100:.1f}%")
    print("="*40 + "\n")

In [7]:
if __name__ == "__main__":
    model, tokenizer, processor = load_vlm_model()
    model.load_adapter(OUTPUT_DIR_V2)
    run_eval(model, tokenizer, processor, DATASET_TEST_PATH)

Loading Model: unsloth/Qwen3-VL-8B-Instruct-unsloth-bnb-4bit
==((====))==  Unsloth 2026.1.4: Fast Qwen3_Vl patching. Transformers: 4.57.6.
   \\   /|    NVIDIA A100 80GB PCIe MIG 1g.20gb. Num GPUs = 1. Max memory: 19.5 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.10.0+cu128. CUDA: 8.0. CUDA Toolkit: 12.8. Triton: 3.6.0
\        /    Bfloat16 = TRUE. FA [Xformers = None. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Model and Tokenizer loaded successfully

--- STARTING EVAL ON 5 SAMPLES ---



100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 400/400 [4:51:26<00:00, 43.72s/it]  


JSON Syntax:     100.0%
----------------------------------------
[ALLOWED] Acc:   100.0%
[MOVE] Legal:    93.8%
[MOVE] Exact:    13.2%
----------------------------------------
[STATE] Exact:   82.5%
[STATE] Sim %:   92.0%
----------------------------------------
[LEGAL] Logic:   95.0%




