In [59]:
from catanatron import Game, Player, RandomPlayer, Color, ActionType
from catanatron_gym.envs.catanatron_env import from_action_space, to_action_space
from catanatron_gym.features import create_sample_vector
import gymnasium as gym

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import numpy as np
import random

from collections import namedtuple, deque
from tqdm import tqdm

In [27]:
class DQN(nn.Module):
    def __init__(self, dim_state, num_actions, hidden_size):
        super(DQN, self).__init__()
        self.lin1 = nn.Linear(dim_state, hidden_size)
        self.lin2 = nn.Linear(hidden_size, hidden_size)
        self.lin3 = nn.Linear(hidden_size, num_actions)

    def forward(self, x):
        x = F.relu(self.lin1(x))
        x = F.relu(self.lin2(x))
        return self.lin3(x)

In [36]:
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward', 'done'))

class ReplayMemory(object):
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [60]:
def reward_f(env, action):
    action_type = action.action_type
    reward = 0
    if action_type == ActionType.BUILD_ROAD:
       reward = .1
    elif action_type == ActionType.BUILD_SETTLEMENT:
       reward = .2
    elif action_type == ActionType.BUILD_CITY:
       reward = .3
    if env.game.winning_color() == Color.BLUE:
        reward = 1000
    elif env.game.winning_color() is not None:
        reward = -1000
    return reward

def state_tensor(state):
    res = {k: state.player_state[k] for k in ('P0_ACTUAL_VICTORY_POINTS', 'P0_VICTORY_POINTS', 'P0_WOOD_IN_HAND', 'P0_BRICK_IN_HAND', 'P0_ORE_IN_HAND', 'P0_SHEEP_IN_HAND', 'P0_WHEAT_IN_HAND')}
    for i in range(54):
        building_feature = "LOC_" + str(i) + "_HAS_BUILDING"
        if i in state.board.buildings:
            res[building_feature] = True
        else:
            res[building_feature] = False
    for i in range(72):
        road_feature = "LOC_" + str(i) + "_HAS_ROAD"
        if i in state.board.roads:
            res[road_feature] = True
        else:
            res[road_feature] = False

    return torch.tensor(list(res.values()), dtype=torch.float32).unsqueeze(0)

In [None]:
def select_action(state, action_ints, policy_net, eps):
    sample = random.random()
    if sample > eps:
        with torch.no_grad():
            Q = policy_net(state)[0, action_ints]
        return action_ints[Q.max(0).indices.item()]
    else:
        return random.choice(action_ints)

In [39]:
GAMMA = 0.9
LR = 0.01
BATCH_SIZE = 128
HIDDEN_SIZE = 128
TAU = .005

POLICY_UPDATE = 1 # Number of actions before retraining
TARGET_UPDATE = 1 # Number of episodes before updating target network
NUM_EPISODES = 1000

In [40]:
env = gym.make("catanatron_gym:catanatron-v0")
state, info = env.reset()

num_actions = env.action_space.n
dim_state = state_tensor(env.game.state).shape[1]

memory = ReplayMemory(10000)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# EPS_START = 0.9
# EPS_END = 0.05
# EPS_DECAY = 200
# TARGET_UPDATE = 10

  logger.warn(


In [41]:
policy_net = DQN(dim_state, num_actions, HIDDEN_SIZE).to(device)
target_net = DQN(dim_state, num_actions, HIDDEN_SIZE).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.Adam(policy_net.parameters(), LR, amsgrad=True)
loss_fn = nn.MSELoss()

In [45]:
def train_batch():
    if len(memory) < BATCH_SIZE:
        return

    transitions = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))

    states = torch.cat(batch.state)
    actions = torch.tensor(batch.action)
    next_states = torch.cat([state for state in batch.next_state if state is not None])
    rewards = torch.tensor(batch.reward)
    dones = torch.tensor(batch.done)

    cur_Q = policy_net(states).gather(1, actions.unsqueeze(1)).squeeze(1)

    target_Q = torch.zeros(BATCH_SIZE, device=device)
    with torch.no_grad():
        target_Q[~dones] = target_net(next_states).max(1).values
    target_Q = rewards + GAMMA * target_Q

    # Update the policy network
    loss = F.mse_loss(cur_Q, target_Q)
    print("Loss:", loss.item())
    
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
    optimizer.step()

