In [1]:
import numpy as np
import pyspiel
import copy

BOARD_SIZE = 3
game = pyspiel.load_game("hex",{"board_size":BOARD_SIZE})
BLACK, WHITE = 1, -1  # first turn or second turn player

class State:
    '''Board implementation of BOARD_SIZE x BOARD_SIZE Hex Board'''
    X, Y = 'ABCDEFGHI'[0:BOARD_SIZE],  '123456789'[0:BOARD_SIZE]
    C = {0: '_', BLACK: 'O', WHITE: 'X'}

    def __init__(self):
        self.board = np.zeros((BOARD_SIZE, BOARD_SIZE)) # (x, y)
        self.color = 1
        self.win_color = 0
        self.record = []
        self.hex_state = game.new_initial_state()

    def action2str(self, a: int):
        return self.X[a // BOARD_SIZE] + self.Y[a % BOARD_SIZE]

    def str2action(self, s: str):
        return self.X.find(s[0]) * BOARD_SIZE + self.Y.find(s[1])

    def record_string(self):
        return ' '.join([self.action2str(a) for a in self.record])
    
    def __deepcopy__(self):
        newState = State()
        newState.board = copy.deepcopy(self.board)
        newState.win_color = copy.deepcopy(self.win_color)
        newState.record = copy.deepcopy(self.record)
        newState.hex_state = copy.deepcopy(self.hex_state)
        return newState

    def __str__(self):
        final_bd = [" "+" ".join(self.Y)]
        hex_bd = str(self.hex_state).split("\n")
        for i in range(len(hex_bd)):
            final_bd.append(self.X[i]+" "+hex_bd[i])
        return "\n".join(final_bd)

    def dry_play(self, action):
        freshState = self.__deepcopy__()
        freshState.play(action)
        return freshState

    def play(self, action):
        # state transition function
        # action is position interger (0~8) or string representation of action sequence
        # Handles the case where action is sequence of actions "0 1 2 3 4"
        if isinstance(action, str):
            for astr in action.split():
                self.play(self.str2action(astr))
            return self

        # Single action case
        x, y = action // BOARD_SIZE, action % BOARD_SIZE
        self.board[x, y] = self.color
        self.hex_state.apply_action(action)

        # check whether 3 stones are on the line
        if self.hex_state.is_terminal():
            self.win_color = self.color

        self.color = -self.color
        self.record.append(action)
        return self

    def terminal(self):
        # terminal state check
        return self.hex_state.is_terminal()

    def terminal_reward(self):
        # terminal reward 
        # return self.win_color if self.color == BLACK else -self.win_color
        return self.win_color

    def legal_actions(self):
        # list of legal actions on each state
        return [a for a in range(BOARD_SIZE * BOARD_SIZE) if self.board[a // BOARD_SIZE, a % BOARD_SIZE] == 0]

    def feature(self):
        # input tensor for neural net (state)
        # return np.stack([self.board == self.color, self.board == -self.color]).astype(np.float32)
        observation =  np.array(self.hex_state.observation_tensor(), np.float32)
        return observation.reshape(9,BOARD_SIZE,BOARD_SIZE)

    def action_feature(self, action):
        # input tensor for neural net (action)
        a = np.zeros((1, BOARD_SIZE, BOARD_SIZE), dtype=np.float32)
        a[0, action // BOARD_SIZE, action % BOARD_SIZE] = 1
        return a

In [12]:
from collections import defaultdict
import math

class MiniMaxAgent:
    def __init__(self):
        self.result = defaultdict(lambda x: defaultdict(int))
    
    def think(self, state: State, sim_num: int, temperature:int, show=False):
        return self.minimax(state, BOARD_SIZE*BOARD_SIZE, -math.inf, math.inf, state.color==1)

    def minimax(self, state, depth, alpha, beta, maximizing_player):

        if depth == 0 or state.terminal():
            return None, state.terminal_reward()

        children = state.legal_actions()
        best_move = children[0]
        
        if maximizing_player:
            max_eval = -math.inf        
            for child in children:
                freshState = state.__deepcopy__()
                freshState.play(child)
                current_eval = self.minimax(freshState, depth - 1, alpha, beta, False)[1]
                if current_eval > max_eval:
                    max_eval = current_eval
                    best_move = child
                alpha = max(alpha, current_eval)
                if beta <= alpha:
                    break
            return best_move, max_eval

        else:
            min_eval = math.inf
            for child in children:
                freshState = state.__deepcopy__()
                freshState.play(child)
                current_eval = self.minimax(freshState, depth - 1, alpha, beta, True)[1]
                if current_eval < min_eval:
                    min_eval = current_eval
                    best_move = child
                beta = min(beta, current_eval)
                if beta <= alpha:
                    break
            return best_move, min_eval

    def my_minimax(self, state: State):
        if str(state) in self.result:
            return self.result[str(state)]

        if state.terminal():
            self.result[str(state)] = state.terminal_reward()
            return state.terminal_reward()
        else:
            results = []
            for a in state.legal_actions():
                freshState = state.__deepcopy__()
                freshState.play(a)
                value = self.my_minimax(freshState)
                self.result[str(freshState)] = value
                results.append((a, value))
            if state.color==1:
                maxPair = max(results, key=lambda x: x[1])
                return maxPair[0]
            else:
                minPair = min(results, key=lambda x: x[1])
                return minPair[0]

    def pv(self, state):
        nextStates = [(a,self.result[str(state.dry_play(a))]) for a in state.legal_actions()]
        return [max(nextStates, key=lambda x: x[1])[1] ]


In [13]:
# Search with trained net

# tree = Tree(net)
tree = MiniMaxAgent()
state = State()
while True:  
  move, eval = tree.think(state, 5000, temperature=1, show=True)
  # pv_seq = tree.pv(state)
  # print(pv_seq)
  state.play(move)
  # display_heatmap2(distb.reshape((BOARD_SIZE,BOARD_SIZE)))
  print(state)
  if state.terminal():
    break
  noCorrectMove = True
  while noCorrectMove:
    user_input = input("Input move: ")
    if state.str2action(user_input) in state.legal_actions():
      noCorrectMove = False
  state.play(user_input)
  if state.terminal():
    break
print(state)
print(state.terminal_reward())

 1 2 3
A y . . 
B  . . . 
C   . . . 
 1 2 3
A y y . 
B  . o . 
C   . . . 
 1 2 3
A y y y 
B  p p . 
C   . . . 
 1 2 3
A y y y 
B  p p y 
C   . . q 
 1 2 3
A y y y 
B  p p y 
C   . O q 
-1


In [None]:

gl = [(1,2),(3,4),(5,6)]
f = lambda x: x[0]
print(max(gl, key=lambda x: x[1]))
print(f((3,4)))

(5, 6)
3
