-
Notifications
You must be signed in to change notification settings - Fork 320
/
dueling_dqn_model.py
60 lines (38 loc) · 1.42 KB
/
dueling_dqn_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
"""
Dueling DQN
"""
import argparse
import pytorch_lightning as pl
from pl_bolts.models.rl.common.networks import DuelingCNN
from pl_bolts.models.rl.dqn_model import DQN
class DuelingDQN(DQN):
"""
PyTorch Lightning implementation of `Dueling DQN <https://arxiv.org/abs/1511.06581>`_
Paper authors: Ziyu Wang, Tom Schaul, Matteo Hessel, Hado van Hasselt, Marc Lanctot, Nando de Freitas
Model implemented by:
- `Donal Byrne <https://github.com/djbyrne>`
Example:
>>> from pl_bolts.models.rl.dueling_dqn_model import DuelingDQN
...
>>> model = DuelingDQN("PongNoFrameskip-v4")
Train::
trainer = Trainer()
trainer.fit(model)
.. note:: Currently only supports CPU and single GPU training with `distributed_backend=dp`
"""
def build_networks(self) -> None:
"""Initializes the Dueling DQN train and target networks"""
self.net = DuelingCNN(self.obs_shape, self.n_actions)
self.target_net = DuelingCNN(self.obs_shape, self.n_actions)
def cli_main():
parser = argparse.ArgumentParser(add_help=False)
# trainer args
parser = pl.Trainer.add_argparse_args(parser)
# model args
parser = DuelingDQN.add_model_specific_args(parser)
args = parser.parse_args()
model = DuelingDQN(**args.__dict__)
trainer = pl.Trainer.from_argparse_args(args)
trainer.fit(model)
if __name__ == '__main__':
cli_main()