In [49]:
def train_episode():
    env.reset()
    state = state_tensor(env.game.state)
    done = False
    i = 0

    while not done:
        action_int = select_action(state, env.get_valid_actions(), policy_net, 0.5)
        # action_int = random.choice(env.get_valid_actions())
        action_struct = from_action_space(action_int, env.game.state.playable_actions)

        _, _, done, _ = env.step(action_int)
        reward = reward_f(env, action_struct)
        next_state = state_tensor(env.game.state) if not done else None

        memory.push(state, action_int, next_state, reward, done)
        state = next_state

        i += 1
        if i % POLICY_UPDATE == 0:
            train_batch()
    print(i)

        

In [50]:
train_episode()

Reward: 0.2
Loss: 2.5498762130737305
Reward: 0.1
Loss: 2.569211006164551
Reward: 0.2
Loss: 1.9969182014465332
Reward: 0.1
Loss: 2.1411640644073486
Reward: 0
Loss: 1.876656413078308
Reward: 0
Loss: 7815.42626953125
Reward: 0
Loss: 0.8467344045639038
Reward: 0
Loss: 15590.744140625
Reward: 0
Loss: 7788.66845703125
Reward: 0
Loss: 0.7628897428512573
Reward: 0
Loss: 0.6623737215995789
Reward: 0
Loss: 0.4462103843688965
Reward: 0
Loss: 0.46230024099349976
Reward: 0
Loss: 0.40422940254211426
Reward: 0
Loss: 0.3573697805404663
Reward: 0
Loss: 7813.20947265625
Reward: 0
Loss: 7764.36572265625
Reward: 0
Loss: 0.23999238014221191
Reward: 0
Loss: 7743.00146484375
Reward: 0
Loss: 0.26404422521591187
Reward: 0
Loss: 7821.513671875
Reward: 0
Loss: 0.24358396232128143
Reward: 0
Loss: 0.5576533079147339
Reward: 0
Loss: 0.2691742181777954
Reward: 0
Loss: 7638.966796875
Reward: 0
Loss: 0.3722044825553894
Reward: 0
Loss: 0.2417464405298233
Reward: 0
Loss: 0.34350359439849854
Reward: 0
Loss: 7526.20214843

In [51]:
def train():
    for i in tqdm(range(NUM_EPISODES)):
        train_episode()
        if i % TARGET_UPDATE == 0:
            target_net_state_dict = target_net.state_dict()
            policy_net_state_dict = policy_net.state_dict()
            for key in policy_net_state_dict:
                target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
            target_net.load_state_dict(target_net_state_dict)
            target_net.eval()
            # target_net.load_state_dict(policy_net.state_dict())
            # target_net.eval()

In [71]:
NUM_EPISODES = 10
train()

In [56]:
# def select_action(state, eps):
#     if random.random() < eps:
#         return torch.tensor([[random.randrange(num_actions)]], device=device, dtype=torch.long)
#     else:
#         with torch.no_grad():
#             return policy_net(state).max(1)[1].view(1, 1)

In [62]:
class DQNPlayer(Player):
    def __init__(self, model, color):
        super().__init__(color)
        assert color == Color.BLUE, "DQNPlayer only supports blue"
        self.model = model
        self.model.eval()

    def decide(self, game, valid_actions):
        state = state_tensor(game.state)
        action_ints = list(map(to_action_space, valid_actions))
        with torch.no_grad():
            Q = self.model(state)[0, action_ints]
        best_action = action_ints[Q.max(0).indices.item()]
        return from_action_space(best_action, valid_actions)

In [68]:
players = [
    RandomPlayer(Color.ORANGE),
    DQNPlayer(policy_net, Color.BLUE),
]

winners = []
for i in tqdm(range(100)):
    game = Game(players)
    winners.append(game.play())

  0%|          | 0/100 [00:00<?, ?it/s]

100%|██████████| 100/100 [00:59<00:00,  1.67it/s]


In [69]:
winners.count(Color.BLUE) / len(winners)

0.46

In [70]:
torch.save(policy_net.state_dict(), "policy_net.pt")