In [8]:
from typing import List, Tuple, Optional
from dataclasses import dataclass

LINES = [
    (0, 1, 2), (3, 4, 5), (6, 7, 8),  # rows
    (0, 3, 6), (1, 4, 7), (2, 5, 8),  # cols
    (0, 4, 8), (2, 4, 6)              # diags
]

def check_winner(board: List[str]) -> Optional[str]:
    for a, b, c in LINES:
        if board[a] != ' ' and board[a] == board[b] == board[c]:
            return board[a]
    return None

def is_full(board: List[str]) -> bool:
    return all(c != ' ' for c in board)

def terminal(board: List[str]) -> bool:
    return check_winner(board) is not None or is_full(board)

def eval_terminal(board: List[str]) -> int:
    w = check_winner(board)
    if w == 'X':
        return 100
    if w == 'O':
        return -100
    return 0

def heuristic(board: List[str]) -> int:
    score = 0
    for a, b, c in LINES:
        line = [board[a], board[b], board[c]]
        xs = line.count('X')
        os = line.count('O')
        if xs > 0 and os > 0:
            continue
        if xs > 0 and os == 0:
            score += xs
        elif os > 0 and xs == 0:
            score -= os
    return score

def evaluate(board: List[str]) -> int:
    if terminal(board):
        return eval_terminal(board)
    return heuristic(board)

def available_moves(board: List[str]) -> List[int]:
    return [i for i in range(9) if board[i] == ' ']

def apply_move(board: List[str], move: int, player: str) -> List[str]:
    nb = board.copy()
    nb[move] = player
    return nb

@dataclass
class SearchResult:
    value: int
    best_move: Optional[int]
    nodes: int
    pv: List[int]

def minimax(board: List[str], player: str, depth: int, max_depth: int) -> SearchResult:
    if depth == max_depth or terminal(board):
        return SearchResult(evaluate(board), None, 1, [])
    moves = available_moves(board)
    if player == 'X':  # MAX
        best_val = -10**9
        best_move = None
        total_nodes = 1
        best_pv: List[int] = []
        for m in moves:
            res = minimax(apply_move(board, m, player), 'O', depth + 1, max_depth)
            total_nodes += res.nodes
            if res.value > best_val:
                best_val = res.value
                best_move = m
                best_pv = [m] + res.pv
        return SearchResult(best_val, best_move, total_nodes, best_pv)
    else:              # MIN
        best_val = 10**9
        best_move = None
        total_nodes = 1
        best_pv: List[int] = []
        for m in moves:
            res = minimax(apply_move(board, m, player), 'X', depth + 1, max_depth)
            total_nodes += res.nodes
            if res.value < best_val:
                best_val = res.value
                best_move = m
                best_pv = [m] + res.pv
        return SearchResult(best_val, best_move, total_nodes, best_pv)

def alphabeta(board: List[str], player: str, depth: int, max_depth: int, alpha: int, beta: int) -> SearchResult:
    if depth == max_depth or terminal(board):
        return SearchResult(evaluate(board), None, 1, [])
    moves = available_moves(board)
    total_nodes = 1
    if player == 'X':
        best_val = -10**9
        best_move = None
        best_pv: List[int] = []
        for m in moves:
            res = alphabeta(apply_move(board, m, player), 'O', depth + 1, max_depth, alpha, beta)
            total_nodes += res.nodes
            if res.value > best_val:
                best_val = res.value
                best_move = m
                best_pv = [m] + res.pv
            alpha = max(alpha, best_val)
            if alpha >= beta:
                break  
        return SearchResult(best_val, best_move, total_nodes, best_pv)
    else: 
        best_val = 10**9
        best_move = None
        best_pv: List[int] = []
        for m in moves:
            res = alphabeta(apply_move(board, m, player), 'X', depth + 1, max_depth, alpha, beta)
            total_nodes += res.nodes
            if res.value < best_val:
                best_val = res.value
                best_move = m
                best_pv = [m] + res.pv
            beta = min(beta, best_val)
            if alpha >= beta:
                break
        return SearchResult(best_val, best_move, total_nodes, best_pv)

