In [1]:
import numpy as np
import time
from ipywidgets import IntProgress, HTML, VBox
from IPython.display import display

In [2]:
def drawTicTacToe(state_vector):
    '''
    Input
        state_vector : Vector of 9 elements
            eg : [1,2,1,0,2,0,0,1,0] where {0: Empty cell, 1 : X. 2 : O}
    output :
    
        X | O | X
        ---------
          | O | 
        ---------
          | X |  
    '''
    symbol = [' ','X','O']
    
    for i in range(3):
        row = [symbol[int(state_vector[3*i+j])] for j in range(3)]
        print(' | '.join(row))
        if(i!=2):
            print('-'*9)
        else:
            print('\n\n')

drawTicTacToe([1,2,2,2,2,2,1,1,2])

X | O | O
---------
O | O | O
---------
X | X | O





In [3]:
def is_terminal_state(state_vector):
    '''
    Returns if the given state is terminal and it's corresponding reward
    Input:
        state_vector : state_vector : Vector of 9 elements
    Output:
        state : int {0 : Non terminal state, 1 : Terminal state draw, 2 : Terminal state win}
    '''
    state_vector = np.reshape(state_vector,[3,3])
    
    row = (3 in np.sum(state_vector==1,0)) or (3 in np.sum(state_vector==2,0))
    col = (3 in np.sum(state_vector==1,1)) or (3 in np.sum(state_vector==2,1))
    diag1 = (np.sum(state_vector*np.eye(3)==1)==3) or (np.sum(state_vector*np.eye(3)==2)==3)
    diag2 = (np.sum(state_vector*np.eye(3)[:,::-1]==1)==3) or (np.sum(state_vector*np.eye(3)[:,::-1]==2)==3)
    
    full = (np.sum(state_vector==0)==0)
    
    if(row or col or diag1 or diag2):
        return 2
    elif(full):
        return 1
    else:
        return 0

In [4]:
def find_turn(state_vector):
    x = np.sum(state_vector==1)
    o = np.sum(state_vector==2)
    if(x>o):
        turn = 2
    elif(x==o):
        turn = 1
    return turn

In [5]:
def stateAction(state_vector,action):
    '''
    Ouputs next state and reward given a state vector and action
    
    Input
        state_vector : Vector of 9 elements
            eg : [1,2,1,0,2,0,0,1,0] where {0: Empty cell, 1 : X. 2 : O}
        action : Position in which either 'X' or 'O' goes
    Output
        Next state vector : Vector of 9 elements
            eg : [1,2,1,0,2,0,0,1,0] where {0: Empty cell, 1 : X. 2 : O}
        reward : real number {+1:win,0:Draw,-1:Loss}
    '''
    state_vector = np.copy(state_vector)
    
    # Find turn if 'X' or 'O' should play
    if(np.sum(state_vector==1)==np.sum(state_vector==2)): # 'X' turn
        turn = 1
    elif(np.sum(state_vector==1)>np.sum(state_vector==2)): # 'O' turn
        turn = 2
    
    # Update state transition
    if(state_vector[action]>0): # Non empty cell
        reward = -1000
    else:
        state_vector[action] = turn
        state = is_terminal_state(state_vector)
        if(state==0 or state==1):
            reward = 0
        elif(state==2):
            reward = 1
    return(state_vector,reward)

state,reward = stateAction(np.array([2,1,1,1,1,2,2,2,0]),8)
drawTicTacToe(state)
print(reward)

O | X | X
---------
X | X | O
---------
O | O | X



0


In [6]:
def state_oneHotEncoding(state_vector):
    '''
    Takes a state vector as input and outputs a single unique number representing the state vector
    Input:
        state_vector
    Output
        state : one hot encoded int
    '''
    
    exp = np.array([3**i for i in range(8,-1,-1)])
    state = np.sum(state_vector*exp)
    return state

state_oneHotEncoding(np.array([0,0,0,1,0,0,0,2,2]))

251

