Skip to content

Commit

Permalink
Added Boltzmann torch policy
Browse files Browse the repository at this point in the history
- boltzmann torch policy added
- added test
- added example on acrobot with A2C
  • Loading branch information
boris-il-forte committed Jul 30, 2020
1 parent e8fc96a commit 4d20e68
Show file tree
Hide file tree
Showing 7 changed files with 204 additions and 10 deletions.
108 changes: 108 additions & 0 deletions examples/acrobot_a2c.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from mushroom_rl.algorithms.actor_critic import A2C
from mushroom_rl.core import Core
from mushroom_rl.environments import Gym
from mushroom_rl.policy import BoltzmannTorchPolicy
from mushroom_rl.approximators.parametric.torch_approximator import *
from mushroom_rl.utils.dataset import compute_J
from mushroom_rl.utils.parameters import Parameter
from tqdm import tqdm, trange


class Network(nn.Module):
def __init__(self, input_shape, output_shape, n_features, **kwargs):
super(Network, self).__init__()

n_input = input_shape[-1]
n_output = output_shape[0]

self._h1 = nn.Linear(n_input, n_features)
self._h2 = nn.Linear(n_features, n_features)
self._h3 = nn.Linear(n_features, n_output)

nn.init.xavier_uniform_(self._h1.weight,
gain=nn.init.calculate_gain('relu'))
nn.init.xavier_uniform_(self._h2.weight,
gain=nn.init.calculate_gain('relu'))
nn.init.xavier_uniform_(self._h3.weight,
gain=nn.init.calculate_gain('linear'))

def forward(self, state, **kwargs):
features1 = torch.relu(self._h1(torch.squeeze(state, 1).float()))
features2 = torch.relu(self._h2(features1))
a = self._h3(features2)

return a


def experiment(n_epochs, n_steps, n_steps_per_fit, n_step_test):
np.random.seed()

# MDP
horizon = 1000
gamma = 0.99
gamma_eval = 1.
mdp = Gym('Acrobot-v1', horizon, gamma)

# Policy
policy_params = dict(
n_features=32,
use_cuda=False
)

beta = Parameter(1e0)
pi = BoltzmannTorchPolicy(Network,
mdp.info.observation_space.shape,
(mdp.info.action_space.n,),
beta=beta,
**policy_params)

# Agent
critic_params = dict(network=Network,
optimizer={'class': optim.RMSprop,
'params': {'lr': 1e-3,
'eps': 1e-5}},
loss=F.mse_loss,
n_features=32,
batch_size=64,
input_shape=mdp.info.observation_space.shape,
output_shape=(1,))

alg_params = dict(actor_optimizer={'class': optim.RMSprop,
'params': {'lr': 1e-3,
'eps': 3e-3}},
critic_params=critic_params,
#max_grad_norm=10.0,
ent_coeff=0.01
)

agent = A2C(mdp.info, pi, **alg_params)

# Algorithm
core = Core(agent, mdp)

core.learn(n_steps=n_steps, n_steps_per_fit=n_steps_per_fit)

# RUN
dataset = core.evaluate(n_steps=n_step_test, render=False)
J = compute_J(dataset, gamma_eval)
print('J: ', np.mean(J))

for n in trange(n_epochs):
tqdm.write('Epoch: ' + str(n))
core.learn(n_steps=n_steps, n_steps_per_fit=n_steps_per_fit)
dataset = core.evaluate(n_steps=n_step_test, render=False)
J = compute_J(dataset, gamma_eval)
tqdm.write('J: ' + str(np.mean(J)))
# core.evaluate(n_episodes=2, render=True)

print('Press a button to visualize acrobot')
input()
core.evaluate(n_episodes=5, render=True)


if __name__ == '__main__':
experiment(n_epochs=40, n_steps=1000, n_steps_per_fit=5, n_step_test=2000)
8 changes: 5 additions & 3 deletions examples/acrobot_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from mushroom_rl.utils.dataset import compute_J
from mushroom_rl.utils.parameters import Parameter, LinearParameter

from tqdm import tqdm, trange


class Network(nn.Module):
def __init__(self, input_shape, output_shape, n_features, **kwargs):
Expand Down Expand Up @@ -95,14 +97,14 @@ def experiment(n_epochs, n_steps, n_steps_test):
J = compute_J(dataset, gamma_eval)
print('J: ', np.mean(J))

for n in range(n_epochs):
print('Epoch: ', n)
for n in trange(n_epochs):
tqdm.write('Epoch: ' + str(n))
pi.set_epsilon(epsilon)
core.learn(n_steps=n_steps, n_steps_per_fit=train_frequency)
pi.set_epsilon(epsilon_test)
dataset = core.evaluate(n_steps=n_steps_test, render=False)
J = compute_J(dataset, gamma_eval)
print('J: ', np.mean(J))
tqdm.write('J: ' + str(np.mean(J)))

