In [None]:
import numpy as np
from scipy.special import softmax

In [None]:
eps_, max_ = 1e-12, 1e12
sigmoid = lambda x: 1.0 / (1.0 + clip_exp(-x))
logit   = lambda p: np.log(p) - np.log1p(-p)

def clip_exp(x):
    x = np.clip(x, a_min=-max_, a_max=50)
    y = np.exp(x)
    return np.where(y > 1e-11, y, 0)

class simpleBuffer:
    '''Simple Buffer 2.0
    Update log: 
        To prevent naive writing mistakes,
        we turn the list storage into dict.
    '''
    def __init__(self):
        self.m = {}
        
    def push(self, m_dict):
        self.m = {k: m_dict[k] for k in m_dict.keys()}
        
    def sample(self, *args):
        lst = [self.m[k] for k in args]
        if len(lst)==1: return lst[0]
        else: return lst

In [None]:
class SR():
    name = 'Successor Representation'
    bnds = [(0,1),(0,5)]
    pbnds = [(.1,.5),(.1,2)]
    p_name   = ['alpha', 'beta']
    n_params = len(p_name) 

    p_trans = [lambda x: 0.0 + (1 - 0.0) * sigmoid(x),   
               lambda x: 0.0 + (5 - 0.0) * sigmoid(x)]  
    p_links = [lambda y: logit(np.clip((y - 0.0) / (1 - 0.0), eps_, 1 - eps_)),  
               lambda y: logit(np.clip((y - 0.0) / (5 - 0.0), eps_, 1 - eps_))] 
    
    def __init__(self, env, params):
        self.env = env 
        self.gamma = 1.0
        self._init_mem()
        self._init_SR()
        self._init_critic()
        self._load_params(params)
    
    def _init_mem(self):
        self.mem = simpleBuffer()

    def _load_params(self, params):
        params = [fn(p) for p, fn in zip(params, self.p_trans)]
        self.alpha = params[0]
        self.beta  = params[1]
    
    def _init_SR(self):
        self.M = np.zeros(
            [self.env.nS, self.env.nA, self.env.nS]
        )
        for s in range(self.env.nS):
            for a in range(self.env.nA):
                self.M[s, a, s] = 1.0
        
    def _init_critic(self):
        self.Q = np.zeros([self.env.nS, self.env.nA])
        self.R = np.zeros(self.env.nS)

    def policy(self, s):
        q = self.Q[s, :]
        return softmax(self.beta * q)

    def eval_act(self, s, a):
        logit = self.Q[s, :] 
        prob  = softmax(self.beta * logit)
        return prob[int(a)]
    
    def learn(self):
        s, a, s_next, r, g = self.mem.sample(
            's', 'a', 's_next', 'r', 'g')
        
        self.env.set_reward(g)
        
        if s < 5:
            self.R[s_next] = r
            
            phi_s = np.zeros(self.env.nS)
            phi_s[s] = 1.0
            
            a_next = np.argmax(self.Q[s_next, :])
            
            SR_target = phi_s + self.gamma * self.M[s_next, a_next, :]
            self.SR_error = SR_target - self.M[s, a, :]
            
            self.M_old = self.M[s, a, :].copy()
            self.M[s, a, :] = self.M[s, a, :] + self.alpha * self.SR_error
            
            self.Q_old = self.Q[s, a]
            self.Q[s, a] = np.dot(self.M[s, a, :], self.R)
            
            for si in range(self.env.nS):
                for ai in range(self.env.nA):
                    self.Q[si, ai] = np.dot(self.M[si, ai, :], self.R)
        
        return np.mean(np.abs(self.SR_error)), self.Q_old, self.Q