# STREL

In [2]:
import torch
import torch.nn as nn

## Reach

In [None]:
class Reach():
    """
    Reachability operator for STREL. Models bounded or unbounded reach
    over a spatial graph.
    """
    def __init__(
        self,
        distance_matrix,
        d1,
        d2,
        graph_nodes,
        is_unbounded: bool = False,
        distance_domain_min=None,
        distance_domain_max=None,
    ) -> None:
        super().__init__()
        self.d1 = d1
        self.d2 = d2
        self.is_unbounded = is_unbounded
        self.distance_domain_min = distance_domain_min
        self.distance_domain_max = distance_domain_max
        self.graph_nodes = graph_nodes

        self.weight_matrix = distance_matrix
        self.adjacency_matrix = (distance_matrix > 0).int()

        self.boolean_min_satisfaction = torch.tensor(0.0)
        self.quantitative_min_satisfaction = torch.tensor(float('-inf'))

    def __call__(self, s1, s2):
        return self._boolean(s1, s2)
        '''
        if quantitative:
            return self._quantitative(s1, s2)
        else:
            return self._boolean(s1, s2)
        '''

    def __str__(self) -> str:
        bound_type = "unbounded" if self.is_unbounded else f"[{self.d1},{self.d2}]"
        return f"Reach{bound_type}"

    # def time_depth(self) -> int:
        # return 0

    def neighbors_fn(self, node):
        mask = self.adjacency_matrix[:, node] > 0
        neighbors = self.graph_nodes[mask]
        neigh_pairs = [(i.item(), self.weight_matrix[i, node].item()) for i in neighbors]
        print('node = ', node, ' has neigh_pairs = ', neigh_pairs)
        return neigh_pairs

    def distance_function(self, weight):
        return weight

    # ----------------
    #    Boolean
    # ----------------

    def _boolean(self, s1, s2):
        if self.is_unbounded:
            return self._unbounded_reach_boolean(s1, s2)
        else:
            return self._bounded_reach_boolean(s1, s2)

    def _bounded_reach_boolean(self, s1, s2):
        s = torch.zeros(len(self.graph_nodes), requires_grad=True)

        for i, lt in enumerate(self.graph_nodes):
            l = lt.item()
            if self.d1 == self.distance_domain_min:
                s = s.clone().scatter_(0, torch.tensor([l]), s2[l])
            else:
                s = s.clone().scatter_(0, torch.tensor([l]), self.boolean_min_satisfaction)

        Q = {llt.item(): [(s2[llt.item()], self.distance_domain_min)] for llt in self.graph_nodes}

        while Q:
            print('Q = ', Q)
            Q_prime = {}
            for l in Q.keys():
                print('l = ', l)
                for v, d in Q[l]:
                    for l_prime, w in self.neighbors_fn(l):
                        print(f'neigh of {l}: {l_prime}')
                        v_new = torch.minimum(v, s1[l_prime])
                        d_new = d + w

                        if self.d1 <= d_new <= self.d2:
                            current_val = s[l_prime]
                            new_val = torch.maximum(current_val, v_new)
                            s = s.clone().scatter_(0, torch.tensor([l_prime]), new_val)
                            print('s = ', s)

                        if d_new < self.d2:

                            existing_entries = Q_prime.get(l_prime, [])
                            updated = False
                            new_entries = []
                            for vv, dd in existing_entries:
                                if dd == d_new:
                                    new_v = torch.maximum(vv, v_new)
                                    new_entries.append((new_v, dd))
                                    updated = True
                                else:
                                    new_entries.append((vv, dd))

                            if not updated:
                                new_entries.append((v_new, d_new))
                            Q_prime[l_prime] = new_entries

            print('Q_prime = ', Q_prime)
            Q = Q_prime
        return s

    def _unbounded_reach_boolean(self, s1, s2):

        if self.d1 == self.distance_domain_min:
            s = s2
        else:
            d_max = torch.max(self.distance_function(self.weight_matrix))
            self.d2 = self.d1 + d_max
            s = self._bounded_reach_boolean(s1, s2)

        T = [n for n in self.graph_nodes]

        while T:
            T_prime = []

            for l in T:
                for l_prime, w in self.neighbors_fn(l):
                    v_prime = torch.minimum(s[l], s1[l_prime])
                    v_prime = torch.maximum(v_prime, s[l_prime])

                    if v_prime != s[l_prime]:
                        s = s.clone().scatter_(0, torch.tensor([l_prime]), v_prime)
                        T_prime.append(l_prime)

            T = T_prime

        return s

    # ---------------------
    #     Quantitative
    # ---------------------

    def _quantitative(self, s1, s2):
        if self.is_unbounded:
            return self._unbounded_reach_quantitative(s1, s2)
        else:
            return self._bounded_reach_quantitative(s1, s2)

    def _bounded_reach_quantitative(self, s1, s2):
        s = torch.full((len(self.graph_nodes),), self.quantitative_min_satisfaction, requires_grad=True)

        for i, lt in enumerate(self.graph_nodes):
            l = lt.item()
            if self.d1 == self.distance_domain_min:
                s = s.clone().scatter_(0, torch.tensor([l]), s2[l])
            else:
                s = s.clone().scatter_(0, torch.tensor([l]), self.quantitative_min_satisfaction)

        Q = {
            llt.item(): [(s2[llt.item()], self.distance_domain_min)]
            for llt in self.graph_nodes
        }

        while Q:
            Q_prime = {}
            for l in Q.keys():
                for (v, d) in Q[l]:
                    for l_prime, w in self.neighbors_fn(l):
                        v_new = torch.minimum(v, s1[l_prime])
                        d_new = d + w

                        if self.d1 <= d_new <= self.d2:
                            current_val = s[l_prime]
                            new_val = torch.maximum(current_val, v_new)
                            s = s.clone().scatter_(0, torch.tensor([l_prime]), new_val)

                        if d_new < self.d2:
                            existing_entries = Q_prime.get(l_prime, [])
                            updated = False
                            new_entries = []
                            for (vv, dd) in existing_entries:
                                if dd == d_new:
                                    new_v = torch.maximum(vv, v_new)
                                    new_entries.append((new_v, dd))
                                    updated = True
                                else:
                                    new_entries.append((vv, dd))
                            if not updated:
                                new_entries.append((v_new, d_new))
                            Q_prime[l_prime] = new_entries
            Q = Q_prime

        return s

    def _unbounded_reach_quantitative(self, s1, s2):

        if self.d1 == self.distance_domain_min:
            s = s2
        else:
            d_max = torch.max(self.distance_function(self.weight_matrix))
            self.d2 = self.d1 + d_max
            s = self._bounded_reach_quantitative(s1, s2)

        T = [n for n in self.graph_nodes]

        while T:
            T_prime = []

            for l in T:
                for l_prime, w in self.neighbors_fn(l):
                    v_prime = torch.minimum(s[l], s1[l_prime])
                    new_val = torch.maximum(v_prime, s[l_prime])
                    if new_val != s[l_prime]:
                        s = s.clone().scatter_(0, torch.tensor([l_prime]), new_val)
                        T_prime.append(l_prime)
            T = T_prime

        return s

