In [1]:
#!pip install unsloth transformers accelerate datasets bitsandbytes pandas pillow packaging ninja

In [2]:
import torch
import json, random
from unsloth import FastLanguageModel
from transformers import TrainingArguments, Trainer, AutoProcessor
from datasets import Dataset
from config import *

ðŸ¦¥ Unsloth: Will patch your computer to enable 2x faster free finetuning.
Unsloth: Your Flash Attention 2 installation seems to be broken?
A possible explanation is you have a new CUDA version which isn't
yet compatible with FA2? Please file a ticket to Unsloth or FA2.
We shall now use Xformers instead, which does not have any performance hits!
We found this negligible impact by benchmarking on 1x A100.
ðŸ¦¥ Unsloth Zoo will now patch everything to make training faster!


In [3]:
model = None
tokenizer = None
processor = None

### Load the model

In [4]:
def load_vlm_model():
    global model, tokenizer, processor
    print(f"Loading Model: {MODEL_NAME}")

    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name = MODEL_NAME,
        max_seq_length = 4096,
        dtype = None,
        load_in_4bit = True,
    )

    processor = AutoProcessor.from_pretrained(MODEL_NAME)

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"

    print("Model and Tokenizer loaded successfully")
    return model, tokenizer, processor

### Data Collator

In [5]:
def MultimodalDataCollator(batch):
    global model, processor, tokenizer

    images = [item["image"].convert("RGB") for item in batch]
    messages = [item["messages"] for item in batch]

    texts = [
        processor.apply_chat_template(
            msg,
            tokenize=False,
            add_generation_prompt=False
        )
        for msg in messages
    ]

    inputs = processor(
        images=images,
        text=texts,
        padding=True,
        truncation=True,
        max_length=2048,
        return_tensors="pt",
    )

    input_ids = inputs["input_ids"]
    labels = input_ids.clone()
    token_handler = processor.tokenizer
    
    im_start_token_id = token_handler.convert_tokens_to_ids("<|im_start|>")
    
    if im_start_token_id is None:
        im_start_token_id = 151644 

    for i in range(len(batch)):
        start_indices = (input_ids[i] == im_start_token_id).nonzero(as_tuple=True)[0]

        if len(start_indices) == 0:
            labels[i, :] = -100
            continue
            
        last_start_idx = start_indices[-1].item()
        mask_end_idx = last_start_idx + 3
        mask_end_idx = min(mask_end_idx, len(labels[i]))
        
        labels[i, :mask_end_idx] = -100
        if token_handler.pad_token_id is not None:
             labels[i][input_ids[i] == token_handler.pad_token_id] = -100

    inputs["labels"] = labels
    return inputs

### Formatting the data in batches for the training