def pretty_board(board: List[str]) -> str:
    rows = [' | '.join(c if c != ' ' else '.' for c in board[i:i+3])
            for i in range(0, 9, 3)]
    return '\n---------\n'.join(rows)

def board_from_pv(start_board: List[str], pv: List[int]) -> List[List[str]]:
    boards = []
    b = start_board.copy()
    player = 'X'
    boards.append(b.copy())
    for m in pv:
        b = apply_move(b, m, player)
        boards.append(b.copy())
        player = 'O' if player == 'X' else 'X'
    return boards

def demo(initial_board: Optional[List[str]] = None, max_depth: int = 6) -> None:
    if initial_board is None:
        initial_board = [' '] * 9

    print("=== Tic-Tac-Toe (depth-limited) ===")
    print("Initial position:\n", pretty_board(initial_board))
    print(f"\nDepth limit: {max_depth}\n")

    mm_res = minimax(initial_board, 'X', 0, max_depth)
    ab_res = alphabeta(initial_board, 'X', 0, max_depth, -10**9, 10**9)

    print("--- Minimax (no pruning) ---")
    print(f"Best value: {mm_res.value}")
    print(f"Best first move: {mm_res.best_move} (0..8 indexing)")
    print(f"Nodes visited: {mm_res.nodes}")
    print(f"Principal variation: {mm_res.pv}")

    print("\n--- Alpha-Beta pruning ---")
    print(f"Best value: {ab_res.value}")
    print(f"Best first move: {ab_res.best_move} (0..8 indexing)")
    print(f"Nodes visited: {ab_res.nodes}")
    print(f"Principal variation: {ab_res.pv}")

    pv_boards = board_from_pv(initial_board, ab_res.pv[:9])
    print("\n--- Boards along Alpha-Beta principal line (first few plies) ---")
    for i, b in enumerate(pv_boards):
        print(f"\nPly {i} (to-move: {'X' if i % 2 == 0 else 'O'})")
        print(pretty_board(b))

if __name__ == "__main__":
    demo(max_depth=9)


=== Tic-Tac-Toe (depth-limited) ===
Initial position:
 . | . | .
---------
. | . | .
---------
. | . | .

Depth limit: 9

--- Minimax (no pruning) ---
Best value: 0
Best first move: 0 (0..8 indexing)
Nodes visited: 549946
Principal variation: [0, 4, 1, 2, 6, 3, 5, 7, 8]

--- Alpha-Beta pruning ---
Best value: 0
Best first move: 0 (0..8 indexing)
Nodes visited: 18297
Principal variation: [0, 4, 1, 2, 6, 3, 5, 7, 8]

--- Boards along Alpha-Beta principal line (first few plies) ---

Ply 0 (to-move: X)
. | . | .
---------
. | . | .
---------
. | . | .

Ply 1 (to-move: O)
X | . | .
---------
. | . | .
---------
. | . | .

Ply 2 (to-move: X)
X | . | .
---------
. | O | .
---------
. | . | .

Ply 3 (to-move: O)
X | X | .
---------
. | O | .
---------
. | . | .

Ply 4 (to-move: X)
X | X | O
---------
. | O | .
---------
. | . | .

Ply 5 (to-move: O)
X | X | O
---------
. | O | .
---------
X | . | .

Ply 6 (to-move: X)
X | X | O
---------
O | O | .
---------
X | . | .

Ply 7 (to-move: O)
X | X 

In [9]:
from typing import Dict, List, Tuple
from dataclasses import dataclass
import json

State = str
Action = str

@dataclass
class MDP:
    states: List[State]
    actions: Dict[State, List[Action]]
    P: Dict[State, Dict[Action, List[Tuple[float, State, float]]]]
    gamma: float = 0.9


