In [None]:
import base64
import bz2
import pickle

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

model_path = "../input/hungry-geese-models/latest.pth"

weights = torch.load(model_path)

PARAM = base64.b64encode(bz2.compress(pickle.dumps(weights)))
state_dict = pickle.loads(bz2.decompress(base64.b64decode(PARAM)))

In [None]:
%%writefile submission.py

# This is a lightweight ML agent trained by self-play.
# After sharing this notebook,
# we will add Hungry Geese environment in our HandyRL library.
# https://github.com/DeNA/HandyRL
# We hope you enjoy reinforcement learning!


import base64
import bz2
import pickle

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# Neural Network for Hungry Geese


class Dense(nn.Module):
    def __init__(self, units0, units1, bnunits=0, bias=True):
        super().__init__()
        if bnunits > 0:
            bias = False
        self.dense = nn.Linear(units0, units1, bias=bias)
        self.bnunits = bnunits
        self.bn = nn.BatchNorm1d(bnunits) if bnunits > 0 else None

    def forward(self, x):
        h = self.dense(x)
        if self.bn is not None:
            size = h.size()
            h = h.view(-1, self.bnunits)
            h = self.bn(h)
            h = h.view(*size)
        return h


class TorusConv2d(nn.Module):
    def __init__(self, input_dim, output_dim, kernel_size, bn):
        super().__init__()
        self.edge_size = (kernel_size[0] // 2, kernel_size[1] // 2)
        self.conv = nn.Conv2d(input_dim, output_dim, kernel_size=kernel_size)
        self.bn = nn.BatchNorm2d(output_dim) if bn else None

    def forward(self, x):
        h = torch.cat([x[:, :, :, -self.edge_size[1] :], x, x[:, :, :, : self.edge_size[1]]], dim=3)
        h = torch.cat([h[:, :, -self.edge_size[0] :], h, h[:, :, : self.edge_size[0]]], dim=2)
        h = self.conv(h)
        h = self.bn(h) if self.bn is not None else h
        return h


class Conv2d(nn.Module):
    def __init__(self, input_dim, output_dim, kernel_size, bn):
        super().__init__()
        self.conv = nn.Conv2d(input_dim, output_dim, kernel_size=kernel_size)
        self.bn = nn.BatchNorm2d(output_dim) if bn else None

    def forward(self, x):
        h = self.conv(x)
        h = self.bn(h) if self.bn is not None else h
        return h


class ChannelSELayer(nn.Module):
    def __init__(self, channel, reduction=8):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)


class GeeseNet(nn.Module):
    def __init__(self):
        super().__init__()
        layers, filters = 12, 32
        self.conv0 = TorusConv2d(17, filters, (3, 3), True)
        self.cnn_blocks = nn.ModuleList([TorusConv2d(filters, filters, (3, 3), True) for _ in range(layers)])
        self.cse_blocks = nn.ModuleList([ChannelSELayer(filters, 4) for _ in range(layers)])

        self.conv_p = TorusConv2d(filters, filters, (3, 3), True)
        self.conv_v = TorusConv2d(filters, filters, (3, 3), True)

        # self.head_p = nn.Linear(filters, 1, bias=False)
        self.head_v = nn.Linear(77, 1, bias=False)

    def forward(self, x):
        h = F.relu_(self.conv0(x))
        for cnn, cse in zip(self.cnn_blocks, self.cse_blocks):
            h = cnn(h)
            h = F.relu_(h + cse(h))

        p = self.conv_p(h)
        # p = F.relu_(self.conv_p(h))

        head = x[:, :1]
        head_n = torch.roll(head, shifts=-1, dims=-2)
        head_s = torch.roll(head, shifts=1, dims=-2)
        head_w = torch.roll(head, shifts=-1, dims=-1)
        head_e = torch.roll(head, shifts=1, dims=-1)

        # p_head = (p * head).view(h.size(0), h.size(1), -1).sum(-1)
        p_head_n = (p * head_n).view(h.size(0), h.size(1), -1).sum(-1)
        p_head_s = (p * head_s).view(h.size(0), h.size(1), -1).sum(-1)
        p_head_w = (p * head_w).view(h.size(0), h.size(1), -1).sum(-1)
        p_head_e = (p * head_e).view(h.size(0), h.size(1), -1).sum(-1)

        p = torch.stack([p_head_n, p_head_s, p_head_w, p_head_e], dim=1).mean(-1)
        # p = self.head_p(p).view(p.size(0), p.size(1))

        v = F.relu_(self.conv_v(h))
        v = v.view(h.size(0), h.size(1), -1).mean(1)
        v = torch.tanh(self.head_v(v))

        return {"policy": p, "value": v}


