Skip to content

Commit

Permalink
馃殌 [RofuncRL] Still struggling in finding the ASEHRL bug
Browse files Browse the repository at this point in the history
Now 1 step LLC works, but more steps doesn't
  • Loading branch information
Skylark0924 committed Jun 26, 2023
1 parent 12fac37 commit c7c8000
Show file tree
Hide file tree
Showing 36 changed files with 5,142 additions and 78 deletions.
8 changes: 4 additions & 4 deletions examples/learning_rl/example_HumanoidASE_RofuncRL.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def inference(custom_args):


if __name__ == '__main__':
gpu_id = 1
gpu_id = 0

parser = argparse.ArgumentParser()
# Available tasks and motion files:
Expand All @@ -101,15 +101,15 @@ def inference(custom_args):
# HumanoidASEReachSwordShield -> reallusion_sword_shield/RL_Avatar_Idle_Ready_Motion.npy
# HumanoidASELocationSwordShield -> reallusion_sword_shield/RL_Avatar_Idle_Ready_Motion.npy
# HumanoidASEStrikeSwordShield -> reallusion_sword_shield/RL_Avatar_Idle_Ready_Motion.npy
parser.add_argument("--task", type=str, default="HumanoidASEGetupSwordShield")
parser.add_argument("--task", type=str, default="HumanoidASEHeadingSwordShield")
parser.add_argument("--motion_file", type=str,
default="reallusion_sword_shield/dataset_reallusion_sword_shield.yaml")
default="reallusion_sword_shield/RL_Avatar_Idle_Ready_Motion.npy")
parser.add_argument("--agent", type=str, default="ase") # Available agent: ase
parser.add_argument("--num_envs", type=int, default=4096)
parser.add_argument("--sim_device", type=str, default="cuda:{}".format(gpu_id))
parser.add_argument("--rl_device", type=str, default="cuda:{}".format(gpu_id))
parser.add_argument("--graphics_device_id", type=int, default=gpu_id)
parser.add_argument("--headless", type=str, default="True")
parser.add_argument("--headless", type=str, default="False")
parser.add_argument("--inference", action="store_true", help="turn to inference mode while adding this argument")
parser.add_argument("--ckpt_path", type=str, default=None)
custom_args = parser.parse_args()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ Agent:
kl_threshold: 0 # Initial coefficient for KL divergence.

llc_ckpt_path:
llc_steps_per_high_action: 5
llc_steps_per_high_action: 1

# state_preprocessor: # State preprocessor type.
# state_preprocessor_kwargs: # State preprocessor kwargs.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ Agent:

kl_threshold: 0 # Initial coefficient for KL divergence.

llc_ckpt_path:
llc_ckpt_path: /home/ubuntu/Github/Knowledge-Universe/Robotics/Roadmap-for-robot-science/examples/learning_rl/runs/RofuncRL_ASETrainer_HumanoidASEGetupSwordShield_23-06-26_12-49-35-111331/checkpoints/ckpt_87000.pth
llc_steps_per_high_action: 5

# state_preprocessor: # State preprocessor type.
Expand Down
2 changes: 1 addition & 1 deletion rofunc/learning/RofuncRL/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def track_data(self, tag: str, value: float) -> None:
def store_transition(self, states: torch.Tensor, actions: torch.Tensor, next_states: torch.Tensor,
rewards: torch.Tensor, terminated: torch.Tensor, truncated: torch.Tensor, infos: torch.Tensor):
"""
Record the transition.
Record the transition. (Only rewards, truncated and terminated are used in this base class)
"""
if self.cumulative_rewards is None:
self.cumulative_rewards = torch.zeros_like(rewards, dtype=torch.float32)
Expand Down
49 changes: 35 additions & 14 deletions rofunc/learning/RofuncRL/agents/mixline/ase_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import rofunc as rf
from rofunc.learning.RofuncRL.agents.base_agent import BaseAgent
from rofunc.learning.RofuncRL.agents.mixline.amp_agent import AMPAgent
from rofunc.learning.RofuncRL.models.misc_models import ASEDiscEnc
from rofunc.learning.RofuncRL.models.base_models import BaseMLP
from rofunc.learning.RofuncRL.utils.memory import Memory

