In [1]:
pip install fastapi uvicorn unsloth transformers pillow torch numpy pydantic nest_asyncio asyncio

Looking in indexes: https://nexus.iisys.de/repository/ki-awz-pypi-group/simple, https://pypi.org/simple
Note: you may need to restart the kernel to use updated packages.


In [2]:
from fastapi import FastAPI, HTTPException
from unsloth import FastLanguageModel
from transformers import AutoProcessor
from pydantic import BaseModel
from PIL import Image as PILImage
import torch, io, base64, json, re, nest_asyncio, uvicorn, asyncio
import numpy as np
from config import *

ðŸ¦¥ Unsloth: Will patch your computer to enable 2x faster free finetuning.
Unsloth: Your Flash Attention 2 installation seems to be broken?
A possible explanation is you have a new CUDA version which isn't
yet compatible with FA2? Please file a ticket to Unsloth or FA2.
We shall now use Xformers instead, which does not have any performance hits!
We found this negligible impact by benchmarking on 1x A100.
ðŸ¦¥ Unsloth Zoo will now patch everything to make training faster!


In [3]:
nest_asyncio.apply()
api = FastAPI()

model, tokenizer = FastLanguageModel.from_pretrained(
        model_name = MODEL_NAME,
        max_seq_length = 2048,
        dtype = None,
        load_in_4bit = True,
        trust_remote_code = True
    )

processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
model.config.use_cache = False

model.load_adapter(OUTPUT_DIR_V4)
FastLanguageModel.for_inference(model)

class MoveRequest(BaseModel):
    image_base64: str
    player_turn: str
    global_state: list

def check_win(grid):
    grid = np.array(grid)
    for i in range(3):
        if grid[i, 0] == grid[i, 1] == grid[i, 2] != 0: return grid[i, 0]
        if grid[0, i] == grid[1, i] == grid[2, i] != 0: return grid[0, i]
    if (grid[0, 0] == grid[1, 1] == grid[2, 2] != 0) or (grid[2, 0] == grid[1, 1] == grid[0, 2] != 0):
        return grid[1, 1]
    return -1 if np.all(grid != 0) else 0

def reconstruct_board_matrix(global_state_list):
    board_matrix = [[[[0 for _ in range(3)] for _ in range(3)] for _ in range(3)] for _ in range(3)]
    global_status = [[0 for _ in range(3)] for _ in range(3)]
    for cell in global_state_list:
        g_r, g_c, l_r, l_c = cell['global_row'], cell['global_col'], cell['local_row'], cell['local_col']
        board_matrix[g_r][g_c][l_r][l_c] = cell['player']
    for g_r in range(3):
        for g_c in range(3):
            global_status[g_r][g_c] = check_win(board_matrix[g_r][g_c])
    return board_matrix, global_status

def get_unplayable_boards(global_status):
    unplayable = []
    for r in range(3):
        for c in range(3):
            if global_status[r][c] != 0:
                unplayable.append({"global_row": r, "global_col": c})
    return unplayable

def render_ascii_board(global_state_list):
    symbols = {0: '.', 1: 'X', 2: 'O'}
    state_map = {(c['global_row'], c['global_col'], c['local_row'], c['local_col']): symbols.get(c['player'], '.') for c in global_state_list}
    sections = []
    for g_r in range(3):
        for g_c in range(3):
            s = [f"=== Global Board [{g_r}, {g_c}] ===", "    0 1 2", "   -------"]
            for l_r in range(3):
                row = [state_map.get((g_r, g_c, l_r, l_c), '.') for l_c in range(3)]
                s.append(f"{l_r} | " + " ".join(row))
            sections.append("\n".join(s))
    return "\n\n".join(sections)

def parse_move_from_text(text):
    match = re.search(r'\{.*\}', text, re.DOTALL)
    if match:
        try: return json.loads(match.group(0).replace("'", '"'))
        except: return None
    return None

def format_squares_to_str(squares_list):
    if not squares_list:
        return "[]"

    formatted = []
    for sq in squares_list:
        if isinstance(sq, dict) and 'global_row' in sq and 'global_col' in sq:
            formatted.append(f"({sq['global_row']}, {sq['global_col']})")

    return "[" + ", ".join(formatted) + "]"

@api.post("/predict_move")
async def predict_move(request: MoveRequest):
    try:
        img_bytes = base64.b64decode(request.image_base64)
        image = PILImage.open(io.BytesIO(img_bytes)).convert("RGB")

        _, global_status = reconstruct_board_matrix(request.global_state)
        unplayable_boards = get_unplayable_boards(global_status)
        unplayable_list_str = format_squares_to_str(unplayable_boards)
        ascii_board = render_ascii_board(request.global_state)

        system_content = (
            f"You are an expert Ultimate Tic-Tac-Toe player. "
            f"Your goal is to identify the optimal, legal move based on the provided image and context. "
            f"The final output must be **ONLY** a raw JSON object containing the chosen move."
            f"{{\"global_row\": r, \"global_col\": c, \"local_row\": lr, \"local_col\": lc}}."
        )

        user_prompt_text = (
            f"Player: {request.player_turn} (X=Player 1, O=Player 2)\n"
            f"Analyze the board state in the image and determine the optimal move.\n\n"
            f"--- BOARD CONTEXT ---\n"
            f"**Allowed/Active Board:** The global board highlighted in **BRIGHT GREEN** in the image is the current active board constraint. If this board is already won/tied, you must select any other available board (Free Play).\n"
            f"**Unplayable Boards:** The following Global Boards are already WON or TIED and cannot be played: {unplayable_list_str}\n\n"
            f"--- ASCII VISUALIZATION ---\n"
            f"Use this labeled diagram to cross-reference the image coordinates (0, 1, 2) with the piece locations and board status:\n"
            f"{ascii_board}\n\n"
            f"CRITICAL RULE: The target local cell (local_row, local_col) MUST be **EMPTY** on the global board (global_row, global_col).\n"
            f"CRITICAL RULE: All output coordinates (global_row, global_col, local_row, local_col) MUST be **0, 1 or 2**."
        )

        messages = [
            {"role": "system", "content": [{"type": "text", "text": system_content}]},
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": image},
                    {"type": "text", "text": user_prompt_text},
                ],
            },
        ]

        inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device)

        with torch.no_grad():
            outputs = model.generate(**inputs, max_new_tokens=128, do_sample=False, temperature=0.1)

        response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        if "assistant" in response_text:
            response_text = response_text.split("assistant")[-1].strip()

        move = parse_move_from_text(response_text)
        if not move:
            raise HTTPException(status_code=422, detail=f"Model failed: {response_text}")
        return move

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

Are you certain you want to do remote code execution?
==((====))==  Unsloth 2026.1.3: Fast Qwen3_Vl patching. Transformers: 4.57.3.
   \\   /|    NVIDIA A100 80GB PCIe MIG 1g.20gb. Num GPUs = 1. Max memory: 19.5 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.1+cu128. CUDA: 8.0. CUDA Toolkit: 12.8. Triton: 3.5.1
\        /    Bfloat16 = TRUE. FA [Xformers = None. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Qwen3_Vl does not support SDPA - switching to fast eager.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
if __name__ == "__main__":
    config = uvicorn.Config(
        api, 
        host="0.0.0.0", 
        port=8000, 
        log_level="info",
    )
    server = uvicorn.Server(config)
    
    loop = asyncio.get_event_loop()
    loop.create_task(server.serve())

INFO:     Started server process [7832]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
