In [None]:
# ============================================================
# Shadow Dexterous Hand + Crosley Alarm Clock
# 5 grasp trials + METRICS + summary, viewer stays open
# ============================================================

import os, time
import numpy as np
import mujoco
import mujoco.viewer
import gymnasium_robotics

# -------------------------------
# CONFIG
# -------------------------------
ASSETS_DIR = os.path.join(os.path.dirname(gymnasium_robotics.__file__), "envs", "assets", "hand")
XML_PATH   = os.path.join(ASSETS_DIR, "hand_manipulate_clock.xml")

NUM_TRIALS  = 5
OPEN_STEPS  = 40
CLOSE_STEPS = 120
HOLD_STEPS  = 200

# "успех" в HOLD: параметры стабильности
Z_MIN      = 0.05    # объект не должен "упасть ниже"
V_MAX      = 2.0     # скорость не слишком большая
DRIFT_MAX  = 0.25    # дрейф не слишком большой
SUCCESS_HOLD_RATIO = 0.90

print("Assets dir:", ASSETS_DIR)
print("Clock scene XML exists:", os.path.exists(XML_PATH))

# -------------------------------
# LOAD MODEL
# -------------------------------
model = mujoco.MjModel.from_xml_path(XML_PATH)
data  = mujoco.MjData(model)

print("\nScene loaded successfully")
print("Bodies:", model.nbody, "Joints:", model.njnt, "Actuators:", model.nu, "Sensors:", model.nsensor)

# -------------------------------
# OBJECT IDs
# -------------------------------
OBJ_BODY_NAME  = "object"
OBJ_JOINT_NAME = "object:joint"

obj_bid = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_BODY, OBJ_BODY_NAME)
if obj_bid < 0:
    raise RuntimeError(f"Body '{OBJ_BODY_NAME}' not found in XML.")

obj_jid = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_JOINT, OBJ_JOINT_NAME)
if obj_jid < 0:
    raise RuntimeError(f"Joint '{OBJ_JOINT_NAME}' not found. Check your XML joint name.")

obj_qpos_adr = int(model.jnt_qposadr[obj_jid])  # free joint: qpos[adr:adr+7]

# -------------------------------
# BUILD SENSOR SLICES
# -------------------------------
sensor_slices = {}
off = 0
for i in range(model.nsensor):
    s = model.sensor(i)
    dim = int(s.dim)
    sensor_slices[s.name] = (off, dim, int(s.type))
    off += dim

def sensor_sum(name: str) -> float:
    if name not in sensor_slices:
        return 0.0
    st, dim, _ = sensor_slices[name]
    return float(np.sum(data.sensordata[st:st+dim]))

# fingertip touch sensors (если в XML есть shared_touch_sensors_*.xml)
FINGERTIP_TOUCH = {
    "FF": "robot0:ST_Tch_fftip",
    "MF": "robot0:ST_Tch_mftip",
    "RF": "robot0:ST_Tch_rftip",
    "LF": "robot0:ST_Tch_lftip",
    "TH": "robot0:ST_Tch_thtip",
}
FINGERS = ["FF", "MF", "RF", "LF", "TH"]

touch_present = {f: (FINGERTIP_TOUCH[f] in sensor_slices) for f in FINGERS}
print("Fingertip touch present:", touch_present)

# "ладонь" — попробуем собрать сенсоры, похожие на palm (если есть)
PALM_SENSOR_NAMES = [n for n in sensor_slices.keys() if ("palm" in n.lower()) or ("TS_palm" in n)]
# -------------------------------
# ACTUATORS PER FINGER
# -------------------------------
def build_finger_actuator_map():
    m = {"FF": [], "MF": [], "RF": [], "LF": [], "TH": []}
    for i in range(model.nu):
        name = model.actuator(i).name  # e.g. robot0:A_FFJ3
        if ":A_FF" in name: m["FF"].append(i)
        if ":A_MF" in name: m["MF"].append(i)
        if ":A_RF" in name: m["RF"].append(i)
        if ":A_LF" in name: m["LF"].append(i)
        if ":A_TH" in name: m["TH"].append(i)
    return m

FINGER_ACT = build_finger_actuator_map()
print("Actuators per finger:", {k: len(v) for k, v in FINGER_ACT.items()})

