In [1]:
from catanatron import Game, Player, RandomPlayer, Color, ActionType
from catanatron_gym.envs.catanatron_env import from_action_space, to_action_space
from catanatron_gym.features import create_sample_vector
import gymnasium as gym

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import numpy as np
import random

from collections import namedtuple, deque
from tqdm import tqdm

In [2]:
class DQN(nn.Module):
    def __init__(self, dim_state, num_actions, hidden_size):
        super(DQN, self).__init__()
        self.lin1 = nn.Linear(dim_state, hidden_size)
        self.lin2 = nn.Linear(hidden_size, hidden_size)
        self.lin3 = nn.Linear(hidden_size, num_actions)

    def forward(self, x):
        x = F.relu(self.lin1(x))
        x = F.relu(self.lin2(x))
        return self.lin3(x)

In [3]:
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward', 'done'))

class ReplayMemory(object):
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [4]:
def reward_f(env, action):
    action_type = action.action_type
    reward = 0
    if action_type == ActionType.BUILD_ROAD:
       reward = 1
    elif action_type == ActionType.BUILD_SETTLEMENT:
       reward = 2
    elif action_type == ActionType.BUILD_CITY:
       reward = 3
    if env.game.winning_color() == Color.WHITE:
        reward = 1000
    elif env.game.winning_color() is not None:
        reward = -1000
    return reward

def state_tensor(state):
    res = {k: state.player_state[k] for k in ('P0_ACTUAL_VICTORY_POINTS', 'P0_VICTORY_POINTS', 'P0_WOOD_IN_HAND', 'P0_BRICK_IN_HAND', 'P0_ORE_IN_HAND', 'P0_SHEEP_IN_HAND', 'P0_WHEAT_IN_HAND')}
    for i in range(54):
        building_feature = "LOC_" + str(i) + "_HAS_BUILDING"
        if i in state.board.buildings:
            res[building_feature] = True
        else:
            res[building_feature] = False
    for i in range(72):
        road_feature = "LOC_" + str(i) + "_HAS_ROAD"
        if i in state.board.roads:
            res[road_feature] = True
        else:
            res[road_feature] = False

    return torch.tensor(list(res.values()), dtype=torch.float32).unsqueeze(0)

In [5]:
GAMMA = 0.9
LR = 0.05
BATCH_SIZE = 128
HIDDEN_SIZE = 128

POLICY_UPDATE = 1 # Number of actions before retraining
TARGET_UPDATE = 1 # Number of episodes before updating target network
NUM_EPISODES = 1000

In [6]:
env = gym.make("catanatron_gym:catanatron-v0")
state, info = env.reset()

num_actions = env.action_space.n
dim_state = state_tensor(env.game.state).shape[1]

memory = ReplayMemory(10000)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# EPS_START = 0.9
# EPS_END = 0.05
# EPS_DECAY = 200
# TARGET_UPDATE = 10

  logger.warn(


In [7]:
policy_net = DQN(dim_state, num_actions, HIDDEN_SIZE).to(device)
target_net = DQN(dim_state, num_actions, HIDDEN_SIZE).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.Adam(policy_net.parameters(), LR)
loss_fn = nn.MSELoss()

In [8]:
def train_batch():
    if len(memory) < BATCH_SIZE:
        return

    transitions = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))

    states = torch.cat(batch.state)
    actions = torch.tensor(batch.action)
    next_states = torch.cat([state for state in batch.next_state if state is not None])
    rewards = torch.tensor(batch.reward)
    dones = torch.tensor(batch.done)

    cur_Q = policy_net(states).gather(1, actions.unsqueeze(1)).squeeze(1)

    target_Q = torch.zeros(BATCH_SIZE, device=device)
    with torch.no_grad():
        target_Q[~dones] = target_net(next_states).max(1).values
    target_Q = rewards + GAMMA * target_Q

    # Update the policy network
    loss = F.mse_loss(cur_Q, target_Q)
    print(loss.item())
    optimizer.zero_grad()
    ### TODO: Clip Gradients #######################
    loss.backward()
    optimizer.step()

In [9]:
def train_episode():
    env.reset()
    state = state_tensor(env.game.state)
    done = False
    i = 0

    while not done:
        action_int = random.choice(env.get_valid_actions())
        action_struct = from_action_space(action_int, env.game.state.playable_actions)
        print(action_struct)

        _, _, done, _ = env.step(action_int)
        reward = reward_f(env, action_struct)
        next_state = state_tensor(env.game.state) if not done else None

        memory.push(state, action_int, next_state, reward, done)
        state = next_state

        i += 1
        if i % POLICY_UPDATE == 0:
            train_batch()

        

