In [None]:
import base64
import bz2
import pickle

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

model_path = {
    "a": "../input/hungry-geese-models/geese_net_fold1_best.pth",
    "b": "../input/hungry-geese-models/geese_net_fold2_best.pth",
    "c": "../input/hungry-geese-models/geese_net_fold4_best.pth",
    # "d": "../input/hungry-geese-models/geese_net_fold3_best.pth",
    # "e": "../input/hungry-geese-models/geese_net_fold4_best.pth",
    # "f": "../input/hungry-geese-models/geese_net_fold5_best.pth",
}

PARAM = {}
for key, val in model_path.items():
    # weights = torch.load(val)
    weights = torch.load(val, map_location=torch.device("cpu"))
    PARAM[key] = base64.b64encode(bz2.compress(pickle.dumps(weights)))

In [None]:
%%writefile submission.py

PARAM = {
    "a": b"aaaaaaaaaa",
    "b": b"bbbbbbbbbb",
    "c": b"cccccccccc",
    # "d": b"dddddddddd",
    # "e": b"eeeeeeeeee",
    # "f": b"ffffffffff",
}


# This is a lightweight ML agent trained by self-play.
# After sharing this notebook,
# we will add Hungry Geese environment in our HandyRL library.
# https://github.com/DeNA/HandyRL
# We hope you enjoy reinforcement learning!


import base64
import bz2
import math
import pickle
import random
import time
from collections import defaultdict, deque
from copy import deepcopy

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from kaggle_environments.envs.hungry_geese.hungry_geese import Action, translate
from kaggle_environments.helpers import histogram

# MCTS


class MCTS:
    def __init__(self, game, nn_agent, eps=1e-8, cpuct=1.0, pb_c_base=19652, pb_c_init=1.25):
        self.game = game
        self.nn_agent = nn_agent
        self.eps = eps
        self.cpuct = cpuct
        self.pb_c_base = pb_c_base
        self.pb_c_init = pb_c_init

        self.Qsa = {}  # 状態 s でプレイヤー i が行動 a を行ったあとの状態の状態評価値(訪問回数で平均)
        self.Nsa = {}  # 状態 s でプレイヤー i が行動 a を行ったあとの状態への訪問回数
        self.Ns = {}  # 状態 s の訪問回数
        self.Ps = {}  # 状態 s でプレイヤー i の行動の評価値。policy networkの出力

        self.Es = {}  # 状態 s でゲームが終了している場合の プレイヤー i の成績
        self.Vs = {}  # 状態 s でのプレイヤー i の有効手

        self.last_obs = None

    def getActionProb(self, obs, timelimit=1.0):
        start_time = time.time()
        while time.time() - start_time < timelimit:
            self.search(obs, self.last_obs)

        s = self.game.stringRepresentation(obs)
        i = obs.index
        counts = [self.Nsa[(s, i, a)] if (s, i, a) in self.Nsa else 0 for a in range(self.game.getActionSize())]

        # もっとも探索数が多い方角が2つ以上あるときは、もう一回探索する
        if len([v for v in counts if v == max(counts)]) > 1:
            self.search(obs, self.last_obs)

        counts = [self.Nsa[(s, i, a)] if (s, i, a) in self.Nsa else 0 for a in range(self.game.getActionSize())]
        prob = counts / np.sum(counts)
        a = np.argmax(prob)

        print(f"step: {obs['step']}, player: {i}, value: {self.Qsa[(s, i, a)]:.3}, count: {counts} / {np.sum(counts)}")

        self.last_obs = obs
        return a

    def search(self, obs, last_obs):
        """
        用語:
            葉ノード: まだシミュレーションが行われていないノード
        """
        s = self.game.stringRepresentation(obs)

        if s not in self.Es:
            self.Es[s] = self.game.getGameEnded(obs, last_obs)
        if self.Es[s] is not None:
            return self.Es[s]

        # Aug を効かせるため、毎回推論する
        values = [-10] * 4
        for i in range(4):
            if len(obs.geese[i]) == 0:
                continue

            # ニューラルネットワークで局面を評価する
            self.Ps[(s, i)], values[i] = self.nn_agent.predict(obs, last_obs, i)

            if (s, i) not in self.Vs:
                self.Vs[(s, i)] = self.game.getValidMoves(obs, last_obs, i)
            self.Ps[(s, i)] = self.Ps[(s, i)] * self.Vs[(s, i)]  # masking invalid moves
            sum_Ps_s = np.sum(self.Ps[(s, i)])
            if sum_Ps_s > 0:
                self.Ps[(s, i)] /= sum_Ps_s  # renormalize

        # 現在の局面が葉ノードならば
        if s not in self.Ns:
            self.Ns[s] = 0

            # 各プレイヤーの現在の局面の 状態の評価値 を返す
            return values

        best_acts = [None] * 4
        for i in range(4):
            if len(obs.geese[i]) == 0:
                continue

            valids = self.Vs[(s, i)]
            cur_best = -float("inf")
            best_act = self.game.actions[-1]

            # pick the action with the highest upper confidence bound
            # 現在の局面 s でプレイヤー i の最適な行動を決定する
            for a in range(self.game.getActionSize()):
                if valids[a]:

                    # PUCT (AlphaGo)
                    """
                    if (s, i, a) in self.Qsa:
                        u = self.Qsa[(s, i, a)] + self.cpuct * self.Ps[(s, i)][a] * math.sqrt(self.Ns[s]) / (
                            1 + self.Nsa[(s, i, a)]
                        )
                    else:
                        u = self.cpuct * self.Ps[(s, i)][a] * math.sqrt(self.Ns[s] + self.eps)
                    """

                    # PUCT (AlphaZero)
                    """
                    cs = math.log((1 + self.Ns[s] + self.pb_c_base) / self.pb_c_base) + self.pb_c_init

                    if (s, i, a) in self.Qsa:
                        u = self.Qsa[(s, i, a)] + cs * self.Ps[(s, i)][a] * math.sqrt(self.Ns[s]) / (
                            1 + self.Nsa[(s, i, a)]
                        )
                    else:
                        u = cs * self.Ps[(s, i)][a] * math.sqrt(self.Ns[s] + self.eps)
                    """

                    # Use only policy
                    u = self.Ps[(s, i)][a]

                    if u > cur_best:
                        cur_best = u
                        best_act = self.game.actions[a]

            best_acts[i] = best_act

        # 各プレイヤーがベストな行動を行ったあとの局面を生成
        next_obs = self.game.getNextState(obs, last_obs, best_acts)

        # 生成した次の局面を探索
        values = self.search(next_obs, obs)

        for i in range(4):
            if len(obs.geese[i]) == 0:
                continue

            a = self.game.actions.index(best_acts[i])
            v = values[i]

            if (s, i, a) in self.Qsa:
                self.Qsa[(s, i, a)] = (self.Nsa[(s, i, a)] * self.Qsa[(s, i, a)] + v) / (self.Nsa[(s, i, a)] + 1)
                self.Nsa[(s, i, a)] += 1

            else:
                self.Qsa[(s, i, a)] = v
                self.Nsa[(s, i, a)] = 1

        self.Ns[s] += 1
        return values


