In [None]:
import gymnasium as gym
from src.QLearning import QLearningAgent
import numpy as np
from tqdm import tqdm
from src.Helpers import store_data

In [None]:
learning_rates = [0.1, 0.05, 0.01]
n_episodes = [10_000, 100_000, 1_000_000]
start_epsilons = [1.0, 0.5, 0.25]
final_epsilons = [0.25, 0.1, 0.0]

In [None]:
confs = np.array(np.meshgrid(learning_rates, n_episodes, start_epsilons, final_epsilons)).T.reshape(-1,4)
n_confs = len(confs)
for i, p in enumerate(confs):
    env: gym.Env = gym.make('Blackjack-v1', render_mode="rgb_array", natural=False, sab=False)
    env = gym.wrappers.RecordEpisodeStatistics(env, deque_size=int(p[1]))

    agent = QLearningAgent(
        action_space=env.action_space,
        learning_rate=p[0],
        initial_epsilon=p[2],
        epsilon_decay=p[2] / (p[1] / 2),
        final_epsilon=p[3],
    )

    params= {
        'n_episodes': p[1],
        'learning_rate': p[0],
        'initial_epsilon': p[2],
        'epsilon_decay': p[2] / (p[1] / 2),
        'final_epsilon': p[3],
    }
    print('Configuration', i, 'of', n_confs)
    print('params:', params)

    for episode in tqdm(range(int(p[1]))):
        curr_observation, info = env.reset()
        curr_action: int = agent.get_action(curr_observation)
        # play one episode
        while True:
            # act upon the enviromment
            next_observation, reward, terminated, truncated, info = env.step(curr_action)
            is_terminal: bool = terminated or truncated
            # select next action
            next_action: int = agent.get_action(next_observation)
            # update the agent
            agent.update(curr_observation, curr_action, reward, terminated, next_observation, next_action)
            # update the current observation and action
            curr_observation = next_observation
            curr_action = next_action
            # end the episode
            if (is_terminal):
                break
        # reduce exploration factor
        agent.decay_epsilon()
    
    store_data(f'blackjack_qlearning_{i}', agent, env, params)