### Install dependances

In [1]:
#pip install torch unsloth dataset transformers pillow

### Import libs

In [2]:
import torch
from unsloth import FastLanguageModel
from datasets import load_dataset
from transformers import AutoProcessor
from PIL import Image as PILImage
import json
import re
from config import *

ðŸ¦¥ Unsloth: Will patch your computer to enable 2x faster free finetuning.
ðŸ¦¥ Unsloth Zoo will now patch everything to make training faster!


### Load the Model

In [3]:
def load_vlm_model():
    print(f"Loading Model: {MODEL_NAME}")

    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name = MODEL_NAME,
        max_seq_length = MAX_SEQ_LENGTH,
        dtype = DTYPE,
        load_in_4bit = True,
        trust_remote_code = True
    )

    processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)

    model.config.use_cache = False

    print("Model and Tokenizer loaded successfully")
    return model, tokenizer, processor

### Evaluation Functions

In [4]:
def evaluate_baseline(model, tokenizer, processor, test_data_path):
    """
    baseline evaluation on evaluation set, to test model's ability.
    """

    try:
        test_data = load_dataset("parquet", data_files={"test": test_data_path}, split="test")
    except Exception as e:
        print(f"Error loading test dataset: {e}")
        return 0,0,0

    # small sample to avoid long initial run time
    sample_size = min(len(test_data), 5)
    print(f"\nDataset size: {len(test_data)}")
    print(f"Starting baseline evaluation on {sample_size} samples")

    correct_moves = 0
    correct_allowed_squares = 0
    legal_move_count = 0
    invalid_model_move = 0
    total_similarity_score = 0
    total_moves = 0

    for i in range(sample_size):
        sample = test_data[i]

        try:
            ground_truth_move = sample["best_move"]
            allowed_squares_list = sample["allowed_squares"]
            
            if isinstance(allowed_squares_list, list) and len(allowed_squares_list) >= 2:
                ground_truth_allowed_square = {
                    "global_row": allowed_squares_list[0],
                    "global_col": allowed_squares_list[1]
                }
            else:
                ground_truth_allowed_square = {"global_row": -1, "global_col": -1}
            
            legal_moves = sample["legal_moves"]
            unplayable_boards = sample["unplayable_boards"]
            image = sample["image"].convert("RGB")
            player_turn = sample["player"]
            ascii_board = sample["ascii_board"]
            cot_text = sample["chain_of_thought"]

        except KeyError as e:
            print(f"Skipping sample {i}: Missing required data field {e}.")
            continue
        except Exception as e:
            print(f"Skipping sample {i}: General data processing error: {type(e).__name__}: {e}")
            continue

        try:
            unplayable_list_str = format_squares_to_str(unplayable_boards)
            legal_moves_str = format_moves_to_str(legal_moves)

            system_content = (
                f"You are an expert Ultimate Tic-Tac-Toe player. "
                f"Your goal is to identify the optimal, legal move based on the provided image and context. "
                f"The final output must be **ONLY** a raw JSON object containing the chosen move."
                f"{{\"global_row\": r, \"global_col\": c, \"local_row\": lr, \"local_col\": lc}}."
            )
            
            user_prompt_text = (
                f"Player: {player_turn} (X=Player 1, O=Player 2)\n"
                f"Analyze the board state in the image and determine the optimal move.\n\n"
                f"--- BOARD CONTEXT ---\n"
                f"**Allowed/Active Board:** The global board highlighted in **BRIGHT GREEN** in the image is the current active board constraint. If this board is already won/tied, you must select any other available board (Free Play).\n"
                f"**Unplayable Boards:** The following Global Boards are already WON or TIED and cannot be played: {unplayable_list_str}\n\n"
                f"--- ASCII VISUALIZATION ---\n"
                f"Use this labeled diagram to cross-reference the image coordinates (0, 1, 2) with the piece locations and board status:\n"
                f"{ascii_board}\n\n"
                f"CRITICAL RULE: The target local cell (local_row, local_col) MUST be **EMPTY** on the global board (global_row, global_col).\n"
                f"CRITICAL RULE: All output coordinates (global_row, global_col, local_row, local_col) MUST be **0, 1 or 2**."
            )
        except Exception as e:
            print(f"Skipping sample {i}: Error formatting prompt: {type(e).__name__}: {e}")
            continue
            
        messages = [
            {"role": "system", "content": [{"type": "text", "text": system_content}]},
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": image},
                    {"type": "text", "text": user_prompt_text},
                ],
            },
        ]

        try:
            inputs = processor.apply_chat_template(
                messages,
                tokenize=True,
                add_generation_prompt=True,
                return_dict=True,
                return_tensors="pt"
            ).to(model.device)

            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=256,
                    do_sample=False,
                    temperature=0.1,
                    top_p=0.9,
                    use_cache=True
                )
            response_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
            # remove input prompt to get just the response
            response_text = response_text.split("<|im_start|>assistant\n")[-1].strip()

            # parse model's move
            model_move = parse_move_from_text(response_text)
            model_allowed_square = None
            if model_move and all(k in model_move for k in ["global_row", "global_col"]):
                model_allowed_square = {
                    "global_row": model_move.get("global_row"),
                    "global_col": model_move.get("global_col")
                }

            total_moves += 1
            is_correct_move = compare_moves(model_move, ground_truth_move)
            is_correct_allowed_square = compare_sqs(model_allowed_square, ground_truth_allowed_square)
            is_legal = is_move_legal(model_move, legal_moves)
            is_model_move_invalid = model_move_invalid(model_move)
            similarity_score = move_similarity(model_move, ground_truth_move)

            total_similarity_score += similarity_score

            if is_correct_move:
                correct_moves += 1

            if is_correct_allowed_square:
                correct_allowed_squares += 1

            if is_legal:
                legal_move_count += 1

            if is_model_move_invalid:
                invalid_model_move += 1

            if i < 3:
                print(f"\nSample {i} structure:")
                print(f"  best_move type: {type(ground_truth_move)}, value: {ground_truth_move}")
                print(f"  allowed_squares type: {type(allowed_squares_list)}, value: {allowed_squares_list}")
                print(f"  converted allowed_square: {ground_truth_allowed_square}")
                print(f"\n--- Raw response from model for sample {i}: ---")
                print(f"Response length: {len(response_text)}")
                print(f"Response: {response_text}")
                print("--- End of raw response ---")

            if i < 5 or is_correct_move or not is_correct_move:
                print(f"\nSample {i+1}")
                print(f"Ground Truth Best Move: {ground_truth_move}")
                print(f"Ground Truth Allowed Square: {ground_truth_allowed_square}")
                print("-" * 100)
                print(f"Model Move: {model_move}")
                print(f"Model Allowed Square: {model_allowed_square}")
                print(f"Move Result: {'CORRECT MOVE' if is_correct_move else 'INCORRECT MOVE'}")
                print(f"Allowed Square Result: {'CORRECT Square' if is_correct_allowed_square else 'INCORRECT Square'}")
                print(f"Move Similarity Score: {similarity_score}")
                print(f"Move Allowed?: {'Legal Move' if is_legal else 'Illegal Move'}")
                print(f"{'Invalid Model Move' if is_model_move_invalid else ''}")
                
        except Exception as e:
            print(f"Error processing sample {i}: {type(e).__name__}: {e}")
            continue

    if total_moves == 0:
        print("WARNING: No samples were successfully processed!")
        return 0, 0, 0, 0, 0, 0
    
    move_accuracy = (correct_moves / total_moves) * 100
    square_accuracy = (correct_allowed_squares / total_moves) * 100
    legal_move_accuracy = (legal_move_count / total_moves) * 100
    avg_similarity_score = total_similarity_score / total_moves
    invalid_moves = invalid_model_move
    return move_accuracy, square_accuracy, legal_move_accuracy, total_moves, invalid_moves, avg_similarity_score

