# Creating an AI Opponent in Connect4 with Reinforcement Learning

This notebook explores Reinforcement Learning techniques that can be used to train an AI to play Connect4.

First, let's import the packages necessary to complete this task.

In [1]:
#Import packages
import pygame
import sys
import numpy as np
import random
import os

pygame-ce 2.4.1 (SDL 2.28.4, Python 3.12.7)


First, we'll create the constant values for the game to be ran.

In [2]:
#Create constants for the game
ROW_COUNT = 6
COLUMN_COUNT = 7
PLAYER1 = 1
PLAYER2 = 2
EMPTY = 0

#Pygame display constants
SQUARESIZE = 100
RADIUS = int(SQUARESIZE / 2 - 5)
width = COLUMN_COUNT * SQUARESIZE
height = (ROW_COUNT + 1) * SQUARESIZE
size = (width, height)

BLUE = (0, 0, 255)
BLACK = (0, 0, 0)
RED = (255, 0, 0)
YELLOW = (255, 255, 0)

Then we'll make all of the functions for the PyGame to work.

In [3]:
#Pygame Functions
def draw_board(board, screen):
    for c in range(COLUMN_COUNT):
        for r in range(ROW_COUNT):
            pygame.draw.rect(screen, BLUE, (c*SQUARESIZE, (r+1)*SQUARESIZE, SQUARESIZE, SQUARESIZE))
            pygame.draw.circle(screen, BLACK, (int(c*SQUARESIZE + SQUARESIZE/2), int((r+1)*SQUARESIZE + SQUARESIZE/2)), RADIUS)

    for c in range(COLUMN_COUNT):
        for r in range(ROW_COUNT):
            if board[r][c] == PLAYER1:
                pygame.draw.circle(screen, RED, (int(c*SQUARESIZE + SQUARESIZE/2), height - int((r+1)*SQUARESIZE - SQUARESIZE/2)), RADIUS)
            elif board[r][c] == PLAYER2:
                pygame.draw.circle(screen, YELLOW, (int(c*SQUARESIZE + SQUARESIZE/2), height - int((r+1)*SQUARESIZE - SQUARESIZE/2)), RADIUS)
    pygame.display.update()

def play_gui(agent, env):
    pygame.init()
    screen = pygame.display.set_mode(size)
    font = pygame.font.SysFont("monospace", 75)
    pygame.display.set_caption("Connect 4 - Play vs AI")

    state = env.reset()
    draw_board(state, screen)
    pygame.display.update()
    turn = PLAYER1
    game_over = False

    while not game_over:
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                sys.exit()

            #Have coin follow mouse
            if event.type == pygame.MOUSEMOTION:
                pygame.draw.rect(screen, BLACK, (0,0, width, SQUARESIZE))
                posx = event.pos[0]
                if turn == PLAYER1:
                    pygame.draw.circle(screen, RED, (posx, int(SQUARESIZE/2)), RADIUS)
                pygame.display.update()

            #Put coin where clicked
            if turn == PLAYER1 and event.type == pygame.MOUSEBUTTONDOWN:
                x_pos = event.pos[0]
                col = int(x_pos / SQUARESIZE)

                if col in env.valid_actions():
                    state, reward, done = env.step(col, PLAYER1)
                    draw_board(state, screen)

                    if done:
                        if env.winner == PLAYER1:
                            label = font.render("You win!", 1, RED)
                        else:
                            label = font.render("Draw!", 1, RED)
                        screen.blit(label, (40, 10))
                        pygame.display.update()
                        pygame.time.wait(3000)
                        game_over = True
                        break

                    turn = PLAYER2

        if turn == PLAYER2 and not game_over:
            pygame.time.wait(500)  # Delay for realism
            col = agent.choose_action(env)
            state, reward, done = env.step(col, PLAYER2)
            draw_board(state, screen)

            if done:
                if env.winner == PLAYER2:
                    label = font.render("AI wins!", 1, YELLOW)
                else:
                    label = font.render("Draw!", 1, YELLOW)
                screen.blit(label, (40, 10))
                pygame.display.update()
                pygame.time.wait(3000)
                game_over = True

            turn = PLAYER1

    pygame.quit()

