In [None]:
# ============================================================
# Shadow Dexterous Hand + Crosley Alarm Clock
# Auto-tune theta via Random Search (ONE CELL)
# - Baseline A, Baseline B, Method (auto theta)
# - Metrics + score
# - LF (мизинец) усиливаем через constraints + бонусы в score
# - Viewer stays open
# ============================================================

import os, time
import numpy as np
import mujoco
import mujoco.viewer

# -------------------------------
# PATHS
# -------------------------------
ROOT_DIR = r"C:\Users\d_v_o\Desktop\new_roms"   # <-- ПРОВЕРЬ
XML_PATH = os.path.join(ROOT_DIR, "project", "models", "hand", "hand_manipulate_clock.xml")

print("XML exists:", os.path.exists(XML_PATH))
print("XML:", XML_PATH)

# -------------------------------
# EPISODE CONFIG
# -------------------------------
OPEN_STEPS  = 40
CLOSE_STEPS = 120
HOLD_STEPS  = 200

# latch не включать сразу (чтобы не "приклеилось" от касания при падении)
MIN_CLOSE_STEPS_BEFORE_LATCH = 15

# Success критерии в HOLD
Z_MIN      = 0.05
V_MAX      = 2.0
DRIFT_MAX  = 0.25
SUCCESS_HOLD_RATIO = 0.90

# -------------------------------
# OPTIMIZER CONFIG
# -------------------------------
EVAL_TRIALS_PER_THETA = 3
RANDOM_ITERS = 80
SEED0 = 2000

# веса штрафов в score (подбирай)
LAMBDA_ENERGY = 1e-4
LAMBDA_SMOOTH = 1e-3
LAMBDA_DRIFT  = 1.0

# диапазоны поиска
TH_TOUCH_RANGE = (0.001, 0.05)
K_HOLD_RANGE   = (0.10, 0.60)
W_RANGE        = (0.5, 1.5)

# ---- LF (мизинец) усиление ----
LF_MIN = 1.10             # минимальный вес мизинца
LF_RANGE = (1.10, 1.80)   # и отдельный диапазон для wLF при генерации
BONUS_LF_TOUCH = 0.02     # бонус за величину touch мизинца (после close)
BONUS_LF_ACTIVE = 0.08    # бонус если мизинец вообще "активен" по touch
PENALTY_NO_LF = 0.15      # штраф если мизинец не коснулся вообще

# -------------------------------
# LOAD MODEL
# -------------------------------
model = mujoco.MjModel.from_xml_path(XML_PATH)
data  = mujoco.MjData(model)

print("\nScene loaded successfully")
print("Bodies:", model.nbody, "Joints:", model.njnt, "Actuators:", model.nu, "Sensors:", model.nsensor)

# -------------------------------
# OBJECT IDs
# -------------------------------
OBJ_BODY_NAME  = "object"
OBJ_JOINT_NAME = "object:joint"

obj_bid = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_BODY, OBJ_BODY_NAME)
if obj_bid < 0:
    raise RuntimeError(f"Body '{OBJ_BODY_NAME}' not found in XML.")

obj_jid = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_JOINT, OBJ_JOINT_NAME)
if obj_jid < 0:
    raise RuntimeError(f"Joint '{OBJ_JOINT_NAME}' not found. Check your XML joint name.")

obj_qpos_adr = int(model.jnt_qposadr[obj_jid])  # free joint: qpos[adr:adr+7]

# -------------------------------
# BUILD SENSOR SLICES
# -------------------------------
sensor_slices = {}
off = 0
for i in range(model.nsensor):
    s = model.sensor(i)
    dim = int(np.asarray(s.dim).item())
    sensor_slices[s.name] = (off, dim, int(np.asarray(s.type).item()))
    off += dim

def sensor_sum(name: str) -> float:
    if name not in sensor_slices:
        return 0.0
    st, dim, _ = sensor_slices[name]
    return float(np.sum(data.sensordata[st:st+dim]))

# fingertip touch sensors
FINGERTIP_TOUCH = {
    "FF": "robot0:ST_Tch_fftip",
    "MF": "robot0:ST_Tch_mftip",
    "RF": "robot0:ST_Tch_rftip",
    "LF": "robot0:ST_Tch_lftip",  # <- мизинец
    "TH": "robot0:ST_Tch_thtip",
}
FINGERS = ["FF", "MF", "RF", "LF", "TH"]
touch_present = {f: (FINGERTIP_TOUCH[f] in sensor_slices) for f in FINGERS}
print("Fingertip touch present:", touch_present)