### Parsing Model Response to Evalute

In [5]:
def format_squares_to_str(squares_list):
    """Converts a list of global square dicts into a formatted string: [(r, c), (r, c), ...]."""
    if not squares_list:
        return "[]"
    
    formatted = []
    for sq in squares_list:
        if isinstance(sq, dict) and 'global_row' in sq and 'global_col' in sq:
            formatted.append(f"({sq['global_row']}, {sq['global_col']})")
    
    return "[" + ", ".join(formatted) + "]"

def format_moves_to_str(moves_list):
    """Converts a list of move dicts into a formatted string: [(gr, gc:lr, lc), ...]."""
    if not moves_list:
        return "[]"
    
    formatted = []
    for move in moves_list:
        if isinstance(move, dict) and all(k in move for k in ["global_row", "global_col", "local_row", "local_col"]):
            formatted.append(f"({move['global_row']},{move['global_col']}:{move['local_row']},{move['local_col']})")
    
    return "[\n" + ",\n".join(formatted) + "\n]"

def parse_move_from_text(text):
    """Parse move JSON from model response, more robust version"""
    if not text:
        return None
    
    # multiple patterns to find the JSON
    patterns = [
        # Standard JSON pattern
        r'\{\s*"global_row":\s*\d+,\s*"global_col":\s*\d+,\s*"local_row":\s*\d+,\s*"local_col":\s*\d+\s*\}',
        # Allow for extra whitespace and line breaks
        r'\{\s*[\n\r]*"global_row"\s*:\s*\d+\s*,\s*[\n\r]*"global_col"\s*:\s*\d+\s*,\s*[\n\r]*"local_row"\s*:\s*\d+\s*,\s*[\n\r]*"local_col"\s*:\s*\d+\s*[\n\r]*\}',
        # Try to find any JSON-like structure
        r'\{\s*["\']global_row["\']\s*:\s*\d+\s*,\s*["\']global_col["\']\s*:\s*\d+\s*,\s*["\']local_row["\']\s*:\s*\d+\s*,\s*["\']local_col["\']\s*:\s*\d+\s*\}'
    ]
    
    for pattern in patterns:
        match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
        if match:
            move_str = match.group(0)
            move_str = move_str.replace("'", '"').replace("\n", "").replace("\r", "")
            try:
                return json.loads(move_str)
            except json.JSONDecodeError as e:
                continue
    
    # If no pattern matched, try to find any JSON and validate
    try:
        json_start = text.find('{')
        json_end = text.rfind('}')
        if json_start != -1 and json_end > json_start:
            json_str = text[json_start:json_end+1]
            data = json.loads(json_str)
            required_keys = ["global_row", "global_col", "local_row", "local_col"]
            if all(key in data for key in required_keys):
                return data
    except:
        pass
    
    return None

