In [None]:
from copy import deepcopy
import numpy as np
import matplotlib.pyplot as plt
import random

import gym
from gym import spaces

from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import configure

In [None]:
ROWS, COLS = 7, 8
COLORS_ORIGINAL = {0: 'red', 1: 'yellow', 2: 'green', 3: 'blue', 4: 'purple', 5: 'black'}
COLORS = {0: 'red', 1: 'yellow', 2: '#a4de02', 3: '#6badce', 4: 'purple', 5: 'black'}

In [None]:
class FillerEnv(gym.Env):

    def __init__(self) -> None:
        super(FillerEnv, self).__init__()
        
        # env params
        self.rows, self.cols = ROWS, COLS
        self.colors = list(COLORS.keys())
        self.original_grid = self.generate_grid()
        self.reset()
        
        # observation/state space: grid, p1 color, p2 color
        self.observation_space = spaces.Dict({
            # 'grid': spaces.Box(low=1, high=6, shape=(self.rows, self.cols), dtype=np.int32),
            'grid': spaces.Box(low=0, high=5, shape=(self.rows, self.cols), dtype=np.int32),
            'p1_color': spaces.Discrete(6),
            'p2_color': spaces.Discrete(6),
        })
        
        # action space: choosing a color
        self.action_space = spaces.Discrete(6)
        self.render()  # render the initial grid


    def reset(self) -> dict:
        self.grid = deepcopy(self.original_grid)
        self.p1_color = self.grid[-1][0]
        self.p2_color = self.grid[0][-1]
        self.squares = {1: {(6, 0)}, 2: {(0, 7)}}
        self.turn = 1  # p1 starts
        return self._get_obs()
    

    def step(self, action: int) -> tuple[dict, float, bool, dict]:
        if action == self.p1_color or action == self.p2_color:
            # invalid move penalty
            reward = -0.5
            done = False
            return self._get_obs(), reward, done, {}
        
        if self.turn == 1:
            self._update_grid(color=action, player=1)
            self.p1_color = action
        else:
            self._update_grid(color=action, player=2)
            self.p2_color = action

        reward = self._compute_reward()
        done = self._check_termination()
        self.turn = 3 - self.turn  # change the turn
        
        return self._get_obs(), reward, done, {}
    

    def _get_obs(self):
        return {
            'grid': self.grid,
            'p1_color': self.p1_color,
            'p2_color': self.p2_color,
        }
    

    def _update_grid(self, color: int, player: int):
        dirs = [(1, 0), (-1, 0), (0, 1), (0, -1)]
        curr_squares = self.squares[player].copy()
        for row, col in curr_squares:
            self.grid[row][col] = color
            for dr, dc in dirs:
                new_r, new_c = row + dr, col + dc
                if 0 <= new_r < self.rows and 0 <= new_c < self.cols and self.grid[new_r][new_c] == color:
                    self.squares[player].add((new_r, new_c))
    

    def _compute_reward(self):
        p1_score = len(self.squares[1])
        p2_score = len(self.squares[2])
        if p1_score > 28:
            return 1  # win reward
        elif p2_score > 28:
            return -1  # lose penalty
        return p1_score / (self.rows * self.cols)  # Fraction of board captured
    

    def _check_termination(self):
        p1_score = len(self.squares[1])
        p2_score = len(self.squares[2])
        total_squares = self.rows * self.cols
        if p1_score > 28 or p2_score > 28 or (p1_score + p2_score == total_squares):
            return True
        return False
    

    def render(self):
        color_grid = [[COLORS[key] for key in row] for row in self.grid]

        _, ax = plt.subplots()

        for i in range(ROWS):
            for j in range(COLS):
                rect = plt.Rectangle((j, ROWS - i - 1), 1, 1, color=color_grid[i][j])
                ax.add_patch(rect)

        ax.set_xlim(0, COLS)
        ax.set_ylim(0, ROWS)
        ax.set_aspect('equal')
        ax.set_xticks([])
        ax.set_yticks([])

        plt.show()
    

    def generate_grid(self) -> list:

        # determine if we can change grid[row][col] to color
        def is_valid(row, col, color):

            if row > 0 and grid[row-1][col] == color:  # on top of current cell
                return False
            if col > 0 and grid[row][col-1] == color:  # to the left of current cell
                return False

            if (row, col) == (1, 7) and grid[0][6] == color:  # check 2 neighbors of top-right
                return False
            if (row, col) == (6, 1) and grid[5][0] == color:  # check 2 neighbors of bottom-left
                return False
            if (row, col) == (6, 1) and grid[5][0] == grid[0][6] and grid[1][7] == color:  # check 2 options at start are different
                return False
            if (row, col) == (6, 0) and grid[0][7] == color:  # ensure difference in P1/P2 start colors
                return False
            return True
        
        grid = [[-1]*COLS for _ in range(ROWS)]
        for row in range(ROWS):
            for col in range(COLS):
                color_keys = list(COLORS.keys())
                random.shuffle(color_keys)  # shuffle colors for randomness
                for color in color_keys:
                    if is_valid(row, col, color):
                        grid[row][col] = color
                        break
        return grid
    

    def human_turn(self):
        """Handles the human player's turn by taking input."""
        def possible_turns():
            return list(set(range(6)) - {self.p1_color, self.p2_color})
        
        print(f'VALID TURNS for Player {self.turn}: {possible_turns()}')
        
        valid_action = False
        while not valid_action:
            try:
                action = input("Choose a color (hit ENTER to exit) ({0: 'red', 1: 'yellow', 2: 'green', 3: 'blue', 4: 'purple', 5: 'black'}) :: ")
                if action == '':
                    return -1
                
                action = int(action)
                if action in COLORS.keys() and action not in (self._get_obs()['p1_color'], self._get_obs()['p2_color']):
                    valid_action = True
                else:
                    print('Invalid move! You cannot choose the current colors.')
            except ValueError:
                print('Please enter a number between 0-5, inclusive.')
        return action
    

    def computer_turn(self, trained_model):
        """Handles the computer's turn using the trained model."""
        action, _ = trained_model.predict(env._get_obs())
        return int(action)

