In [1]:
import numpy as np
import pickle

In [2]:
class Environment:
    def __init__(self, p1, p2):
        self.board = np.zeros((3, 3))
        self.p1 = p1
        self.p2 = p2
        self.isEnd = False
        self.boardHash = None
        self.playerSymbol = 1
    
    def getHash(self):
        self.boardHash = str(self.board.reshape(3*3))
        return self.boardHash
        
    def winner(self):
        for i in range(3):
            if sum(self.board[i, :])== 3 or sum(self.board[:, i])== 3:
                self.isEnd = True
                return 1
            if sum(self.board[:, i])== -3 or sum(self.board[i, :])== -3:
                self.isEnd = True
                return -1
           
        diag_sum = sum([self.board[i, i] for i in range(3)])
        diag_sum2 = sum([self.board[i, 2 - i] for i in range(3)])
        
        if diag_sum == 3 or diag_sum2 == 3:
            isEnd = True
            return 1
        elif diag_sum == -3 or diag_sum2 == -3:
            isEnd = True
            return -1
        if len(self.availablePositions()) == 0:
            isEnd = True
            return 0
        self.isEnd = False
        return None
    
    def availablePositions(self):
        ps = []
        for i in range(3):
            for j in range(3):
                if self.board[i, j] == 0:
                    ps.append((i, j))
                
        return ps
    
    def updateState(self, position):
        self.board[position] = self.playerSymbol
        self.playerSymbol = -1 if self.playerSymbol == 1 else 1
        
    def giveReward(self):
        result = self.winner()

        if result == 1:
            self.p1.feedforward(1)
            self.p2.feedforward(0)
        elif result == -1:
            self.p1.feedforward(0)
            self.p2.feedforward(1)
        else:
            self.p1.feedforward(0.1)
            self.p2.feedforward(0.5)

    def reset(self):
        self.board = np.zeros((3, 3))
        self.isEnd = False
        self.playerSymbol = 1

    def play(self, epochs):
        for i in range(epochs):
            while not self.isEnd:
                positions = self.availablePositions()
                p1_action = self.p1.chooseAction(positions, self.board, self.playerSymbol)

                self.updateState(p1_action)
                board_hash = self.getHash()
                self.p1.addState(board_hash)

                win = self.winner()
                if win is not None:
                    self.giveReward()
                    self.p1.reset()
                    self.p2.reset()
                    break

                else:
                    positions = self.availablePositions()
                    p1_action = self.p1.chooseAction(positions, self.board, self.playerSymbol)

                    self.updateState(p1_action)
                    board_hash = self.getHash()
                    self.p2.addState(board_hash)

                    win = self.winner()
                    if win is not None:   
                        self.giveReward()
                        self.p1.reset()
                        self.p2.reset()
                        break
    def play2(self):
        '''
        This function is required for the game play between Human and AI agent.
        '''
        while not self.isEnd:
            positions = self.availablePositions()
            p1_action = self.p1.chooseAction(positions, self.board, self.playerSymbol)

            self.updateState(p1_action)
            self.showBoard()
            
            win = self.winner()
            if win is not None:
                if win== 1:
                    print(self.p1.name, 'win!!')
                else: 
                    print('tie!!')
                self.reset()
                break
            else:
                #player 2
                positions = self.availablePositions()
                p2_action = self.p2.chooseAction(positions)

                self.updateState(p2_action)
                self.showBoard()
            
                win = self.winner()
                if win is not None:
                    if win== -1:
                        print(self.p2.name, 'win!!')
                    else: 
                        print('tie!!')
                    self.reset()
                    break
                
            
        
    def showBoard(self):
        for i in range(0, 3):
            print('----------------')
            out = '|'
            for j in range(0, 3):
                if self.board[i, j] == 1:
                    token = 'X'
                if self.board[i, j]== -1:
                    token = 'O'
                if self.board[i, j] == 0:
                    token = ' '
                out += token + ' | '
            print(out)
        print('------------------')

In [3]:
class Player:
    def __init__(self, name, exp_rate = 0.2):
        self.name = name
        self.states = []
        self.exp_rate = exp_rate
        self.decay_gamma = 0.9
        self.lr = 0.2
        
        self.states_value = {} #state -> value
        
    def getHash(self, board):
        boardHash = str(board.reshape(3*3))
        return boardHash
        
    def chooseAction(self, positions, current_board, symbol):
        if np.random.uniform(0, 1) <= self.exp_rate:
            idx = np.random.choice(len(positions))
            action = positions[idx]
        else:
            value_max = -999
            for p in positions:
                next_board = current_board.copy()
                next_board[p] = symbol
                next_boardHash = self.getHash(next_board)
                value = 0 if self.states_value.get(next_boardHash) is None else self.states_value.get(next_boardHash)
                if value >= value_max:
                    value_max = value
                    action = p
        return action
    
    def feedforward(self, reward):
        for st in reversed(self.states):
            self.states_value[st] = 0
        self.states_value[st] += self.lr*(self.decay_gamma*reward - self.states_value[st])
        reward = self.states_value[st]
            
    def addState(self, state):
        self.states.append(state)
        
    def reset(self):
        self.states= []
    
    def savePolicy(self):
        f = open('policy_' + str(self.name), 'wb')
        pickle.dump(self.states_value, f)
        f.close()
        
    def loadPolicy(self, file):
        f = open(file, 'rb')
        self.states_value = pickle.load(f)
        f.close()

In [4]:
class HumanPlayer:
    def __init__(self, name):
        self.name = name
        
    def chooseAction(self, positions):
        while True:
            row = int(input("Input your action row: "))
            col = int(input("Input your action col: "))
            action = (row, col)
            if action in positions:
                return action
            
            

In [5]:
p1 = Player('p1')
p2 = Player('p2')

st = Environment(p1, p2)
print('training...')
st.play(5)

training...


In [6]:
p1.savePolicy()
p2.savePolicy()

In [7]:
p1 = Player('Computer', exp_rate=0.001)
p1.loadPolicy('policy_p1')

p2 = HumanPlayer('Human')

game = Environment(p1, p2)
game.play2()

----------------
|  |   |   | 
----------------
|  |   |   | 
----------------
|  | X |   | 
------------------
Input your action row: 1
Input your action col: 1
----------------
|  |   |   | 
----------------
|  | O |   | 
----------------
|  | X |   | 
------------------
----------------
|  |   |   | 
----------------
|  | O |   | 
----------------
|  | X | X | 
------------------
Input your action row: 2
Input your action col: 0
----------------
|  |   |   | 
----------------
|  | O |   | 
----------------
|O | X | X | 
------------------
----------------
|  |   |   | 
----------------
|  | O | X | 
----------------
|O | X | X | 
------------------
Input your action row: 0
Input your action col: 0
----------------
|O |   |   | 
----------------
|  | O | X | 
----------------
|O | X | X | 
------------------
----------------
|O |   |   | 
----------------
|X | O | X | 
----------------
|O | X | X | 
------------------
Input your action row: 0
Input your action col: 1
----------------