In [290]:
from catanatron import Game, Player, RandomPlayer, Color
from catanatron_gym.envs.catanatron_env import from_action_space, to_action_space
from catanatron_gym.features import create_sample_vector
import gymnasium as gym

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import numpy as np
import pandas as pd
import random

from collections import namedtuple, deque
from tqdm import tqdm

In [134]:
class DQN(nn.Module):
    def __init__(self, dim_state, num_actions, hidden_size):
        super(DQN, self).__init__()
        self.lin1 = nn.Linear(dim_state, hidden_size)
        self.lin2 = nn.Linear(hidden_size, num_actions)

    def forward(self, x):
        x = F.relu(self.lin1(x))
        return self.lin2(x)

In [102]:
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward', 'done'))

class ReplayMemory(object):
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [118]:
# Fill these in
def reward(state, action):
    ### TODO #######################
    return 0

def state_tensor(state):
    #### TODO #######################
    return torch.tensor(state, dtype=torch.float32).unsqueeze(0)

In [200]:
GAMMA = 0.99
LR = 0.001
BATCH_SIZE = 128
HIDDEN_SIZE = 128

POLICY_UPDATE = 100 # Number of actions before retraining
TARGET_UPDATE = 5 # Number of episodes before updating target network
NUM_EPISODES = 1000

In [142]:
env = gym.make("catanatron_gym:catanatron-v0")
state, info = env.reset()

num_actions = env.action_space.n
dim_state = state_tensor(state).shape[1]

memory = ReplayMemory(10000)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# EPS_START = 0.9
# EPS_END = 0.05
# EPS_DECAY = 200
# TARGET_UPDATE = 10

In [140]:
policy_net = DQN(dim_state, num_actions, HIDDEN_SIZE).to(device)
target_net = DQN(dim_state, num_actions, HIDDEN_SIZE).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.Adam(policy_net.parameters(), LR)
loss_fn = nn.MSELoss()

In [193]:
def train_batch():
    if len(memory) < BATCH_SIZE:
        return

    transitions = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))

    states = torch.cat(batch.state)
    actions = torch.tensor(batch.action)
    next_states = torch.cat([state for state in batch.next_state if state is not None])
    rewards = torch.tensor(batch.reward)
    dones = torch.tensor(batch.done)

    cur_Q = policy_net(states).gather(1, actions.unsqueeze(1)).squeeze(1)

    target_Q = torch.zeros(BATCH_SIZE, device=device)
    with torch.no_grad():
        target_Q[~dones] = target_net(next_states).max(1).values
    target_Q = rewards + GAMMA * target_Q

    # Update the policy network
    loss = F.mse_loss(cur_Q, target_Q)
    optimizer.zero_grad()
    ### TODO: Clip Gradients #######################
    loss.backward()
    optimizer.step()

In [195]:
def train_episode():
    state, info = env.reset()
    state = state_tensor(state)
    done = False
    i = 0

    while not done:
        action = random.choice(env.get_valid_actions()) 
        next_state, reward, done, info = env.step(action)
        next_state = state_tensor(next_state) if not done else None
        memory.push(state, action, next_state, reward, done)
        state = next_state

        i += 1
        if i % POLICY_UPDATE == 0:
            train_batch()

        

In [205]:
def train():
    for i in tqdm(range(NUM_EPISODES)):
        train_episode()
        if i % TARGET_UPDATE == 0:
            target_net.load_state_dict(policy_net.state_dict())
            target_net.eval()

In [56]:
# def select_action(state, eps):
#     if random.random() < eps:
#         return torch.tensor([[random.randrange(num_actions)]], device=device, dtype=torch.long)
#     else:
#         with torch.no_grad():
#             return policy_net(state).max(1)[1].view(1, 1)

In [206]:
train()

100%|██████████| 1000/1000 [05:06<00:00,  3.26it/s]


In [308]:
class DQNPlayer(Player):
    def __init__(self, model, color):
        super().__init__(color)
        self.model = model
        self.model.eval()

    def decide(self, game, valid_actions):
        state = np.array(create_sample_vector(game, self.color))
        state = state_tensor(state)
        action_ints = list(map(to_action_space, valid_actions))
        with torch.no_grad():
            Q = self.model(state)[0, action_ints]
        best_action = action_ints[Q.max(0).indices.item()]
        return from_action_space(best_action, valid_actions)

In [316]:
players = [
    RandomPlayer(Color.WHITE),
    DQNPlayer(policy_net, Color.ORANGE),
]

winners = []
for i in tqdm(range(100)):
    game = Game(players)
    winners.append(game.play())

100%|██████████| 100/100 [00:41<00:00,  2.40it/s]


In [318]:
winners.count(Color.ORANGE) / len(winners)

0.0

In [235]:
import catanatron_gym

In [254]:
len(catanatron_gym.features.create_sample_vector(game, Color.WHITE))

614