def compare_moves(model_move, ground_truth_move):
    # compare model move against ground truth
    if model_move is None:
        return False

    keys = ["global_row", "global_col", "local_row", "local_col"]

    try:
        return all(model_move.get(k) == ground_truth_move.get(k) for k in keys)
    except Exception:
        return False

def compare_sqs(model_square, ground_truth_square):
    # compare model allowed square against ground truth allowed square
    if model_square is None:
        return False
    
    # Check if ground_truth_square is valid
    if ground_truth_square is None:
        return False
    
    # Check for the 'Any Board' case (-1, -1)
    if ground_truth_square.get("global_row") == -1:
        # If Ground Truth is 'Any Board' (-1, -1), any board coordinate (0, 1, 2) 
        # is technically a valid target for the model's prediction.
        return True 

    keys = ["global_row", "global_col"]

    try:
        return all(model_square.get(k) == ground_truth_square.get(k) for k in keys)
    except Exception:
        return False

def is_move_legal(model_move, legal_moves):
    if model_move is None:
        return False
    for m in legal_moves:
        if (
                model_move.get("global_row") == m["global_row"] and
                model_move.get("global_col") == m["global_col"] and
                model_move.get("local_row") == m["local_row"] and
                model_move.get("local_col") == m["local_col"]
        ):
            return True
    return False