## The Environment

The environment for this project is the classic game of Connect 4.

- **State:** The state of the environment is represented by a 2D NumPy array (6 rows, 7 columns) where each element indicates whether a cell is empty (0), occupied by Player 1 (1), or occupied by Player 2 (2).
- **Actions:** The available actions are the columns in which a player can drop their piece. A column is a valid action if the topmost cell in that column is empty.
- **Reward:** The reward is 1 if the agent wins and 0 for a draw.
- **Completed:** The episode is done when a player wins or when the board is full (a draw).

In [4]:
#Connect4 Environment
class Connect4Env:
    def __init__(self):
        self.board = np.zeros((ROW_COUNT, COLUMN_COUNT), dtype=int)
        self.done = False
        self.winner = None

    def reset(self):
        self.board[:] = 0
        self.done = False
        self.winner = None
        return self.get_state()

    def get_state(self):
        return self.board.copy()

    def valid_actions(self):
        return [c for c in range(COLUMN_COUNT) if self.board[ROW_COUNT-1][c] == EMPTY]

    def drop_piece(self, col, piece):
        for r in range(ROW_COUNT):
            if self.board[r][col] == EMPTY:
                self.board[r][col] = piece
                return r

    def winning_move(self, piece):
        for c in range(COLUMN_COUNT - 3):
            for r in range(ROW_COUNT):
                if np.all(self.board[r, c:c+4] == piece):
                    return True
        for c in range(COLUMN_COUNT):
            for r in range(ROW_COUNT - 3):
                if np.all(self.board[r:r+4, c] == piece):
                    return True
        for c in range(COLUMN_COUNT - 3):
            for r in range(ROW_COUNT - 3):
                if all(self.board[r+i][c+i] == piece for i in range(4)):
                    return True
        for c in range(COLUMN_COUNT - 3):
            for r in range(3, ROW_COUNT):
                if all(self.board[r-i][c+i] == piece for i in range(4)):
                    return True
        return False

    def step(self, action, piece):
        if action not in self.valid_actions():
            return self.get_state(), -10, True
        self.drop_piece(action, piece)

        reward = 0
        if self.winning_move(piece):
            self.done = True
            self.winner = piece
            reward = 1
        elif len(self.valid_actions()) == 0:
            self.done = True
            self.winner = 0
            reward = 0
        else:
            reward = 0

        return self.get_state(), reward, self.done

## The Agent

The agent implemented is a TD(lambda) agent, a type of reinforcement learning agent that uses temporal difference learning with eligibility traces.

- **TD(lambda):** This method updates the agent's value function based on the difference between the predicted value of a state and the actual outcome (reward plus discounted value of the next state), considering a trace of recent states visited. This helps to attribute credit for the outcome to earlier states that contributed to it.
- **Player:** The agent is initialized with a player number (PLAYER1 or PLAYER2) to distinguish its pieces from the opponent's.
- **Weights (w):** The agent learns a set of weights that represent the importance of different features in the state for determining its value. These weights are updated during training.
- **Eligibility Traces (e):** These traces keep track of which features were active and how recently they were active, allowing for credit assignment over multiple time steps.

## The Policy

The agent uses a combination of strategies for choosing an action:

- **Epsilon-Greedy (during training):** The agent explores the environment by choosing a random valid action with a probability of epsilon. Otherwise, it chooses the action that is expected to lead to the state with the highest estimated value. Epsilon decays over training episodes to shift from exploration to exploitation.
- **Prioritized Actions (during play):** The agent first checks for immediate winning moves and blocks the opponent's winning moves. It also checks for and blocks "double-sided two-in-a-row" threats. If none of these immediate threats or opportunities exist, it selects the action that leads to the highest predicted state value.
- **Middle Column Preference:** Among actions that yield the same highest predicted value, the agent prioritizes dropping a piece in the middle column, which is a common strategy in Connect 4.

The update method used is the TD(lambda) update rule, which modifies the weights based on the temporal difference error and the eligibility traces.

