<a href="https://colab.research.google.com/github/arrshad/llama3-meets-Tic-Tac-Toe/blob/main/llama3_meets_Tic_Tac_Toe.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install together langchain  langchain_openai ipywidgets==7.7.1



In [None]:
import os
import re

from together import Together
from langchain.chains import LLMChain
from langchain_openai import ChatOpenAI
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.prompts import ChatPromptTemplate


model_together = ChatOpenAI(
    base_url="https://api.together.xyz/v1",
    api_key=os.environ["TOGETHER_API_KEY"],
    model="meta-llama/Meta-Llama-3-70B-Instruct-Turbo"
)

In [None]:
def ask_llm(board: str, x_positions: list, o_positions: list) -> int:
    code_prompt = ChatPromptTemplate.from_messages(
      [    ("system", """
  You are an expert Tic-tac-toe player. The game is played on a 3x3 grid. You are the second player (O), and your opponent is X.

  CRITICAL: Follow this algorithm EXACTLY:

  1. Analyze the board state:
    - Create a 3x3 grid representation of the board based on the given moves.
    - Mark 'X' for opponent's moves, 'O' for your moves, and '.' for empty cells.

  2. Check for immediate threats:
    - Look for any row, column, or diagonal with TWO 'X' and ONE empty cell.
    - If found, you MUST place your 'O' in that empty cell. Return this move immediately.

  3. If no immediate threat:
    - Choose the center (2,2) if it's empty.
    - If center is taken, choose a corner (1,1), (1,3), (3,1), or (3,3) if available.
    - If no corners are available, choose any empty side (1,2), (2,1), (2,3), or (3,2).

  Your response MUST be in the format: "response: row,column"
  Rows are numbered 1 (top) to 3 (bottom). Columns are numbered 1 (left) to 3 (right).

  Example:
  Given board:
  .|X|.
  .|X|.
  O|.|.

  Correct response: "response: 3,2" (blocking the vertical threat)

  Analyze the provided game state carefully and follow the algorithm step by step.
      """),
      ("human", "{question}")
  ])

    judge_prompt = ChatPromptTemplate.from_messages([
    ("system", """
    You are an impartial Tic-tac-toe judge. Your task is to evaluate a move made by player O (second player) and determine if it's the best possible move.

    CRITICAL: Follow these steps to evaluate the move:

    1. Analyze the current board state.
    2. Check if the proposed move by player O is legal (in an empty cell).
    3. Determine if there was an immediate threat (two Xs in a line) that O should have blocked.
    4. If there was no immediate threat, evaluate if the move was strategic (center, corner, or side in that order of preference).

    Respond with one of the following:
    - "CORRECT" if the move was the best possible move.
    - "INCORRECT" if there was a better move available, especially if it failed to block an immediate threat.

    Also, provide a brief explanation of your judgment.
        """),
        ("human", "{question}")
    ])


    # Create chains for both prompts
    player_chain = LLMChain(llm=model_together, prompt=code_prompt)
    judge_chain = LLMChain(llm=model_together, prompt=judge_prompt)


    question = f"""
    **The moves by the first player (marked by X): {x_positions}. The moves by the second player (marked by O): {o_positions}. \
    You are a master at the Tic-tac-toe game, and you are unbeatable. You are the second player. What would be your next move \
    to prevent the first player from winning? Do not explain your move. Just give it in the format 'response: row number, column \
    number' for an available position. The board looks like this:**

    ```
    {board}
    ```

    **The correct move would be:**
    """
    messages = [("user", question)]

    def judge_move(question, proposed_move):
        judge_question = f"""
        Current game state:
        {question}

        Proposed move by player O:
        {proposed_move}

        Is this the best move? Explain your reasoning.
        """

        judge_result = judge_chain.run(question=judge_question)
        return judge_result

    def parse_move(move_string):
        match = re.search(r'response:\s*(\d),\s*(\d)', move_string)
        if match:
            return tuple(map(int, match.groups()))
        return None

    def get_new_move(question):
        new_question = f"""
        {question}
        Your previous move was incorrect. Please analyze the board carefully and make a new move.
        """
        new_result = player_chain.run(question=new_question)
        return new_result


    initial_move = player_chain.run(question=question)
    print("Initial move:", initial_move)

    # Judge the move
    judgment = judge_move(question, initial_move)

    # If the move is incorrect, try once more
    if "INCORRECT" in judgment.upper():
        new_move = get_new_move(question)
        print("Final move:", new_move)
        final_move = new_move
    else:
        print("Final move:", initial_move)
        final_move = initial_move

    #result = chain.invoke(messages)
    r, c = map(int, re.findall(r'\d+', final_move))
    return ((r-1)*3)+(c-1)


