In [1]:
# super ttt. 4row4col5diag. board 5x(4x4). 1/2 prob accept, 1/16 around(forfeit if occupied or outside). tfagent.


In [2]:
# import
import numpy as np
import tensorflow as tf
import random
from tensorflow.keras.initializers import RandomNormal, GlorotNormal, Zeros
import tqdm

In [None]:

# tictactoe env
class SuperTicTacToeEnv:
    """
    Environment for the Super Tic-Tac-Toe game.

    The game is played on 5 boards, each of which is a 4x4 grid.
    Two players take turns placing their marks (1 or -1) on the boards.
    The first player to get 4 in a row, column, or diagonal wins.
    If all squares are filled and no player has won, the game is a draw.

    Attributes:
        _state (np.ndarray): The current state of the game.
        _current_player (int): The current player (1 or -1).
        _num_actions (int): The number of possible actions.
        _episode_ended (bool): Whether the episode has ended.
        nearbys (list): List of nearby squares.
        board_layout (dict): Layout of the boards.
    """
    def __init__(self):
        """
        Initializes the SuperTicTacToeEnv.
        """
        self._state = np.zeros((5, 4, 4), dtype=np.int32)  # 5 boards, of which each 4x4
        self._current_player = 1  # two players, represented by 1 and -1
        self._num_actions = 5 * 4 * 4
        self._episode_ended = False
        self.nearbys = [(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)]
        self.board_layout = {  # not used, just for understand
            0: (-1, 0),  # up
            1: (0, -1),  # left
            2: (0, 0),   # center
            3: (0, 1),   # right
            4: (1, 0)    # down
        }

    def reset(self):
        """
        Resets the environment to the initial state.

        Returns:
            np.ndarray: The initial state of the game.
        """
        self._state = np.zeros((5, 4, 4), dtype=np.int32)
        self._current_player = 1
        self._episode_ended = False
        return self._state

    def step(self, action, verbose=False):
        """
        Takes a step in the environment.

        Args:
            action (int): The action to take.
            verbose (bool, optional): Whether to print debugging information. Defaults to False.

        Returns:
            tuple: The next state, reward, and whether the episode has ended.
        """
        if self._episode_ended:
            return self.reset(), 0, self._episode_ended

        # parse action to (board, row, col)
        board = action // (4 * 4)
        row = (action % (4 * 4)) // 4
        col = (action % (4 * 4)) % 4
        if verbose:
            print("player {} action {}, board {}, row {}, col {}".format(self._current_player, action, board, row, col))

        # selection occupied, end with negative reward
        occupied_penalty = 0
        if self._state[board, row, col] != 0:
            occupied_penalty = - 1.0
            if verbose:
                print("action {} occupied already, reward {}".format(action, occupied_penalty))
            # self._episode_ended = True  # do not end episode. state unchanged
            # return self._state, reward, self._episode_ended
        elif np.random.rand() < 0.5:  # action accepted
            self._state[board, row, col] = self._current_player
            if verbose:
                print("action {} accepted directly".format(action))
        else:  # randomly select a nearby square
            drow, dcol = random.choice(self.nearbys)
            nrow, ncol = row + drow, col + dcol

            # move across boards
            if board == 0 and 0 <= ncol <= 3 and nrow > 3:  # up board goes to center board
                nboard = 2
                nrow -= 4
            elif board == 1 and 0 <= nrow <= 3 and ncol > 3:  # left board goes to center board
                nboard = 2
                ncol -= 4
            elif board == 3 and 0 <= nrow <= 3 and ncol < 0:  # right board goes to center board
                nboard = 2
                ncol += 4
            elif board == 4 and 0 <= ncol <= 3 and nrow < 0:  # down board goes to center board
                nboard = 2
                nrow += 4
            elif board == 2:  # center board goes to nearby board
                if 0 <= ncol <= 3 and nrow < 0:  # up
                    nboard = 0
                    nrow += 4
                elif 0 <= ncol <= 3 and nrow > 3:  # down
                    nboard = 4
                    nrow -= 4
                elif 0 <= nrow <= 3 and ncol < 0:  # left
                    nboard = 1
                    ncol += 4
                elif 0 <= nrow <= 3 and ncol > 3:  # right
                    nboard = 3
                    ncol -= 4
                else:
                    nboard = board
            else:  # move
                nboard = board
            if verbose:
                print("action {} shift from (b{}, r{} c{}) to (b{}, r{}, c{}), d=({}, {})".format(action, board, row, col, nboard, nrow, ncol, drow, dcol))

            if (0 <= nrow <= 3 and 0 <= ncol <= 3) and (self._state[nboard, nrow, ncol] == 0):  # empty legal square
                self._state[nboard, nrow, ncol] = self._current_player
                if verbose:
                    print("action {} accepted".format(action))
            else:
                if verbose:
                    print("action {} denied".format(action))
                pass

        if self._check_win(self._current_player):
            reward = 1.0
            self._episode_ended = True
        elif np.all(self._state != 0):
            reward = 0
            self._episode_ended = True
        else:
            self._current_player = -self._current_player
            reward = occupied_penalty
        if verbose:
            print("reward {}, episode_ended {}".format(reward, self._episode_ended))
            print("boards:")
            self._draw_boards()
        return self._state, reward, self._episode_ended

    def _full_boards(self):
        """
        Pads the boards to create a full 12x12 grid.

        Returns:
            np.ndarray: The padded boards.
        """
        # (5, 4, 4) boards padded to (9, 4, 4) boards, final shape (3x4, 3x4).
        padded_states = np.concatenate(
            [np.concatenate([np.zeros([4, 4]), self._state[0], np.zeros([4, 4])], axis=-1),
             np.concatenate([self._state[1], self._state[2], self._state[3]], axis=-1),
             np.concatenate([np.zeros([4, 4]), self._state[4], np.zeros([4, 4])], axis=-1),
            ], axis=0).astype(np.int64)
        return padded_states

    def _draw_boards(self):
        """
        Prints the current state of the game boards in a human-readable format.
        """
        boards = self._full_boards()
        for row_idx in range(12):
            curr_row = ""
            for col_idx in range(12):
                if (row_idx <= 3 or row_idx >= 8) and (col_idx <= 3 or col_idx >= 8):
                    curr_row += "   "
                else:
                    curr_row += "{:3d}".format(boards[row_idx][col_idx])
            print(curr_row)

    def _check_win(self, player):
        """
        Checks if the specified player has won the game.

        Args:
            player (int): The player to check (1 or -1).

        Returns:
            bool: True if the player has won, False otherwise.
        """
        padded_states = self._full_boards()
        rows, cols = 12, 12
        # check 4 in a row
        for row in padded_states:
            for i in range(cols - 3):
                if all(row[j] == player for j in range(i, i + 4)):
                    return True
        # check 4 in a col
        for colidx in range(cols):
            col = [padded_states[rowidx][colidx] for rowidx in range(rows)]
            for i in range(rows - 3):
                if all(col[j] == player for j in range(i, i + 4)):
                    return True

        # check 5 in a diag
        for i in range(rows - 4):
            for j in range(cols - 4):
                diag1 = [padded_states[i + k][j + k] for k in range(5)]  # upleft to downright
                if all(cell == player for cell in diag1):
                    return True
                diag2 = [padded_states[i + k][j + 4 - k] for k in range(5)]  # upright to downleft
                if all(cell == player for cell in diag2):
                    return True
        return False

