In [38]:
import numpy as np
import pickle
import os


In [39]:
BOARD_ROWS = 4
BOARD_COLS = 4
BOARD_SIZE = BOARD_ROWS * BOARD_COLS


class State:
    def __init__(self):
        # the board is represented by an n * n array,
        # 1 represents a chessman of the player who moves first,
        # -1 represents a chessman of another player
        # 0 represents an empty position
        self.data = np.zeros((BOARD_ROWS, BOARD_COLS))
        self.winner = None
        self.hash_val = None
        self.end = None

    # compute the hash value for one state, it's unique
    def hash(self):
        if self.hash_val is None:
            self.hash_val = 0
            for i in np.nditer(self.data):
                self.hash_val = self.hash_val * 3 + i + 1
        return self.hash_val

    # check whether a player has won the game, or it's a tie
    def is_end(self):
        if self.end is not None:
            return self.end
        results = []
        # check row
        for i in range(BOARD_ROWS):
            results.append(np.sum(self.data[i, :]))
        # check columns
        for i in range(BOARD_COLS):
            results.append(np.sum(self.data[:, i]))

        # check diagonals
        trace = 0
        reverse_trace = 0
        for i in range(BOARD_ROWS):
            trace += self.data[i, i]
            reverse_trace += self.data[i, BOARD_ROWS - 1 - i]
        results.append(trace)
        results.append(reverse_trace)

        for result in results:
            if result == 4:
                self.winner = 1
                self.end = True
                return self.end
            if result == -4:
                self.winner = -1
                self.end = True
                return self.end

        # whether it's a tie
        sum_values = np.sum(np.abs(self.data))
        if sum_values == BOARD_SIZE:
            self.winner = 0
            self.end = True
            return self.end

        # game is still going on
        self.end = False
        return self.end

    # @symbol: 1 or -1
    # put chessman symbol in position (i, j)
    def next_state(self, i, j, symbol):
        new_state = State()
        new_state.data = np.copy(self.data)
        new_state.data[i, j] = symbol
        return new_state

    # print the board
    def print_state(self):
        output_str = ""
        for i in range(BOARD_ROWS):
            output_str += '------------------\n'
            out = '| '
            for j in range(BOARD_COLS):
                if self.data[i, j] == 1:
                    token = '*'
                elif self.data[i, j] == -1:
                    token = 'x'
                else:
                    token = '0'
                out += token + ' | '
            output_str += out + '\n'
        output_str += '------------------\n'
        print(output_str)
        return output_str


In [40]:
def get_all_states_impl(current_state, current_symbol, all_states):
    for i in range(BOARD_ROWS):
        for j in range(BOARD_COLS):
            if current_state.data[i][j] == 0:
                new_state = current_state.next_state(i, j, current_symbol)
                new_hash = new_state.hash()
                # print(new_state.data)
                if new_hash not in all_states:
                    is_end = new_state.is_end()
                    all_states[new_hash] = (new_state, is_end)
                    if not is_end:
                        get_all_states_impl(new_state, -current_symbol, all_states) # -current_symbol表示切换玩家


def get_all_states():
    current_symbol = 1
    current_state = State()
    all_states = dict()
    all_states[current_state.hash()] = (current_state, current_state.is_end())
    get_all_states_impl(current_state, current_symbol, all_states)
    return all_states

# all possible board configurations
all_states = get_all_states()

In [41]:
class Judger:
    # @player1: the player who will move first, its chessman will be 1
    # @player2: another player with a chessman -1
    def __init__(self, player1, player2):
        self.p1 = player1
        
        self.p2 = player2
        self.current_player = None
        self.p1_symbol = 1
        self.p2_symbol = -1
        self.p1.set_symbol(self.p1_symbol)
        self.p2.set_symbol(self.p2_symbol)
        self.current_state = State()

    def reset(self):
        self.p1.reset()
        self.p2.reset()

    def alternate(self):
        while True:
            yield self.p1
            yield self.p2

    # @print_state: if True, print each board during the game
    def play(self, print_state=True):
        alternator = self.alternate()
        self.reset()
        current_state = State()
        self.p1.set_state(current_state)
        self.p2.set_state(current_state)
        if print_state:
            current_state.print_state()
        while True:
            player = next(alternator)
            i, j, symbol = player.act()
            next_state_hash = current_state.next_state(i, j, symbol).hash()
            current_state, is_end = all_states[next_state_hash]
            self.p1.set_state(current_state)
            self.p2.set_state(current_state)
            if print_state:
                current_state.print_state()
            if is_end:
                return current_state.winner


