In [None]:
import copy
import os
import numpy as np
# ML libraries
import torch
import torch.nn as nn
from collections import deque

from agents.random_agent import Random_Agent
from agents.dqn_agent import DQN_Agent
from envs._env import JassEnv
import utils

utils.seed_everything(99, deterministic=False)

NUM_EPISODES = 10000
# Use GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Player table
#   P2
# P3  P1
#   P0

agent = DQN_Agent(player_id=0, team_id=0, device=device)
players = {"P1": "greedy", "P2": "greedy", "P3": "greedy"}
starting_player_id = 0

total_rewards = deque(maxlen=10)
for i in range(NUM_EPISODES):
    env = JassEnv(starting_player_id=starting_player_id, players=players)
    state = env.reset()
    done = False
    
    total_reward = 0
    while not done:
        print('\r                                                                                                                                                                                                          ', end='', flush=True)
        print(f'\rRunning episode {i} of {NUM_EPISODES}. Agent Parameters: Epsilon = {agent.epsilon:.6f}, Memory Size = {len(agent.memory.memory)}. AVG_total_reward = {np.average(total_rewards)}', end='', flush=True)
                
        action = agent.act(state)
        next_state, reward, done = env.step(action)
        
        agent.remember(state, action, reward, next_state, done)
        agent.optimize_model()
            
        state = copy.deepcopy(next_state)
        total_reward += reward
    
    starting_player_id = (starting_player_id + 1) % 4
    total_rewards.append(total_reward)
    if i % 100 == 0:
        print(f"Episode {i} done")
        print(f"Total reward: {total_reward}")
    
    if i % 1000 == 0:
        directory = "./agents/models"
        if not os.path.isdir(directory):
            os.mkdir(directory)
        torch.save(agent.network.state_dict(), f"./agents/models/dqn_agent_{i}.pt")
        
        
        
        
        

KeyboardInterrupt: 