In [None]:
import numpy as np
from collections import deque

eps = 1e-12

#Use State instead of Game class to simplify tree operations below (also undo is expensive :/)
class TicTacToeState:
    def __init__(self, board=None, to_move=1):
        self.board = np.zeros((3,3), dtype=int) if board is None else board.copy()
        self.to_move = to_move  # 1 for X, -1 for O

    def legal_actions(self):
        return [(i,j) for i in range(3) for j in range(3) if self.board[i,j] == 0]

    def is_terminal(self):
        # Win conditions
        for player in (1, -1):
            # rows and columns
            if any(np.all(self.board[i,:] == player) for i in range(3)) or \
               any(np.all(self.board[:,j] == player) for j in range(3)):
                return True, player
            # diagonals
            if (self.board[0,0] == player and self.board[1,1] == player and self.board[2,2] == player) or \
               (self.board[0,2] == player and self.board[1,1] == player and self.board[2,0] == player):
                return True, player
        if not self.legal_actions():
            return True, 0  # draw
        return False, None

    def next_state(self, action):
        new_state = TicTacToeState(self.board, -self.to_move)
        new_state.board[action] = self.to_move
        return new_state

    def __hash__(self):
        return hash(self.board.tobytes()) ^ hash(self.to_move)

    def __eq__(self, other):
        return isinstance(other, TicTacToeState) and self.to_move == other.to_move and np.array_equal(self.board, other.board)

    def draw(self):
        def convert(n):
            return "X"*(n==1) + " "*(n==0) + "O"*(n==-1)
        
        temp = [[convert(int(self.board[j][i])) for i in range(3)] for j in range(3)]
        print(f"{temp[0][0]}|{temp[0][1]}|{temp[0][2]}")
        print("-----")
        print(f"{temp[1][0]}|{temp[1][1]}|{temp[1][2]}")
        print("-----")
        print(f"{temp[2][0]}|{temp[2][1]}|{temp[2][2]}")  
        print()
    
class SequenceForm:
    def __init__(self):
        self.states = []
        self.state_to_id = {}
        self.indices = {}  # (state_id, action) -> seq_idx
        self.num_sequences = 0

    def build(self, root_state):
        # BFS to collect all states
        queue = deque([root_state])
        while queue:
            st = queue.popleft()
            if st in self.state_to_id:
                continue
            sid = len(self.states)
            self.state_to_id[st] = sid
            self.states.append(st)
            terminal, _ = st.is_terminal()
            if not terminal:
                for a in st.legal_actions():
                    queue.append(st.next_state(a))
        # Assign sequence indices
        seq_idx = 0
        for sid, st in enumerate(self.states):
            terminal, _ = st.is_terminal()
            if terminal:
                continue
            for a in st.legal_actions():
                self.indices[(sid, a)] = seq_idx
                seq_idx += 1
        self.num_sequences = seq_idx

def GRADIENT(x):
    return 1.0 + np.log(np.clip(x, EPS, None))

def ARGCONJUGATE(g):
    exp_g = np.exp(g - np.max(g))
    return exp_g / exp_g.sum() # Just a point estimate

# --- Bandit regret solver ---
class BanditRegretSolver:
    def __init__(self, sf: SequenceForm, eta: float):
        self.sf = sf
        self.eta = eta
        self.x = np.ones(sf.num_sequences) / sf.num_sequences

    def next_strategy(self):
        return self.x.copy()

    def observe_loss(self, loss_estimate):
        g = self.eta * loss_estimate - GRADIENT(self.x)
        self.x = ARGCONJUGATE(-g)

    def build_loss_estimate(self, path, payoff, x):
        est = np.zeros_like(x)
        for i in path:
            if x[i] > 0:
                est[i] = payoff / x[i]
        return est

