In [1]:
from kaggle_environments import make
from kaggle_environments.envs.hungry_geese.hungry_geese import Observation, Configuration, Action, row_col

import numpy as np

Loading environment football failed: No module named 'gfootball'


In [55]:
def geese_heads(obs_dict, config_dict):
    """
    Return the position of the geese's heads
    """
    configuration = Configuration(config_dict)

    observation = Observation(obs_dict)
    player_index = observation.index
    player_goose = observation.geese[player_index]
    player_head = player_goose[0]
    player_row, player_column = row_col(player_head, configuration.columns)
    positions = []
    for geese in observation.geese:
        if len(geese)>0:
            geese_head = geese[0]
            row, column = row_col(geese_head, configuration.columns)
        else:
            row = None
            column = None
        positions.append((row, column))
    return positions

def get_last_actions(previous_geese_heads, heads_positions):

    def get_last_action(prev, cur):
        last_action = None

        prev_row = prev[0]
        prev_col = prev[1]
        cur_row = cur[0]
        cur_col = cur[1]

        if cur_row is not None:
            if (cur_row-prev_row == 1) | ((cur_row==0) & (prev_row==6)):
                last_action = Action.SOUTH.name
            elif (cur_row-prev_row == -1) | ((cur_row==6) & (prev_row==0)):
                last_action = Action.NORTH.name
            elif (cur_col-prev_col == 1) | ((cur_col==0) & (prev_col==10)):
                last_action = Action.EAST.name
            elif (cur_col-prev_col == -1) | ((cur_col==10) & (prev_col==0)):
                last_action = Action.WEST.name

        return last_action

    if len(previous_geese_heads) == 0:
        actions = [Action.SOUTH.name, Action.NORTH.name, Action.EAST.name, Action.WEST.name]
        nb_geeses = len(heads_positions)
        last_actions = ["None" for _ in range(nb_geeses)]
    else:   
        last_actions = [get_last_action(*pos) for pos in zip(previous_geese_heads, heads_positions)]

    return last_actions
    
def central_state_space(obs_dict, config_dict, last_actions):
    """
    Recreating a board where my agent's head in the middle of the board 
    (position (4,5)), and creating features accordingly
    """
    
    last_actions_dict = {
        Action.WEST.name: 1,
        Action.EAST.name: 2,
        Action.NORTH.name: 3,
        Action.SOUTH.name: 4,
        "None": 16
    }
    
    configuration = Configuration(config_dict)

    observation = Observation(obs_dict)
    player_index = observation.index
    player_goose = observation.geese[player_index]
    player_head = player_goose[0]
    longuest_opponent = 0
    for i, goose in enumerate(observation.geese):
        if i != player_index:
            opponent_length = len(goose)
            if opponent_length > longuest_opponent:
                longuest_opponent = opponent_length
    player_row, player_column = row_col(player_head, configuration.columns)
    row_offset = player_row - 3
    column_offset = player_row - 5

    foods = observation['food']

    def centralize(row, col):
        if col > player_column:
            new_col = (5 + col - player_column) % 11
        else:
            new_col = 5 - (player_column - col)
            if new_col < 0:
                new_col += 11

        if row > player_row:
            new_row = (3 + row - player_row) % 7
        else:
            new_row = 3 - (player_row - row)
            if new_row < 0:
                new_row += 7
        return new_row, new_col

    food1_row, food1_column = centralize(*row_col(foods[0], configuration.columns))
    food2_row, food2_column = centralize(*row_col(foods[1], configuration.columns))

    food1_row_feat = float(food1_row - 3)/5 if food1_row>=3 else float(food1_row - 3)/5
    food2_row_feat = float(food2_row - 3)/5 if food2_row>=3 else float(food2_row - 3)/5

    food1_col_feat = float(food1_column - 5)/5 if food1_column>=5 else float(food1_column - 5)/5
    food2_col_feat = float(food2_column - 5)/5 if food2_column>=5 else float(food2_column - 5)/5

    # Create the grid
    board = np.zeros([7, 11])
    # Add food to board
    board[food1_row, food1_column] = 15
    board[food2_row, food2_column] = 15

    for geese_id, geese in enumerate(observation.geese):
        nb_blocks = len(geese)
        if nb_blocks > 0:
            for i, pix in enumerate(geese[::-1]):
                if ((i+1) == nb_blocks): #This is the head
                    idx = last_actions_dict[last_actions[geese_id]] #head
                else:
                    idx = (i+5) if (i+5)<15 else 14
                row, col = centralize(*row_col(pix, configuration.columns))
                board[row, col] = idx
            
    return board, len(player_goose), longuest_opponent, food1_row_feat, food1_col_feat, food2_row_feat, food2_col_feat

