In [1]:
from optionCritic.Qlearning import IntraOptionQLearning, IntraOptionActionQLearning
from optionCritic.policies import SoftmaxOptionPolicy, SoftmaxActionPolicy

from distributed.belief import MultinomialDirichletBelief
from distributed.broadcast import Broadcast
from random import shuffle
import numpy as np


class DOC:
    def __init__(self, env, options, mu_policy):
        '''
        :param states_list: all combination of joint states. This is an input from the environment
        :param lr_thea: list of learning rates for learning policy parameters (pi), for all the agents
        :param lr_phi: list of learning rates for learning termination functions (beta), for all the agents
        :param init_observation: list of joint observation of all the agents
        '''

        self.env = env
        self.options = options
    
        '''
        2. Start with initial common belief b
        '''
        # set initial belief
        initial_joint_observation = params['env']['initial_joint_state']
        self.belief = MultinomialDirichletBelief(env, initial_joint_observation)
        #self.b0 = Belief(env)

        '''
        3. Sample a joint state s := vec(s_1,...,s_n) according to b_0
        '''
        self.joint_state = self.belief.sampleJointState()
        print('joint_state',self.joint_state)

        # policy over options
        self.mu_policy = mu_policy
        #print(self.mu_policy)

        self.joint_option = self.chooseOption() #self.chooseOption()
        self.joint_action = self.chooseAction(self.joint_state,self.joint_option)


    def chooseOption(self):
        # Choose joint-option o based on softmax option-policy
        joint_state = tuple(np.sort(self.joint_state))

        joint_option = self.mu_policy.sample(joint_state)
        print('joint_option',joint_option)

        for option in self.options:
            option.available = True

        for option in joint_option:
            self.options[option].available = False

        return joint_option

#     def chooseAction(self):
#         joint_action = []
#         for agent in self.env.agents:
#             print('agent state', agent.state, 'agent option', agent.option)
#             action = self.options[agent.option].policy.sample(agent.state)
#             print('agent ID:', agent.ID, 'state:', agent.state, 'option ID:', agent.option, 'action:', action)
#             agent.action = action
#             joint_action.append(action)

#         return joint_action

    def chooseAction(self, joint_state, joint_option):
        joint_action = []
        for agent in self.env.agents:
            print('agent state', agent.state, 'agent option', agent.option)
            agent.state = joint_state[agent.ID]
            agent.option = joint_option[agent.ID]
            agent_action = self.options[agent.option].policy.sample(agent.state)
            print('agent ID:', agent.ID, 'state:', agent.state, 'option ID:', agent.option, 'agent action:', agent_action)
            agent.action = agent_action
            joint_action.append(agent_action)

        return joint_action


 

    def evaluateOption(self, critic, action_critic, terminations, baseline=False):
        # critic.start(joint_state, joint_option)
        # action_critic.start(joint_state, joint_option, joint_action)
        
        reward, next_true_joint_state, done, _ = self.env.step(joint_action)

        broadcasts = Broadcast(self.env, next_true_joint_state, self.joint_state, self.joint_option,done).broadcastBasedOnQ(critic,reward)

        #broadcasts = self.env.broadcast(reward, next_true_joint_state, self.s, self.o, terminations)
        joint_observation = self.env.get_observation(broadcasts)

        self.belief = MultinomialDirichletBelief(self.env, joint_observation)
        self.joint_state = self.belief.sampleJointState()

        # Critic update
        update_target = critic.update(self.joint_state, self.joint_option, reward, done)
        action_critic.update(self.joint_state, self.joint_option, self.joint_action, reward, done)


        critic_feedback = action_critic.getQvalue(self.joint_state, self.joint_option, self.joint_action)  #Q(s,o,a)

        if baseline:
            critic_feedback -= critic.value(self.joint_state, self.joint_option)
        return critic_feedback


    def improveOption_of_agent(self, agentID, intra_option_policy_improvement, termination_improvement, critic_feedback):
        return intra_option_policy_improvement.update(agent_state, agent_action, critic_feedback), termination_improvement.update(agentID, self.joint_state, self.joint_option)

In [2]:
import itertools
from fourroomsEnv import FourroomsMA
from modelConfig import params
from optionCritic.option import Option, createOptions
from optionCritic.policies import SoftmaxOptionPolicy, SoftmaxActionPolicy
from optionCritic.termination import SigmoidTermination
from optionCritic.Qlearning import IntraOptionQLearning, IntraOptionActionQLearning
import optionCritic.gradients as grads