# palm sensors (если есть)
PALM_SENSOR_NAMES = [n for n in sensor_slices.keys() if ("ts_palm" in n.lower()) or ("palm" in n.lower())]

# -------------------------------
# ACTUATORS PER FINGER
# -------------------------------
def build_finger_actuator_map():
    m = {"FF": [], "MF": [], "RF": [], "LF": [], "TH": []}
    for i in range(model.nu):
        name = model.actuator(i).name  # e.g. robot0:A_FFJ3
        if ":A_FF" in name: m["FF"].append(i)
        if ":A_MF" in name: m["MF"].append(i)
        if ":A_RF" in name: m["RF"].append(i)
        if ":A_LF" in name: m["LF"].append(i)
        if ":A_TH" in name: m["TH"].append(i)
    return m

FINGER_ACT = build_finger_actuator_map()
print("Actuators per finger:", {k: len(v) for k, v in FINGER_ACT.items()})

# -------------------------------
# CONTACT COUNT (hand-object contacts)
# -------------------------------
def count_object_contacts():
    cnt = 0
    for i in range(data.ncon):
        c = data.contact[i]
        b1 = int(model.geom_bodyid[int(c.geom1)])
        b2 = int(model.geom_bodyid[int(c.geom2)])
        if (b1 == obj_bid and b2 != obj_bid) or (b2 == obj_bid and b1 != obj_bid):
            cnt += 1
    return cnt

# -------------------------------
# RANDOMIZE OBJECT START (XY + yaw)
# -------------------------------
def set_object_pose_random(seed: int, xy_sigma=0.015, yaw_range=np.deg2rad(25)):
    rng = np.random.default_rng(seed)
    qpos = data.qpos

    base_pos  = qpos[obj_qpos_adr:obj_qpos_adr+3].copy()
    base_quat = qpos[obj_qpos_adr+3:obj_qpos_adr+7].copy()

    dx, dy = rng.normal(0, xy_sigma, size=2)
    pos = base_pos.copy()
    pos[0] += dx
    pos[1] += dy

    yaw = float(rng.uniform(-yaw_range, yaw_range))
    q_yaw = np.zeros(4, dtype=float)
    mujoco.mju_axisAngle2Quat(q_yaw, np.array([0.0, 0.0, 1.0], dtype=float), yaw)

    q = np.zeros(4, dtype=float)
    mujoco.mju_mulQuat(q, q_yaw, base_quat)

    qpos[obj_qpos_adr:obj_qpos_adr+3] = pos
    qpos[obj_qpos_adr+3:obj_qpos_adr+7] = q

# -------------------------------
# CONTROLLER (theta)
# theta = [wFF,wMF,wRF,wLF,wTH, th_touch, k_hold]
# -------------------------------
def controller_step(t_close, theta, latched):
    w = np.array(theta[:5], dtype=float)
    th_touch = float(theta[5])
    k_hold   = float(theta[6])

    u_base = np.clip(t_close / max(1, CLOSE_STEPS), 0.0, 1.0)  # ramp 0..1

    # latch only after MIN_CLOSE_STEPS_BEFORE_LATCH
    if t_close >= MIN_CLOSE_STEPS_BEFORE_LATCH:
        for f in FINGERS:
            if touch_present[f]:
                tv = sensor_sum(FINGERTIP_TOUCH[f])
                if (not latched[f]) and (tv > th_touch):
                    latched[f] = True

    u = np.zeros(model.nu, dtype=float)
    for i, f in enumerate(FINGERS):
        wf = float(w[i])
        uf = wf * (k_hold * u_base if latched[f] else u_base)
        uf = float(np.clip(uf, 0.0, 1.0))
        for ai in FINGER_ACT[f]:
            u[ai] = uf
    return u

