In [None]:
import numpy as np
from scipy.special import softmax

In [None]:
eps_, max_ = 1e-12, 1e12
sigmoid = lambda x: 1.0 / (1.0 + clip_exp(-x))
logit   = lambda p: np.log(p) - np.log1p(-p)

def clip_exp(x):
    x = np.clip(x, a_min=-max_, a_max=50)
    y = np.exp(x)
    return np.where(y > 1e-11, y, 0)

class simpleBuffer:
    '''Simple Buffer 2.0
    Update log: 
        To prevent naive writing mistakes,
        we turn the list storage into dict.
    '''
    def __init__(self):
        self.m = {}
        
    def push(self, m_dict):
        self.m = {k: m_dict[k] for k in m_dict.keys()}
        
    def sample(self, *args):
        lst = [self.m[k] for k in args]
        if len(lst)==1: return lst[0]
        else: return lst

In [None]:
class HC():
    name = 'Hierarchical Control'
    bnds = [(0,1),(0,5),(0,1),(0,5)]
    pbnds = [(.1,.5),(.1,2),(.1,.5),(.1,2)]
    p_name = ['alpha_goal', 'beta_goal', 'alpha_action', 'beta_action']
    n_params = len(p_name)
    
    p_trans = [lambda x: 0.0 + (1 - 0.0) * sigmoid(x),   
               lambda x: 0.0 + (5 - 0.0) * sigmoid(x),
               lambda x: 0.0 + (1 - 0.0) * sigmoid(x),   
               lambda x: 0.0 + (5 - 0.0) * sigmoid(x)]  
    p_links = [lambda y: logit(np.clip((y - 0.0) / (1 - 0.0), eps_, 1 - eps_)),  
               lambda y: logit(np.clip((y - 0.0) / (5 - 0.0), eps_, 1 - eps_)),
               lambda y: logit(np.clip((y - 0.0) / (1 - 0.0), eps_, 1 - eps_)),  
               lambda y: logit(np.clip((y - 0.0) / (5 - 0.0), eps_, 1 - eps_))]
    
    def __init__(self, env, params):
        self.env = env 
        self._init_mem()
        self._init_critic()
        self._load_params(params)
        self.current_goal = None
    
    def _init_mem(self):
        self.mem = simpleBuffer()
    
    def _load_params(self, params):
        params = [fn(p) for p, fn in zip(params, self.p_trans)]
        self.alpha_goal = params[0]
        self.beta_goal  = params[1]
        self.alpha_action = params[2]
        self.beta_action  = params[3]
    
    def _init_critic(self):
        self.Q_goal = np.zeros([self.env.nS, self.env.nS])
        self.Q_action = np.zeros([self.env.nS, self.env.nA, self.env.nS])
    
    def select_goal(self, s):
        q_goal = self.Q_goal[s, :]
        return softmax(self.beta_goal * q_goal)
    
    def policy(self, s):
        if self.current_goal is None:
            goal_probs = self.select_goal(s)
            self.current_goal = np.random.choice(self.env.nS, p=goal_probs)
        
        q_action = self.Q_action[s, :, self.current_goal]
        return softmax(self.beta_action * q_action)
    
    def eval_act(self, s, a):
        if self.current_goal is None:
            goal_probs = self.select_goal(s)
            self.current_goal = np.random.choice(self.env.nS, p=goal_probs)
        
        q_action = self.Q_action[s, :, self.current_goal]
        prob = softmax(self.beta_action * q_action)
        return prob[int(a)]
    
    def learn(self):
        s, a, s_next, r, g = self.mem.sample(
            's', 'a', 's_next', 'r', 'g')
        
        self.env.set_reward(g)
        
        if s < 5:
            if self.current_goal is None:
                self.current_goal = g
            
            goal_reached = (s_next == self.current_goal)
            
            if goal_reached:
                intrinsic_reward = 1.0
            else:
                intrinsic_reward = 0.0
            
            self.Q_action_old = self.Q_action[s, a, self.current_goal]
            action_target = intrinsic_reward + max(self.Q_action[s_next, :, self.current_goal])
            self.action_PE = action_target - self.Q_action[s, a, self.current_goal]
            self.Q_action[s, a, self.current_goal] += self.alpha_action * self.action_PE
            
            self.Q_goal_old = self.Q_goal[s, self.current_goal]
            goal_target = r + max(self.Q_goal[s_next, :])
            self.goal_PE = goal_target - self.Q_goal[s, self.current_goal]
            self.Q_goal[s, self.current_goal] += self.alpha_goal * self.goal_PE
            
            if goal_reached:
                self.current_goal = None
        
        return self.goal_PE, self.Q_goal_old, self.Q_goal