In [None]:
!pip install -q -U kaggle-environments
!pip list | grep kaggle

In [None]:
import base64
import bz2
import pickle

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

model_path = {
    "a": "../input/hungry-geese-models/first_stage_2462.pth",
    "b": "../input/hungry-geese-models/IMO_pretrain_1st_4_280.pth",
    # "c": "../input/hungry-geese-models/first_stage_2237.pth",
    # "d": "../input/hungry-geese-models/second_stage_3257.pth",
    # "e": "../input/hungry-geese-models/CNN_first_2_2414.pth",
    # "f": "../input/hungry-geese-models/CNN_first_2_2540.pth",
    # "g": "../input/hungry-geese-models/CNN_first_2_2545.pth",
}

PARAM = {}
for key, val in model_path.items():
    weights = torch.load(val)
    PARAM[key] = base64.b64encode(bz2.compress(pickle.dumps(weights)))

In [None]:
%%writefile submission.py

# This is a lightweight ML agent trained by self-play.
# After sharing this notebook,
# we will add Hungry Geese environment in our HandyRL library.
# https://github.com/DeNA/HandyRL
# We hope you enjoy reinforcement learning!


import base64
import bz2
import pickle
from collections import defaultdict

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# Neural Network for Hungry Geese


class Dense(nn.Module):
    def __init__(self, units0, units1, bnunits=0, bias=True):
        super().__init__()
        if bnunits > 0:
            bias = False
        self.dense = nn.Linear(units0, units1, bias=bias)
        self.bnunits = bnunits
        self.bn = nn.BatchNorm1d(bnunits) if bnunits > 0 else None

    def forward(self, x):
        h = self.dense(x)
        if self.bn is not None:
            size = h.size()
            h = h.view(-1, self.bnunits)
            h = self.bn(h)
            h = h.view(*size)
        return h


