In [None]:
from kaggle_environments.envs.hungry_geese.hungry_geese import Observation, Configuration, Action, row_col
from kaggle_environments import evaluate, make, utils

from enum import auto, Enum
from tqdm import tqdm

import numpy as np
import random
import pickle
import json

In [None]:
class Cell(Enum):
    EMPTY = 0
    FOOD = auto()
    GOOSE = auto()

class QAgent():
    def __init__(self, alpha, gamma, epsilon):
        self.q_table = {}
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.action_size = 4 # (EAST, WEST, NORTH, SOUTH)
        self.prev_action = None
        self.FOV = 2
        self.rows = 7
        self.columns = 11

    def encode_state(self, observation):
        """
        Encode state of size FOV (includes wrapping)
        """
        state = []
        player_goose = observation.geese[observation.index]
        player_head = player_goose[0]
        row_0, col_0 = row_col(player_head, self.columns)
        for row_delta in range (-self.FOV, self.FOV+1):
            for col_delta in range (-self.FOV, self.FOV+1):
                row_i = (row_0+row_delta)%self.rows
                col_i = (col_0+col_delta)%self.columns

                pos = self.columns*row_i+col_i
                goose_cells = [cell for geese in observation.geese for cell in geese]
                if pos in goose_cells:
                    state.append(Cell.GOOSE)
                elif pos in observation.food:
                    state.append(Cell.FOOD)
                else:
                    state.append(Cell.EMPTY)

        state = "".join([str(s.value) for s in state])
        return state

    def get_action(self, obs, cfg=None):
        """
        Returns action for training using exploration and exploitation
        """
        state = self.encode_state(obs)

        if not state in self.q_table:
            print("state doesn't exist, choosing random action")
            self.q_table[state] = [0]*self.action_size
            action = random.choice(range(self.action_size))
        else:
            if (random.random() < self.epsilon):
                print("exploration")
                action = random.choice(range(self.action_size))
            else:
                print("exploitation")
                q_state = self.q_table[state]
                action = np.argmax(q_state)
        
        # make sure action is not the opposite of the previous action
        if (self.prev_action == None):
            self.prev_action = Action(action+1)
        else:
            while (Action(action+1) == self.prev_action.opposite()):
                print("opposite action trying again :(")
                action = random.choice(range(self.action_size))
            
            self.prev_action = Action(action+1)
        
        return Action(action+1).name

    def get_action_max(self, obs, cfg=None):
        """
        Returns action using exploitation for testing
        """
        state = self.encode_state(obs)

        if not state in self.q_table:
            action = random.choice(range(self.action_size))
        else:
            q_state = self.q_table[state]
            action = np.argmax(q_state)
        
        return Action(action+1).name
    
    def train(self, experience):
        state, action, next_state, reward = experience
        state = self.encode_state(state)
        next_state = self.encode_state(next_state)

        q_next = self.q_table.get(next_state, 0)
        q_target = reward + self.gamma * np.max(q_next)
        q_temporal_diff = q_target - self.q_table[state][Action[action].value-1]

        self.q_table[state][Action[action].value-1] += self.alpha * q_temporal_diff

    def translate_goose(self, pos, action):
        row, col = row_col(pos, self.columns)
        if (action == 'NORTH'):
            row = row - 1 % self.rows
        elif (action == 'SOUTH'):
            row = row + 1 % self.rows
        elif (action == 'WEST'):
            col = col - 1 % self.columns
        else:
            col = col + 1 % self.columns

        translated_pos = self.columns*row+col
        return translated_pos

    def compute_reward(self, state, action):
        player_goose = state.geese[state.index]
        player_head = player_goose[0]
        pos = self.translate_goose(player_head, action)
        reward = 0

        goose_cells = [cell for geese in state.geese for cell in geese]
        if pos in goose_cells:
            # negative reward for colliding
            reward = -10
        elif pos in state.food:
            # reward for eating food
            reward = 10
        else:
            # reward +1 for moving towards food
            orig_row, orig_col = row_col(player_head, self.columns)
            trans_row, trans_col = row_col(pos, self.columns)
            for row_delta in range (-self.FOV, self.FOV+1):
                for col_delta in range (-self.FOV, self.FOV+1):
                    row_i = (orig_row+row_delta)%self.rows
                    col_i = (orig_col+col_delta)%self.columns

                    cell_pos = self.columns*row_i+col_i
                    if (cell_pos in state.food):
                        food_row, food_col = row_col(cell_pos, self.columns)

                        orig_dist = abs(orig_row-food_row) + abs(orig_col-food_col)
                        trans_dist = abs(trans_row-food_row) + abs(trans_col-food_col)
                        
                        if (trans_dist < orig_dist):
                            reward = 1
        
        return reward

    def save_pickle(self, name):
        save_data = self.q_table
        with open(f'{name}', 'wb') as handle:
            pickle.dump(save_data, handle, protocol=pickle.HIGHEST_PROTOCOL)
    
    def load_pickle(self, name):
        with open(f'{name}', 'rb') as handle:
            data = pickle.load(handle)
            self.q_table = data

In [None]:
agent = QAgent(alpha=0.05, gamma=0.8, epsilon=0.1)
agent.load_pickle("qtable")

In [None]:
# train the qtable
env = make("hungry_geese", debug=True)

trainer = env.train([None, "greedy"]) # modify num goose in training
EPISODES = 500000
episodes = tqdm(range(EPISODES))
for eps in range(EPISODES):
    episode_reward = 0
    state = trainer.reset()
    while not env.done:
        action = agent.get_action(state)
        train_reward = agent.compute_reward(state, action)
        next_state, reward, done, _ = trainer.step(action)
        episode_reward += reward
        if (env.done):
            break
        agent.train((state, action, next_state, train_reward))
        state = next_state
    
    print(episode_reward)

agent.save_pickle("qtable")

In [None]:
# test goose based on trained qtable
test_env = make("hungry_geese", debug=False)
test_env.run([agent.get_action_max, "greedy"])
test_env.render(mode="ipython", width=500, height=400)