## Escape

In [None]:
class Escape():
    """
    Escape operator for STREL. Models escape condition over a spatial graph.
    """
    def __init__(
        self,
        distance_matrix,
        d1,
        d2,
        graph_nodes,
        distance_domain_min=None,
        distance_domain_max=None,
    ) -> None:
        self.d1 = d1
        self.d2 = d2
        self.distance_domain_min = distance_domain_min
        self.distance_domain_max = distance_domain_max
        self.graph_nodes = graph_nodes

        self.weight_matrix = distance_matrix
        self.adjacency_matrix = (distance_matrix > 0).int()

        self.boolean_min_satisfaction = torch.tensor(0.0)
        self.quantitative_min_satisfaction = torch.tensor(float('-inf'))

    def __call__(self, s1, quantitative=False):
        return self._quantitative(s1) if quantitative else self._boolean(s1)

    # def time_depth(self) -> int:
        # return 0

    def neighbors_fn(self, node):
        mask = (self.adjacency_matrix[:, node] > 0)
        neighbors = self.graph_nodes[mask]
        return [(i.item(), self.weight_matrix[i, node].item()) for i in neighbors]

    def forward_neighbors_fn(self, node):
        mask = (self.adjacency_matrix[node, :] > 0)
        neighbors = self.graph_nodes[mask]
        return [(j.item(), self.weight_matrix[node, j].item()) for j in neighbors]

    def compute_min_distance_matrix(self):

        n = len(self.graph_nodes) 
        D = torch.full((n, n), float('inf')) 

        for start in range(n): 
            visited = torch.zeros(n, dtype=torch.bool) 
            distance = torch.full((n,), float('inf')) 
            distance[start] = 0 

            frontier = torch.zeros(n, dtype=torch.bool) 
            frontier[start] = True

            while frontier.any(): 
                next_frontier = torch.zeros(n, dtype=torch.bool) 

                for node in torch.nonzero(frontier).flatten(): 
                    node = node.item()
                    visited[node] = True

                    for neighbor, weight in self.forward_neighbors_fn(node): 
                        if visited[neighbor]:
                            continue
                        
                        edge_cost = weight 
                        new_dist = distance[node] + edge_cost 

                        if new_dist < distance[neighbor]:
                            distance[neighbor] = new_dist
                            next_frontier[neighbor] = True

                frontier = next_frontier

            D[start] = distance

        return D

    # ----------------
    #    Boolean
    # ----------------

    def _boolean(self, s1):
        L = self.graph_nodes
        n = len(L)

        D = self.compute_min_distance_matrix()

        e = torch.ones((n, n), requires_grad=True) * self.boolean_min_satisfaction
        e = e - torch.diag(torch.diag(e)) + torch.diag(s1)

        T = [(i, i) for i in range(n)]

        while T:
            T_prime = []
            e_prime = e.clone()

            for l1, l2 in T:
                for l1_prime, w in self.neighbors_fn(l1):
                    new_val = torch.minimum(s1[l1_prime], e[l1, l2])
                    old_val = e[l1_prime, l2]
                    combined = torch.maximum(old_val, new_val)

                    if combined != old_val:
                        e_prime = e_prime.clone().index_put_(tuple(torch.tensor([[l1_prime], [l2]])), combined)
                        T_prime.append((l1_prime, l2))

            T = T_prime
            e = e_prime

        s = torch.ones(n, requires_grad=True) * self.boolean_min_satisfaction
        for i in range(n):
            vals = [
                e[i, j] for j in range(n)
                if self.d1 <= D[i, j] <= self.d2
            ]
            if vals:
                max_val = torch.stack(vals).max()
                s = s.clone().scatter_(0, torch.tensor([i]), max_val.unsqueeze(0))

        return s

    # ---------------------
    #     Quantitative
    # ---------------------

    def _quantitative(self, s1):
        L = self.graph_nodes
        n = len(L)

        D = self.compute_min_distance_matrix()

        e = torch.ones((n, n), requires_grad=True) * self.quantitative_min_satisfaction
        e = e - torch.diag(torch.diag(e)) + torch.diag(s1)

        T = [(i, i) for i in range(n)]

        while T:
            T_prime = []
            e_prime = e.clone()

            for l1, l2 in T:
                for l1_prime, w in self.neighbors_fn(l1):
                    new_val = torch.minimum(s1[l1_prime], e[l1, l2])
                    old_val = e[l1_prime, l2]
                    combined = torch.maximum(old_val, new_val)

                    if combined != old_val:
                        e_prime = e_prime.clone().index_put_(
                            tuple(torch.tensor([[l1_prime], [l2]])), combined
                        )
                        T_prime.append((l1_prime, l2))

            T = T_prime
            e = e_prime

        s = torch.ones(n, requires_grad=True) * self.quantitative_min_satisfaction
        for i in range(n):
            vals = [
                e[i, j] for j in range(n)
                if self.d1 <= D[i, j] <= self.d2
            ]
            if vals:
                max_val = torch.stack(vals).max()
                s = s.clone().scatter_(0, torch.tensor([i]), max_val.unsqueeze(0))

        return s