In [42]:
# AI player
class Player:
    # @step_size: the step size to update estimations
    # @epsilon: the probability to explore
    def __init__(self, step_size=0.1, epsilon=0.1, epochs=int(1e5)):
        self.estimations = dict()
        self.step_size = step_size
        self.epsilon = epsilon
        self.states = []
        self.greedy = []
        self.symbol = 0
        self.epochs = epochs

    def reset(self):
        self.states = []
        self.greedy = []

    def set_state(self, state):
        self.states.append(state)
        self.greedy.append(True)

    def set_symbol(self, symbol):
        self.symbol = symbol
        for hash_val in all_states:
            state, is_end = all_states[hash_val]
            if is_end:
                if state.winner == self.symbol:
                    self.estimations[hash_val] = 1.0
                elif state.winner == 0:
                    # we need to distinguish between a tie and a lose
                    self.estimations[hash_val] = 0.5
                else:
                    self.estimations[hash_val] = 0
            else:
                self.estimations[hash_val] = 0.5

    # update value estimation
    def backup(self):
        states = [state.hash() for state in self.states]

        for i in reversed(range(len(states) - 1)):
            state = states[i]
            td_error = self.greedy[i] * (
                self.estimations[states[i + 1]] - self.estimations[state]
            )
            self.estimations[state] += self.step_size * td_error

    # choose an action based on the state
    def act(self):
        state = self.states[-1]
        next_states = []
        next_positions = []
        for i in range(BOARD_ROWS):
            for j in range(BOARD_COLS):
                if state.data[i, j] == 0:
                    next_positions.append([i, j])
                    next_states.append(state.next_state(
                        i, j, self.symbol).hash())

        if np.random.rand() < self.epsilon:
            action = next_positions[np.random.randint(len(next_positions))]
            action.append(self.symbol)
            self.greedy[-1] = False
            return action

        values = []
        for hash_val, pos in zip(next_states, next_positions):
            values.append((self.estimations[hash_val], pos))
        # to select one of the actions of equal value at random due to Python's sort is stable
        np.random.shuffle(values)
        values.sort(key=lambda x: x[0], reverse=True)
        action = values[0][1]
        action.append(self.symbol)
        return action

    def save_policy(self):
        with open('policy_%s_%d.bin' % ('first' if self.symbol == 1 else 'second', self.epochs), 'wb') as f:
            pickle.dump(self.estimations, f)

    def load_policy(self):
        with open('policy_%s_%d.bin' % ('first' if self.symbol == 1 else 'second',self.epochs), 'rb') as f:
            self.estimations = pickle.load(f)



In [53]:
def train(epochs, print_every_n=500):
    player1 = Player(epsilon=0.01, epochs=epochs)
    player2 = Player(epsilon=0.01, epochs=epochs)
    judger = Judger(player1, player2)
    player1_win = 0.0
    player2_win = 0.0
    for i in range(1, epochs + 1):
        winner = judger.play(print_state=False)
        if winner == 1:
            player1_win += 1
        if winner == -1:
            player2_win += 1
        if i % print_every_n == 0:
            print('Epoch %d, player 1 winrate: %.02f, player 2 winrate: %.02f' % (i, player1_win / i, player2_win / i))
        player1.backup()
        player2.backup()
        judger.reset()
    player1.save_policy()
    player2.save_policy()


def compete(turns, epochs1, epochs2):
    player1 = Player(epsilon=0, epochs=epochs1)
    player2 = Player(epsilon=0, epochs=epochs2)
    judger = Judger(player1, player2)
    player1.load_policy()
    player2.load_policy()
    player1_win = 0.0
    player2_win = 0.0
    for i in range(turns):
        print("------------- %s -------------" %("TURN "+str(i+1)))
        winner = judger.play()
        if winner == 1:
            player1_win += 1
        if winner == -1:
            player2_win += 1
        judger.reset()
    print('%d turns, player 1 win %.02f, player 2 win %.02f, tie %.02f'
          % (turns, player1_win / turns, player2_win / turns, (turns - player1_win - player2_win) / turns))


In [44]:
train(int(1e3))
train(int(1e4))
train(int(1e5))