class HungryGeese(object):
    def __init__(
        self, rows=7, columns=11, actions=[Action.NORTH, Action.SOUTH, Action.WEST, Action.EAST], hunger_rate=40
    ):
        self.rows = rows
        self.columns = columns
        self.actions = actions
        self.hunger_rate = hunger_rate

    def getActionSize(self):
        return len(self.actions)

    def getNextState(self, obs, last_obs, directions):
        next_obs = deepcopy(obs)
        next_obs.step += 1
        geese = next_obs.geese
        food = next_obs.food
        new_food = 0

        for i in range(4):
            goose = geese[i]

            if len(goose) == 0:
                continue

            head = translate(goose[0], directions[i], self.columns, self.rows)

            # Check action direction
            if last_obs is not None and head == last_obs.geese[i][0]:
                geese[i] = []
                continue

            # Consume food or drop a tail piece.
            if head in food:
                food.remove(head)
                new_food += 1
            else:
                goose.pop()

            # Add New Head to the Goose.
            goose.insert(0, head)

            # If hunger strikes remove from the tail.
            if next_obs.step % self.hunger_rate == 0:
                if len(goose) > 0:
                    goose.pop()

            geese[i] = goose

        goose_positions = histogram(position for goose in geese for position in goose)

        # Check for collisions.
        for i in range(4):
            if len(geese[i]) > 0:
                head = geese[i][0]
                if goose_positions[head] > 1:
                    geese[i] = []

        if new_food > 0:
            collisions = {position for goose in geese for position in goose}
            available_positions = set(range(77)).difference(collisions).difference(food)
            # Ensure we don't sample more food than available positions.
            needed_food = min(new_food, len(available_positions))
            food.extend(random.sample(available_positions, needed_food))

        next_obs.geese = geese
        next_obs.food = food

        return next_obs

    def getValidMoves(self, obs, last_obs, index):
        geese = obs.geese
        pos = geese[index][0]
        obstacles = {position for goose in geese for position in goose[:-1]}
        if last_obs is not None:
            obstacles.add(last_obs.geese[index][0])

        valid_moves = [translate(pos, action, self.columns, self.rows) not in obstacles for action in self.actions]

        return valid_moves

    def getGameEnded(self, obs, last_obs):
        """
        return None if game is not ended.
        """
        active_geese = len([goose for goose in obs.geese if len(goose) > 0])
        if active_geese > 1 and obs.step < 199:
            return None

        rewards = [0.0] * 4
        for p, geese in enumerate(obs.geese):
            if len(geese) > 0:
                rewards[p] = len(geese) + 100
        for p, geese in enumerate(last_obs.geese):
            if len(geese) > 0 and rewards[p] == 0:
                rewards[p] = len(geese)

        outcomes = [0.0] * 4
        for p, r in enumerate(rewards):
            for pp, rr in enumerate(rewards):
                if p != pp:
                    if r > rr:
                        outcomes[p] += 1.0
                    elif r < rr:
                        outcomes[p] -= 2.0

        # print(f"outcomes: {outcomes}")
        return outcomes

    def stringRepresentation(self, obs):
        return str(obs.geese + obs.food)


