https://www.youtube.com/watch?v=IS0V8z8HXrM&list=PL-9x0_FO_lgkwi8ES611NsV-cjYaH_nLa&index=2

In [1]:
from YambEnv import ROW, COL, YambEnv, Action
from tensorflow.keras.layers import Dense, Activation, Input
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import CategoricalCrossentropy
import tensorflow.keras.backend as K
import numpy as np
from functools import lru_cache
from typing import List, Tuple

In [2]:
# First of all we need to be able to enumerate all the arrays of size 6 we can have which sum to 5
# Generalize by n different buckets throw in 5 balls
@lru_cache(maxsize=10)
def dp(n, k):
    result = set()
    if k==0:
        result.add((0,)*n)
        return result
    
    recursive_result = dp(n, k-1)
    for tup in recursive_result:
        for i in range(n):
            new_arr = list(tup)
            new_arr[i] += 1
            result.add(tuple(new_arr))
            
    return result


In [3]:
# list of all possible count arrays of dice we can keep
COUNT_ARRAYS = list(reversed(sorted(list(dp(6, 5)))))

In [4]:
class Agent(object):
    def __init__(self, agent_type : int):
        assert 1 <= agent_type <= 3, "Agent must be of type 1, 2 or 3"
        self.agent_type = agent_type
        self.discount_rate = 0.99
        self.learning_rate = 0.01
        self.action_space = self._build_action_space()
        self.input_dim, self.output_dim = self._get_dims()
        assert len(self.action_space) == self.output_dim, "Action space and output of network must be of same dimension"
        self.state_memory = []
        self.action_memory = []
        self.reward_memory = []
        self.policy, self.predict = self._build_policy_network()
        
        self.model_file = 'agent_{}'.format(agent_type)
        
    def _get_dims(self) -> Tuple[int, int]:
        """Get the dimensions of the input to the network and the output to the network
        :return: input_dim, output_dim
        """
        grid_dim = 14 * 4
        announced_dim = 15 # first 14 tell us which row we announced, 15th tells us we did not announce
        roll_input_dim = len(COUNT_ARRAYS)
        
        if self.agent_type == 1:
            return grid_dim + roll_input_dim, len(self.action_space)
        
        if self.agent_type == 2:
            return grid_dim + roll_input_dim + announced_dim, len(self.action_space)
        
        if self.agent_type == 3:
            return grid_dim + roll_input_dim + announced_dim, len(self.action_space)
        
    def _convert_observation_to_input(self, observation : dict) -> np.array:
        """Converts the observation from the environment into something consumable by the neural network
        :param observation: dictionary which comes from the environment
        :return: numpy array in the appropriate format to be consumed by the model as an input
        """
        assert self.agent_type == observation["roll_number"], "Agent type should match roll number"
        grid = np.nan_to_num(observation["grid"].flatten() / 100, nan=-1)
        roll = np.eye(len(COUNT_ARRAYS))[COUNT_ARRAYS.index( tuple(observation["roll"]) )]
        announced = np.eye(15)[observation["announced_row"].value if observation["announced"] else 14]
        
        if self.agent_type == 1:
            ret = np.hstack([grid, roll])
            return ret[np.newaxis, :]
        
        if self.agent_type == 2:
            ret = np.hstack([grid, roll, announced])
            return ret[np.newaxis, :]
        
        if self.agent_type == 3:
            ret = np.hstack([grid, roll, announced])
            return ret[np.newaxis, :]
        
    def _build_action_space(self) -> List[Action]:
        """Build the action space of the particular agent based on it's agent type
        :return: list of actions which are consumable by the environment
        """
        action_space = []
        if self.agent_type == 1:
            for ca in COUNT_ARRAYS:
                for row in ROW:
                    action_space.append(Action(roll_number=1, keep=np.array(ca), announce=True, announce_row=row))
                    
            for ca in COUNT_ARRAYS:
                action_space.append(Action(roll_number=1, keep=np.array(ca), announce=False))
        
        if self.agent_type == 2:
            for ca in COUNT_ARRAYS:
                action_space.append(Action(roll_number=2, keep=np.array(ca)))
            
        if self.agent_type == 3:
            for row in ROW:
                for col in COL:
                    action_space.append(Action(roll_number=3, keep=np.zeros(6), row_to_fill=row, col_to_fill=col))
                    
        return action_space
        
    def _build_policy_network(self) -> Tuple:
        """Build the neural network for the agent who is responsible for actions of type 1
        :return: policy network, prediction network
        """
        x = Input(shape=(self.input_dim,))
        dense = Dense(100, activation='relu')(x)
        probs = Dense(self.output_dim, activation='softmax')(dense)
        policy = Model(inputs=x, outputs=probs, name='policy_network_{}'.format(self.agent_type))
        policy.compile(optimizer=Adam(learning_rate=self.learning_rate), loss=CategoricalCrossentropy())
        predict = Model(inputs=x, outputs=probs)
        return policy, predict

    def choose_action(self, observation):
        """Which action should the model choose?
        :param observation: observation from the environment -> needs to convert to correct format before network uses
        :return: needs to return an Action
        """
        input_to_network = self._convert_observation_to_input(observation)
        output_from_network = self.predict.predict(input_to_network, verbose=0)
        action = np.random.choice(self.action_space, p=output_from_network[0])
        return action

    def store_transition(self, observation, action, reward):
        self.state_memory.append(self._convert_observation_to_input(observation)[0])
        self.action_memory.append(action)
        self.reward_memory.append(reward)

    def learn(self, truncated):
        state_memory = np.array(self.state_memory)
        reward_memory = np.array(self.reward_memory)

        actions = np.zeros((len(self.action_memory), len(self.action_space)))
        for i, a in enumerate(self.action_memory):
            actions[i, self.action_space.index(a)] = 1

        G = np.zeros_like(reward_memory)
        for t in range(len(reward_memory)):
            G_sum = 0
            discount_factor = 1
            for k in range(t, len(reward_memory)):
                G_sum += reward_memory[k] * discount_factor
                discount_factor *= self.discount_rate
            G[t] = G_sum
        mean = np.mean(G)
        std = np.std(G) if np.std(G) > 0 else 1
        adv = (G - mean) / std
        if truncated:
            adv = np.minimum(adv, -1000)
            
        cost = self.policy.train_on_batch(x=state_memory, y=actions, sample_weight=adv)

        self.state_memory = []
        self.action_memory = []
        self.reward_memory = []

        return cost

    def save_model(self):
        self.policy.save(self.model_file)

    def load_model(self):
        self.policy = load_model(self.model_file)

