In [1]:
import copy
import importlib
import random
from collections import defaultdict, deque

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from kaggle_environments import make
from kaggle_environments.envs.hungry_geese.hungry_geese import Action, Configuration, Observation, row_col

%matplotlib inline

Loading environment football failed: No module named 'gfootball'


In [2]:
model_module = importlib.import_module("handyrl.model")
env_module = importlib.import_module("handyrl.envs.kaggle.hungry_geese")

In [3]:
e = env_module.Environment()
e.reset()

In [4]:
# while not e.terminal():
for _ in range(50):
    obs = e.observation()
    actions = {}
    for player in e.turns():
        actions[player] = e.rule_based_action_smart_geese(player)
    e.step(actions)
    rewards = e.reward()
# e.outcome()

In [5]:
obs = e.observation()

In [28]:
demo_config = {
    "actTimeout": 1,
    "columns": 11,
    "episodeSteps": 200,
    "hunger_rate": 40,
    "max_length": 99,
    "min_food": 2,
    "rows": 7,
    "runTimeout": 1200,
}

In [73]:
demo_last_obs = {
    "food": [18, 53],
    "geese": [
        [3, 2, 13, 12, 11, 21, 32, 43, 33, 22, 23, 24, 25, 14],
        [74, 75, 9, 8, 19, 30, 31, 42, 41, 52, 63],
        [27, 16, 5, 6, 7, 73, 62, 51, 40, 29, 28],
        [71, 70, 4, 15, 26, 37, 48, 49, 38, 39, 50, 61, 72],
    ],
    "index": 0,
    "remainingOverageTime": 20.105263000000026,
    "step": 177,
}

In [74]:
demo_obs = {
    "food": [18, 53],
    "geese": [
        [14, 3, 2, 13, 12, 11, 21, 32, 43, 33, 22, 23, 24, 25],
        [63, 74, 75, 9, 8, 19, 30, 31, 42, 41, 52],
        [28, 27, 16, 5, 6, 7, 73, 62, 51, 40, 29],
        [72, 71, 70, 4, 15, 26, 37, 48, 49, 38, 39, 50, 61],
    ],
    "index": 0,
    "remainingOverageTime": 20.020537000000026,
    "step": 178,
}

In [75]:
opposites = {Action.EAST: Action.WEST, Action.WEST: Action.EAST, Action.NORTH: Action.SOUTH, Action.SOUTH: Action.NORTH}
action_meanings = {Action.EAST: (1, 0), Action.WEST: (-1, 0), Action.NORTH: (0, -1), Action.SOUTH: (0, 1)}
action_names = {
    (1, 0): Action.EAST,
    (-10, 0): Action.EAST,
    (-1, 0): Action.WEST,
    (10, 0): Action.WEST,
    (0, -1): Action.NORTH,
    (0, 6): Action.NORTH,
    (0, -6): Action.SOUTH,
    (0, 1): Action.SOUTH,
}
strValue = {Action.EAST: "EAST", Action.WEST: "WEST", Action.NORTH: "NORTH", Action.SOUTH: "SOUTH"}

frame = 0
all_last_actions = [None, None, None, None]
revert_last_actions = [None, None, None, None]
last_observation = None

In [76]:
class Obs:
    pass

In [77]:
def setLastActions(observation, configuration):
    global frame, revert_last_actions, all_last_actions
    if not frame == 0:
        for i in range(4):
            setLastAction(observation, configuration, i)
    revert_last_actions = copy.deepcopy(all_last_actions)

In [78]:
def revertLastActions():
    global revert_last_actions, all_last_actions
    all_last_actions = copy.deepcopy(revert_last_actions)

In [79]:
def setLastAction(observation, configuration, gooseIndex):
    global last_observation, all_last_actions, action_names
    if len(observation.geese[gooseIndex]) > 0:
        oldGooseRow, oldGooseCol = row_col(last_observation.geese[gooseIndex][0], configuration.columns)
        newGooseRow, newGooseCol = row_col(observation.geese[gooseIndex][0], configuration.columns)
        all_last_actions[gooseIndex] = action_names[
            ((newGooseCol - oldGooseCol) % configuration.columns, (newGooseRow - oldGooseRow) % configuration.rows)
        ]

In [80]:
def getValidDirections(observation, configuration, gooseIndex):
    global all_last_actions, opposites
    directions = [Action.EAST, Action.WEST, Action.NORTH, Action.SOUTH]
    returnDirections = []
    for direction in directions:
        row, col = getRowColForAction(observation, configuration, gooseIndex, direction)
        if (
            not willGooseBeThere(observation, configuration, row, col)
            and not all_last_actions[gooseIndex] == opposites[direction]
        ):
            returnDirections.append(direction)
    if len(returnDirections) == 0:
        return directions
    return returnDirections

In [81]:
def randomTurn(observation, configuration, actionOverrides, rewards, fr):
    newObservation = cloneObservation(observation)
    for i in range(4):
        if len(observation.geese[i]) > 0:
            if i in actionOverrides.keys():
                newObservation = performActionForGoose(
                    observation, configuration, i, newObservation, actionOverrides[i]
                )
            else:
                newObservation = randomActionForGoose(observation, configuration, i, newObservation)

    checkForCollisions(newObservation, configuration)
    updateRewards(newObservation, configuration, rewards, fr)
    hunger(newObservation, fr)
    return newObservation

