In [1]:
import torch 
import pickle
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device('cpu')
    
print(device)

mps


### steps:
-State state
1. Initialize everything
2. Get board state, generate hash for each new board
3. Check available positions
4. update state with position
5. check winner
6. give reward
-Player setting
7. Initialize everything
8. Choose action

-State Value Update
-Training
-Saving and Loading "Policy"
-Inferencing for Human vs Computer



In [128]:
import numpy as np
BOARD_ROWS = 3
BOARD_COLS = 3


class Board:
    def __init__(self, player1, player2):
        # initializing players
        self.player1 = p1
        self.player2 = p2
        self.currSymbol = 1
        
        # initilaizing board
        self.board = np.zeros((BOARD_ROWS, BOARD_COLS))
        self.boardHash = None
        self.isEnd = False

            
    def generateHash(self):
        self.boardHash = str(self.board.reshape(BOARD_ROWS*BOARD_COLS))
        return self.boardHash

    def availablePositions(self):
        positions = []
        for i in range(BOARD_ROWS):
            for j in range(BOARD_COLS):
                if self.board[i][j] == 0:
                    positions.append((i,j))
        return positions
    
    def checkWinner(self):
        for i in range(BOARD_ROWS):
            if sum(self.board[i, :]) == 3:
                self.isEnd=True
                return 1
            if sum(self.board[i, :]) == -3:
                self.isEnd = True
                return -1
        for j in range(BOARD_COLS):
            if sum(self.board[:, j]) == 3:
                self.isEnd=True
                return 1
            if sum(self.board[:, j]) == -3:
                self.isEnd = True
                return -1

        diag_sum_1 = sum([self.board[i][i] for i in range(BOARD_ROWS)])
        diag_sum_2 = sum([self.board[i][i] for i in range(BOARD_COLS-1,-1, -1)])
        if diag_sum_1 == 3 or diag_sum_2 == 3:
            self.isEnd=True
            return 1
        if diag_sum_1 == -3 or diag_sum_2 == -3:
            self.isEnd=True
            return -1

        # tie?
        if len(self.availablePositions()) == 0:
            self.isEnd=True
            return 0
        self.isEnd = False
        return None
    
    def updateState(self, position):
        self.board[position] = self.currSymbol
        self.currSymbol = -1 if self.currSymbol == 1 else 1
        
        
    def allocateReward(self):
        result = self.checkWinner()
        if result == 1:
            self.player1.feedReward(1)
            self.player2.feedReward(0)
        elif result == -1:
            self.player1.feedReward(0)
            self.player2.feedReward(1)
        else:
            self.player1.feedReward(0.2)
            self.player2.feedReward(0.5)
    
    def resetBoard(self):
        self.isEnd = False
        self.board = np.zeros((BOARD_ROWS, BOARD_COLS))
        self.currSymbol = 1
        self.boardHash = None
        
    
    def play_computer(self, num_round=100):
        for i in range(num_round):
            if i%500 == 0:
                print("round: {}".format(i))
            while not self.isEnd:
                # player 1's turn
                
                pos = self.availablePositions()
                p1_action = self.player1.choosePosition(pos, self.board, self.currSymbol)
                self.updateState(p1_action)
                boardHash = self.generateHash()
                self.player1.addState(boardHash)
                
                win = self.checkWinner()
                if win is not None:
                    # someone won or game tied
                    self.allocateReward()
                    self.player1.reset()
                    self.player2.reset()
                    self.resetBoard()
                    break
                else:
                    # player 2's turn
                    pos = self.availablePositions()
                    p2_action = self.player2.choosePosition(pos, self.board, self.currSymbol)
                    self.updateState(p2_action)
                    boardHash = self.generateHash()
                    self.player2.addState(boardHash)
                    
                    win = self.checkWinner()
                    if win is not None:
                        self.allocateReward()
                        self.player1.reset()
                        self.player2.reset()
                        self.resetBoard()
                        break
    
    
    def play_human(self, num_round=100):

        while not self.isEnd:
            # player 1's turn

            pos = self.availablePositions()
            p1_action = self.player1.choosePosition(pos, self.board, self.currSymbol)
            self.updateState(p1_action)
            self.showBoard()

            win = self.checkWinner()
            if win is not None:
                # someone won or game tied
                if win == 1:
                    print(self.player1.name, "wins")
                else:
                    print("tie")
                self.resetBoard()
                break
            else:
                # player 2's turn
                pos = self.availablePositions()
                p2_action = self.player2.choosePosition(pos, self.board, self.currSymbol)
                self.updateState(p2_action)
                self.showBoard()
                
                win = self.checkWinner()
                if win is not None:
                    if win == -1:
                        print(self.player2.name, "wins")
                    else:
                        print("tie")
                    self.resetBoard()
                    break
                    
    def showBoard(self):
        string = ""
        for i in range(BOARD_ROWS):
            string += "\n---------------\n"
            string += " | "
            for j in range(BOARD_COLS):
                if self.board[i,j] == 1:
                    token = "x"
                elif self.board[i,j] == -1:
                    token = "o"
                else:
                    token = " "
                string+= str(token) + " | "
        string += "\n---------------"
        print(string)
    

