In [None]:
import numpy as np
from scipy.special import softmax

In [None]:
eps_, max_ = 1e-12, 1e12
sigmoid = lambda x: 1.0 / (1.0 + clip_exp(-x))
logit   = lambda p: np.log(p) - np.log1p(-p)

def clip_exp(x):
    x = np.clip(x, a_min=-max_, a_max=50)
    y = np.exp(x)
    return np.where(y > 1e-11, y, 0)

class simpleBuffer:
    '''Simple Buffer 2.0
    Update log: 
        To prevent naive writing mistakes,
        we turn the list storage into dict.
    '''
    def __init__(self):
        self.m = {}
        
    def push(self, m_dict):
        self.m = {k: m_dict[k] for k in m_dict.keys()}
        
    def sample(self, *args):
        lst = [self.m[k] for k in args]
        if len(lst)==1: return lst[0]
        else: return lst

In [None]:
class MF():
    name = 'Model Free'
    bnds = [(0,1),(0,5)] #边界
    pbnds = [(.1,.5),(.1,2)] #采样边界
    p_name   = ['alpha', 'beta']  #参数名
    n_params = len(p_name) 

    p_trans = [lambda x: 0.0 + (1 - 0.0) * sigmoid(x),   
               lambda x: 0.0 + (5 - 0.0) * sigmoid(x)]  
    p_links = [lambda y: logit(np.clip((y - 0.0) / (1 - 0.0), eps_, 1 - eps_)),  
               lambda y: logit(np.clip((y - 0.0) / (5 - 0.0), eps_, 1 - eps_))] 

    gamma = 1 
    
    def __init__(self,env,params):
        self.env = env 
        self.gamma = MF.gamma
        self._init_mem()
        self._init_critic()
        self._load_params(params)

    def _init_mem(self):
        self.mem = simpleBuffer()

    def _load_params(self, params):
        params = [fn(p) for p, fn in zip(params, self.p_trans)]
        self.alpha = params[0] # learning rate 
        self.beta  = params[1] # inverse temperature 

    def _init_critic(self):
        self.Q = np.zeros([self.env.nS, self.env.nA])
    # ----------- decision ----------- #

    def policy(self, s):
        q = self.Q[s,:]
        return softmax(self.beta*q)

    def eval_act(self, s, a):
        '''Evaluate the probability of given state and action
        '''
        logit = self.Q[s, :] 
        prob  = softmax(self.beta*logit)
        #print(logit)
        return prob[int(a)]
    
        # ----------- learning ----------- #
    
    def learn(self):
        s, a, s_next,a_next, r, done = self.mem.sample(
                        's','a','s_next','a_next','r','done')
        
        if done != True:  
            self.RPE = r + self.gamma*self.Q[s_next,a_next]-self.Q[s,a]
        else:
            self.RPE = r - self.Q[s,a]
        # Q-update
        self.Q_old = self.Q[s,a]
        self.Q[s,a] = self.Q[s,a]+self.alpha*self.RPE

        return self.RPE,self.Q_old,self.Q
    
    def bw_update(self,g):
        return self.Q

In [None]:
class MB():
    name = 'Model Based raw'
    bnds = [(0,1),(0,5)] #边界
    pbnds = [(.1,.5),(.1,2)] #采样边界
    p_name   = ['alpha', 'beta']  #参数名
    n_params = len(p_name) 

    p_trans = [lambda x: 0.0 + (1 - 0.0) * sigmoid(x),   
               lambda x: 0.0 + (5 - 0.0) * sigmoid(x)]  
    p_links = [lambda y: logit(np.clip((y - 0.0) / (1 - 0.0), eps_, 1 - eps_)),  
               lambda y: logit(np.clip((y - 0.0) / (5 - 0.0), eps_, 1 - eps_))] 
    
    def __init__(self,env,params):
        self.env = env 
        self._init_mem()
        self._init_env_model()
        self._init_critic()
        self._load_params(params)
    
    def _init_mem(self):
        self.mem = simpleBuffer()

    def _load_params(self, params):
        params = [fn(p) for p, fn in zip(params, self.p_trans)]
        self.alpha = params[0]
        self.beta  = params[1]
    
    def _init_env_model(self):
        self.T_bel = np.zeros(
            [self.env.nS, self.env.nA, self.env.nS] #(s,a,s')
            ) 
        for s in range(self.env.nS):
            for a in range(self.env.nA):
                self.T_bel[s,a,:] = 1/self.env.nS
        
    def _init_critic(self):
        self.Q = np.zeros([self.env.nS, self.env.nA])

    # ----------- decision ----------- #

    def policy(self, s):
        q = self.Q[s,:]
        return softmax(self.beta*q)

    def eval_act(self, s, a):
        '''Evaluate the probability of given state and action
        '''
        logit = self.Q[s, :] 
        prob  = softmax(self.beta*logit)
        return prob[int(a)]
    
    # ----------- learning ----------- #
    def learn(self):

        s, a, s_next, r, g = self.mem.sample(
                        's','a','s_next','r','g')
        
        self.env.set_reward(g)
        
        if s < 5: 
            self.SPE = 1-self.T_bel[s,a,s_next]
            # T-update (increase T(s,a,s') and decrease T(s,a,-) to ensure the sum=1
            self.T_bel[s,a,s_next]= self.T_bel[s,a,s_next] + self.alpha * self.SPE
            array_rest = np.where(np.arange(self.env.nS) != s_next)[0] #the rest of the states
            for j in array_rest:
                self.T_bel[s,a,j] = self.T_bel[s,a,j]*(1-self.alpha) #SPE = 0-self.T


            # compute the prediction error 
            Q_sum=0
            Q_sum=Q_sum+self.T_bel[s,a,s_next]*(r+max(self.Q[s_next,:])) # for (s,a,s')
            
            for j in array_rest: # for the rest
                Q_sum=Q_sum+self.T_bel[s,a,j]*(self.env.R[j]+max(self.Q[j,:]))
                
            self.Q_old = self.Q[s,a]
            self.Q[s,a] = Q_sum
        return self.SPE,self.Q_old,self.Q
    
    def bw_update(self,g): #update all Q(s) when switch reward target
        self.env.set_reward(g)

        Q_bwd_before = self.Q
        for i in range(max(self.env.level),0,-1):
            state_ind_set = np.where(self.env.level==i)
            for l in range(len(state_ind_set)):
                current_S = state_ind_set[l]
                for current_A in range(self.env.nA):
                    tmp_sum=0
                    for j in range(self.env.nS): 
                        tmp_sum = tmp_sum + self.T_bel[current_S,current_A,j] * \
                                (self.env.R[j]+max(self.Q[j,:]))
                    self.Q[current_S,current_A]=tmp_sum
        Q_bwd_after = self.Q
        self.dQ = Q_bwd_after - Q_bwd_before
        self.dQ_bwd_energy = np.sqrt(np.sum(np.sum((self.dQ)**2)))
        self.dQ_mean_energy = np.mean(np.mean(self.dQ))
        return self.Q