def model_move_invalid(model_move):
    if model_move is None:
        return True
    required_keys = ['global_row', 'global_col', 'local_row', 'local_col']
    if not all(key in model_move for key in required_keys):
        return True
    coordinates = [
        model_move.get('global_row'),
        model_move.get('global_col'),
        model_move.get('local_row'),
        model_move.get('local_col'),
    ]

    for coord in coordinates:
        # Check if coordinates are integers and within the 0, 1, 2 bounds
        if not isinstance(coord, int) or coord not in {0, 1, 2}:
            return True

    return False

def move_similarity(model_move, ground_truth_move):
    score = 0
    if model_move is None:
        return 0.0

    keys = ["global_row", "global_col", "local_row", "local_col"]
    if not all(key in model_move for key in keys):
        return 0.0

    if model_move["global_row"] == ground_truth_move["global_row"]:
        score += 1
    if model_move["global_col"] == ground_truth_move["global_col"]:
        score += 1
    if model_move["local_row"] == ground_truth_move["local_row"]:
        score += 1
    if model_move["local_col"] == ground_truth_move["local_col"]:
        score += 1
    return score / 4.0

### Main Execution

In [6]:
if __name__ == "__main__":
    model, tokenizer, processor = load_vlm_model()
    model.load_adapter(f"{OUTPUT_DIR_V4}/checkpoint-1206")
    
    move_accuracy, square_accuracy, legal_move_accuracy, total_samples, invalid_moves, avg_similarity_score = evaluate_baseline(model, tokenizer, processor, DATASET_TEST_PATH)
    
    print("\n" + "="*100)
    print("BASELINE EVALUATION RESULTS")
    print(f"Total samples: {total_samples}")
    print(f"Baseline Move Accuracy: {move_accuracy:.2f}%")
    print(f"Baseline Allowed Square Accuracy: {square_accuracy:.2f}%")
    print(f"Average Move Similarity Score: {avg_similarity_score}")
    print(f"Baseline Legal Move Accuracy: {legal_move_accuracy:.2f}%")
    print(f"Invalid Moves by Model (Hallucination/Out-of-bounds): {invalid_moves} moves")
    print("\n" + "="*100)

Loading Model: unsloth/Qwen3-VL-8B-Instruct-unsloth-bnb-4bit
Are you certain you want to do remote code execution?
==((====))==  Unsloth 2025.11.6: Fast Qwen3_Vl patching. Transformers: 4.57.2.
   \\   /|    NVIDIA A100-SXM4-40GB MIG 4g.20gb. Num GPUs = 1. Max memory: 19.625 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.1+cu128. CUDA: 8.0. CUDA Toolkit: 12.8. Triton: 3.5.1
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.33.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Qwen3_Vl does not support SDPA - switching to fast eager.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Model and Tokenizer loaded successfully


Generating test split: 0 examples [00:00, ? examples/s]


Dataset size: 400
Starting baseline evaluation on 5 samples

Sample 0 structure:
  best_move type: <class 'dict'>, value: {'global_row': 0, 'global_col': 2, 'local_row': 2, 'local_col': 2}
  allowed_squares type: <class 'list'>, value: [0, 2]
  converted allowed_square: {'global_row': 0, 'global_col': 2}

--- Raw response from model for sample 0: ---
Response length: 76
Response: {"global_row": 0, "global_col": 2, "local_row": 1, "local_col": 1}<|im_end|>
--- End of raw response ---

Sample 1
Ground Truth Best Move: {'global_row': 0, 'global_col': 2, 'local_row': 2, 'local_col': 2}
Ground Truth Allowed Square: {'global_row': 0, 'global_col': 2}
----------------------------------------------------------------------------------------------------
Model Move: {'global_row': 0, 'global_col': 2, 'local_row': 1, 'local_col': 1}
Model Allowed Square: {'global_row': 0, 'global_col': 2}
Move Result: INCORRECT MOVE
Allowed Square Result: CORRECT Square
Move Similarity Score: 0.5
Move Allowed?: L