In [60]:
class RuleBasedAgent:
    """
    Rule based agent - 
    We will use this rule-based agent to collect state-space data and the actions to take.
    An initial neural network will be trained to learn this rule-based policy.
    The neural network will then be improved using RL methods.
    """
    def __init__(self):
        self.last_action = None
        self.last_heads_positions = []
        self.stateSpace = None
        
    def getStateSpace(self, obs_dict, config_dict):
        heads_positions = geese_heads(obs_dict, config_dict)
        last_actions = get_last_actions(self.last_heads_positions, heads_positions)
        
        board, player_goose_len, longuest_opponent, food1_row_feat, food1_col_feat, food2_row_feat, food2_col_feat = central_state_space(obs_dict, config_dict, last_actions)
        
        cur_obs = {}
        cur_obs['food1_col'] = food1_col_feat
        cur_obs['food2_col'] = food2_col_feat
        cur_obs['food1_row'] = food1_row_feat
        cur_obs['food2_row'] = food2_row_feat
        cur_obs['goose_size'] = player_goose_len
        cur_obs['longuest_opponent'] = longuest_opponent
        cur_obs['board'] = board
        cur_obs['hunger'] = -1 + (float(obs_dict['step']%40)/20)
        cur_obs['step'] = (float(obs_dict['step'])/100) - 1
        
        return cur_obs, heads_positions, last_actions
    def __call__(self, obs_dict, config_dict):
        cur_obs, heads_positions, last_actions = self.getStateSpace(obs_dict, config_dict)
        
        food1_col_feat = cur_obs['food1_col'] 
        food2_col_feat = cur_obs['food2_col'] 
        food1_row_feat = cur_obs['food1_row']
        food2_row_feat = cur_obs['food2_row']
        player_goose_len = cur_obs['goose_size']
        board = cur_obs['board'] 
        cur_obs['hunger'] = -1 + (float(obs_dict['step']%40)/20)
        cur_obs['step'] = (float(obs_dict['step'])/100) - 1

        self.stateSpace = cur_obs
        
        # Prioritize food that is closer
        if (abs(food1_row_feat) + abs(food1_col_feat)) <= (abs(food2_row_feat) + abs(food2_col_feat)):
            p1_food_row_feat = food1_row_feat
            p1_food_col_feat = food1_col_feat
            p2_food_row_feat = food2_row_feat
            p2_food_col_feat = food2_col_feat
        else:
            p1_food_row_feat = food2_row_feat
            p1_food_col_feat = food2_col_feat
            p2_food_row_feat = food1_row_feat
            p2_food_col_feat = food1_col_feat
            

        action = None
        
        
        action_dict = {}

        # For each possible action, we create a value using the following logic:
            # Is action eligible? If yes, +10 000 points
            # Will the action kill us right away? if no, +1000 points
            # Is there a possibility that any other player 
                # move to that same box at that same step? If no, +100 points
            # Is this action getting us closer to the nearest food? If yes, +10 points
            # Is this action getting us closer to the other food? If yes, +1 points
            
        # We then take the action with the most points (won't kill us and
        # brings us toward food if possible)
        
        hunger_boost = 1
        if player_goose_len == 1:
            if (40 - obs_dict['step']%40) < 6:
                hunger_boost = 10
            elif (40 - obs_dict['step']%40) < 3:
                hunger_boost = 100
                
        action_dict[Action.WEST.name] = 0
        # Is action eligible?
        if (self.last_action is None) | (self.last_action != Action.EAST.name):
            action_dict[Action.WEST.name] += 1E7
        # Will the action kill us right away?
        if (board[3, 4] == 0) | (board[3, 4] == 15):
            action_dict[Action.WEST.name] += 1E6
        # Will the action kill us on the subsequent step?:
        if not((board[2, 4] in list(range(6,15))) & (board[3, 3] in list(range(6,15))) & (board[4, 4] in list(range(6,15)))):
            action_dict[Action.WEST.name] += 1E5
        # Could the action kill us on the subsequent step? - is there a head nearby?
        if not((board[2, 4] in list(range(1,5))) & (board[3, 3] in list(range(1,5))) & (board[4, 4] in list(range(1,5)))):
            action_dict[Action.WEST.name] += 1E4
        # Could the action kill us on the subsequent step? - is there a head further?
        if  (not ((board[2, 3] in list(range(1,5))) | (board[3, 2] in list(range(1,5))) | (board[4, 3] in list(range(1,5))))):
            action_dict[Action.WEST.name] += 1E3
        # Is there a possibility that any other player 
        # move to that same box at that same step?
        if (board[3, 3] in [0, 1, 15]) & (board[4, 4] in [0, 4, 15]) & (board[2, 4] in [0, 3, 15]):
            action_dict[Action.WEST.name] += 1E2
        # Is this action getting us closer to the nearest food?
        if p1_food_col_feat < 0:
            action_dict[Action.WEST.name] += 1E1 * hunger_boost
        # Is this action getting us closer to the other food?
        if p2_food_col_feat < 0:
            action_dict[Action.WEST.name] += 1E0 * hunger_boost
            
        action_dict[Action.EAST.name] = 0
        if (self.last_action is None) | (self.last_action != Action.WEST.name):
            action_dict[Action.EAST.name] += 1E7
        if (board[3, 6] == 0) | (board[3, 6] == 15):
            action_dict[Action.EAST.name] += 1E6
        # Will the action kill us on the subsequent step?:
        if not((board[2, 6] in list(range(6,15))) & (board[3, 7] in list(range(6,15))) & (board[4, 6] in list(range(6,15)))):
            action_dict[Action.EAST.name] += 1E5
        # Could the action kill us on the subsequent step? - is there a head nearby?
        if not((board[2, 6] in list(range(1,5))) & (board[3, 7] in list(range(1,5))) & (board[4, 6] in list(range(1,5)))):
            action_dict[Action.EAST.name] += 1E4
        # Could the action kill us on the subsequent step? - is there a head further?
        if  (not ((board[2, 7] in list(range(1,5))) | (board[3, 8] in list(range(1,5))) | (board[4, 7] in list(range(1,5))))):
            action_dict[Action.EAST.name] += 1E3
        if (board[3, 7] in [0, 2, 15]) & (board[4, 6] in [0, 4, 15]) & (board[2, 6] in [0, 3, 15]):
            action_dict[Action.EAST.name] += 1E2
        if p1_food_col_feat > 0:
            action_dict[Action.EAST.name] += 1E1 * hunger_boost
        if p2_food_col_feat > 0: 
            action_dict[Action.EAST.name] += 1E0 * hunger_boost
            

        action_dict[Action.NORTH.name] = 0
        if (self.last_action is None) | (self.last_action != Action.SOUTH.name):
            action_dict[Action.NORTH.name] += 1E7
        if (board[2, 5] == 0) | (board[2, 5] == 15):
            action_dict[Action.NORTH.name] += 1E6
        # Will the action kill us on the subsequent step?:
        if not((board[2, 4] in list(range(6,15))) & (board[2, 6] in list(range(6,15))) & (board[1, 5] in list(range(6,15)))):
            action_dict[Action.NORTH.name] += 1E5
        # Will the action kill us on the subsequent step?  - is there a head nearby?
        if not((board[2, 4] in list(range(1,5))) & (board[2, 6] in list(range(1,5))) & (board[1, 5] in list(range(1,5)))):
            action_dict[Action.NORTH.name] += 1E4
        # Could the action kill us on the subsequent step? - is there a head further?
        if  (not ((board[1, 4] in list(range(1,5))) | (board[0, 5] in list(range(1,5))) | (board[1, 6] in list(range(1,5))))):
            action_dict[Action.NORTH.name] += 1E3
        if (board[1, 5] in [0, 3, 15]) & (board[2, 4] in [0, 1, 15]) & (board[2, 6] in [0, 2, 15]):
            action_dict[Action.NORTH.name] += 1E2
        if p1_food_row_feat < 0:
            action_dict[Action.NORTH.name] += 1E1 * hunger_boost
        if p2_food_row_feat < 0:
            action_dict[Action.NORTH.name] += 1E0 * hunger_boost
            
        action_dict[Action.SOUTH.name] = 0
        if (self.last_action is None) | (self.last_action != Action.NORTH.name):
            action_dict[Action.SOUTH.name] += 1E7
        if (board[4, 5] == 0) | (board[4, 5] == 15):
            action_dict[Action.SOUTH.name] += 1E6
        # Will the action kill us on the subsequent step?:
        if not((board[4, 4] in list(range(6,15))) & (board[4, 6] in list(range(6,15))) & (board[5, 5] in list(range(6,15)))):
            action_dict[Action.SOUTH.name] += 1E5
        # Will the action kill us on the subsequent step? - is there a head nearby?
        if not((board[4, 4] in list(range(1,5))) & (board[4, 6] in list(range(1,5))) & (board[5, 5] in list(range(1,5)))):
            action_dict[Action.SOUTH.name] += 1E4
        # Could the action kill us on the subsequent step? - is there a head further?
        if  (not ((board[5, 4] in list(range(1,5))) | (board[6, 5] in list(range(1,5))) | (board[5, 6] in list(range(1,5))))):
            action_dict[Action.SOUTH.name] += 1E3
        if (board[5, 5] in [0, 4, 15]) & (board[4, 4] in [0, 1, 15]) & (board[4, 6] in [0, 2, 15]):
            action_dict[Action.SOUTH.name] += 1E2
        if p1_food_row_feat > 0:
            action_dict[Action.SOUTH.name] += 1E1 * hunger_boost
        if p2_food_row_feat > 0:
            action_dict[Action.SOUTH.name] += 1E0 * hunger_boost
        
        actions = [Action.SOUTH.name, Action.NORTH.name, Action.EAST.name, Action.WEST.name]
        values = [action_dict[action] for action in actions]
        
        action = actions[np.argmax(values)]

        
        self.last_action = action
        self.last_heads_positions = heads_positions
        return action

In [61]:
from random import choice
from copy import deepcopy
from kaggle_environments.envs.hungry_geese.hungry_geese import Observation, Configuration, Action, \
                                                                row_col, adjacent_positions, translate, min_distance
class GreedyAgent:
    def __init__(self):
        
        self.last_action = None
        self.observations = []

    def __call__(self, observation: Observation, configuration: Configuration):
        self.configuration = configuration
        
        board = np.zeros(self.configuration.rows*self.configuration.columns)
        board_shape = (self.configuration.rows, self.configuration.columns)
        
        board_heads = deepcopy(board)
        board_bodies = deepcopy(board)
        board_rewards = deepcopy(board)
        
        
        rows, columns = self.configuration.rows, self.configuration.columns

        food = observation.food
        geese = observation.geese
        
        
        opponents = [
            goose
            for index, goose in enumerate(geese)
            if index != observation.index and len(goose) > 0
        ]

        
        opponent_heads = [opponent[0] for opponent in opponents]
        # Don't move adjacent to any heads
        head_adjacent_positions = {
            opponent_head_adjacent
            for opponent_head in opponent_heads
            for opponent_head_adjacent in adjacent_positions(opponent_head, columns, rows)
        }
        
        tail_adjacent_positions ={
            opponent_tail_adjacent
            for opponent in opponents
            for opponent_tail in [opponent[-1]]
            for opponent_tail_adjacent in adjacent_positions(opponent_tail, columns, rows)
        }
        # Don't move into any bodies
        #bodies, heads = [position for goose in geese for position in goose]
        
        heads = [i[0] for i in geese if len(i)>1]
        bodies = [item for sublist in geese for item in sublist]
        
        board_bodies[list(bodies)] = 1
        board_heads[heads] = 1

        # Move to the closest food
        position = geese[observation.index][0]
        actions = {
            action: min_distance(new_position, food, columns)
            for action in Action
            for new_position in [translate(position, action, columns, rows)]
            if (
                new_position not in head_adjacent_positions and
                new_position not in bodies and
                (self.last_action is None or action != self.last_action.opposite())
            )
        }

        action = min(actions, key=actions.get) if any(actions) else choice([action for action in Action])
        
        
        cur_obs = {}
        cur_obs['head_adjacent_positions'] = head_adjacent_positions
        cur_obs['bodies'] = bodies
        cur_obs['board_bodies'] = board_bodies.reshape(board_shape)
        cur_obs['board_heads'] = board_heads.reshape(board_shape)
        cur_obs['tails'] = tail_adjacent_positions
        cur_obs['actions'] = actions
        cur_obs['action'] = action
        cur_obs['last_action'] = self.last_action
#         cur_obs['goose_size'] = player_goose_len
#         cur_obs['board'] = board
        cur_obs['cur_action'] = action
        self.observations.append(cur_obs)
        
        self.last_action = action
        return action.name


cached_greedy_agents = {}


def greedy_agent(obs, config):
    index = obs["index"]
    if index not in cached_greedy_agents:
        cached_greedy_agents[index] = GreedyAgent(Configuration(config))
    return cached_greedy_agents[index](Observation(obs))

In [62]:
from kaggle_environments import evaluate, make, utils

# Setup a hungry_geese environment.
env = make("hungry_geese", debug = True)
my_agent = RuleBasedAgent()
env.run([my_agent] + ["greedy" for i in range(7)])
env.render(mode="ipython", width=600, height=650)

Opposite action: (3, <Action.WEST: 4>, <Action.EAST: 2>)
Goose Collision: NORTH
Goose Collision: WEST
Body Hit: (4, <Action.EAST: 2>, 40, [39, 50, 51, 52, 41, 40, 29])
Opposite action: (1, <Action.NORTH: 1>, <Action.SOUTH: 3>)
Opposite action: (2, <Action.NORTH: 1>, <Action.SOUTH: 3>)
Opposite action: (5, <Action.WEST: 4>, <Action.EAST: 2>)


In [63]:
my_agent.stateSpace