Epoch 500, player 1 winrate: 0.47, player 2 winrate: 0.28
Epoch 1000, player 1 winrate: 0.48, player 2 winrate: 0.25
Epoch 500, player 1 winrate: 0.45, player 2 winrate: 0.35
Epoch 1000, player 1 winrate: 0.44, player 2 winrate: 0.34
Epoch 1500, player 1 winrate: 0.47, player 2 winrate: 0.30
Epoch 2000, player 1 winrate: 0.48, player 2 winrate: 0.27
Epoch 2500, player 1 winrate: 0.49, player 2 winrate: 0.25
Epoch 3000, player 1 winrate: 0.43, player 2 winrate: 0.22
Epoch 3500, player 1 winrate: 0.42, player 2 winrate: 0.21
Epoch 4000, player 1 winrate: 0.41, player 2 winrate: 0.21
Epoch 4500, player 1 winrate: 0.37, player 2 winrate: 0.19
Epoch 5000, player 1 winrate: 0.37, player 2 winrate: 0.18
Epoch 5500, player 1 winrate: 0.38, player 2 winrate: 0.18
Epoch 6000, player 1 winrate: 0.38, player 2 winrate: 0.19
Epoch 6500, player 1 winrate: 0.37, player 2 winrate: 0.18
Epoch 7000, player 1 winrate: 0.36, player 2 winrate: 0.17
Epoch 7500, player 1 winrate: 0.34, player 2 winrate: 0.16

In [54]:
compete(turns=int(10), epochs1=int(1e4), epochs2=int(1e3))

------------- TURN 1 -------------
------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------

------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | * | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------

------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | * | x | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------

------------------
| 0 | 0 | 0 | * | 
------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | * | x | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------

------------------
| 0 | x | 0 | * | 
------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | * | x | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------

------------------
| 0 | x | 0 | * | 
------------------
| 0 | 0 | * | 0 | 
------------------
| 0 | * | 

In [46]:
compete(turns=int(10), epochs1=int(1e3), epochs2=int(1e4))

------------- TURN 1 -------------
------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------

------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | * | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------

------------------
| 0 | 0 | x | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | * | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------

------------------
| * | 0 | x | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | * | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------

------------------
| * | 0 | x | 0 | 
------------------
| 0 | 0 | 0 | x | 
------------------
| 0 | * | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------

------------------
| * | 0 | x | 0 | 
------------------
| 0 | 0 | 0 | x | 
------------------
| 0 | * | 

In [47]:
compete(turns=int(10), epochs1=int(1e5), epochs2=int(1e4))

------------- TURN 1 -------------
------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------

------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | * | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------

------------------
| 0 | 0 | x | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | * | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------

------------------
| 0 | 0 | x | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | * | 0 | * | 
------------------
| 0 | 0 | 0 | 0 | 
------------------

------------------
| 0 | 0 | x | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | * | x | * | 
------------------
| 0 | 0 | 0 | 0 | 
------------------

------------------
| 0 | 0 | x | 0 | 
------------------
| 0 | 0 | * | 0 | 
------------------
| 0 | * | 

In [48]:
compete(turns=int(10), epochs1=int(1e4), epochs2=int(1e5))

------------- TURN 1 -------------
------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------

------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | * | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------

------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------
| x | * | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------

------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------
| x | * | 0 | 0 | 
------------------
| * | 0 | 0 | 0 | 
------------------

------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | x | 0 | 0 | 
------------------
| x | * | 0 | 0 | 
------------------
| * | 0 | 0 | 0 | 
------------------

------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | x | 0 | 0 | 
------------------
| x | * | 

In [49]:
compete(turns=int(10), epochs1=int(1e3), epochs2=int(1e5))

------------- TURN 1 -------------
------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------

------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | * | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------

------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------
| x | * | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------

------------------
| * | 0 | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------
| x | * | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------

------------------
| * | 0 | 0 | 0 | 
------------------
| 0 | x | 0 | 0 | 
------------------
| x | * | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------

------------------
| * | 0 | * | 0 | 
------------------
| 0 | x | 0 | 0 | 
------------------
| x | * | 

In [50]:
compete(turns=int(10), epochs1=int(1e5), epochs2=int(1e3))

------------- TURN 1 -------------
------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------

------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | * | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------

------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | * | x | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------

------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | * | x | 0 | 
------------------
| * | 0 | 0 | 0 | 
------------------

------------------
| 0 | x | 0 | 0 | 
------------------
| 0 | 0 | 0 | 0 | 
------------------
| 0 | * | x | 0 | 
------------------
| * | 0 | 0 | 0 | 
------------------

------------------
| 0 | x | 0 | 0 | 
------------------
| 0 | 0 | * | 0 | 
------------------
| 0 | * | 