
## Generalized Advantage Estimation in Reinforcement Learning

*****************************************************************************
### References:
[1] Schulman, J., Moritz, P., Levine, S., Jordan, M., & Abbeel, P. (2015). High-dimensional continuous control using generalized advantage estimation. arXiv preprint arXiv:1506.02438.


[2] The code in this notebook demo is adapted based on - https://colab.research.google.com/drive/1Wb_2zKgAqhI2tVK19Y1QC8AHImrzlcme?usp=sharing

*****************************************************************************
*Comments of this notebook to be added later*


**Some Transition Dynamics**

In [2]:
import numpy as np

In [3]:
P = np.array([[[1, 0, 0, 0, 0, 0, 0],    # s_0 terminal \
               [1, 0, 0, 0, 0, 0, 0],    # s_1           |
               [0, 1, 0, 0, 0, 0, 0],    # s_2           |
               [0, 0, 1, 0, 0, 0, 0],    # s_3           |-> for action a_0 i.e. left          
               [0, 0, 0, 1, 0, 0, 0],    # s_4           |
               [0, 0, 0, 0, 1, 0, 0],    # s_5           |
               [0, 0, 0, 0, 0, 0, 1]],   # s_6 terminal /
              
              [[1, 0, 0, 0, 0, 0, 0],    # s_0 terminal \
               [0, 0, 1, 0, 0, 0, 0],    # s_1           |
               [0, 0, 0, 1, 0, 0, 0],    # s_2           |
               [0, 0, 0, 0, 1, 0, 0],    # s_3           |-> for action a_1 i.e. right
               [0, 0, 0, 0, 0, 1, 0],    # s_4           |
               [0, 0, 0, 0, 0, 0, 1],    # s_5           |
               [0, 0, 0, 0, 0, 0, 1]]])  # s_6 terminal /
#  State:       0  1  2  3  4  5  6

r = np.array([[0, 0],  # s_0
              [0, 0],  # s_1
              [0, 0],  # s_2
              [0, 0],  # s_3
              [0, 0],  # s_4
              [0, 1],  # s_5
              [0, 0]]) # s_6
# Action:    a_0  a_1

pi = np.array([[0.5, 0.5],  # s_0
               [0.5, 0.5],  # s_1
               [0.5, 0.5],  # s_2
               [0.5, 0.5],  # s_3
               [0.5, 0.5],  # s_4
               [0.5, 0.5],  # s_5
               [0.5, 0.5]]) # s_6
# Action:       a_0  a_1

Given a linear chain, the step function will calculate the next stochastic state and reward automatically based on the action and the transition dynamics.

In [4]:
class LinearChain():
    def __init__(self, P, r, start_state, terminal_states, noise=0):
        self.P = P
        self.r = r
        self.noise = noise
        self.n = P.shape[-1]
        self.start_state = start_state
        self.terminal_states = terminal_states

        self.observation_space = self.n
        self.action_space = 2
        self.state = None

        self.t = 0

    def reset(self):
        self.state = self.start_state
        return self.state

    def step(self, action):
        if self.state is None:
            raise Exception('step() used before calling reset()')
        assert action in range(P.shape[0])

        reward = r[self.state, action] \
            + np.random.normal(loc=0, scale=self.noise)
        self.state = np.random.choice(a=self.n, p=self.P[action, self.state])
        self.t = self.t + 1

        done = False
        if self.state in self.terminal_states:
            done = True

        return self.state, reward, done, {}

