In [1]:
import numpy as np
import matplotlib.pyplot as plt

In [11]:
class Environment:
  state = np.zeros(3*3)

  def __init__(self):
    self.reset()

  def reset(self):
    self.state = np.zeros(3*3)

  def check_turn(self):
    return -(np.sum(self.state) * 2 - 1) # 1 for O, -1 for X

  def check_valid(self, action):
    return self.state[action] == 0

  def check_winner(self):
    for i in range(3):
      if self.state[i] == self.state[i+3] == self.state[i+6] != 0:
        return self.state[i]
      if self.state[i*3] == self.state[i*3+1] == self.state[i*3+2] != 0:
        return self.state[i*3]
    if self.state[0] == self.state[4] == self.state[8] != 0:
      return self.state[0]
    if self.state[2] == self.state[4] == self.state[6] != 0:
      return self.state[2]
    return 0

  def step(self, action) -> tuple[np.array, int, bool]:
    """step function

    Parameters
    ----------
    action : int
        action to take

    Returns
    -------
    tuple[np.array, int, bool]
        next state, reward, done
    """
    if not self.check_valid(action):
      return self.state, -1, True

    self.state[action] = self.check_turn()



  def str_state(self):
    out = ""
    for s in self.state:
      if s == 0:
        out += "-"
      elif s == 1:
        out += "O"
      else:
        out += "X"

    return out

In [14]:
class Agent:
  env = None
  Q = {}
  R = {}
  epsilon = 0.1

  def __init__(self, env):
    self.env = env

  def get_action(self):
    state = self.env.str_state()
    if state not in self.Q:
      self.Q[state] = np.zeros(3*3)
      self.R[state] = np.zeros(3*3)

    # epsilon greedy
    if np.random.rand() < self.epsilon:
      return np.random.randint(3*3)

    return np.argmax(self.Q[state])

In [9]:
def main():
  env = Environment()
  agent = Agent(env)

  for _ in range(1000):
    action = agent.act()
    env.step(action)