In [None]:
from typing import List, Literal, Tuple

import ipywidgets as widgets
from IPython.display import display, clear_output


board = ['' for _ in range(9)]
current_player: Literal['X', 'Y'] = 'X'
game_over = False
buttons: List[widgets.Button] = []


def index_to_coordinates(index: int) -> Tuple[int, int]:
    return ((index // 3)+1, (index % 3)+1)

def generate_board_str(board: List[Literal['X', 'Y', '']]) -> str:
    board_lines = []

    for i in range(3):
        row = '|'.join(board[i*3:(i+1)*3])
        row = row.replace('', '.')
        board_lines.append(row)

    return '\n'.join(board_lines)

def is_winner(board: List[Literal['X', 'Y', '']]) -> bool:
    # Winning combinations
    combinations = [
        (0, 1, 2), (3, 4, 5), (6, 7, 8),  # rows
        (0, 3, 6), (1, 4, 7), (2, 5, 8),  # columns
        (0, 4, 8), (2, 4, 6)  # diagonals
    ]
    for combo in combinations:
        if board[combo[0]] == board[combo[1]] == board[combo[2]] != '':
            return True
    return False

def button_click(b: widgets.Button):
    global current_player, game_over

    if game_over or b.description != '':
        return

    b.description = current_player
    idx = int(b.tooltip)
    board[idx] = current_player

    if is_winner(board):
        output.clear_output()
        with output:
            show_message(f"{'Human' if current_player == 'X' else 'AI'} wins!")
        game_over = True
    elif '' not in board:
        output.clear_output()
        with output:
            show_message("It's a draw!")
        game_over = True
    else:
        current_player = 'O' if current_player == 'X' else 'X'
        if current_player == 'O': # llm
            x_positions = [index_to_coordinates(i) for i, val in enumerate(board) if val == 'X']
            o_positions = [index_to_coordinates(i) for i, val in enumerate(board) if val == 'O']

            next_idx = ask_llm(generate_board_str(board), x_positions, o_positions)
            button_click(buttons[next_idx])

def show_message(text: str):
    text_label = widgets.Label(
        value=text.upper(),
        layout=widgets.Layout(
            font_size='20px',
            color='white',
            text_align='center'
        )
    )

    rectangle_widget = widgets.Box(
        children=[text_label],
        layout=widgets.Layout(
            width='340px',
            height='50px',
            border='1px solid gray',
            display='flex',
            align_items='center',
            justify_content='center',
            background_color='gray'
        )
    )

    display(rectangle_widget)

def create_horizontal_line() -> widgets.Box:
    return widgets.Box(layout=widgets.Layout(height='1px', width='340px', border='7px solid gray'))

def create_vertical_line() -> widgets.Box:
    return widgets.Box(layout=widgets.Layout(width='1px', height='105px', border='7px solid gray'))

def create_board() -> widgets.VBox:
    global buttons
    buttons = [
        widgets.Button(description='', tooltip=str(i),
                      layout=widgets.Layout(width='100px', height='100px'),
                      style={'font_weight': 'bold', 'font_size': '32px', 'button_color': 'transparent'})
        for i in range(9)
    ]
    for button in buttons:
        button.on_click(button_click)

    return widgets.VBox([
        widgets.HBox([buttons[0], create_vertical_line(), buttons[1], create_vertical_line(), buttons[2]]),
        create_horizontal_line(),
        widgets.HBox([buttons[3], create_vertical_line(), buttons[4], create_vertical_line(), buttons[5]]),
        create_horizontal_line(),
        widgets.HBox([buttons[6], create_vertical_line(), buttons[7], create_vertical_line(), buttons[8]])
    ])


output = widgets.Output()
display(create_board(), output)

VBox(children=(HBox(children=(Button(layout=Layout(height='100px', width='100px'), style=ButtonStyle(button_co…

Output()

Initial move: response: 2,2
Final move: response: 2,2
Initial move: response: 1,1
Final move: response: 1,1
Initial move: response: 1,2
Final move: response: 1,3
Initial move: response: 1,2
Final move: response: 1,2
