In [79]:
!pip uninstall tic_env -y

Found existing installation: tic_env 0.0.1
Uninstalling tic_env-0.0.1:
  Successfully uninstalled tic_env-0.0.1


In [1]:
!pip install git+https://github.com/JamorMoussa/Tic-Tac-Toe-Reinforcement-Learning.git

Collecting git+https://github.com/JamorMoussa/Tic-Tac-Toe-Reinforcement-Learning.git
  Cloning https://github.com/JamorMoussa/Tic-Tac-Toe-Reinforcement-Learning.git to /tmp/pip-req-build-ns2ng61n
  Running command git clone --filter=blob:none --quiet https://github.com/JamorMoussa/Tic-Tac-Toe-Reinforcement-Learning.git /tmp/pip-req-build-ns2ng61n
  Resolved https://github.com/JamorMoussa/Tic-Tac-Toe-Reinforcement-Learning.git to commit 1056abc9f59714bb1763b3f91ac542c50f840cb1
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [3]:
import gymnasium as gym
import tic_env

import numpy as np
import random

from tqdm.notebook import trange

In [4]:
env = gym.make("tic_env/TicTacToe-v0").unwrapped

In [5]:
env

<tic_env.envs.tic_env.TicTacToeEnv at 0x7e7390c5c340>

In [6]:
print("Observation Space", env.observation_space)
print("Sample observation", env.observation_space.sample()) # display a random observation

Observation Space Discrete(511)
Sample observation 155


In [7]:
print("Action Space Shape", env.action_space.n)
print("Action Space Sample", env.action_space.sample())

Action Space Shape 9
Action Space Sample 7


In [8]:
num_states = env.observation_space.n

num_actions = env.action_space.n

In [9]:
num_states, num_actions

(np.int64(511), np.int64(9))

In [10]:
class QTable:

  def __init__(self, num_states: int, num_actions: int):
    self.num_states = num_states
    self.num_actions = num_actions

    self.q_table = np.zeros((self.num_states, self.num_actions))

  def get(self, state: int, action: int):
    action -= 1
    state -= 1
    return self.q_table[state, action]

  def argmax(self, state: int):
    state -= 1
    return np.argmax(self.q_table[state])

  def max(self, state: int):
    state -= 1
    return np.max(self.q_table[state])

  def update(self, state: int, action: int, q_value: float):
    action -= 1
    state -= 1
    self.q_table[state, action] = q_value

  def __repr__(self):
    return self.q_table.__repr__()

In [11]:
def eps_greedy_policy(q_table: QTable, state: int, eps: float):

  if random.uniform(0, 1) < eps:
    return env.action_space.sample()

  return q_table.argmax(state)

In [12]:
def greedy_policy(q_table: QTable, state: int):

  return q_table.argmax(state)

In [172]:
# Training parameters
n_training_episodes = 3 * 10000
lr = 0.2

# Evaluation parameters
n_eval_episodes = 100

max_steps = 100
gamma = 0.95
eval_seed = []

# Exploration parameters
max_epsilon = 1.0
min_epsilon = 0
decay_rate = 2e-3

In [173]:
float(min_epsilon + (max_epsilon - min_epsilon) * np.exp(-decay_rate * 0.5 * 10000))

4.5399929762484854e-05

In [174]:
def train(
    num_iters: int,
    min_eps: int,
    max_eps: int,
    deacy_rate: float,
    env: gym.Env,
    max_steps: int,
    q_table: QTable
):

  for episode in trange(num_iters):

    eps = float(min_eps + (max_eps - min_eps) * np.exp(-decay_rate * episode))

    state = env.reset()

    step = 0
    done = False

    for step in range(max_steps):

      action = eps_greedy_policy(q_table, state, eps)

      new_state, reward, done, info = env.step(action)

      q_value = q_table.get(state, action) + lr * (
          reward  + gamma * q_table.max(new_state) - q_table.get(state, action)
      )

      q_table.update(state, action, q_value)

      if done:
        break

      state = new_state

  return q_table

In [175]:
q_table = QTable(num_states, num_actions)

In [176]:
q_table = train(n_training_episodes, min_epsilon, max_epsilon, max_epsilon, env, max_steps, q_table)

  0%|          | 0/30000 [00:00<?, ?it/s]

In [153]:
q_table.q_table[0]

array([-0.86931777,  0.1428359 ,  0.14475404,  0.14402784,  0.14351932,
        0.14337215,  0.14281248,  0.12584032,  0.10555922])

In [177]:
q_table.q_table[0]

array([-0.93364533,  0.06984703,  0.06320568,  0.06144047,  0.06102971,
        0.06113571,  0.06041233,  0.04229524,  0.0141167 ])

In [178]:
env = gym.make("tic_env/TicTacToe-v0").unwrapped

obv = env.reset()

In [179]:
from tic_env.tic_game import PlayerId

In [180]:
def human_vs_machine(
    env, q_table
):

  state = env.reset()

  env.game.change_player()

  done = False

  while not done:

      if env.game.cur_player.id == PlayerId.O:
        print("allowed action", env.game.get_allowed_actions())
        action = int(input("Enter and action: "))
        while action not in env.game.get_allowed_actions():
          print("This action is not allowed")
          action = int(input("Enter and action: "))

        print(f"Human takes this action: {action}")

      else:
        action = greedy_policy(q_table=q_table, state=state) + 1
        print(f"AI takes this action: {action}")

      state, reward, done, _ = env.step(action=action)

      env.render()

In [188]:
human_vs_machine(
    env=env, q_table=q_table
)

AI takes this action: 1
+---+---+---+
| X |   |   |
+---+---+---+
|   |   |   |
+---+---+---+
|   |   |   |
+---+---+---+

allowed action [2, 3, 4, 5, 6, 7, 8, 9]
Enter and action: 1
This action is not allowed
Enter and action: 2
Human takes this action: 2
+---+---+---+
| X | O |   |
+---+---+---+
|   |   |   |
+---+---+---+
|   |   |   |
+---+---+---+

AI takes this action: 5
+---+---+---+
| X | O |   |
+---+---+---+
|   | X |   |
+---+---+---+
|   |   |   |
+---+---+---+

allowed action [3, 4, 6, 7, 8, 9]
Enter and action: 9
Human takes this action: 9
+---+---+---+
| X | O |   |
+---+---+---+
|   | X |   |
+---+---+---+
|   |   | O |
+---+---+---+

AI takes this action: 7
+---+---+---+
| X | O |   |
+---+---+---+
|   | X |   |
+---+---+---+
| X |   | O |
+---+---+---+

allowed action [3, 4, 6, 8]
Enter and action: 3
Human takes this action: 3
+---+---+---+
| X | O | O |
+---+---+---+
|   | X |   |
+---+---+---+
| X |   | O |
+---+---+---+

AI takes this action: 8
+---+---+---+
| X | 