Expand Down Expand Up @@ -72,6 +73,13 @@ def __init__(self,
self._enc_reward_weight = cfg.Agent.enc_reward_weight

'''Define ASE specific models except for AMP'''
# self.discriminator = ASEDiscEnc(cfg.Model,
# input_dim=amp_observation_space.shape[0],
# enc_output_dim=self._ase_latent_dim,
# disc_output_dim=1,
# cfg_name='encoder').to(device)
# self.encoder = self.discriminator

self.encoder = BaseMLP(cfg.Model,
input_dim=amp_observation_space.shape[0],
output_dim=self._ase_latent_dim,
Expand All @@ -95,10 +103,11 @@ def __init__(self,

def _set_up(self):
super()._set_up()
self.optimizer_enc = torch.optim.Adam(self.encoder.parameters(), lr=self._lr_e, eps=self._adam_eps)
if self._lr_scheduler is not None:
self.scheduler_enc = self._lr_scheduler(self.optimizer_enc, **self._lr_scheduler_kwargs)
self.checkpoint_modules["optimizer_enc"] = self.optimizer_enc
if self.encoder is not self.discriminator:
self.optimizer_enc = torch.optim.Adam(self.encoder.parameters(), lr=self._lr_e, eps=self._adam_eps)
if self._lr_scheduler is not None:
self.scheduler_enc = self._lr_scheduler(self.optimizer_enc, **self._lr_scheduler_kwargs)
self.checkpoint_modules["optimizer_enc"] = self.optimizer_enc

def act(self, states: torch.Tensor, deterministic: bool = False, ase_latents: torch.Tensor = None):
if self._current_states is not None:
Expand Down Expand Up @@ -173,7 +182,10 @@ def update_net(self):
style_rewards *= self._discriminator_reward_scale

# Compute encoder reward
enc_output = self.encoder(self._amp_state_preprocessor(amp_states))
if self.encoder is self.discriminator:
enc_output = self.encoder.get_enc(self._amp_state_preprocessor(amp_states))
else:
enc_output = self.encoder(self._amp_state_preprocessor(amp_states))
enc_output = torch.nn.functional.normalize(enc_output, dim=-1)
enc_reward = torch.clamp_min(torch.sum(enc_output * ase_latents, dim=-1, keepdim=True), 0.0)
enc_reward *= self._enc_reward_scale
Expand Down Expand Up @@ -311,7 +323,10 @@ def update_net(self):
discriminator_loss *= self._discriminator_loss_scale

# encoder loss
enc_output = self.encoder(self._amp_state_preprocessor(sampled_amp_states))
if self.encoder is self.discriminator:
enc_output = self.encoder.get_enc(self._amp_state_preprocessor(sampled_amp_states))
else:
enc_output = self.encoder(self._amp_state_preprocessor(sampled_amp_states_batch))
enc_output = torch.nn.functional.normalize(enc_output, dim=-1)
enc_err = -torch.sum(enc_output * sampled_ase_latents, dim=-1, keepdim=True)
enc_loss = torch.mean(enc_err)
Expand Down Expand Up @@ -357,17 +372,21 @@ def update_net(self):

# Update discriminator network
self.optimizer_disc.zero_grad()
discriminator_loss.backward()
if self.encoder is self.discriminator:
(discriminator_loss + enc_loss).backward()
else:
discriminator_loss.backward()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(self.discriminator.parameters(), self._grad_norm_clip)
self.optimizer_disc.step()

# Update encoder network
self.optimizer_enc.zero_grad()
enc_loss.backward()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(self.encoder.parameters(), self._grad_norm_clip)
self.optimizer_enc.step()
if self.encoder is not self.discriminator:
self.optimizer_enc.zero_grad()
enc_loss.backward()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(self.encoder.parameters(), self._grad_norm_clip)
self.optimizer_enc.step()

# update cumulative losses
cumulative_policy_loss += policy_loss.item()
Expand All @@ -382,7 +401,8 @@ def update_net(self):
self.scheduler_policy.step()
self.scheduler_value.step()
self.scheduler_disc.step()
self.scheduler_enc.step()
if self.encoder is not self.discriminator:
self.scheduler_enc.step()

# update AMP replay buffer
self.replay_buffer.add_samples(states=amp_states.view(-1, amp_states.shape[-1]))
Expand All @@ -407,4 +427,5 @@ def update_net(self):
self.track_data("Learning / Learning rate (policy)", self.scheduler_policy.get_last_lr()[0])
self.track_data("Learning / Learning rate (value)", self.scheduler_value.get_last_lr()[0])
self.track_data("Learning / Learning rate (discriminator)", self.scheduler_disc.get_last_lr()[0])
self.track_data("Learning / Learning rate (encoder)", self.scheduler_enc.get_last_lr()[0])
if self.encoder is not self.discriminator:
self.track_data("Learning / Learning rate (encoder)", self.scheduler_enc.get_last_lr()[0])
Loading

0 comments on commit c7c8000

Please sign in to comment.