In [None]:
env = YambEnv()
n_episodes = 1
agent_1 = Agent(1)
agent_2 = Agent(2)
agent_3 = Agent(3)
agents = {1: agent_1, 2: agent_2, 3: agent_3}

score_history = []
for i in range(n_episodes):
    observation = env.reset()
    truncated, terminated = False, False
    observations = {1: None, 2: None, 3: None}
    actions = {1: None, 2: None, 3: None}
    rewards = {1: 0, 2: 0, 3: 0}
    score = 0
    roll_number = 1
    while not(terminated or truncated):
        print("Roll number: {}".format(roll_number))
        truncated = True
        while truncated:
            action = agents[roll_number].choose_action(observation)
            observation_new, reward, terminated, truncated, truncation_reason = env.step(action)
            
        actions[roll_number] = action
        observations[roll_number] = observation
        observation = observation_new
        score += reward
        
        if roll_number==3:
            print(observation["grid"])
            print(observation["score"])
            rewards[1] += reward
            rewards[2] += reward
            rewards[3] += reward
            agents[1].store_transition(observations[1], actions[1], rewards[1])
            agents[2].store_transition(observations[2], actions[2], rewards[2])
            agents[3].store_transition(observations[3], actions[3], rewards[3])
            roll_number = 1
            observations = {1: None, 2: None, 3: None}
            actions = {1: None, 2: None, 3: None}
            rewards = {1: 0, 2: 0, 3: 0} 
        elif roll_number == 2:
            rewards[1] += reward
            rewards[2] += reward
            roll_number += 1
        elif roll_number == 1:
            print("Announced {}".format(action.announce))
            rewards[1] += reward
            roll_number += 1
            
    if truncated:
        if observations[1] is not None:
            agents[1].store_transition(observations[1], actions[1], rewards[1])
        if observations[2] is not None:
            agents[2].store_transition(observations[2], actions[2], rewards[2])
        if observations[3] is not None:
            agents[3].store_transition(observations[3], actions[3], rewards[3])
        
    score_history.append(score)
    agent_1.learn(truncated)
    if len(agent_2.state_memory) > 0:
        agent_2.learn(truncated)
    if len(agent_3.state_memory) > 0:
        agent_3.learn(truncated)
    print("Episode: {}, Score: {}, Average score: {}".format(i, score, sum(score_history[-100:]) / 100.0))
        
    