In [82]:
def hunger(observation, fr):
    if fr % 40 == 0:
        for g, goose in enumerate(observation.geese):
            goose = goose[0 : len(goose) - 1]

In [83]:
def updateRewards(observation, configuration, rewards, fr):
    for g, goose in enumerate(observation.geese):
        if len(goose) > 0:
            rewards[g] = 100 * fr + len(goose)

In [84]:
def checkForCollisions(observation, configuration):
    killed = []
    for g, goose in enumerate(observation.geese):
        if len(goose) > 0:
            for o, otherGoose in enumerate(observation.geese):
                for p, part in enumerate(otherGoose):
                    if not (o == g and p == 0):
                        if goose[0] == part:
                            killed.append(g)

    for kill in killed:
        observation.geese[kill] = []

In [85]:
def cloneObservation(observation):
    newObservation = Obs()
    newObservation.index = observation.index
    newObservation.geese = copy.deepcopy(observation.geese)
    newObservation.food = copy.deepcopy(observation.food)
    return newObservation

In [86]:
def randomActionForGoose(observation, configuration, gooseIndex, newObservation):
    validActions = getValidDirections(observation, configuration, gooseIndex)
    action = random.choice(validActions)
    row, col = getRowColForAction(observation, configuration, gooseIndex, action)
    newObservation.geese[gooseIndex] = [row * configuration.columns + col] + newObservation.geese[gooseIndex]
    if not isFoodThere(observation, configuration, row, col):
        newObservation.geese[gooseIndex] = newObservation.geese[gooseIndex][
            0 : len(newObservation.geese[gooseIndex]) - 1
        ]
    return newObservation

In [87]:
def performActionForGoose(observation, configuration, gooseIndex, newObservation, action):
    row, col = getRowColForAction(observation, configuration, gooseIndex, action)
    newObservation.geese[gooseIndex][:0] = [row * configuration.columns + col]
    if not isFoodThere(observation, configuration, row, col):
        newObservation.geese[gooseIndex] = newObservation.geese[gooseIndex][
            0 : len(newObservation.geese[gooseIndex]) - 1
        ]
    return newObservation

In [88]:
def isFoodThere(observation, configuration, row, col):
    for food in observation.food:
        foodRow, foodCol = row_col(food, configuration.columns)
        if foodRow == row and foodCol == col:
            return True
    return False

In [89]:
def willGooseBeThere(observation, configuration, row, col):
    for goose in observation.geese:
        for p, part in enumerate(goose):
            if not p == len(goose) - 1:
                partRow, partCol = row_col(part, configuration.columns)
                if partRow == row and partCol == col:
                    return True
    return False

In [90]:
def getRowColForAction(observation, configuration, gooseIndex, action):
    global action_meanings
    gooseRow, gooseCol = row_col(observation.geese[gooseIndex][0], configuration.columns)
    actionRow = (gooseRow + action_meanings[action][1]) % configuration.rows
    actionCol = (gooseCol + action_meanings[action][0]) % configuration.columns
    return actionRow, actionCol

In [91]:
def simulateMatch(observation, configuration, firstMove, depth):
    global frame
    actionOverrides = {observation.index: firstMove}
    revertLastActions()
    simulationFrame = frame + 1
    newObservation = cloneObservation(observation)
    rewards = [0, 0, 0, 0]
    count = 0
    while count < depth:
        newObservation = randomTurn(newObservation, configuration, actionOverrides, rewards, simulationFrame)
        actionOverrides = {}
        simulationFrame += 1
        count += 1
    return rewards

In [92]:
def simulateMatches(observation, configuration, numMatches, depth):
    options = getValidDirections(observation, configuration, observation.index)
    rewardTotals = []
    for o, option in enumerate(options):
        rewardsForOption = [0, 0, 0, 0]
        for i in range(numMatches):
            matchRewards = simulateMatch(observation, configuration, option, depth)
            for j in range(4):
                rewardsForOption[j] += matchRewards[j]
        rewardTotals.append(rewardsForOption)
    scores = []
    for o, option in enumerate(options):
        rewards = rewardTotals[o]
        if len(rewards) <= 0:
            mean = 0
        else:
            mean = sum(rewards) / len(rewards)
        if mean == 0:
            scores.append(0)
        else:
            scores.append(rewards[observation.index] / mean)

    print("frame: ", frame)
    print("options: ", options)
    print("scores: ", scores)
    print("reward totals: ", rewardTotals)
    print("lengths: ")
    print("0: ", len(observation.geese[0]))
    print("1: ", len(observation.geese[1]))
    print("2: ", len(observation.geese[2]))
    print("3: ", len(observation.geese[3]))

    return options[scores.index(max(scores))]

In [93]:
configuration = Configuration(demo_config)

In [94]:
observation = Observation(demo_last_obs)

In [95]:
last_observation = cloneObservation(observation)

In [96]:
frame = 178

In [97]:
observation = Observation(demo_obs)

In [98]:
setLastActions(observation, configuration)

In [99]:
myLength = len(observation.geese[observation.index])

In [100]:
simulateMatches(observation, configuration, 10, 5)

frame:  178
options:  [<Action.SOUTH: 3>]
scores:  [1.0046031957408428]
reward totals:  [[182940, 182919, 181118, 181430]]
lengths: 
0:  14
1:  11
2:  11
3:  13


<Action.SOUTH: 3>