forked from mlpack/mlpack
/
n_step_q_learning_worker.hpp
236 lines (200 loc) · 6.32 KB
/
n_step_q_learning_worker.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
/**
* @file n_step_q_learning_worker.hpp
* @author Shangtong Zhang
*
* This file is the definition of NStepQLearningWorker class,
* which implements an episode for async n step Q-Learning algorithm.
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/
#ifndef MLPACK_METHODS_RL_WORKER_N_STEP_Q_LEARNING_WORKER_HPP
#define MLPACK_METHODS_RL_WORKER_N_STEP_Q_LEARNING_WORKER_HPP
#include <mlpack/methods/reinforcement_learning/training_config.hpp>
namespace mlpack {
namespace rl {
/**
* N step Q-Learning worker.
*
* @tparam EnvironmentType The type of the reinforcement learning task.
* @tparam NetworkType The type of the network model.
* @tparam UpdaterType The type of the optimizer.
* @tparam PolicyType The type of the behavior policy. *
*/
template <
typename EnvironmentType,
typename NetworkType,
typename UpdaterType,
typename PolicyType
>
class NStepQLearningWorker
{
public:
using StateType = typename EnvironmentType::State;
using ActionType = typename EnvironmentType::Action;
using TransitionType = std::tuple<StateType, ActionType, double, StateType>;
/**
* @param updater The optimizer.
* @param environment The reinforcement learning task.
* @param config Hyper-parameters.
* @param deterministic Whether it should be deterministic.
*/
NStepQLearningWorker(
const UpdaterType& updater,
const EnvironmentType& environment,
const TrainingConfig& config,
bool deterministic):
updater(updater),
environment(environment),
config(config),
deterministic(deterministic),
pending(config.UpdateInterval())
{ reset(); }
/**
* Initialize the worker.
* @param learningNetwork The shared network.
*/
void Initialize(NetworkType& learningNetwork)
{
updater.Initialize(learningNetwork.Parameters().n_rows,
learningNetwork.Parameters().n_cols);
// Build local network.
network = learningNetwork;
}
/**
* The agent will execute one step.
*
* @param learningNetwork The shared learning network.
* @param targetNetwork The shared target network.
* @param totalSteps The shared counter for total steps.
* @param policy The shared behavior policy.
* @param totalReward This will be the episode return if the episode ends
* after this step. Otherwise this is invalid.
* @return Indicate whether current episode ends after this step.
*/
bool Step(NetworkType& learningNetwork,
NetworkType& targetNetwork,
size_t& totalSteps,
PolicyType& policy,
double& totalReward)
{
// Interact with the environment.
arma::colvec actionValue;
network.Predict(state.Encode(), actionValue);
ActionType action = policy.Sample(actionValue, deterministic);
StateType nextState;
double reward = environment.Sample(state, action, nextState);
bool terminal = environment.IsTerminal(nextState);
episodeReturn += reward;
steps++;
terminal = terminal || steps >= config.StepLimit();
if (deterministic)
{
if (terminal)
{
totalReward = episodeReturn;
reset();
// Sync with latest learning network.
network = learningNetwork;
return true;
}
state = nextState;
return false;
}
#pragma omp atomic
totalSteps++;
pending[pendingIndex] = std::make_tuple(state, action, reward, nextState);
pendingIndex++;
if (terminal || pendingIndex >= config.UpdateInterval())
{
// Initialize the gradient storage.
arma::mat totalGradients(learningNetwork.Parameters().n_rows,
learningNetwork.Parameters().n_cols);
// Bootstrap from the value of next state.
arma::colvec actionValue;
double target = 0;
if (!terminal)
{
#pragma omp critical
{ targetNetwork.Predict(nextState.Encode(), actionValue); };
target = actionValue.max();
}
// Update in reverse order.
for (int i = pending.size() - 1; i >= 0; --i)
{
TransitionType &transition = pending[i];
target = config.Discount() * target + std::get<2>(transition);
// Compute the training target for current state.
network.Forward(std::get<0>(transition).Encode(), actionValue);
actionValue[std::get<1>(transition)] = target;
// Compute gradient.
arma::mat gradients;
network.Backward(actionValue, gradients);
// Accumulate gradients.
totalGradients += gradients;
}
// Clamp the accumulated gradients.
totalGradients.transform(
[&](double gradient)
{ return std::min(std::max(gradient, -config.GradientLimit()),
config.GradientLimit()); });
// Perform async update of the global network.
updater.Update(learningNetwork.Parameters(),
config.StepSize(), totalGradients);
// Sync the local network with the global network.
network = learningNetwork;
pendingIndex = 0;
}
// Update global target network.
if (totalSteps % config.TargetNetworkSyncInterval() == 0)
{
#pragma omp critical
{ targetNetwork = learningNetwork; }
}
policy.Anneal();
if (terminal)
{
totalReward = episodeReturn;
reset();
return true;
}
state = nextState;
return false;
}
private:
/**
* Reset the worker for a new episdoe.
*/
void reset()
{
steps = 0;
episodeReturn = 0;
pendingIndex = 0;
state = environment.InitialSample();
}
//! Locally-stored optimizer.
UpdaterType updater;
//! Locally-stored task.
EnvironmentType environment;
//! Locally-stored hyper-parameters.
TrainingConfig config;
//! Whether this episode is deterministic or not.
bool deterministic;
//! Total steps in current episode.
size_t steps;
//! Total reward in current episode.
double episodeReturn;
//! Buffer for delayed update.
std::vector<TransitionType> pending;
//! Current position of the buffer.
size_t pendingIndex;
//! Local network of the worker.
NetworkType network;
//! Current state of the agent.
StateType state;
};
} // namespace rl
} // namespace mlpack
#endif