print('Press a button to visualize acrobot')
input()
Expand Down
5 changes: 3 additions & 2 deletions examples/pendulum_a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def experiment(alg, env_id, horizon, gamma, n_epochs, n_steps, n_steps_per_fit,

alg_params['critic_params'] = critic_params


policy = GaussianTorchPolicy(Network,
mdp.info.observation_space.shape,
mdp.info.action_space.shape,
Expand Down Expand Up @@ -101,8 +102,8 @@ def experiment(alg, env_id, horizon, gamma, n_epochs, n_steps, n_steps_per_fit,
(A2C, 'a2c', a2c_params)
]

for alg, alg_name, alg_params in algs_params:
for alg, alg_name, params in algs_params:
experiment(alg=alg, env_id='Pendulum-v0', horizon=200, gamma=.99,
n_epochs=40, n_steps=30000, n_steps_per_fit=5,
n_step_test=5000, alg_params=alg_params,
n_step_test=5000, alg_params=params,
policy_params=policy_params)
4 changes: 2 additions & 2 deletions mushroom_rl/policy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
from .gaussian_policy import GaussianPolicy, DiagonalGaussianPolicy, \
StateStdGaussianPolicy, StateLogStdGaussianPolicy
from .deterministic_policy import DeterministicPolicy
from .torch_policy import TorchPolicy, GaussianTorchPolicy
from .torch_policy import TorchPolicy, GaussianTorchPolicy, BoltzmannTorchPolicy


__all_td__ = ['TDPolicy', 'Boltzmann', 'EpsGreedy', 'Mellowmax']
__all_parametric__ = ['ParametricPolicy', 'GaussianPolicy',
'DiagonalGaussianPolicy', 'StateStdGaussianPolicy',
'StateLogStdGaussianPolicy']
__all_torch__ = ['TorchPolicy', 'GaussianTorchPolicy']
__all_torch__ = ['TorchPolicy', 'GaussianTorchPolicy', 'BoltzmannTorchPolicy']

__all__ = ['Policy', 'DeterministicPolicy', 'OrnsteinUhlenbeckPolicy'] \
+ __all_td__ + __all_parametric__ + __all_torch__
65 changes: 64 additions & 1 deletion mushroom_rl/policy/torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def log_prob_t(self, state, action):
"""
raise NotImplementedError

def entropy_t(self, state=None):
def entropy_t(self, state):
"""
Compute the entropy of the policy.
Expand Down Expand Up @@ -249,3 +249,66 @@ def get_weights(self):

def parameters(self):
return chain(self._mu.model.network.parameters(), [self._log_sigma])


class BoltzmannTorchPolicy(TorchPolicy):
"""
Torch policy implementing a Boltzmann policy.
"""
def __init__(self, network, input_shape, output_shape, beta, use_cuda=False, **params):
"""
Constructor.
Args:
network (object): the network class used to implement the mean
regressor;
input_shape (tuple): the shape of the state space;
output_shape (tuple): the shape of the action space;
beta (Parameter): the inverse of the temperature distribution. As
the temperature approaches infinity, the policy becomes more and
more random. As the temperature approaches 0.0, the policy becomes
more and more greedy.
params (dict): parameters used by the network constructor.
"""
super().__init__(use_cuda)

self._action_dim = output_shape[0]

self._logits = Regressor(TorchApproximator, input_shape, output_shape,
network=network, use_cuda=use_cuda, **params)
self._beta = beta

self._add_save_attr(
_action_dim='primitive',
_beta='pickle',
_logits='mushroom'
)

def draw_action_t(self, state):
action = self.distribution_t(state).sample().detach()
#print(action)
if len(action.shape) > 1:
return action
else:
return action.unsqueeze(0)

def log_prob_t(self, state, action):
return self.distribution_t(state).log_prob(action.squeeze())[:, None]

def entropy_t(self, state):
return torch.mean(self.distribution_t(state).entropy())

def distribution_t(self, state):
logits = self._logits(state, output_tensor=True) * self._beta(state.numpy())
return torch.distributions.Categorical(logits=logits)

def set_weights(self, weights):
self._logits.set_weights(weights)

def get_weights(self):
return self._logits.get_weights()

def parameters(self):
return self._logits.model.network.parameters()
2 changes: 1 addition & 1 deletion tests/distributions/test_gaussian_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_gaussian():
assert np.array_equal(dist.get_parameters(), weights.dot(theta) / np.sum(weights))

entropy = dist.entropy()
assert entropy == 4.749208309037535
assert entropy == 4.74920830903762


def test_diagonal_gaussian():
Expand Down
22 changes: 21 additions & 1 deletion tests/policy/test_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

import numpy as np

from mushroom_rl.policy.torch_policy import TorchPolicy, GaussianTorchPolicy
from mushroom_rl.policy.torch_policy import TorchPolicy, GaussianTorchPolicy, BoltzmannTorchPolicy
from mushroom_rl.utils.parameters import Parameter


def abstract_method_tester(f, *args):
Expand Down Expand Up @@ -73,3 +74,22 @@ def test_gaussian_torch_policy():
assert np.allclose(entropy, entropy_test)


def test_boltzmann_torch_policy():
np.random.seed(88)
torch.manual_seed(88)
beta = Parameter(1.0)
pi = BoltzmannTorchPolicy(Network, (3,), (2,), beta, n_features=50)

state = np.random.rand(3)
action = pi.draw_action(state)
action_test = np.array([0])
assert np.allclose(action, action_test)

p_sa = pi(state, action)
p_sa_test = 0.7594595984401512
assert np.allclose(p_sa, p_sa_test)

states = np.random.rand(1000, 3)
entropy = pi.entropy(states)
entropy_test = 0.5429736971855164
assert np.allclose(entropy, entropy_test)

0 comments on commit 4d20e68

Please sign in to comment.