In [6]:
def format_data(examples):
    messages_list = []
    batch_size = len(next(iter(examples.values())))

    coordinate_definition = (
        "**STRICT PROTOCOL:**\n"
        "1. FORMAT: All responses must be valid JSON snippets wrapped in JSON_START/JSON_END anchors.\n"
        "2. COORDINATES: Use 0-indexed integers (0, 1, 2) for all Row/Col values.\n"
        "3. VISUAL ANCHOR: The allowed square (active subgrid) is highlighted in GREEN. You must play there."
    )

    for i in range(batch_size):
        raw_allowed = examples["allowed_squares"][i]
        
        ag_r, ag_c = None, None
        if raw_allowed and len(raw_allowed) >= 2:
            ag_r, ag_c = raw_allowed[0], raw_allowed[1]
            gt_active_json = {"global_row": ag_r, "global_col": ag_c}
        else:
            gt_active_json = None
        
        image = examples["image"][i]
        rot_angle = examples["rotation_angle"][i]
        player_val = examples["player"][i]
        player_char = "X" if player_val == 1 else "O"
        best_move = examples["best_move"][i]
        legal_moves = examples["legal_moves"][i]
        full_state = examples["global_state"][i]
        original_cot = examples["chain_of_thought"][i]
        unplayable = examples["unplayable_boards"][i]
        
        system_content = f"You are an Ultimate Tic-Tac-Toe Visual Engine. {coordinate_definition}"
        
        if rot_angle == 0:
            task_type = random.choices(
                ["ALLOWED_SQUARE", "MOVE", "STATE", "LEGALITY"], 
                weights=[10, 50, 20, 20],
                k=1
            )[0]
        else:
            task_type = random.choices(
                ["ALLOWED_SQUARE", "STATE", "LEGALITY"], 
                weights=[20, 40, 40],
                k=1
            )[0]
        
        assistant_content = ""

        # --- 1. ALLOWED_SQUARE ---
        if task_type == "ALLOWED_SQUARE":
            user_prompt = "Identify the allowed square (active global subgrid) based on the green highlight. If none is highlighted, report null."
            assistant_content = f"JSON_START\n{{\"allowed_square\": {json.dumps(gt_active_json)}}}\nJSON_END"

        # --- 2. MOVE ---
        elif task_type == "MOVE":
            user_prompt = f"Visually analyze the board. It is Player {player_char}'s turn. Identify the allowed square (green highlight) and then select the best move."
            
            final_move_json = {
                "global_row": best_move["global_row"], "global_col": best_move["global_col"], 
                "local_row": best_move["local_row"], "local_col": best_move["local_col"]
            }

            clean_cot = original_cot
            # removing the final move in cot to avoid repetition and token wastage
            if "FINAL MOVE:" in original_cot:
                clean_cot = original_cot.split("FINAL MOVE:")[0].strip()

            if ag_r is not None:
                # force the model to also analyze the state in move
                scan_text = f"Active Global Board is ({ag_r}, {ag_c}). Visual Scan:\n"
                empty_spots = []
                
                for r in range(3):
                    row_desc = []
                    for c in range(3):
                        occ = next((p['player'] for p in full_state 
                                    if p['global_row'] == ag_r and p['global_col'] == ag_c 
                                    and p['local_row'] == r and p['local_col'] == c), 0)
                        
                        symbol = "X" if occ == 1 else "O" if occ == 2 else "."
                        row_desc.append(symbol)
                        if occ == 0:
                            empty_spots.append(f"({r},{c})")
                    scan_text += f"Row {r}: {row_desc}\n"
                
                scan_text += f"Available Squares: {', '.join(empty_spots)}.\n"
                scan_text += f"Strategy: {clean_cot}"
            else:
                scan_text = f"No active constraint. Free move. {clean_cot}"

            assistant_content = (
                "JSON_START\n"
                f'{{"allowed_square": {json.dumps(gt_active_json)}, '
                f'"thinking": {json.dumps(scan_text)}, ' 
                f'"best_move": {json.dumps(final_move_json)}}}'
                "\nJSON_END"
            )

        # --- 3. STATE ---
        elif task_type == "STATE":
            target_r, target_c = None, None

            # improve recognition of the smaller grids in the allowed square
            if raw_allowed and len(raw_allowed) >= 2 and random.random() < 0.5:
                 target_r, target_c = raw_allowed[0], raw_allowed[1]

            # otherwise make it look at populated subgrids rather than empty ones
            if target_r is None:
                occupied_globals = list(set((x['global_row'], x['global_col']) for x in full_state))
                if occupied_globals:
                    target_r, target_c = random.choice(occupied_globals)

            # start of the game where occupied_boards might be empty
            if target_r is None:
                target_r, target_c = random.randint(0, 2), random.randint(0, 2)

            user_prompt = f"Examine Global Subgrid ({target_r}, {target_c}). Represent the 3x3 local grid state as a matrix of 'X', 'O', or '.' (Empty)."
            
            matrix = []
            for lr in range(3):
                row_list = []
                for lc in range(3):
                    occ = next((p['player'] for p in full_state if p['global_row'] == target_r and p['global_col'] == target_c and p['local_row'] == lr and p['local_col'] == lc), 0)
                    val = "X" if occ == 1 else "O" if occ == 2 else "."
                    row_list.append(val)
                matrix.append(row_list)

            assistant_content = f"JSON_START\n{{\"target_global\": [{target_r}, {target_c}], \"grid_matrix\": {json.dumps(matrix)}}}\nJSON_END"

        # --- 4. LEGALITY ---
        elif task_type == "LEGALITY":
            choice_roll = random.random()
            move = None

            # pick an unplayable board so that the model learns to not play in completed boards
            if choice_roll < 0.40 and unplayable:
                dead_board = random.choice(unplayable)
                
                db_r = dead_board[0] if isinstance(dead_board, list) else dead_board['global_row']
                db_c = dead_board[1] if isinstance(dead_board, list) else dead_board['global_col']
                
                move = {
                    "global_row": db_r, "global_col": db_c, 
                    "local_row": random.randint(0, 2), "local_col": random.randint(0, 2)
                }

            # otherwise some random grid
            elif choice_roll < 0.80 and full_state:
                p = random.choice(full_state)
                move = {
                    "global_row": p['global_row'], "global_col": p['global_col'], 
                    "local_row": p['local_row'], "local_col": p['local_col']
                }

            # otherwise some legal move
            elif legal_moves:
                move = random.choice(legal_moves)
                
            # start of a game as the other checks would fail
            else:
                move = {
                    "global_row": random.randint(0, 2), "global_col": random.randint(0, 2), 
                    "local_row": random.randint(0, 2), "local_col": random.randint(0, 2)
                }

            is_legal = any(m['global_row'] == move['global_row'] and m['global_col'] == move['global_col'] and 
                           m['local_row'] == move['local_row'] and m['local_col'] == move['local_col'] for m in legal_moves)
            
            occ = next((p['player'] for p in full_state if p['global_row'] == move['global_row'] and 
                        p['global_col'] == move['global_col'] and p['local_row'] == move['local_row'] and 
                        p['local_col'] == move['local_col']), 0)
            sq_state = "X" if occ == 1 else "O" if occ == 2 else "Empty"

            is_dead_board = False
            if unplayable:
                for u in unplayable:
                    u_r = u[0] if isinstance(u, list) else u.get('global_row')
                    u_c = u[1] if isinstance(u, list) else u.get('global_col')
                    if u_r == move['global_row'] and u_c == move['global_col']:
                        is_dead_board = True
                        break

            if is_legal:
                reason = "Square is Empty, board is playable, and within active constraint."
            else:
                if is_dead_board:
                    reason = f"Global Board ({move['global_row']},{move['global_col']}) is already completed."
                
                elif sq_state != "Empty":
                    reason = f"Square is occupied by {sq_state}."
                
                else:
                    reason = "Violates allowed square constraint."

            user_prompt = (
                f"Is it legal for Player {player_char} to play at Global({move['global_row']},{move['global_col']}), Local({move['local_row']},{move['local_col']})? "
                "Step 1: Inspect the square state. Step 2: Check allowed square constraint. Step 3: Verdict."
            )

            assistant_content = (
                "JSON_START\n"
                f'{{"step_1_square_state": "{sq_state}", '
                f'"step_2_allowed_square": {json.dumps(gt_active_json)}, ' 
                f'"is_legal": {json.dumps(is_legal)}, '
                f'"reason": {json.dumps(reason)}}}'
                "\nJSON_END"
            )

        messages = [
            {"role": "system", "content": [{"type": "text", "text": system_content}]},
            {"role": "user", "content": [{"type": "text", "text": user_prompt}, {"type": "image"}]},
            {"role": "assistant", "content": [{"type": "text", "text": assistant_content}]}
        ]
        messages_list.append(messages)

    examples["messages"] = messages_list
    return examples

