In [None]:
from random import randint
import numpy as np

In [None]:
from src.agent import DominoAgent
from src.environment import DominoEnvironment
from src.action import DominoAction, TOTAL_ACTIONS
from src.replay_buffer import ReplayBuffer

In [None]:
def train(num_episodes: int = 10_000, batch_size: int = 64, buffer_capacity: int = 100_000, min_buffer_size: int = 1000,
          target_update_freq: int = 100):
    loss: float = 0.0
    agent = DominoAgent()
    agent.load("domino_agent.pt")
    buffer = ReplayBuffer(buffer_capacity)
    all_rewards = []
    wins = []

    for episode in range(num_episodes):
        num_players = randint(2, 4)
        env = DominoEnvironment(num_players=num_players, agent_indices=[randint(0, num_players - 1)])
        env.reset()
        done = False
        episode_reward: float = 0.0

        while not done:
            while True:
                hand_tiles = [i for i, held in enumerate(env.current_state.hand_tiles) if held]
                available_player_actions = [x for i in hand_tiles for x in (2 * i, 2 * i + 1)] + [TOTAL_ACTIONS - 1]
                available_board_actions = [i for i, valid in enumerate(env.current_state.legal_actions) if valid]
                legal_actions = sorted(set(available_player_actions) & set(available_board_actions))

                if legal_actions == [TOTAL_ACTIONS - 1] and env.draw_pile:
                    _, _ = env.draw_tile()
                    continue
                else:
                    break

            action_index = agent.select_action(env.current_state, legal_actions, training=True)
            action = DominoAction(action_index)

            next_state, reward, done = env.step(action)
            episode_reward += reward

            buffer.push(
                np.array(env.current_state.to_array()),
                action_index,
                reward,
                np.array(next_state.to_array()),
                done,
                np.array(env.current_state.legal_actions),
                np.array(next_state.legal_actions)
            )

            if len(buffer) >= min_buffer_size:
                batch = buffer.sample(batch_size)
                batch_dict = {
                    'states': batch[0],
                    'actions': batch[1],
                    'rewards': batch[2],
                    'next_states': batch[3],
                    'dones': batch[4]
                }
                loss = agent.train_step(batch_dict)

        print(f"[{episode}] Reward: {episode_reward:.2f}, Epsilon: {agent.epsilon:.3f}")
        all_rewards.append(episode_reward)
        won = env.final_rewards.get(0, 0) > 0
        wins.append(int(won))
        agent.decay_epsilon()

        if episode % target_update_freq == 0:
            agent.update_target()
            print(f"[{episode}] Loss = {loss:.4f}")
            print(f"[{episode}] Avg reward (last {target_update_freq}): {sum(all_rewards) / len(all_rewards):.2f}")
            print(f"[{episode}] Win rate (last {target_update_freq}): {sum(wins) / len(wins):.2%}")
            all_rewards = []
            wins = []
            agent.save("domino_agent.pt")


In [None]:
train(5000,128,100_000,2000,50)