In [None]:
# Load the training data
from training import PostgresPlaydata
playdata = PostgresPlaydata(conn_str="postgresql://user:password@localhost:5432/playdata").get()
print(len(playdata))

In [None]:
from tictactoe.evaluate import evaluate
from tictactoe.agent import QLearningAgent, RandomAgent
from pathlib import Path
import numpy as np

def win_percent(data):
    # Create a new RL agent and train it on the provided data
    save_path = Path("./agent.pickle")
    assert not save_path.exists()
    rl_agent = QLearningAgent(save_path=save_path)
    rl_agent.train(data=data)
    
    # Evaluate the RL agent against a random opponent
    rounds = 10000
    wins = np.array(evaluate(rounds=rounds, agent1=rl_agent, agent2=RandomAgent()))
    
    return 100 * wins / rounds

In [None]:
from math import log10

# Sample the amount of data logarithmically
samples = 25
data_amounts = np.linspace(0, len(playdata), num=samples).astype(dtype=np.int64)
print(data_amounts)

In [None]:
from tqdm.notebook import tqdm
import multiprocessing

def slice_data(amount):
    return win_percent(playdata[:amount])

win_percents = []
with multiprocessing.Pool() as pool:
    for result in tqdm(pool.imap(slice_data, data_amounts), total=len(data_amounts)):
        win_percents.append(result)

In [None]:
import matplotlib.pyplot as plt

print(win_percents)
plt.plot(data_amounts, win_percents, label=["Tie %", "Win %", "Lose %"])
plt.legend()
plt.xlabel('Data Points')
plt.ylabel('Percent (%)')
# plt.xscale('log')