Skip to content

Commit

Permalink
perf: optimized c51, rainbow (#34)
Browse files Browse the repository at this point in the history
  • Loading branch information
StepNeverStop committed Aug 27, 2021
1 parent 71115ea commit d20cf42
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 156 deletions.
88 changes: 33 additions & 55 deletions rls/algorithms/single/c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ def __init__(self,
**kwargs):
super().__init__(**kwargs)
assert not self.is_continuous, 'c51 only support discrete action space'
self.v_min = v_min
self.v_max = v_max
self.atoms = atoms
self.delta_z = (self.v_max - self.v_min) / (self.atoms - 1)
self.z = t.tensor(
[self.v_min + i * self.delta_z for i in range(self.atoms)]).float().to(self.device) # [N,]
self._v_min = v_min
self._v_max = v_max
self._atoms = atoms
self._delta_z = (self._v_max - self._v_min) / (self._atoms - 1)
self._z = t.linspace(self._v_min, self._v_max,
self._atoms).float().to(self.device) # [N,]
self.expl_expt_mng = ExplorationExploitationClass(eps_init=eps_init,
eps_mid=eps_mid,
eps_final=eps_final,
Expand All @@ -50,7 +50,7 @@ def __init__(self,
self.q_net = TargetTwin(C51Distributional(self.obs_spec,
rep_net_params=self._rep_net_params,
action_dim=self.a_dim,
atoms=self.atoms,
atoms=self._atoms,
network_settings=network_settings)).to(self.device)
self.oplr = OPLR(self.q_net, lr)
self._trainer_modules.update(model=self.q_net,
Expand All @@ -61,70 +61,48 @@ def select_action(self, obs):
if self._is_train_mode and self.expl_expt_mng.is_random(self.cur_train_step):
actions = np.random.randint(0, self.a_dim, self.n_copys)
else:
feat = self.q_net(obs, cell_state=self.cell_state) # [B, N, A]
feat = self.q_net(obs, cell_state=self.cell_state) # [B, A, N]
self.next_cell_state = self.q_net.get_cell_state()
feat = feat.swapaxes(-1, -2) # [B, N, A] => [B, A, N]
q = (self.z * feat).sum(-1) # [B, A, N] * [N,] => [B, A]
q = (self._z * feat).sum(-1) # [B, A, N] * [N,] => [B, A]
actions = q.argmax(-1) # [B,]
return actions, Data(action=actions)

@iTensor_oNumpy
def _train(self, BATCH):
time_step = BATCH.reward.shape[0]
batch_size = BATCH.reward.shape[1]
indexes = t.arange(time_step*batch_size).view(-1, 1) # [T*B, 1]
q_dist = self.q_net(BATCH.obs) # [T, B, A, N]
# [T, B, A, N] * [T, B, A, 1] => [T, B, A, N] => [T, B, N]
q_dist = (q_dist * BATCH.action.unsqueeze(-1)).sum(-2)

q_dist = self.q_net(BATCH.obs) # [T, B, N, A]
q_dist = q_dist.permute(2, 0, 1, 3) # [T, B, N, A] => [N, T, B, A]
q_dist = (q_dist * BATCH.action).sum(-1) # [N, T, B, A] => [N, T, B]
q_dist = q_dist.permute(1, 2, 0) # [N, T, B] => [T, B, N]
q_eval = (q_dist * self.z).sum(-1) # [T, B, N] * [N,] => [T, B]
q_dist = q_dist.view(-1, self.atoms) # [T, B, N] => [T*B, N]
q_eval = (q_dist * self._z).sum(-1) # [T, B, N] * [N,] => [T, B]

target_q_dist = self.q_net.t(BATCH.obs) # [T, B, N, A]
# [T, B, N, A] => [T, B, A, N] * [1, N] => [T, B, A]
target_q = (target_q_dist.swapaxes(-1, -2) * self.z).sum(-1)
target_q_dist = self.q_net.t(BATCH.obs) # [T, B, A, N]
# [T, B, A, N] * [1, N] => [T, B, A]
target_q = (target_q_dist * self._z).sum(-1)
a_ = target_q.argmax(-1) # [T, B]
a_onehot = t.nn.functional.one_hot(a_, self.a_dim).float() # [T, B, A]
target_q_dist = target_q_dist.permute(
2, 0, 1, 3) # [T, B, N, A] => [N, T, B, A]
# [N, T, B, A] => [N, T, B]
target_q_dist = (target_q_dist * a_onehot).sum(-1)
target_q_dist = target_q_dist.permute(
1, 2, 0) # [N, T, B] => [T, B, N]
# [T, B, N] => [T*B, N]
target_q_dist = target_q_dist.view(-1, self.atoms)
# [T, B, A, N] * [T, B, A, 1] => [T, B, A, N] => [T, B, N]
target_q_dist = (target_q_dist * a_onehot.unsqueeze(-1)).sum(-2)

target = q_target_func(BATCH.reward.repeat(1, 1, self.atoms),
target = q_target_func(BATCH.reward.repeat(1, 1, self._atoms),
self.gamma,
BATCH.done.repeat(1, 1, self.atoms),
self.z.view(1, 1, self.atoms).repeat(
time_step, batch_size, 1),
BATCH.begin_mask,
BATCH.done.repeat(1, 1, self._atoms),
target_q_dist,
BATCH.begin_mask.repeat(1, 1, self._atoms),
use_rnn=self.use_rnn) # [T, B, N]
target = target.clamp(self.v_min, self.v_max) # [T, B, N]
target = target.view(-1, self.atoms) # [T, B, N] => [T*B, N]
b = (target - self.v_min) / self.delta_z # [T*B, N]
u, l = b.ceil(), b.floor() # [T*B, N]
u_id, l_id = u.long(), l.long() # [T*B, N]
u_minus_b, b_minus_l = u - b, b - l # [T*B, N]
target = target.clamp(self._v_min, self._v_max) # [T, B, N]
# An amazing trick for calculating the projection gracefully.
# ref: https://github.com/ShangtongZhang/DeepRL
target_dist = (1 - (target.unsqueeze(-1) -
self._z.view(1, 1, -1, 1)).abs() / self._delta_z
).clamp(0, 1) * target_q_dist.unsqueeze(-1) # [T, B, N, 1]
target_dist = target_dist.sum(-1) # [T, B, N]

index_help = indexes.repeat(1, self.atoms) # [T*B, 1] => [T*B, N]
index_help = index_help.unsqueeze(-1) # [T*B, N, 1]
u_id = t.cat([index_help, u_id.unsqueeze(-1)], -1) # [T*B, N, 2]
l_id = t.cat([index_help, l_id.unsqueeze(-1)], -1) # [T*B, N, 2]
u_id = u_id.long().permute(2, 0, 1) # [2, T*B, N]
l_id = l_id.long().permute(2, 0, 1) # [2, T*B, N]
_cross_entropy = (target_q_dist * u_minus_b).detach() * q_dist[list(l_id)].log()\
+ (target_q_dist * b_minus_l).detach() * \
q_dist[list(u_id)].log() # [T*B, N]
td_error = cross_entropy = - \
_cross_entropy.sum(-1).view(time_step, batch_size) # [T, B]

loss = (cross_entropy*BATCH.get('isw', 1.0)).mean() # 1
_cross_entropy = - (target_dist * t.log(q_dist +
t.finfo().eps)).sum(-1, keepdim=True) # [T, B, 1]
loss = (_cross_entropy*BATCH.get('isw', 1.0)).mean() # 1

self.oplr.step(loss)
return td_error, dict([
return _cross_entropy, dict([
['LEARNING_RATE/lr', self.oplr.lr],
['LOSS/loss', loss],
['Statistics/q_max', q_eval.max()],
Expand Down
9 changes: 4 additions & 5 deletions rls/algorithms/single/iqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

from rls.algorithms.base.sarl_off_policy import SarlOffPolicy
from rls.utils.expl_expt import ExplorationExploitationClass
from rls.utils.torch_utils import (huber_loss,
q_target_func)
from rls.utils.torch_utils import q_target_func
from rls.nn.models import IqnNet
from rls.nn.utils import OPLR
from rls.common.decorator import iTensor_oNumpy
Expand Down Expand Up @@ -164,10 +163,10 @@ def _train(self, BATCH):
1, 2, 0, 3) # [N, T, B, 1] => [T, B, N, 1]
# [T, B, N, 1] - [T, B, 1, N'] => [T, B, N, N']
quantile_error = quantiles_value_online - quantiles_value_target
huber = huber_loss(
quantile_error, delta=self.huber_delta) # [T, B, N, N']
huber = t.nn.functional.huber_loss(
quantiles_value_online, quantiles_value_target, reduction="none", delta=self.huber_delta) # [T, B, N, N]
# [T, B, N, 1] - [T, B, N, N'] => [T, B, N, N']
huber_abs = (quantiles - t.where(quantile_error < 0, 1., 0.)).abs()
huber_abs = (quantiles - quantile_error.detach().le(0.).float()).abs()
loss = (huber_abs * huber).mean(-1) # [T, B, N, N'] => [T, B, N]
loss = loss.sum(-1, keepdim=True) # [T, B, N] => [T, B, 1]

Expand Down
16 changes: 11 additions & 5 deletions rls/algorithms/single/npg.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,28 +133,34 @@ def _train(self, BATCH):
if self.is_continuous:
mu, log_std = output # [T, B, A], [T, B, A]
dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
new_log_prob = dist.log_prob(BATCH.action).unsqueeze(-1) # [T, B, 1]
new_log_prob = dist.log_prob(
BATCH.action).unsqueeze(-1) # [T, B, 1]
entropy = dist.entropy().mean() # 1
else:
logits = output # [T, B, A]
logp_all = logits.log_softmax(-1) # [T, B, A]
new_log_prob = (BATCH.action * logp_all).sum(-1, keepdim=True) # [T, B, 1]
new_log_prob = (BATCH.action * logp_all).sum(-1,
keepdim=True) # [T, B, 1]
entropy = -(logp_all.exp() * logp_all).sum(-1).mean() # 1
ratio = (new_log_prob - BATCH.log_prob).exp() # [T, B, 1]
actor_loss = -(ratio * BATCH.gae_adv).mean() # 1

flat_grads = self._get_flat_grad(actor_loss, self.actor, retain_graph=True).detach() # [1,]
flat_grads = self._get_flat_grad(
actor_loss, self.actor, retain_graph=True).detach() # [1,]

if self.is_continuous:
kl = td.kl_divergence(
td.Independent(td.Normal(BATCH.mu, BATCH.log_std.exp()), 1),
td.Independent(td.Normal(mu, log_std.exp()), 1)
).mean()
else:
kl = (BATCH.logp_all.exp() * (BATCH.logp_all - logp_all)).sum(-1).mean() # 1
kl = (BATCH.logp_all.exp() * (BATCH.logp_all - logp_all)
).sum(-1).mean() # 1

flat_kl_grad = self._get_flat_grad(kl, self.actor, create_graph=True)
search_direction = -self._conjugate_gradients(flat_grads, flat_kl_grad, cg_iters=self._cg_iters) # [1,]
search_direction = - \
self._conjugate_gradients(
flat_grads, flat_kl_grad, cg_iters=self._cg_iters) # [1,]

with t.no_grad():
flat_params = t.cat([param.data.view(-1)
Expand Down
14 changes: 8 additions & 6 deletions rls/algorithms/single/qrdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

from rls.algorithms.base.sarl_off_policy import SarlOffPolicy
from rls.utils.expl_expt import ExplorationExploitationClass
from rls.utils.torch_utils import (huber_loss,
q_target_func)
from rls.utils.torch_utils import q_target_func
from rls.nn.models import QrdqnDistributional
from rls.nn.utils import OPLR
from rls.common.decorator import iTensor_oNumpy
Expand Down Expand Up @@ -92,13 +91,16 @@ def _train(self, BATCH):
q_target = target.mean(-1, keepdim=True) # [T, B, 1]
td_error = q_target - q_eval # [T, B, 1], used for PER

target = target.unsqueeze(-2) # [T, B, 1, N]
q_dist = q_dist.unsqueeze(-1) # [T, B, N, 1]

# [T, B, 1, N] - [T, B, N, 1] => [T, B, N, N]
quantile_error = target.unsqueeze(-2) - q_dist.unsqueeze(-1)
huber = huber_loss(
quantile_error, delta=self.huber_delta) # [T, B, N, N]
quantile_error = target - q_dist
huber = t.nn.functional.huber_loss(
target, q_dist, reduction="none", delta=self.huber_delta) # [T, B, N, N]
# [N,] - [T, B, N, N] => [T, B, N, N]
huber_abs = (self.quantiles -
t.where(quantile_error < 0, 1., 0.)).abs()
quantile_error.detach().le(0.).float()).abs()
loss = (huber_abs * huber).mean(-1) # [T, B, N, N] => [T, B, N]
loss = loss.sum(-1, keepdim=True) # [T, B, N] => [T, B, 1]
loss = (loss*BATCH.get('isw', 1.0)).mean() # 1
Expand Down
70 changes: 26 additions & 44 deletions rls/algorithms/single/rainbow.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ def __init__(self,
**kwargs):
super().__init__(**kwargs)
assert not self.is_continuous, 'rainbow only support discrete action space'
self.v_min = v_min
self.v_max = v_max
self.atoms = atoms
self.delta_z = (self.v_max - self.v_min) / (self.atoms - 1)
self.z = t.tensor(
[self.v_min + i * self.delta_z for i in range(self.atoms)]).float().to(self.device) # [N,]
self._v_min = v_min
self._v_max = v_max
self._atoms = atoms
self._delta_z = (self._v_max - self._v_min) / (self._atoms - 1)
self._z = t.linspace(self._v_min, self._v_max,
self._atoms).float().to(self.device) # [N,]
self.expl_expt_mng = ExplorationExploitationClass(eps_init=eps_init,
eps_mid=eps_mid,
eps_final=eps_final,
Expand All @@ -59,7 +59,7 @@ def __init__(self,
self.rainbow_net = TargetTwin(RainbowDueling(self.obs_spec,
rep_net_params=self._rep_net_params,
action_dim=self.a_dim,
atoms=self.atoms,
atoms=self._atoms,
network_settings=network_settings)).to(self.device)
self.oplr = OPLR(self.rainbow_net, lr)
self._trainer_modules.update(model=self.rainbow_net,
Expand All @@ -73,67 +73,49 @@ def select_action(self, obs):
q_values = self.rainbow_net(
obs, cell_state=self.cell_state) # [B, A, N]
self.next_cell_state = self.rainbow_net.get_cell_state()
q = (self.z * q_values).sum(-1) # [B, A, N] * [N, ] => [B, A]
q = (self._z * q_values).sum(-1) # [B, A, N] * [N, ] => [B, A]
actions = q.argmax(-1) # [B,]
return actions, Data(action=actions)

@iTensor_oNumpy
def _train(self, BATCH):
time_step = BATCH.reward.shape[0]
batch_size = BATCH.reward.shape[1]
indexes = t.arange(time_step*batch_size).view(-1, 1) # [T*B, 1]

q_dist = self.rainbow_net(BATCH.obs) # [T, B, A, N]
# [T, B, A, N] * [T, B, A, 1] => [T, B, A, N] => [T, B, N]
q_dist = (q_dist * BATCH.action.unsqueeze(-1)).sum(-2)
# [T, B, N] * [N, ] => [T, B, N] => [T, B]
q_eval = (q_dist * self.z).sum(-1)
q_dist = q_dist.view(-1, self.atoms) # [T, B, N] => [T*B, N]
q_eval = (q_dist * self._z).sum(-1)

target_q = self.rainbow_net(BATCH.obs_) # [T, B, A, N]
# [T, B, A, N] * [N, ] => [T, B, A, N] => [T, B, A]
target_q = (self.z * target_q).sum(-1)
target_q = (self._z * target_q).sum(-1)
_a = target_q.argmax(-1) # [T, B]
next_max_action = t.nn.functional.one_hot(
_a, self.a_dim).float().unsqueeze(-1) # [T, B, A, 1]

target_q_dist = self.rainbow_net.t(BATCH.obs_) # [T, B, A, N]
# [T, B, A, N] => [T, B, N]
target_q_dist = (target_q_dist * next_max_action).sum(-2)
# [T, B, N] => [T*B, N]
target_q_dist = target_q_dist.view(-1, self.atoms)

target = q_target_func(BATCH.reward.repeat(1, 1, self.atoms),
target = q_target_func(BATCH.reward.repeat(1, 1, self._atoms),
self.gamma,
BATCH.done.repeat(1, 1, self.atoms),
self.z.view(1, 1, self.atoms).repeat(
time_step, batch_size, 1),
BATCH.begin_mask.repeat(1, 1, self.atoms),
BATCH.done.repeat(1, 1, self._atoms),
target_q_dist,
BATCH.begin_mask.repeat(1, 1, self._atoms),
use_rnn=self.use_rnn) # [T, B, N]

target = target.clamp(self.v_min, self.v_max) # [T, B, N]
target = target.view(-1, self.atoms) # [T, B, N] => [T*B, N]
b = (target - self.v_min) / self.delta_z # [T*B, N]
u, l = b.ceil(), b.floor() # [T*B, N]
u_id, l_id = u.long(), l.long() # [T*B, N]
u_minus_b, b_minus_l = u - b, b - l # [T*B, N]

index_help = indexes.repeat(1, self.atoms) # [T*B, 1] => [T*B, N]
index_help = index_help.unsqueeze(-1) # [T*B, N, 1]
u_id = t.cat([index_help, u_id.unsqueeze(-1)], -1) # [T*B, N, 2]
l_id = t.cat([index_help, l_id.unsqueeze(-1)], -1) # [T*B, N, 2]
u_id = u_id.long().permute(2, 0, 1) # [2, T*B, N]
l_id = l_id.long().permute(2, 0, 1) # [2, T*B, N]
_cross_entropy = (target_q_dist * u_minus_b).detach() * q_dist[list(l_id)].log()\
+ (target_q_dist * b_minus_l).detach() * \
q_dist[list(u_id)].log() # [T*B, N]
td_error = cross_entropy = - \
_cross_entropy.sum(-1).view(time_step, batch_size) # [T, B]

loss = (cross_entropy*BATCH.get('isw', 1.0)).mean() # 1
target = target.clamp(self._v_min, self._v_max) # [T, B, N]
# An amazing trick for calculating the projection gracefully.
# ref: https://github.com/ShangtongZhang/DeepRL
target_dist = (1 - (target.unsqueeze(-1) -
self._z.view(1, 1, -1, 1)).abs() / self._delta_z
).clamp(0, 1) * target_q_dist.unsqueeze(-1) # [T, B, N, 1]
target_dist = target_dist.sum(-1) # [T, B, N]

_cross_entropy = - (target_dist * t.log(q_dist +
t.finfo().eps)).sum(-1, keepdim=True) # [T, B, 1]
loss = (_cross_entropy*BATCH.get('isw', 1.0)).mean() # 1

self.oplr.step(loss)
return td_error, dict([
return _cross_entropy, dict([
['LEARNING_RATE/lr', self.oplr.lr],
['LOSS/loss', loss],
['Statistics/q_max', q_eval.max()],
Expand Down
Loading

0 comments on commit d20cf42

Please sign in to comment.