env = SuperTicTacToeEnv()

In [4]:
# test environment implementation of step()
env.reset()
for i in range(20):
    x = random.randint(0, 5 * 4 * 4 -1)
    env.step(x, verbose=True)

player 1 action 10, board 0, row 2, col 2
action 10 shift from (b0, r2 c2) to (b0, r3, c2), d=(1, 0)
action 10 accepted
reward 0, episode_ended False
boards:
              0  0  0  0            
              0  0  0  0            
              0  0  0  0            
              0  0  1  0            
  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0
              0  0  0  0            
              0  0  0  0            
              0  0  0  0            
              0  0  0  0            
player -1 action 22, board 1, row 1, col 2
action 22 accepted directly
reward 0, episode_ended False
boards:
              0  0  0  0            
              0  0  0  0            
              0  0  0  0            
              0  0  1  0            
  0  0  0  0  0  0  0  0  0  0  0  0
  0  0 -1  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0

In [5]:
# test environment implementation of check_win()
# check init
env.reset()
print(env._check_win(-1), env._check_win(1))

# check row in one board
env.reset()
env._state[4][0][0] = 1
env._state[4][0][1] = 1
env._state[4][0][2] = 1
env._state[4][0][3] = 1
print(env._check_win(-1), env._check_win(1))
env._state[4][0][3] = -1
env._state[4][1][3] = 1
print(env._check_win(-1), env._check_win(1))

