In [9]:
import importlib
from collections import defaultdict, deque

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from kaggle_environments import make

%matplotlib inline

In [10]:
model_module = importlib.import_module("handyrl.model")
env_module = importlib.import_module("handyrl.envs.kaggle.hungry_geese")

In [11]:
e = env_module.Environment()
e.reset()

In [12]:
# while not e.terminal():
for _ in range(50):
    obs = e.observation()
    actions = {}
    for player in e.turns():
        actions[player] = e.rule_based_action_smart_geese(player)
    e.step(actions)
    rewards = e.reward()
# e.outcome()

In [13]:
obs = e.observation()

In [35]:
NUM_ROW = 7
NUM_COL = 11
CENTER_ROW = NUM_ROW // 2
CENTER_COL = NUM_COL // 2


def to_offset(x):
    row = CENTER_ROW - x // NUM_COL
    col = CENTER_COL - x % NUM_COL
    return row, col


def to_row(offset, x):
    return (x // NUM_COL + offset) % NUM_ROW


def to_col(offset, x):
    return (x + offset) % NUM_COL

In [33]:
def make_input_num_turn_of_fill(obses):
    b = np.zeros((7, 11), dtype=np.float32)
    obs = obses[-1]

    player_goose_head = obs[0]["observation"]["geese"][obs[0]["observation"]["index"]][0]
    o_row, o_col = to_offset(player_goose_head)

    for p, geese in enumerate(obs[0]["observation"]["geese"]):
        # マスが、何ターン後に空くか
        for i, pos in enumerate(geese[::-1]):
            b[to_row(o_row, pos), to_col(o_col, pos)] = i + 1

        # 自分の頭は、逆に進めないので、空くのは1周後
        if (p - obs[0]["observation"]["index"]) % 4 == 0:
            for pos in geese[:1]:
                if b[to_row(o_row, pos), to_col(o_col, pos)] < 4:
                    b[to_row(o_row, pos), to_col(o_col, pos)] = 4

    # previous head position
    # 自分が直前にいた場所も、最短で、一周後
    if len(obses) > 1:
        obs_prev = obses[-2]
        for p, geese in enumerate(obs_prev[0]["observation"]["geese"]):
            if (p - obs[0]["observation"]["index"]) % 4 == 0:
                for pos in geese[:1]:
                    if b[to_row(o_row, pos), to_col(o_col, pos)] < 3:
                        b[to_row(o_row, pos), to_col(o_col, pos)] = 3

    return b

In [124]:
def distance(a, b):
    x = b[0] - a[0]
    y = b[1] - a[1]
    return (x, y), abs(x) + abs(y)


def around(a):
    return [
        ((a[0] - 1) % 7, a[1]),
        ((a[0] + 1) % 7, a[1]),
        (a[0], (a[1] - 1) % 11),
        (a[0], (a[1] + 1) % 11),
    ]


def empty_around_head(field, head, x):
    return [e for e in around(x) if field[e[0], e[1]] <= (distance(head, x)[1] + 1)]


def bfs_close_route(field, head, q_border=10):
    q = deque([head])
    searched = defaultdict(bool)
    num_seed = 0
    while len(q) != 0:
        v = q.popleft()
        searched[v] = True
        num_seed += 1
        edges = [a for a in empty_around_head(field, head, v) if not searched[a]]
        print(f"num_seed: {num_seed}, len_q: {len(q)}, queue: {q}")
        if len(q) > q_border:
            break
        for edge in edges:
            q.append(edge)
    return num_seed, len(q)

In [125]:
def apply_rule2(b, prob):
    """
    player head = (3, 5)
    ["NORTH", "SOUTH", "WEST", "EAST"]
    """
    north = (2, 5)
    south = (4, 5)
    west = (3, 4)
    east = (3, 6)
    neighbor = [north, south, west, east]

    # 閉域探索
    for i, n in enumerate(neighbor):
        print(f"neighbor: {n}")
        if b[n[0], n[1]] != 0:
            continue
        num_seed, len_q = bfs_close_route(b, n)
        if len_q == 0:
            prob[i] -= 100_000 // (num_seed + 1)

    return prob

In [129]:
demo_obses = [
    [
        {
            "action": "EAST",
            "reward": 1002,
            "info": {},
            "observation": {
                "remainingOverageTime": 60,
                "step": 9,
                "geese": [
                    [39],
                    [66, 55, 44, 45, 46, 35, 24, 25, 26, 27, 28, 29],
                    [8, 7, 73, 72, 71, 70, 69, 58, 47, 48, 59, 60, 61, 50, 51, 52],
                    [1, 2, 3],
                ],
                "food": [29, 0],
                "index": 0,
            },
            "status": "ACTIVE",
        },
        {
            "action": "NORTH",
            "reward": 1001,
            "info": {},
            "observation": {"remainingOverageTime": 60, "index": 1},
            "status": "ACTIVE",
        },
        {
            "action": "EAST",
            "reward": 1003,
            "info": {},
            "observation": {"remainingOverageTime": 60, "index": 2},
            "status": "ACTIVE",
        },
        {
            "action": "EAST",
            "reward": 1003,
            "info": {},
            "observation": {"remainingOverageTime": 60, "index": 3},
            "status": "ACTIVE",
        },
    ],
    [
        {
            "action": "EAST",
            "reward": 1002,
            "info": {},
            "observation": {
                "remainingOverageTime": 60,
                "step": 9,
                "geese": [
                    [38],
                    [66, 67,  56, 45, 34, 23, 24, 13, 14, 15, 26, 27, 28, 29],
                    [8, 7, 73, 72, 71, 70, 69, 58, 47, 48, 59, 60, 61, 50],
                    [1, 2, 3, 4, 5, 6, 17, 18, 19, 20],
                ],
                "food": [29, 0],
                "index": 0,
            },
            "status": "ACTIVE",
        },
        {
            "action": "NORTH",
            "reward": 1001,
            "info": {},
            "observation": {"remainingOverageTime": 60, "index": 1},
            "status": "ACTIVE",
        },
        {
            "action": "EAST",
            "reward": 1003,
            "info": {},
            "observation": {"remainingOverageTime": 60, "index": 2},
            "status": "ACTIVE",
        },
        {
            "action": "EAST",
            "reward": 1003,
            "info": {},
            "observation": {"remainingOverageTime": 60, "index": 3},
            "status": "ACTIVE",
        },
    ],
]

In [130]:
board = make_input_num_turn_of_fill(demo_obses)
board

array([[ 0., 10.,  9.,  8.,  7.,  6.,  5., 13., 14.,  0.,  0.],
       [ 0.,  0.,  7.,  6.,  5.,  0.,  4.,  3.,  2.,  1.,  0.],
       [ 0.,  9.,  8.,  0.,  4.,  3.,  2.,  1.,  0.,  0.,  0.],
       [ 0., 10.,  0.,  0.,  0.,  4.,  3.,  0.,  0.,  0.,  0.],
       [ 0., 11.,  0.,  6.,  5.,  0.,  1.,  0.,  0.,  0.,  0.],
       [ 0., 12.,  0.,  7.,  4.,  3.,  2.,  0.,  0.,  0.,  0.],
       [14., 13.,  0.,  8.,  9., 10., 11., 12.,  0.,  0.,  0.]],
      dtype=float32)

In [131]:
PROB = np.array([0, 0, 0, 0])

apply_rule2(board, PROB)

neighbor: (2, 5)
neighbor: (4, 5)
num_seed: 1, len_q: 0, queue: deque([])
num_seed: 2, len_q: 0, queue: deque([])
num_seed: 3, len_q: 1, queue: deque([(4, 7)])
num_seed: 4, len_q: 2, queue: deque([(5, 5), (5, 7)])
num_seed: 5, len_q: 4, queue: deque([(5, 7), (3, 7), (5, 7), (4, 8)])
num_seed: 6, len_q: 3, queue: deque([(3, 7), (5, 7), (4, 8)])
num_seed: 7, len_q: 3, queue: deque([(5, 7), (4, 8), (5, 8)])
num_seed: 8, len_q: 5, queue: deque([(4, 8), (5, 8), (2, 7), (3, 6), (3, 8)])
num_seed: 9, len_q: 5, queue: deque([(5, 8), (2, 7), (3, 6), (3, 8), (5, 8)])
num_seed: 10, len_q: 7, queue: deque([(2, 7), (3, 6), (3, 8), (5, 8), (3, 8), (5, 8), (4, 9)])
num_seed: 11, len_q: 8, queue: deque([(3, 6), (3, 8), (5, 8), (3, 8), (5, 8), (4, 9), (6, 8), (5, 9)])
num_seed: 12, len_q: 10, queue: deque([(3, 8), (5, 8), (3, 8), (5, 8), (4, 9), (6, 8), (5, 9), (1, 7), (2, 6), (2, 8)])
num_seed: 13, len_q: 10, queue: deque([(5, 8), (3, 8), (5, 8), (4, 9), (6, 8), (5, 9), (1, 7), (2, 6), (2, 8), (2, 6)]

array([     0,      0, -12500,      0])