In [1]:
import time
import mujoco.viewer
import mujoco
import numpy as np
import torch
from model.actor_critic import EncoderNet, StochasticDDPGActor
from RLAlg.nn.steps import DeterministicContinuousPolicyStep

In [2]:
m = mujoco.MjModel.from_xml_path("env/assets/so101/scene.xml")
d = mujoco.MjData(m)
m.opt.timestep = 1/30

goal_state = np.array([0.25, 0.0, 0.17, 1.0, 0.0, 0.0, 0.0, 0.7071, 0.7071, 0.0, 0.0])

In [3]:
pre_pos = d.qpos[:].copy()
current_pos = d.qpos[:].copy()
pre_action = np.array([0, 0, 0, 0, 0, 0])

In [4]:
device = torch.device("cuda:0")
encoder = EncoderNet(6+6+6+3+4+4, [256, 256, 256]).to(device)
actor = StochasticDDPGActor(encoder.dim, [256, 256], 6).to(device)

encoder_params, actor_params, _ = torch.load("model.pth")
encoder.load_state_dict(encoder_params)
actor.load_state_dict(actor_params)

encoder.eval()
actor.eval()

StochasticDDPGActor(
  (layers): Sequential(
    (0): MLPLayer(
      (linear): Linear(in_features=256, out_features=256, bias=False)
      (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (activate_func): SiLU()
    )
    (1): MLPLayer(
      (linear): Linear(in_features=256, out_features=256, bias=False)
      (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (activate_func): SiLU()
    )
  )
  (policy_layer): DeterministicHead(
    (linear): Linear(in_features=256, out_features=6, bias=True)
  )
)

In [5]:
@torch.no_grad()
def get_action(obs):
    obs = torch.from_numpy(obs).unsqueeze(0).float().to(device)

    feature = encoder(obs)
    step:DeterministicContinuousPolicyStep = actor(feature, std=1.0)
    action = step.mean.squeeze(0).cpu().numpy()

    return action

In [6]:
with mujoco.viewer.launch_passive(m, d) as viewer:
    # Close the viewer automatically after simulation_duration wall-seconds.
    start = time.time()
    while viewer.is_running() and time.time() - start < 5:
        step_start = time.time()

        obs = np.concatenate([goal_state, current_pos, pre_pos, pre_action])
        action = get_action(obs)
        target_pos = current_pos + action * 0.25

        
        target_pos = target_pos.clip(m.jnt_range[:, 0], m.jnt_range[:, 1])

        print(action * 0.25)
        print(current_pos)
        print(target_pos)
        print("---------")

        d.qpos[:] = target_pos
        # mj_step can be replaced with code that also evaluates
        # a policy and applies a control signal before stepping the physics.
        mujoco.mj_step(m, d)

        pre_pos = current_pos.copy()
        current_pos = d.qpos[:].copy()
        pre_action = action.copy()

        viewer.sync()

        # Rudimentary time keeping, will drift relative to wall clock.
        time_until_next_step = m.opt.timestep - (time.time() - step_start)
        if time_until_next_step > 0:
            time.sleep(time_until_next_step)

[ 0.01975363  0.12390633  0.25        0.25        0.02286527 -0.12073833]
[0. 0. 0. 0. 0. 0.]
[ 0.01975363  0.12390633  0.25        0.25        0.02286527 -0.12073833]
---------
[ 0.00595753 -0.10739202  0.19186655  0.25        0.02918612 -0.07603059]
[ 0.01440974  0.1040628   0.19934614  0.18095974  0.01465216 -0.07245386]
[ 0.02036727 -0.00332922  0.39121269  0.43095974  0.04383827 -0.14848445]
---------
[-0.00160523 -0.09071025  0.12178066  0.25        0.02750999 -0.06706887]
[ 1.10729532e-02 -1.27037510e-04  3.00473093e-01  3.18305067e-01
  2.21977724e-02 -6.04724938e-02]
[ 0.00946773 -0.09083729  0.42225375  0.56830507  0.04970776 -0.12754136]
---------
[ 0.00111714 -0.0575632   0.10018256  0.25        0.03044422 -0.06620227]
[ 0.00126619 -0.04444928  0.3029932   0.42852543  0.01781955 -0.02491155]
[ 0.00238333 -0.10201248  0.40317575  0.67852543  0.04826377 -0.09111382]
---------
[ 0.00180581 -0.04534572  0.07847452  0.25        0.03381475 -0.06927224]
[-0.00263573 -0.02535977  0

In [7]:
d.qpos

array([-0.00331967,  0.03169701,  0.01719076,  1.2730679 ,  0.00331237,
       -0.00334428])

In [8]:
d.qvel

array([ 0.50007907, -1.15025615,  2.26017634, -5.47823369, -1.033261  ,
        1.17004507])