def sample_self_play(sf: SequenceForm, x, o):
    path_X, path_O = [], []
    st = sf.states[0]  # root state
    while True:
        terminal, payoff = st.is_terminal()
        if terminal:
            return path_X, path_O, payoff
        sid = sf.state_to_id[st]
        legal = st.legal_actions()
        seqs = [sf.indices[(sid, a)] for a in legal]
        probs = (x if st.to_move == 1 else o)[seqs]
        total = probs.sum()
        if total <= 0 or np.isnan(total):
            probs = np.ones(len(legal)) / len(legal) # Handle edges cases 
        else:
            probs = probs / total
        choice = np.random.choice(len(legal), p=probs)
        idx = seqs[choice]
        if st.to_move == 1:
            path_X.append(idx)
        else:
            path_O.append(idx)
        st = st.next_state(legal[choice])


def train_self_play(T=10000, eta=0.1):
    root = TicTacToeState()
    sf = SequenceForm()
    sf.build(root)
    X_solver = BanditRegretSolver(sf, eta)
    O_solver = BanditRegretSolver(sf, eta)
    avg_x = np.zeros(sf.num_sequences)
    avg_o = np.zeros(sf.num_sequences)

    for t in range(1, T+1):
        x_t = X_solver.next_strategy()
        o_t = O_solver.next_strategy()
        path_X, path_O, payoff = sample_self_play(sf, x_t, o_t)
        lx = X_solver.build_loss_estimate(path_X, payoff, x_t)
        lo = O_solver.build_loss_estimate(path_O, -payoff, o_t)
        X_solver.observe_loss(lx)
        O_solver.observe_loss(lo)
        avg_x += x_t
        avg_o += o_t
        
        if t %100 == 0:
            print(f"Completed {t=} iterations   ", end="\r")
    return avg_x / T, avg_o / T



def play_against_agent(pi, sf, human_player= -1, deterministic = True):
    """
    Play a human vs. agent game. 
    human_player: 1 for X (goes first), -1 for O.
    """
    state = TicTacToeState()
    while True:
        terminal, payoff = state.is_terminal()
        if terminal:
            if payoff == 0:
                print("Draw!")
            else:
                winner = 'X' if payoff == 1 else 'O'
                print(f"Winner: {winner}")
            break
        state.draw()
        if state.to_move == human_player:
            # human turn
            while True:
                move = input(f"Your move {('X' if human_player==1 else 'O')}, enter row,col: ")
                try:
                    i,j = map(int, move.split(','))
                    if (i,j) in state.legal_actions():
                        break
                    else:
                        print("Illegal move")
                except:
                    print("Invalid input, format: row,col")
            state = state.next_state((i,j))
        else:
            # agent turn
            sid = sf.state_to_id[state]
            legal = state.legal_actions()
            seqs = [sf.indices[(sid,a)] for a in legal]
            probs = pi[seqs]
            probs = probs / probs.sum()
            print(f"Probabilities {probs=}")
            if deterministic: 
                choice = int(np.argmax(probs))
            else:
                choice = np.random.choice(len(legal), p=probs)
            print(f"Agent plays {legal[choice]}")
            state = state.next_state(legal[choice])

np.random.seed(420)
pi_X, pi_O = train_self_play(T=100000, eta=0.05)
sf = SequenceForm()
sf.build(TicTacToeState())

np.save('pi_X.npy', pi_X)
np.save('pi_O.npy', pi_O)

Completed t=95800 iterations   

In [None]:
pi_X = np.load('pi_X.npy')
pi_O = np.load('pi_O.npy')

In [112]:
play_against_agent(pi_X, sf, human_player=-1)

 | | 
-----
 | | 
-----
 | | 

Probabilities probs=array([0.10002728, 0.12805571, 0.09729799, 0.12908966, 0.08385077,
       0.13132796, 0.10456902, 0.1193195 , 0.10646212])
Agent plays (1, 2)
 | | 
-----
 | |X
-----
 | | 

Your move O, enter row,col: 1, 1
 | | 
-----
 |O|X
-----
 | | 

Probabilities probs=array([0.1452948 , 0.13557789, 0.10604309, 0.23036927, 0.1180822 ,
       0.14521477, 0.11941798])
