In [1]:
!pip uninstall -y flash_attn xformers
!pip install unsloth transformers accelerate datasets bitsandbytes pandas pillow packaging ninja

Found existing installation: flash-attn 2.3.6
Uninstalling flash-attn-2.3.6:
  Successfully uninstalled flash-attn-2.3.6
Found existing installation: xformers 0.0.33.post2
Uninstalling xformers-0.0.33.post2:
  Successfully uninstalled xformers-0.0.33.post2
Looking in indexes: https://nexus.iisys.de/repository/ki-awz-pypi-group/simple, https://pypi.org/simple
Collecting xformers>=0.0.27.post2 (from unsloth)
  Using cached xformers-0.0.33.post2-cp39-abi3-manylinux_2_28_x86_64.whl.metadata (1.2 kB)
Using cached xformers-0.0.33.post2-cp39-abi3-manylinux_2_28_x86_64.whl (122.9 MB)
Installing collected packages: xformers
Successfully installed xformers-0.0.33.post2


In [1]:
import torch
import os, json
import pandas as pd
from unsloth import FastLanguageModel
from transformers import TrainingArguments, Trainer, AutoProcessor
from datasets import Dataset, load_from_disk
from PIL import Image as PILImage
from config import *
import gc

ðŸ¦¥ Unsloth: Will patch your computer to enable 2x faster free finetuning.
ðŸ¦¥ Unsloth Zoo will now patch everything to make training faster!


In [2]:
model = None
tokenizer = None
processor = None

In [3]:
def load_vlm_model():
    global model, tokenizer, processor
    print(f"Loading Model: {MODEL_NAME}")

    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name = MODEL_NAME,
        max_seq_length = MAX_SEQ_LENGTH,
        dtype = DTYPE,
        load_in_4bit = True,
    )

    processor = AutoProcessor.from_pretrained(MODEL_NAME)
    tokenizer.pad_token_id = processor.tokenizer.pad_token_id 
    tokenizer.padding_side = "right"

    print("Model and Tokenizer loaded successfully")
    return model, tokenizer, processor

In [4]:
def MultimodalDataCollator(data):
    """
    Data collator responsible for batching, tokenization, image processing, 
    and applying label masking for multimodal training.

    1. Extracts images and conversation messages.
    2. Converts the conversation messages (System + User + Assistant) into a 
       single text prompt string using the tokenizer's chat template.
    3. Processes images and tokenizes the full text prompt into a batch using 
       the VLM's multimodal processor.
    4. CRITICALLY: Applies label masking (setting loss to -100) to all tokens 
       belonging to the System and User turns, ensuring the model only learns 
       to predict the Assistant's (target JSON) response.
    """
    global model, processor, tokenizer
    if model is None or processor is None or tokenizer is None:
        raise ValueError("Model, processor, tokenizer must be loaded globally before using this.")

    # 1. Extract Images and Messages
    images = [item["image"].convert("RGB") for item in data]
    messages = [item["messages"] for item in data]

    # 2. Generate the raw prompt strings for the FULL conversation
    text_prompts = [
        tokenizer.apply_chat_template(
            conversation,
            tokenize=False, 
            add_generation_prompt=False
        )
        for conversation in messages
    ]

    # 3. Process Images and Tokenize Text (Full Batch)
    try: 
        # The processor handles both image preprocessing (to image_pixel_values) 
        # and full text tokenization for the VLM, resulting in a padded batch.
        inputs = processor(
            images = images,
            text = text_prompts,
            return_tensors = "pt",
            padding = True,
            max_length = processor.tokenizer.model_max_length,
        )
    except Exception as e:
        print(f"Unable to create batch. Actual error: {e}")
        raise

    # 4. Manual Label Masking: Clone input_ids to create the initial labels tensor
    labels = inputs["input_ids"].clone()
    
    # 5. Determine the masking cutoff point for each example
    for i, conversation in enumerate(messages):
        # We need the tokenized length of the [system, user] part to know where to stop masking
        system_user_conversation = conversation[:2] 
        
        # A. Generate the System + User string (text only)
        # we must use the full multimodal processor on the System + User part. 
        # The previous attempt failed because the text-only tokenizer does not 
        # account for the image patch tokens inserted by the VLM processor.
        
        # 1. Generate the System + User string (includes <image> placeholder)
        system_user_text = tokenizer.apply_chat_template(
            system_user_conversation,
            tokenize=False,
            # This is critical to include the final <|assistant|> turn token
            add_generation_prompt=True 
        )

        # 2. Use the FULL multimodal processor on this single example to get the length.
        # This correctly tokenizes the text and accounts for the image tokens.
        try:
            system_user_tokens = processor(
                images = [images[i]], # Pass the single image from the batch
                text = system_user_text,
                return_tensors = "pt",
                padding = False, # Do not pad single example
            ).input_ids.squeeze()
        except Exception as e:
            print(f"Error processing single prompt for masking length: {e}")
            system_user_tokens = torch.tensor([-1])


        # Calculate the length for masking
        if system_user_tokens.dim() == 0 or system_user_tokens[0] == -1:
             # Handle error case by skipping (loss will be computed for full sequence, which is wrong)
             print(f"Warning: Masking calculation failed for sample {i}. Using safe, non-masked sequence.")
             cutoff_length = 0 # No mask
        else:
            # The number of tokens corresponding to the System + User part + <|assistant|> token
            cutoff_length = system_user_tokens.shape[0]

        # Mask out the System and User tokens in the padded batch by setting them to -100
        # The loss function ignores tokens with label -100.
        labels[i, :cutoff_length] = -100

    inputs["labels"] = labels
    return inputs