## Somewhere

In [None]:
class Somewhere():
    """
    Somewhere operator for STREL. Models existence of a satisfying location within a distance interval.
    """
    def __init__(
        self,
        distance_matrix,
        d2,
        graph_nodes,
        distance_domain_min=None,
        distance_domain_max=None,
    ) -> None:
        self.d1 = 0
        self.d2 = d2
        self.distance_domain_min = distance_domain_min
        self.distance_domain_max = distance_domain_max
        self.graph_nodes = graph_nodes

        self.reach_op = Reach(
            distance_matrix=distance_matrix,
            d1=self.d1,
            d2=d2,
            graph_nodes=graph_nodes,
            distance_domain_min=distance_domain_min,
            distance_domain_max=distance_domain_max
        )

        self.boolean_min_satisfaction = torch.tensor(0.0)
        self.quantitative_min_satisfaction = torch.tensor(float('-inf'))

    def __call__(self, phi, quantitative=False):
        if quantitative:
            return self._quantitative(phi)
        else:
            return self._boolean(phi)

    def __str__(self) -> str:
        return f"somewhere_[{self.d1},{self.d2}]"

    # def time_depth(self) -> int:
        # return 0

    # ----------------
    #    Boolean
    # ----------------

    def _boolean(self, phi):
        # somewhere_[0, d] φ = true R_[0, d] φ
        true = torch.ones(len(self.graph_nodes), requires_grad=True)
        return self.reach_op._boolean(true, phi)

    # ---------------------
    #     Quantitative
    # ---------------------

    def _quantitative(self, phi):
        # somewhere_[0, d] φ = true R_[0, d] φ
        true = torch.ones(len(self.graph_nodes), requires_grad=True) * float('inf')
        return self.reach_op._quantitative(true, phi)