{'food1_col': -0.6,
 'food2_col': -0.2,
 'food1_row': -0.4,
 'food2_row': -0.6,
 'goose_size': 15,
 'longuest_opponent': 13,
 'board': array([[ 0.,  0.,  9.,  8., 15., 14.,  0.,  6.,  0.,  0.,  0.],
        [ 0.,  0., 15.,  7.,  0., 14.,  0.,  5.,  0.,  0.,  0.],
        [ 0.,  0.,  5.,  6.,  0., 14.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  4.,  0.,  0.,  0.,  0.,  0.],
        [13., 14., 14., 14.,  2.,  0.,  0.,  9., 10., 11., 12.],
        [ 0.,  0., 11., 12., 13.,  0.,  0.,  8.,  0.,  0.,  0.],
        [ 0.,  0., 10.,  0., 14., 14.,  0.,  7.,  0.,  0.,  0.]]),
 'hunger': 0.55,
 'step': 0.51}

In [599]:
actions_list = np.array(['EAST',
                        'WEST',
                        'SOUTH',
                        'NORTH'])
def action_to_target(action):
    pos = np.argmax(actions_list == action)
    target = np.zeros(4)
    target[pos] = 1
    return target

def target_to_action(target):
    pos = np.argmax(target)
    return actions_list[pos]

def pred_to_action(pred):
    pos = np.random.multinomial(1, pred)
    return actions_list[pos]

In [600]:
def add_numerical(steps):
    numerical = ['goose_size',
                 'longuest_opponent',
                 'hunger',
                 'step']
    food_position = ['food1_col',
                     'food2_col',
                     'food1_row',
                     'food2_row']
    for step in steps:
        numerical_vector = np.zeros(len(numerical)+1)
        for i, nm in enumerate(numerical):
            if nm in ['goose_size', 'longuest_opponent']:
                numerical_vector[i] = (float(step['cur_state'][nm])-8)/16
            else:
                numerical_vector[i] = step['cur_state'][nm]
        food_position_vector = np.zeros(len(food_position))
        for i, nm in enumerate(food_position):
            food_position_vector[i] = step['cur_state'][nm]
        numerical_vector[len(numerical)] = float(step['cur_state']['goose_size'] - step['cur_state']['longuest_opponent'])/10
        step['numerical'] = numerical_vector
        step['food_position_vector'] = food_position_vector
    return None

In [601]:
def add_embeddings(steps):
    numerical = ['food1_col',
                'food2_col',
                'food1_row',
                'food2_row',
                'goose_size',
                'hunger',
                'step']
    for step in steps:
        #vector = np.zeros(7*11, dtype=int)
        vector = []
        board = step['cur_state']['board']
        for row in range(7):
            for col in range(11):
               #vector[11*row + col] =  np.array(board[row][col], dtype=int)
               vector.append(int(board[row][col]))
        step['embeddings'] = vector
    return None

In [602]:
def add_state_value(discount, steps):
    steps_back = steps[::-1]
    v_prime = 0
    for step in steps_back:
        v = step['reward'] + discount*v_prime
        v_prime = v
        step['v'] = v
    return None

In [603]:
def add_next_state(steps):
    nb_steps = len(steps)
    for i, step in enumerate(steps):
        if step['status'] == 'ACTIVE':
            step['next_embeddings'] = steps[i+1]['embeddings']
            step['next_food_position_vector'] = steps[i+1]['food_position_vector']
            step['next_numerical'] = steps[i+1]['numerical']
        else:
            step['next_embeddings'] = None
            step['next_food_position_vector'] = None
            step['next_numerical'] = None

In [604]:
def process(discount, episodes):
    for episode in episodes:
        add_embeddings(episode)
        add_numerical(episode)
        add_state_value(discount, episode)
        add_next_state(episode)
    return None

In [None]:
def training_data(episodes):
    targets = []
    next_numerical = []
    next_embeddings = []
    reward = []
    done = []
    v = []
    actions = []
    numerical = []
    embeddings = []
    for episode in episodes:
        for step in episode:
            action = step['action']
            target = action_to_target(action)
            targets.append(target)
            num = step['numerical']
            emb = step['embeddings']
            next_numerical = step['next_numerical']
            next_embeddings = step['next_embeddings']
            numerical.append(num)
            embeddings.append(emb)
            actions.append(action)
            v.append(step['v'])
            done.append(step['done'])
            reward.append(step['reward'])

    target_reshaped = np.array(targets).reshape(-1, 4)
    e = [np.array(embeddings)[:, i].reshape(-1, 1) for i in range(7*11)]
    n = [np.array(numerical)[:, i].reshape(-1, 1) for i in range(5)]
    train = n+e

    e_next = [np.array(next_embeddings)[:, i].reshape(-1, 1) for i in range(7*11)]
    n_next = [np.array(next_numerical)[:, i].reshape(-1, 1) for i in range(5)]
    train_next = n_next+e_next

    training_dict = {'state': train,
                     'action': action,
                     'next_state': train_next,
                     'y': target_reshaped,
                     'reward': reward,
                     'v': v,
                     'done': done}
    return training_dict

In [683]:
step_reward = 0
dying_reward = 0
winning_reward = 1
step_200_reward = lambda my_goose, longuest_opponent: winning_reward if my_goose > longuest_opponent else 0
win_game_reward = lambda step, my_goose, longuest_opponent: winning_reward #max((200-step), winning_reward)

discount = 1

nb_opponents = 1

steps_per_ep = 200
num_episodes = 100000


env = make("hungry_geese", debug=True)
config = env.configuration

In [684]:
import pickle
for it in range(100):
    print(f'starting iteration {it}')
    name = f'it_{it}.pkl'
    episodes = []
    for ep in range(num_episodes):
        print('episode number: ', ep)
        steps = []
        my_agent = RuleBasedAgent()
        agents =  [my_agent] + [(RuleBasedAgent() if np.random.rand()<0.7 else GreedyAgent()) for _ in range(nb_opponents)]
        state_dict = env.reset(num_agents=nb_opponents + 1)[0]
        observation = state_dict['observation']
        my_goose_ind = observation['index']

        reward = state_dict['reward']
        action = state_dict['action']



        done = False
        for step in range(1, steps_per_ep):
            actions = []

            for i, agent in enumerate(agents):
                obs = deepcopy(observation)
                obs['index'] = i
                action = agent(obs, config)
                actions.append(action)

            state_dict = env.step(actions)[0]
            observation = state_dict['observation']
            my_goose_ind = observation['index']

            my_goose_length = len(observation['geese'][my_goose_ind])

            longuest_opponent=0
            for i, goose in enumerate(observation.geese):
                if i != my_goose_ind:
                    opponent_length = len(goose)
                    if opponent_length > longuest_opponent:
                        longuest_opponent = opponent_length

            #new_state, _, _ = agent.getStateSpace(observation, config)

            #reward = state_dict['reward']
            action = state_dict['action']
            status = state_dict['status']

            if status != "ACTIVE":
                done = True

            # Check if my goose died
            if my_goose_length == 0:
                done = True
                reward = dying_reward
            elif (step+1) == steps_per_ep:
                reward = step_200_reward(my_goose_length, longuest_opponent)
                done = True
            elif status != "ACTIVE":
                reward = win_game_reward(step, my_goose_length, longuest_opponent)
            else:
                reward = step_reward

            steps.append({'cur_state': my_agent.stateSpace,
                                    'action': action,
                                    'reward': reward,
                                    'new_state': '',#new_state,
                                    'status': status,
                                    'done': done})
            if done:
#                 print('Done, Step: ', step+1)
#                 print('status, ', status)
                break

            if step%50 == 0:
                pass
                #print(f'We survived {step+1} steps')
        episodes.append(steps)
    process(discount, episodes)
    train_data = training_data(episodes)
    with open(f'../data/{name}', 'wb') as f:
        pickle.dump(train_data, f)