In [5]:
#Agent Class
class TDLambdaAgent:
    def __init__(self, player, alpha=0.03, gamma=0.9, lam=0.8, epsilon=1.0):
        self.player = player
        self.alpha = alpha
        self.gamma = gamma
        self.lam = lam
        self.epsilon = epsilon
        self.w = np.zeros(87)
        self.e = np.zeros_like(self.w)

    def set_epsilon(self, eps):
        self.epsilon = eps

    def save_weights(self, path):
        np.save(path, self.w)

    def load_weights(self, path):
        self.w = np.load(path)

    def value(self, state):
        x = enhanced_features(state, self.player)
        return np.dot(self.w, x), x

    def choose_action(self, env):
        valid = env.valid_actions()
        MIDDLE_COL = 3
        opponent = PLAYER1 if self.player == PLAYER2 else PLAYER2

        #Check if AI can win immediately
        for col in valid:
            temp_env = Connect4Env()
            temp_env.board = env.board.copy()
            temp_env.drop_piece(col, self.player)
            if temp_env.winning_move(self.player):
                return col  # play winning move

        #Check for immediate block against opponent's 3-in-a-row (winning threat)
        for col in valid:
            temp_env = Connect4Env()
            temp_env.board = env.board.copy()
            temp_env.drop_piece(col, opponent)
            if temp_env.winning_move(opponent):
                return col  # block winning move

        #Check for "double-sided two-in-a-row" threats from opponent
        for col in valid:
            row = next((r for r in range(ROW_COUNT) if env.board[r][col] == EMPTY), None)
            if row is not None:
                temp_board = env.board.copy()
                temp_board[row][col] = opponent
                if creates_double_sided_two(temp_board, row, col, opponent):
                    return col

        #Prefer middle if it's among max-valued actions
        values = []
        for a in valid:
            temp_env = Connect4Env()
            temp_env.board = env.board.copy()
            temp_env.done = env.done
            temp_env.winner = env.winner
            temp_env.drop_piece(a, self.player)
            v, _ = self.value(temp_env.board)
            values.append(v)

        max_val = max(values)
        max_actions = [a for a, v in zip(valid, values) if v == max_val]

        #Prioritize middle column if it's one of the best options
        if MIDDLE_COL in max_actions:
            return MIDDLE_COL

        return random.choice(max_actions)

    def update(self, state, reward, next_state, done):
        v_s, x_s = self.value(state)
        v_s_next, x_s_next = (0, np.zeros_like(self.w)) if done else self.value(next_state)
        delta = reward + self.gamma * v_s_next - v_s
        self.e = self.gamma * self.lam * self.e + x_s
        self.w += self.alpha * delta * self.e

## Value Function

The agent uses a linear value function approximator:

$V(s) = w^T \phi(s)$

where:
- $V(s)$ is the estimated value of state $s$.
- $w$ is the vector of learned weights.
- $\phi(s)$ is the feature vector extracted from state $s$.

The `enhanced_features` function extracts several features from the game board to represent the state:

- **Player and Opponent Piece Positions:** Binary features indicating the presence of the agent's or opponent's pieces in each cell.
- **Center Column Control:** A feature representing the proportion of the center column occupied by the agent's pieces.
- **Threat Potential:** Features counting the number of potential winning lines (4-in-a-row) for the agent and the opponent that are currently unfilled but contain at least one piece of that player.

## Parameters

- **alpha ($\alpha$):** The learning rate (0.03). This parameter determines how much the weights are adjusted with each update. A smaller value leads to slower but potentially more stable learning.
- **gamma ($\gamma$):** The discount factor (0.9). This parameter determines the importance of future rewards. A value closer to 1 means the agent considers future rewards more heavily.
- **lambda ($\lambda$):** The eligibility trace decay rate (0.8). This parameter controls how quickly the eligibility traces decay. A higher value means that credit is given to more distant past states.
- **epsilon ($\epsilon$):** The exploration rate (starts at 1.0 and decays to 0.05). This parameter determines the probability of choosing a random action during training. It is crucial for exploring the state space and finding optimal strategies. The decay schedule ensures that the agent exploits its learned knowledge more as training progresses.
- **Minimax Depth:** The `minimax_move` function uses a depth of 3 for evaluating moves when playing against the minimax agent during training or evaluation. This determines how many steps ahead the minimax algorithm looks.
- **Training Episodes:** The number of episodes the agent trains for (100,000 in the commented-out training code). More episodes generally lead to better performance, but also require more computational time.
- **Evaluation Games:** The number of games played against the minimax agent during evaluation (100). This helps to get a more reliable estimate of the agent's win rate.

