In [None]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import itertools

from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback
import copy
import random, math
import os
import torch as th
from torch import nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

In [None]:
import math
import numpy as np
import pexpect
import ctypes

class Connect4Solver:
    !pip install stable_baselines3
    !git clone https://github.com/TonyCongqianWang/connect4_solver_fork.git && cd connect4_solver_fork && make
    !curl -L https://github.com/PascalPons/connect4/releases/download/book/7x6.book --output 7x6.book
    solver_path='./connect4_solver_fork/c4solver_c_interface.so'
    solver_lib = ctypes.CDLL(solver_path)
            
    solver_lib.solver_init.argtypes = [ctypes.c_char_p]
    solver_lib.solver_init.restype = ctypes.POINTER(ctypes.c_void_p)
    
    solver_lib.solver_delete.argtypes = [ctypes.POINTER(ctypes.c_void_p)]
    solver_lib.solver_delete.restype = None
    
    solver_lib.solver_solve.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_char_p, ctypes.c_bool, ctypes.c_bool, ctypes.c_char_p, ctypes.c_size_t]
    solver_lib.solver_solve.restype = ctypes.c_char_p
    def __init__(self):
        """
        Initializes the Connect4Solver with the path to the solver executable.

        Args:
            solver_path (str): Path to the Connect4 solver executable.
        """
        self.MAX_SCORE = 24
        self.handle = Connect4Solver.solver_lib.solver_init(None)
        self.result_buffer = ctypes.create_string_buffer(256)

    def __del__(self):
        """
        Destructor that sends EOF to the solver process.
        """
        if hasattr(self, 'child') and self.child is not None:
            try:
                self.child.sendeof()
            except:
                pass

    def _process_output(self, prompt_str, answer_str):
        """
        Processes the output from the solver.

        Args:
            prompt_str (str): The prompt string.
            answer_str (str): The answer string.

        Returns:
            list: List of floats representing the processed output.
        """
        if answer_str.startswith(prompt_str):
            answer_str = answer_str[len(prompt_str):].strip()
            
        answer_list = [float(x) for x in answer_str.split()]
        return answer_list

    def _softmax(self, x, temperature=1.0):
        """
        Calculates a modified softmax that approaches argmax for small temperatures.

        For very small temperatures, indices with the maximum value will receive
        equal probability, and the rest will receive 0.

        Args:
            x (list): List of values.
            temperature (float): Temperature parameter for softmax.

        Returns:
            list: List of probabilities.
        """
        if temperature <= 5e-2:  # Consider a very small temperature as argmax
            max_val = max(x)
            max_indices = [i for i, val in enumerate(x) if val == max_val]
            probabilities = [0.0] * len(x)
            prob = 1.0 / len(max_indices)
            for i in max_indices:
                probabilities[i] = prob
            return probabilities
        else:
            e_x = []
            for i in x:
                exponent = i / temperature
                if exponent < -100:
                    e_x.append(0.0)
                else:
                    e_x.append(math.exp(exponent))

            sum_e_x = sum(e_x)
            if sum_e_x == 0:
                return [1.0 / len(x) for _ in range(len(x))]
            return [e / sum_e_x for e in e_x]

    def _transform(self, data, score_offset):
        transformed_data = []
        for x in data:
            sign = 1 if x > 0 else -1 if x < 0 else 0
            if x > -1000:
                transformed_x = sign * ((abs(x) + score_offset) / self.MAX_SCORE * 5)
            else:
                transformed_x = -1000
            transformed_data.append(transformed_x)
        #print(transformed_data)
        return transformed_data

    def _random_index(self, softmax_probs):
        """
        Selects a random index based on softmax probabilities.

        Args:
            softmax_probs (list): List of softmax probabilities.

        Returns:
            int: Selected index.
        """
        selected_index = np.random.choice(len(softmax_probs), p=softmax_probs)
        return selected_index

    def get_solver_move(self, move_str, temperature=1.0):
        """
        Gets a move from the solver.

        Args:
            move_str (str): Move string to send to the solver.
            temperature (float): Temperature parameter for softmax.

        Returns:
            int: Selected move index.
        """
        try:
            result = Connect4Solver.solver_lib.solver_solve(self.handle, move_str.encode("utf-8"), False, True, self.result_buffer, 256)
            answer = result.decode()
            score_offset = math.floor(len(move_str) / 2)
            transformed = self._transform(self._process_output(move_str, answer), score_offset)
            probas = self._softmax(transformed, temperature)
            #print(f"{answer}")
            #print(probas)
            return self._random_index(probas)
        except Exception as e:
            print(f"{e}")
            print(f"{temperature=}")
            print(f"{answer=}")
            try:
                print(f"{transformed=}")
                print(f"{probas=}")
            except:
                pass
        return 0