## Everywhere

In [None]:
class Everywhere():
    """
    Everywhere operator for STREL. Models satisfaction of a property at all locations within a distance interval.
    """
    def __init__(
        self,
        distance_matrix,
        d2,
        graph_nodes,
        distance_domain_min=None,
        distance_domain_max=None,
    ) -> None:
        self.d1 = 0
        self.d2 = d2
        self.distance_domain_min = distance_domain_min
        self.distance_domain_max = distance_domain_max
        self.graph_nodes = graph_nodes

        self.somewhere_op = Somewhere(
            distance_matrix=distance_matrix,
            d2=d2,
            graph_nodes=graph_nodes,
            distance_domain_min=distance_domain_min,
            distance_domain_max=distance_domain_max
        )

        self.boolean_min_satisfaction = torch.tensor(0.0)
        self.quantitative_min_satisfaction = torch.tensor(float('-inf'))

    def __call__(self, phi, quantitative=False):
        if quantitative:
            return self._quantitative(phi)
        else:
            return self._boolean(phi)

    def __str__(self) -> str:
        return f"everywhere_[{self.d1},{self.d2}]"

    # def time_depth(self) -> int:
        # return 0

    # ----------------
    #    Boolean
    # ----------------

    def _boolean(self, phi):
        # everywhere_[0, d] φ = ¬somewhere_[0, d] ¬φ
        neg_phi = 1 - phi # usa not
        somewhere_neg_phi = self.somewhere_op._boolean(neg_phi)
        return 1 - somewhere_neg_phi

    # ---------------------
    #     Quantitative
    # ---------------------

    def _quantitative(self, phi):
        # everywhere_[0, d] φ = ¬somewhere_[0, d] ¬φ
        neg_phi = -phi
        somewhere_neg_phi = self.somewhere_op._quantitative(neg_phi)
        return -somewhere_neg_phi

## Sorround