### Training & Finetuning

In [7]:
def get_columns_to_remove(raw_dataset, columns_needed_for_map):
    return [
        col for col in raw_dataset.column_names 
        if col not in columns_needed_for_map and col != "messages"
    ]

def run_fine_tuning():
    model, tokenizer, processor = load_vlm_model()

    print("Applying PEFT (QLoRA) layer...")
    model = FastLanguageModel.get_peft_model(
        model,
        r=128,
        target_modules="all-linear",
        lora_alpha=128,
        lora_dropout=0,
        bias="none",
        use_gradient_checkpointing="unsloth",
        random_state=42,
        embedding_layer_names=["visual"],
    )

    print("Loading training dataset...")
    raw_train_dataset = Dataset.from_parquet(DATASET_TRAIN_PATH)

    print("Loading evaluation dataset...")
    raw_eval_dataset = Dataset.from_parquet(DATASET_EVAL_PATH)

    columns_needed_for_map = ["image", "player", "allowed_squares", "best_move",
                              "chain_of_thought", "legal_moves", "ascii_board", 
                              "unplayable_boards", "global state", "rotation_angle"]
    
    columns_to_remove_train = get_columns_to_remove(raw_train_dataset, columns_needed_for_map)
    columns_to_remove_eval = get_columns_to_remove(raw_eval_dataset, columns_needed_for_map)
    
    train_dataset = raw_train_dataset.map(
        format_data,
        remove_columns = columns_to_remove_train,
        batched=True
    ).filter(lambda x: len(x['messages']) > 0)

    eval_dataset = raw_eval_dataset.map(
        format_data,
        remove_columns = columns_to_remove_eval,
        batched=True
    ).filter(lambda x: len(x['messages']) > 0)

    print("Setting up training arguments...")
    training_arguments = TrainingArguments(
        per_device_train_batch_size = 4,
        per_device_eval_batch_size = 1,
        gradient_accumulation_steps = 4,
        warmup_ratio = 0.1,
        num_train_epochs = 4,
        
        learning_rate = 2e-4,
        max_grad_norm = 0.3,
        weight_decay = 0.05,
        
        fp16 = False,
        bf16 = True,
        lr_scheduler_type = "cosine",
        output_dir = OUTPUT_DIR_V2,
        optim = "paged_adamw_8bit",
        seed = 42,
        
        eval_strategy = "steps",
        eval_steps = 100,
        save_strategy = "steps",
        save_steps = 100,
        logging_steps = 5,
        
        load_best_model_at_end = True,
        metric_for_best_model = "eval_loss",
        save_total_limit = 5,
        report_to = "none",
        remove_unused_columns = False,
        gradient_checkpointing = True, 
    )

    print("Initializing Trainer and starting fine-tuning...")
    trainer = Trainer(
        model = model,
        tokenizer = tokenizer,
        train_dataset = train_dataset,
        eval_dataset = eval_dataset,
        args = training_arguments,
        data_collator = MultimodalDataCollator,
    )

    trainer.train(resume_from_checkpoint=True)
    
    print("\nTraining complete.")
    trainer.model.save_pretrained(OUTPUT_DIR_V2)
    tokenizer.save_pretrained(OUTPUT_DIR_V2)

