In [None]:
%pip install -q --upgrade pip
%pip install -q "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
%pip install -q optax matplotlib
%pip install -q mujoco-mjx
%pip install -q stljax
import os

repo_url = "https://github.com/LiHuaqing-tum/Meta-Learning-for-control-Semesterproject-.git"
repo_name = "Meta-Learning-for-control-Semesterproject-"

if not os.path.exists(repo_name):
    !git clone {repo_url}
else:
    %cd {repo_name}
    !git pull origin main
    %cd ..

%cd {repo_name}

In [None]:
import jax, mujoco, optax
print("✅ Setup complete!")
print("JAX version:", jax.__version__)
print("MuJoCo version:", mujoco.__version__)



In [None]:
import sys
import pathlib

import jax
import jax.numpy as jnp
from mujoco import mjx

import matplotlib.pyplot as plt
from stljax.formula import *
from stljax.viz import *
import optax
import functools

In [None]:
def sdist_circle_2d(p_xy, center_xy, radius):
    
    return jnp.linalg.norm(p_xy - center_xy, axis=-1) - radius

In [None]:

GREEN_C = jnp.array([0.0, 0.90])
GREEN_H = jnp.array([0.08, 0.08])
R_OBS   = jnp.linalg.norm(GREEN_H)

RED_C   = jnp.array([0.0, 1.40])
RED_H   = jnp.array([0.15, 0.15])
R_GOAL  = jnp.minimum(RED_H[0], RED_H[1])

BLOCK_H = jnp.array([0.06, 0.06])
R_BLOCK = jnp.linalg.norm(BLOCK_H)


def avoid_green_signal_circ_circ(states_xy, obstacle_center_xy, safe_margin=0.02):
    req = R_BLOCK + R_OBS + safe_margin
    return jnp.linalg.norm(states_xy - obstacle_center_xy, axis=-1) - req


def reach_red_signal_circ_goal(states_xy, target_center_xy, inside_tol=0.02):
    req_in = jnp.maximum(R_GOAL - (R_BLOCK + inside_tol), 0.0)
    return req_in - jnp.linalg.norm(states_xy - target_center_xy, axis=-1)

In [None]:
repo_root = pathlib.Path.cwd().resolve().parent
if str(repo_root) not in sys.path:
    sys.path.append(str(repo_root))

from src.environment.SImpleReacherEnv import SimpleReacherEnv


env = SimpleReacherEnv(
    model_path=str(repo_root / "task" / "scenes" / "panda_push_scene_without_obstacle.xml"),
    return_full_qpos_qvel=False,   
)
_ = env.reset(jax.random.PRNGKey(0))
initial_data = env.d  


obstacle_center_xy = env.obstacle_center_xy
target_center_xy = env.target_center_xy


# STL predicates that capture the current obstacle/target centers
def build_spec(obstacle_center_xy, target_center_xy, safe_margin=0.02, inside_tol=0.02):
    avoid_pred = Predicate(
        "avoid_green_margin",
        lambda traj_xy: avoid_green_signal_circ_circ(traj_xy, obstacle_center_xy, safe_margin)
    )
    reach_pred = Predicate(
        "reach_red_margin",
        lambda traj_xy: reach_red_signal_circ_goal(traj_xy, target_center_xy, inside_tol)
    )
    avoid_atom = (avoid_pred > 0.0)
    reach_atom = (reach_pred > 0.0)
    return avoid_atom, reach_atom

avoid_atom, reach_atom = build_spec(obstacle_center_xy, target_center_xy)


horizon = 150  # number of discrete control steps in each rollout

spec = Always(avoid_atom, interval=(0, horizon)) & Eventually(reach_atom, interval=(0, horizon))






In [None]:

obs_dim = env.get_observation_space()[0]
act_dim = env.get_action_space()[0]


layer_sizes = [obs_dim, 64, 64, act_dim]

def init_layer_params(key, in_dim, out_dim):
    w_key, b_key = jax.random.split(key)
    glorot = jnp.sqrt(2.0 / (in_dim + out_dim))
    W = glorot * jax.random.normal(w_key, (out_dim, in_dim))
    b = jnp.zeros((out_dim,))
    return W, b

def init_policy_params(key, sizes):
    keys = jax.random.split(key, len(sizes) - 1)
    return [init_layer_params(k, sizes[i], sizes[i + 1]) for i, k in enumerate(keys)]

policy_key = jax.random.PRNGKey(42)
policy_params = init_policy_params(policy_key, layer_sizes)


ctrl_lo = env.ctrlrange_jnp[:, 0] if env.ctrlrange_jnp.size else None
ctrl_hi = env.ctrlrange_jnp[:, 1] if env.ctrlrange_jnp.size else None

def policy_apply(params, obs):
    x = obs
    for W, b in params[:-1]:
        x = jnp.tanh(W @ x + b)
    W, b = params[-1]
    raw = W @ x + b
    raw = jnp.tanh(raw)
    if ctrl_lo is not None:
        return ctrl_lo + 0.5 * (raw + 1.0) * (ctrl_hi - ctrl_lo)
    return raw

def block_xy_from_data(data: mjx.Data) -> jnp.ndarray:
    return data.site_xpos[env.site_id_block, :2]

def rollout_block_xy(params, data0):
    obs0 = env._obs_from_data(data0)
    xy0 = block_xy_from_data(data0)

    def body_fn(carry, _):
        data, obs = carry
        act = policy_apply(params, obs)
        data_new = data.replace(ctrl=act)
        data_new = mjx.step(env.m, data_new)
        obs_new = env._obs_from_data(data_new)
        xy_new = block_xy_from_data(data_new)
        return (data_new, obs_new), (obs_new, act, xy_new)

    (_, _), (obs_seq, act_seq, xy_seq) = jax.lax.scan(
        body_fn,
        (data0, obs0),
        jnp.arange(horizon)
    )

    obs_traj = jnp.vstack([obs0, obs_seq])            # (horizon+1, obs_dim)
    xy_traj = jnp.vstack([xy0, xy_seq])               # (horizon+1, 2)
    return xy_traj, obs_traj, act_seq

def loss_fn(params):
    traj_xy, _, _ = rollout_block_xy(params, initial_data)
    
    rob = spec.robustness(traj_xy) 
    return jax.nn.relu(-rob)            

In [None]:
optimizer = optax.adam(1e-3)
opt_state = optimizer.init(policy_params)


def update(params, opt_state):
    loss_val = loss_fn(params)
    flat_params, unravel = ravel_pytree(params)
    def loss_flat(p):
        return loss_fn(unravel(p))
    grad_flat = jacfwd(loss_flat)(flat_params)
    grads = unravel(grad_flat)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_val


num_steps = 2000
for step in range(1, num_steps + 1):
    policy_params, opt_state, train_loss = update(policy_params, opt_state)
    if step % 100 == 0:
        print(f"step {step:04d} | STL loss {-train_loss:.4f} robustness")