# -------------------------------
# CONTACT COUNT (hand-object contacts)
# -------------------------------
def count_object_contacts():
    cnt = 0
    for i in range(data.ncon):
        c = data.contact[i]
        b1 = int(model.geom_bodyid[int(c.geom1)])
        b2 = int(model.geom_bodyid[int(c.geom2)])
        if (b1 == obj_bid and b2 != obj_bid) or (b2 == obj_bid and b1 != obj_bid):
            cnt += 1
    return cnt

# -------------------------------
# RANDOMIZE OBJECT START (XY + yaw)
# -------------------------------
def set_object_pose_random(seed: int, xy_sigma=0.015, yaw_range=np.deg2rad(25)):
    rng = np.random.default_rng(seed)
    qpos = data.qpos

    base_pos  = qpos[obj_qpos_adr:obj_qpos_adr+3].copy()
    base_quat = qpos[obj_qpos_adr+3:obj_qpos_adr+7].copy()

    dx, dy = rng.normal(0, xy_sigma, size=2)
    pos = base_pos.copy()
    pos[0] += dx
    pos[1] += dy

    yaw = rng.uniform(-yaw_range, yaw_range)
    q_yaw = np.zeros(4, dtype=float)
    mujoco.mju_axisAngle2Quat(q_yaw, np.array([0.0, 0.0, 1.0], dtype=float), float(yaw))

    q = np.zeros(4, dtype=float)
    mujoco.mju_mulQuat(q, q_yaw, base_quat)

    qpos[obj_qpos_adr:obj_qpos_adr+3] = pos
    qpos[obj_qpos_adr+3:obj_qpos_adr+7] = q

# -------------------------------
# CONTROLLER (theta)
# theta = [wFF,wMF,wRF,wLF,wTH, th_touch, k_hold]
# -------------------------------
def controller_step(t, T_close, theta, latched):
    w = np.array(theta[:5], dtype=float)
    th_touch = float(theta[5])
    k_hold   = float(theta[6])

    u_base = np.clip(t / max(1, T_close), 0.0, 1.0)  # ramp

    # latch by fingertip touch
    touch = {}
    for f in FINGERS:
        touch[f] = sensor_sum(FINGERTIP_TOUCH[f]) if touch_present[f] else 0.0
        if (not latched[f]) and (touch[f] > th_touch):
            latched[f] = True

    u = np.zeros(model.nu, dtype=float)

    for i, f in enumerate(FINGERS):
        wf = float(w[i])
        uf = wf * (k_hold * u_base if latched[f] else u_base)
        uf = float(np.clip(uf, 0.0, 1.0))
        for ai in FINGER_ACT[f]:
            u[ai] = uf

    return u, touch