In [5]:
def format_data_for_training(examples):
    messages_list = []
    batch_size = len(examples.get("image", []))

    if batch_size == 0:
        return {"messages": []}

    for i in range(batch_size):
        
        if not isinstance(examples.get("image", [None]*batch_size)[i], PILImage.Image):
            messages_list.append([])
            continue

        allowed_squares = examples["allowed_squares"][i]
        unplayable_boards = examples["unplayable_boards"][i]
        best_move_dict = examples["best_move"][i]
        player_turn = examples["player"][i]
        cot_text = examples["chain_of_thought"][i]
        legal_moves = examples["legal_moves"][i]
        ascii_board = examples["ascii_board"][i]
            
        if isinstance(allowed_squares, (list, tuple)) and len(allowed_squares) == 2:
            ACTIVE_GLOBAL_ROW, ACTIVE_GLOBAL_COL = allowed_squares
            active_board_desc = f"The active board coordinates are Global[{ACTIVE_GLOBAL_ROW}, {ACTIVE_GLOBAL_COL}] (highlighted in GREEN)."
        else:
            active_board_desc = "The player has a FREE MOVE, and can play in any board that is NOT UNPLAYABLE."


        unplayable_list_str = json.dumps(unplayable_boards)    
        best_move_json = json.dumps(best_move_dict)
        legal_moves_str = json.dumps(legal_moves, indent=2)

        # --- System Prompt (defines model's role) ---
        system_content = (
            "You are an expert Ultimate Tic-Tac-Toe player. "
            "Your task is to determine the optimal move for the current player. "
            "Your output **MUST** be a single, raw JSON object containing the chosen move."
        )

        # --- User Prompt (provides all game context) ---
        user_prompt = (
            f"Player: {player_turn} (X=Player 1, O=Player 2)\n"
            f"Analyze the board state in the image and determine the optimal move.\n\n"
            
            f"--- CRITICAL UTTT RULES ---\n"
            f"* **UNPLAYABLE BOARDS:** The following Global Boards are WON or TIED and **STRICTLY FORBIDDEN** for play:\n"
            f"    Unplayable Boards (Global R, C): {unplayable_list_str}\n"
            f"* **ACTIVE BOARD CONSTRAINT:** {active_board_desc}\n"
            f"    -   If the highlighted board is UNPLAYABLE, the player gets a **FREE MOVE**.\n"
            f"    -   Otherwise, the move **MUST** be in the highlighted board.\n"
            
            f"\n--- BOARD VISUALIZATION ---\n"
            f"Use the image and the following ASCII diagram for coordinate reference:\n"
            f"{ascii_board}\n\n"
            
            f"CRITICAL RULE: All coordinates (global_row, global_col, local_row, local_col) MUST be **0, 1 or 2**.\n"
            f"The set of all legal moves is provided for reference:\n{legal_moves_str}\n"
            
            f"\n*** Output the optimal move as a single, raw JSON object. ***"
        )

        assistant_content = ""
        if cot_text:
            assistant_content += f"<think>{cot_text}</think>\n"
            assistant_content += f"{best_move_json}\n"
        else:
            assistant_content += f"{best_move_json}"
            
        messages_for_current_example = [
            {"role": "system", "content": [{"type": "text", "text": system_content}]},
            {
                "role": "user",
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": user_prompt},
                ],
            },
            {"role": "assistant", "content": [{"type": "text", "text": assistant_content}]}
        ]
        messages_list.append(messages_for_current_example)

    examples["messages"] = messages_list
    return examples