# Neural Network for Hungry Geese


class TorusConv2d(nn.Module):
    def __init__(self, input_dim, output_dim, kernel_size, bn):
        super().__init__()
        self.edge_size = (kernel_size[0] // 2, kernel_size[1] // 2)
        self.conv = nn.Conv2d(input_dim, output_dim, kernel_size=kernel_size)
        self.bn = nn.BatchNorm2d(output_dim) if bn else None

    def forward(self, x):
        h = torch.cat([x[:, :, :, -self.edge_size[1] :], x, x[:, :, :, : self.edge_size[1]]], dim=3)
        h = torch.cat([h[:, :, -self.edge_size[0] :], h, h[:, :, : self.edge_size[0]]], dim=2)
        h = self.conv(h)
        h = self.bn(h) if self.bn is not None else h
        return h


class GeeseNetAlpha(nn.Module):
    def __init__(self):
        super().__init__()

        layers = 12
        filters = 48
        dim = 270

        self.embed_step = nn.Embedding(5, 3)
        self.embed_hunger = nn.Embedding(5, 3)
        self.embed_diff_len = nn.Embedding(7, 4)
        self.embed_diff_head = nn.Embedding(9, 4)

        self.conv0 = TorusConv2d(25, filters, (3, 3), True)
        self.blocks = nn.ModuleList([TorusConv2d(filters, filters, (3, 3), True) for _ in range(layers)])

        self.conv1 = TorusConv2d(filters, filters, (5, 5), True)

        self.head_p1 = nn.Linear(dim, dim // 2, bias=False)
        self.head_p2 = nn.Linear(dim // 2, 4, bias=False)
        self.head_v1 = nn.Linear(dim, dim // 2, bias=False)
        self.head_v2 = nn.Linear(dim // 2, 1, bias=False)

    def forward(self, x, _=None):
        x_feats = x[:, -1].view(x.size(0), -1).long()

        # Embedding for features
        e_step = self.embed_step(x_feats[:, 0])
        e_hung = self.embed_hunger(x_feats[:, 1])
        e_diff_l = self.embed_diff_len(x_feats[:, 2:5]).view(x.size(0), -1)
        e_diff_h = self.embed_diff_head(x_feats[:, 5:8]).view(x.size(0), -1)

        x = x[:, :-1].float()

        # CNN for observation
        h = F.relu_(self.conv0(x))

        for block in self.blocks:
            h = F.relu_(h + block(h))

        h = F.relu_(self.conv1(h))

        # Extract head position
        h_head = (h * x[:, :1]).view(h.size(0), h.size(1), -1).sum(-1)
        h_head2 = (h * x[:, 1:2]).view(h.size(0), h.size(1), -1).sum(-1)
        h_head3 = (h * x[:, 2:3]).view(h.size(0), h.size(1), -1).sum(-1)
        h_head4 = (h * x[:, 3:4]).view(h.size(0), h.size(1), -1).sum(-1)
        h_avg = h.view(h.size(0), h.size(1), -1).mean(-1)

        # Merge features
        h = torch.cat(
            [
                h_head,
                h_head2,
                h_head3,
                h_head4,
                h_avg,
                e_step,
                e_hung,
                e_diff_l,
                e_diff_h,
            ],
            1,
        ).view(1, h.size(0), -1)

        h_p = F.relu_(self.head_p1(h.view(x.size(0), -1)))
        p = self.head_p2(h_p)

        h_v = F.relu_(self.head_v1(h.view(x.size(0), -1)))
        v = torch.tanh(self.head_v2(h_v))

        return p, v  # {"policy": p, "value": v}


def identity(image):
    return image.copy(), [0, 1, 2, 3]


def h_flip(image):
    image = image[:, :, ::-1]
    return image.copy(), [0, 1, 3, 2]


def v_flip(image):
    image = image[:, ::-1, :]
    return image.copy(), [1, 0, 2, 3]


def hv_flip(image):
    image = image[:, ::-1, ::-1]
    return image.copy(), [1, 0, 3, 2]


class NNAgent:

    next_position_map = {}
    for pos in range(77):
        position = []
        position.append((11 * (1 + pos // 11) + pos % 11) % 77)
        position.append((11 * (-1 + pos // 11) + pos % 11) % 77)
        position.append((11 * (pos // 11) + (pos + 1) % 11) % 77)
        position.append((11 * (pos // 11) + (pos - 1) % 11) % 77)
        next_position_map[pos] = set(position)

    def __init__(self, state_dicts):
        self.models = {}
        for key, state in state_dicts.items():
            self.models[key] = GeeseNetAlpha()
            self.models[key].load_state_dict(state)
            self.models[key].eval()

    def predict(self, obs, last_obs, index):
        x, info = self._make_input(obs, last_obs, index)

        transform = random.choice([identity, h_flip, v_flip, hv_flip])
        p, v = self._predict(x, transform, info)

        return p, v

    def _predict(self, x, transform, info=None):
        x, slices = transform(x)
        if info is not None:
            x = np.concatenate([x, info], axis=0)

        model_key = random.choice(list(PARAM.keys()))
        with torch.no_grad():
            xt = torch.from_numpy(x).unsqueeze(0)
            p, v = self.models[model_key](xt)

        p = p.squeeze(0).detach().numpy()
        p = p[slices]
        return p, v.item()

    # Input for Neural Network
    def _make_input(self, obs, last_obs, index):
        x_ = []
        x_.append(self._make_input_normal(obs, last_obs, index))
        x_.append(self._get_reverse_cube(obs, index))
        x_.append(self._get_next_disappear_cube(obs, index))
        x = np.concatenate(x_)

        info_ = []
        # info_.append(self._get_step_cube_v2(obs))
        # info_.append(self._get_length_cube(obs, index))
        info_.append(self._get_features(obs, index))
        info = np.concatenate(info_)

        return x, info

    def _make_input_normal(self, obs, last_obs, index):
        b = np.zeros((17, 7 * 11), dtype=np.float32)

        for p, pos_list in enumerate(obs.geese):
            # head position
            for pos in pos_list[:1]:
                b[0 + (p - index) % 4, pos] = 1
            # tip position
            for pos in pos_list[-1:]:
                b[4 + (p - index) % 4, pos] = 1
            # whole position
            for pos in pos_list:
                b[8 + (p - index) % 4, pos] = 1

        # previous head position
        if last_obs is not None:
            for p, pos_list in enumerate(last_obs.geese):
                for pos in pos_list[:1]:
                    b[12 + (p - index) % 4, pos] = 1

        # food
        for pos in obs.food:
            b[16, pos] = 1

        return b.reshape(-1, 7, 11)

    def _get_reverse_cube(self, obs, index):
        """
        尻尾から順番に 1, 0.9, 0.8, ... という並び
        """
        b = np.zeros((4, 7 * 11), dtype=np.float32)

        for p, geese in enumerate(obs["geese"]):
            # whole position reverse
            for num_reverse, pos in enumerate(geese[::-1]):
                b[(p - index) % 4, pos] = 1 - num_reverse * 0.1

        return b.reshape(-1, 7, 11)

    def _get_next_disappear_cube(self, obs, index):
        """
        次になくなる場所: 1
        次になくなる可能性のある場所: 0.5
        """
        b = np.zeros((4, 7 * 11), dtype=np.float32)
        step = obs["step"]

        # foodを食べる可能性があるか。
        eat_food_possibility = defaultdict(int)
        for p, geese in enumerate(obs["geese"]):
            for pos in geese[:1]:
                if not self.next_position_map[pos].isdisjoint(obs["food"]):
                    eat_food_possibility[p] = 1

        if (step % 40) == 39:  # 1つ短くなる
            for p, geese in enumerate(obs["geese"]):
                if eat_food_possibility[p]:  # 尻尾が1、尻尾の１つ前0.5
                    for pos in geese[-1:]:
                        b[(p - index) % 4, pos] = 1
                    for pos in geese[-2:-1]:
                        b[(p - index) % 4, pos] = 0.5
                else:  # 食べる可能性なし -> 尻尾が1, 尻尾の1つ前1
                    for pos in geese[-2:]:
                        b[(p - index) % 4, pos] = 1
        else:  # 1つ短くならない
            for p, geese in enumerate(obs["geese"]):
                if eat_food_possibility[p]:  # 食べる可能性があり -> 尻尾を0.5
                    for pos in geese[-1:]:
                        b[(p - index) % 4, pos] = 0.5
                else:  # 食べる可能性なし # 尻尾を1
                    for pos in geese[-1:]:
                        b[(p - index) % 4, pos] = 1

        return b.reshape(-1, 7, 11)

    def _get_step_cube_v2(self, obs):
        """
        step0: 0, step199: 1
        step0: 0, step39 + 40n: 1
        """
        b = np.zeros((1, 7, 11), dtype=np.float32)
        step = obs["step"]

        b[:, :, :5] = (step % 200) / 199
        b[:, :, 5:] = (step % 40) / 39

        return b

    def _get_length_cube(self, obs, index):
        b = np.zeros((2, 7, 11), dtype=np.float32)

        my_length = len(obs["geese"][index])
        opposite1_length = len(obs["geese"][(index + 1) % 4])
        opposite2_length = len(obs["geese"][(index + 2) % 4])
        opposite3_length = len(obs["geese"][(index + 3) % 4])

        b[0] = my_length / 10
        max_opposite_length = max(opposite1_length, opposite2_length, opposite3_length)
        b[1, :, 0:2] = (my_length - max_opposite_length) / 10
        b[1, :, 2:5] = (my_length - opposite1_length) / 10
        b[1, :, 5:8] = (my_length - opposite2_length) / 10
        b[1, :, 8:11] = (my_length - opposite3_length) / 10

        return b

    def _get_features(self, obs, index):
        b = np.zeros((7 * 11), dtype=np.float32)
        step = obs["step"]

        my_goose = obs["geese"][index]
        my_length = len(my_goose)

        # num step
        b[0] = (step - 194) if step >= 195 else 0
        b[1] = (step % 40 - 35) if step % 40 > 35 else 0

        """
        2-4: difference between my_length and opponent length (-3 to 3)
        """
        for p, pos_list in enumerate(obs["geese"]):
            pid = (p - index) % 4
            p_length = len(pos_list)

            if pid == 0:
                continue

            b[1 + pid] = max(min(my_length - p_length, 3), -3) + 3

        """
        5-7: difference between my head position and opponent one
        """
        if my_length != 0:

            for p, pos_list in enumerate(obs["geese"]):
                pid = (p - index) % 4

                if pid == 0 or len(pos_list) == 0:
                    continue

                diff = abs(my_goose[0] - pos_list[0])
                x_ = diff % 11
                x = min(x_, 11 - x_)
                y_ = diff // 11
                y = min(y_, 7 - y_)
                b[4 + pid] = x + y

        return b.reshape(1, 7, 11)


# Load PyTorch Model


state_dicts = {}
for key, param in PARAM.items():
    state_dicts[key] = pickle.loads(bz2.decompress(base64.b64decode(param)))

game = HungryGeese()
agent = NNAgent(state_dicts)
mcts = MCTS(game, agent, pb_c_base=10, pb_c_init=1.0)


def alphageese_agent(obs, config):
    action = game.actions[mcts.getActionProb(obs, timelimit=0.9)]  # timelimit=config.actTimeout
    return action.name

In [None]:
# Read in the submission file
with open(
    "submission.py",
) as file:
    filedata = file.read()

# Replace the target string
for key, val in PARAM.items():
    filedata = filedata.replace(key * 10, val.decode("utf-8"))

# Write the file out again
with open("submission.py", "w") as file:
    file.write(filedata)

In [None]:
from kaggle_environments import make

env = make("hungry_geese", debug=True)

env.reset()
env.run(["submission.py", "submission.py", "submission.py", "submission.py"])
env.render(mode="ipython", width=800, height=700)