In [8]:
if __name__ == "__main__":
    run_fine_tuning()

Loading Model: unsloth/Qwen3-VL-8B-Instruct-unsloth-bnb-4bit
==((====))==  Unsloth 2026.1.4: Fast Qwen3_Vl patching. Transformers: 4.57.6.
   \\   /|    NVIDIA A100 80GB PCIe MIG 1g.20gb. Num GPUs = 1. Max memory: 19.5 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.10.0+cu128. CUDA: 8.0. CUDA Toolkit: 12.8. Triton: 3.6.0
\        /    Bfloat16 = TRUE. FA [Xformers = None. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Model and Tokenizer loaded successfully
Applying PEFT (QLoRA) layer...
Loading training dataset...


Generating train split: 0 examples [00:00, ? examples/s]

Dataset loaded with 3203 examples.
Loading evaluation dataset...


Generating train split: 0 examples [00:00, ? examples/s]

Dataset loaded with 401 examples.
Formatting datasets for multimodal training...


Map:   0%|          | 0/3203 [00:00<?, ? examples/s]

Filter:   0%|          | 0/3203 [00:00<?, ? examples/s]

Map:   0%|          | 0/401 [00:00<?, ? examples/s]

Filter:   0%|          | 0/401 [00:00<?, ? examples/s]

Setting up training arguments...
Initializing Trainer and starting fine-tuning...


  trainer = Trainer(
==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 3,203 | Num Epochs = 4 | Total steps = 804
O^O/ \_/ \    Batch size per device = 4 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (4 x 4 x 1) = 16
 "-____-"     Trainable parameters = 410,775,552 of 9,177,899,248 (4.48% trained)


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss,Validation Loss
200,0.3009,0.082992
300,0.2576,0.08221
400,0.1562,0.074895
500,0.0837,0.076684
600,0.0892,0.076509
700,0.0527,0.09078
800,0.0523,0.091683


Unsloth: Not an error, but Qwen3VLForConditionalGeneration does not accept `num_items_in_batch`.
Using gradient accumulation will be very slightly less accurate.
Read more on gradient accumulation issues here: https://unsloth.ai/blog/gradient



Training complete.