In [10]:
def train():
    for i in tqdm(range(NUM_EPISODES)):
        train_episode()
        if i % TARGET_UPDATE == 0:
            target_net.load_state_dict(policy_net.state_dict())
            target_net.eval()

In [11]:
NUM_EPISODES = 1
train()

  logger.deprecation(


Action(color=<Color.BLUE: 'BLUE'>, action_type=<ActionType.BUILD_SETTLEMENT: 'BUILD_SETTLEMENT'>, value=25)
Action(color=<Color.BLUE: 'BLUE'>, action_type=<ActionType.BUILD_ROAD: 'BUILD_ROAD'>, value=(24, 25))
Action(color=<Color.BLUE: 'BLUE'>, action_type=<ActionType.BUILD_SETTLEMENT: 'BUILD_SETTLEMENT'>, value=37)
Action(color=<Color.BLUE: 'BLUE'>, action_type=<ActionType.BUILD_ROAD: 'BUILD_ROAD'>, value=(36, 37))
Action(color=<Color.BLUE: 'BLUE'>, action_type=<ActionType.ROLL: 'ROLL'>, value=None)
Action(color=<Color.BLUE: 'BLUE'>, action_type=<ActionType.END_TURN: 'END_TURN'>, value=None)
Action(color=<Color.BLUE: 'BLUE'>, action_type=<ActionType.ROLL: 'ROLL'>, value=None)
Action(color=<Color.BLUE: 'BLUE'>, action_type=<ActionType.MARITIME_TRADE: 'MARITIME_TRADE'>, value=('ORE', 'ORE', 'ORE', None, 'BRICK'))
Action(color=<Color.BLUE: 'BLUE'>, action_type=<ActionType.END_TURN: 'END_TURN'>, value=None)
Action(color=<Color.BLUE: 'BLUE'>, action_type=<ActionType.ROLL: 'ROLL'>, value=No

100%|██████████| 1/1 [00:04<00:00,  4.44s/it]

0.004191169980913401
Action(color=<Color.BLUE: 'BLUE'>, action_type=<ActionType.END_TURN: 'END_TURN'>, value=None)
0.0006655871402472258
Action(color=<Color.BLUE: 'BLUE'>, action_type=<ActionType.ROLL: 'ROLL'>, value=None)
0.0007120549562387168
Action(color=<Color.BLUE: 'BLUE'>, action_type=<ActionType.END_TURN: 'END_TURN'>, value=None)
0.000513885694090277
Action(color=<Color.BLUE: 'BLUE'>, action_type=<ActionType.ROLL: 'ROLL'>, value=None)
0.001959867076948285
Action(color=<Color.BLUE: 'BLUE'>, action_type=<ActionType.MOVE_ROBBER: 'MOVE_ROBBER'>, value=((-2, 1, 1), None, None))
0.0014524100115522742
Action(color=<Color.BLUE: 'BLUE'>, action_type=<ActionType.END_TURN: 'END_TURN'>, value=None)
0.0010053979931399226
Action(color=<Color.BLUE: 'BLUE'>, action_type=<ActionType.ROLL: 'ROLL'>, value=None)
0.0012064989423379302
Action(color=<Color.BLUE: 'BLUE'>, action_type=<ActionType.MARITIME_TRADE: 'MARITIME_TRADE'>, value=('WHEAT', 'WHEAT', 'WHEAT', None, 'BRICK'))
0.0025874320417642593
A




In [56]:
# def select_action(state, eps):
#     if random.random() < eps:
#         return torch.tensor([[random.randrange(num_actions)]], device=device, dtype=torch.long)
#     else:
#         with torch.no_grad():
#             return policy_net(state).max(1)[1].view(1, 1)

In [12]:
class DQNPlayer(Player):
    def __init__(self, model, color):
        super().__init__(color)
        assert color == Color.WHITE, "DQNPlayer only supports white"
        self.model = model
        self.model.eval()

    def decide(self, game, valid_actions):
        state = state_tensor(game.state)
        action_ints = list(map(to_action_space, valid_actions))
        with torch.no_grad():
            Q = self.model(state)[0, action_ints]
        best_action = action_ints[Q.max(0).indices.item()]
        return from_action_space(best_action, valid_actions)

In [13]:
players = [
    RandomPlayer(Color.ORANGE),
    DQNPlayer(policy_net, Color.WHITE),
]

winners = []
for i in tqdm(range(10)):
    game = Game(players)
    winners.append(game.play())

  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:06<00:00,  1.45it/s]


In [15]:
winners.count(Color.WHITE) / len(winners)

0.0