# check row in two boards
env.reset()
env._state[2][2][2] = -1
env._state[2][2][3] = -1
env._state[3][2][0] = -1
env._state[3][2][1] = -1
print(env._check_win(-1), env._check_win(1))
env._state[3][2][1] = 1
env._state[3][2][2] = -1
print(env._check_win(-1), env._check_win(1))

# check col in one board
env.reset()
env._state[0][0][3] = -1
env._state[0][1][3] = -1
env._state[0][2][3] = -1
env._state[0][3][3] = -1
print(env._check_win(-1), env._check_win(1))
env._state[0][3][3] = 1
env._state[0][3][2] = -1
print(env._check_win(-1), env._check_win(1))

# check col in two boards
env.reset()
env._state[0][3][1] = 1
env._state[2][0][1] = 1
env._state[2][1][1] = 1
env._state[2][2][1] = 1
print(env._check_win(-1), env._check_win(1))
env._state[2][2][1] = -1
env._state[0][2][1] = 1
print(env._check_win(-1), env._check_win(1))
env._state[2][1][1] = -1
print(env._check_win(-1), env._check_win(1))

# check diag, across boards
env.reset()
env._state[0][2][2] = 1
env._state[0][3][1] = 1
env._state[2][0][0] = 1
env._state[1][1][3] = 1
env._state[1][2][2] = 1
env._state[0][2][0] = -1
env._state[2][0][2] = -1
env._state[2][1][3] = -1
env._state[3][2][0] = -1
# env._draw_boards()
print(env._check_win(-1), env._check_win(1))
env._state[0][3][1] = -1
# env._draw_boards()
print(env._check_win(-1), env._check_win(1))


False False
False True
False False
True False
False False
True False
False False
False True
False True
False False
False True
True False