These parameters were chosen based on common values used in reinforcement learning for similar problems, then fine-tuned to where they are now. The epsilon decay schedule is designed to allow for sufficient exploration early in training and transition to exploitation later.

In [6]:
#Training Functions
def enhanced_features(board, player):
    opponent = PLAYER1 if player == PLAYER2 else PLAYER2
    features = []

    #Player and opponent piece positions (binary encodings)
    player_board = (board == player).astype(int)
    opp_board = (board == opponent).astype(int)
    features.extend(player_board.flatten())
    features.extend(opp_board.flatten())

    #Center column feature
    center_col = COLUMN_COUNT // 2
    center_array = board[:, center_col]
    center_control = np.sum(center_array == player) / ROW_COUNT
    features.append(center_control)

    #Threat potential features
    player_threats = count_threats(board, player)
    opponent_threats = count_threats(board, opponent)
    features.append(player_threats / 50.0)
    features.append(opponent_threats / 50.0)

    return np.array(features)


def count_threats(board, piece):
    count = 0

    #Check all possible 4-in-a-row slices
    for r in range(ROW_COUNT):
        for c in range(COLUMN_COUNT - 3):
            window = board[r, c:c+4]
            if is_threat_window(window, piece):
                count += 1

    for r in range(ROW_COUNT - 3):
        for c in range(COLUMN_COUNT):
            window = board[r:r+4, c]
            if is_threat_window(window, piece):
                count += 1

    for r in range(ROW_COUNT - 3):
        for c in range(COLUMN_COUNT - 3):
            window = [board[r+i][c+i] for i in range(4)]
            if is_threat_window(window, piece):
                count += 1

    for r in range(ROW_COUNT - 3):
        for c in range(COLUMN_COUNT - 3):
            window = [board[r+3-i][c+i] for i in range(4)]
            if is_threat_window(window, piece):
                count += 1

    return count

def is_threat_window(window, piece):
    return (np.count_nonzero(window == piece) > 0 and
            np.count_nonzero(window == EMPTY) == 4 - np.count_nonzero(window == piece))

def creates_double_sided_two(board, row, col, piece):
    directions = [(0, 1), (1, 0), (1, 1), (1, -1)]  #horiz, vert, diag1, diag2
    for dr, dc in directions:
        count = 1
        empty_left = False
        empty_right = False

        #Check one direction
        r, c = row + dr, col + dc
        while 0 <= r < ROW_COUNT and 0 <= c < COLUMN_COUNT and board[r][c] == piece:
            count += 1
            r += dr
            c += dc
        if 0 <= r < ROW_COUNT and 0 <= c < COLUMN_COUNT and board[r][c] == EMPTY:
            empty_right = True

        #Check opposite direction
        r, c = row - dr, col - dc
        while 0 <= r < ROW_COUNT and 0 <= c < COLUMN_COUNT and board[r][c] == piece:
            count += 1
            r -= dr
            c -= dc
        if 0 <= r < ROW_COUNT and 0 <= c < COLUMN_COUNT and board[r][c] == EMPTY:
            empty_left = True

        if count == 2 and empty_left and empty_right:
            return True

    return False

