In [1]:
#!pip install unsloth transformers accelerate datasets bitsandbytes pandas pillow packaging ninja

In [2]:
import torch
import os, json
import pandas as pd
from unsloth import FastLanguageModel
from transformers import TrainingArguments, Trainer, AutoProcessor
from datasets import Dataset, load_from_disk
from PIL import Image as PILImage
from PIL import ImageDraw, ImageFilter, ImageEnhance
from config import *

ðŸ¦¥ Unsloth: Will patch your computer to enable 2x faster free finetuning.
ðŸ¦¥ Unsloth Zoo will now patch everything to make training faster!


In [3]:
model = None
tokenizer = None
processor = None

In [4]:
def load_vlm_model():
    global model, tokenizer, processor
    print(f"Loading Model: {MODEL_NAME}")

    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name = MODEL_NAME,
        max_seq_length = MAX_SEQ_LENGTH,
        dtype = DTYPE,
        load_in_4bit = True,
    )

    processor = AutoProcessor.from_pretrained(MODEL_NAME)
    tokenizer.pad_token_id = processor.tokenizer.pad_token_id 
    tokenizer.padding_side = "right"

    print("Model and Tokenizer loaded successfully")
    return model, tokenizer, processor

In [5]:
def visual_spotlight_active_board(image, active_row, active_col):
    if active_row < 0 or active_col < 0 or active_row > 2 or active_col > 2:
        return image

    width, height = image.size
    
    converter = ImageEnhance.Color(image)
    desaturated = converter.enhance(0.5)
    
    converter = ImageEnhance.Brightness(desaturated)
    darkened = converter.enhance(0.4)

    board_size_w = width // 3
    board_size_h = height // 3
    
    left = active_col * board_size_w
    top = active_row * board_size_h
    right = (active_col + 1) * board_size_w
    bottom = (active_row + 1) * board_size_h

    active_patch = image.crop((left, top, right, bottom))

    darkened.paste(active_patch, (left, top))
        
    return darkened

In [6]:
def MultimodalDataCollator(data):
    """
    Data collator responsible for batching, tokenization, image processing, 
    and applying label masking for multimodal training.

    1. Extracts images and conversation messages.
    2. Converts the conversation messages (System + User + Assistant) into a 
       single text prompt string using the tokenizer's chat template.
    3. Processes images and tokenizes the full text prompt into a batch using 
       the VLM's multimodal processor.
    4. CRITICALLY: Applies label masking (setting loss to -100) to all tokens 
       belonging to the System and User turns, ensuring the model only learns 
       to predict the Assistant's (target JSON) response.
    """
    global model, processor, tokenizer
    if model is None or processor is None or tokenizer is None:
        raise ValueError("Model, processor, tokenizer must be loaded globally before using this.")

    # 1. Extract Images and Messages
    images = [item["image"].convert("RGB") for item in data]
    messages = [item["messages"] for item in data]

    # 2. Generate the raw prompt strings for the FULL conversation
    text_prompts = [
        tokenizer.apply_chat_template(
            conversation,
            tokenize=False, 
            add_generation_prompt=False
        )
        for conversation in messages
    ]

    # 3. Process Images and Tokenize Text (Full Batch)
    try: 
        # The processor handles both image preprocessing (to image_pixel_values) 
        # and full text tokenization for the VLM, resulting in a padded batch.
        inputs = processor(
            images = images,
            text = text_prompts,
            return_tensors = "pt",
            padding = True,
            max_length = processor.tokenizer.model_max_length,
        )
    except Exception as e:
        print(f"Unable to create batch. Actual error: {e}")
        raise

    # 4. Manual Label Masking: Clone input_ids to create the initial labels tensor
    labels = inputs["input_ids"].clone()
    
    # 5. Determine the masking cutoff point for each example
    for i, conversation in enumerate(messages):
        # We need the tokenized length of the [system, user] part to know where to stop masking
        system_user_conversation = conversation[:2] 
        
        # A. Generate the System + User string (text only)
        # we must use the full multimodal processor on the System + User part. 
        # The previous attempt failed because the text-only tokenizer does not 
        # account for the image patch tokens inserted by the VLM processor.
        
        # 1. Generate the System + User string (includes <image> placeholder)
        system_user_text = tokenizer.apply_chat_template(
            system_user_conversation,
            tokenize=False,
            # This is critical to include the final <|assistant|> turn token
            add_generation_prompt=True 
        )

        # 2. Use the FULL multimodal processor on this single example to get the length.
        # This correctly tokenizes the text and accounts for the image tokens.
        try:
            system_user_tokens = processor(
                images = [images[i]], # Pass the single image from the batch
                text = system_user_text,
                return_tensors = "pt",
                padding = False, # Do not pad single example
            ).input_ids.squeeze()
        except Exception as e:
            print(f"Error processing single prompt for masking length: {e}")
            system_user_tokens = torch.tensor([-1])


        # Calculate the length for masking
        if system_user_tokens.dim() == 0 or system_user_tokens[0] == -1:
             # Handle error case by skipping (loss will be computed for full sequence, which is wrong)
             print(f"Warning: Masking calculation failed for sample {i}. Using safe, non-masked sequence.")
             cutoff_length = 0 # No mask
        else:
            # The number of tokens corresponding to the System + User part + <|assistant|> token
            cutoff_length = system_user_tokens.shape[0]

        # Mask out the System and User tokens in the padded batch by setting them to -100
        # The loss function ignores tokens with label -100.
        labels[i, :cutoff_length] = -100

    inputs["labels"] = labels
    return inputs

