# Test for the non-differentiable version of the Reach Operator from STREL

In [1]:
# Libraries

import torch
import torch.nn as nn

## Non-differentiable Reach

In [2]:
class NonDiffReach():
    """
    Reachability operator for STREL (non-differentiable version).
    Models bounded or unbounded reach over a spatial graph.
    """
    def __init__(
        self,
        adjacency_matrix,
        d1,
        d2,
        graph_nodes,
        is_unbounded: bool = False,
        distance_domain_min=None,
        distance_domain_max=None,
    ) -> None:
        self.d1 = d1
        self.d2 = d2
        self.is_unbounded = is_unbounded
        self.distance_domain_min = distance_domain_min
        self.distance_domain_max = distance_domain_max
        self.adjacency_matrix = adjacency_matrix
        self.graph_nodes = graph_nodes

    def __call__(self, s1, s2):
        return self._boolean(s1, s2)

    def __str__(self) -> str:
        bound_type = "unbounded" if self.is_unbounded else f"[{self.d1},{self.d2}]"
        return f"Reach{bound_type}"

    def time_depth(self) -> int:
        return 0

    '''
    def neighbors_fn(self, node):
        aa = (self.adjacency_matrix[:, node] > 0)
        neigh = self.graph_nodes[aa]
        return [(i.item(), self.adjacency_matrix[i, node].item()) for i in neigh]
    '''

    def neighbors_fn(self, node):
        aa = (self.adjacency_matrix[:, node] > 0)
        neigh = self.graph_nodes[aa]
        neigh_pairs = [(i.item(), self.adjacency_matrix[i, node].item()) for i in neigh]
        print('node = ', node, ' has neigh_pairs = ', neigh_pairs)
        return neigh_pairs

    def distance_function(self, weight):
        return weight

    # def _boolean(self, s1, s2):
        # return self._bounded_reach(s1, s2)

    def _boolean(self, s1, s2):
        if self.is_unbounded:
            return self._unbounded_reach(s1, s2)
        else:
            return self._bounded_reach(s1, s2)

    def _bounded_reach(self, s1, s2):
        s = torch.zeros(len(self.graph_nodes))  # No requires_grad

        for i, lt in enumerate(self.graph_nodes):
            l = lt.item()
            if self.d1 == self.distance_domain_min:
                s[l] = s2[l]
            else:
                s[l] = self.distance_domain_min

        Q = {lt.item(): [(s2[lt.item()], self.distance_domain_min)] for lt in self.graph_nodes}

        while Q:
            Q_prime = {}
            for l in Q:
                for v, d in Q[l]:
                    for l_prime, w in self.neighbors_fn(l):
                        v_new = min(v.item(), s1[l_prime].item())
                        d_new = d + w

                        if self.d1 <= d_new <= self.d2:
                            s[l_prime] = max(s[l_prime].item(), v_new)

                        if d_new < self.d2:
                            existing = Q_prime.get(l_prime, [])
                            updated = False
                            new_entries = []
                            for vv, dd in existing:
                                if dd == d_new:
                                    new_entries.append((max(vv, v_new), dd))
                                    updated = True
                                else:
                                    new_entries.append((vv, dd))
                            if not updated:
                                new_entries.append((v_new, d_new))
                            Q_prime[l_prime] = new_entries
            Q = Q_prime
        return s

    def _unbounded_reach(self, s1, s2):
        if self.d1 == self.distance_domain_min:
            return s2

        d_max = 0
        for i in self.graph_nodes:
            for j in self.graph_nodes:
                if self.adjacency_matrix[i, j] > 0:
                    d_max = max(d_max, self.distance_function(self.adjacency_matrix[i, j].item()))

        d2_prime = self.d1 + d_max
        s = self._bounded_reach(s1, s2)

        T = set(n.item() for n in self.graph_nodes)

        while T:
            T_prime = set()
            for l in T:
                for l_prime, w in self.neighbors_fn(l):
                    v = max(min(s[l], s1[l_prime]), s[l_prime])
                    if v != s[l_prime]:
                        s[l_prime] = v
                        T_prime.add(l_prime)
            T = T_prime

        return s


    def _quantitative(self, s1, s2):
        return self._bounded_reach(s1, s2)


## Test

In [4]:
# 1. Inizializzazione del grafo (16 nodi, 0-based indexing)
graph_nodes = torch.tensor(list(range(16)))
N = len(graph_nodes)

