# Signal Temporal Logic with stljax

In [None]:
import jax.numpy as jnp
import jax
from stljax.formula import Predicate, Always, Eventually, Until

# Create predicates
x_pred = Predicate("x", lambda x: x)
pred_gt = x_pred > 0.0
pred_lt = x_pred < 5.0

# Temporal operators
formula_always = Always(pred_gt)  # G(x > 0)
formula_bounded = Eventually(pred_gt, interval=[0, 3])  # F[0,3](x > 0)
formula_unbounded = Eventually(pred_gt)  # F(x > 0)

# Boolean connectives
combined = formula_always & formula_bounded  # AND
either = formula_always | formula_bounded  # OR
negated = ~formula_always  # NOT

In [None]:
x_signal = jnp.array([1.0, 2.0, 3.0, 4.0, 1.0, 2.0])
rob_trace = formula_always(x_signal)
print(rob_trace)
# [1. 1. 1. 1. 1. 2.]

[1. 1. 1. 1. 1. 2.]


## Globally

In [None]:
# Dataset 1: x is always > 0
x_sat = jnp.array([1.0, 2.0, 3.0, 4.0, 1.0, 2.0])

# Dataset 2: x drops to -1 at time step 1
x_vio = jnp.array([1.0, -1.0, 3.0, 4.0, 5.0, 6.0])

# Define formula: G(x > 0)
x_pred = Predicate("x", lambda x: x)
formula_always = Always(x_pred > 0.0)

# Compute robustness traces
rob_sat = formula_always(x_sat)
rob_vio = formula_always(x_vio)

print("G(x > 0):")
print(f"Robustness trace for `rob_sat`: {rob_sat}")
print(f"Overall robustness: {rob_sat[0]:.2f}. Specification {'SATISFIED' if rob_sat[0] > 0 else 'VIOLATED'}")
print(f"\nRobustness trace for `rob_vio`: {rob_vio}")
print(f"Overall robustness: {rob_vio[0]:.2f}. Specification {'SATISFIED' if rob_vio[0] > 0 else 'VIOLATED'}")
# G(x > 0):
# Robustness trace for `rob_sat`: [1. 1. 1. 1. 1. 2.]
# Overall robustness: 1.00. Specification SATISFIED

# Robustness trace for `rob_vio`: [-1. -1.  3.  4.  5.  6.]
# Overall robustness: -1.00. Specification VIOLATED

G(x > 0):
Robustness trace for `rob_sat`: [1. 1. 1. 1. 1. 2.]
Overall robustness: 1.00. Specification SATISFIED

Robustness trace for `rob_vio`: [-1. -1.  3.  4.  5.  6.]
Overall robustness: -1.00. Specification VIOLATED


## Eventually

In [None]:
# Dataset 1: x > 5 at time 2 (within the [0,3] window)
x_sat = jnp.array([1.0, 2.0, 6.0, 4.0, 3.0, 2.0])

# Dataset 2: x > 5 only at time 4 (outside the [0,3] window)
x_vio = jnp.array([1.0, 2.0, 3.0, 4.0, 6.0, 7.0])

x_pred = Predicate("x", lambda x: x)
formula_bounded = Eventually(x_pred > 5.0, interval=[0, 3])

rob_sat = formula_bounded(x_sat)
rob_vio = formula_bounded(x_vio)

print("F[0,3](x > 5):")
print(f"Robustness trace for `rob_sat`: {rob_sat}")
print(f"Overall robustness: {rob_sat[0]:.2f}. Specification {'SATISFIED' if rob_sat[0] > 0 else 'VIOLATED'}")

print(f"\nRobustness trace for `rob_vio`: {rob_vio}")
print(f"Overall robustness: {rob_vio[0]:.2f}. Specification {'SATISFIED' if rob_vio[0] > 0 else 'VIOLATED'}")
# F[0,3](x > 5):
# Robustness trace for `rob_sat`: [ 1.  1.  1. -1. -2. -3.]
# Overall robustness: 1.00. Specification SATISFIED

# Robustness trace for `rob_vio`: [-1.  1.  2.  2.  2.  2.]
# Overall robustness: -1.00. Specification VIOLATED