class GeeseNetA(nn.Module):
    class GeeseEncoder(nn.Module):
        def __init__(self, d_model):
            super().__init__()
            # self.edge_size = (1, 2)
            self.d_model = d_model
            self.filters = d_model // 28

            self.conv0 = Conv2d(17, self.filters, (1, 1), True)

        def forward(self, x):
            x = self.conv0(x)

            h = x.view(x.size(0), self.filters, 7, 11)

            # make torus. size: (..., 9, 15)
            # h = torch.cat([h[:,:,:,-self.edge_size[1]:], h, h[:,:,:,:self.edge_size[1]]], dim=3)
            # h = torch.cat([h[:,:,-self.edge_size[0]:], h, h[:,:,:self.edge_size[0]]], dim=2)

            # drop edge 2 cols. size: (..., 7, 7)
            h = h[:, :, :, 2:-2]

            # Info: ['NORTH', 'SOUTH', 'WEST', 'EAST']
            # split into patch 7x4 and flatten
            n = torch.rot90(h[:, :, :4], 1, (2, 3)).reshape(h.size(0), -1)
            s = torch.rot90(h[:, :, 3:], 3, (2, 3)).reshape(h.size(0), -1)
            w = h[:, :, :, :4].reshape(h.size(0), -1)
            e = torch.rot90(h[:, :, :, 3:], 2, (2, 3)).reshape(h.size(0), -1)

            z = torch.stack([n, s, w, e])

            # z = z.permute(1, 0, 2)

            return z

    class GeeseBlock(nn.Module):
        def __init__(self, embed_dim, num_heads):
            super().__init__()
            self.ln = nn.LayerNorm(embed_dim)
            self.attention = nn.MultiheadAttention(embed_dim, num_heads)

        def forward(self, x):
            h = self.ln(x)
            h, _ = self.attention(h, h, h)
            h = h + x
            return h

    class GeeseControll(nn.Module):
        def __init__(self, d_model):
            super().__init__()
            self.d_model = d_model

        def forward(self, x, e):
            # h = x.permute(1, 0, 2)
            h = x
            return h

    class GeeseHead(nn.Module):
        def __init__(self, d_model):
            super().__init__()
            self.filters = d_model // 28
            self.head_v = nn.Linear(49, 1, bias=False)

        def forward(self, x):
            p = x.permute(1, 0, 2).mean(-1)

            v = x.view(x.size(0), x.size(1), self.filters, 7, 4)

            # Info: ['NORTH', 'SOUTH', 'WEST', 'EAST']
            n = torch.rot90(v[0], 3, (2, 3))
            s = torch.rot90(v[1], 1, (2, 3))
            w = v[2]
            e = torch.rot90(v[3], 2, (2, 3))

            n = F.pad(n, (0, 0, 0, 3, 0, 0, 0, 0), mode="constant", value=0)
            s = F.pad(s, (0, 0, 3, 0, 0, 0, 0, 0), mode="constant", value=0)
            w = F.pad(w, (0, 3, 0, 0, 0, 0, 0, 0), mode="constant", value=0)
            e = F.pad(e, (3, 0, 0, 0, 0, 0, 0, 0), mode="constant", value=0)

            v = n + s + w + e
            v = v.mean(1).view(x.size(1), -1)

            v = torch.tanh(self.head_v(v))
            return p, v

    def __init__(self):
        super().__init__()
        d_model = 224  # 28x8
        n_heads = 8
        blocks = 6

        # self.geese_net = GeeseNet()
        self.encoder = self.GeeseEncoder(d_model)

        self.blocks = nn.ModuleList([self.GeeseBlock(d_model, n_heads) for _ in range(blocks)])
        # self.gtrxl = GTrXL(d_model, n_heads, blocks)

        # self.control = self.GeeseControll(d_model)
        self.head = self.GeeseHead(d_model)

    def forward(self, x, _=None):
        # e = self.geese_net(x)["hidden"]
        # h = self.encoder(e["p"])
        h = self.encoder(x)

        for block in self.blocks:
            h = block(h)
        # h = self.gtrxl(h)

        # h = self.control(h, e)
        p, v = self.head(h)
        return {"policy": p, "value": v}


