In [None]:
%pip install -q --upgrade pip
%pip install -q "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
%pip install -q optax matplotlib
%pip install -q mujoco-mjx
%pip install -q stljax
import os

repo_url = "https://github.com/LiHuaqing-tum/Meta-Learning-for-control-Semesterproject-.git"
repo_name = "Meta-Learning-for-control-Semesterproject-"

if not os.path.exists(repo_name):
    !git clone {repo_url}
else:
    %cd {repo_name}
    !git pull origin main
    %cd ..

%cd {repo_name}



Cloning into 'Meta-Learning-for-control-Semesterproject-'...
remote: Enumerating objects: 165, done.[K
remote: Counting objects: 100% (165/165), done.[K
remote: Compressing objects: 100% (144/144), done.[K
remote: Total 165 (delta 33), reused 121 (delta 16), pack-reused 0 (from 0)[K
Receiving objects: 100% (165/165), 4.84 MiB | 3.36 MiB/s, done.
Resolving deltas: 100% (33/33), done.
/content/Meta-Learning-for-control-Semesterproject-/Meta-Learning-for-control-Semesterproject-/Meta-Learning-for-control-Semesterproject-/Meta-Learning-for-control-Semesterproject-/Meta-Learning-for-control-Semesterproject-/Meta-Learning-for-control-Semesterproject-/Meta-Learning-for-control-Semesterproject-/Meta-Learning-for-control-Semesterproject-/Meta-Learning-for-control-Semesterproject-/Meta-Learning-for-control-Semesterproject-


In [None]:
import jax, mujoco, optax
print("✅ Setup complete!")
print("JAX version:", jax.__version__)
print("MuJoCo version:", mujoco.__version__)


✅ Setup complete!
JAX version: 0.5.3
MuJoCo version: 3.3.6


In [None]:
import sys
import pathlib

import jax
import jax.numpy as jnp
from mujoco import mjx

import matplotlib.pyplot as plt
from stljax.formula import *
from stljax.viz import *
import optax
import functools

In [None]:

from pathlib import Path

repo_root = Path.cwd()
xml_path  = repo_root / "task" / "scenes" / "panda_push_scene_without_obstacle.xml"
from src.environment.SImpleReacherEnv import SimpleReacherEnv


In [None]:
def sdist_circle_2d(p_xy, center_xy, radius):

    return jnp.linalg.norm(p_xy - center_xy, axis=-1) - radius

In [None]:
import jax
import jax.numpy as jnp
GREEN_C = jnp.array([0.0, 0.90])
GREEN_H = jnp.array([0.08, 0.08])
R_OBS   = jnp.linalg.norm(GREEN_H)

RED_C   = jnp.array([0.0, 1.40])
RED_H   = jnp.array([0.15, 0.15])
R_GOAL  = jnp.minimum(RED_H[0], RED_H[1])

BLOCK_H = jnp.array([0.06, 0.06])
R_BLOCK = jnp.linalg.norm(BLOCK_H)

env = SimpleReacherEnv(
    model_path=str(repo_root / "task" / "scenes" / "panda_push_scene_without_obstacle.xml"),
    return_full_qpos_qvel=False,
)
_ = env.reset(jax.random.PRNGKey(0))
initial_data = env.d


obstacle_center_xy = env.obstacle_center_xy
target_center_xy = env.target_center_xy
def safe_norm(v, eps=1e-8):
    return jnp.sqrt(jnp.maximum(jnp.sum(v*v, axis=-1), eps))

def avoid_green_signal_circ_circ(traj_xy, obstacle_center_xy, safe_margin, eps=1e-8):

    dist = safe_norm(traj_xy - obstacle_center_xy, eps=eps)
    return dist - safe_margin

def reach_red_signal_circ_goal(traj_xy, target_center_xy, inside_tol, eps=1e-8):

    dist = safe_norm(traj_xy - target_center_xy, eps=eps)
    return inside_tol - dist

def softmax_reduce_max(x, beta=50.0):
    m = jnp.max(x); return m + jnp.log(jnp.mean(jnp.exp(beta*(x-m))))/beta
def softmin_reduce_min(x, beta=50.0):
    return -softmax_reduce_max(-x, beta)

def spec_robustness(traj_xy, obstacle_center_xy, target_center_xy,
                    safe_margin=0.02, inside_tol=0.02, H:int=1):

    window = traj_xy[1:H+1]


    avoid_t = jax.vmap(lambda p: avoid_green_signal_circ_circ(p, obstacle_center_xy, safe_margin))(window)
    reach_t = jax.vmap(lambda p: reach_red_signal_circ_goal(p, target_center_xy, inside_tol))(window)

    rho_always   = softmin_reduce_min(avoid_t,   beta=50.0)   # □
    rho_eventual = softmax_reduce_max(reach_t,   beta=50.0)   # ◇
    return jnp.minimum(rho_always, rho_eventual)  # &


horizon = 5

In [None]:
print(obs_dim)
print(act_dim)

8
7


In [None]:
print(tau_lo)
print(tau_hi)

[-87. -87. -87. -87. -12. -12. -12.]
[87. 87. 87. 87. 12. 12. 12.]


In [None]:
print(env.site_id_block)

0


In [None]:

obs_dim = env.get_observation_space()[0]
act_dim = env.get_action_space()[0]


layer_sizes = [obs_dim, 64, 64, act_dim]

def init_layer_params(key, in_dim, out_dim):
    w_key, b_key = jax.random.split(key)
    glorot = jnp.sqrt(2.0 / (in_dim + out_dim))
    W = glorot * jax.random.normal(w_key, (out_dim, in_dim))
    b = jnp.zeros((out_dim,))
    return W, b

def init_policy_params(key, sizes):
    keys = jax.random.split(key, len(sizes) - 1)
    return [init_layer_params(k, sizes[i], sizes[i + 1]) for i, k in enumerate(keys)]

policy_key = jax.random.PRNGKey(42)
policy_params = init_policy_params(policy_key, layer_sizes)


tau_lo = env.ctrlrange_jnp[:, 0] if env.ctrlrange_jnp.size else None
tau_hi = env.ctrlrange_jnp[:, 1] if env.ctrlrange_jnp.size else None

def policy_apply(params, obs):
    x = obs
    for W, b in params[:-1]:
        x = jnp.tanh(W @ x + b)
    W, b = params[-1]
    raw = jnp.tanh(W @ x + b)
    if tau_lo is not None:
        raw_tau = tau_lo + 0.5 * (raw + 1.0) * (tau_hi - tau_lo)
        return  raw_tau
    return raw


def block_xy_from_data(data: mjx.Data) -> jnp.ndarray:
    return data.site_xpos[env.site_id_block, :2]

def rollout_block_xy(params, data0):
    obs0 = env._obs_from_data(data0)
    xy0 = block_xy_from_data(data0)

    def body_fn(carry, _):
        data, obs = carry
        act = policy_apply(params, obs)
        data_new = data.replace(ctrl=act)
        data_new = mjx.step(env.m, data_new)
        obs_new = env._obs_from_data(data_new)
        xy_new = block_xy_from_data(data_new)
        return (data_new, obs_new), (obs_new, act, xy_new)

    (_, _), (obs_seq, act_seq, xy_seq) = jax.lax.scan(
        body_fn,
        (data0, obs0),
        jnp.arange(horizon)
    )

    obs_traj = jnp.vstack([obs0, obs_seq])            # (horizon+1, obs_dim)
    xy_traj = jnp.vstack([xy0, xy_seq])               # (horizon+1, 2)
    return xy_traj, obs_traj, act_seq


def loss_fn(params):
    traj_xy, _, _ = rollout_block_xy(params, initial_data)
    rob = spec_robustness(traj_xy, obstacle_center_xy, target_center_xy, H=horizon)
    return jax.nn.relu(-rob)


In [None]:
from jax import jacfwd
from jax.flatten_util import ravel_pytree
import optax

In [None]:
optimizer = optax.adam(1e-3)
opt_state = optimizer.init(policy_params)

@jax.jit
def update(params, opt_state):
    loss_val = loss_fn(params)
    flat_params, unravel = ravel_pytree(params)
    def loss_flat(p):
        return loss_fn(unravel(p))
    grad_flat = jacfwd(loss_flat)(flat_params)
    grads = unravel(grad_flat)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_val

prev_flat, _ = ravel_pytree(policy_params)
num_steps = 2000
for step in range(1, num_steps + 1):
    policy_params, opt_state, train_loss = update(policy_params, opt_state)

    # ---- 监控参数变化（最小改动）----
    flat, _ = ravel_pytree(policy_params)
    delta   = flat - prev_flat
    d_norm  = float(jnp.linalg.norm(delta))
    d_max   = float(jnp.max(jnp.abs(delta)))
    w_norm  = float(jnp.linalg.norm(flat))

    # 取几枚“探针”参数看看具体数值是否在动（第一层的一两个元素）
    W0, b0 = policy_params[0]  # 第一层 (W,b)
    probe_w = float(W0[0, 0])
    probe_b = float(b0[0])

    print(f"step {step:04d} | STL loss {-train_loss:.4f} robustness | "
          f"||Δθ||2={d_norm:.3e}  max|Δ|={d_max:.3e}  ||θ||2={w_norm:.3e}  "
          f"W0[0,0]={probe_w:.5f}  b0[0]={probe_b:.5f}")

    prev_flat = flat  # 更新基线


step 0001 | STL loss -0.3800 robustness | ||Δθ||2=3.727e-15  max|Δ|=1.041e-15  ||θ||2=9.509e+00  W0[0,0]=-0.03669  b0[0]=-0.00000
step 0002 | STL loss -0.3800 robustness | ||Δθ||2=1.850e-15  max|Δ|=4.996e-16  ||θ||2=9.509e+00  W0[0,0]=-0.03669  b0[0]=-0.00000
step 0003 | STL loss -0.3800 robustness | ||Δθ||2=1.816e-15  max|Δ|=4.996e-16  ||θ||2=9.509e+00  W0[0,0]=-0.03669  b0[0]=-0.00000
step 0004 | STL loss -0.3800 robustness | ||Δθ||2=1.908e-15  max|Δ|=6.939e-16  ||θ||2=9.509e+00  W0[0,0]=-0.03669  b0[0]=-0.00000
step 0005 | STL loss -0.3800 robustness | ||Δθ||2=1.886e-15  max|Δ|=5.274e-16  ||θ||2=9.509e+00  W0[0,0]=-0.03669  b0[0]=-0.00000
step 0006 | STL loss -0.3800 robustness | ||Δθ||2=1.519e-15  max|Δ|=4.441e-16  ||θ||2=9.509e+00  W0[0,0]=-0.03669  b0[0]=-0.00000
step 0007 | STL loss -0.3800 robustness | ||Δθ||2=1.520e-15  max|Δ|=3.886e-16  ||θ||2=9.509e+00  W0[0,0]=-0.03669  b0[0]=-0.00000
step 0008 | STL loss -0.3800 robustness | ||Δθ||2=1.554e-15  max|Δ|=4.718e-16  ||θ||2=9.50

KeyboardInterrupt: 