# -------------------------------
# RUN ONE TRIAL (with metrics)
# -------------------------------
def run_trial(theta, seed, viewer=None, do_render=False,
             sleep_open=0.0, sleep_close=0.0, sleep_hold=0.0):

    mujoco.mj_resetData(model, data)
    mujoco.mj_forward(model, data)

    set_object_pose_random(seed)
    mujoco.mj_forward(model, data)

    latched = {f: False for f in FINGERS}

    energy = 0.0
    smooth = 0.0
    prev_u = np.zeros(model.nu, dtype=float)

    # ---- OPEN ----
    for _ in range(OPEN_STEPS):
        u = np.zeros(model.nu, dtype=float)
        data.ctrl[:] = u
        energy += float(np.sum(u*u))
        smooth += float(np.sum((u-prev_u)**2))
        prev_u = u
        mujoco.mj_step(model, data)
        if viewer is not None and do_render:
            viewer.sync()
            time.sleep(sleep_open)

    # ---- CLOSE ----
    last_touch = {f: 0.0 for f in FINGERS}
    for t in range(1, CLOSE_STEPS+1):
        u = controller_step(t_close=t, theta=theta, latched=latched)
        data.ctrl[:] = u

        for f in FINGERS:
            last_touch[f] = sensor_sum(FINGERTIP_TOUCH[f]) if touch_present[f] else 0.0

        energy += float(np.sum(u*u))
        smooth += float(np.sum((u-prev_u)**2))
        prev_u = u

        mujoco.mj_step(model, data)
        if viewer is not None and do_render:
            viewer.sync()
            time.sleep(sleep_close)

    contacts = int(count_object_contacts())
    obj_pos = data.xpos[obj_bid].copy()
    z = float(obj_pos[2])

    tip_vals = np.array([last_touch[f] for f in FINGERS], dtype=float)
    tip_sum = float(np.sum(tip_vals))
    tip_active = int(np.sum(tip_vals > 0.0))

    palm_sum = float(np.sum([sensor_sum(n) for n in PALM_SENSOR_NAMES])) if len(PALM_SENSOR_NAMES) else 0.0

    # ---- HOLD ----
    pos_ref = obj_pos.copy()
    stable_steps = 0
    max_speed = 0.0
    max_drift = 0.0

    for _ in range(HOLD_STEPS):
        u = controller_step(t_close=CLOSE_STEPS, theta=theta, latched=latched)
        data.ctrl[:] = u

        energy += float(np.sum(u*u))
        smooth += float(np.sum((u-prev_u)**2))
        prev_u = u

        mujoco.mj_step(model, data)

        p = data.xpos[obj_bid]
        v = data.cvel[obj_bid][3:6]  # linear velocity
        speed = float(np.linalg.norm(v))
        drift = float(np.linalg.norm(p - pos_ref))

        max_speed = max(max_speed, speed)
        max_drift = max(max_drift, drift)

        if (p[2] > Z_MIN) and (speed < V_MAX) and (drift < DRIFT_MAX):
            stable_steps += 1

        if viewer is not None and do_render:
            viewer.sync()
            time.sleep(sleep_hold)

    hold_ratio = stable_steps / float(HOLD_STEPS)
    success = 1 if hold_ratio >= SUCCESS_HOLD_RATIO else 0

    return {
        "contacts": contacts,
        "tip_active": tip_active,
        "tip_sum": tip_sum,
        "palm_sum": palm_sum,
        "z": z,
        "hold_ratio": float(hold_ratio),
        "max_speed": float(max_speed),
        "max_drift": float(max_drift),
        "energy": float(energy),
        "smooth": float(smooth),
        "success": int(success),
        "touch_per_finger": {f: float(last_touch[f]) for f in FINGERS},
    }

# -------------------------------
# SCORE FUNCTION (higher better)
# -------------------------------
def score_metrics(m):
    score = 0.0
    score += 2.0 * m["success"]
    score += 1.0 * m["hold_ratio"]

    score += 0.05 * m["contacts"]
    score += 0.03 * m["tip_active"]

    # LF priority (мизинец)
    lf_touch = m["touch_per_finger"]["LF"]
    score += BONUS_LF_TOUCH * lf_touch
    if lf_touch > 0.0:
        score += BONUS_LF_ACTIVE
    else:
        score -= PENALTY_NO_LF

    # penalties
    score -= LAMBDA_DRIFT  * m["max_drift"]
    score -= LAMBDA_ENERGY * m["energy"]
    score -= LAMBDA_SMOOTH * m["smooth"]
    return float(score)

def eval_theta(theta, base_seed):
    scores = []
    metrics = []
    for k in range(EVAL_TRIALS_PER_THETA):
        m = run_trial(theta=theta, seed=base_seed + k, viewer=None, do_render=False)
        metrics.append(m)
        scores.append(score_metrics(m))
    return float(np.mean(scores)), metrics

# -------------------------------
# BASELINES + RANDOM SEARCH
# -------------------------------
rng = np.random.default_rng(123)