In [None]:
class MB4MDT():
    name = 'Model Based'
    bnds = [(0,1),(0,5)] #边界
    pbnds = [(.1,.5),(.1,2)] #采样边界
    p_name   = ['alpha', 'beta']  #参数名
    n_params = len(p_name) 

    p_trans = [lambda x: 0.0 + (1 - 0.0) * sigmoid(x),   
               lambda x: 0.0 + (5 - 0.0) * sigmoid(x)]  
    p_links = [lambda y: logit(np.clip((y - 0.0) / (1 - 0.0), eps_, 1 - eps_)),  
               lambda y: logit(np.clip((y - 0.0) / (5 - 0.0), eps_, 1 - eps_))] 


    def __init__(self,env,params):
        self.env = env 
        self._init_mem()
        self._init_env_model()
        self._init_critic()
        self._load_params(params)
    
    def _init_mem(self):
        self.mem = simpleBuffer()

    def _load_params(self, params):
        params = [fn(p) for p, fn in zip(params, self.p_trans)]
        self.alpha = params[0]
        self.beta  = params[1]
    
    def _init_env_model(self):
        self.T_bel = np.zeros(
            [self.env.nS, self.env.nA, self.env.nS] #(s,a,s')
            ) 
        self.T = {0:[[1,2], [6,7], [7,8], [6,5], [6,8]],
                  1:[[3,4], [7,8], [6,8], [5,8], [8,5]]}

        for a in range(self.env.nA):
            for s in range(len(self.T[a])):
                self.T_bel[s,a,self.T[a][s]]=[0.5,1-0.5]        
        
    def _init_critic(self):
        self.Q = np.zeros([self.env.nS, self.env.nA])

    # ----------- decision ----------- #

    def policy(self, s):
        q = self.Q[s,:]
        return softmax(self.beta*q)

    def eval_act(self, s, a):
        '''Evaluate the probability of given state and action
        '''
        logit = self.Q[s, :] 
        prob  = softmax(self.beta*logit)
        return prob[int(a)]
    
    # ----------- learning ----------- #
    def learn(self):

        s, a, s_next, r, g = self.mem.sample(
                        's','a','s_next','r','g')
        
        self.env.set_reward(g)

        if s < 5: 
            self.SPE=1-self.T_bel[s,a,s_next]
            # T-update (increase T(s,a,s') and decrease T(s,a,-) to ensure the sum=1
            self.T_bel[s,a,s_next]= self.T_bel[s,a,s_next] + self.alpha * self.SPE
            s_unchosen = [elem for elem in self.T[a][s] if elem != s_next] # the rest of the states
            for j in s_unchosen:
                self.T_bel[s,a,j] = self.T_bel[s,a,j]*(1-self.alpha) #SPE = 0-self.T

            # Q update
            Q_sum=0
            Q_sum=Q_sum+self.T_bel[s,a,s_next]*(r+max(self.Q[s_next,:])) # for (s,a,s')
            
            for j in s_unchosen: # for the rest
                Q_sum=Q_sum+self.T_bel[s,a,j]*(self.env.R[j]+max(self.Q[j,:]))
                
            self.Q_old = self.Q[s,a] 
            self.Q[s,a]= Q_sum

            #print(SPE, s, a, r, self.alpha, self.Q_old, self.Q[s, a])
        return self.SPE,self.Q_old,self.Q
    
    def bw_update(self,g): #update all Q(s) when switch reward target
        self.env.set_reward(g)

        Q_bwd_before = self.Q
        for i in range(max(self.env.level),0,-1):
            state_ind_set = np.where(self.env.level==i)
            for l in range(len(state_ind_set)):
                current_S = state_ind_set[l]
                for current_A in range(self.env.nA):
                    tmp_sum=0
                    for j in range(self.env.nS): 
                        tmp_sum = tmp_sum + self.T_bel[current_S,current_A,j] * \
                                (self.env.R[j]+max(self.Q[j,:]))
                    self.Q[current_S,current_A]=tmp_sum
        Q_bwd_after = self.Q
        self.dQ = Q_bwd_after - Q_bwd_before
        self.dQ_bwd_energy = np.sqrt(np.sum(np.sum((self.dQ)**2)))
        self.dQ_mean_energy = np.mean(np.mean(self.dQ))
        return self.Q