In [132]:
class Player:
    def __init__(self, name, random_selection= 0.3):
        self.name = name
        self.random_selection = random_selection
        self.states = []
        self.lr = 0.2
        self.gamma_decay = 0.9
        self.state_values = {}
        
    def getHash(self, board):
        boardHash = str(board.reshape(BOARD_ROWS*BOARD_COLS))
        return boardHash
    
    def choosePosition(self, positions, current_board, symbol):
        if np.random.uniform(0, 1) <= self.random_selection:

            idx = np.random.choice(len(positions))
            selected_position = positions[idx]
        else:
            max_val = -9999
            for pos in positions:

                next_board = current_board.copy()
                next_board[pos] = symbol
                next_board_hash = self.getHash(next_board)
                value = 0 if self.state_values.get(next_board_hash) is None else self.state_values.get(next_board_hash)

                if value > max_val:
                    max_val = value
                    selected_position = pos

        return selected_position
    
    def addState(self, state):
        self.states.append(state)
    
    def feedReward(self, reward):
        for st in reversed(self.states):

            if self.state_values.get(st) is None:
                self.state_values[st] = 0
            # using the learning algo
            # V(S t) = V(S t) + alpha*[V(S t+1) - V(S t)]
            # for correct: a praticular state with win -> 0.1 * [0.9] -> 0.09 + 0.1*[0.9- 0.09]
            self.state_values[st] += self.lr * (self.gamma_decay*reward - self.state_values[st])
            reward = self.state_values[st]
            
    
    def reset(self):
        self.states = []
    
    def save_policy(self):
        with open('policy_' + str(self.name), 'wb') as f:
            pickle.dump(self.state_values, f)
        f.close()
        
    def load_policy(self, file):
        with open(file, 'rb') as f:
            self.state_values = pickle.load(f)

        
    

In [133]:
class HumanPlayer:
    def __init__(self, name):
        self.name = name
    
    def choosePosition(self, positions, current_board, symbol):
        while True:
            row = int(input("Input row number: "))
            col = int(input("Input col number: "))
            selection = (row, col)
            if selection in positions:
                return selection

            
    def addState(self, state):
        pass
    def feedReward(self, reward):
        pass
    def reset(self):
        pass

In [135]:
# if __name__ == '__main__':
p1 = Player("p1")
p2 = Player("p2")

st = Board(p1, p2)
print("Training .... ")
st.play_computer(10000)
p1.save_policy()
print("Policy saved for p1")


Training .... 
round: 0
round: 500
round: 1000
round: 1500
round: 2000
round: 2500
round: 3000
round: 3500
round: 4000
round: 4500
round: 5000
round: 5500
round: 6000
round: 6500
round: 7000
round: 7500
round: 8000
round: 8500
round: 9000
round: 9500


In [136]:

p1 = Player("p1", random_selection=0.0)
p1.load_policy("policy_p1")
p2 = HumanPlayer("Human")

st = Board(p1, p2)
st.play_human(1)


---------------
 |   |   |   | 
---------------
 |   |   |   | 
---------------
 |   | x |   | 
---------------
Input row number: 1
Input col number: 1

---------------
 |   |   |   | 
---------------
 |   | o |   | 
---------------
 |   | x |   | 
---------------

---------------
 |   |   |   | 
---------------
 |   | o |   | 
---------------
 |   | x | x | 
---------------
Input row number: 2
Input col number: 0

---------------
 |   |   |   | 
---------------
 |   | o |   | 
---------------
 | o | x | x | 
---------------

---------------
 |   |   | x | 
---------------
 |   | o |   | 
---------------
 | o | x | x | 
---------------
Input row number: 1
Input col number: 0

---------------
 |   |   | x | 
---------------
 | o | o |   | 
---------------
 | o | x | x | 
---------------

---------------
 |   |   | x | 
---------------
 | o | o | x | 
---------------
 | o | x | x | 
---------------
p1 wins