Agent plays (1, 0)
 | | 
-----
X|O|X
-----
 | | 

Your move O, enter row,col: 0,0
O| | 
-----
X|O|X
-----
 | | 

Probabilities probs=array([0.11997939, 0.14253484, 0.30026015, 0.36249353, 0.07473208])
Agent plays (2, 1)
O| | 
-----
X|O|X
-----
 |X| 

Your move O, enter row,col: 2,2
Winner: O


In [None]:
play_against_agent(pi_O, sf, human_player=1)

In [40]:
# Define the game and sequence form utils
import numpy as np

class Game: 
    
    def __init__(self):
        self.board = np.zeros(10, dtype = int)
        self.player = 1 # Alternates between -1 and 1
        self.board[9] = self.player
        self.hist = []
     
    def get_actions(self, board = None):
        if board == None: board = self.board
        return [i for i in range(9) if board[i] == 0]
    
    def winner(self): #Return tuple (done, outcome)
        b = self.board[:9].reshape(3, 3)

        # Check rows and columns
        rows, cols = b.sum(axis = 1), b.sum(axis = 0)
        if 3 in rows or 3 in cols: return True, 1
        elif -3 in rows or -3 in cols: return True, -1
        
        # Check diagonals
        if abs(b[0, 0] + b[1, 1] + b[2, 2]) == 3: return True, np.sign(b[0, 0])
        if abs(b[0, 2] + b[1, 1] + b[2, 0]) == 3: return True, np.sign(b[0, 2])

        # Check for draw
        if np.all(self.board[:9] != 0):
            return True, 0

        return False, None  # Game still ongoing

    
    def update(self, a: int):
        self.board[a] = self.player
        self.player *= -1
        self.hist.append(a)
        self.board[9] = self.player
        
    def undo(self):
        a = self.hist.pop()
        self.player *= -1
        self.board[a] = 0
        self.board[9] = self.player
            
    def get_state(self):
        return self.board
    
    def get_state_id(self):
        return hash(self.board.tobytes())
    
    def __eq__(self, other):
        return np.array_equal(self.board, other.board)
    
    
    def draw(self):
        def convert(n):
            return "X"*(n==1) + " "*(n==0) + "O"*(n==-1)
        
        temp = [[convert(int(self.board[3*j+i])) for i in range(3)] for j in range(3)]
        print(f"{temp[0][0]}|{temp[0][1]}|{temp[0][2]}")
        print("-----")
        print(f"{temp[1][0]}|{temp[1][1]}|{temp[1][2]}")
        print("-----")
        print(f"{temp[2][0]}|{temp[2][1]}|{temp[2][2]}")  
        print()
    
class SequenceForm: 
    
    def __init__(self, game):
        self.game = game
        
        self.state_to_id = {} # board -> id
        self.indices = {}  # (state_id, action) -> seq_idx
        self.num_sequences = 0
        
        self.build()
        
    def build(self):
        def dfs():
            state_hash = self.game.get_state_id()
            if state_hash in self.state_to_id:
                return

            sid = len(self.state_to_id)
            self.state_to_id[state_hash] = sid

            done, _ = self.game.winner()
            if done: return

            for a in self.game.get_actions():
                self.indices[(sid, a)] = self.num_sequences
                self.num_sequences += 1
                self.game.update(a)
                dfs()
                self.game.undo()
        
        dfs()
        
def GRADIENT(x):
    pass

def ARGCONJUGATE(g):
    pass

# --- Bandit regret solver ---
class BanditRegretSolver:
    def __init__(self, sf: SequenceForm, eta: float):
        pass

    def next_strategy(self):
        pass

    def observe_loss(self, loss_estimate):
        pass

    def build_loss_estimate(self, path, payoff, x):
        pass

# --- Self-play sampling and training ---
def sample_self_play(sf: SequenceForm, x, o):
    pass


def train_self_play(T=10000, eta=0.1):
    pass

sf = SequenceForm(Game())
sf.num_sequences

16167