I was helping at some friends' garage sale, and saw the game Connect 4 lying in a pile of old board games to be sold for a dollar. One of the friends mentioned they 

This game is a great candidate for a "toy" reinforcement learning problem for a few reasons:
* **The state space of the game is discrete and finite**, being simply the $6 \times 7$ rectangular array of tokens. Each slot in the array is either empty (0) or is occupied by a red (1) or black (2) token. Thus, the state space is the set of 42-vectors whose components are either 0, 1, or 2. There are thus at most $3^{42}$ possible states of the board, though most of these will never occur in the normal course of gameplay (for example, the board consisting of all red tokens).
* **The agents know the state entirely**.
* **The dynamics of the game are deterministic and known**: when a token is placed in a column, it falls to the bottom of the column, without fail. 

There is one issue: to apply standard tabular value function methods, each agent would need to maintain a value for every of the $3^{42}$ states. This is an astronomically large number. We will need to use some value function approximation in order to make this problem tractable.

In [6]:
import torch
import numpy as np
from typing import Tuple, Any
from copy import deepcopy

In [7]:
class C4Grid:
    def __init__(self):
        self._grid_size = torch.Size([6,7])
        self._nwin = 4
        self._grid_state = torch.zeros(self._grid_size, dtype=torch.int8)
        self._player_chars = { 0: ' ', 1: 'O', 2: 'X' }
        self._complete_flag = False


    def restart(self) -> None:
        print('Restarting board!')
        self._complete_flag = False
        self._grid_state = torch.zeros(self._grid_size, dtype=torch.int8)


    def make_move(self, player: int, move_column: int) -> Tuple[int, Any, bool]:
        """Make a move.
        
            Returns:
                int: reward for the move.
                Any: state of the board after the move.
                bool: True if game is won by this move.
        """
        if self._complete_flag:
            raise GameCompleteError('Game is complete!')

        if not self.__is_valid_move(move_column):
            # move is invalid: 
            # raise ValueError("Move invalid!")
            return -1_000, self._emit_state(), False
        else:
            # make the move
            mr, _ = self.__drop_token(p=player, mc=move_column)
            # check if game is won
            if self._is_game_won(mr=mr, mc=move_column):
                self._complete_flag = True
                return 100, self._emit_state(), True
            elif self._is_game_draw():
                self._complete_flag = True
                return 0, self._emit_state(), True
        return 0, self._emit_state(), False


    def __is_valid_move(self, mc: int) -> bool:
        """Verify if move is valid, given the state of the board."""
        if not isinstance(mc, int) or mc < 0 or mc >= self._grid_size[1]:
            return False
        if not self._grid_state[0, mc] == 0:
            return False
        return True


    def __drop_token(self, p: int, mc: int) -> Tuple[int, int]:
        """Drop a token into a valid column.
        
            Starts at bottom row and moves up the rows, checking.
            Assumes that the column is not empty.

        Returns:
            (row, col) where the move placed a token
        """
        i = self._grid_size[0] - 1
        while (self._grid_state[i, mc] != 0):
            i -= 1
        self._grid_state[i, mc] = p
        return (i, mc)


    def print_board(self) -> None:
        """Print a representation of the board to the console."""
        # print(self._grid_state)
        nrows, ncols = self._grid_size
        print('----' * ncols + '-')
        for r in range(nrows):
                self.__print_row_moves(r)
                self.__print_row_lines()
    

    def __print_row_moves(self, r: int) -> None:
        row_chars = [
            self._player_chars[self._grid_state[r,c].item()]
            for c in range(self._grid_size[1])
            ]
        row_str = ' | '.join(row_chars)
        print('| ' + row_str + ' |')


    def __print_row_lines(self) -> None:
        print('-' * (4 * self._grid_size[1] + 1))

    @property
    def complete(self):
        return self._complete_flag


    def _is_game_won(self, mr: int, mc: int) -> Tuple[bool, int]:
        """Check if game is won after player made move in move column mc.
        
            Since we are checking after a particular player's move, we need
            only check in the vicinity of that move, i.e., we can simply
            check if that move was a winning move, rather than check the
            entirety of the board.
        """
        return True if any([
            self.__check_row(mr, mc),
            self.__check_col(mr, mc),
            self.__check_diags(mr, mc),
            ]) else False


    def _is_game_draw(self):
        """Check if the game is a draw."""
        return not (self._grid_state == 0).any()


    def __check_row(self, mr: int, mc: int) -> bool:
        ptok = self._grid_state[mr, mc]  # player token
        cmin = max(0, mc - (self._nwin - 1))
        while cmin + self._nwin <= self._grid_size[1]:
            if all(self._grid_state[mr, cmin:cmin+self._nwin] == ptok):
                return True
            cmin += 1
        return False


    def __check_col(self, mr: int, mc: int) -> bool:
        ptok = self._grid_state[mr, mc]  # player token
        rmin = max(0, mr - 3)
        while rmin + self._nwin <= self._grid_size[0]:
            if all(self._grid_state[rmin:rmin+self._nwin, mc] == ptok):
                return True
            rmin += 1
        return False

    
    def __check_downward_diag(self, gs: torch.tensor, mr: int, mc: int, ptok: int):
        # downward diagonal
        i = 0
        rmin = mr
        cmin = mc
        # find upper-leftmost point
        while i < 3 and rmin - i > 0 and cmin - i > 0:
            i += 1
        rmin = rmin - i
        cmin = cmin - i
        
        while (
            rmin + self._nwin <= self._grid_size[0] and
            cmin + self._nwin <= self._grid_size[1]
            ):
            if all(
                gs[
                    rmin:rmin+self._nwin,
                    cmin:cmin+self._nwin
                ].diag() == ptok
                ):
                return True
            rmin += 1
            cmin += 1
        return False

    def __check_diags(self, mr: int, mc: int) -> bool:
        """Check both diagonals."""
        ptok = self._grid_state[mr, mc]  # player token

        return (
            self.__check_downward_diag(
                gs=self._grid_state,
                mr=mr,
                mc=mc,
                ptok=ptok,
                )
            or
            self.__check_downward_diag(
                gs=torch.flip(self._grid_state, dims=[1]),
                mr=self._grid_size[0] - mr,
                mc=self._grid_size[0] - mc,
                ptok=ptok,
                )
        )

    def _emit_state(self):
        """Emit the game state. To be consumed by the agent(s)."""
        return deepcopy(self._grid_state)


