Skip to content

Commit

Permalink
Fix the buffer bug and add td3 for elegantrl
Browse files Browse the repository at this point in the history
  • Loading branch information
Skylark0924 committed Dec 17, 2022
1 parent 277bfe4 commit 2c421be
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 8 deletions.
23 changes: 15 additions & 8 deletions rofunc/examples/learning/CURICabinet_elegantrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,25 @@ def eval(custom_args, ckpt_path=None):
# TODO: add support for eval mode
beauty_print("Start evaluating")

env, args = setup(custom_args, eval_mode=True)
env, agent = setup(custom_args, eval_mode=True)

# load checkpoint
if ckpt_path is None:
ckpt_path = model_zoo(name="CURICabinetPPO_right_arm.pt")
agent.load(ckpt_path)
agent.save_or_load_agent(cwd=ckpt_path, if_save=False)

# evaluate the agent
trainer.eval()
state = env.reset()
episode_reward = 0
for i in range(2 ** 10):
action = agent.act.get_action(state).detach()
next_state, reward, done, _ = env.step(action)
episode_reward += reward.mean()
# if done:
# print(f'Step {i:>6}, Episode return {episode_reward:8.3f}')
# break
# else:
state = next_state


if __name__ == '__main__':
Expand All @@ -47,15 +57,12 @@ def eval(custom_args, ckpt_path=None):
parser.add_argument("--sim_device", type=str, default="cuda:{}".format(gpu_id))
parser.add_argument("--rl_device", type=str, default="cuda:{}".format(gpu_id))
parser.add_argument("--graphics_device_id", type=int, default=gpu_id)
parser.add_argument("--headless", type=str, default="False")
parser.add_argument("--headless", type=str, default="True")
parser.add_argument("--test", action="store_true", help="turn to test mode while adding this argument")
custom_args = parser.parse_args()

if not custom_args.test:
train(custom_args)
else:
# TODO: add support for eval mode
folder = 'CURICabinetSAC_22-11-27_18-38-53-296354'
ckpt_path = "/home/ubuntu/Github/Knowledge-Universe/Robotics/Roadmap-for-robot-science/rofunc/examples/learning/runs/{}/checkpoints/best_agent.pt".format(
folder)
ckpt_path = "/home/ubuntu/Github/Knowledge-Universe/Robotics/Roadmap-for-robot-science/rofunc/examples/learning/result/CURICabinet_SAC_42/actor_53608448_00007.742.pth"
eval(custom_args, ckpt_path=ckpt_path)
11 changes: 11 additions & 0 deletions rofunc/lfd/rl/utils/elegantrl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
from elegantrl.train.config import Arguments
from elegantrl.agents.AgentPPO import AgentPPO
from elegantrl.agents.AgentSAC import AgentSAC
from elegantrl.agents.AgentTD3 import AgentTD3
from elegantrl.train.run import init_agent

from rofunc.config.utils import get_config, omegaconf_to_dict
from rofunc.lfd.rl.tasks import task_map
from rofunc.utils.logger.beauty_logger import beauty_print


class ElegantRLIsaacGymEnvWrapper:
Expand Down Expand Up @@ -96,6 +99,7 @@ def step(
def setup(custom_args, eval_mode=False):
# get config
sys.argv.append("task={}".format(custom_args.task))
beauty_print("Agent: {}{}ElegantRL".format(custom_args.task, custom_args.agent.upper()), 2)
sys.argv.append("sim_device={}".format(custom_args.sim_device))
sys.argv.append("rl_device={}".format(custom_args.rl_device))
sys.argv.append("graphics_device_id={}".format(custom_args.graphics_device_id))
Expand All @@ -106,6 +110,7 @@ def setup(custom_args, eval_mode=False):

if eval_mode:
task_cfg_dict['env']['numEnvs'] = 16
cfg.headless = False

env = task_map[custom_args.task](cfg=task_cfg_dict,
rl_device=cfg.rl_device,
Expand All @@ -121,6 +126,8 @@ def setup(custom_args, eval_mode=False):
agent_class = AgentPPO
elif custom_args.agent.lower() == "sac":
agent_class = AgentSAC
elif custom_args.agent.lower() == "td3":
agent_class = AgentTD3
else:
raise ValueError("Agent not supported")

Expand All @@ -143,4 +150,8 @@ def setup(custom_args, eval_mode=False):
args.learner_gpus = cfg.graphics_device_id
args.random_seed = 42

if eval_mode:
agent = init_agent(args, args.learner_gpus, env)
return env, agent

return env, args

0 comments on commit 2c421be

Please sign in to comment.