# REINFORCE : Acrobot

In [1]:
import numpy as np
import gymnasium as gym

In [8]:
env = gym.make("Acrobot-v1", render_mode = None)
n_actions = env.action_space.n
shape_states = env.observation_space.shape

In [9]:
env.reset()

(array([ 0.99830556, -0.05818904,  0.9965354 , -0.08316938, -0.05834483,
        -0.0116648 ], dtype=float32),
 {})

## REINFORCE par paramétrisation

On va utiliser la fonction softmax pour paramétriser la politique. La fonction softmax prend en entrée un vecteur de paramètres $\theta$ et renvoie un vecteur de probabilité de même taille. 

The log-softmax function is defined as follows:

$$\log\left(\frac{\exp(z_i)}{\sum_j \exp(z_j)}\right)$$

where $z_i$ is the $i$-th element of a vector of logits. The gradient of the log-softmax function with respect to the logits is:

$$\frac{\partial}{\partial z_i} \log\left(\frac{\exp(z_i)}{\sum_j \exp(z_j)}\right) = \frac{\partial}{\partial z_i} \left( z_i - \log\left(\sum_j \exp(z_j)\right) \right)$$

To compute this gradient, we can start by taking the derivative of the logarithm term:

$$\frac{\partial}{\partial z_i} \log\left(\sum_j \exp(z_j)\right) = \frac{1}{\sum_j \exp(z_j)} \cdot \frac{\partial}{\partial z_i} \sum_j \exp(z_j)$$

Note that the sum in the denominator is over all the logits $z_j$, and the sum in the numerator is only over the logits that correspond to the $i$-th action. Therefore, we can simplify the expression as follows:

$$\frac{\partial}{\partial z_i} \log\left(\sum_j \exp(z_j)\right) = \frac{\exp(z_i)}{\sum_j \exp(z_j)}$$

Next, we can use the chain rule to compute the derivative of the entire log-softmax function:

$$\frac{\partial}{\partial z_i} \log\left(\frac{\exp(z_i)}{\sum_j \exp(z_j)}\right) = \frac{\partial}{\partial z_i} \left( z_i - \log\left(\sum_j \exp(z_j)\right) \right) = 1 - \frac{\exp(z_i)}{\sum_j \exp(z_j)} = 1 - \mathrm{softmax}(z)_i$$

where $\mathrm{softmax}(z)_i$ is the $i$-th element of the softmax function applied to the logits $z$. Therefore, the gradient of the log-softmax function with respect to the logits is simply the difference between 1 and the probability of selecting the $i$-th action according to the softmax function

In [17]:
def softmax(alpha) :
    proba = np.exp(alpha)
    return proba/proba.sum()

def parametrization(x, theta) :
    return softmax(np.dot(np.transpose(x), theta))

In [51]:
def REINFORCE(env, learning_rate, initial_param, episodes, steps, gamma) :
    
    params = initial_param
    
    for ep in range(episodes) :
        env.reset()
        
        states = []
        actions = []
        rewards = []
        proba = []
        
        for stp in range(steps) :
            
            state = env.state
            states.append(state)

            proba_vector = parametrization(state, params)
            proba.append(proba_vector)

            action = np.random.choice([0, 1, 2], p = proba_vector)
                
            actions.append(action)
            
            observation, reward, terminated, truncated, info = env.step(action)
            
            rewards.append(reward)
            
            if terminated or truncated :
                break
        
        for t in range(len(actions)) :
            G = 0
            for k in range(t+1, len(actions)) :
                G += gamma**(k - t - 1)*rewards[k]

            params[: , actions[t]] += learning_rate*gamma**t*G*(1 - proba[t])[actions[t]]
            
    return params

In [53]:
param_opt = REINFORCE(env, learning_rate = 0.0001, initial_param = np.ones((4, 3)), episodes = 10000, steps = 100, gamma = 0.8)

In [54]:
param_opt

array([[-4.54859299, -4.57261727, -4.54538462],
       [-4.54859299, -4.57261727, -4.54538462],
       [-4.54859299, -4.57261727, -4.54538462],
       [-4.54859299, -4.57261727, -4.54538462]])

In [56]:
env = gym.make("Acrobot-v1", render_mode = 'human')
obs, info = env.reset()

reward = 0
while True :
    
    proba = parametrization(env.state, param_opt)
    action = np.random.choice([0, 1, 2], p = proba)
    
    s, r, done, tr,_ = env.step(action)
    
    reward += r
    
    if terminated or truncated :
        break
        
env.close()

KeyboardInterrupt: 