In [1]:
import importlib
from collections import defaultdict, deque

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from kaggle_environments import make

%matplotlib inline

Loading environment football failed: No module named 'gfootball'


In [2]:
model_module = importlib.import_module("handyrl.model")
env_module = importlib.import_module("handyrl.envs.kaggle.hungry_geese")

In [3]:
e = env_module.Environment()
e.reset()

In [4]:
# while not e.terminal():
for _ in range(50):
    obs = e.observation()
    actions = {}
    for player in e.turns():
        actions[player] = e.rule_based_action_smart_geese(player)
    e.step(actions)
    rewards = e.reward()
# e.outcome()

In [5]:
obs = e.observation()

In [6]:
NUM_ROW = 7
NUM_COL = 11
CENTER_ROW = NUM_ROW // 2
CENTER_COL = NUM_COL // 2


def to_offset(x):
    row = CENTER_ROW - x // NUM_COL
    col = CENTER_COL - x % NUM_COL
    return row, col


def to_row(offset, x):
    return (x // NUM_COL + offset) % NUM_ROW


def to_col(offset, x):
    return (x + offset) % NUM_COL

In [7]:
def make_input_num_turn_of_fill(obses):
    b = np.zeros((7, 11), dtype=np.float32)
    obs = obses[-1]

    player_goose_head = obs[0]["observation"]["geese"][obs[0]["observation"]["index"]][0]
    o_row, o_col = to_offset(player_goose_head)

    for p, geese in enumerate(obs[0]["observation"]["geese"]):
        # マスが、何ターン後に空くか
        for i, pos in enumerate(geese[::-1]):
            b[to_row(o_row, pos), to_col(o_col, pos)] = i + 1

        # 自分の頭は、逆に進めないので、空くのは1周後
        if (p - obs[0]["observation"]["index"]) % 4 == 0:
            for pos in geese[:1]:
                if b[to_row(o_row, pos), to_col(o_col, pos)] < 4:
                    b[to_row(o_row, pos), to_col(o_col, pos)] = 4

    # previous head position
    # 自分が直前にいた場所も、最短で、一周後
    if len(obses) > 1:
        obs_prev = obses[-2]
        for p, geese in enumerate(obs_prev[0]["observation"]["geese"]):
            if (p - obs[0]["observation"]["index"]) % 4 == 0:
                for pos in geese[:1]:
                    if b[to_row(o_row, pos), to_col(o_col, pos)] < 3:
                        b[to_row(o_row, pos), to_col(o_col, pos)] = 3

    return b

In [8]:
def distance(a, b):
    x = b[0] - a[0]
    y = b[1] - a[1]
    return (x, y), abs(x) + abs(y)


def around(a):
    return [
        ((a[0] - 1) % 7, a[1]),
        ((a[0] + 1) % 7, a[1]),
        (a[0], (a[1] - 1) % 11),
        (a[0], (a[1] + 1) % 11),
    ]


def empty_around_head(field, head, x):
    around_x = []
    for e in around(x):
        print(f"head: {head}, pos: {e}, val: {field[e[0], e[1]]}, dist_head: {distance(head, e)[1]}")
        if field[e[0], e[1]] <= (distance(head, e)[1] + 1):
            around_x.append(e)
    return around_x


def bfs_close_route(field, head, q_border=10):
    q = deque([(1, head)])
    searched = defaultdict(bool)
    num_seed = 0
    while len(q) != 0:
        print(f"num_seed: {num_seed}, len_q: {len(q)}, queue: {q}")
        v = q.popleft()
        searched[v[1]] = True
        num_seed += 1
        edges = [(v[0] + 1, a) for a in empty_around_head(field, head, v[1]) if not searched[a]]
        if len(q) > q_border:
            break
        for edge in edges:
            q.append(edge)
    return num_seed, len(q)

In [9]:
def apply_rule2(b, prob):
    """
    player head = (3, 5)
    ["NORTH", "SOUTH", "WEST", "EAST"]
    """
    north = (2, 5)
    south = (4, 5)
    west = (3, 4)
    east = (3, 6)
    neighbor = [north, south, west, east]

    # 閉域探索
    for i, n in enumerate(neighbor):
        print(f"neighbor: {n}")
        if b[n[0], n[1]] != 0:
            continue
        num_seed, len_q = bfs_close_route(b, n)
        if len_q == 0:
            prob[i] -= 100_000 // (num_seed + 1)

    return prob

In [10]:
demo_obses = [
    [
        {
            "action": "EAST",
            "reward": 1002,
            "info": {},
            "observation": {
                "remainingOverageTime": 60,
                "step": 9,
                "geese": [
                    [39],
                    [66, 55, 44, 45, 46, 35, 24, 25, 26, 27, 28, 29],
                    [8, 7, 73, 72, 71, 70, 69, 58, 47, 48, 59, 60, 61, 50, 51, 52],
                    [1, 2, 3],
                ],
                "food": [29, 0],
                "index": 0,
            },
            "status": "ACTIVE",
        },
        {
            "action": "NORTH",
            "reward": 1001,
            "info": {},
            "observation": {"remainingOverageTime": 60, "index": 1},
            "status": "ACTIVE",
        },
        {
            "action": "EAST",
            "reward": 1003,
            "info": {},
            "observation": {"remainingOverageTime": 60, "index": 2},
            "status": "ACTIVE",
        },
        {
            "action": "EAST",
            "reward": 1003,
            "info": {},
            "observation": {"remainingOverageTime": 60, "index": 3},
            "status": "ACTIVE",
        },
    ],
    [
        {
            "action": "EAST",
            "reward": 1002,
            "info": {},
            "observation": {
                "remainingOverageTime": 60,
                "step": 9,
                "geese": [
                    [38],
                    [66, 67,  56, 45, 34, 23, 24, 13, 14, 15, 26, 27, 28, 29],
                    [8, 7, 73, 72, 71, 70, 69, 58, 47, 48, 59, 60, 61, 50, 51],
                    [1, 2, 3, 4, 5, 6, 17, 18, 19, 20],
                ],
                "food": [29, 0],
                "index": 0,
            },
            "status": "ACTIVE",
        },
        {
            "action": "NORTH",
            "reward": 1001,
            "info": {},
            "observation": {"remainingOverageTime": 60, "index": 1},
            "status": "ACTIVE",
        },
        {
            "action": "EAST",
            "reward": 1003,
            "info": {},
            "observation": {"remainingOverageTime": 60, "index": 2},
            "status": "ACTIVE",
        },
        {
            "action": "EAST",
            "reward": 1003,
            "info": {},
            "observation": {"remainingOverageTime": 60, "index": 3},
            "status": "ACTIVE",
        },
    ],
]

In [11]:
board = make_input_num_turn_of_fill(demo_obses)
board

array([[ 0., 10.,  9.,  8.,  7.,  6.,  5., 14., 15.,  0.,  0.],
       [ 0.,  0.,  7.,  6.,  5.,  0.,  4.,  3.,  2.,  1.,  0.],
       [ 0.,  9.,  8.,  0.,  4.,  3.,  2.,  1.,  0.,  0.,  0.],
       [ 0., 10.,  0.,  0.,  0.,  4.,  3.,  0.,  0.,  0.,  0.],
       [ 0., 11.,  0.,  7.,  6.,  0.,  2.,  1.,  0.,  0.,  0.],
       [ 0., 12.,  0.,  8.,  5.,  4.,  3.,  0.,  0.,  0.,  0.],
       [14., 13.,  0.,  9., 10., 11., 12., 13.,  0.,  0.,  0.]],
      dtype=float32)

In [12]:
PROB = np.array([0, 0, 0, 0])

apply_rule2(board, PROB)

neighbor: (2, 5)
neighbor: (4, 5)
num_seed: 0, len_q: 1, queue: deque([(1, (4, 5))])
head: (4, 5), pos: (3, 5), val: 4.0, dist_head: 1
head: (4, 5), pos: (5, 5), val: 4.0, dist_head: 1
head: (4, 5), pos: (4, 4), val: 6.0, dist_head: 1
head: (4, 5), pos: (4, 6), val: 2.0, dist_head: 1
num_seed: 1, len_q: 1, queue: deque([(2, (4, 6))])
head: (4, 5), pos: (3, 6), val: 3.0, dist_head: 2
head: (4, 5), pos: (5, 6), val: 3.0, dist_head: 2
head: (4, 5), pos: (4, 5), val: 0.0, dist_head: 0
head: (4, 5), pos: (4, 7), val: 1.0, dist_head: 2
num_seed: 2, len_q: 3, queue: deque([(3, (3, 6)), (3, (5, 6)), (3, (4, 7))])
head: (4, 5), pos: (2, 6), val: 2.0, dist_head: 3
head: (4, 5), pos: (4, 6), val: 2.0, dist_head: 1
head: (4, 5), pos: (3, 5), val: 4.0, dist_head: 1
head: (4, 5), pos: (3, 7), val: 0.0, dist_head: 3
num_seed: 3, len_q: 4, queue: deque([(3, (5, 6)), (3, (4, 7)), (4, (2, 6)), (4, (3, 7))])
head: (4, 5), pos: (4, 6), val: 2.0, dist_head: 1
head: (4, 5), pos: (6, 6), val: 12.0, dist_head

array([     0,      0, -12500,      0])