In [7]:
def format_data_for_training(examples):
    messages_list = []
    batch_size = len(examples.get("image", []))

    if batch_size == 0:
        return {"messages": []}

    coordinate_definition = (
        "**COORDINATE SYSTEM DEFINITION:**\n"
        "The board uses a 0-indexed system (0, 1, 2) for both rows and columns.\n"
        "Row 0 is the TOP row. Column 0 is the LEFT column.\n"
        "Your output must STRICTLY use the integers 0, 1, or 2 for all coordinates.\n"
        "Example: The top-left corner is (0, 0)."
    )

    for i in range(batch_size):
        
        image = examples.get("image", [None]*batch_size)[i]
        
        if not isinstance(image, PILImage.Image):
            messages_list.append([])
            continue

        allowed_squares = examples["allowed_squares"][i]
        unplayable_boards = examples["unplayable_boards"][i]
        best_move_dict = examples["best_move"][i]
        player_turn = examples["player"][i]
        cot_text = examples["chain_of_thought"][i]
        ascii_board = examples["ascii_board"][i]
            
        if isinstance(allowed_squares, (list, tuple)) and len(allowed_squares) == 2:
            ACTIVE_GLOBAL_ROW = allowed_squares[0]
            ACTIVE_GLOBAL_COL = allowed_squares[1]
        else:
            ACTIVE_GLOBAL_ROW = -1
            ACTIVE_GLOBAL_COL = -1

        unplayable_list_str = json.dumps(unplayable_boards)    
        best_move_json = json.dumps(best_move_dict)

        transformed_image = visual_spotlight_active_board(image, ACTIVE_GLOBAL_ROW, ACTIVE_GLOBAL_COL)
        examples["image"][i] = transformed_image

        system_content = (
            f"You are an expert Ultimate Tic-Tac-Toe player. "
            f"Your goal is to identify the optimal, legal move based on the provided image and context. "
            f"{coordinate_definition}\n\n"
            f"The final output must be **ONLY** a raw JSON object containing the chosen move."
        )

        # --- User Prompt (provides all game context) ---
        user_prompt = (
            f"Player: {player_turn} (X=Player 1, O=Player 2)\n"
            f"Analyze the board state in the image and determine the optimal move.\n\n"
            
            f"--- CRITICAL MOVE RESTRICTION ---\n"
            f"The global board highlighted in **BRIGHT GREEN** (and visually emphasized) is the current active board constraint. You MUST select a local cell within this board unless it is Free Play (where no board is emphasized).\n"
            f"**Unplayable Boards:** The following Global Boards are already WON or TIED and cannot be played: {unplayable_list_str}\n\n"
            
            f"--- ASCII VISUALIZATION ---\n"
            f"Use this labeled diagram to cross-reference the image coordinates (0, 1, 2) with the piece locations and board status:\n"
            f"{ascii_board}\n\n"

            f"CRITICAL RULE: The target local cell (local_row, local_col) MUST be **EMPTY** on the active global board.\n"
            f"CRITICAL RULE: All output coordinates (global_row, global_col, local_row, local_col) MUST be **0, 1 or 2**."
        )

        assistant_content = best_move_json
            
        messages_for_current_example = [
            {"role": "system", "content": [{"type": "text", "text": system_content}]},
            {
                "role": "user",
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": user_prompt},
                ],
            },
            {"role": "assistant", "content": [{"type": "text", "text": assistant_content}]}
        ]
        messages_list.append(messages_for_current_example)

    examples["messages"] = messages_list
    return examples