env = FourroomsMA()
avail_options, mu_policies = createOptions(env)
#print(avail_options[0].optionID, avail_options[0].policy, avail_options[0].termination, avail_options[0].available)

joint_state_list = set([tuple(np.sort(s)) for s in env.states_list])
joint_option_list = list(itertools.permutations(range(params['agent']['n_options']), params['env']['n_agents']))
# joint_action_list = list(itertools.product(range(len(env.agent_actions)), repeat=params['env']['n_agents']))

# mu_policy is the policy over options
mu_weights = dict.fromkeys(joint_state_list, dict.fromkeys(joint_option_list, 0))
# mu_policy = SoftmaxOptionPolicy(mu_weights)


pi_policies = [SoftmaxActionPolicy(len(env.cell_list), len(env.agent_actions)) for _ in avail_options]

# terminations take agent's state (not joint-state)
option_terminations = [SigmoidTermination(len(env.cell_list)) for _ in range(params['agent']['n_options'])]
critic = IntraOptionQLearning(params['env']['discount'], params['doc']['lr_Q'], option_terminations, SoftmaxOptionPolicy(mu_weights).weights)


action_critic = IntraOptionActionQLearning(params['env']['discount'], params['doc']['lr_Q'], option_terminations, SoftmaxActionPolicy(len(env.cell_list), len(env.agent_actions)).weights, critic)

joint_state = DOC(env, avail_options, mu_policies).joint_state
print('joint state', joint_state)
joint_option = DOC(env, avail_options, mu_policies).joint_option
print('joint option', joint_option)
# for agent in env.agents:
#     print('agent ID', agent.ID)
#     agent.state = joint_state[agent.ID]
#     agent.option = joint_option[agent.ID]
    

joint_action = DOC(env, avail_options, mu_policies).joint_action
# for agent in env.agents:
#     agent.action = joint_action[agent.ID]


# test with Agent~1
agent1 = env.agents[0]
print('agent1',agent1, 'agent1 option', agent1.option)

agent1_state = agent1.state
agent1_option = agent1.option
agent1_action = agent1.action

pi_policy_of_agent_option = pi_policies[agent1_option]
agent_option_termination = option_terminations[agent1_option]

intra_option_policy_improvement = grads.IntraOptionGradient(pi_policy_of_agent_option, params['doc']['lr_theta'])
termination_improvement = grads.TerminationGradient(agent_option_termination, critic, params['doc']['lr_phi'])



critic.start(joint_state, joint_option)
action_critic.start(joint_state, joint_option, joint_action)

evalOption = DOC(env, avail_options, mu_policies).evaluateOption(critic, action_critic, option_terminations, baseline=False)
imprvOption = DOC(env, avail_options, mu_policies).improveOption_of_agent(agent1.ID, intra_option_policy_improvement, termination_improvement, evalOption)

joint_state (55, 50, 34)
joint_option (0, 4, 3)
agent state 39 agent option None
agent ID: 0 state: 55 option ID: 0 agent action: 2
agent state 35 agent option None
agent ID: 1 state: 50 option ID: 4 agent action: 1
agent state 66 agent option None
agent ID: 2 state: 34 option ID: 3 agent action: 3
joint state (55, 50, 34)
joint_state (32, 81, 91)
joint_option (3, 4, 1)
agent state 55 agent option 0
agent ID: 0 state: 32 option ID: 3 agent action: 1
agent state 50 agent option 4
agent ID: 1 state: 81 option ID: 4 agent action: 1
agent state 34 agent option 3
agent ID: 2 state: 91 option ID: 1 agent action: 3
joint option (3, 4, 1)
joint_state (6, 90, 80)
joint_option (4, 3, 0)
agent state 32 agent option 3
agent ID: 0 state: 6 option ID: 4 agent action: 3
agent state 81 agent option 4
agent ID: 1 state: 90 option ID: 3 agent action: 1
agent state 91 agent option 1
agent ID: 2 state: 80 option ID: 0 agent action: 2
agent1 <agent.Agent object at 0x111a782e8> agent1 option 4
joint_state (

KeyError: (40, 43, 46)