In [1]:
import pickle
import time
import matplotlib.pyplot as plt
from tic_tac_toe import *

pygame 2.5.2 (SDL 2.28.3, Python 3.10.11)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
def time_until_done(start_time, episode, num_episodes):
    current_time = time.time()
    elapsed_time = current_time - start_time

    average_time_per_episode = elapsed_time / episode
    remaining_episodes = num_episodes - episode
    estimated_remaining_time = remaining_episodes * average_time_per_episode

    # Calculate days, hours, minutes, and seconds
    days, remainder = divmod(estimated_remaining_time, 86400)
    hours, remainder = divmod(remainder, 3600)
    minutes, seconds = divmod(remainder, 60)

    return f"{int(days)} days, {int(hours)} hours, {int(minutes)} minutes, {int(seconds)} seconds"

In [3]:
num_episodes = 5000000
a = 0.1  # learning rate
y = 0.99  # discount factor
initial_epsilon = 1.0
final_epsilon = 0.1

epsilon_decay_rate = (initial_epsilon - final_epsilon) / num_episodes
epsilon = initial_epsilon
win_rate_check = max(num_episodes // 100, 100)

env = TicTacToeEnv(render=False, human_env=False, auto_block=False, auto_win=False)
q_table = {}
win_rate_history = ([], [])
epsilon_history = ([], [])

In [4]:
try:
    with open(r".\qtable.pkl", "rb") as file:
        q_table = pickle.load(file)
        print("Loaded q_table")
except:
    games_won = 0

    for episode in range(1, num_episodes + 1):
        epsilon_history[0].append(episode)
        epsilon_history[1].append(epsilon)
        obs = env.reset(random.choice([True, False]))
        while True:
            # choosing an action
            s1 = hash_obs(obs)
            if not s1 in q_table:
                q_table[s1] = [0] * 9
                action = random_action(obs)
            else:
                action = epsilon_greedy_action(epsilon, q_table[s1], obs)
            # get observation data
            obs, reward, done = env.step(action)
            s2 = hash_obs(obs)
            if not s2 in q_table:
                q_table[s2] = [0] * 9
            # update the current Q value: Q(S,A) = Q(S,A) + a(R + y max Q(S',a) - Q(S,A))
            q_table[s1][action] = q_table[s1][action] + a * (
                reward + y * max(q_table[s2]) - q_table[s1][action]
            )
            if done:
                if reward == 10:
                    games_won += 1
                break
        # Update epsilon for the next episode
        epsilon -= epsilon_decay_rate
        epsilon = max(final_epsilon, epsilon)  # Ensure epsilon doesn't go below the final value
        if episode % win_rate_check == 0:
            win_rate = games_won / win_rate_check * 100
            print(f"Episods done: {episode}, win rate: {win_rate:.1f}%")
            win_rate_history[0].append(episode)
            win_rate_history[1].append(win_rate)
            games_won = 0

    env.close()

Loaded q_table


In [5]:
# Win rates
if win_rate_history[0]:
    plt.plot(win_rate_history[0], win_rate_history[1])
    plt.title("Agent win rates")
    plt.ylabel("Win rate in %")
    plt.xlabel("Episode")
    plt.ylim(bottom=0, top=100)
    plt.xlim(left=0, right=num_episodes)
    plt.show()

In [6]:
# Epsilon
if epsilon_history[0]:
    plt.plot(epsilon_history[0], epsilon_history[1])
    plt.title("Epsilon decay")
    plt.ylabel("Epsilon")
    plt.xlabel("Episode")
    plt.ylim(bottom=0, top=1)
    plt.xlim(left=0, right=num_episodes)
    plt.show()

In [7]:
print(len(q_table), q_table)

6436 {(0, 0, 0, 0, 0, 0, 0, 0, 0): [9.577256517912712, 9.61234480847263, 9.559950000373284, 9.638119030227113, 9.660488488662024, 9.57776023350775, 9.607258718491593, 9.483278137125986, 9.602013596989298], (0, 2, 0, 0, 0, 0, 1, 0, 0): [9.894946153579095, 0, 8.99761832041786, 8.228188107105948, 9.703663650490824, 8.923280825786712, 0, 9.550926390199525, 8.77994979794505], (0, 2, 2, 1, 0, 0, 1, 0, 0): [9.999999999999993, 0, 0, 0, 4.7817685477469025, 7.954402997071179, 0, 4.945455420748474, 6.513644973610699], (0, 2, 2, 1, 0, 1, 1, 0, 2): [9.999999999999993, 0, 0, 0, -2.422013180150944, 0, 0, -9.999999998007436, 0], (0, 2, 2, 1, 1, 1, 1, 2, 2): [0, 0, 0, 0, 0, 0, 0, 0, 0], (0, 0, 0, 0, 0, 1, 0, 0, 0): [9.17183019938727, 8.189316011489801, 8.704120750425734, 6.914333732579094, 9.163228349117933, 0, 9.630466875047976, 9.148346955639099, 8.01855457002368], (0, 0, 0, 1, 0, 1, 2, 0, 0): [2.1132784697225677, 2.580224692814658, 5.825449331861121, 0, 9.879069863135424, 0, 0, 3.3930359122157436, 5

In [8]:
env = TicTacToeEnv(render=True, fps=60, human_env=True)
closed = False

while not closed:
    try:
        obs = env.reset(random.choice([True, False]))
    except KeyboardInterrupt:  # The game was closed
        break
    while True:
        # choosing an action
        s1 = hash_obs(obs)
        if not s1 in q_table:
            q_table[s1] = [0] * 9
            action = random_action(obs)
        else:
            action = best_action(q_table[s1], obs)
        # get observation data
        try:
            obs, reward, done = env.step(action)
        except KeyboardInterrupt:  # The game was closed
            closed = True
            break
        s2 = hash_obs(obs)
        if not s2 in q_table:
            q_table[s2] = [0] * 9
        # update the current Q value: Q(S,A) = Q(S,A) + a(R + y max Q(S',a) - Q(S,A))
        q_table[s1][action] = q_table[s1][action] + a * (
            reward + y * max(q_table[s2]) - q_table[s1][action]
        )
        if done:
            for _ in range(150):
                env.clock.tick(60)
            break

In [9]:
with open(r".\qtable.pkl", "wb") as file:
    pickle.dump(q_table, file)