# Acrobot with Deep Reinforcement Learning

![Alt Text](fig/acrobot.gif)

## Reinforcement Learning

### Illustration

![Alt Text](fig/RL_illustration.png)

*Reference: https://lilianweng.github.io/posts/2018-02-19-rl-overview/*

### Formulation

MDP = $<\mathcal{S}, \mathcal{A}, \mathcal{P}, R, \rho, \gamma>$

- State s $\in \mathcal{S}$
- Action a $\in \mathcal{A}$
- Transition Function $ s^{'} \sim P(s^{'}| s, a) $
- Reward function $r = R(s,a,s^{'})$
- Initial State Transition Function $s_{0} \sim \rho$
- Discount Factor $\gamma$

## Objective

Episode $\tau = (s_{0}, a_{0}, r_{1}, s_{1}, a_{1}, r_{2}, ... )$

$ G = r_{1} + \gamma r_{2} + \gamma^{2} r_{3} + ... = \sum_{t=0}^{\infty} \gamma^{t} r_{t+1} $

## Policy (Actor)

Agent uses a policy to make actions.

$ a \sim \pi(a|s) $

## Value Function (Critic)

Value function helps the agent evaluate actions.

Value function: $v(s) = \mathbb{E}_{\pi} [G_{t} | s_{t}=s]$

Action-value function: $q(s,a) = \mathbb{E}_{\pi} [G_{t} | s_{t}=s, a_{t}=a]$

## Definitions

- **Episode**: transition with finite time $T$
- **Horizon**: $T$ steps
- **Epoch**: learn + evaluate

___

## Mushroom RL

![](fig/mushroom_rl.png)

The three basic interface of mushroom_rl are the Agent, the Environment and the Core interface.

- The **Agent** is the basic interface for any Reinforcement Learning algorithm.

- The **Environment** is the basic interface for every problem/task that the agent should solve.

- The **Core** is a class used to control the interaction between an agent and an environment.

## Import Modules

In [1]:
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from mushroom_rl.algorithms.value import DQN
from mushroom_rl.core import Core, Logger
from mushroom_rl.environments import *
from mushroom_rl.policy import EpsGreedy
from mushroom_rl.approximators.parametric.torch_approximator import *
from mushroom_rl.utils.dataset import compute_J
from mushroom_rl.utils.parameters import Parameter, LinearParameter

from tqdm import tqdm, trange

  logger.warn(
pybullet build time: Nov 28 2023 23:51:11


## Define Environment

In [2]:
horizon = 1000
gamma = 0.99
gamma_eval = 1.
mdp = Gym('Acrobot-v1', horizon, gamma)

## Define Policy

$ a_{greedy} = \text{argmax}_{a}~q(s,a)$

$\pi(a|s) = 
\left\{\begin{matrix}
a_{greedy} & \text{with}~1-\epsilon\\ 
a_{random} & \text{with}~\epsilon
\end{matrix}\right.$

In [3]:
epsilon = LinearParameter(value=1., threshold_value=.01, n=5000)
epsilon_test = Parameter(value=0.)
epsilon_random = Parameter(value=1.)
pi = EpsGreedy(epsilon=epsilon_random)

## Define Agent

In [4]:
# Settings
n_epochs=20
n_steps=1000
n_steps_test=2000

initial_replay_size = 500
max_replay_size = 5000
target_update_frequency = 100
batch_size = 200
n_features = 80
train_frequency = 1

In [5]:
# Network

# q(s, .) = [q1, q2, q3]
class Network(nn.Module):
    def __init__(self, input_shape, output_shape, n_features, **kwargs):
        super().__init__()

        n_input = input_shape[-1]
        n_output = output_shape[0]

        self._h1 = nn.Linear(n_input, n_features)
        self._h2 = nn.Linear(n_features, n_features)
        self._h3 = nn.Linear(n_features, n_output)

        nn.init.xavier_uniform_(self._h1.weight,
                                gain=nn.init.calculate_gain('relu'))
        nn.init.xavier_uniform_(self._h2.weight,
                                gain=nn.init.calculate_gain('relu'))
        nn.init.xavier_uniform_(self._h3.weight,
                                gain=nn.init.calculate_gain('linear'))

    def forward(self, state, action=None):
        features1 = F.relu(self._h1(torch.squeeze(state, 1).float()))
        features2 = F.relu(self._h2(features1))
        q = self._h3(features2)

        if action is None:
            return q
        else:
            action = action.long()
            q_acted = torch.squeeze(q.gather(1, action))

            return q_acted

In [6]:
# Approximator
input_shape = mdp.info.observation_space.shape
approximator_params = dict(network=Network,
                            optimizer={'class': optim.Adam,
                                        'params': {'lr': .001}},
                            loss=F.smooth_l1_loss,
                            n_features=n_features,
                            input_shape=input_shape,
                            output_shape=mdp.info.action_space.size,
                            n_actions=mdp.info.action_space.n)

In [7]:
# Agent
agent = DQN(mdp.info, pi, TorchApproximator,
            approximator_params=approximator_params, batch_size=batch_size,
            initial_replay_size=initial_replay_size,
            max_replay_size=max_replay_size,
            target_update_frequency=target_update_frequency)

## Define Training/ Evaluation Loop

In [8]:
core = Core(agent, mdp)

### Collect random transitions

In [9]:
core.learn(n_steps=initial_replay_size, n_steps_per_fit=initial_replay_size)

                                                   

### Define Logger

In [10]:
logger = Logger(DQN.__name__, results_dir=None)
logger.strong_line()
logger.info('Experiment Algorithm: ' + DQN.__name__)

14/01/2024 23:26:44 [INFO] ###################################################################################################
14/01/2024 23:26:44 [INFO] Experiment Algorithm: DQN


### Evaluate the initial policy

In [11]:
pi.set_epsilon(epsilon_test)
dataset = core.evaluate(n_steps=n_steps_test, render=False)
J = compute_J(dataset, gamma_eval)
logger.epoch_info(0, J=np.mean(J))

                                                     

14/01/2024 23:26:46 [INFO] Epoch 0 | J: -1000.0




### Save the initial agent

In [12]:
agent.save("agents/acrobot_agent_initial.msh", full_save=True)

### Loop

In [17]:
for n in trange(n_epochs):
    pi.set_epsilon(epsilon)
    core.learn(n_steps=n_steps, n_steps_per_fit=train_frequency)
    pi.set_epsilon(epsilon_test)
    dataset = core.evaluate(n_steps=n_steps_test, render=False)
    J = compute_J(dataset, gamma_eval)
    logger.epoch_info(n+1, J=np.mean(J))

  5%|▌         | 1/20 [00:07<02:18,  7.30s/it]

14/01/2024 23:17:52 [INFO] Epoch 1 | J: -1000.0


 10%|█         | 2/20 [00:13<01:59,  6.65s/it]

14/01/2024 23:17:58 [INFO] Epoch 2 | J: -1000.0


 15%|█▌        | 3/20 [00:19<01:49,  6.43s/it]

14/01/2024 23:18:04 [INFO] Epoch 3 | J: -199.1


 20%|██        | 4/20 [00:25<01:41,  6.36s/it]

14/01/2024 23:18:11 [INFO] Epoch 4 | J: -1000.0


 25%|██▌       | 5/20 [00:32<01:34,  6.32s/it]

14/01/2024 23:18:17 [INFO] Epoch 5 | J: -1000.0


 30%|███       | 6/20 [00:38<01:28,  6.36s/it]

14/01/2024 23:18:23 [INFO] Epoch 6 | J: -666.3333333333334


 35%|███▌      | 7/20 [00:45<01:23,  6.41s/it]

14/01/2024 23:18:30 [INFO] Epoch 7 | J: -104.3157894736842


 40%|████      | 8/20 [00:51<01:17,  6.47s/it]

14/01/2024 23:18:36 [INFO] Epoch 8 | J: -499.5


 45%|████▌     | 9/20 [00:58<01:11,  6.48s/it]

14/01/2024 23:18:43 [INFO] Epoch 9 | J: -132.4


 50%|█████     | 10/20 [01:04<01:04,  6.48s/it]

14/01/2024 23:18:49 [INFO] Epoch 10 | J: -1000.0


 55%|█████▌    | 11/20 [01:11<00:58,  6.51s/it]

14/01/2024 23:18:56 [INFO] Epoch 11 | J: -132.4


 60%|██████    | 12/20 [01:17<00:52,  6.55s/it]

14/01/2024 23:19:03 [INFO] Epoch 12 | J: -1000.0


 65%|██████▌   | 13/20 [01:24<00:45,  6.46s/it]

14/01/2024 23:19:09 [INFO] Epoch 13 | J: -1000.0


 70%|███████   | 14/20 [01:30<00:38,  6.45s/it]

14/01/2024 23:19:15 [INFO] Epoch 14 | J: -666.3333333333334


 75%|███████▌  | 15/20 [01:36<00:32,  6.44s/it]

14/01/2024 23:19:22 [INFO] Epoch 15 | J: -1000.0


 80%|████████  | 16/20 [01:43<00:25,  6.39s/it]

14/01/2024 23:19:28 [INFO] Epoch 16 | J: -332.6666666666667


 85%|████████▌ | 17/20 [01:49<00:19,  6.36s/it]

14/01/2024 23:19:34 [INFO] Epoch 17 | J: -499.5


 90%|█████████ | 18/20 [01:55<00:12,  6.32s/it]

14/01/2024 23:19:40 [INFO] Epoch 18 | J: -94.23809523809524


 95%|█████████▌| 19/20 [02:02<00:06,  6.30s/it]

14/01/2024 23:19:47 [INFO] Epoch 19 | J: -499.5


100%|██████████| 20/20 [02:08<00:00,  6.42s/it]

14/01/2024 23:19:53 [INFO] Epoch 20 | J: -82.375





### Save final policy

In [18]:
agent.save("agents/acrobot_agent_final.msh", full_save=True)