F[0,3](x > 5):
Robustness trace for `rob_sat`: [ 1.  1.  1. -1. -2. -3.]
Overall robustness: 1.00. Specification SATISFIED

Robustness trace for `rob_vio`: [-1.  1.  2.  2.  2.  2.]
Overall robustness: -1.00. Specification VIOLATED


In [None]:
x_pred = Predicate("x", lambda x: x)
formula_unbounded = Eventually(x_pred > 5.0)

rob_sat_ub = formula_unbounded(x_sat)
rob_vio_ub = formula_unbounded(x_vio)

print(" F(x > 5):")
print(f"Robustness trace for `rob_sat_ub`: {rob_sat_ub}")
print(f"Overall robustness: {rob_sat_ub[0]:.2f}. Specification {'SATISFIED' if rob_sat_ub[0] > 0 else 'VIOLATED'}")

print(f"\nRobustness trace for `rob_vio_ub`: {rob_vio_ub}")
print(f"Overall robustness: {rob_vio_ub[0]:.2f}. Specification {'SATISFIED' if rob_vio_ub[0] > 0 else 'VIOLATED'}")
# F(x > 5):
# Robustness trace for `rob_sat_ub`: [ 1.  1.  1. -1. -2. -3.]
# Overall robustness: 1.00. Specification SATISFIED

# Robustness trace for `rob_vio_ub`: [2. 2. 2. 2. 2. 2.]
# Overall robustness: 2.00. Specification SATISFIED

 F(x > 5):
Robustness trace for `rob_sat_ub`: [ 1.  1.  1. -1. -2. -3.]
Overall robustness: 1.00. Specification SATISFIED

Robustness trace for `rob_vio_ub`: [2. 2. 2. 2. 2. 2.]
Overall robustness: 2.00. Specification SATISFIED


## Until

In [None]:
# Dataset 1: p holds until q becomes true at time 2
p_sat = jnp.array([1.0, 1.0, 1.0, 0.0, 0.0, 0.0])
q_sat = jnp.array([0.0, 0.0, 1.0, 1.0, 0.0, 0.0])

# Dataset 2: p fails at time 2 (goes to -1), but q only becomes true at time 3
p_vio = jnp.array([1.0, 1.0, -1.0, 0.0, 0.0, 0.0])
q_vio = jnp.array([-1.0, -1.0, -1.0, 1.0, 0.0, 0.0])

p_pred = Predicate("p", lambda x: x)
q_pred = Predicate("q", lambda x: x)
formula_until = Until(p_pred > 0.0, q_pred > 0.0, interval=[0, 3])

rob_sat = formula_until((p_sat, q_sat))
rob_vio = formula_until((p_vio, q_vio))

print("(p > 0) U[0,3] (q > 0):")
print(f"Robustness trace for `rob_sat`: {rob_sat}")
print(f"Overall robustness: {rob_sat[0]:.2f}. Specification {'SATISFIED' if rob_sat[0] > 0 else 'VIOLATED'}")

print(f"Robustness trace for `rob_vio`: {rob_vio}")
print(f"Overall robustness: {rob_vio[0]:.2f}. Specification {'SATISFIED' if rob_vio[0] > 0 else 'VIOLATED'}")
# (p > 0) U[0,3] (q > 0):
# Robustness trace for `rob_sat`: [ 1.  1.  1. -0. -0. -0.]
# Overall robustness: 1.00. Specification SATISFIED
# Robustness trace for `rob_vio`: [-1. -1. -1. -0. -0. -0.]
# Overall robustness: -1.00. Specification VIOLATED

(p > 0) U[0,3] (q > 0):
Robustness trace for `rob_sat`: [ 1.  1.  1. -0. -0. -0.]
Overall robustness: 1.00. Specification SATISFIED
Robustness trace for `rob_vio`: [-1. -1. -1. -0. -0. -0.]
Overall robustness: -1.00. Specification VIOLATED


## Differentiation

In [None]:
x_pred = Predicate("x", lambda x: x)
formula = Always(x_pred > 0.0)