class TorusConv2d(nn.Module):
    def __init__(self, input_dim, output_dim, kernel_size, bn):
        super().__init__()
        self.edge_size = (kernel_size[0] // 2, kernel_size[1] // 2)
        self.conv = nn.Conv2d(input_dim, output_dim, kernel_size=kernel_size)
        self.bn = nn.BatchNorm2d(output_dim) if bn else None

    def forward(self, x):
        h = torch.cat([x[:, :, :, -self.edge_size[1] :], x, x[:, :, :, : self.edge_size[1]]], dim=3)
        h = torch.cat([h[:, :, -self.edge_size[0] :], h, h[:, :, : self.edge_size[0]]], dim=2)
        h = self.conv(h)
        h = self.bn(h) if self.bn is not None else h
        return h


class GeeseNet(nn.Module):
    def __init__(self):
        super().__init__()
        layers, filters = 12, 32
        self.conv0 = TorusConv2d(17, filters, (3, 3), True)
        self.blocks = nn.ModuleList([TorusConv2d(filters, filters, (3, 3), True) for _ in range(layers)])

        self.conv_p = TorusConv2d(filters, filters, (3, 3), True)
        self.conv_v = TorusConv2d(filters, filters, (3, 3), True)

        self.head_p = nn.Linear(filters, 4, bias=False)
        self.head_v1 = nn.Linear(filters * 2, filters, bias=False)
        self.head_v2 = nn.Linear(filters, 1, bias=False)

    def forward(self, x, _=None):
        h = F.relu_(self.conv0(x))
        for block in self.blocks:
            h = F.relu_(h + block(h))

        h_p = F.relu_(self.conv_p(h))
        h_head_p = (h_p * x[:, :1]).view(h_p.size(0), h_p.size(1), -1).sum(-1)
        p = self.head_p(h_head_p)

        h_v = F.relu_(self.conv_v(h))
        h_head_v = (h_v * x[:, :1]).view(h_v.size(0), h_v.size(1), -1).sum(-1)
        h_avg_v = h_v.view(h_v.size(0), h_v.size(1), -1).mean(-1)

        h_v = F.relu_(self.head_v1(torch.cat([h_head_v, h_avg_v], 1)))
        v = torch.tanh(self.head_v2(h_v))

        return {"policy": p, "value": v, "h_head_p": h_head_p, "h_head_v": h_head_v, "h_avg_v": h_avg_v}


class GeeseNet2(nn.Module):
    def __init__(self):
        super().__init__()
        blocks, filters = 12, 32
        self.conv0 = TorusConv2d(17, filters, (3, 3), True)
        self.blocks = nn.ModuleList([TorusConv2d(filters, filters, (3, 3), True) for _ in range(blocks)])

        self.conv_p1 = TorusConv2d(filters, filters, (3, 3), True)
        self.conv_p2 = TorusConv2d(filters, filters, (3, 3), True)
        self.conv_v = TorusConv2d(filters, filters, (3, 3), True)

        self.head_p = nn.Linear(filters, 4, bias=False)
        self.head_v1 = nn.Linear(filters * 2, filters, bias=False)
        self.head_v2 = nn.Linear(filters, 1, bias=False)

    def forward(self, x, _=None):
        h = F.relu_(self.conv0(x))
        for block in self.blocks:
            h = F.relu_(h + block(h))

        h_p = F.relu_(self.conv_p1(h))
        h_p = F.relu_(self.conv_p2(h_p))
        h_head_p = (h_p * x[:, :1]).view(h_p.size(0), h_p.size(1), -1).sum(-1)
        p = self.head_p(h_head_p)

        h_v = F.relu_(self.conv_v(h))
        h_head_v = (h_v * x[:, :1]).view(h_v.size(0), h_v.size(1), -1).sum(-1)
        h_avg_v = h_v.view(h_v.size(0), h_v.size(1), -1).mean(-1)
        h_v = F.relu_(self.head_v1(torch.cat([h_head_v, h_avg_v], 1)))
        v = torch.tanh(self.head_v2(h_v))

        return {"policy": p, "value": v, "h_head_p": h_head_p, "h_head_v": h_head_v, "h_avg_v": h_avg_v}


class GeeseNetIMO(nn.Module):
    class GeeseBlock(nn.Module):
        def __init__(self, embed_dim, num_heads):
            super().__init__()
            self.attention = nn.MultiheadAttention(embed_dim, num_heads)

        def forward(self, x):
            h, _ = self.attention(x, x, x)
            return h

    class GeeseControll(nn.Module):
        def __init__(self, filters, final_filters):
            super().__init__()
            self.filters = filters
            self.attention = nn.MultiheadAttention(filters, 1)
            self.fc_control = Dense(filters * 3, final_filters, bnunits=final_filters)

        def forward(self, x, e):
            h, _ = self.attention(x, x, x)

            h = torch.cat([x, e, h], dim=2).view(x.size(1), -1)
            h = self.fc_control(h)
            return h

    class GeeseHead(nn.Module):
        def __init__(self, filters):
            super().__init__()
            f = filters // 2
            self.head_p_1 = nn.Linear(filters, f, bias=False)
            self.head_p_2 = nn.Linear(f, 4, bias=False)
            self.head_v_1 = nn.Linear(filters, f, bias=True)
            self.head_v_2 = nn.Linear(f, 1, bias=True)

        def forward(self, x):
            p = self.head_p_1(x)
            p = self.head_p_2(p)
            v = self.head_v_1(x)
            v = torch.tanh(self.head_v_2(v))
            return p, v

    def __init__(self):
        super().__init__()
        blocks = 5
        filters = 64
        final_filters = 128

        self.geese_net = GeeseNet()

        self.blocks = nn.ModuleList([self.GeeseBlock(filters, 8) for _ in range(blocks)])
        self.control = self.GeeseControll(filters, final_filters)
        self.head = self.GeeseHead(final_filters)

    def forward(self, x, _=None):
        x_ = self.geese_net(x)
        e = torch.cat([x_["h_head_p"], x_["h_avg_v"]], 1).view(1, x.size()[0], -1)
        h = e
        for block in self.blocks:
            h = block(h)
        h = self.control(h, e)
        p, v = self.head(h)
        return {"policy": p, "value": v}


# Input for Neural Network


NUM_ROW = 7
NUM_COL = 11
CENTER_ROW = NUM_ROW // 2
CENTER_COL = NUM_COL // 2


def to_offset(x):
    row = CENTER_ROW - x // NUM_COL
    col = CENTER_COL - x % NUM_COL
    return row, col


def to_row(offset, x):
    return (x // NUM_COL + offset) % NUM_ROW


def to_col(offset, x):
    return (x + offset) % NUM_COL


def make_input(obses):
    b = np.zeros((17, 7 * 11), dtype=np.float32)
    obs = obses[-1]

    for p, pos_list in enumerate(obs["geese"]):
        # head position
        for pos in pos_list[:1]:
            b[0 + (p - obs["index"]) % 4, pos] = 1
        # tip position
        for pos in pos_list[-1:]:
            b[4 + (p - obs["index"]) % 4, pos] = 1
        # whole position
        for pos in pos_list:
            b[8 + (p - obs["index"]) % 4, pos] = 1

    # previous head position
    if len(obses) > 1:
        obs_prev = obses[-2]
        for p, pos_list in enumerate(obs_prev["geese"]):
            for pos in pos_list[:1]:
                b[12 + (p - obs["index"]) % 4, pos] = 1

    # food
    for pos in obs["food"]:
        b[16, pos] = 1

    return b.reshape(-1, 7, 11)


def make_input_centering_head_for_rule(obses):
    b = defaultdict(list)
    obs = obses[-1]

    player_goose_head = obs["geese"][obs["index"]][0]
    o_row, o_col = to_offset(player_goose_head)

    for p, geese in enumerate(obs["geese"]):
        # body position
        for pos in geese[1:-1]:
            if (p - obs["index"]) % 4 == 0:
                b["pb"].append((to_row(o_row, pos), to_col(o_col, pos)))
            else:
                b["ob"].append((to_row(o_row, pos), to_col(o_col, pos)))

        # tip position
        for pos in geese[-1:]:
            if (p - obs["index"]) % 4 == 0:
                b["pt"].append((to_row(o_row, pos), to_col(o_col, pos)))
            else:
                b["ot"].append((to_row(o_row, pos), to_col(o_col, pos)))

        # head position
        for pos in geese[:1]:
            if (p - obs["index"]) % 4 == 0:
                b["ph"].append((to_row(o_row, pos), to_col(o_col, pos)))
            else:
                b["oh"].append((to_row(o_row, pos), to_col(o_col, pos)))

    # previous head position
    if len(obses) > 1:
        obs_prev = obses[-2]
        for p, geese in enumerate(obs_prev["geese"]):
            for pos in geese[:1]:
                if (p - obs["index"]) % 4 == 0:
                    b["pp"].append((to_row(o_row, pos), to_col(o_col, pos)))

    # food
    for pos in obs["food"]:
        b["f"].append((to_row(o_row, pos), to_col(o_col, pos)))

    return b


def distance(a, b):
    x = b[0] - a[0]
    y = b[1] - a[1]
    return (x, y), abs(x) + abs(y)


def apply_rule(b, prob):
    """
    player head = (3, 5)
    ["NORTH", "SOUTH", "WEST", "EAST"]
    """
    neighbor = [(2, 5), (4, 5), (3, 4), (3, 6)]

    # 隣接している場所に行けないケース
    for i, n in enumerate(neighbor):
        if (n in b["pb"]) or (n in b["ob"]) or (n in b["pp"]):
            prob[i] = -np.inf

    # 次の移動で頭がぶつかる可能性のあるケース
    # for i, h in enumerate(b["oh"]):
    #     (x, y), d = distance(b["ph"][0], h)
    #     if d == 2:
    #         if x < 0:
    #             prob[0] -= 10
    #         elif x > 0:
    #             prob[1] -= 10
    #         if y < 0:
    #             prob[2] -= 10
    #         elif y > 0:
    #             prob[3] -= 10

    return prob


# Load PyTorch Model


PARAM = {
    "a": b"aaaaaaaaaa",
    "imo_b": b"bbbbbbbbbb",
    # "c": b"cccccccccc",
    # "d": b"dddddddddd",
    # "2_e": b"eeeeeeeeee",
    # "2_f": b"ffffffffff",
    # "2_g": b"gggggggggg",
}


model = {}
for key, param in PARAM.items():
    state_dict = pickle.loads(bz2.decompress(base64.b64decode(param)))
    if "imo_" in key:
        model[key] = GeeseNetIMO()
    elif "2_" in key:
        model[key] = GeeseNet2()
    else:
        model[key] = GeeseNet()
    model[key].load_state_dict(state_dict)
    model[key].eval()


# Main Function of Agent

obses = []


def agent(obs, _):
    obses.append(obs)
    x = make_input(obses)
    y = make_input_centering_head_for_rule(obses)

    preds = np.zeros((len(PARAM), 4), dtype=np.float32)
    for i, key in enumerate(PARAM.keys()):
        with torch.no_grad():
            xt = torch.from_numpy(x).unsqueeze(0)
            o = model[key](xt)
        p = o["policy"].squeeze(0).detach().numpy()
        preds[i] = p

    inf = np.mean(preds, axis=0)
    inf = apply_rule(y, inf)

    actions = ["NORTH", "SOUTH", "WEST", "EAST"]
    return actions[np.argmax(inf)]

In [None]:
# Read in the submission file
with open(
    "submission.py",
) as file:
    filedata = file.read()

# Replace the target string
filedata = filedata.replace("aaaaaaaaaa", PARAM["a"].decode("utf-8"))
filedata = filedata.replace("bbbbbbbbbb", PARAM["b"].decode("utf-8"))
# filedata = filedata.replace("cccccccccc", PARAM["c"].decode("utf-8"))
# filedata = filedata.replace("dddddddddd", PARAM["d"].decode("utf-8"))
# filedata = filedata.replace("eeeeeeeeee", PARAM["e"].decode("utf-8"))
# filedata = filedata.replace("ffffffffff", PARAM["f"].decode("utf-8"))
# filedata = filedata.replace("gggggggggg", PARAM["g"].decode("utf-8"))

# Write the file out again
with open("submission.py", "w") as file:
    file.write(filedata)

In [None]:
from kaggle_environments import make

env = make("hungry_geese", debug=True)

env.reset()
env.run(["submission.py", "submission.py", "submission.py", "submission.py"])
env.render(mode="ipython", width=800, height=700)