def value_iteration(
    mdp: MDP,
    tol: float = 1e-6,
    max_iter: int = 10_000,
    V0: Dict[State, float] = None,
    verbose: bool = True
) -> Tuple[Dict[State, float], Dict[State, Action], int]:

    S = mdp.states
    A = mdp.actions
    P = mdp.P
    gamma = mdp.gamma

    V = {s: 0.0 for s in S} if V0 is None else V0.copy()

    def q_value(s: State, a: Action, Vref: Dict[State, float]) -> float:
        return sum(p * (r + gamma * Vref[s2]) for p, s2, r in P[s][a])

    if verbose:
        print("Value Iteration start")
        print(f"gamma = {gamma}, tol = {tol}, max_iter = {max_iter}")
        print(f"Initial V: {json.dumps(V, indent=2)}\n")

    for k in range(1, max_iter + 1):
        V_new: Dict[State, float] = {}
        delta = 0.0
        for s in S:
            if s not in A or len(A[s]) == 0:
                V_new[s] = V[s]
                continue
            V_new[s] = max(q_value(s, a, V) for a in A[s])
            delta = max(delta, abs(V_new[s] - V[s]))
        V = V_new

        if verbose:
            pretty = ", ".join(f"{s}: {V[s]:.6f}" for s in S)
            print(f"Iter {k:3d} | {pretty} | delta={delta:.6e}")

        if delta < tol:
            if verbose:
                print(f"\nConverged in {k} iterations (delta={delta:.3e} < tol={tol}).\n")
            break
    else:
        if verbose:
            print(f"\nStopped after max_iter={max_iter} without meeting tol={tol}.\n")

    pi: Dict[State, Action] = {}
    for s in S:
        if s in A and len(A[s]) > 0:
            best_a = max(A[s], key=lambda a: q_value(s, a, V))
            pi[s] = best_a

    if verbose:
        print("Greedy policy π* (w.r.t. converged V):")
        for s in S:
            print(f"  {s}: {pi.get(s, '-')}")
        print("\nFinal values:")
        for s in S:
            print(f"  V({s}) = {V[s]:.6f}")

    return V, pi, k


# S = {s1, s2}, A = {a1, a2}, gamma=0.9
# From s1 with a1 → s1 (0.5, r=5), s2 (0.5, r=10)
# From s1 with a2 → s2 (1.0, r=3)
# From s2 with a1 → s1 (1.0, r=2)
# From s2 with a2 → s2 (1.0, r=4)

def build_two_state_mdp() -> MDP:
    states = ["s1", "s2"]
    actions = {
        "s1": ["a1", "a2"],
        "s2": ["a1", "a2"],
    }
    P = {
        "s1": {
            "a1": [(0.5, "s1", 5.0), (0.5, "s2", 10.0)],
            "a2": [(1.0, "s2", 3.0)],
        },
        "s2": {
            "a1": [(1.0, "s1", 2.0)],
            "a2": [(1.0, "s2", 4.0)],
        },
    }
    return MDP(states=states, actions=actions, P=P, gamma=0.9)


if __name__ == "__main__":
    mdp = build_two_state_mdp()
    value_iteration(mdp, tol=1e-8, max_iter=1000, verbose=True)

Value Iteration start
gamma = 0.9, tol = 1e-08, max_iter = 1000
Initial V: {
  "s1": 0.0,
  "s2": 0.0
}

Iter   1 | s1: 7.500000, s2: 4.000000 | delta=7.500000e+00
Iter   2 | s1: 12.675000, s2: 8.750000 | delta=5.175000e+00
Iter   3 | s1: 17.141250, s2: 13.407500 | delta=4.657500e+00
Iter   4 | s1: 21.246938, s2: 17.427125 | delta=4.105688e+00
Iter   5 | s1: 24.903328, s2: 21.122244 | delta=3.695119e+00
Iter   6 | s1: 28.211507, s2: 24.412995 | delta=3.308179e+00
Iter   7 | s1: 31.181026, s2: 27.390357 | delta=2.977361e+00
Iter   8 | s1: 33.857122, s2: 30.062924 | delta=2.676096e+00
Iter   9 | s1: 36.264021, s2: 32.471410 | delta=2.408486e+00
Iter  10 | s1: 38.430944, s2: 34.637619 | delta=2.166923e+00
Iter  11 | s1: 40.380853, s2: 36.587849 | delta=1.950231e+00
Iter  12 | s1: 42.135916, s2: 38.342768 | delta=1.755063e+00
Iter  13 | s1: 43.715408, s2: 39.922325 | delta=1.579557e+00
Iter  14 | s1: 45.136980, s2: 41.343867 | delta=1.421572e+00
Iter  15 | s1: 46.416381, s2: 42.623282 | de