# Input for Neural Network


NUM_ROW = 7
NUM_COL = 11
CENTER_ROW = NUM_ROW // 2
CENTER_COL = NUM_COL // 2


def to_offset(x):
    row = CENTER_ROW - x // NUM_COL
    col = CENTER_COL - x % NUM_COL
    return row, col


def to_row(offset, x):
    return (x // NUM_COL + offset) % NUM_ROW


def to_col(offset, x):
    return (x + offset) % NUM_COL


def make_input(obses):
    b = np.zeros((17, 7 * 11), dtype=np.float32)
    obs = obses[-1]

    for p, pos_list in enumerate(obs["geese"]):
        # head position
        for pos in pos_list[:1]:
            b[0 + (p - obs["index"]) % 4, pos] = 1
        # tip position
        for pos in pos_list[-1:]:
            b[4 + (p - obs["index"]) % 4, pos] = 1
        # whole position
        for pos in pos_list:
            b[8 + (p - obs["index"]) % 4, pos] = 1

    # previous head position
    if len(obses) > 1:
        obs_prev = obses[-2]
        for p, pos_list in enumerate(obs_prev["geese"]):
            for pos in pos_list[:1]:
                b[12 + (p - obs["index"]) % 4, pos] = 1

    # food
    for pos in obs["food"]:
        b[16, pos] = 1

    return b.reshape(-1, 7, 11)


def make_input_centering_head(obses):
    b = np.zeros((17, 7, 11), dtype=np.float32)
    obs = obses[-1]

    player_goose_head = obs["geese"][obs["index"]][0]
    o_row, o_col = to_offset(player_goose_head)

    for p, geese in enumerate(obs["geese"]):
        # whole position
        for pos in geese:
            b[0 + (p - obs["index"]) % 4, to_row(o_row, pos), to_col(o_col, pos)] = 1
        # tip position
        for pos in geese[-1:]:
            b[4 + (p - obs["index"]) % 4, to_row(o_row, pos), to_col(o_col, pos)] = 1
        # head position
        for pos in geese[:1]:
            b[8 + (p - obs["index"]) % 4, to_row(o_row, pos), to_col(o_col, pos)] = 1

    # previous head position
    if len(obses) > 1:
        obs_prev = obses[-2]
        for p, geese in enumerate(obs_prev["geese"]):
            for pos in geese[:1]:
                b[12 + (p - obs["index"]) % 4, to_row(o_row, pos), to_col(o_col, pos)] = 1

    # food
    for pos in obs["food"]:
        b[16, to_row(o_row, pos), to_col(o_col, pos)] = 1

    return b


# Load PyTorch Model


PARAM = b"xxxxxxxxxx"

state_dict = pickle.loads(bz2.decompress(base64.b64decode(PARAM)))
model = GeeseNet()
# model = GeeseNetA()
model.load_state_dict(state_dict)
model.eval()


# Main Function of Agent

obses = []


def agent(obs, _):
    obses.append(obs)
    x = make_input(obses)
    # x = make_input_centering_head(obses)
    with torch.no_grad():
        xt = torch.from_numpy(x).unsqueeze(0)
        o = model(xt)
    p = o["policy"].squeeze(0).detach().numpy()

    actions = ["NORTH", "SOUTH", "WEST", "EAST"]
    return actions[np.argmax(p)]

In [None]:
# Read in the submission file
with open(
    "submission.py",
) as file:
    filedata = file.read()

# Replace the target string
filedata = filedata.replace("xxxxxxxxxx", PARAM.decode("utf-8"))

# Write the file out again
with open("submission.py", "w") as file:
    file.write(filedata)

In [None]:
from kaggle_environments import make

env = make("hungry_geese", debug=True)

env.reset()
env.run(["submission.py", "submission.py", "submission.py", "submission.py"])
env.render(mode="ipython", width=800, height=700)