In [None]:
import torch
import torch.nn as nn

In [None]:
class DiffReach():
    """
    Reachability operator for STREL. Models bounded or unbounded reach
    over a spatial graph.
    """
    def __init__(
        self,
        adjacency_matrix,
        d1,
        d2,
        graph_nodes,
        is_unbounded: bool = False,
        distance_domain_min=None,
        distance_domain_max=None,
    ) -> None:
        super().__init__()
        self.d1 = d1
        self.d2 = d2
        self.is_unbounded = is_unbounded
        self.distance_domain_min = distance_domain_min
        self.distance_domain_max = distance_domain_max
        self.adjacency_matrix = adjacency_matrix
        self.graph_nodes = graph_nodes

    def __call__(self, s1, s2):
        return self._boolean(s1, s2)

    def __str__(self) -> str:
        bound_type = "unbounded" if self.is_unbounded else f"[{self.d1},{self.d2}]"
        return f"Reach{bound_type}"

    def time_depth(self) -> int:
        return 0

    def neighbors_fn(self, node):
        aa = (self.adjacency_matrix[:, node] > 0)
        neigh = self.graph_nodes[aa]
        neigh_pairs = [(i.item(), self.adjacency_matrix[i, node].item()) for i in neigh]
        print('node = ', node, ' has neigh_pairs = ', neigh_pairs)
        return neigh_pairs

    def distance_function(self, weight):
        return weight

    def _boolean(self, s1, s2):
        return self._bounded_reach(s1, s2)

    def _bounded_reach(self, s1, s2):
        # Initialize s with requires_grad=True
        s = torch.zeros(len(self.graph_nodes), requires_grad=True)

        # Set initial values without breaking the graph
        for i, lt in enumerate(self.graph_nodes):
            l = lt.item()
            if self.d1 == self.distance_domain_min:
                s = s.clone().scatter_(0, torch.tensor([l]), s2[l])
                #print('i: ', i, 'l: ', l, 's = ', s)
            else:
                s = s.clone().scatter_(0, torch.tensor([l]), self.distance_domain_min)

        Q = {llt.item(): [(s2[llt.item()], self.distance_domain_min)] for llt in self.graph_nodes}

        while Q:
            print('Q = ', Q)
            Q_prime = {}
            for l in Q.keys():
                print('l = ', l)
                for v, d in Q[l]:
                    for l_prime, w in self.neighbors_fn(l):
                        print(f'neigh of {l}: {l_prime}')
                        v_new = torch.minimum(v, s1[l_prime])
                        d_new = d + w

                        if self.d1 <= d_new <= self.d2:
                            current_val = s[l_prime]
                            new_val = torch.maximum(current_val, v_new)
                            # Use scatter_ to preserve differentiability
                            s = s.clone().scatter_(0, torch.tensor([l_prime]), new_val)
                            print('s = ', s)

                        if d_new < self.d2:
                            #if l_prime not in Q_prime:
                            #    Q_prime[l_prime] = []

                            # Create new list with updated values
                            existing_entries = Q_prime.get(l_prime, [])
                            updated = False
                            new_entries = []
                            for vv, dd in existing_entries:
                                if dd == d_new:
                                    new_v = torch.maximum(vv, v_new)
                                    new_entries.append((new_v, dd))
                                    updated = True
                                else:
                                    new_entries.append((vv, dd))

                            if not updated:
                                new_entries.append((v_new, d_new))
                            Q_prime[l_prime] = new_entries
            print('Q_prime = ', Q_prime)
            Q = Q_prime
        return s

    def _unbounded_reach(self, s1, s2):
        s = torch.zeros(len(self.graph_nodes), requires_grad=True)
        return s

    def _quantitative(self, s1, s2):
        return self._bounded_reach(s1, s2)

In [None]:
graph_nodes = torch.arange(5)
adjacency_matrix = torch.zeros((5,5))

adjacency_matrix[0,1] = 1
adjacency_matrix[0,2] = 1
adjacency_matrix[1,3] = 1
adjacency_matrix[2,3] = 1
adjacency_matrix[2,4] = 1

s1 = torch.tensor([0.,0,1.,0,0],requires_grad=True)
s2 = torch.tensor([0.,0.,0.,0.,1.],requires_grad=True)

reach = DiffReach(adjacency_matrix=adjacency_matrix,
        d1=0,
        d2=2,
        graph_nodes=graph_nodes,
        distance_domain_min=0,
        distance_domain_max=100)

In [None]:
print('sat = ', reach._boolean(s1,s2))

Q =  {0: [(tensor(0., grad_fn=<SelectBackward0>), 0)], 1: [(tensor(0., grad_fn=<SelectBackward0>), 0)], 2: [(tensor(0., grad_fn=<SelectBackward0>), 0)], 3: [(tensor(0., grad_fn=<SelectBackward0>), 0)], 4: [(tensor(1., grad_fn=<SelectBackward0>), 0)]}
l =  0
node =  0  has neigh_pairs =  []
l =  1
node =  1  has neigh_pairs =  [(0, 1.0)]
neigh of 1: 0
s =  tensor([0., 0., 0., 0., 1.], grad_fn=<ScatterBackward0>)
l =  2
node =  2  has neigh_pairs =  [(0, 1.0)]
neigh of 2: 0
s =  tensor([0., 0., 0., 0., 1.], grad_fn=<ScatterBackward0>)
l =  3
node =  3  has neigh_pairs =  [(1, 1.0), (2, 1.0)]
neigh of 3: 1
s =  tensor([0., 0., 0., 0., 1.], grad_fn=<ScatterBackward0>)
neigh of 3: 2
s =  tensor([0., 0., 0., 0., 1.], grad_fn=<ScatterBackward0>)
l =  4
node =  4  has neigh_pairs =  [(2, 1.0)]
neigh of 4: 2
s =  tensor([0., 0., 1., 0., 1.], grad_fn=<ScatterBackward0>)
Q_prime =  {0: [(tensor(0., grad_fn=<MaximumBackward0>), 1.0)], 1: [(tensor(0., grad_fn=<MinimumBackward0>), 1.0)], 2: [(tensor