def train_td_lambda(agent1, agent2, env, episodes=100000):
    win_history = []
    best_winrate = 0
    decay_episodes = episodes // 20  #reach 0.05 after 1/20th of episodes

    for ep in range(episodes):
        #Epsilon decay
        epsilon = max(0.05, 1.0 - (0.95 * ep / decay_episodes))
        agent1.set_epsilon(epsilon)
        agent2.set_epsilon(epsilon)

        state = env.reset()
        done = False
        state_prev = None

        #Alternate between self-play and minimax every other episode
        use_minimax = (ep % 2 == 0)
        current_agent = agent1
        other_agent = agent2
        current_player = agent1.player

        while not done:
            if use_minimax and current_agent == agent2:
                action, _ = minimax_move(env, player=current_player)
                if action is None:
                    action = random.choice(env.valid_actions())
            else:
                action = current_agent.choose_action(env)

            next_state, reward, done = env.step(action, current_player)

            if done:
                current_agent.update(state, reward, next_state, done)
                if state_prev is not None:
                    other_agent.update(state_prev, -reward, next_state, done)
                win_history.append(env.winner)
                break

            if state_prev is not None:
                current_agent.update(state_prev, 0, state, False)

            state_prev = state
            state = next_state

            current_agent, other_agent = other_agent, current_agent
            current_player = PLAYER1 if current_player == PLAYER2 else PLAYER2

        #Evaluation and checkpoint
        if (ep + 1) % 500 == 0:
            winrate = evaluate_vs_minimax(agent1, env)
            print(f"Episode {ep+1}, Epsilon: {epsilon:.4f}, Winrate: {winrate:.2f}")
            if winrate > best_winrate:
                best_winrate = winrate
                agent1.save_weights("best_agent2.npy")

    return win_history

def minimax_move(env, depth=3, maximizing=True, player=PLAYER2):
    opponent = PLAYER1 if player == PLAYER2 else PLAYER2
    valid_moves = env.valid_actions()

    def evaluate(board):
        return np.sum(board == player) - np.sum(board == opponent)

    if depth == 0 or env.done:
        return None, evaluate(env.board)

    best_move = None
    if maximizing:
        max_eval = -float('inf')
        for move in valid_moves:
            temp_env = Connect4Env()
            temp_env.board = env.board.copy()
            temp_env.drop_piece(move, player)
            temp_env.done = temp_env.winning_move(player)
            _, eval_val = minimax_move(temp_env, depth-1, False, player)
            if eval_val > max_eval:
                max_eval = eval_val
                best_move = move
        return best_move, max_eval
    else:
        min_eval = float('inf')
        for move in valid_moves:
            temp_env = Connect4Env()
            temp_env.board = env.board.copy()
            temp_env.drop_piece(move, opponent)
            temp_env.done = temp_env.winning_move(opponent)
            _, eval_val = minimax_move(temp_env, depth-1, True, player)
            if eval_val < min_eval:
                min_eval = eval_val
                best_move = move
        return best_move, min_eval

def evaluate_vs_minimax(agent, env, games=100):
    wins = 0
    draws = 0
    for _ in range(games):
        state = env.reset()
        done = False
        player = PLAYER1
        current_agent = agent

        while not done:
            if player == agent.player:
                action = agent.choose_action(env)
            else:
                action, _ = minimax_move(env, player)
                if action is None:
                    action = random.choice(env.valid_actions())

            state, reward, done = env.step(action, player)
            player = PLAYER1 if player == PLAYER2 else PLAYER2

        if env.winner == agent.player:
            wins += 1
        elif env.winner == 0:
            draws += 1

    winrate = wins / games
    return winrate

Now, we will train the agent using weights from a previous run-through as the starting point.

In [7]:
#Train the Agent
if __name__ == "__main__":
    env = Connect4Env()
    agent1 = TDLambdaAgent(PLAYER2)
    agent2 = TDLambdaAgent(PLAYER1)

    #Load previous weights
    agent1.load_weights("best_agent1.npy")
    print("Loaded saved weights.")

    print("Training agents with TD(lambda) learning and self-play...")
    train_td_lambda(agent1, agent2, env, episodes=10000)

Loaded saved weights.
Training agents with TD(lambda) learning and self-play...


This code allows us to play against the agent using a PyGame GUI.

In [8]:
#Play against the Agent
agent1 = TDLambdaAgent(PLAYER2)
agent1.load_weights("best_agent2.npy")

play_gui(agent1, env)

## Conclusion

I am pretty happy with the result of this project. I would have preferred to have it train for much longer (100,000 episodes rather than 10,000), but with the help of some functions to reduce the training time, it performs somewhat well. Some of the moves are not ideal and seem random but it can still beat me almost as many times as I can beat it. Next time, I will try to have it train longer and maybe make it unique by having the board rotate 90 degrees every so often.