class GameCompleteError(Exception):
    pass


In [8]:
game = C4Grid()

game.make_move(1, 0)
game.make_move(1, 0)
game.make_move(1, 0)
game.make_move(1, 0)

game.print_board()

game.complete

-----------------------------
|   |   |   |   |   |   |   |
-----------------------------
|   |   |   |   |   |   |   |
-----------------------------
| O |   |   |   |   |   |   |
-----------------------------
| O |   |   |   |   |   |   |
-----------------------------
| O |   |   |   |   |   |   |
-----------------------------
| O |   |   |   |   |   |   |
-----------------------------


True

In [9]:
game = C4Grid()

game.make_move(1, 0)
game.make_move(2, 0)
game.make_move(2, 0)
game.make_move(2, 0)
game.make_move(1, 0)

game.make_move(2, 1)
game.make_move(2, 1)
game.make_move(2, 1)
game.make_move(1, 1)

game.make_move(2, 2)
game.make_move(2, 2)
game.make_move(1, 2)

game.make_move(1, 3)
game.make_move(1, 3)

game.print_board()

-----------------------------
|   |   |   |   |   |   |   |
-----------------------------
| O |   |   |   |   |   |   |
-----------------------------
| X | O |   |   |   |   |   |
-----------------------------
| X | X | O |   |   |   |   |
-----------------------------
| X | X | X | O |   |   |   |
-----------------------------
| O | X | X | O |   |   |   |
-----------------------------


In [10]:
class C4Agent:
    def __init__(self):
        self._t: int = 0  # time (num moves)
        self._n_actions = 7
        self.reward = 0

    def policy(self, state: Any = None) -> int:
        """Mapping of (enviroment) states to agent's actions."""
        # random policy
        return int(np.random.choice(a=range(self._n_actions)))