In [None]:
# Custom callback for detailed logging
class DebugCallback(BaseCallback):
    def __init__(self, verbose=0):
        super(DebugCallback, self).__init__(verbose)
    
    def _on_step(self) -> bool:
        # Log additional information during training
        if self.num_timesteps % 100000 == 0:
            print(f"Timestep: {self.num_timesteps}, Reward: {self.locals['rewards']}")
        return True

In [None]:
# log to the console
new_logger = configure(folder="./logs/", format_strings=["stdout", "csv", "tensorboard"])

env = FillerEnv()
vec_env = make_vec_env(lambda: env, n_envs=8)

model = PPO('MultiInputPolicy', vec_env, verbose=1)
model.set_logger(new_logger)

# train with debug callback
callback = DebugCallback(verbose=1)

In [None]:
# train for 2M timesteps using debugger callback
model.learn(total_timesteps=2000000, callback=callback)

In [None]:
# save the model
model.save("filler_rl_model_2M")

In [None]:
# simulate environment to play against model
def play_game(trained_model) -> None:
    print('State of game board at the start :: ')
    env.render()

    done = False
    while not done:
        if env.turn == 1:  # human's turn
            print(f'env.turn == {env.turn}')
            action = env.human_turn()
            if action == -1:
                print('Ending game...')
                break
            
            print(f'You chose {COLORS_ORIGINAL[action]}')
        else:  # computer's turn
            print(f'env.turn == {env.turn}')
            action = env.computer_turn(trained_model=trained_model)
            print(f'Computer chose {COLORS_ORIGINAL[action]}')

        _, reward, done, _ = env.step(action)
        env.render()

        print(f'You: {len(env.squares[1])}')
        print(f'Computer: {len(env.squares[2])}')

        if done:
            print('Game Over!')
            p1_score = len(env.squares[1])
            p2_score = len(env.squares[2])
            print(f'Player 1 (You): {p1_score} squares')
            print(f'Player 2 (AI): {p2_score} squares')

            if p1_score > p2_score:
                print(f'You win! {p1_score} > {p2_score}')
            elif p1_score < p2_score:
                print(f'The AI wins {p2_score} > {p1_score}!')
            else:
                print('Draw!')

In [None]:
env.reset()
play_game(trained_model=model)