In [None]:
class Agent():

    def __init__(self, num_actions, policy_features, value_features,
                 policy_stepsize, value_stepsize, nstep, lambda_, gamma,
                 FLAG_BASELINE, FLAG_POPULAR_PG=False):
        self.policy_features = policy_features
        self.value_features = value_features
        self.num_actions = num_actions

        self.policy_weight = np.zeros((policy_features.shape[1],
                                       num_actions))
        if value_features is None:
            self.value_weight = None
        else:
            self.value_weight = np.zeros((value_features.shape[1], 1))

        self.policy_stepsize = policy_stepsize
        self.value_stepsize = value_stepsize

        self.FLAG_BASELINE = FLAG_BASELINE
        self.FLAG_POPULAR_PG = FLAG_POPULAR_PG
        self.gamma = gamma
        # Parameter for calculating the generalized advantage.
        self.lambda_ = lambda_
        self.nstep = nstep

        self.pi = None
        self.FLAG_POLICY_UPDATED = True

    @staticmethod
    def softmax(x):
        e_x = np.exp(x - np.max(x, 1).reshape(-1, 1))
        out = e_x / e_x.sum(1).reshape(-1, 1)
        return out

    # At a given state, use the existing stochastic policy to decide which 
    # action to take.
    def take_action(self, state):
        if self.FLAG_POLICY_UPDATED:
            action_prefs = np.matmul(self.policy_features, self.policy_weight)
            self.pi = self.softmax(action_prefs)
            self.FLAG_POLICY_UPDATED = False
            
        action = np.random.choice(self.num_actions, p=self.pi[state])
        return action, self.pi[state, action]

    # Use the current value functions to make predictions.
    def calc_v_pi_pred(self):
        return np.matmul(self.value_features, self.value_weight)

    # =========================================================================
    # Calculate the advantage for a specific step.
    def calc_advantage(self, curr_state, next_state, reward):
        return reward + (next_state @ self.value_weight - \
                         curr_state @ self.value_weight)

    # After taking n steps forward and getting the trajectory,  calculate the 
    # advantage at each timestep t. 
    def calc_generalized_advantage(self, t, traj, v_pi):
        reward_list = traj['reward_list']
        next_state_list = traj['next_state_list']
        traj_length = len(reward_list)

        nstep = self.nstep
        assert nstep  == 'inf' or nstep > 0
        if nstep == 'inf' or nstep > traj_length:
            nstep = traj_length

        GAE = 0
        discount = 1
        for i in range(t, min(t+nstep, traj_length)):
            GAE += discount * self.calc_advantage(next_state_list[i], 
                                                  next_state_list[i+1], 
                                                  reward_list[i])
            discount *= self.gamma * self.lambda_
        # i = min(t+nstep, traj_length) - 1
        # nstep_return += discount * v_pi[next_state_list[i]]
            
        return GAE
    # =========================================================================
    
    # # Using the difference between the accumulated n-step reward and the estimated
    # # value predicted by the value function, update the value function parameter.
    # def update_value_fn(self, traj):
    #     state_list = traj['state_list']
    #     traj_length = len(state_list)

    #     for t in range(traj_length):
    #         state = state_list[t]
    #         v_pi = self.calc_v_pi_pred()
    #         G = self.calc_nstep_return(t, traj, v_pi)
                    
    #         v_pred = v_pi[state]
    #         self.value_weight = self.value_weight \
    #             + self.value_stepsize * (G - v_pred) \
    #             * self.value_features[state].reshape(self.value_weight.shape)

    # # helper function for calculating the policy gradient.
    # def calc_grad_log_pi(self, state, action):
    #     x = self.policy_features[state].reshape(-1, 1)
    #     action_prefs = np.matmul(x.T, self.policy_weight)
    #     pi = self.softmax(action_prefs).T

    #     I_action = np.zeros((self.num_actions, 1))
    #     I_action[action] = 1

    #     one_vec = np.ones((1, self.num_actions))

    #     return np.matmul(x, one_vec) * (I_action - pi).T

    # # Calculate the REINFORCE based policy gradient.
    # def calc_reinforce_pg(self, traj, v_pi):
    #     state_list = traj['state_list']
    #     action_list = traj['action_list']
    #     traj_length = len(state_list)
        
    #     policy_grad = np.zeros(self.policy_weight.shape)
    #     for t in range(traj_length):
    #         state = state_list[t]
    #         action = action_list[t]
    #         G = self.calc_nstep_return(t, traj, v_pi)
    #         grad_log_pi = self.calc_grad_log_pi(state, action)
            
    #         if self.FLAG_BASELINE:
    #             baseline = v_pi[state]
    #         else:
    #             baseline = 0

    #         if self.FLAG_POPULAR_PG == False:
    #             policy_grad += self.gamma**t * (G - baseline) * grad_log_pi
    #         else:
    #             policy_grad += (G - baseline) * grad_log_pi

    #     return policy_grad

    # # Use the policy gradient to update the policy function.
    # def update_policy(self, traj, v_pi):
    #     policy_grad = self.calc_reinforce_pg(traj, v_pi)
        
    #     self.policy_weight = self.policy_weight \
    #         + self.policy_stepsize * policy_grad

    #     self.FLAG_POLICY_UPDATED = True