In [1]:
import time
import mujoco.viewer
import mujoco
import numpy as np
import torch
from model.actor_critic import EncoderNet, StochasticDDPGActor
from RLAlg.nn.steps import DeterministicContinuousPolicyStep

In [2]:
m = mujoco.MjModel.from_xml_path("env/assets/so101/scene.xml")
d = mujoco.MjData(m)
m.opt.timestep = 1/30

goal_state = np.array([0.25, 0.0, 0.17, 1.0, 0.0, 0.0, 0.0, 0.7071, 0.7071, 0.0, 0.0])

In [3]:
pre_pos = d.qpos[:].copy()
current_pos = d.qpos[:].copy()
pre_action = np.array([0, 0, 0, 0, 0, 0])

In [4]:
device = torch.device("cuda:0")
encoder = EncoderNet(6+6+6+3+4+4, [256, 256, 256]).to(device)
actor = StochasticDDPGActor(encoder.dim, [256, 256], 6).to(device)

encoder_params, actor_params, _ = torch.load("model.pth")
encoder.load_state_dict(encoder_params)
actor.load_state_dict(actor_params)

encoder.eval()
actor.eval()

StochasticDDPGActor(
  (layers): Sequential(
    (0): MLPLayer(
      (linear): Linear(in_features=256, out_features=256, bias=False)
      (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (activate_func): SiLU()
    )
    (1): MLPLayer(
      (linear): Linear(in_features=256, out_features=256, bias=False)
      (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (activate_func): SiLU()
    )
  )
  (policy_layer): DeterministicHead(
    (linear): Linear(in_features=256, out_features=6, bias=True)
  )
)

In [5]:
@torch.no_grad()
def get_action(obs):
    obs = torch.from_numpy(obs).unsqueeze(0).float().to(device)

    feature = encoder(obs)
    step:DeterministicContinuousPolicyStep = actor(feature, std=1.0)
    action = step.mean.squeeze(0).cpu().numpy()

    return action

In [None]:
with mujoco.viewer.launch_passive(m, d) as viewer:
    # Close the viewer automatically after simulation_duration wall-seconds.
    start = time.time()
    while viewer.is_running() and time.time() - start < 5:
        step_start = time.time()

        obs = np.concatenate([goal_state, current_pos, pre_pos, pre_action])
        action = get_action(obs)
        target_pos = current_pos + action * 0.25

        
        target_pos = target_pos.clip(m.jnt_range[:, 0], m.jnt_range[:, 1])

        print(action * 0.25)
        print(current_pos)
        print(target_pos)
        print("---------")

        d.qpos[:] = target_pos
        # mj_step can be replaced with code that also evaluates
        # a policy and applies a control signal before stepping the physics.
        mujoco.mj_step(m, d)

        pre_pos = current_pos.copy()
        current_pos = d.qpos[:].copy()
        pre_action = action.copy()

        viewer.sync()

        # Rudimentary time keeping, will drift relative to wall clock.
        time_until_next_step = m.opt.timestep - (time.time() - step_start)
        if time_until_next_step > 0:
            time.sleep(time_until_next_step)

[-0.00378472  0.0639085  -0.09148762  0.18301126  0.03379055 -0.01046629]
[-0.00294758  0.03169514  0.01633593  1.32891674  0.00330186 -0.00301716]
[-0.00673231  0.09560364 -0.0751517   1.511928    0.03709242 -0.01348345]
---------
[-0.00378473  0.06390853 -0.09148759  0.18301134  0.03379057 -0.0104663 ]
[-0.00294758  0.03169514  0.01633593  1.32891671  0.00330186 -0.00301715]
[-0.00673231  0.09560367 -0.07515166  1.51192805  0.03709243 -0.01348345]
---------
[-0.00378473  0.06390848 -0.09148765  0.18301128  0.03379057 -0.0104663 ]
[-0.00294758  0.03169516  0.01633595  1.32891676  0.00330186 -0.00301716]
[-0.00673232  0.09560364 -0.07515169  1.51192804  0.03709243 -0.01348346]
---------
[-0.00378472  0.0639085  -0.09148763  0.18301126  0.03379055 -0.0104663 ]
[-0.00294759  0.03169513  0.01633593  1.32891675  0.00330187 -0.00301716]
[-0.00673231  0.09560363 -0.0751517   1.51192801  0.03709242 -0.01348346]
---------
[-0.00378474  0.06390854 -0.09148758  0.18301134  0.03379056 -0.01046631

In [9]:
m.jnt_range[:, 1]

array([1.91986218, 1.74532925, 1.69      , 1.65806273, 2.84120631,
       1.7453292 ])