def format_data_for_eval(examples):
    messages_list = []
    batch_size = len(examples.get("image", []))

    if batch_size == 0:
        return {"messages": []}

    coordinate_definition = (
        "**COORDINATE SYSTEM DEFINITION:**\n"
        "The board uses a 0-indexed system (0, 1, 2) for both rows and columns.\n"
        "Row 0 is the TOP row. Column 0 is the LEFT column.\n"
        "Your output must STRICTLY use the integers 0, 1, or 2 for all coordinates.\n"
        "Example: The top-left corner is (0, 0)."
    )

    for i in range(batch_size):
        
        image = examples.get("image", [None]*batch_size)[i]
        
        if not isinstance(image, PILImage.Image):
            messages_list.append([])
            continue

        allowed_squares = examples["allowed_squares"][i]
        unplayable_boards = examples["unplayable_boards"][i]
        best_move_dict = examples["best_move"][i]
        player_turn = examples["player"][i]
        cot_text = examples["chain_of_thought"][i]
        ascii_board = examples["ascii_board"][i]

        unplayable_list_str = json.dumps(unplayable_boards)    
        best_move_json = json.dumps(best_move_dict)

        system_content = (
            f"You are an expert Ultimate Tic-Tac-Toe player. "
            f"Your goal is to identify the optimal, legal move based on the provided image and context. "
            f"{coordinate_definition}\n\n"
            f"The final output must be **ONLY** a raw JSON object containing the chosen move."
        )

        user_prompt = (
            f"Player: {player_turn} (X=Player 1, O=Player 2)\n"
            f"Analyze the board state in the image and determine the optimal move.\n\n"
            
            f"--- CRITICAL MOVE RESTRICTION ---\n"
            f"The global board highlighted in **BRIGHT GREEN** is the current active board constraint. You MUST select a local cell within this board unless it is Free Play (where no board is highlighted).\n"
            f"**Unplayable Boards:** The following Global Boards are already WON or TIED and cannot be played: {unplayable_list_str}\n\n"
            
            f"--- ASCII VISUALIZATION ---\n"
            f"Use this labeled diagram to cross-reference the image coordinates (0, 1, 2) with the piece locations and board status:\n"
            f"{ascii_board}\n\n"

            f"CRITICAL RULE: The target local cell (local_row, local_col) MUST be **EMPTY** on the active global board.\n"
            f"CRITICAL RULE: All output coordinates (global_row, global_col, local_row, local_col) MUST be **0, 1 or 2**."
        )

        assistant_content = best_move_json
            
        messages_for_current_example = [
            {"role": "system", "content": [{"type": "text", "text": system_content}]},
            {
                "role": "user",
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": user_prompt},
                ],
            },
            {"role": "assistant", "content": [{"type": "text", "text": assistant_content}]}
        ]
        messages_list.append(messages_for_current_example)

    examples["messages"] = messages_list
    return examples

In [8]:
def get_columns_to_remove(raw_dataset, columns_needed_for_map):
    return [
        col for col in raw_dataset.column_names 
        if col not in columns_needed_for_map and col != "messages"
    ]

