-
Notifications
You must be signed in to change notification settings - Fork 144
/
categorical_dqn.py
152 lines (118 loc) · 5.43 KB
/
categorical_dqn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from copy import deepcopy
import torch.nn as nn
import torch.nn.functional as F
from mushroom_rl.algorithms.value.dqn import AbstractDQN
from mushroom_rl.approximators.parametric.torch_approximator import *
def categorical_loss(input, target, reduction='sum'):
input = input.clamp(1e-5)
loss = -torch.sum(target * torch.log(input), 1)
if reduction == 'sum':
return loss.mean()
elif reduction == 'none':
return loss
else:
raise ValueError
class CategoricalNetwork(nn.Module):
def __init__(self, input_shape, output_shape, features_network, n_atoms,
v_min, v_max, n_features, use_cuda, **kwargs):
super().__init__()
self._n_output = output_shape[0]
self._phi = features_network(input_shape, (n_features,),
n_features=n_features, **kwargs)
self._n_atoms = n_atoms
self._v_min = v_min
self._v_max = v_max
delta = (self._v_max - self._v_min) / (self._n_atoms - 1)
self._a_values = torch.arange(self._v_min, self._v_max + delta, delta)
if use_cuda:
self._a_values = self._a_values.cuda()
self._p = nn.ModuleList(
[nn.Linear(n_features, n_atoms) for _ in range(self._n_output)])
for i in range(self._n_output):
nn.init.xavier_uniform_(self._p[i].weight,
gain=nn.init.calculate_gain('linear'))
def forward(self, state, action=None, get_distribution=False):
features = self._phi(state)
a_p = [F.softmax(self._p[i](features), -1) for i in range(self._n_output)]
a_p = torch.stack(a_p, dim=1)
if not get_distribution:
q = torch.empty(a_p.shape[:-1])
for i in range(a_p.shape[0]):
q[i] = a_p[i] @ self._a_values
if action is not None:
return torch.squeeze(q.gather(1, action))
else:
return q
else:
if action is not None:
action = torch.unsqueeze(
action.long(), 2).repeat(1, 1, self._n_atoms)
return torch.squeeze(a_p.gather(1, action))
else:
return a_p
class CategoricalDQN(AbstractDQN):
"""
Categorical DQN algorithm.
"A Distributional Perspective on Reinforcement Learning".
Bellemare M. et al.. 2017.
"""
def __init__(self, mdp_info, policy, approximator_params, n_atoms, v_min,
v_max, **params):
"""
Constructor.
Args:
n_atoms (int): number of atoms;
v_min (float): minimum value of value-function;
v_max (float): maximum value of value-function.
"""
features_network = approximator_params['network']
params['approximator_params'] = deepcopy(approximator_params)
params['approximator_params']['network'] = CategoricalNetwork
params['approximator_params']['features_network'] = features_network
params['approximator_params']['n_atoms'] = n_atoms
params['approximator_params']['v_min'] = v_min
params['approximator_params']['v_max'] = v_max
params['approximator_params']['loss'] = categorical_loss
self._n_atoms = n_atoms
self._v_min = v_min
self._v_max = v_max
self._delta = (v_max - v_min) / (n_atoms - 1)
self._a_values = np.arange(v_min, v_max + self._delta, self._delta)
self._add_save_attr(
_n_atoms='primitive',
_v_min='primitive',
_v_max='primitive',
_delta='primitive',
_a_values='numpy'
)
super().__init__(mdp_info, policy, TorchApproximator, **params)
def fit(self, dataset):
self._replay_memory.add(dataset)
if self._replay_memory.initialized:
state, action, reward, next_state, absorbing, _ =\
self._replay_memory.get(self._batch_size())
if self._clip_reward:
reward = np.clip(reward, -1, 1)
q_next = self.target_approximator.predict(next_state, **self._predict_params)
a_max = np.argmax(q_next, 1)
gamma = self.mdp_info.gamma * (1 - absorbing)
p_next = self.target_approximator.predict(next_state, a_max,
get_distribution=True, **self._predict_params)
gamma_z = gamma.reshape(-1, 1) * np.expand_dims(
self._a_values, 0).repeat(len(gamma), 0)
bell_a = (reward.reshape(-1, 1) + gamma_z).clip(self._v_min,
self._v_max)
b = (bell_a - self._v_min) / self._delta
l = np.floor(b).astype(int)
u = np.ceil(b).astype(int)
m = np.zeros((self._batch_size.get_value(), self._n_atoms))
for i in range(self._n_atoms):
l[:, i][(u[:, i] > 0) * (l[:, i] == u[:, i])] -= 1
u[:, i][(l[:, i] < (self._n_atoms - 1)) * (l[:, i] == u[:, i])] += 1
m[np.arange(len(m)), l[:, i]] += p_next[:, i] * (u[:, i] - b[:, i])
m[np.arange(len(m)), u[:, i]] += p_next[:, i] * (b[:, i] - l[:, i])
self.approximator.fit(state, action, m, get_distribution=True,
**self._fit_params)
self._n_updates += 1
if self._n_updates % self._target_update_frequency == 0:
self._update_target()