In [None]:
class Surround():
    """
    Surround operator for STREL. Models being surrounded by φ2 while in φ1 with distance constraints.
    """
    def __init__(
        self,
        distance_matrix,
        d2,
        graph_nodes,
        distance_domain_min=None,
        distance_domain_max=None,
    ) -> None:
        self.d1 = 0
        self.d2 = d2
        self.distance_domain_min = distance_domain_min
        self.distance_domain_max = distance_domain_max
        self.graph_nodes = graph_nodes

        # Reach operator for [0,d] interval
        self.reach_op = Reach(
            distance_matrix=distance_matrix,
            d1=self.d1, # 0
            d2=d2,
            graph_nodes=graph_nodes,
            distance_domain_min=distance_domain_min,
            distance_domain_max=distance_domain_max
        )

        # Escape operator for [the negated interval [d,∞]
        self.escape_op = Escape(
            distance_matrix=distance_matrix,
            d1=d2,
            d2=distance_domain_max,
            graph_nodes=graph_nodes,
            distance_domain_min=distance_domain_min,
            distance_domain_max=distance_domain_max
        )

        self.boolean_min_satisfaction = torch.tensor(0.0)
        self.quantitative_min_satisfaction = torch.tensor(float('-inf'))

    def __call__(self, phi1, phi2, quantitative=False):
        if quantitative:
            return self._quantitative(phi1, phi2)
        else:
            return self._boolean(phi1, phi2)

    def __str__(self) -> str:
        return f"Surround^f_[{self.d1},{self.d2}]"

    # def time_depth(self) -> int:
        # return 0

    # ----------------
    #    Boolean
    # ----------------

    def _boolean(self, phi1, phi2):
        # φ1 sorround_[0, d] φ2 = φ1 ∧ ¬(φ1 R_[0, d] ¬(φ1 ∨ φ2)) ∧ ¬(E_[d,∞] φ1)

        # First part: φ1
        part1 = phi1

        # Second part: ¬(φ1 R_[0, d] ¬(φ1 ∨ φ2))
        phi1_or_phi2 = torch.maximum(phi1, phi2) # TODO: quando incorpori con STL metti AND, OR e NEGAZIONE di STL
        neg_phi1_or_phi2 = 1 - phi1_or_phi2
        reach_part = self.reach_op._boolean(phi1, neg_phi1_or_phi2)
        part2 = 1 - reach_part

        # Third part: ¬(E_[d,∞] φ1)
        escape_part = self.escape_op._boolean(phi1)
        part3 = 1 - escape_part

        # Combine all parts
        result = torch.minimum(part1, torch.minimum(part2, part3))
        return result

    # ---------------------
    #     Quantitative
    # ---------------------

    def _quantitative(self, phi1, phi2):
        # φ1 sorround_[0, d] φ2 = φ1 ∧ ¬(φ1 R_[0, d] ¬(φ1 ∨ φ2)) ∧ ¬(E_[d,∞] φ1)

        # First part: φ1
        part1 = phi1

        # Second part: ¬(φ1 R_[0, d] ¬(φ1 ∨ φ2))
        phi1_or_phi2 = torch.maximum(phi1, phi2)
        neg_phi1_or_phi2 = -phi1_or_phi2
        reach_part = self.reach_op._quantitative(phi1, neg_phi1_or_phi2)
        part2 = -reach_part

        # Third part: ¬(E_[d,∞] φ1)
        escape_part = self.escape_op._quantitative(phi1)
        part3 = -escape_part

        # Combine all parts
        result = torch.minimum(part1, torch.minimum(part2, part3))
        return result

# Testing

In [None]:
# 1. Graph initialization (same as before)
graph_nodes = torch.tensor(list(range(16)))
N = len(graph_nodes)
distance_matrix = torch.zeros((N, N))

edges = [
    (7, 0), (7, 3), (7, 5), (7, 6), (6, 7), (6, 4), (4, 6), (4, 5),
    (6, 1), (6, 9), (9, 6), (9, 2), (9, 13), (9, 8), (8, 9), (8, 13),
    (9, 15), (15, 9), (9, 10), (10, 9), (15, 10), (10, 15), (15, 12),
    (10, 14), (10, 11),
    (0, 7), (3, 7), (5, 7), (5, 4),
    (1, 6), (2, 9), (13, 9), (13, 8),
    (12, 15), (14, 10), (11, 10)
]
for i, j in edges:
    distance_matrix[i, j] = 1

# 2. Node roles and features (same as before)
router_nodes = [4, 6, 7, 8, 10, 15]
coord_node = 9
end_dev_nodes = [0, 1, 2, 3, 5, 11, 12, 13, 14]

node_features = torch.zeros((N, 2))
for i in router_nodes:
    node_features[i, 0] = 1