baseline_A = np.array([1.0, 1.0, 1.0, max(1.0, LF_MIN), 1.0, 0.010, 0.30], dtype=float)  # равномерно + LF>=min
baseline_B = np.array([1.2, 0.9, 0.9, max(1.15, LF_MIN), 1.4, 0.010, 0.30], dtype=float)  # сильнее FF+TH + LF

print("\n--- Evaluate baselines ---")
sA, mA = eval_theta(baseline_A, SEED0 + 10000)
sB, mB = eval_theta(baseline_B, SEED0 + 20000)
print("Baseline A theta:", baseline_A, "score:", round(sA, 4))
print("Baseline B theta:", baseline_B, "score:", round(sB, 4))

best_theta = baseline_A.copy()
best_score = sA
best_pack  = mA
if sB > best_score:
    best_theta, best_score, best_pack = baseline_B.copy(), sB, mB

print("\n--- Random search (LF constrained) ---")
for it in range(1, RANDOM_ITERS + 1):
    w = rng.uniform(W_RANGE[0], W_RANGE[1], size=5)

    # отдельная генерация wLF
    w[3] = rng.uniform(LF_RANGE[0], LF_RANGE[1])   # LF index = 3
    if w[3] < LF_MIN:
        continue

    th_touch = rng.uniform(TH_TOUCH_RANGE[0], TH_TOUCH_RANGE[1])
    k_hold   = rng.uniform(K_HOLD_RANGE[0], K_HOLD_RANGE[1])
    theta = np.array([*w, th_touch, k_hold], dtype=float)

    s, pack = eval_theta(theta, SEED0 + 30000 + it*100)
    if s > best_score:
        best_theta, best_score, best_pack = theta, s, pack
        print(f"[NEW BEST] it={it:03d} score={best_score:.4f} theta={best_theta}")

print("\n==============================")
print(" BEST RESULT")
print("==============================")
print("best_score:", round(best_score, 4))
print("best_theta [wFF,wMF,wRF,wLF,wTH, th_touch, k_hold] =")
print(best_theta)

# -------------------------------
# SHOW BEST IN VIEWER (5 trials визуально)
# -------------------------------
viewer = mujoco.viewer.launch_passive(model, data)

def print_trial_line(trial_idx, m):
    tp = m["touch_per_finger"]
    print(f"TRIAL {trial_idx}: contacts={m['contacts']} | tip_active={m['tip_active']} | tip_sum={m['tip_sum']:.4f} | palm_sum={m['palm_sum']:.4f} | z={m['z']:.3f}")
    print(f"        hold_ratio={m['hold_ratio']:.2f} | max_speed={m['max_speed']:.3f} | max_drift={m['max_drift']:.3f} | energy={m['energy']:.1f} | smooth={m['smooth']:.1f} | SUCCESS={m['success']}")
    print(f"        touch FF={tp['FF']:.3f} MF={tp['MF']:.3f} RF={tp['RF']:.3f} LF={tp['LF']:.3f} TH={tp['TH']:.3f}")

print("\n--- Visualize best_theta (5 trials) ---")
visual_trials = 5
all_visual = []
for t in range(1, visual_trials+1):
    m = run_trial(theta=best_theta, seed=SEED0 + 50000 + t, viewer=viewer, do_render=True,
                  sleep_open=0.003, sleep_close=0.003, sleep_hold=0.006)
    print_trial_line(t, m)
    all_visual.append(m)

print("\n==============================")
print(" VISUAL SUMMARY")
print("==============================")
succ_rate = float(np.mean([m["success"] for m in all_visual]))
avg_contacts = float(np.mean([m["contacts"] for m in all_visual]))
avg_tip_active = float(np.mean([m["tip_active"] for m in all_visual]))
avg_energy = float(np.mean([m["energy"] for m in all_visual]))
avg_lf_touch = float(np.mean([m["touch_per_finger"]["LF"] for m in all_visual]))
print(f"Success rate: {succ_rate:.2f}")
print(f"Avg contacts after close: {avg_contacts:.2f}")
print(f"Avg active fingertips: {avg_tip_active:.2f}")
print(f"Avg ctrl energy (sum u^2): {avg_energy:.1f}")
print(f"Avg LF touch (after close snapshot): {avg_lf_touch:.3f}")

print("\nAll trials finished.")
print("MuJoCo window will stay open.")
print("Close the viewer window manually to exit.")

while viewer.is_running():
    mujoco.mj_step(model, data)
    viewer.sync()
    time.sleep(0.02)