# -------------------------------
# RUN ONE TRIAL (with metrics)
# -------------------------------
def run_trial(theta, seed, viewer=None, sleep_open=0.003, sleep_close=0.003, sleep_hold=0.006):
    mujoco.mj_resetData(model, data)
    mujoco.mj_forward(model, data)

    set_object_pose_random(seed)
    mujoco.mj_forward(model, data)

    latched = {f: False for f in FINGERS}

    energy = 0.0
    smooth = 0.0
    prev_u = np.zeros(model.nu, dtype=float)

    # ---- OPEN ----
    for _ in range(OPEN_STEPS):
        data.ctrl[:] = 0.0
        u = data.ctrl.copy()
        energy += float(np.sum(u*u))
        smooth += float(np.sum((u-prev_u)**2))
        prev_u = u

        mujoco.mj_step(model, data)
        if viewer is not None:
            viewer.sync()
            time.sleep(sleep_open)

    # ---- CLOSE ----
    last_touch = {f: 0.0 for f in FINGERS}
    for t in range(CLOSE_STEPS):
        u, touch = controller_step(t=t+1, T_close=CLOSE_STEPS, theta=theta, latched=latched)
        data.ctrl[:] = u
        last_touch = touch

        energy += float(np.sum(u*u))
        smooth += float(np.sum((u-prev_u)**2))
        prev_u = u

        mujoco.mj_step(model, data)
        if viewer is not None:
            viewer.sync()
            time.sleep(sleep_close)

    contacts = int(count_object_contacts())
    obj_pos = data.xpos[obj_bid].copy()
    z = float(obj_pos[2])

    # fingertip metrics after close (instant)
    tip_vals = np.array([last_touch[f] for f in FINGERS], dtype=float)
    tip_sum = float(np.sum(tip_vals))
    tip_active = int(np.sum(tip_vals > 0.0))

    # palm sum (instant) if exists
    palm_sum = float(np.sum([sensor_sum(n) for n in PALM_SENSOR_NAMES])) if len(PALM_SENSOR_NAMES) else 0.0

    # ---- HOLD ----
    pos_ref = obj_pos.copy()
    stable_steps = 0
    max_speed = 0.0
    max_drift = 0.0

    for _ in range(HOLD_STEPS):
        u, touch = controller_step(t=CLOSE_STEPS, T_close=CLOSE_STEPS, theta=theta, latched=latched)
        data.ctrl[:] = u

        energy += float(np.sum(u*u))
        smooth += float(np.sum((u-prev_u)**2))
        prev_u = u

        mujoco.mj_step(model, data)

        p = data.xpos[obj_bid]
        v = data.cvel[obj_bid][3:6]  # linear velocity
        speed = float(np.linalg.norm(v))
        drift = float(np.linalg.norm(p - pos_ref))

        max_speed = max(max_speed, speed)
        max_drift = max(max_drift, drift)

        if (p[2] > Z_MIN) and (speed < V_MAX) and (drift < DRIFT_MAX):
            stable_steps += 1

        if viewer is not None:
            viewer.sync()
            time.sleep(sleep_hold)

    hold_ratio = stable_steps / float(HOLD_STEPS)
    success = 1 if hold_ratio >= SUCCESS_HOLD_RATIO else 0

    return {
        "contacts": contacts,
        "tip_active": tip_active,
        "tip_sum": tip_sum,
        "palm_sum": palm_sum,
        "z": z,
        "hold_ratio": float(hold_ratio),
        "max_speed": float(max_speed),
        "max_drift": float(max_drift),
        "energy": float(energy),
        "smooth": float(smooth),
        "success": int(success),
    }

# -------------------------------
# PICK A THETA (manually or from your optimizer)
# -------------------------------
# Example: thumb slightly stronger + moderate latch threshold + strong reduction after touch
theta = np.array([1.00, 1.00, 1.00, 0.90, 1.20, 0.010, 0.30], dtype=float)
print("\ntheta = [wFF,wMF,wRF,wLF,wTH, th_touch, k_hold] =", theta)

# -------------------------------
# LAUNCH VIEWER
# -------------------------------
viewer = mujoco.viewer.launch_passive(model, data)

# -------------------------------
# RUN TRIALS + PRINT METRICS
# -------------------------------
all_metrics = []

for trial in range(1, NUM_TRIALS + 1):
    print("\n==============================")
    print(f" TRIAL {trial}/{NUM_TRIALS}")
    print("==============================")

    m = run_trial(theta=theta, seed=2000 + trial, viewer=viewer)

    # compact print (как у тебя)
    print(f"contacts={m['contacts']} | tip_active={m['tip_active']} | tip_sum={m['tip_sum']:.4f} | palm_sum={m['palm_sum']:.4f} | z={m['z']:.3f}")
    print(f"hold_ratio={m['hold_ratio']:.2f} | max_speed={m['max_speed']:.3f} | max_drift={m['max_drift']:.3f} | energy={m['energy']:.1f} | smooth={m['smooth']:.1f} | SUCCESS={m['success']}")

    all_metrics.append(m)

# -------------------------------
# SUMMARY
# -------------------------------
succ_rate = float(np.mean([m["success"] for m in all_metrics]))
avg_contacts = float(np.mean([m["contacts"] for m in all_metrics]))
avg_tip_active = float(np.mean([m["tip_active"] for m in all_metrics]))
avg_energy = float(np.mean([m["energy"] for m in all_metrics]))

print("\n==============================")
print(" SUMMARY (over trials)")
print("==============================")
print(f"Success rate: {succ_rate:.2f}")
print(f"Avg contacts after close: {avg_contacts:.2f}")
print(f"Avg active fingertips: {avg_tip_active:.2f}")
print(f"Avg ctrl energy (sum u^2): {avg_energy:.1f}")

# -------------------------------
# KEEP WINDOW OPEN
# -------------------------------
print("\nAll trials finished.")
print("MuJoCo window will stay open.")
print("Close the viewer window manually to exit.")

while viewer.is_running():
    mujoco.mj_step(model, data)
    viewer.sync()
    time.sleep(0.02)