# Function mapping signal values -> scalar robustness at t=0
def robustness_fn(x_values):
    rob = formula(x_values)
    return rob[0]


# A signal where the tightest point is x[2] = 0.5
x_signal = jnp.array([1.0, 2.0, 0.5, 1.5, 0.8, 1.2])

# Compute gradient via JAX automatic differentiation
grad_fn = jax.grad(robustness_fn)
gradient = grad_fn(x_signal)

print(f"Signal:   {x_signal}")
print(f"Gradient: {gradient}")
print(
    f"\nRobustness G(x>0) = min(x) = {robustness_fn(x_signal):.2f}, achieved at x[2]={x_signal[2]}"
)
print("Increasing x[2] would most improve the overall robustness.")

Signal:   [1.  2.  0.5 1.5 0.8 1.2]
Gradient: [0. 0. 1. 0. 0. 0.]

Robustness G(x>0) = min(x) = 0.50, achieved at x[2]=0.5
Increasing x[2] would most improve the overall robustness.


# Human-Robot Distance Example

In [19]:
import jax
import jax.numpy as jnp
from stljax.formula import Predicate, Always, Eventually, Until

# Simulated distance trace (meters) over T=10 time steps
# Robot starts at 3.0m, gradually approaches to 0.6m
T = 10
distance = jnp.array([3.0, 2.8, 2.5, 2.2, 1.8, 1.5, 1.2, 0.9, 0.7, 0.6])

# Define the STL specification
d_pred = Predicate("d", lambda x: x)

# G(d >= 0.5): always keep at least 0.5m safety distance
safety_spec = Always(d_pred >= 0.5)

# F[0, T-1](d <= 2.0): come within 2.0m at some point
liveness_spec = Eventually(d_pred <= 2.0, interval=[0, T - 1])

# Combined specification
phi_hri = safety_spec & liveness_spec

# Evaluate robustness
rob_trace = phi_hri(distance)
overall_rob = rob_trace[0]

print(f"Distance trace: {distance}")
print(f"Robustness trace: {rob_trace}")
print(f"Overall robustness: {overall_rob:.4f}")
print(f"Specification {'SATISFIED' if overall_rob > 0 else 'VIOLATED'}")

Distance trace: [3.  2.8 2.5 2.2 1.8 1.5 1.2 0.9 0.7 0.6]
Robustness trace: [0.10000002 0.10000002 0.10000002 0.10000002 0.10000002 0.10000002
 0.10000002 0.10000002 0.10000002 0.10000002]
Overall robustness: 0.1000
Specification SATISFIED


In [20]:
# Violated trace: robot gets too close (below 0.5m at times 6-8)
distance_vio = jnp.array([3.0, 2.5, 2.0, 1.5, 1.0, 0.5, 0.3, 0.2, 0.4, 0.6])

rob_trace_vio = phi_hri(distance_vio)
print(f"Violated trace robustness: {rob_trace_vio[0]:.4f}")
print(f"Specification {'SATISFIED' if rob_trace_vio[0] > 0 else 'VIOLATED'}")

Violated trace robustness: -0.3000
Specification VIOLATED


In [None]:
def hri_robustness(d):
    rob = phi_hri(d)
    return rob[0]


grad_fn = jax.grad(hri_robustness)
gradient = grad_fn(distance)

print("Gradient of HRI robustness w.r.t. distance trace:")
print(gradient)
print("\nInterpretation:")
print("  Positive gradient at time t: increasing distance at t improves robustness.")
print("  Negative gradient at time t: decreasing distance at t improves robustness.")
print(
    f"\nLargest gradient at t=9 (d=0.6m, only {distance[9] - 0.5:.1f}m above safety threshold)."
)
print("The robot should move further away at the end of the interaction.")

Gradient of HRI robustness w.r.t. distance trace:
[0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]

Interpretation:
  Positive gradient at time t: increasing distance at t improves robustness.
  Negative gradient at time t: decreasing distance at t improves robustness.

Largest gradient at t=9 (d=0.6m, only 0.1m above safety threshold).
The robot should move further away at the end of the interaction.