In [7]:
class TicTacToe:
    
    def __init__(self,epsilon=1,alpha=0.1,gamma = 0.9):
        self.states = np.array([-1])
        self.q_pi = np.zeros(9)
        self.epsilon = epsilon
        self.alpha = alpha
        self.gamma = gamma
    
    def epsilon_greedy(self,state_vector):
        state = state_oneHotEncoding(state_vector)
        if(np.sum(self.states==state)==0):
            self.states = np.hstack([self.states,state])
            self.q_pi = np.vstack([self.q_pi,-1000*(state_vector>0)])
            action = np.random.randint(9)
        
        if(np.random.rand()>self.epsilon): # Exploitation
            q_state = self.q_pi[self.states==state][0]
            q_max = np.max(q_state)
            actions = np.arange(9)[q_state==q_max]
            action = actions[np.random.randint(np.size(actions))]
        else: # Exploration
            actions = np.arange(9)[state_vector==0]
            action = actions[np.random.randint(np.size(actions))]
        return action
    
    def sarsa_update(self,state_vector,action,reward,next_state_vector,next_action):
        state = state_oneHotEncoding(state_vector)
        next_state = state_oneHotEncoding(next_state_vector)
        self.q_pi[self.states==state,action] = self.q_pi[self.states==state,action] + self.alpha*(reward + self.gamma*self.q_pi[self.states==next_state,next_action] - self.q_pi[self.states==state,action])
        
    def gen_episode(self):
        '''
        Generates an episode of tic tac toe between two agents
        '''

        state_vector = np.zeros(9)
        state = state_oneHotEncoding(state_vector)
        action = self.epsilon_greedy(state_vector)
        # print(state_vector,state,action)
        # drawTicTacToe(state_vector)
        
        state_vector_1, reward = stateAction(state_vector,action)
        state_1 = state_oneHotEncoding(state_vector_1)
        action_1 = self.epsilon_greedy(state_vector_1)
        # print(state_vector_1,state_1,action_1)
        # drawTicTacToe(state_vector_1)

        while True:
            state_vector_2, reward_1 = stateAction(state_vector_1,action_1)
            state_2 = state_oneHotEncoding(state_vector_2)
            # print(state_vector_2,state_2)
            # drawTicTacToe(state_vector_2)
            if(reward_1>0): # Terminal state (win)
                # print(1,state,state_1)
                reward = -1
                self.q_pi[self.states==state,action] = self.q_pi[self.states==state,action] + self.alpha*(reward - self.q_pi[self.states==state,action])
                self.q_pi[self.states==state_1,action_1] = self.q_pi[self.states==state_1,action_1] + self.alpha*(reward_1 - self.q_pi[self.states==state_1,action_1])
                break
            else:
                terminal_state = is_terminal_state(state_vector_2)
                if(terminal_state==1): # Terminal state (Draw)
                    # print(2,state,state_1)
                    self.q_pi[self.states==state,action] = self.q_pi[self.states==state,action] + self.alpha*(reward-self.q_pi[self.states==state,action])
                    self.q_pi[self.states==state_1,action_1] = self.q_pi[self.states==state_1,action_1] + self.alpha*(reward_1 - self.q_pi[self.states==state_1,action_1])
                    break
                elif(terminal_state==0): # Non terminal state
                    # print(3,state,state_1)
                    action_2 = self.epsilon_greedy(state_vector_2)
                    self.sarsa_update(state_vector,action,reward,state_vector_2,action_2)
                    state_vector = state_vector_1
                    action = action_1
                    state = state_1
                    state_vector_1 = state_vector_2
                    action_1 = action_2
                    state_1 = state_2

        
        

In [8]:
iters = 100000
progress = IntProgress(value=0, min=0, max=iters)
label = HTML(value="Starting...")
box = VBox([label, progress])
display(box)

agent = TicTacToe(epsilon=1)
for i in range(iters):
    agent.gen_episode()
    progress.value = i
    label.value = f"Progress: {i} | Number of states visited: {np.size(agent.states)}"

VBox(children=(HTML(value='Starting...'), IntProgress(value=0, max=100000)))

In [9]:
def play_game(agent):
    state_vector = np.zeros(9)
    drawTicTacToe(state_vector)
    
    # Decide who starts first
    if(np.random.rand()>0):
        while True:
            while True:
                action = int(input('Enter move : '))
                if(state_vector[action]==0):
                    break
            state_vector,reward = stateAction(state_vector,action)
            terminal_state = is_terminal_state(state_vector)
            drawTicTacToe(state_vector)
            if(terminal_state>0 and reward>0):
                print('\n'+'#'*20+'\n\n'+' You won the game :) \n\n'+'#'*20)
                break
            elif(terminal_state>0 and reward==0):
                print('\n'+'#'*20+'\n\n'+' You drew the game :| \n\n'+'#'*20)
                break
            state = state_oneHotEncoding(state_vector)
            action = np.argmax(agent.q_pi[agent.states==state])
            state_vector,reward = stateAction(state_vector,action)
            terminal_state = is_terminal_state(state_vector)
            drawTicTacToe(state_vector)
            if(terminal_state>0 and reward>0):
                print('\n'+'#'*20+'\n\n'+' You lost the game :( \n\n'+'#'*20)
                break
            elif(terminal_state>0 and reward==0):
                print('\n'+'#'*20+'\n\n'+' You drew the game :| \n\n'+'#'*20)
                break
    else:
        while True:
            state = state_oneHotEncoding(state_vector)
            action = np.argmax(agent.q_pi[agent.states==state])
            state_vector,reward = stateAction(state_vector,action)
            terminal_state = is_terminal_state(state_vector)
            drawTicTacToe(state_vector)
            if(terminal_state>0 and reward>0):
                print('\n'+'#'*20+'\n\n'+' You lost the game :( \n\n'+'#'*20)
                break
            elif(terminal_state>0 and reward==0):
                print('\n'+'#'*20+'\n\n'+' You drew the game :| \n\n'+'#'*20)
                break
            drawTicTacToe(state_vector)
            while True:
                action = int(input('Enter move : '))
                if(state_vector[action]==0):
                    break
            state_vector,reward = stateAction(state_vector,action)
            terminal_state = is_terminal_state(state_vector)
            drawTicTacToe(state_vector)
            if(terminal_state>0 and reward>0):
                print('\n'+'#'*20+'\n\n'+' You won the game :) \n\n'+'#'*20)
                break
            elif(terminal_state>0 and reward==0):
                print('\n'+'#'*20+'\n\n'+' You drew the game :| \n\n'+'#'*20)
                break

In [20]:
play_game(agent)

  |   |  
---------
  |   |  
---------
  |   |  



Enter move : 0
X |   |  
---------
  |   |  
---------
  |   |  



X | O |  
---------
  |   |  
---------
  |   |  



Enter move : 3
X | O |  
---------
X |   |  
---------
  |   |  



X | O |  
---------
X |   |  
---------
O |   |  





KeyboardInterrupt: Interrupted by user

In [12]:
state_oneHotEncoding([1,2,0,0,0,0,0,0,0])

10935

In [13]:
agent.q_pi[agent.states==10935]

array([[-1000., -1000.,     0.,     0.,     0.,     0.,     0.,     0.,
            0.]])