def run_fine_tuning():
    model, tokenizer, processor = load_vlm_model()

    print("Applying PEFT (QLoRA) layer...")
    model = FastLanguageModel.get_peft_model(
        model,
        r = 32,
        target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"],
        lora_alpha = 32,
        lora_dropout = 0,
        bias = "none",
        use_gradient_checkpointing = "unsloth",
        random_state = 42,
        embedding_layer_names = ["vision_tower.image_projection"]
    )

    print("Loading training dataset...")
    try: 
        raw_train_dataset = Dataset.from_parquet(DATASET_TRAIN_PATH)
        print(f"Dataset loaded with {len(raw_train_dataset)} examples.")
    except Exception as e:
        print(f"Error loading dataset: {e}")
        return

    print("Loading evaluation dataset...")
    try: 
        raw_eval_dataset = Dataset.from_parquet(DATASET_EVALUATION_PATH)
        print(f"Dataset loaded with {len(raw_eval_dataset)} examples.")
    except Exception as e:
        print(f"Error loading dataset: {e}")
        return

    print("Formatting datasets for multimodal training...")

    columns_needed_for_map = ["image", "player", "allowed_squares", "best_move",
                              "chain_of_thought", "legal_moves", "ascii_board", 
                              "unplayable_boards"]
    
    columns_to_remove_train = get_columns_to_remove(raw_train_dataset, columns_needed_for_map)
    columns_to_remove_eval = get_columns_to_remove(raw_eval_dataset, columns_needed_for_map)
    
    train_dataset = raw_train_dataset.map(
        format_data_for_training,
        remove_columns = columns_to_remove_train,
        batched=True
    ).filter(lambda x: len(x['messages']) > 0)

    eval_dataset = raw_eval_dataset.map(
        format_data_for_eval,
        remove_columns = columns_to_remove_eval,
        batched=True
    ).filter(lambda x: len(x['messages']) > 0)

    print("Setting up training arguments...")
    training_arguments = TrainingArguments(
        per_device_train_batch_size = 1,
        per_device_eval_batch_size = 1,
        gradient_accumulation_steps = 16,
        warmup_steps = 40,
        num_train_epochs = 8,
        learning_rate = 2e-5,
        fp16 = False,
        bf16 = True,
        output_dir = OUTPUT_DIR_V4,
        optim = "paged_adamw_8bit",
        seed = 42,
        eval_strategy = "epoch",
        eval_steps = 50,
        logging_steps = 5,
        load_best_model_at_end = True,
        metric_for_best_model = "eval_loss",
        greater_is_better = False,
        save_strategy = "epoch",
        save_steps = 200,
        report_to = "none",
        remove_unused_columns = False,
        weight_decay = 0.01,
    )

    print("Initializing Trainer and starting fine-tuning...")
    trainer = Trainer(
        model = model,
        tokenizer = tokenizer,
        train_dataset = train_dataset,
        eval_dataset = eval_dataset,
        args = training_arguments,
        data_collator = MultimodalDataCollator,
    )

    trainer.train()

    print("\nTraining complete.")
    trainer.model.save_pretrained(OUTPUT_DIR_V4)
    tokenizer.save_pretrained(OUTPUT_DIR_V4)

In [9]:
if __name__ == "__main__":
    if not torch.cuda.is_available():
        print("Error: CUDA not detected!")
    else:
        run_fine_tuning()

Loading Model: unsloth/Qwen3-VL-8B-Instruct-unsloth-bnb-4bit
==((====))==  Unsloth 2025.11.6: Fast Qwen3_Vl patching. Transformers: 4.57.2.
   \\   /|    NVIDIA A100-SXM4-40GB MIG 4g.20gb. Num GPUs = 1. Max memory: 19.625 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.1+cu128. CUDA: 8.0. CUDA Toolkit: 12.8. Triton: 3.5.1
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.33.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Model and Tokenizer loaded successfully
Applying PEFT (QLoRA) layer...
Loading training dataset...
Dataset loaded with 3203 examples.
Loading evaluation dataset...
Dataset loaded with 401 examples.
Formatting datasets for multimodal training...
Setting up training arguments...
Initializing Trainer and starting fine-tuning...


  trainer = Trainer(
==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 3,203 | Num Epochs = 8 | Total steps = 1,608
O^O/ \_/ \    Batch size per device = 1 | Gradient accumulation steps = 16
\        /    Data Parallel GPUs = 1 | Total batch size (1 x 16 x 1) = 16
 "-____-"     Trainable parameters = 30,670,848 of 8,797,794,544 (0.35% trained)


Unsloth: Will smartly offload gradients to save VRAM!


Epoch,Training Loss,Validation Loss
1,0.0811,0.078109
2,0.0652,0.069689
3,0.0661,0.065381
4,0.0583,0.063964
5,0.0586,0.063729
6,0.0579,0.065534
7,0.0559,0.06658
8,0.0489,0.068977


Unsloth: Not an error, but Qwen3VLForConditionalGeneration does not accept `num_items_in_batch`.
Using gradient accumulation will be very slightly less accurate.
Read more on gradient accumulation issues here: https://unsloth.ai/blog/gradient



Training complete.