agent = C4Agent()

In [11]:

niter = 0
nitermax = 10
game = C4Grid()

player_1 = C4Agent()
player_2 = C4Agent()

state = game._emit_state()
try:
    while niter < nitermax:
        p1_move = player_1.policy(state=state)
        p1_reward, state, _ = game.make_move(1, move_column=p1_move)
        player_1.reward += p1_reward
        print(f'Player 1\tselected move: {p1_move:2d}\treward: {p1_reward:5d} (total: {player_1.reward:5d})')
        
        p2_move = player_2.policy(state=state)
        p2_reward, state, _ = game.make_move(2, move_column=p2_move)
        player_2.reward += p2_reward
        print(f'Player 2\tselected move: {p2_move:2d}\treward: {p2_reward:5d} (total: {player_2.reward:5d})')
    niter += 1
    print('Iterations over, game incomplete!')
except GameCompleteError:
    print('Game is complete!')
    if player_1.reward > player_2.reward:
        print('Player 1 wins!')
    elif player_1.reward < player_2.reward:
        print('Player 2 wins!')
    else:
        print('Game is a draw.')

Player 1	selected move:  3	reward:     0 (total:     0)
Player 2	selected move:  1	reward:     0 (total:     0)
Player 1	selected move:  0	reward:     0 (total:     0)
Player 2	selected move:  4	reward:     0 (total:     0)
Player 1	selected move:  3	reward:     0 (total:     0)
Player 2	selected move:  4	reward:     0 (total:     0)
Player 1	selected move:  3	reward:     0 (total:     0)
Player 2	selected move:  1	reward:     0 (total:     0)
Player 1	selected move:  0	reward:     0 (total:     0)
Player 2	selected move:  4	reward:     0 (total:     0)
Player 1	selected move:  5	reward:     0 (total:     0)
Player 2	selected move:  1	reward:     0 (total:     0)
Player 1	selected move:  2	reward:     0 (total:     0)
Player 2	selected move:  2	reward:     0 (total:     0)
Player 1	selected move:  4	reward:     0 (total:     0)
Player 2	selected move:  6	reward:     0 (total:     0)
Player 1	selected move:  0	reward:     0 (total:     0)
Player 2	selected move:  2	reward:     0 (total:

In [12]:
game.print_board()

-----------------------------
|   |   |   |   |   |   |   |
-----------------------------
|   |   |   |   |   |   |   |
-----------------------------
| O |   |   |   | O |   |   |
-----------------------------
| O | X | X | O | X |   |   |
-----------------------------
| O | X | X | O | X |   |   |
-----------------------------
| O | X | O | O | X | O | X |
-----------------------------


In [13]:
import torch
import torch.nn
import torch.nn.functional

class C4AgentNN:
    def __init__(self):
        self._t: int = 0  # time (num moves)
        self._n_actions = 7
        self.reward = 0

    def policy(self, state: Any = None) -> int:
        """Mapping of (enviroment) states to agent's actions."""
        # random policy
        return int(np.random.choice(a=range(self._n_actions)))

In [31]:
class DQN(torch.nn.Module):
    def __init__(self):
        super(DQN, self).__init__()
        self.lin1 = torch.nn.Linear(in_features=6*7, out_features=6*5)
        self.lin2 = torch.nn.Linear(in_features=6*5, out_features=6*3)
        self.lin3 = torch.nn.Linear(in_features=6*3, out_features=7)

    # Called with either one element to determine next action, or a batch
    # during optimization. Returns tensor([[left0exp,right0exp]...]).
    def forward(self, x):
        # x = x.to(device)
        x = torch.nn.functional.relu(self.lin1(x))
        x = torch.nn.functional.relu(self.lin2(x))
        x = torch.nn.functional.relu(self.lin3(x))
        return x

In [32]:
dqn = DQN()
dqn.forward(game._emit_state().flatten().float())

tensor([0.2020, 0.0000, 0.1648, 0.3150, 0.2519, 0.0999, 0.1664],
       grad_fn=<ReluBackward0>)