Roll number: 1
Announced True
Roll number: 2
Roll number: 3
[[nan nan nan nan]
 [nan nan nan  0.]
 [nan nan nan nan]
 [nan nan nan nan]
 [nan nan nan nan]
 [nan nan nan nan]
 [nan nan nan nan]
 [nan nan nan nan]
 [nan nan nan nan]
 [nan nan nan nan]
 [nan nan nan nan]
 [nan nan nan nan]
 [nan nan nan nan]
 [nan nan nan nan]]
0.0
Roll number: 1
Announced True
Roll number: 2
Roll number: 3
[[nan nan nan nan]
 [nan nan nan  0.]
 [nan nan nan nan]
 [nan nan nan nan]
 [nan nan nan  5.]
 [nan nan nan nan]
 [nan nan nan nan]
 [nan nan nan nan]
 [nan nan nan nan]
 [nan nan nan nan]
 [nan nan nan nan]
 [nan nan nan nan]
 [nan nan nan nan]
 [nan nan nan nan]]
5.0
Roll number: 1
Announced True
Roll number: 2
Roll number: 3
[[nan nan nan nan]
 [nan nan nan  0.]
 [nan nan nan  3.]
 [nan nan nan nan]
 [nan nan nan  5.]
 [nan nan nan nan]
 [nan nan nan nan]
 [nan nan nan nan]
 [nan nan nan nan]
 [nan nan nan nan]
 [nan nan nan nan]
 [nan nan nan nan]
 [nan nan nan nan]
 [nan nan nan nan]]
8.0
Roll nu

Announced False
Roll number: 2
Roll number: 3
[[ 3. nan  0.  0.]
 [nan nan  0.  0.]
 [nan nan  3.  3.]
 [nan nan nan  0.]
 [nan nan 10.  5.]
 [nan nan  0.  0.]
 [nan nan nan 11.]
 [nan nan 17. 21.]
 [nan nan  0.  0.]
 [nan nan nan 26.]
 [nan nan  0.  0.]
 [nan nan  0.  0.]
 [nan nan  0.  0.]
 [nan nan  0.  0.]]
50.0
Roll number: 1
Announced False
Roll number: 2
Roll number: 3
[[ 3. nan  0.  0.]
 [ 2. nan  0.  0.]
 [nan nan  3.  3.]
 [nan nan nan  0.]
 [nan nan 10.  5.]
 [nan nan  0.  0.]
 [nan nan nan 11.]
 [nan nan 17. 21.]
 [nan nan  0.  0.]
 [nan nan nan 26.]
 [nan nan  0.  0.]
 [nan nan  0.  0.]
 [nan nan  0.  0.]
 [nan nan  0.  0.]]
52.0
Roll number: 1
Announced False
Roll number: 2
Roll number: 3
[[ 3. nan  0.  0.]
 [ 2. nan  0.  0.]
 [nan nan  3.  3.]
 [nan nan nan  0.]
 [nan nan 10.  5.]
 [nan nan  0.  0.]
 [nan nan nan 11.]
 [nan nan 17. 21.]
 [nan nan  0.  0.]
 [nan nan nan 26.]
 [nan nan  0.  0.]
 [nan nan  0.  0.]
 [nan nan  0.  0.]
 [nan  0.  0.  0.]]
52.0
Roll number: 1
A

3780

In [29]:
env = YambEnv()
observation = env.reset()
agent = Agent(1)
# action = agent.choose_action(observation)
# observation_new, reward, terminated, truncated, truncation_reason = env.step(action)
# agent.store_transition(observation, action, reward)
# agent.learn(truncated)

In [32]:
env = YambEnv()
observation = env.reset()
agent = Agent(1)
print(observation['roll'])
print(COUNT_ARRAYS[list(agent._convert_observation_to_input(observation)[0][56:]).index(1)])

[1 0 0 0 2 2]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
(1, 0, 0, 0, 2, 2)


[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]


(0, 2, 2, 1, 0, 0)

[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]


[0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 1.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0