In [None]:
import zipfile

def zip_directories(directory_paths, working_dir):
    """
    Zips the given directories into their parent directory.

    Args:
        directory_paths (list): A list of paths to directories.

    Returns:
        list: A list of paths to the created zip files.
    """
    os.makedirs(working_dir, exist_ok=True)
    
    zip_file_paths = []
    for dir_path in directory_paths:
        if not os.path.isdir(dir_path):
            print(f"Warning: {dir_path} is not a directory. Skipping.")
            continue

        parent_dir = working_dir
        dir_name = os.path.basename(dir_path)
        zip_file_path = os.path.join(parent_dir, f"{dir_name}.zip")

        try:
            with zipfile.ZipFile(zip_file_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
                for root, _, files in os.walk(dir_path):
                    for file in files:
                        file_path = os.path.join(root, file)
                        relative_path = os.path.relpath(file_path, dir_path)
                        zipf.write(file_path, relative_path)
            zip_file_paths.append(zip_file_path)
        except Exception as e:
            print(f"Error zipping {dir_path}: {e}")

    return zip_file_paths

In [None]:
class ConnectFourEnv(gym.Env):
    metadata = {"render_modes": ["human", "ansi", "rgb_array"], "render_fps": 1}
    def __init__(self, render_mode=None, board_rows=6, board_cols=7):
        super(ConnectFourEnv, self).__init__()
        self.board_rows = board_rows
        self.board_cols = board_cols
        self.action_space = spaces.Discrete(self.board_cols)  # Columns to drop a piece
        self.observation_space = spaces.Box(low=0, high=255, shape=(2, self.board_rows, self.board_cols), dtype=np.uint8)  # two binary matrices. one for each players stones
        self.render_mode = render_mode
        self.move_history = ""
        
        self.reset()

    def reset(self, seed=None, options=None):
        super().reset(seed=seed, options=options)
        self.board = np.zeros((self.board_rows, self.board_cols), dtype=np.int8)
        self.player = 1  # Player 1 starts
        self.done = False
        self.winner = None
        self.turns = 0
        self.move_history_str = ""
        info = {}
        return self._get_observation(), info

    def step(self, action):
        if self.done:
            return self._get_observation(), 0, True, False, {}

        if not self._is_valid_move(action):
            return self._get_observation(), -50, False, False, {}

        self._drop_piece(action)
        self.move_history_str += str(action + 1)
        self.turns += 1

        if self._check_win():
            self.done = True
            self.winner = self.player
            reward = 80 + 20 * (len(self.board.flatten()) - self.turns) / len(self.board.flatten())
        elif self._check_draw():
            self.done = True
            reward = 0
        else:
            reward = 0
        self.player *= -1  # Switch players
        return  self._get_observation(), reward, self.done, False, {}

    def get_valid_moves(self):
        valid_moves = []
        for col in range(self.board_cols):
            if self._is_valid_move(col):
                valid_moves.append(col)
        return valid_moves

    def _get_observation(self):
        m, n = self.board.shape
        player_perspective = self.board * self.player
        new_array = np.zeros((2, m, n), dtype=np.uint8)
        new_array[0, :, :] = 255 * (player_perspective == 1).astype(np.uint8)
        new_array[1, :, :] = 255 * (player_perspective == -1).astype(np.uint8)
        return new_array

    def _is_valid_move(self, col):
        return self.board[0, col] == 0

    def _drop_piece(self, col):
        for row in range(self.board_rows - 1, -1, -1):
            if self.board[row, col] == 0:
                self.board[row, col] = self.player
                return

    def _check_win(self):
        # Check horizontal, vertical, and diagonal wins
        for r in range(self.board_rows):
            for c in range(self.board_cols - 3):
                if (
                    self.board[r, c] == self.board[r, c + 1] == self.board[r, c + 2] == self.board[r, c + 3] != 0
                ):
                    return True

        for c in range(self.board_cols):
            for r in range(self.board_rows - 3):
                if (
                    self.board[r, c] == self.board[r + 1, c] == self.board[r + 2, c] == self.board[r + 3, c] != 0
                ):
                    return True

        for r in range(self.board_rows - 3):
            for c in range(self.board_cols - 3):
                if (
                    self.board[r, c] == self.board[r + 1, c + 1] == self.board[r + 2, c + 2] == self.board[r + 3, c + 3] != 0
                ):
                    return True

        for r in range(3, self.board_rows):
            for c in range(self.board_cols - 3):
                if (
                    self.board[r, c] == self.board[r - 1, c + 1] == self.board[r - 2, c + 2] == self.board[r - 3, c + 3] != 0
                ):
                    return True
        return False

    def _check_draw(self):
        return np.all(self.board != 0)

    def render(self):
        board_str = ""
        board_str += "-" * (self.board_cols * 2 + 3) + "\n"
        for row in self.board:
            board_str += "| "
            for cell in row:
                if cell == 1:
                    board_str += "x "
                elif cell == -1:
                    board_str += "o "
                else:
                    board_str += "  "
            board_str += "|\n"
        board_str += "-" * (self.board_cols * 2 + 3)
        print(board_str)

In [None]:
try:
    agent_dir = "/kaggle/input/connect-4-agents/"
    agent_files = [f for f in os.listdir(agent_dir)]
    agent_paths = [os.path.join(agent_dir, f) for f in agent_files]
    agent_paths = zip_directories(agent_paths, "/kaggle/working/opponents")
except:
    agent_paths = []

In [None]:
try:
    checkpoint_dir = "/kaggle/input/connect-4-checkpoints/"
    checkpoint_files = [f for f in os.listdir(checkpoint_dir)]
    checkpoint_paths = [os.path.join(checkpoint_dir, f) for f in checkpoint_files]
    zip_directories(checkpoint_paths, "/kaggle/working/")
except Exception as e:
    print(e)
    pass

In [None]:
class Connect4TrainingEnv(gym.Wrapper):
    def __init__(self, env, opponents):
        super().__init__(env)
        self.player_move = 0
        self.opponents = opponents
        self.current_opponent = opponents[0]
        self.env = env
        self.solver = Connect4Solver()

    def reset(self, seed=None, options=None):
        self.player_move = 0
        self.current_opponent = random.choice(self.opponents)
        if type(self.current_opponent) == float:
            self.current_opponent = random.uniform(0, self.current_opponent)
        obs, info = self.env.reset(seed, options)
        num_opening_moves = random.randint(0, 5)
        while num_opening_moves:
            opponent_action = self._get_valid_opp_move(obs, 0.5)
            obs, *_ = self.env.step(opponent_action)
            num_opening_moves -= 1
        return obs, info

    def step(self, action):
        if self.player_move >= self.env.board_rows * self.env.board_cols / 1.5:
            return self.env._get_observation(), 0, False, True, {}
        old_turns = self.env.turns
        observation, reward, done, truncated, info = self.env.step(action)
        self.player_move += 1
        if self.env.turns > old_turns and not done and not truncated:
            opponent_action = self._get_valid_opp_move(observation)
            observation, opp_reward, done, truncated, info = self.env.step(opponent_action)
            reward -= opp_reward
        return observation, reward, done, truncated, info

    def render(self):
        print(f"..... MOVE {self.player_move} .....")
        self.env.render()
        
    def update_opponents(self, model, replace_fraction = 0.5):
        num_opps = len(self.opponents)
        update_idxs = random.sample(range(num_opps), math.ceil(num_opps * replace_fraction))
        for i in update_idxs:
            self.opponents[i].set_parameters(model.get_parameters())
    
    def _get_valid_opp_move(self, observation, random_freq = 0.1):
        valid_moves = self.env.get_valid_moves()
        if type(self.current_opponent) == float:
            opponent_action = self.solver.get_solver_move(self.env.move_history_str, temperature = self.current_opponent)
        else:
            opponent_action, _ = self.current_opponent.predict(observation, deterministic=False)
        if opponent_action not in valid_moves or random.random() < random_freq:
            opponent_action = random.choice(valid_moves)
        return opponent_action

class OpponentUpdateCallback(BaseCallback):
    def __init__(self, self_play_env, verbose: int = 0):
        super().__init__(verbose)
        self.self_play_env = self_play_env

    def _on_rollout_start(self) -> None:
        self.self_play_env.update_opponents(self.model, 0.4)
        pass
    def _on_rollout_end(self) -> None:
        self.self_play_env.update_opponents(self.model, 0.8)
        pass
        
    def _on_training_start(self) -> None:
        pass

    def _on_step(self) -> bool:
        return True


    def _on_training_end(self) -> None:
        pass
    

def train_agent(agent_name, num_iterations, working_dir, opponents_dir=None, num_solver_opponents=0, solver_max_temp=0.5, rollout_len = 2**14):
    policy_kwargs = dict(
        net_arch = [2**13, 2**8, 2**9]
    )
    opponent_pool_size = 5 

    selfplay = False
    try:
        opponents = []
        opponent_paths = [f for f in os.listdir(opponents_dir) if f.endswith(".zip")]
        opponent_paths = [os.path.join(opponents_dir, f) for f in opponent_paths]
        for opponent_path in opponent_paths:
            opponents.append(PPO.load(opponent_path))
        print(f"training against opponents: {opponent_paths}")
        num_solver_opponents = math.ceil(num_solver_opponents * len(opponent_paths))
    except Exception as e:
        print(f"Error loading opponents: {e}")

    if num_solver_opponents > 0:
        opponents += num_solver_opponents * [solver_max_temp]
        print(f"training against solver")
    if not opponents:
        print(f"training using selfplay")
        selfplay = True
        opponents = [PPO("MlpPolicy", ConnectFourEnv(), policy_kwargs=policy_kwargs, verbose=0) for _ in range(opponent_pool_size)]

    training_env = Connect4TrainingEnv(ConnectFourEnv(), opponents)

    # Check for existing checkpoints
    checkpoint_prefix = f"checkpoints_{agent_name}"
    try:
        checkpoint_dir = f"{working_dir}"
        checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.startswith(checkpoint_prefix) and f.endswith(".zip")]
        checkpoint_paths = [os.path.join(checkpoint_dir, f) for f in checkpoint_files]
        latest_checkpoint = max(checkpoint_paths, key=os.path.getmtime)
        print(f"Loading checkpoint: {latest_checkpoint}")
        model = PPO.load(latest_checkpoint, env=training_env, policy_kwargs=policy_kwargs)
        total_timesteps = model.num_timesteps
        print(f"Checkpoint loaded. Timesteps: {total_timesteps}")
    except Exception as e:
        print(f"No checkpoint loaded {e}. Starting from scratch.")
        model = PPO("MlpPolicy", training_env, policy_kwargs=policy_kwargs, verbose=1, n_steps=rollout_len)

    print(model.policy)
    checkpoint_callback = CheckpointCallback(save_freq=rollout_len * 2, save_path=f"{working_dir}", name_prefix=checkpoint_prefix)
    opponent_update_callback = OpponentUpdateCallback(training_env)

    callbacks = [checkpoint_callback]
    if selfplay:
        training_env.update_opponents(model, 1)
        callbacks += [opponent_update_callback]
    
    model.learn(total_timesteps=rollout_len * num_iterations, callback=callbacks, reset_num_timesteps=False)

    model.save(f"{working_dir}/models/{agent_name}")

In [None]:
agent_name = "agent_mlp_1"

train_agent(agent_name, 5, "/kaggle/working/", "/kaggle/working/opponents/", num_solver_opponents = 0.2, rollout_len = 2 ** 16)
train_agent(agent_name, 15, "/kaggle/working/", rollout_len = 2 ** 16)
train_agent(agent_name, 5, "/kaggle/working/", "/kaggle/working/opponents/", num_solver_opponents = 0.5, rollout_len = 2 ** 16)
train_agent(agent_name, 3, "/kaggle/working/", num_solver_opponents = 1, rollout_len = 2 ** 16)
train_agent(agent_name, 15, "/kaggle/working/", rollout_len = 2 ** 16)
train_agent(agent_name, 3, "/kaggle/working/", num_solver_opponents = 1, rollout_len = 2 ** 16)