In [1]:
from pettingzoo.classic import connect_four_v3
import random
import numpy as np
from tqdm import tqdm
import time

In [2]:
class Player:
    def __init__(self, get_action) -> None:
        self.get_action = get_action

In [3]:
class Qlearning:
    def __init__(self, exploration_factor=0.1, discount_factor=0.7, learning_rate=0.1) -> None:
        self.exploration_factor = exploration_factor
        self.discount_factor = discount_factor
        self.learning_rate = learning_rate
        self.q_table = {}
        self.env = connect_four_v3.env()

    def __create_q_table_entry(self, state):
        name = "".join([str(x) for x in state.flatten()])
        if name not in self.q_table.keys():
            self.q_table[name] = [0 for _ in range(7)]
        return name

    def training(self, n_training_game=1000):
        for _ in tqdm(range(n_training_game)):
            self.env.reset()
            self.agents = {
                0: {"name": "player_0", "last_state": None, "current_state": None, "reward": 0},
                1: {"name": "player_1", "last_state": None, "current_state": None, "reward": 0},
            }
            end = False

            i = 0
            while end is False:
                current_agent = self.agents[i % 2]["name"]
                self.env.agent_selection = current_agent

                state = self.env.observe(current_agent)
                key = self.__create_q_table_entry(state["observation"])

                self.agents[i % 2]["last_state"] = key

                if random.uniform(0, 1) < self.exploration_factor:
                    action = self.env.action_space(current_agent).sample(state["action_mask"])
                else:
                    action = self.get_action(state)

                self.env.step(action)
                state, reward, termination, truncation, info = self.env.last()

                self.agents[i % 2]["reward"] = reward

                key = self.__create_q_table_entry(state["observation"])
                self.agents[i % 2]["current_state"] = key

                end = termination or truncation

                if end:
                    if self.agents[i % 2]["reward"] == 1:
                        self.agents[(i + 1) % 2]["reward"] = -1

                    for j in [0, 1]:
                        old_value = self.q_table[self.agents[j]["last_state"]][action]
                        next_max = np.max(self.q_table[self.agents[j]["current_state"]])
                        new_value = (1 - self.learning_rate) * old_value + self.learning_rate * (
                            self.agents[j]["reward"] + self.discount_factor * next_max
                        )
                        self.q_table[self.agents[j]["last_state"]][action] = new_value

                elif self.agents[(i + 1) % 2]["last_state"] != None:
                    old_value = self.q_table[self.agents[(i + 1) % 2]["last_state"]][action]
                    next_max = np.max(self.q_table[self.agents[(i + 1) % 2]["current_state"]])
                    new_value = (1 - self.learning_rate) * old_value + self.learning_rate * (
                        self.agents[(i + 1) % 2]["reward"] + self.discount_factor * next_max
                    )
                    self.q_table[self.agents[(i + 1) % 2]["last_state"]][action] = new_value
                i += 1
            self.env.close()

    def play(self):
        self.env = connect_four_v3.env(render_mode="human")
        self.env.reset()
        self.agents = {
            0: {"name": "player_0", "last_state": None, "current_state": None, "reward": 0},
            1: {"name": "player_1", "last_state": None, "current_state": None, "reward": 0},
        }
        end = False
        i = 0
        while end is False:
            current_agent = self.agents[i % 2]["name"]
            self.env.agent_selection = current_agent

            state = self.env.observe(current_agent)

            action = self.get_action(state)

            self.env.step(action)
            state, reward, termination, truncation, info = self.env.last()

            end = termination or truncation

            i += 1
            time.sleep(0.3)
        self.env.close()

    def get_action(self, state):
        key = self.__create_q_table_entry(state["observation"])
        possible = [
            self.q_table[key][i] if state["action_mask"][i] != 0 else -np.inf for i in range(7)
        ]
        action = np.argmax(possible)
        return action

In [4]:
Q = Qlearning()

In [5]:
Q.training(n_training_game=100)

100%|██████████| 100/100 [00:00<00:00, 167.44it/s]


In [6]:
Q.play()

In [7]:
player = Player(get_action=lambda state: Q.get_action(state))