starting iteration 0
episode number:  0
Goose Collision: SOUTH
episode number:  1
Goose Collision: EAST
episode number:  2
Goose Collision: NORTH
Goose Collision: SOUTH
episode number:  3
Goose Collision: EAST
episode number:  4
Goose Collision: SOUTH
episode number:  5
Goose Collision: EAST
episode number:  6
Goose Collision: WEST
Goose Collision: NORTH
episode number:  7
Goose Collision: NORTH
episode number:  8
Goose Collision: NORTH
episode number:  9
Goose Starved: Action.NORTH
Goose Starved: Action.NORTH
episode number:  10
Body Hit: (0, <Action.NORTH: 1>, 51, [62, 73, 7, 18, 29, 40, 51, 52, 63])
episode number:  11
Goose Collision: EAST
episode number:  12
episode number:  13
Body Hit: (1, <Action.WEST: 4>, 53, [54, 65, 76, 10, 9, 75, 64, 53, 42, 43, 33, 34, 35, 46, 45, 44, 55, 66, 0])
episode number:  14
episode number:  15
episode number:  16
Goose Starved: Action.WEST
Goose Starved: Action.WEST
episode number:  17
Body Hit: (1, <Action.WEST: 4>, 8, [9, 20, 19, 8, 7, 73, 72, 6

Body Hit: (0, <Action.WEST: 4>, 1, [2, 13, 24, 23, 12, 1, 67, 68, 69, 3, 14, 25, 36])
episode number:  153
Goose Starved: Action.SOUTH
episode number:  154
Goose Collision: EAST
episode number:  155
Body Hit: (0, <Action.NORTH: 1>, 75, [9, 10, 21, 32, 31, 20, 19, 8, 74, 75, 76, 66, 0])
episode number:  156
Goose Collision: WEST
episode number:  157
Goose Collision: WEST
Goose Collision: SOUTH
episode number:  158
Goose Collision: SOUTH
episode number:  159
Body Hit: (0, <Action.EAST: 2>, 56, [55, 65, 54, 44, 45, 56, 67])
episode number:  160
Goose Collision: EAST
episode number:  161
Goose Collision: WEST
episode number:  162
episode number:  163
Goose Starved: Action.WEST
episode number:  164
Goose Collision: NORTH
episode number:  165
Goose Collision: WEST
episode number:  166
Goose Collision: NORTH
episode number:  167
Goose Collision: NORTH
Goose Collision: EAST
episode number:  168
Goose Collision: EAST
Goose Collision: SOUTH
episode number:  169
Goose Collision: WEST
episode numb

Body Hit: (0, <Action.EAST: 2>, 67, [66, 55, 44, 33, 34, 45, 56, 67, 68, 2])
episode number:  309
Goose Collision: NORTH
episode number:  310
Body Hit: (0, <Action.SOUTH: 3>, 64, [53, 42, 31, 32, 43, 54, 65, 64, 63, 52])
episode number:  311
episode number:  312
Goose Collision: SOUTH
episode number:  313
Goose Collision: SOUTH
episode number:  314
Goose Collision: SOUTH
episode number:  315
Goose Collision: NORTH
episode number:  316
Goose Collision: NORTH
episode number:  317
Goose Collision: NORTH
Goose Collision: WEST
episode number:  318
Goose Collision: SOUTH
episode number:  319
episode number:  320
episode number:  321
Goose Collision: EAST
episode number:  322
episode number:  323
Goose Collision: NORTH
episode number:  324
Goose Collision: NORTH
episode number:  325
Body Hit: (1, <Action.WEST: 4>, 34, [35, 24, 23, 34, 33, 44, 54, 53, 42, 31, 32, 21, 10, 0])
episode number:  326
Goose Collision: EAST
episode number:  327
Body Hit: (0, <Action.SOUTH: 3>, 58, [47, 48, 59, 58, 57

Body Hit: (0, <Action.WEST: 4>, 31, [32, 43, 54, 53, 42, 31, 30, 29, 40, 51, 62, 61, 60, 59])
episode number:  456
Goose Collision: NORTH
episode number:  457
Goose Collision: EAST
episode number:  458
episode number:  459
Goose Collision: EAST
episode number:  460
Body Hit: (0, <Action.NORTH: 1>, 42, [53, 52, 51, 40, 41, 42, 31, 32])
episode number:  461
Goose Collision: EAST
episode number:  462
Goose Collision: SOUTH
episode number:  463
episode number:  464
Goose Collision: SOUTH
Goose Collision: NORTH
episode number:  465
Goose Collision: EAST
episode number:  466
Goose Starved: Action.EAST
Goose Starved: Action.EAST
episode number:  467
Goose Collision: NORTH
episode number:  468
Goose Collision: NORTH
episode number:  469
Goose Starved: Action.EAST
Goose Starved: Action.EAST
episode number:  470
Body Hit: (0, <Action.EAST: 2>, 57, [56, 55, 65, 76, 66, 67, 68, 57, 46, 45, 44, 54, 53, 52, 51, 50, 39, 40])
episode number:  471
Goose Starved: Action.NORTH
Goose Starved: Action.NORTH

episode number:  605
Goose Collision: NORTH
episode number:  606
Goose Collision: SOUTH
Goose Collision: WEST
episode number:  607
Goose Starved: Action.NORTH
episode number:  608
Goose Collision: NORTH
Goose Collision: WEST
episode number:  609
Goose Collision: SOUTH
episode number:  610
Goose Collision: WEST
Goose Collision: EAST
episode number:  611
Goose Starved: Action.SOUTH
episode number:  612
Goose Collision: SOUTH
episode number:  613
Goose Collision: EAST
episode number:  614
episode number:  615
Goose Collision: WEST
episode number:  616
Goose Collision: NORTH
Goose Collision: EAST
episode number:  617
Body Hit: (1, <Action.SOUTH: 3>, 36, [25, 24, 23, 34, 35, 36, 37, 26, 15, 4, 3, 14, 13, 12, 1, 67, 56, 57])
episode number:  618
Body Hit: (0, <Action.NORTH: 1>, 18, [29, 30, 41, 40, 51, 52, 63, 62, 73, 7, 18, 19, 20, 31])
episode number:  619
Goose Starved: Action.NORTH
episode number:  620
Body Hit: (0, <Action.SOUTH: 3>, 21, [10, 9, 20, 21, 11, 0, 1])
episode number:  621
G

Goose Collision: SOUTH
Goose Collision: EAST
episode number:  768
Goose Collision: WEST
episode number:  769
Goose Collision: NORTH
Goose Collision: EAST
episode number:  770
episode number:  771
Goose Collision: EAST
Goose Collision: SOUTH
episode number:  772
Body Hit: (0, <Action.NORTH: 1>, 58, [69, 70, 71, 60, 59, 58, 47, 36, 37, 38, 49, 50, 61, 72])
episode number:  773
Body Hit: (1, <Action.NORTH: 1>, 38, [49, 60, 71, 5, 16, 27, 38, 39, 50, 61])
episode number:  774
Goose Collision: WEST
Goose Collision: EAST
episode number:  775
Body Hit: (1, <Action.EAST: 2>, 34, [33, 43, 32, 22, 23, 34, 35, 46, 47, 48, 49, 50, 51, 40, 41, 42, 53])
episode number:  776
Goose Collision: NORTH
episode number:  777
Goose Starved: Action.EAST
episode number:  778
episode number:  779
Goose Collision: EAST
episode number:  780
Goose Collision: EAST
episode number:  781
Goose Collision: EAST
episode number:  782
Goose Starved: Action.WEST
episode number:  783
Goose Starved: Action.SOUTH
episode numbe

Goose Collision: EAST
episode number:  928
Goose Collision: SOUTH
episode number:  929
Goose Starved: Action.NORTH
episode number:  930
Goose Starved: Action.SOUTH
episode number:  931
Goose Collision: EAST
episode number:  932
Body Hit: (0, <Action.SOUTH: 3>, 10, [76, 66, 55, 65, 54, 43, 32, 21, 10, 0, 1, 67, 56, 45, 34])
episode number:  933
Goose Collision: WEST
Goose Collision: SOUTH
episode number:  934
Body Hit: (1, <Action.WEST: 4>, 68, [69, 70, 71, 60, 59, 58, 57, 68, 2, 3, 4, 15, 14, 25])
episode number:  935
episode number:  936
Goose Collision: EAST
Goose Collision: NORTH
episode number:  937
Body Hit: (1, <Action.WEST: 4>, 59, [60, 71, 70, 59, 58, 69, 68, 2, 13, 14, 15])
episode number:  938
episode number:  939
Goose Collision: EAST
episode number:  940
Goose Collision: NORTH
episode number:  941
Goose Starved: Action.SOUTH
episode number:  942
Goose Collision: SOUTH
episode number:  943
Goose Starved: Action.SOUTH
Goose Starved: Action.SOUTH
episode number:  944
Goose Col

Goose Collision: WEST
episode number:  1077
Goose Collision: WEST
Goose Collision: NORTH
episode number:  1078
Goose Collision: EAST
Goose Collision: NORTH
episode number:  1079
Goose Starved: Action.WEST
Goose Starved: Action.WEST
episode number:  1080
Goose Collision: SOUTH
Goose Collision: EAST
episode number:  1081
Body Hit: (0, <Action.NORTH: 1>, 67, [1, 0, 66, 67, 68, 2, 3, 14])
episode number:  1082
episode number:  1083
Body Hit: (1, <Action.WEST: 4>, 30, [31, 32, 21, 10, 9, 20, 19, 30, 41, 52, 53, 42, 43, 33, 22, 11, 12])
episode number:  1084
Goose Collision: SOUTH
Goose Collision: EAST
episode number:  1085
Body Hit: (0, <Action.NORTH: 1>, 54, [65, 55, 56, 57, 68, 67, 66, 76, 75, 64, 53, 54])
episode number:  1086
Body Hit: (0, <Action.NORTH: 1>, 65, [76, 75, 74, 73, 62, 63, 64, 65, 55, 66, 67])
episode number:  1087
Body Hit: (1, <Action.WEST: 4>, 14, [15, 26, 37, 36, 25, 14, 13, 2, 68, 67, 66, 55])
episode number:  1088
episode number:  1089
Goose Collision: EAST
episode n

episode number:  1220
Body Hit: (0, <Action.NORTH: 1>, 70, [4, 5, 16, 17, 6, 72, 71, 70, 69, 3, 14, 15])
episode number:  1221
Goose Collision: SOUTH
episode number:  1222
Goose Starved: Action.NORTH
Goose Starved: Action.NORTH
episode number:  1223
Goose Starved: Action.SOUTH
episode number:  1224
Body Hit: (0, <Action.WEST: 4>, 0, [1, 67, 56, 45, 44, 54, 43, 32, 21, 11, 0, 66])
episode number:  1225
Goose Starved: Action.SOUTH
episode number:  1226
Goose Collision: NORTH
episode number:  1227
Body Hit: (0, <Action.SOUTH: 3>, 22, [11, 12, 13, 24, 23, 22, 32, 21, 20, 19, 8, 9, 10, 0, 1, 67, 56])
episode number:  1228
Goose Collision: SOUTH
episode number:  1229
Goose Starved: Action.EAST
Goose Starved: Action.SOUTH
episode number:  1230
Goose Collision: NORTH
Goose Collision: SOUTH
episode number:  1231
Goose Starved: Action.NORTH
Goose Starved: Action.SOUTH
episode number:  1232
Body Hit: (1, <Action.EAST: 2>, 2, [1, 12, 13, 2, 3, 69, 58, 47])
episode number:  1233
Body Hit: (1, <Acti

episode number:  1369
Goose Collision: EAST
episode number:  1370
Goose Starved: Action.NORTH
episode number:  1371
episode number:  1372
Body Hit: (0, <Action.SOUTH: 3>, 70, [59, 58, 57, 56, 55, 66, 0, 1, 67, 68, 69, 70, 71])
episode number:  1373
Goose Collision: SOUTH
episode number:  1374
Goose Starved: Action.NORTH
Goose Starved: Action.NORTH
episode number:  1375
Goose Starved: Action.WEST
episode number:  1376
Goose Collision: WEST
episode number:  1377
Goose Collision: NORTH
Goose Collision: SOUTH
episode number:  1378
Goose Collision: SOUTH
episode number:  1379
Goose Starved: Action.WEST
Goose Starved: Action.WEST
episode number:  1380
episode number:  1381
Goose Collision: EAST
episode number:  1382
Goose Collision: EAST
Goose Collision: NORTH
episode number:  1383
Goose Starved: Action.SOUTH
episode number:  1384
Goose Starved: Action.SOUTH
Goose Starved: Action.NORTH
episode number:  1385
Goose Collision: WEST
Goose Collision: NORTH
episode number:  1386
Goose Collision: S

Goose Starved: Action.EAST
episode number:  1515
Goose Collision: SOUTH
episode number:  1516
Goose Collision: EAST
episode number:  1517
Goose Collision: WEST
episode number:  1518
Goose Collision: NORTH
Goose Collision: SOUTH
episode number:  1519
Goose Collision: SOUTH
episode number:  1520
episode number:  1521
Body Hit: (1, <Action.WEST: 4>, 19, [20, 31, 32, 21, 10, 9, 8, 19, 30])
episode number:  1522
Body Hit: (0, <Action.EAST: 2>, 57, [56, 45, 34, 23, 12, 13, 24, 35, 46, 57, 58])
episode number:  1523
Goose Collision: SOUTH
episode number:  1524
episode number:  1525
Goose Collision: NORTH
episode number:  1526
Goose Collision: NORTH
episode number:  1527
Goose Collision: NORTH
episode number:  1528
Goose Collision: EAST
episode number:  1529
Goose Collision: SOUTH
episode number:  1530
episode number:  1531
Goose Collision: WEST
Goose Collision: SOUTH
episode number:  1532
Goose Starved: Action.WEST
Goose Starved: Action.EAST
episode number:  1533
Goose Starved: Action.SOUTH
e

Goose Starved: Action.EAST
episode number:  1667
Goose Collision: EAST
episode number:  1668
Goose Starved: Action.SOUTH
episode number:  1669
Goose Collision: EAST
episode number:  1670
Goose Collision: WEST
episode number:  1671
Goose Collision: EAST
episode number:  1672
Goose Collision: WEST
episode number:  1673
Goose Starved: Action.NORTH
Goose Starved: Action.NORTH
episode number:  1674
Goose Collision: NORTH
Goose Collision: EAST
episode number:  1675
episode number:  1676
Goose Collision: SOUTH
Goose Collision: WEST
episode number:  1677
Goose Starved: Action.SOUTH
episode number:  1678
Goose Collision: SOUTH
Goose Collision: EAST
episode number:  1679
Goose Collision: EAST
episode number:  1680
Body Hit: (0, <Action.EAST: 2>, 44, [54, 53, 42, 43, 33, 44, 55, 65, 64])
episode number:  1681
Goose Collision: SOUTH
episode number:  1682
Body Hit: (0, <Action.EAST: 2>, 2, [1, 0, 10, 76, 66, 67, 56, 57, 58, 69, 68, 2, 13, 12, 11, 21, 20])
episode number:  1683
Goose Collision: NORT

Body Hit: (1, <Action.NORTH: 1>, 38, [49, 48, 37, 38, 27, 28, 17, 18, 19])
episode number:  1819
Goose Collision: WEST
episode number:  1820
Goose Collision: NORTH
Goose Collision: WEST
episode number:  1821
Body Hit: (1, <Action.EAST: 2>, 8, [7, 18, 17, 6, 72, 73, 62, 63, 74, 8, 19, 30, 29, 40, 39])
episode number:  1822
Goose Collision: SOUTH
Goose Collision: WEST
episode number:  1823
Goose Starved: Action.EAST
episode number:  1824
Goose Collision: NORTH
episode number:  1825
Body Hit: (1, <Action.NORTH: 1>, 49, [60, 71, 72, 61, 50, 49, 38])
episode number:  1826
episode number:  1827
Goose Collision: NORTH
episode number:  1828
Body Hit: (0, <Action.WEST: 4>, 73, [74, 63, 62, 73, 7, 6, 5, 4, 3, 2])
episode number:  1829
Goose Collision: NORTH
Goose Collision: WEST
episode number:  1830
Goose Collision: EAST
Goose Collision: NORTH
episode number:  1831
Goose Collision: EAST
episode number:  1832
episode number:  1833
Goose Collision: NORTH
episode number:  1834
Body Hit: (1, <Actio

Goose Starved: Action.NORTH
episode number:  1956
Goose Starved: Action.WEST
Goose Starved: Action.EAST
episode number:  1957
Body Hit: (0, <Action.NORTH: 1>, 74, [8, 19, 30, 31, 20, 9, 75, 74, 73, 7, 18, 29, 40, 41, 52, 53])
episode number:  1958
Goose Collision: NORTH
episode number:  1959
Body Hit: (0, <Action.WEST: 4>, 3, [4, 15, 14, 3, 69, 70, 71, 5, 16, 27, 26, 25, 36])
episode number:  1960
Goose Collision: EAST
episode number:  1961
episode number:  1962
Goose Collision: WEST
episode number:  1963
Goose Collision: NORTH
Goose Collision: EAST
episode number:  1964
Goose Collision: EAST
Goose Collision: WEST
episode number:  1965
episode number:  1966
Goose Starved: Action.SOUTH
episode number:  1967
Body Hit: (0, <Action.SOUTH: 3>, 26, [15, 16, 27, 26, 37, 38, 39, 40, 51, 50, 61, 72, 6])
episode number:  1968
Goose Starved: Action.SOUTH
episode number:  1969
Goose Collision: SOUTH
episode number:  1970
Goose Collision: NORTH
episode number:  1971
episode number:  1972
Body Hit: 

Goose Collision: WEST
Goose Collision: NORTH
episode number:  2109
Goose Collision: SOUTH
episode number:  2110
Goose Collision: EAST
Goose Collision: SOUTH
episode number:  2111
Goose Collision: EAST
episode number:  2112
Goose Collision: WEST
episode number:  2113
Body Hit: (0, <Action.WEST: 4>, 27, [28, 17, 6, 72, 61, 50, 39, 38, 27, 16])
episode number:  2114
episode number:  2115
Goose Collision: SOUTH
episode number:  2116
episode number:  2117
Goose Starved: Action.SOUTH
Goose Starved: Action.NORTH
episode number:  2118
Goose Collision: NORTH
episode number:  2119
episode number:  2120
Goose Collision: SOUTH
episode number:  2121
episode number:  2122
Goose Collision: EAST
episode number:  2123
Goose Collision: WEST
episode number:  2124
Body Hit: (1, <Action.EAST: 2>, 20, [19, 8, 74, 75, 64, 65, 55, 66, 76, 10, 9, 20, 31])
episode number:  2125
episode number:  2126
episode number:  2127
Goose Collision: WEST
episode number:  2128
Body Hit: (1, <Action.EAST: 2>, 28, [27, 16, 5,

Goose Collision: SOUTH
Goose Collision: EAST
episode number:  2253
Goose Collision: EAST
episode number:  2254
Body Hit: (0, <Action.WEST: 4>, 75, [76, 65, 64, 75, 9, 20])
episode number:  2255
Goose Collision: NORTH
episode number:  2256
episode number:  2257
Goose Collision: SOUTH
episode number:  2258
Goose Starved: Action.SOUTH
Goose Starved: Action.NORTH
episode number:  2259
episode number:  2260
Goose Starved: Action.NORTH
episode number:  2261
Goose Starved: Action.NORTH
episode number:  2262
Goose Collision: EAST
episode number:  2263
episode number:  2264
Goose Collision: WEST
episode number:  2265
episode number:  2266
Goose Collision: EAST
Goose Collision: SOUTH
episode number:  2267
Body Hit: (1, <Action.EAST: 2>, 50, [49, 60, 71, 5, 6, 17, 18, 19, 8, 7, 73, 72, 61, 50, 51, 52, 63, 64])
episode number:  2268
Goose Collision: SOUTH
Goose Collision: EAST
episode number:  2269
Goose Collision: NORTH
episode number:  2270
Goose Collision: SOUTH
episode number:  2271
Goose Coll

Goose Starved: Action.NORTH
episode number:  2411
episode number:  2412
Goose Starved: Action.NORTH
episode number:  2413
Goose Collision: EAST
episode number:  2414
Goose Collision: SOUTH
episode number:  2415
Goose Collision: NORTH
episode number:  2416
Goose Starved: Action.SOUTH
Goose Starved: Action.SOUTH
episode number:  2417
Body Hit: (1, <Action.SOUTH: 3>, 32, [21, 20, 19, 18, 7, 8, 9, 10, 0, 11, 22, 32, 31, 30, 29])
episode number:  2418
Body Hit: (0, <Action.EAST: 2>, 6, [5, 16, 17, 6, 7, 8, 9, 10, 0, 66])
episode number:  2419
episode number:  2420
Goose Starved: Action.NORTH
episode number:  2421
Body Hit: (1, <Action.EAST: 2>, 12, [11, 0, 10, 21, 32, 22, 33, 44, 45, 34, 35, 24, 13, 12, 1, 67, 66])
episode number:  2422
Goose Collision: WEST
episode number:  2423
Goose Collision: NORTH
episode number:  2424
Goose Collision: NORTH
episode number:  2425
Body Hit: (1, <Action.SOUTH: 3>, 74, [63, 52, 41, 30, 19, 20, 31, 42, 53, 64, 75, 74, 73, 72, 61, 62, 51])
episode number:  

Goose Collision: NORTH
episode number:  2573
Goose Collision: SOUTH
episode number:  2574
episode number:  2575
Goose Collision: EAST
episode number:  2576
Goose Collision: NORTH
Goose Collision: EAST
episode number:  2577
Goose Collision: EAST
episode number:  2578
Goose Collision: NORTH
episode number:  2579
Goose Collision: EAST
Goose Collision: SOUTH
episode number:  2580
Body Hit: (1, <Action.EAST: 2>, 22, [32, 43, 33, 22, 11, 21, 10, 76])
episode number:  2581
Body Hit: (0, <Action.EAST: 2>, 23, [22, 33, 44, 55, 65, 54, 43, 32, 21, 11, 0, 66, 67, 68, 2, 13, 24, 23])
episode number:  2582
Goose Collision: NORTH
episode number:  2583
Goose Collision: EAST
Goose Collision: WEST
episode number:  2584
Goose Collision: NORTH
episode number:  2585
Body Hit: (0, <Action.EAST: 2>, 52, [51, 40, 29, 30, 41, 52, 53, 64, 75, 76, 10, 0, 11, 12])
episode number:  2586
episode number:  2587
Goose Collision: SOUTH
Goose Collision: NORTH
episode number:  2588
Body Hit: (0, <Action.WEST: 4>, 72, [7

Goose Collision: NORTH
episode number:  2715
episode number:  2716
Goose Collision: EAST
episode number:  2717
episode number:  2718
Goose Starved: Action.SOUTH
episode number:  2719
Goose Collision: SOUTH
episode number:  2720
Goose Starved: Action.WEST
episode number:  2721
Goose Collision: SOUTH
episode number:  2722
Goose Starved: Action.NORTH
episode number:  2723
Body Hit: (0, <Action.SOUTH: 3>, 67, [56, 57, 58, 69, 70, 4, 3, 2, 68, 67, 66, 55, 44, 45, 46, 35, 36, 37])
episode number:  2724
episode number:  2725
Goose Starved: Action.SOUTH
Goose Starved: Action.NORTH
episode number:  2726
Goose Collision: WEST
episode number:  2727
Goose Collision: NORTH
episode number:  2728
Goose Collision: WEST
episode number:  2729
Goose Collision: NORTH
episode number:  2730
Goose Starved: Action.SOUTH
Goose Starved: Action.NORTH
episode number:  2731
Goose Collision: NORTH
Goose Collision: WEST
episode number:  2732
Goose Collision: WEST
episode number:  2733
Goose Collision: WEST
episode n

Goose Collision: NORTH
episode number:  2855
Goose Collision: EAST
episode number:  2856
episode number:  2857
episode number:  2858
Goose Collision: WEST
episode number:  2859
Body Hit: (0, <Action.NORTH: 1>, 25, [36, 47, 48, 37, 26, 25, 14, 13, 2, 1, 0, 10])
episode number:  2860
episode number:  2861
Goose Collision: SOUTH
Goose Collision: NORTH
episode number:  2862
Goose Collision: NORTH
Goose Collision: EAST
episode number:  2863
episode number:  2864
episode number:  2865
Goose Starved: Action.SOUTH
Goose Starved: Action.SOUTH
episode number:  2866
Goose Starved: Action.SOUTH
Goose Starved: Action.NORTH
episode number:  2867
Goose Collision: WEST
episode number:  2868
Goose Collision: NORTH
episode number:  2869
episode number:  2870
episode number:  2871
Goose Collision: NORTH
episode number:  2872
Goose Collision: EAST
episode number:  2873
Goose Starved: Action.NORTH
Goose Starved: Action.NORTH
episode number:  2874
Goose Collision: SOUTH
episode number:  2875
Body Hit: (1, <

Goose Starved: Action.EAST
episode number:  3006
Body Hit: (1, <Action.EAST: 2>, 60, [59, 48, 37, 26, 27, 28, 39, 38, 49, 60])
episode number:  3007
Body Hit: (0, <Action.WEST: 4>, 27, [28, 29, 18, 19, 30, 41, 52, 51, 50, 39, 38, 27, 16, 17, 6, 7])
episode number:  3008
Goose Collision: EAST
episode number:  3009
Goose Collision: SOUTH
Goose Collision: WEST
episode number:  3010
Body Hit: (1, <Action.EAST: 2>, 8, [7, 73, 72, 6, 17, 18, 19, 8, 74])
episode number:  3011
Goose Collision: WEST
episode number:  3012
episode number:  3013
Goose Collision: EAST
episode number:  3014
Goose Starved: Action.SOUTH
Goose Starved: Action.NORTH
episode number:  3015
Goose Collision: NORTH
Goose Collision: WEST
episode number:  3016
Body Hit: (1, <Action.WEST: 4>, 68, [69, 3, 2, 68, 57, 58, 47, 36, 37, 38, 39, 40, 41, 42, 43, 32])
episode number:  3017
Goose Collision: NORTH
episode number:  3018
Goose Collision: NORTH
Goose Collision: SOUTH
episode number:  3019
Goose Starved: Action.SOUTH
episode 

Goose Starved: Action.SOUTH
episode number:  3152
episode number:  3153
Goose Collision: SOUTH
episode number:  3154
episode number:  3155
Goose Collision: NORTH
episode number:  3156
Goose Collision: NORTH
Goose Collision: EAST
episode number:  3157
Goose Collision: SOUTH
episode number:  3158
Goose Collision: WEST
episode number:  3159
Goose Collision: EAST
episode number:  3160
Goose Collision: EAST
episode number:  3161
Goose Starved: Action.SOUTH
episode number:  3162
Body Hit: (1, <Action.EAST: 2>, 46, [45, 56, 67, 1, 12, 23, 34, 35, 24, 13, 2, 68, 57, 46])
episode number:  3163
Goose Collision: SOUTH
episode number:  3164
Goose Starved: Action.NORTH
episode number:  3165
Goose Collision: SOUTH
Goose Collision: EAST
episode number:  3166
Goose Collision: NORTH
episode number:  3167
Body Hit: (1, <Action.EAST: 2>, 24, [23, 34, 35, 24, 25, 26, 27, 16])
episode number:  3168
Goose Collision: NORTH
Goose Collision: WEST
episode number:  3169
episode number:  3170
Goose Collision: NOR

episode number:  3308
Goose Starved: Action.WEST
episode number:  3309
Body Hit: (1, <Action.SOUTH: 3>, 16, [5, 4, 15, 16, 27, 28, 39, 50, 61, 72, 6])
episode number:  3310
Goose Collision: NORTH
episode number:  3311
Goose Starved: Action.NORTH
episode number:  3312
Goose Starved: Action.SOUTH
episode number:  3313
Goose Collision: NORTH
episode number:  3314
Goose Collision: SOUTH
episode number:  3315
Goose Collision: SOUTH
episode number:  3316
Goose Collision: EAST
episode number:  3317
Goose Collision: WEST
episode number:  3318
Body Hit: (0, <Action.NORTH: 1>, 9, [20, 31, 42, 41, 30, 19, 8, 9, 10, 21])
episode number:  3319
Body Hit: (1, <Action.EAST: 2>, 12, [11, 21, 10, 9, 75, 74, 63, 64, 65, 76, 66, 0, 1, 12, 23])
episode number:  3320
Goose Collision: EAST
episode number:  3321
episode number:  3322
episode number:  3323
Goose Collision: EAST
episode number:  3324
Goose Collision: EAST
Goose Collision: SOUTH
episode number:  3325
Goose Collision: EAST
episode number:  3326
e

Body Hit: (1, <Action.SOUTH: 3>, 29, [18, 17, 16, 5, 6, 7, 8, 19, 30, 29, 28])
episode number:  3450
Goose Starved: Action.NORTH
Goose Starved: Action.NORTH
episode number:  3451
Goose Collision: NORTH
episode number:  3452
Body Hit: (0, <Action.SOUTH: 3>, 55, [44, 33, 43, 54, 65, 55, 66, 0, 11, 22, 23, 34])
episode number:  3453
Goose Collision: NORTH
episode number:  3454
Goose Collision: EAST
Goose Collision: SOUTH
episode number:  3455
Goose Starved: Action.NORTH
episode number:  3456
Body Hit: (0, <Action.SOUTH: 3>, 6, [72, 71, 5, 6, 7, 73, 62])
episode number:  3457
Goose Collision: WEST
episode number:  3458
Goose Collision: EAST
Goose Collision: NORTH
episode number:  3459
Goose Starved: Action.NORTH
Goose Starved: Action.NORTH
episode number:  3460
Goose Collision: NORTH
Goose Collision: EAST
episode number:  3461
Goose Collision: WEST
episode number:  3462
Body Hit: (1, <Action.SOUTH: 3>, 21, [10, 76, 75, 9, 20, 21, 32, 43, 54, 65, 55, 66])
episode number:  3463
Goose Starved

Goose Collision: NORTH
Goose Collision: EAST
episode number:  3600
Body Hit: (1, <Action.WEST: 4>, 75, [76, 10, 0, 66, 55, 65, 64, 75, 9, 20, 21])
episode number:  3601
episode number:  3602
Body Hit: (1, <Action.EAST: 2>, 14, [13, 2, 1, 12, 11, 22, 33, 34, 35, 24, 25, 14, 3, 69, 68])
episode number:  3603
Goose Collision: NORTH
episode number:  3604
Goose Collision: NORTH
episode number:  3605
Body Hit: (1, <Action.SOUTH: 3>, 75, [64, 53, 42, 41, 52, 51, 50, 61, 62, 63, 74, 75, 9, 20, 19])
episode number:  3606
Goose Collision: NORTH
episode number:  3607
episode number:  3608
Body Hit: (1, <Action.NORTH: 1>, 5, [16, 15, 14, 3, 4, 5, 6, 17, 18, 19, 30, 41, 42, 31, 32, 21])
episode number:  3609
episode number:  3610
Goose Collision: SOUTH
episode number:  3611
episode number:  3612
episode number:  3613
Body Hit: (1, <Action.WEST: 4>, 34, [35, 36, 25, 24, 23, 34, 33, 22, 11, 21, 32, 43])
episode number:  3614
Goose Collision: WEST
episode number:  3615
Body Hit: (1, <Action.WEST: 4>, 

Body Hit: (0, <Action.NORTH: 1>, 1, [12, 13, 14, 3, 2, 1, 0, 11, 22, 23, 24, 25, 26, 15])
episode number:  3739
Goose Collision: NORTH
episode number:  3740
episode number:  3741
Goose Collision: SOUTH
episode number:  3742
Body Hit: (0, <Action.SOUTH: 3>, 66, [55, 65, 64, 75, 76, 66, 67, 56, 45])
episode number:  3743
Goose Collision: NORTH
episode number:  3744
Goose Collision: WEST
episode number:  3745
episode number:  3746
Goose Collision: NORTH
episode number:  3747
episode number:  3748
Goose Collision: EAST
Goose Collision: NORTH
episode number:  3749
Goose Collision: NORTH
episode number:  3750
Body Hit: (1, <Action.WEST: 4>, 28, [29, 40, 39, 28, 27, 16, 15, 4, 70, 59, 58, 57, 56])
episode number:  3751
Goose Collision: EAST
Goose Collision: WEST
episode number:  3752
episode number:  3753
Goose Collision: WEST
episode number:  3754
Body Hit: (1, <Action.SOUTH: 3>, 18, [7, 73, 74, 8, 19, 18, 29, 40, 41, 42])
episode number:  3755
Goose Collision: NORTH
Goose Collision: SOUTH
e

episode number:  3890
Goose Collision: EAST
episode number:  3891
Body Hit: (0, <Action.SOUTH: 3>, 28, [17, 16, 27, 28, 29, 18, 7, 6, 5, 71])
episode number:  3892
Goose Collision: EAST
episode number:  3893
episode number:  3894
episode number:  3895
Goose Collision: WEST
episode number:  3896
Goose Collision: SOUTH
episode number:  3897
Goose Collision: WEST
episode number:  3898
Body Hit: (1, <Action.EAST: 2>, 3, [2, 13, 24, 35, 36, 25, 14, 3, 4, 5])
episode number:  3899
Goose Collision: EAST
Goose Collision: SOUTH
episode number:  3900
Goose Collision: WEST
episode number:  3901
Goose Starved: Action.SOUTH
Goose Starved: Action.SOUTH
episode number:  3902
Goose Collision: NORTH
episode number:  3903
Goose Collision: NORTH
episode number:  3904
Goose Starved: Action.NORTH
episode number:  3905
Goose Collision: SOUTH
episode number:  3906
Body Hit: (0, <Action.NORTH: 1>, 19, [30, 31, 20, 19, 18, 17, 16, 15, 4, 3, 2])
episode number:  3907
Body Hit: (0, <Action.SOUTH: 3>, 2, [68, 69,

Body Hit: (0, <Action.WEST: 4>, 68, [69, 58, 59, 48, 37, 36, 47, 46, 57, 68, 2, 3, 4, 70, 71, 60, 49])
episode number:  4040
Goose Collision: SOUTH
episode number:  4041
Goose Collision: SOUTH
episode number:  4042
Goose Collision: EAST
episode number:  4043
Goose Collision: SOUTH
Goose Collision: WEST
episode number:  4044
Goose Collision: NORTH
episode number:  4045
episode number:  4046
Goose Collision: EAST
Goose Collision: WEST
episode number:  4047
Goose Collision: SOUTH
episode number:  4048
Goose Collision: NORTH
episode number:  4049
episode number:  4050
Body Hit: (1, <Action.WEST: 4>, 70, [71, 60, 59, 70, 4, 5, 6, 72, 61, 50, 49, 38, 27, 26, 25])
episode number:  4051
Goose Collision: SOUTH
episode number:  4052
Goose Collision: NORTH
Goose Collision: EAST
episode number:  4053
episode number:  4054
Goose Collision: NORTH
episode number:  4055
episode number:  4056
Goose Collision: NORTH
episode number:  4057
episode number:  4058
Body Hit: (0, <Action.WEST: 4>, 1, [2, 13, 1

Goose Starved: Action.NORTH
episode number:  4191
Goose Collision: SOUTH
episode number:  4192
Goose Collision: WEST
episode number:  4193
Goose Collision: NORTH
episode number:  4194
Goose Collision: NORTH
episode number:  4195
Goose Collision: WEST
Goose Collision: NORTH
episode number:  4196
Body Hit: (0, <Action.WEST: 4>, 71, [72, 61, 50, 49, 48, 59, 70, 71, 5, 6, 7, 73, 62])
episode number:  4197
episode number:  4198
Goose Collision: NORTH
Goose Collision: EAST
episode number:  4199
Goose Collision: NORTH
episode number:  4200
Body Hit: (0, <Action.NORTH: 1>, 7, [18, 29, 28, 17, 6, 7, 8, 19, 30])
episode number:  4201
episode number:  4202
Body Hit: (0, <Action.NORTH: 1>, 12, [23, 34, 45, 44, 33, 43, 32, 22, 11, 12, 13, 24, 35, 46])
episode number:  4203
Goose Starved: Action.WEST
episode number:  4204
Goose Starved: Action.SOUTH
Goose Starved: Action.SOUTH
episode number:  4205
Goose Collision: WEST
episode number:  4206
Body Hit: (1, <Action.SOUTH: 3>, 23, [12, 13, 24, 23, 34, 

episode number:  4337
Goose Collision: EAST
Goose Collision: WEST
episode number:  4338
Goose Starved: Action.NORTH
episode number:  4339
Goose Starved: Action.EAST
episode number:  4340
Goose Collision: NORTH
episode number:  4341
Body Hit: (0, <Action.NORTH: 1>, 31, [42, 43, 32, 31, 30, 41, 52, 53, 54, 44, 33])
episode number:  4342
Goose Starved: Action.SOUTH
episode number:  4343
Body Hit: (0, <Action.WEST: 4>, 41, [42, 43, 33, 22, 32, 31, 30, 41, 52, 53, 54])
episode number:  4344
Goose Collision: WEST
episode number:  4345
episode number:  4346
Goose Starved: Action.SOUTH
Goose Starved: Action.SOUTH
episode number:  4347
Goose Starved: Action.NORTH
episode number:  4348
Goose Collision: WEST
Goose Collision: SOUTH
episode number:  4349
Goose Collision: SOUTH
Goose Collision: EAST
episode number:  4350
Body Hit: (1, <Action.WEST: 4>, 61, [62, 73, 72, 61, 60, 71, 70, 4, 15, 14, 13])
episode number:  4351
episode number:  4352
Goose Starved: Action.WEST
episode number:  4353
Goose C

Body Hit: (0, <Action.SOUTH: 3>, 32, [21, 11, 0, 10, 76, 65, 54, 43, 32, 22, 23, 12, 13, 14, 25])
episode number:  4475
Goose Starved: Action.SOUTH
episode number:  4476
Goose Collision: EAST
episode number:  4477
Body Hit: (1, <Action.NORTH: 1>, 6, [17, 28, 39, 38, 27, 16, 15, 4, 70, 59, 60, 61, 72, 6])
episode number:  4478
Goose Collision: SOUTH
episode number:  4479
Goose Collision: EAST
episode number:  4480
Goose Collision: EAST
episode number:  4481
Goose Collision: EAST
episode number:  4482
episode number:  4483
Goose Starved: Action.WEST
episode number:  4484
Goose Collision: NORTH
episode number:  4485
Body Hit: (1, <Action.EAST: 2>, 21, [20, 31, 32, 21, 11, 22, 23, 24, 25])
episode number:  4486
Goose Collision: EAST
Goose Collision: SOUTH
episode number:  4487
episode number:  4488
Goose Starved: Action.SOUTH
episode number:  4489
episode number:  4490
Body Hit: (1, <Action.SOUTH: 3>, 0, [66, 55, 56, 67, 1, 0, 11, 22])
episode number:  4491
Goose Collision: WEST
episode nu

KeyboardInterrupt: 

In [687]:
process(discount, episodes)

In [688]:
episodes[0]

[{'cur_state': {'food1_col': -0.2,
   'food2_col': -0.6,
   'food1_row': -0.2,
   'food2_row': -0.2,
   'goose_size': 1,
   'longuest_opponent': 1,
   'board': array([[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0., 16.,  0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0., 15.,  0., 15.,  0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0., 16.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]]),
   'hunger': -1.0,
   'step': -1.0},
  'action': 'NORTH',
  'reward': 0,
  'new_state': '',
  'status': 'ACTIVE',
  'done': False,
  'embeddings': [0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   16,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   15,
   0,
   15,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
  

### Let's create a network to predict action to take based on state

In [589]:
import os
from math import ceil
import numpy as np
from tensorflow.keras.layers import Dense,Input, Embedding, concatenate,\
    Flatten, Average, Dropout, BatchNormalization, Activation
from tensorflow.keras import Sequential, Model
from tensorflow import config, distribute
import tensorflow as tf


In [590]:
tf.compat.v1.reset_default_graph()

In [591]:
num_layers = []
for _ in range(5):
    num_layers.append(Input(shape=1))

In [592]:
emb_layers = []
for i in range(7*11):
    m = Sequential()
    embedding._name = f'embeddings_{i}'
    m.add(embedding)
    m.add(Flatten(name=f'flat_embeddings-{i}'))
    emb_layers.append(m)

In [593]:
inputs = num_layers + [inp.input for inp in emb_layers]
outputs = num_layers + [inp.output for inp in emb_layers]

In [594]:
c = concatenate(outputs)
model = Dense(100, activation='elu')(c)
model = Dense(50, activation='elu')(model)
model = Dense(20, activation='elu')(model)
model = Dense(10, activation='elu')(model)
model = Dense(4, activation='softmax')(model)

In [595]:
m = Model(inputs, model)
m.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

In [597]:
m.predict([x.reshape(-1,1) for x in np.array([[16,5,10]] * (5+11*7))])

array([[0.89772683, 0.02073192, 0.01104983, 0.07049137],
       [0.62568307, 0.07258962, 0.15280053, 0.1489268 ],
       [0.8255264 , 0.03592074, 0.04352432, 0.09502851]], dtype=float32)

In [655]:
len([x.reshape(-1,1) for x in np.array([[16,5,10]] * (5+11*7))])

82

In [654]:
[x.reshape(-1,1) for x in np.array([[16,5,10]] * (5+11*7))]

[array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
    

In [598]:
[x.reshape(-1,1) for x in np.array([[16,5,10]] * (5+11*7))]

[array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
        [10]]),
 array([[16],
        [ 5],
    

In [686]:
episodes[0]

[{'cur_state': {'food1_col': -0.2,
   'food2_col': -0.6,
   'food1_row': -0.2,
   'food2_row': -0.2,
   'goose_size': 1,
   'longuest_opponent': 1,
   'board': array([[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0., 16.,  0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0., 15.,  0., 15.,  0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0., 16.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]]),
   'hunger': -1.0,
   'step': -1.0},
  'action': 'NORTH',
  'reward': 0,
  'new_state': '',
  'status': 'ACTIVE',
  'done': False},
 {'cur_state': {'food1_col': -0.6,
   'food2_col': 0.4,
   'food1_row': 0.0,
   'food2_row': 0.0,
   'goose_size': 1,
   'longuest_opponent': 2,
   'board': array([[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.

In [None]:
def training_data(episodes):
    targets = []
    next_numerical = []
    next_embeddings = []
    reward = []
    done = []
    v = []
    actions = []
    numerical = []
    embeddings = []
    for episode in episodes:
        for step in episode:
            action = step['action']
            target = action_to_target(action)
            targets.append(target)
            num = step['numerical']
            emb = step['embeddings']
            next_numerical = step['next_numerical']
            next_embeddings = step['next_embeddings']
            numerical.append(num)
            embeddings.append(emb)
            actions.append(action)
            v.append(step['v'])
            done.append(step['done'])
            reward.append(step['reward'])

    target_reshaped = np.array(targets).reshape(-1, 4)
    e = [np.array(embeddings)[:, i].reshape(-1, 1) for i in range(7*11)]
    n = [np.array(numerical)[:, i].reshape(-1, 1) for i in range(5)]
    train = n+e

    e_next = [np.array(next_embeddings)[:, i].reshape(-1, 1) for i in range(7*11)]
    n_next = [np.array(next_numerical)[:, i].reshape(-1, 1) for i in range(5)]
    train_next = n_next+e_next

    training_dict = {'state': train,
                     'action': action,
                     'next_state': train_next,
                     'y': target_reshaped,
                     'reward': reward,
                     'v': v,
                     'done': done}
    return training_dict

[0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,


In [None]:
np.array(targets).reshape(-1, 4)

In [660]:
np.array(embeddings)[:, 0]

array([ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  7,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  5,  5,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, 15,
        0,  0,  0,  7,  7,  7,  0,  0,  0,  0,  0,  6,  0,  7,  7,  7,  7,
        7,  0,  0,  0,  5,  5,  5,  5,  0,  0,  6,  0,  0,  0,  0, 11,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  8,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, 14, 13, 11,  9,
        0,  3,  1,  3,  1,  0,  0,  0, 14, 14,  6,  6,  6,  6,  0,  0,  0,
        0, 14,  0,  0,  0,  0,  8, 14, 12,  0, 14, 13, 12])

In [661]:
e = [np.array(embeddings)[:, i].reshape(-1, 1) for i in range(7*11)]

In [666]:
n = [np.array(numerical)[:, i].reshape(-1, 1) for i in range(5)]

In [668]:
train = n+e

In [None]:
from tensorflow.keras.callbacks import Callback
from sklearn.metrics import roc_auc_score, mean_squared_error

def reset_weights(model):
    session = K.get_session()
    for layer in model.layers:
        if hasattr(layer, 'kernel_initializer'):
            layer.kernel.initializer.run(session=session)
        if hasattr(layer, 'bias_initializer'):
            layer.bias.initializer.run(session=session)

class perf_callback(Callback):
    def __init__(self, data, target_type='classification'):
        self.X = data[0]
        self.y = data[1]
        self.target_type = target_type
    def on_epoch_end(self, epoch, logs={}):
        y_pred = self.model.predict(self.X)
        if self.target_type == 'classification':
            perf = roc_auc_score(self.y, y_pred, average='micro')
        elif self.target_type == 'regression':
            perf = mean_squared_error(self.y, y_pred)
        logs['validation'] = perf
        

early_stop = EarlyStopping(patience=self.patience,
                           monitor='validation',
                           mode=mode)

In [676]:
from sklearn.model_selection import train_test_split
y = np.array(targets).reshape(-1, 4)

callbacks = [
    perf_callback((X_val_adj, y_val_adj), 'classification'),
    early_stop]

ModuleNotFoundError: No module named 'sklearn'

In [675]:
m.fit(train,
      np.array(targets).reshape(-1, 4),
      epochs=100,
      batch_size=32,
      callbacks=callbacks)

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78

Epoch 85/100
Epoch 86/100
Epoch 87/100
Epoch 88/100
Epoch 89/100
Epoch 90/100
Epoch 91/100
Epoch 92/100
Epoch 93/100
Epoch 94/100
Epoch 95/100
Epoch 96/100
Epoch 97/100
Epoch 98/100
Epoch 99/100
Epoch 100/100


<tensorflow.python.keras.callbacks.History at 0x7ff291e9eb20>

In [612]:
targets

[array([0., 1., 0., 0.]),
 array([0., 1., 0., 0.]),
 array([0., 0., 0., 1.]),
 array([0., 1., 0., 0.]),
 array([0., 0., 0., 1.]),
 array([0., 1., 0., 0.]),
 array([0., 0., 0., 1.]),
 array([1., 0., 0., 0.]),
 array([0., 0., 1., 0.]),
 array([0., 0., 1., 0.]),
 array([1., 0., 0., 0.]),
 array([0., 0., 1., 0.]),
 array([1., 0., 0., 0.]),
 array([0., 0., 0., 1.]),
 array([0., 1., 0., 0.]),
 array([0., 0., 1., 0.]),
 array([0., 1., 0., 0.]),
 array([0., 0., 0., 1.]),
 array([0., 1., 0., 0.]),
 array([0., 0., 0., 1.]),
 array([0., 1., 0., 0.]),
 array([0., 0., 0., 1.]),
 array([0., 0., 0., 1.]),
 array([1., 0., 0., 0.]),
 array([1., 0., 0., 0.]),
 array([0., 0., 1., 0.]),
 array([0., 0., 1., 0.]),
 array([0., 0., 1., 0.]),
 array([0., 0., 1., 0.]),
 array([0., 0., 1., 0.]),
 array([0., 1., 0., 0.]),
 array([0., 1., 0., 0.]),
 array([0., 0., 0., 1.]),
 array([0., 1., 0., 0.]),
 array([0., 1., 0., 0.]),
 array([0., 1., 0., 0.]),
 array([0., 1., 0., 0.]),
 array([0., 1., 0., 0.]),
 array([0., 