node_features[coord_node, 0] = 0
for i in end_dev_nodes:
    node_features[i, 0] = 2
node_features[:, 1] = torch.rand(N)

# 3. Define signals
coord_signal = (node_features[:, 0] == 0).float()  # Coordinator (node 9)
router_signal = (node_features[:, 0] == 1).float() # Routers
end_dev_signal = (node_features[:, 0] == 2).float() # End devices
router_or_coord = torch.maximum(router_signal, coord_signal) # φ1 for surround

# ======================
# TEST 1: Somewhere^hop_[0,4] coord
# "Is there a coordinator reachable within 4 hops?"
# ======================
somewhere_op = Somewhere(
    distance_matrix=distance_matrix,
    d2=1,
    graph_nodes=graph_nodes,
    distance_domain_min=0,
    distance_domain_max=float('inf')
)

somewhere_result = somewhere_op(coord_signal)
print("\n=== Somewhere^hop_[0,4] coord ===")
print("(Is there a coordinator within 4 hops?)")
for i in range(N):
    print(f"Node {i} ({'coord' if i==9 else 'router' if i in router_nodes else 'end_dev'}): {'✅ Yes' if somewhere_result[i] > 0.5 else '❌ No'}")

# ======================
# TEST 2: Everywhere^hop_[0,2] router
# "Are all nodes within 2 hops routers?"
# ======================
everywhere_op = Everywhere(
    distance_matrix=distance_matrix,
    # d1=0,
    d2=0,
    graph_nodes=graph_nodes,
    distance_domain_min=0,
    distance_domain_max=float('inf')
)

everywhere_result = everywhere_op(router_signal)
print("\n=== Everywhere^hop_[0,2] router ===")
print("(Are all nodes within 2 hops routers?)")
for i in range(N):
    print(f"Node {i} ({'coord' if i==9 else 'router' if i in router_nodes else 'end_dev'}): {'✅ Yes' if everywhere_result[i] > 0.5 else '❌ No'}")

# ======================
# TEST 3: (coord ∨ router) Surround^hop_[0,3] end_dev
# "Is the node surrounded by end devices within 3 hops, while being a coordinator or router?"
# ======================
surround_op = Surround(
    distance_matrix=distance_matrix,
    # d1=0,
    d2=4 ,
    graph_nodes=graph_nodes,
    distance_domain_min=0,
    distance_domain_max=float('inf')
)

surround_result = surround_op(router_or_coord, end_dev_signal)
print("\n=== (coord ∨ router) Surround^hop_[0,3] end_dev ===")
print("(Is the node surrounded by end devices within 3 hops while being coord/router?)")
for i in range(N):
    node_type = ('coord' if i == 9 else
                'router' if i in router_nodes else
                'end_dev')
    print(f"Node {i} ({node_type}): {'✅ Yes' if surround_result[i] > 0.5 else '❌ No'}")

Q =  {0: [(tensor(0.), 0)], 1: [(tensor(0.), 0)], 2: [(tensor(0.), 0)], 3: [(tensor(0.), 0)], 4: [(tensor(0.), 0)], 5: [(tensor(0.), 0)], 6: [(tensor(0.), 0)], 7: [(tensor(0.), 0)], 8: [(tensor(0.), 0)], 9: [(tensor(1.), 0)], 10: [(tensor(0.), 0)], 11: [(tensor(0.), 0)], 12: [(tensor(0.), 0)], 13: [(tensor(0.), 0)], 14: [(tensor(0.), 0)], 15: [(tensor(0.), 0)]}
l =  0
node =  0  has neigh_pairs =  [(7, 1.0)]
neigh of 0: 7
s =  tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       grad_fn=<ScatterBackward0>)
l =  1
node =  1  has neigh_pairs =  [(6, 1.0)]
neigh of 1: 6
s =  tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       grad_fn=<ScatterBackward0>)
l =  2
node =  2  has neigh_pairs =  [(9, 1.0)]
neigh of 2: 9
s =  tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       grad_fn=<ScatterBackward0>)
l =  3
node =  3  has neigh_pairs =  [(7, 1.0)]
neigh of 3: 7
s =  tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 