In [None]:
class DQN(tf.keras.Model):
    """
    Deep Q-Network (DQN) model for reinforcement learning.

    The model predicts Q-values for each possible action given the current state.

    Attributes:
        num_actions (int): The number of possible actions.
        flatten (tf.keras.layers.Layer): Layer to flatten the input.
        dense1, dense2, dense3, dense4 (tf.keras.layers.Layer): Dense layers for the network.
    """
    def __init__(self, num_actions):
        """
        Initializes the DQN model with the specified number of actions.

        Args:
            num_actions (int): The number of possible actions.
        """
        super(DQN, self).__init__()
        self.num_actions = num_actions
        self.flatten = tf.keras.layers.Flatten()
        self.dense1 = tf.keras.layers.Dense(256, activation='relu', kernel_initializer=GlorotNormal(), bias_initializer=Zeros())
        self.dense2 = tf.keras.layers.Dense(128, activation='relu', kernel_initializer=GlorotNormal(), bias_initializer=Zeros())
        self.dense3 = tf.keras.layers.Dense(128, activation='relu', kernel_initializer=GlorotNormal(), bias_initializer=Zeros())
        self.dense4 = tf.keras.layers.Dense(num_actions, activation='linear', kernel_initializer=GlorotNormal(), bias_initializer=Zeros())

    def call(self, x):
        """
        Performs a forward pass through the network.

        Args:
            x (tf.Tensor): The input tensor.

        Returns:
            tf.Tensor: The output Q-values for each action.
        """
        x = self.flatten(x)
        x = self.dense1(x)
        x = self.dense2(x)
        x = self.dense3(x)
        return self.dense4(x)

    def epsilon_greedy(self, state, curr_player, epsilon=0.1):
        """
        Selects an action using the epsilon-greedy strategy.

        Args:
            state (np.ndarray): The current state of the environment.
            curr_player (int): The current player (1 or -1).
            epsilon (float): The probability of selecting a random action.

        Returns:
            int: The selected action.
        """
        curr_player = env._current_player
        if np.random.rand() < epsilon:
            action = np.random.choice(self.num_actions)
        else:
            state_ = np.concatenate([state.flatten().reshape([1, -1]), np.array(curr_player).reshape(1, 1)], axis=1)
            q_values = dqn(state_)
            # penalty for occupied actions
            penalty = np.where(state.reshape([1, -1]) == 0, 0, -1e6)
            q_values += penalty
            # action = np.argmax(q_values.numpy()[0]). random select when there is a tie
            max_qval = np.max(q_values)
            max_indices = np.where(q_values == max_qval)[1]  # [1] takes the indices
            action = np.random.choice(max_indices)
        return action


In [None]:
class ReplayBuffer:
    """
    Replay buffer for storing and sampling experience tuples.

    Attributes:
        capacity (int): The maximum number of experiences to store.
        buffer (list): The list of stored experiences.
        position (int): The current position for overwriting old experiences.
    """
    def __init__(self, capacity):
        """
        Initializes the replay buffer with the specified capacity.

        Args:
            capacity (int): The maximum number of experiences to store.
        """
        self.capacity = capacity
        self.buffer = []
        self.position = 0

    def push(self, state, action, reward, next_state, done, curr_player):
        """
        Adds a new experience to the buffer.

        Args:
            state (np.ndarray): The current state.
            action (int): The action taken.
            reward (float): The reward received.
            next_state (np.ndarray): The next state.
            done (bool): Whether the episode has ended.
            curr_player (int): The current player.
        """
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = (state, action, reward, next_state, done, curr_player)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        """
        Samples a batch of experiences from the buffer.

        Args:
            batch_size (int): The number of experiences to sample.

        Returns:
            tuple: A tuple containing arrays of states, actions, rewards, next states, dones, and current players.
        """
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done, curr_player = zip(*batch)
        return np.array(state), np.array(action), np.array(reward), np.array(next_state), np.array(done), np.array(curr_player)

    def __len__(self):
        """
        Returns the number of experiences currently stored in the buffer.

        Returns:
            int: The number of stored experiences.
        """
        return len(self.buffer)

In [None]:
def test_dqn(env, dqn, num_games=10, max_steps=1000):
    """
    Tests the DQN model by playing a specified number of games.

    Args:
        env (SuperTicTacToeEnv): The game environment.
        dqn (DQN): The trained DQN model.
        num_games (int): The number of games to play.
        max_steps (int): The maximum number of steps per game.
    """
    for _ in range(num_games):
        state = env.reset()
        state = np.array([state])
        done = False
        num_steps = 0
        while not done:
            curr_player = env._current_player
            action = dqn.epsilon_greedy(state, curr_player)
            state, reward, done = env.step(action, verbose=True)
            state = np.array([state])
            num_steps += 1
            if num_steps >= max_steps:
                break
        print("Game End")