In [6]:
def run_fine_tuning():
    model, tokenizer, processor = load_vlm_model()

    print("Applying PEFT (QLoRA) layer...")
    model = FastLanguageModel.get_peft_model(
        model,
        r = 32,
        target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"],
        lora_alpha = 32,
        lora_dropout = 0,
        bias = "none",
        use_gradient_checkpointing = "unsloth",
        random_state = 42,
        embedding_layer_names = ["vision_tower.image_projection"]
    )

    print("Loading training dataset...")
    try: 
        raw_train_dataset = Dataset.from_parquet(DATASET_TRAIN_PATH)
        print(f"Dataset loaded with {len(raw_train_dataset)} examples.")
    except Exception as e:
        print(f"Error loading dataset: {e}")
        return

    print("Loading evaluation dataset...")
    try: 
        raw_eval_dataset = Dataset.from_parquet(DATASET_EVALUATION_PATH)
        print(f"Dataset loaded with {len(raw_eval_dataset)} examples.")
    except Exception as e:
        print(f"Error loading dataset: {e}")
        return

    print("Formatting datasets for multimodal training...")

    columns_needed_for_map = ["image", "player", "allowed_squares", "best_move",
                              "chain_of_thought", "legal_moves", "ascii_board", 
                              "unplayable_boards"]
    
    columns_to_remove = [
        col for col in raw_train_dataset.column_names 
        if col not in columns_needed_for_map and col != "messages"
    ]
    
    train_dataset = raw_train_dataset.map(
        format_data_for_training,
        remove_columns = columns_to_remove,
        batched=True
    ).filter(lambda x: len(x['messages']) > 0)

    eval_dataset = raw_eval_dataset.map(
        format_data_for_training,
        remove_columns = columns_to_remove,
        batched=True
    ).filter(lambda x: len(x['messages']) > 0)

    print("VRAM Cleanup: Forcing garbage collection before training.")
    del raw_train_dataset, raw_eval_dataset 
    gc.collect() # Trigger Python garbage collection
    if torch.cuda.is_available():
        torch.cuda.empty_cache() # Clear PyTorch's VRAM cache
    print("Cleanup complete. Starting Trainer initialization.")

    print("Setting up training arguments...")
    training_arguments = TrainingArguments(
        per_device_train_batch_size = 1,
        per_device_eval_batch_size = 1,
        gradient_accumulation_steps = 8,
        warmup_steps = 5,
        num_train_epochs = 10,
        learning_rate = 5e-5,
        fp16 = False,
        bf16 = True,
        output_dir = OUTPUT_DIR_V3,
        optim = "paged_adamw_8bit",
        seed = 42,
        eval_strategy = "epoch",
        eval_steps = 50,
        logging_steps = 5,
        load_best_model_at_end = True,
        metric_for_best_model = "eval_loss",
        greater_is_better = False,
        save_strategy = "epoch",
        save_steps = 100,
        report_to = "none",
        remove_unused_columns = False,
        weight_decay = 0.01,
    )

    print("Initializing Trainer and starting fine-tuning...")
    trainer = Trainer(
        model = model,
        tokenizer = tokenizer,
        train_dataset = train_dataset,
        eval_dataset = eval_dataset,
        args = training_arguments,
        data_collator = MultimodalDataCollator,
    )

    trainer.train()

    print("\nTraining complete.")
    trainer.model.save_pretrained(OUTPUT_DIR_V3)
    tokenizer.save_pretrained(OUTPUT_DIR_V3)

In [7]:
if __name__ == "__main__":
    if not torch.cuda.is_available():
        print("Error: CUDA not detected!")
    else:
        run_fine_tuning()

Loading Model: unsloth/Qwen3-VL-8B-Instruct-unsloth-bnb-4bit
==((====))==  Unsloth 2025.11.6: Fast Qwen3_Vl patching. Transformers: 4.57.2.
   \\   /|    NVIDIA A100-SXM4-40GB MIG 4g.20gb. Num GPUs = 1. Max memory: 19.625 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.1+cu128. CUDA: 8.0. CUDA Toolkit: 12.8. Triton: 3.5.1
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.33.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Model and Tokenizer loaded successfully
Applying PEFT (QLoRA) layer...
Loading training dataset...
Dataset loaded with 800 examples.
Loading evaluation dataset...
Dataset loaded with 101 examples.
Formatting datasets for multimodal training...
VRAM Cleanup: Forcing garbage collection before training.
Cleanup complete. Starting Trainer initialization.
Setting up training arguments...
Initializing Trainer and starting fine-tuning...


  trainer = Trainer(
==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 800 | Num Epochs = 10 | Total steps = 1,000
O^O/ \_/ \    Batch size per device = 1 | Gradient accumulation steps = 8
\        /    Data Parallel GPUs = 1 | Total batch size (1 x 8 x 1) = 8
 "-____-"     Trainable parameters = 30,670,848 of 8,797,794,544 (0.35% trained)


Unsloth: Will smartly offload gradients to save VRAM!


Epoch,Training Loss,Validation Loss
1,1.1012,0.959011
2,0.901,0.876762
3,0.8455,0.838616
4,0.8059,0.818492
5,0.7376,0.802874
6,0.7099,0.789876
7,0.7357,0.789669
8,0.5782,0.794076
9,0.6416,0.794031
10,0.6452,0.799393


Unsloth: Not an error, but Qwen3VLForConditionalGeneration does not accept `num_items_in_batch`.
Using gradient accumulation will be very slightly less accurate.
Read more on gradient accumulation issues here: https://unsloth.ai/blog/gradient



Training complete.
