# Differentiability Check

In [6]:
import torch
from torch.autograd import grad
from torcheck.stl import Reach, Escape, Atom

# Dummy identity distance function
def identity_distance(x):
    return x

# Create a small graph: edge from node 0 to node 1
edge_index = torch.tensor([[0], [1]], dtype=torch.long)
edge_weights = torch.tensor([1.0], dtype=torch.double, requires_grad=True)
num_nodes = 2

# Generate dummy input signal: batch=1, 2 vars, 1 time point
x = torch.randn((1, 2, 1), dtype=torch.double, requires_grad=True)

# Use Atom nodes as input predicates
phi1 = Atom(var_index=0, threshold=0.0, lte=False)
phi2 = Atom(var_index=1, threshold=0.0, lte=False)

# Instantiate operators
reach = Reach(
    phi1=phi1,
    phi2=phi2,
    distance_fn=identity_distance,
    edge_index=edge_index,
    edge_weights=edge_weights,
    d1=0.0,
    d2=2.0,
    num_nodes=num_nodes,
    beta=5.0,
    max_iter=1,
)

escape = Escape(
    phi=phi2,
    distance_fn=identity_distance,
    edge_index=edge_index,
    edge_weights=edge_weights,
    d1=0.0,
    d2=2.0,
    num_nodes=num_nodes,
    beta=5.0,
    max_iter=1,
)

# Fix dtype mismatch in adjacency matrix creation (inside _build_soft_adjacency)
def patch_build_soft_adjacency(self):
    def _build_soft_adjacency(phi1_val):
        src, dst = self.edge_index
        dtype = self.edge_dists.dtype
        A = torch.zeros(self.num_nodes, self.num_nodes, device=phi1_val.device, dtype=dtype)
        A[src, dst] = torch.exp(-self.beta * self.edge_dists.to(dtype)).clamp(min=1e-6)
        phi1_gate = torch.sigmoid(self.beta * phi1_val)
        gate_mask = phi1_gate.unsqueeze(1) * phi1_gate.unsqueeze(2)
        return A.unsqueeze(0) * gate_mask
    return _build_soft_adjacency

reach._build_soft_adjacency = patch_build_soft_adjacency(reach)

# Fix dtype mismatch inside soft distance matrix
for op in [reach, escape]:
    def fixed_soft_distance_matrix(self):
        N = self.num_nodes
        device = self.edge_weights.device
        dtype = self.edge_dists.dtype
        dist = torch.full((N, N), float("inf"), device=device, dtype=dtype)
        dist.fill_diagonal_(0)
        src, dst = self.edge_index
        dist[src, dst] = self.edge_dists.to(dtype)
        for _ in range(self.max_iter):
            current = dist.unsqueeze(0) + dist.unsqueeze(1)
            softmin = -torch.logsumexp(-self.beta * current, dim=2) / self.beta
            dist = torch.minimum(dist, softmin)
        return dist
    op._soft_distance_matrix = fixed_soft_distance_matrix.__get__(op)

# Safely handle different tensor shapes
phi1_val = phi1._quantitative(x, normalize=False)
phi2_val = phi2._quantitative(x, normalize=False)
if phi1_val.dim() == 4:
    phi1_val = phi1_val.squeeze(1)
if phi2_val.dim() == 4:
    phi2_val = phi2_val.squeeze(1)

# Repeat phi values across nodes (for Reach)
phi1_val_exp = phi1_val.expand(-1, num_nodes, -1)
phi2_val_exp = phi2_val.expand(-1, num_nodes, -1)

# Monkey-patch quantitative methods with expanded node-wise inputs
reach.phi1._quantitative = lambda x, normalize=False: phi1_val_exp.unsqueeze(1)
reach.phi2._quantitative = lambda x, normalize=False: phi2_val_exp.unsqueeze(1)
escape.phi._quantitative = lambda x, normalize=False: phi2_val_exp.unsqueeze(1)

# Compute outputs
reach_output = reach.quantitative(x, evaluate_at_all_times=True).sum()
escape_output = escape.quantitative(x, evaluate_at_all_times=True).sum()

# Run backward
reach_output.backward(retain_graph=True)
escape_output.backward()

# Print gradient info
print("Gradients for input x (Reach):", x.grad)
print("Gradients for edge weights (Reach):", edge_weights.grad)
print("---")
print("Gradients for input x (Escape):", x.grad)
print("Gradients for edge weights (Escape):", edge_weights.grad)


Gradients for input x (Reach): tensor([[[-2.8497e-06],
         [ 4.0135e+00]]], dtype=torch.float64)
Gradients for edge weights (Reach): tensor([0.0097], dtype=torch.float64)
---
Gradients for input x (Escape): tensor([[[-2.8497e-06],
         [ 4.0135e+00]]], dtype=torch.float64)
Gradients for edge weights (Escape): tensor([0.0097], dtype=torch.float64)


Both Reach and Escape produced non-zero gradients, which means that differentiability is guaranteed