# Matrice di adiacenza
adjacency_matrix = torch.zeros((N, N))


edges = [
    (7, 0), (7, 3), (7, 5), (7, 6), (6, 7), (6, 4), (4, 6), (4, 5),
    (6, 1), (6, 9), (9, 6), (9, 2), (9, 13), (9, 8), (8, 9), (8, 13),
    (9, 15), (15, 9), (9,10), (10,9), (15,10), (10,15), (15,12),
    (10,14), (10,11),
    (0, 7), (3, 7), (5, 7), (5, 4),
    (1, 6), (2, 9), (13, 9), (13, 8),
    (12, 15), (14, 10), (11, 10)
]

for i, j in edges:
    adjacency_matrix[i, j] = 1

# 2. Ruoli:
#     - router (può inoltrare dati)
#     - coordinatore (inizializza la rete)
#     - end device (può solo ricevere, non inoltra)

router_nodes = [4, 6, 7, 8, 10, 15] # nodi 5, 7, 8, 9, 11, 16
coord_node = 9 # nodo 10
end_dev_nodes = [0, 1, 2, 3, 5, 11, 12, 13, 14] # nodi 1,2,3,4,6,12,13,14,15

# 3. Segnali: s1 tutti 1 (nessun vincolo sul sul tipo di nodo da cui parte il path), s2 solo router (nodo che voglio raggiungere)
# s1 = torch.ones(N)
s1 = torch.zeros(N)
s2 = torch.zeros(N)
for i in end_dev_nodes:
    s1[i] = 1
for i in router_nodes:
    s2[i] = 1
# s2 = torch.zeros(N)
# for i in router_nodes:
    # s2[i] = 1

# 4. Operatore NonDiffReach
reach_op = NonDiffReach(
    adjacency_matrix=adjacency_matrix,
    d1=0, d2=1, # al max 1 hop
    graph_nodes=graph_nodes,
    is_unbounded=False,
    distance_domain_min=0,
    distance_domain_max=1,
)

# 5. Calcolo della formula: s1 R[0,1] s2
result = reach_op(s1, s2) # restituisce un tensore di dimensione N (16): result[i] = 1 se il nodo i può raggiungere un router in 0 o 1 hop, result[i] = 0 altrimenti

# 6. Report: verifica solo per gli end devices
print("\nVerifica: ogni end device può raggiungere un router entro 1 hop?\n")
for i in end_dev_nodes:
    print(f"End device {i+1} → {'✅ Sì' if result[i]==1 else '❌ No'}")

# 7. Matrice Q per logging
Q = torch.zeros((3, N))
for i in range(N):
    Q[0, i] = i
    Q[1, i] = result[i].item()
    Q[2, i] = 1 if result[i] == 1 else 0

print("\nMatrice Q (nodo, soddisfazione, distanza):")
print(Q)

node =  0  has neigh_pairs =  [(7, 1.0)]
node =  1  has neigh_pairs =  [(6, 1.0)]
node =  2  has neigh_pairs =  [(9, 1.0)]
node =  3  has neigh_pairs =  [(7, 1.0)]
node =  4  has neigh_pairs =  [(5, 1.0), (6, 1.0)]
node =  5  has neigh_pairs =  [(4, 1.0), (7, 1.0)]
node =  6  has neigh_pairs =  [(1, 1.0), (4, 1.0), (7, 1.0), (9, 1.0)]
node =  7  has neigh_pairs =  [(0, 1.0), (3, 1.0), (5, 1.0), (6, 1.0)]
node =  8  has neigh_pairs =  [(9, 1.0), (13, 1.0)]
node =  9  has neigh_pairs =  [(2, 1.0), (6, 1.0), (8, 1.0), (10, 1.0), (13, 1.0), (15, 1.0)]
node =  10  has neigh_pairs =  [(9, 1.0), (11, 1.0), (14, 1.0), (15, 1.0)]
node =  11  has neigh_pairs =  [(10, 1.0)]
node =  12  has neigh_pairs =  [(15, 1.0)]
node =  13  has neigh_pairs =  [(8, 1.0), (9, 1.0)]
node =  14  has neigh_pairs =  [(10, 1.0)]
node =  15  has neigh_pairs =  [(9, 1.0), (10, 1.0), (12, 1.0)]

Verifica: ogni end device può raggiungere un router entro 1 hop?

End device 1 → ✅ Sì
End device 2 → ✅ Sì
End device 3 → ❌ No