Skip to content

Commit

Permalink
Implement distributed runs for multi-agents
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Jun 24, 2024
1 parent c739b75 commit c6301fb
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
16 changes: 15 additions & 1 deletion skrl/multi_agents/torch/ippo/ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn as nn
import torch.nn.functional as F

from skrl import config, logger
from skrl.memories.torch import Memory
from skrl.models.torch import Model
from skrl.multi_agents.torch import MultiAgent
Expand Down Expand Up @@ -109,9 +110,18 @@ def __init__(self,
self.values = {uid: self.models[uid].get("value", None) for uid in self.possible_agents}

for uid in self.possible_agents:
# checkpoint models
self.checkpoint_modules[uid]["policy"] = self.policies[uid]
self.checkpoint_modules[uid]["value"] = self.values[uid]

# broadcast models' parameters in distributed runs
if config.torch.is_distributed:
logger.info(f"Broadcasting models' parameters")
if self.policies[uid] is not None:
self.policies[uid].broadcast_parameters()
if self.values[uid] is not None and self.policies[uid] is not self.values[uid]:
self.values[uid].broadcast_parameters()

# configuration
self._learning_epochs = self._as_dict(self.cfg["learning_epochs"])
self._mini_batches = self._as_dict(self.cfg["mini_batches"])
Expand Down Expand Up @@ -437,6 +447,10 @@ def compute_gae(rewards: torch.Tensor,
# optimization step
self.optimizers[uid].zero_grad()
(policy_loss + entropy_loss + value_loss).backward()
if config.torch.is_distributed:
policy.reduce_parameters()
if policy is not value:
value.reduce_parameters()
if self._grad_norm_clip[uid] > 0:
if policy is value:
nn.utils.clip_grad_norm_(policy.parameters(), self._grad_norm_clip[uid])
Expand All @@ -453,7 +467,7 @@ def compute_gae(rewards: torch.Tensor,
# update learning rate
if self._learning_rate_scheduler[uid]:
if isinstance(self.schedulers[uid], KLAdaptiveLR):
self.schedulers[uid].step(torch.tensor(kl_divergences).mean())
self.schedulers[uid].step(torch.tensor(kl_divergences, device=self.device).mean())
else:
self.schedulers[uid].step()

Expand Down
16 changes: 15 additions & 1 deletion skrl/multi_agents/torch/mappo/mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn as nn
import torch.nn.functional as F

from skrl import config, logger
from skrl.memories.torch import Memory
from skrl.models.torch import Model
from skrl.multi_agents.torch import MultiAgent
Expand Down Expand Up @@ -116,9 +117,18 @@ def __init__(self,
self.values = {uid: self.models[uid].get("value", None) for uid in self.possible_agents}

for uid in self.possible_agents:
# checkpoint models
self.checkpoint_modules[uid]["policy"] = self.policies[uid]
self.checkpoint_modules[uid]["value"] = self.values[uid]

# broadcast models' parameters in distributed runs
if config.torch.is_distributed:
logger.info(f"Broadcasting models' parameters")
if self.policies[uid] is not None:
self.policies[uid].broadcast_parameters()
if self.values[uid] is not None and self.policies[uid] is not self.values[uid]:
self.values[uid].broadcast_parameters()

# configuration
self._learning_epochs = self._as_dict(self.cfg["learning_epochs"])
self._mini_batches = self._as_dict(self.cfg["mini_batches"])
Expand Down Expand Up @@ -457,6 +467,10 @@ def compute_gae(rewards: torch.Tensor,
# optimization step
self.optimizers[uid].zero_grad()
(policy_loss + entropy_loss + value_loss).backward()
if config.torch.is_distributed:
policy.reduce_parameters()
if policy is not value:
value.reduce_parameters()
if self._grad_norm_clip[uid] > 0:
if policy is value:
nn.utils.clip_grad_norm_(policy.parameters(), self._grad_norm_clip[uid])
Expand All @@ -473,7 +487,7 @@ def compute_gae(rewards: torch.Tensor,
# update learning rate
if self._learning_rate_scheduler[uid]:
if isinstance(self.schedulers[uid], KLAdaptiveLR):
self.schedulers[uid].step(torch.tensor(kl_divergences).mean())
self.schedulers[uid].step(torch.tensor(kl_divergences, device=self.device).mean())
else:
self.schedulers[uid].step()

Expand Down

0 comments on commit c6301fb

Please sign in to comment.