-
Notifications
You must be signed in to change notification settings - Fork 205
Expand file tree
/
Copy patha2c_common.py
More file actions
1212 lines (992 loc) · 50.8 KB
/
a2c_common.py
File metadata and controls
1212 lines (992 loc) · 50.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import os
from rl_games.common import tr_helpers
from rl_games.common import vecenv
from rl_games.algos_torch.running_mean_std import RunningMeanStd
from rl_games.algos_torch.moving_mean_std import MovingMeanStd
from rl_games.algos_torch.self_play_manager import SelfPlayManager
from rl_games.algos_torch import torch_ext
from rl_games.common import schedulers
from rl_games.common.experience import ExperienceBuffer
from rl_games.common.interval_summary_writer import IntervalSummaryWriter
from rl_games.common.diagnostics import DefaultDiagnostics, PpoDiagnostics
from rl_games.algos_torch import model_builder
from rl_games.interfaces.base_algorithm import BaseAlgorithm
import numpy as np
import time
import gym
from datetime import datetime
from tensorboardX import SummaryWriter
import torch
from torch import nn
from time import sleep
from rl_games.common import common_losses
def swap_and_flatten01(arr):
"""
swap and then flatten axes 0 and 1
"""
if arr is None:
return arr
s = arr.size()
return arr.transpose(0, 1).reshape(s[0] * s[1], *s[2:])
def rescale_actions(low, high, action):
d = (high - low) / 2.0
m = (high + low) / 2.0
scaled_action = action * d + m
return scaled_action
class A2CBase(BaseAlgorithm):
def __init__(self, base_name, params):
self.config = config = params['config']
pbt_str = ''
self.population_based_training = config.get('population_based_training', False)
if self.population_based_training:
# in PBT, make sure experiment name contains a unique id of the policy within a population
pbt_str = f'_pbt_{config["pbt_idx"]:02d}'
# This helps in PBT when we need to restart an experiment with the exact same name, rather than
# generating a new name with the timestamp every time.
full_experiment_name = config.get('full_experiment_name', None)
if full_experiment_name:
print(f'Exact experiment name requested from command line: {full_experiment_name}')
self.experiment_name = full_experiment_name
else:
self.experiment_name = config['name'] + pbt_str + datetime.now().strftime("_%d-%H-%M-%S")
self.config = config
self.algo_observer = config['features']['observer']
self.algo_observer.before_init(base_name, config, self.experiment_name)
self.load_networks(params)
self.multi_gpu = config.get('multi_gpu', False)
self.rank = 0
self.rank_size = 1
self.curr_frames = 0
if self.multi_gpu:
from rl_games.distributed.hvd_wrapper import HorovodWrapper
self.hvd = HorovodWrapper()
self.config = self.hvd.update_algo_config(config)
self.rank = self.hvd.rank
self.rank_size = self.hvd.rank_size
self.use_diagnostics = config.get('use_diagnostics', False)
if self.use_diagnostics and self.rank == 0:
self.diagnostics = PpoDiagnostics()
else:
self.diagnostics = DefaultDiagnostics()
self.network_path = config.get('network_path', "./nn/")
self.log_path = config.get('log_path', "runs/")
self.env_config = config.get('env_config', {})
self.num_actors = config['num_actors']
self.env_name = config['env_name']
self.vec_env = None
self.env_info = config.get('env_info')
if self.env_info is None:
self.vec_env = vecenv.create_vec_env(self.env_name, self.num_actors, **self.env_config)
self.env_info = self.vec_env.get_env_info()
self.ppo_device = config.get('device', 'cuda:0')
self.value_size = self.env_info.get('value_size',1)
self.observation_space = self.env_info['observation_space']
self.weight_decay = config.get('weight_decay', 0.0)
self.use_action_masks = config.get('use_action_masks', False)
self.is_train = config.get('is_train', True)
self.central_value_config = self.config.get('central_value_config', None)
self.has_central_value = self.central_value_config is not None
self.truncate_grads = self.config.get('truncate_grads', False)
if self.has_central_value:
self.state_space = self.env_info.get('state_space', None)
if isinstance(self.state_space,gym.spaces.Dict):
self.state_shape = {}
for k,v in self.state_space.spaces.items():
self.state_shape[k] = v.shape
else:
self.state_shape = self.state_space.shape
self.self_play_config = self.config.get('self_play_config', None)
self.has_self_play_config = self.self_play_config is not None
self.self_play = config.get('self_play', False)
self.save_freq = config.get('save_frequency', 0)
self.save_best_after = config.get('save_best_after', 100)
self.print_stats = config.get('print_stats', True)
self.rnn_states = None
self.name = base_name
self.ppo = config.get('ppo', True)
self.max_epochs = self.config.get('max_epochs', 1e6)
self.is_adaptive_lr = config['lr_schedule'] == 'adaptive'
self.linear_lr = config['lr_schedule'] == 'linear'
if self.is_adaptive_lr:
self.kl_threshold = config['kl_threshold']
self.scheduler = schedulers.AdaptiveScheduler(self.kl_threshold)
elif self.linear_lr:
self.scheduler = schedulers.LinearScheduler(float(config['learning_rate']),
max_steps=self.max_epochs,
apply_to_entropy=config.get('schedule_entropy', False),
start_entropy_coef=config.get('entropy_coef'))
else:
self.scheduler = schedulers.IdentityScheduler()
self.e_clip = config['e_clip']
self.clip_value = config['clip_value']
self.network = config['network']
self.rewards_shaper = config['reward_shaper']
self.num_agents = self.env_info.get('agents', 1)
self.horizon_length = config['horizon_length']
self.seq_len = self.config.get('seq_length', 4)
self.bptt_len = self.config.get('bptt_length', self.seq_len)
self.normalize_advantage = config['normalize_advantage']
self.normalize_rms_advantage = config.get('normalize_rms_advantage', False)
self.normalize_input = self.config['normalize_input']
self.normalize_value = self.config.get('normalize_value', False)
self.truncate_grads = self.config.get('truncate_grads', False)
self.has_phasic_policy_gradients = False
if isinstance(self.observation_space, gym.spaces.Dict):
self.obs_shape = {}
for k,v in self.observation_space.spaces.items():
self.obs_shape[k] = v.shape
else:
self.obs_shape = self.observation_space.shape
self.critic_coef = config['critic_coef']
self.grad_norm = config['grad_norm']
self.gamma = self.config['gamma']
self.tau = self.config['tau']
self.games_to_track = self.config.get('games_to_track', 100)
print('current training device:', self.ppo_device)
self.game_rewards = torch_ext.AverageMeter(self.value_size, self.games_to_track).to(self.ppo_device)
self.game_lengths = torch_ext.AverageMeter(1, self.games_to_track).to(self.ppo_device)
self.obs = None
self.games_num = self.config['minibatch_size'] // self.seq_len # it is used only for current rnn implementation
self.batch_size = self.horizon_length * self.num_actors * self.num_agents
self.batch_size_envs = self.horizon_length * self.num_actors
assert(('minibatch_size_per_env' in self.config) or ('minibatch_size' in self.config))
self.minibatch_size_per_env = self.config.get('minibatch_size_per_env', 0)
self.minibatch_size = self.config.get('minibatch_size', self.num_actors * self.minibatch_size_per_env)
self.mini_epochs_num = self.config['mini_epochs']
self.num_minibatches = self.batch_size // self.minibatch_size
assert(self.batch_size % self.minibatch_size == 0)
self.mixed_precision = self.config.get('mixed_precision', False)
self.scaler = torch.cuda.amp.GradScaler(enabled=self.mixed_precision)
self.last_lr = self.config['learning_rate']
self.frame = 0
self.update_time = 0
self.mean_rewards = self.last_mean_rewards = -100500
self.play_time = 0
self.epoch_num = 0
self.curr_frames = 0
# allows us to specify a folder where all experiments will reside
self.train_dir = config.get('train_dir', 'runs')
# a folder inside of train_dir containing everything related to a particular experiment
self.experiment_dir = os.path.join(self.train_dir, self.experiment_name)
# folders inside <train_dir>/<experiment_dir> for a specific purpose
self.nn_dir = os.path.join(self.experiment_dir, 'nn')
self.summaries_dir = os.path.join(self.experiment_dir, 'summaries')
os.makedirs(self.train_dir, exist_ok=True)
os.makedirs(self.experiment_dir, exist_ok=True)
os.makedirs(self.nn_dir, exist_ok=True)
os.makedirs(self.summaries_dir, exist_ok=True)
self.entropy_coef = self.config['entropy_coef']
if self.rank == 0:
writer = SummaryWriter(self.summaries_dir)
if self.population_based_training:
self.writer = IntervalSummaryWriter(writer, self.config)
else:
self.writer = writer
else:
self.writer = None
self.value_bootstrap = self.config.get('value_bootstrap')
self.use_smooth_clamp = self.config.get('use_smooth_clamp', False)
if self.use_smooth_clamp:
self.actor_loss_func = common_losses.smoothed_actor_loss
else:
self.actor_loss_func = common_losses.actor_loss
if self.normalize_advantage and self.normalize_rms_advantage:
momentum = self.config.get('adv_rms_momentum',0.5 ) #'0.25'
self.advantage_mean_std = MovingMeanStd((1,), momentum=momentum).to(self.ppo_device)
self.is_tensor_obses = False
self.last_rnn_indices = None
self.last_state_indices = None
#self_play
if self.has_self_play_config:
print('Initializing SelfPlay Manager')
self.self_play_manager = SelfPlayManager(self.self_play_config, self.writer)
# features
self.algo_observer = config['features']['observer']
self.soft_aug = config['features'].get('soft_augmentation', None)
self.has_soft_aug = self.soft_aug is not None
# soft augmentation not yet supported
assert not self.has_soft_aug
def trancate_gradients_and_step(self):
if self.multi_gpu:
self.optimizer.synchronize()
if self.truncate_grads:
self.scaler.unscale_(self.optimizer)
nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm)
if self.multi_gpu:
with self.optimizer.skip_synchronize():
self.scaler.step(self.optimizer)
self.scaler.update()
else:
self.scaler.step(self.optimizer)
self.scaler.update()
def load_networks(self, params):
builder = model_builder.ModelBuilder()
self.config['network'] = builder.load(params)
has_central_value_net = self.config.get('central_value_config') is not None
if has_central_value_net:
print('Adding Central Value Network')
if 'model' not in params['config']['central_value_config']:
params['config']['central_value_config']['model'] = {'name': 'central_value'}
network = builder.load(params['config']['central_value_config'])
self.config['central_value_config']['network'] = network
def write_stats(self, total_time, epoch_num, step_time, play_time, update_time, a_losses, c_losses, entropies, kls, last_lr, lr_mul, frame, scaled_time, scaled_play_time, curr_frames):
# do we need scaled time?
self.diagnostics.send_info(self.writer)
self.writer.add_scalar('performance/step_inference_rl_update_fps', curr_frames / scaled_time, frame)
self.writer.add_scalar('performance/step_inference_fps', curr_frames / scaled_play_time, frame)
self.writer.add_scalar('performance/step_fps', curr_frames / step_time, frame)
self.writer.add_scalar('performance/rl_update_time', update_time, frame)
self.writer.add_scalar('performance/step_inference_time', play_time, frame)
self.writer.add_scalar('performance/step_time', step_time, frame)
self.writer.add_scalar('losses/a_loss', torch_ext.mean_list(a_losses).item(), frame)
self.writer.add_scalar('losses/c_loss', torch_ext.mean_list(c_losses).item(), frame)
self.writer.add_scalar('losses/entropy', torch_ext.mean_list(entropies).item(), frame)
self.writer.add_scalar('info/last_lr', last_lr * lr_mul, frame)
self.writer.add_scalar('info/lr_mul', lr_mul, frame)
self.writer.add_scalar('info/e_clip', self.e_clip * lr_mul, frame)
self.writer.add_scalar('info/kl', torch_ext.mean_list(kls).item(), frame)
self.writer.add_scalar('info/epochs', epoch_num, frame)
self.algo_observer.after_print_stats(frame, epoch_num, total_time)
def set_eval(self):
self.model.eval()
if self.normalize_rms_advantage:
self.advantage_mean_std.eval()
def set_train(self):
self.model.train()
if self.normalize_rms_advantage:
self.advantage_mean_std.train()
def update_lr(self, lr):
if self.multi_gpu:
lr_tensor = torch.tensor([lr])
self.hvd.broadcast_value(lr_tensor, 'learning_rate')
lr = lr_tensor.item()
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
#if self.has_central_value:
# self.central_value_net.update_lr(lr)
def get_action_values(self, obs):
processed_obs = self._preproc_obs(obs['obs'])
self.model.eval()
input_dict = {
'is_train': False,
'prev_actions': None,
'obs' : processed_obs,
'rnn_states' : self.rnn_states
}
with torch.no_grad():
res_dict = self.model(input_dict)
if self.has_central_value:
states = obs['states']
input_dict = {
'is_train': False,
'states' : states,
}
value = self.get_central_value(input_dict)
res_dict['values'] = value
return res_dict
def get_values(self, obs):
with torch.no_grad():
if self.has_central_value:
states = obs['states']
self.central_value_net.eval()
input_dict = {
'is_train': False,
'states' : states,
'actions' : None,
'is_done': self.dones,
}
value = self.get_central_value(input_dict)
else:
self.model.eval()
processed_obs = self._preproc_obs(obs['obs'])
input_dict = {
'is_train': False,
'prev_actions': None,
'obs' : processed_obs,
'rnn_states' : self.rnn_states
}
result = self.model(input_dict)
value = result['values']
return value
@property
def device(self):
return self.ppo_device
def reset_envs(self):
self.obs = self.env_reset()
def init_tensors(self):
batch_size = self.num_agents * self.num_actors
algo_info = {
'num_actors' : self.num_actors,
'horizon_length' : self.horizon_length,
'has_central_value' : self.has_central_value,
'use_action_masks' : self.use_action_masks
}
self.experience_buffer = ExperienceBuffer(self.env_info, algo_info, self.ppo_device)
val_shape = (self.horizon_length, batch_size, self.value_size)
current_rewards_shape = (batch_size, self.value_size)
self.current_rewards = torch.zeros(current_rewards_shape, dtype=torch.float32, device=self.ppo_device)
self.current_lengths = torch.zeros(batch_size, dtype=torch.float32, device=self.ppo_device)
self.dones = torch.ones((batch_size,), dtype=torch.uint8, device=self.ppo_device)
if self.is_rnn:
self.rnn_states = self.model.get_default_rnn_state()
self.rnn_states = [s.to(self.ppo_device) for s in self.rnn_states]
total_agents = self.num_agents * self.num_actors
num_seqs = self.horizon_length // self.seq_len
assert((self.horizon_length * total_agents // self.num_minibatches) % self.seq_len == 0)
self.mb_rnn_states = [torch.zeros((num_seqs, s.size()[0], total_agents, s.size()[2]), dtype = torch.float32, device=self.ppo_device) for s in self.rnn_states]
def init_rnn_from_model(self, model):
self.is_rnn = self.model.is_rnn()
def cast_obs(self, obs):
if isinstance(obs, torch.Tensor):
self.is_tensor_obses = True
elif isinstance(obs, np.ndarray):
assert(self.observation_space.dtype != np.int8)
if self.observation_space.dtype == np.uint8:
obs = torch.ByteTensor(obs).to(self.ppo_device)
else:
obs = torch.FloatTensor(obs).to(self.ppo_device)
return obs
def obs_to_tensors(self, obs):
obs_is_dict = isinstance(obs, dict)
if obs_is_dict:
upd_obs = {}
for key, value in obs.items():
upd_obs[key] = self._obs_to_tensors_internal(value)
else:
upd_obs = self.cast_obs(obs)
if not obs_is_dict or 'obs' not in obs:
upd_obs = {'obs' : upd_obs}
return upd_obs
def _obs_to_tensors_internal(self, obs):
if isinstance(obs, dict):
upd_obs = {}
for key, value in obs.items():
upd_obs[key] = self._obs_to_tensors_internal(value)
else:
upd_obs = self.cast_obs(obs)
return upd_obs
def preprocess_actions(self, actions):
if not self.is_tensor_obses:
actions = actions.cpu().numpy()
return actions
def env_step(self, actions):
actions = self.preprocess_actions(actions)
obs, rewards, dones, infos = self.vec_env.step(actions)
if self.is_tensor_obses:
if self.value_size == 1:
rewards = rewards.unsqueeze(1)
return self.obs_to_tensors(obs), rewards.to(self.ppo_device), dones.to(self.ppo_device), infos
else:
if self.value_size == 1:
rewards = np.expand_dims(rewards, axis=1)
return self.obs_to_tensors(obs), torch.from_numpy(rewards).to(self.ppo_device).float(), torch.from_numpy(dones).to(self.ppo_device), infos
def env_reset(self):
obs = self.vec_env.reset()
obs = self.obs_to_tensors(obs)
return obs
def discount_values(self, fdones, last_extrinsic_values, mb_fdones, mb_extrinsic_values, mb_rewards):
lastgaelam = 0
mb_advs = torch.zeros_like(mb_rewards)
for t in reversed(range(self.horizon_length)):
if t == self.horizon_length - 1:
nextnonterminal = 1.0 - fdones
nextvalues = last_extrinsic_values
else:
nextnonterminal = 1.0 - mb_fdones[t+1]
nextvalues = mb_extrinsic_values[t+1]
nextnonterminal = nextnonterminal.unsqueeze(1)
delta = mb_rewards[t] + self.gamma * nextvalues * nextnonterminal - mb_extrinsic_values[t]
mb_advs[t] = lastgaelam = delta + self.gamma * self.tau * nextnonterminal * lastgaelam
return mb_advs
def discount_values_masks(self, fdones, last_extrinsic_values, mb_fdones, mb_extrinsic_values, mb_rewards, mb_masks):
lastgaelam = 0
mb_advs = torch.zeros_like(mb_rewards)
for t in reversed(range(self.horizon_length)):
if t == self.horizon_length - 1:
nextnonterminal = 1.0 - fdones
nextvalues = last_extrinsic_values
else:
nextnonterminal = 1.0 - mb_fdones[t+1]
nextvalues = mb_extrinsic_values[t+1]
nextnonterminal = nextnonterminal.unsqueeze(1)
masks_t = mb_masks[t].unsqueeze(1)
delta = (mb_rewards[t] + self.gamma * nextvalues * nextnonterminal - mb_extrinsic_values[t])
mb_advs[t] = lastgaelam = (delta + self.gamma * self.tau * nextnonterminal * lastgaelam) * masks_t
return mb_advs
def clear_stats(self):
batch_size = self.num_agents * self.num_actors
self.game_rewards.clear()
self.game_lengths.clear()
self.mean_rewards = self.last_mean_rewards = -100500
self.algo_observer.after_clear_stats()
def update_epoch(self):
pass
def train(self):
pass
def prepare_dataset(self, batch_dict):
pass
def train_epoch(self):
self.vec_env.set_train_info(self.frame, self)
def train_actor_critic(self, obs_dict, opt_step=True):
pass
def calc_gradients(self):
pass
def get_central_value(self, obs_dict):
return self.central_value_net.get_value(obs_dict)
def train_central_value(self):
return self.central_value_net.train_net()
def get_full_state_weights(self):
state = self.get_weights()
state['epoch'] = self.epoch_num
state['optimizer'] = self.optimizer.state_dict()
if self.has_central_value:
state['assymetric_vf_nets'] = self.central_value_net.state_dict()
state['frame'] = self.frame
# This is actually the best reward ever achieved. last_mean_rewards is perhaps not the best variable name
# We save it to the checkpoint to prevent overriding the "best ever" checkpoint upon experiment restart
state['last_mean_rewards'] = self.last_mean_rewards
if self.vec_env is not None:
env_state = self.vec_env.get_env_state()
state['env_state'] = env_state
return state
def set_full_state_weights(self, weights):
self.set_weights(weights)
self.epoch_num = weights['epoch']
if self.has_central_value:
self.central_value_net.load_state_dict(weights['assymetric_vf_nets'])
self.optimizer.load_state_dict(weights['optimizer'])
self.frame = weights.get('frame', 0)
self.last_mean_rewards = weights.get('last_mean_rewards', -100500)
env_state = weights.get('env_state', None)
if self.vec_env is not None:
self.vec_env.set_env_state(env_state)
def get_weights(self):
state = self.get_stats_weights()
state['model'] = self.model.state_dict()
return state
def get_stats_weights(self, model_stats=False):
state = {}
if self.mixed_precision:
state['scaler'] = self.scaler.state_dict()
if self.has_central_value:
state['central_val_stats'] = self.central_value_net.get_stats_weights(model_stats)
if model_stats:
if self.normalize_input:
state['running_mean_std'] = self.model.running_mean_std.state_dict()
if self.normalize_value:
state['reward_mean_std'] = self.model.value_mean_std.state_dict()
return state
def set_stats_weights(self, weights):
if self.normalize_rms_advantage:
self.advantage_mean_std.load_state_dic(weights['advantage_mean_std'])
if self.normalize_input and 'running_mean_std' in weights:
self.model.running_mean_std.load_state_dict(weights['running_mean_std'])
if self.normalize_value and 'normalize_value' in weights:
self.model.value_mean_std.load_state_dict(weights['reward_mean_std'])
if self.mixed_precision and 'scaler' in weights:
self.scaler.load_state_dict(weights['scaler'])
def set_weights(self, weights):
self.model.load_state_dict(weights['model'])
self.set_stats_weights(weights)
def _preproc_obs(self, obs_batch):
if type(obs_batch) is dict:
for k,v in obs_batch.items():
if v.dtype == torch.uint8:
obs_batch[k] = v.float() / 255.
else:
obs_batch[k] = v
else:
if obs_batch.dtype == torch.uint8:
obs_batch = obs_batch.float() / 255.0
return obs_batch
def play_steps(self):
update_list = self.update_list
step_time = 0.0
for n in range(self.horizon_length):
if self.use_action_masks:
masks = self.vec_env.get_action_masks()
res_dict = self.get_masked_action_values(self.obs, masks)
else:
res_dict = self.get_action_values(self.obs)
self.experience_buffer.update_data('obses', n, self.obs['obs'])
self.experience_buffer.update_data('dones', n, self.dones)
for k in update_list:
self.experience_buffer.update_data(k, n, res_dict[k])
if self.has_central_value:
self.experience_buffer.update_data('states', n, self.obs['states'])
step_time_start = time.time()
self.obs, rewards, self.dones, infos = self.env_step(res_dict['actions'])
step_time_end = time.time()
step_time += (step_time_end - step_time_start)
shaped_rewards = self.rewards_shaper(rewards)
if self.value_bootstrap and 'time_outs' in infos:
shaped_rewards += self.gamma * res_dict['values'] * self.cast_obs(infos['time_outs']).unsqueeze(1).float()
self.experience_buffer.update_data('rewards', n, shaped_rewards)
self.current_rewards += rewards
self.current_lengths += 1
all_done_indices = self.dones.nonzero(as_tuple=False)
env_done_indices = self.dones.view(self.num_actors, self.num_agents).all(dim=1).nonzero(as_tuple=False)
self.game_rewards.update(self.current_rewards[env_done_indices])
self.game_lengths.update(self.current_lengths[env_done_indices])
self.algo_observer.process_infos(infos, env_done_indices)
not_dones = 1.0 - self.dones.float()
self.current_rewards = self.current_rewards * not_dones.unsqueeze(1)
self.current_lengths = self.current_lengths * not_dones
last_values = self.get_values(self.obs)
fdones = self.dones.float()
mb_fdones = self.experience_buffer.tensor_dict['dones'].float()
mb_values = self.experience_buffer.tensor_dict['values']
mb_rewards = self.experience_buffer.tensor_dict['rewards']
mb_advs = self.discount_values(fdones, last_values, mb_fdones, mb_values, mb_rewards)
mb_returns = mb_advs + mb_values
batch_dict = self.experience_buffer.get_transformed_list(swap_and_flatten01, self.tensor_list)
batch_dict['returns'] = swap_and_flatten01(mb_returns)
batch_dict['played_frames'] = self.batch_size
batch_dict['step_time'] = step_time
return batch_dict
def play_steps_rnn(self):
update_list = self.update_list
mb_rnn_states = self.mb_rnn_states
step_time = 0.0
for n in range(self.horizon_length):
if n % self.seq_len == 0:
for s, mb_s in zip(self.rnn_states, mb_rnn_states):
mb_s[n // self.seq_len,:,:,:] = s
if self.has_central_value:
self.central_value_net.pre_step_rnn(n)
if self.use_action_masks:
masks = self.vec_env.get_action_masks()
res_dict = self.get_masked_action_values(self.obs, masks)
else:
res_dict = self.get_action_values(self.obs)
self.rnn_states = res_dict['rnn_states']
self.experience_buffer.update_data('obses', n, self.obs['obs'])
self.experience_buffer.update_data('dones', n, self.dones.byte())
for k in update_list:
self.experience_buffer.update_data(k, n, res_dict[k])
if self.has_central_value:
self.experience_buffer.update_data('states', n, self.obs['states'])
step_time_start = time.time()
self.obs, rewards, self.dones, infos = self.env_step(res_dict['actions'])
step_time_end = time.time()
step_time += (step_time_end - step_time_start)
shaped_rewards = self.rewards_shaper(rewards)
if self.value_bootstrap and 'time_outs' in infos:
shaped_rewards += self.gamma * res_dict['values'] * self.cast_obs(infos['time_outs']).unsqueeze(1).float()
self.experience_buffer.update_data('rewards', n, shaped_rewards)
self.current_rewards += rewards
self.current_lengths += 1
all_done_indices = self.dones.nonzero(as_tuple=False)
env_done_indices = self.dones.view(self.num_actors, self.num_agents).all(dim=1).nonzero(as_tuple=False)
if len(all_done_indices) > 0:
for s in self.rnn_states:
s[:, all_done_indices, :] = s[:, all_done_indices, :] * 0.0
if self.has_central_value:
self.central_value_net.post_step_rnn(all_done_indices)
self.game_rewards.update(self.current_rewards[env_done_indices])
self.game_lengths.update(self.current_lengths[env_done_indices])
self.algo_observer.process_infos(infos, env_done_indices)
not_dones = 1.0 - self.dones.float()
self.current_rewards = self.current_rewards * not_dones.unsqueeze(1)
self.current_lengths = self.current_lengths * not_dones
last_values = self.get_values(self.obs)
fdones = self.dones.float()
mb_fdones = self.experience_buffer.tensor_dict['dones'].float()
mb_values = self.experience_buffer.tensor_dict['values']
mb_rewards = self.experience_buffer.tensor_dict['rewards']
mb_advs = self.discount_values(fdones, last_values, mb_fdones, mb_values, mb_rewards)
mb_returns = mb_advs + mb_values
batch_dict = self.experience_buffer.get_transformed_list(swap_and_flatten01, self.tensor_list)
batch_dict['returns'] = swap_and_flatten01(mb_returns)
batch_dict['played_frames'] = self.batch_size
states = []
for mb_s in mb_rnn_states:
t_size = mb_s.size()[0] * mb_s.size()[2]
h_size = mb_s.size()[3]
states.append(mb_s.permute(1,2,0,3).reshape(-1,t_size, h_size))
batch_dict['rnn_states'] = states
batch_dict['step_time'] = step_time
return batch_dict
class DiscreteA2CBase(A2CBase):
def __init__(self, base_name, params):
A2CBase.__init__(self, base_name, params)
batch_size = self.num_agents * self.num_actors
action_space = self.env_info['action_space']
if type(action_space) is gym.spaces.Discrete:
self.actions_shape = (self.horizon_length, batch_size)
self.actions_num = action_space.n
self.is_multi_discrete = False
if type(action_space) is gym.spaces.Tuple:
self.actions_shape = (self.horizon_length, batch_size, len(action_space))
self.actions_num = [action.n for action in action_space]
self.is_multi_discrete = True
self.is_discrete = True
def init_tensors(self):
A2CBase.init_tensors(self)
self.update_list = ['actions', 'neglogpacs', 'values']
if self.use_action_masks:
self.update_list += ['action_masks']
self.tensor_list = self.update_list + ['obses', 'states', 'dones']
def train_epoch(self):
super().train_epoch()
self.set_eval()
play_time_start = time.time()
with torch.no_grad():
if self.is_rnn:
batch_dict = self.play_steps_rnn()
else:
batch_dict = self.play_steps()
self.set_train()
play_time_end = time.time()
update_time_start = time.time()
rnn_masks = batch_dict.get('rnn_masks', None)
self.curr_frames = batch_dict.pop('played_frames')
self.prepare_dataset(batch_dict)
self.algo_observer.after_steps()
a_losses = []
c_losses = []
entropies = []
kls = []
if self.has_central_value:
self.train_central_value()
for mini_ep in range(0, self.mini_epochs_num):
ep_kls = []
for i in range(len(self.dataset)):
a_loss, c_loss, entropy, kl, last_lr, lr_mul = self.train_actor_critic(self.dataset[i])
a_losses.append(a_loss)
c_losses.append(c_loss)
ep_kls.append(kl)
entropies.append(entropy)
av_kls = torch_ext.mean_list(ep_kls)
if self.multi_gpu:
av_kls = self.hvd.average_value(av_kls, 'ep_kls')
self.last_lr, self.entropy_coef = self.scheduler.update(self.last_lr, self.entropy_coef, self.epoch_num, 0, av_kls.item())
self.update_lr(self.last_lr)
kls.append(av_kls)
self.diagnostics.mini_epoch(self, mini_ep)
if self.normalize_input:
self.model.running_mean_std.eval() # don't need to update statstics more than one miniepoch
update_time_end = time.time()
play_time = play_time_end - play_time_start
update_time = update_time_end - update_time_start
total_time = update_time_end - play_time_start
return batch_dict['step_time'], play_time, update_time, total_time, a_losses, c_losses, entropies, kls, last_lr, lr_mul
def prepare_dataset(self, batch_dict):
rnn_masks = batch_dict.get('rnn_masks', None)
returns = batch_dict['returns']
values = batch_dict['values']
actions = batch_dict['actions']
neglogpacs = batch_dict['neglogpacs']
dones = batch_dict['dones']
rnn_states = batch_dict.get('rnn_states', None)
advantages = returns - values
obses = batch_dict['obses']
if self.normalize_value:
self.value_mean_std.train()
values = self.value_mean_std(values)
returns = self.value_mean_std(returns)
self.value_mean_std.eval()
advantages = torch.sum(advantages, axis=1)
if self.normalize_advantage:
if self.is_rnn:
if self.normalize_rms_advantage:
advantages = self.advantage_mean_std(advantages, mask=rnn_masks)
else:
advantages = torch_ext.normalization_with_masks(advantages, rnn_masks)
else:
if self.normalize_rms_advantage:
advantages = self.advantage_mean_std(advantages)
else:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
dataset_dict = {}
dataset_dict['old_values'] = values
dataset_dict['old_logp_actions'] = neglogpacs
dataset_dict['advantages'] = advantages
dataset_dict['returns'] = returns
dataset_dict['actions'] = actions
dataset_dict['obs'] = obses
dataset_dict['dones'] = dones
dataset_dict['rnn_states'] = rnn_states
dataset_dict['rnn_masks'] = rnn_masks
if self.use_action_masks:
dataset_dict['action_masks'] = batch_dict['action_masks']
self.dataset.update_values_dict(dataset_dict)
if self.has_central_value:
dataset_dict = {}
dataset_dict['old_values'] = values
dataset_dict['advantages'] = advantages
dataset_dict['returns'] = returns
dataset_dict['actions'] = actions
dataset_dict['dones'] = dones
dataset_dict['obs'] = batch_dict['states']
dataset_dict['rnn_masks'] = rnn_masks
self.central_value_net.update_dataset(dataset_dict)
def train(self):
self.init_tensors()
self.mean_rewards = self.last_mean_rewards = -100500
start_time = time.time()
total_time = 0
rep_count = 0
# self.frame = 0 # loading from checkpoint
self.obs = self.env_reset()
if self.multi_gpu:
self.hvd.setup_algo(self)
while True:
epoch_num = self.update_epoch()
step_time, play_time, update_time, sum_time, a_losses, c_losses, entropies, kls, last_lr, lr_mul = self.train_epoch()
if self.multi_gpu:
self.hvd.sync_stats(self)
# cleaning memory to optimize space
self.dataset.update_values_dict(None)
total_time += sum_time
curr_frames = self.curr_frames
self.frame += curr_frames
should_exit = False
if self.rank == 0:
self.diagnostics.epoch(self, current_epoch=epoch_num)
scaled_time = self.num_agents * sum_time
scaled_play_time = self.num_agents * play_time
frame = self.frame // self.num_agents
if self.print_stats:
fps_step = curr_frames / step_time
fps_step_inference = curr_frames / scaled_play_time
fps_total = curr_frames / scaled_time
print(f'fps step: {fps_step:.1f} fps step and policy inference: {fps_step_inference:.1f} fps total: {fps_total:.1f} epoch: {epoch_num}/{self.max_epochs}')
self.write_stats(total_time, epoch_num, step_time, play_time, update_time, a_losses, c_losses, entropies, kls, last_lr, lr_mul, frame, scaled_time, scaled_play_time, curr_frames)
self.algo_observer.after_print_stats(frame, epoch_num, total_time)
if self.game_rewards.current_size > 0:
mean_rewards = self.game_rewards.get_mean()
mean_lengths = self.game_lengths.get_mean()
self.mean_rewards = mean_rewards[0]
for i in range(self.value_size):
rewards_name = 'rewards' if i == 0 else 'rewards{0}'.format(i)
self.writer.add_scalar(rewards_name + '/step'.format(i), mean_rewards[i], frame)
self.writer.add_scalar(rewards_name + '/iter'.format(i), mean_rewards[i], epoch_num)
self.writer.add_scalar(rewards_name + '/time'.format(i), mean_rewards[i], total_time)
self.writer.add_scalar('episode_lengths/step', mean_lengths, frame)
self.writer.add_scalar('episode_lengths/iter', mean_lengths, epoch_num)
self.writer.add_scalar('episode_lengths/time', mean_lengths, total_time)
if self.has_self_play_config:
self.self_play_manager.update(self)
# removed equal signs (i.e. "rew=") from the checkpoint name since it messes with hydra CLI parsing
checkpoint_name = self.config['name'] + '_ep_' + str(epoch_num) + '_rew_' + str(mean_rewards[0])
if self.save_freq > 0:
if (epoch_num % self.save_freq == 0) and (mean_rewards <= self.last_mean_rewards):
self.save(os.path.join(self.nn_dir, 'last_' + checkpoint_name))
if mean_rewards[0] > self.last_mean_rewards and epoch_num >= self.save_best_after:
print('saving next best rewards: ', mean_rewards)
self.last_mean_rewards = mean_rewards[0]
self.save(os.path.join(self.nn_dir, self.config['name']))
if self.last_mean_rewards > self.config['score_to_win']:
print('Network won!')
self.save(os.path.join(self.nn_dir, checkpoint_name))
should_exit = True
if epoch_num > self.max_epochs:
self.save(os.path.join(self.nn_dir, 'last_' + checkpoint_name))
print('MAX EPOCHS NUM!')
should_exit = True
update_time = 0
if self.multi_gpu:
should_exit_t = torch.tensor(should_exit).float()
self.hvd.broadcast_value(should_exit_t, 'should_exit')
should_exit = should_exit_t.bool().item()
if should_exit:
return self.last_mean_rewards, epoch_num
class ContinuousA2CBase(A2CBase):
def __init__(self, base_name, params):
A2CBase.__init__(self, base_name, params)
self.is_discrete = False
action_space = self.env_info['action_space']
self.actions_num = action_space.shape[0]
self.bounds_loss_coef = self.config.get('bounds_loss_coef', None)
self.clip_actions = self.config.get('clip_actions', True)
# todo introduce device instead of cuda()
self.actions_low = torch.from_numpy(action_space.low.copy()).float().to(self.ppo_device)
self.actions_high = torch.from_numpy(action_space.high.copy()).float().to(self.ppo_device)
def preprocess_actions(self, actions):
if self.clip_actions:
clamped_actions = torch.clamp(actions, -1.0, 1.0)
rescaled_actions = rescale_actions(self.actions_low, self.actions_high, clamped_actions)
else:
rescaled_actions = actions
if not self.is_tensor_obses:
rescaled_actions = rescaled_actions.cpu().numpy()
return rescaled_actions
def init_tensors(self):
A2CBase.init_tensors(self)
self.update_list = ['actions', 'neglogpacs', 'values', 'mus', 'sigmas']
self.tensor_list = self.update_list + ['obses', 'states', 'dones']
def train_epoch(self):