In [None]:
def train_dqn(env, dqn, target_dqn, optimizer, replay_buffer, gamma=0.99, batch_size=32, num_episodes=1000, verbose=False):
    """
    Trains the DQN model using experience replay.

    Args:
        env (SuperTicTacToeEnv): The game environment.
        dqn (DQN): The DQN model to train.
        target_dqn (DQN): The target DQN model for stable training.
        optimizer (tf.keras.optimizers.Optimizer): The optimizer for training.
        replay_buffer (ReplayBuffer): The replay buffer for storing experiences.
        gamma (float): The discount factor for future rewards.
        batch_size (int): The number of experiences to sample per training step.
        num_episodes (int): The number of training episodes.
        verbose (bool): Whether to print detailed training information.
    """
    dqn(np.zeros((1, 5 * 4 * 4 + 1)))
    target_dqn(np.zeros((1, 5 * 4 * 4 + 1)))
    for episode in tqdm.trange(num_episodes):
        state = env.reset()
        state = np.array([state])
        num_step = 0

        total_reward = 0
        done = False

        episode_loss = []
        while not done:
            curr_player = env._current_player
            action = dqn.epsilon_greedy(state, curr_player)

            num_step += 1
            next_state, reward, done = env.step(action, verbose)
            next_state = np.array([next_state])

            replay_buffer.push(state[0], action, reward, next_state[0], done, curr_player)
            if verbose:
                print("epoch {}".format(episode), state[0].shape, action, reward, next_state[0].shape, done, curr_player)

            if len(replay_buffer) > batch_size:
                states, actions, rewards, next_states, dones, curr_players = replay_buffer.sample(batch_size)
                states = np.array([s for s in states])
                next_states = np.array([s for s in next_states])
                curr_players_ = curr_players.reshape([batch_size, 1])

                with tf.GradientTape() as tape:
                    states_ = np.concatenate([states.flatten().reshape(batch_size, -1), curr_players_], axis=1)
                    next_states_ = np.concatenate([next_states.flatten().reshape(batch_size, -1), curr_players_], axis=1)
                    current_q_values = dqn(states_)
                    next_q_values = target_dqn(next_states_)
                    max_next_q_values = tf.reduce_max(next_q_values, axis=1)  # here to maximize?

                    target_q_values = rewards + (1 - dones) * gamma * max_next_q_values
                    target_q_values = tf.expand_dims(target_q_values, axis=1)

                    indices = tf.stack([tf.range(batch_size), actions], axis=1)
                    current_q_values = tf.gather_nd(current_q_values, indices)

                    # Changed to tf.keras.losses.MSE
                    loss = tf.keras.losses.MSE(target_q_values, current_q_values)
                    episode_loss.append(loss)

                gradients = tape.gradient(loss, dqn.trainable_variables)
                optimizer.apply_gradients(zip(gradients, dqn.trainable_variables))

            state = next_state
            total_reward += reward

        # print('Episode {}: Total Reward = {}, num_steps {}, episode_avg_loss {}'.format(episode + 1, total_reward, num_step, np.mean(episode_loss)))

        # update target_dqn periodically
        if episode % 10 == 0:
            target_dqn.set_weights(dqn.get_weights())

In [10]:
env = SuperTicTacToeEnv()
num_actions = 5 * 4 * 4
dqn = DQN(num_actions)
target_dqn = DQN(num_actions)
optimizer = tf.keras.optimizers.legacy.Adam(learning_rate=1e-2)
replay_buffer = ReplayBuffer(capacity=10000)

train_dqn(env, dqn, target_dqn, optimizer, replay_buffer, batch_size=64, num_episodes=500)

100%|█████████████████████████████████████████| 500/500 [03:34<00:00,  2.33it/s]


In [11]:
test_dqn(env, dqn, max_steps=500, num_games=1)

player 1 action 18, board 1, row 0, col 2
action 18 accepted directly
reward 0, episode_ended False
boards:
              0  0  0  0            
              0  0  0  0            
              0  0  0  0            
              0  0  0  0            
  0  0  1  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0
              0  0  0  0            
              0  0  0  0            
              0  0  0  0            
              0  0  0  0            
player -1 action 77, board 4, row 3, col 1
action 77 accepted directly
reward 0, episode_ended False
boards:
              0  0  0  0            
              0  0  0  0            
              0  0  0  0            
              0  0  0  0            
  0  0  1  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0
              0  0  0  0            
      