# What would be the ideal way to do this?

1. Define points $x_0,\dots,x_{n-1}$ (arbitrary) and pick some $\varepsilon$
2. Define the $\varepsilon$ graph. That is the graph where $i-j$ connect if $D[i,j] \leq \varepsilon$
3. Define the VR complex at $\varepsilon$. Every complete subgraph becomes a simplex. This is where we wonder if enumerating all simplices creates duplicates

Idea: Only generate simplices with increasing vertex order.
$$N^+(i) = \{\, j>i : D[i,j]\leq \varepsilon \,\}$$
I.e. neighbors that come after me.

4. Now we have a simplex $\sigma = [i_0<i_1<\dots<i_k]$. We want to add a new vertex $n$ to make the simplex larger. For $\sigma \cup \{n\}$ to be a simplex, $n$ must connect to every vertex in $\sigma$ in the $\varepsilon$ graph. 

$$n \in N^+(i_0)\cap N^+(i_1)\cap \cdots \cap N^+(i_k)$$

This is why we should maintain a candidate set $C(\sigma) = N^+(i_0)\cap N^+(i_1)\cap \cdots \cap N^+(i_k)$. Now we start at vertex $i$ and find $C([i]) = N^+(i)$. If we extend by adding n, the new candidates become $C(\sigma \cup \{n\}) = C(\sigma)\cap N^+(n)$.

This guarantees that only vertices that connect to all simplex vertices are considered and they respect ordering.

5. If we were to compute all edges, then all triangles, ..., we would store too many simplices. We want to build one simplex at a time, emit its contribution as soon as possible, then backtrack.

# What does Dlotko do compared to this?

1. We already define points in an ascending arbitrary manner with this.
```python
def pairwise_dist(X: np.ndarray) -> np.ndarray: #input is a shape X with n points and d dim
    diff = X[:, None, :] - X[None, :, :] #diff is every pairwise distance vector
    return np.linalg.norm(diff, axis=2)

def subseq_neighbors(D: np.ndarray, i: int, epsilon: float) -> list[int]:
    n = D.shape[0] #precomputed distance matrix D (n is number of points)
    js = np.where((D[i] <= epsilon) & (np.arange(n) > i))[0]
    return js.tolist()
```
subseq_neighbors returns the neighbor set of a vertex with only indices larger than itself ($N^+(i) = \{\, j > i \mid D[i,j] \leq \varepsilon \,\}$). This is how we enforce ordering and avoid duplicate simplices. Let's analyze the line `js = np.where((D[i] <= epsilon) & (np.arange(n) > i))[0]`. `D[i] <= epsilon` is taking the ith row of the distance matrix produces a boolean array containing True is point j is withing $\varepsilon$ of i. `np.arange(n) > 1` creates $[0,1, \dots, n-1]$ and produces a boolean list of True if the indice in greater than i. Now we take the and of these two state ments so we have a boolean list that says True if the point is within $\varepsilon$ distance and its index is $> i$. `np.where` returns the indices where the condition is True, so js is an array of valid neighbors of vertex $i$. 

pairwise_dist takes in a shape $X$ and outputs its symmetric distance matrix $D[i,j] = \|X[i] - X[j]\|$. We can have some issues here with large VR. If $n=10,000$, then $D$ contains 100,000,000 entries. Potentially think about avoiding storing the full distance matrix. 

2. Looking at the Algorithms

Abstractly what we want to do is:

```python
for each i:
    expand([i], 0, N+(i))
```

We currently have:

```python
for i in range(n):
    simplices = [[i]] #this is a list of a list just like a simplex is a set of a set (this is initializing our simplex)
    filtrations = [0]
    common_subseq_neighs = subseq_neighbors(D, i, epsilon) #this is C = N+(i)
```
These two are identical.


Now we want to `emit(fsigma, (-1)^dim)` and we have:

```python
for sigma, f_sigma in zip(simplices, filtrations):
    dim_sigma = len(sigma) - 1
    sign = 1 if (dim_sigma % 2 == 0) else -1 #this is for EC assigns its sign in the alternating sum
    C.append((f_sigma, sign)) #stores a pair (filtration value, +-1)
```

Now we wish to grow the simplex. 

Ideally:
```python
for n in C:
    f_new = max(fsigma, max_{v in sigma} D[v,n])
    C_new = C intersect N+(n)
```
We have:
```python
for sigma, f_sigma, commonN in zip(simplices, filtrations, common_subseq_neighs):
    for n in commonN:
        sigma2 = sigma + [n]
```
This corresponds to item number 4 on the top block.

Now we need to update the filtration for this new simplex ($\sigma \cap \{n\}$). Ideally in pseudo code this is `f_new = max(fsigma, longest_edge)` We have:

```python
longest_edge = 0.0
for v in sigma:
    d = D[v,n]
    if d > longest_edge:
        longest_edge = d
new_f = max(f_sigma, longest_edge)
``` 
This is the same mathematically as the pseudocode.

After this we need to update the candidate set `C_new = C intersect N+(n)`. We have:

```python
neigh_n = subseq_neighbors(D, n, epsilon)
intersection = sorted(set(commonN).intersection(neigh_N))
```
which is the same mathematically as well.

Now we hit a difference!

Ideally we build one simplex chain as deep as possible then backtrack like:
```python
expand(i):
    expand(i,j):
        expand(i,j,k):
            expand(i,j,k,l):
                ...
```
This stores memory in terms of O(max simplex dim) not O(simplex number)

We currently have
```python
while simplices not empty:
    record all simplices at this dimension
    simplices = inc_dim(...)
```

This iterates all vertices, then edges starting at i, then triangles at i, ... We have a memory risk in the logic here.

Another place we have a memory risk is that we recompute `neigh_n = subseq_neighbors(D, n ,epsilon)` every time we extend the simplex, which is redundant. 


# Let's try to Optimize this sampling from [0,1]

In [32]:
import numpy as np

def pairwise_dist(X: np.ndarray) -> np.ndarray:
    diff = X[:, None, :] - X[None, :, :]
    return np.linalg.norm(diff, axis = 2)
#this is the same as before inputs shape X and outputs symmetric distance matrix D

def Nplus(D: np.ndarray, epsilon: float) -> list[np.ndarray]:
    #precompute Nplus[i] = sorted array of neighbors j > i with D[i,j] <= epsilon
    n = D.shape[0] #number of vertices
    N = []
    for i in range(n):
        js = np.where(D[i, i + 1:] <= epsilon)[0] + (i + 1)
        N.append(js.astype(np.int32))
    return N #prevents double counting and returns the N+ matrix for all vertices

def intersect_sorted(a: np.ndarray, b: np.ndarray) -> np.ndarray:
    #intersection of two sorted integer arrays with numpys intersect1d
    if a.size == 0 or b.size == 0:
        return np.empty(0, dtype = np.int32)
    return np.intersect1d(a, b, assume_unique=False)
    # this is how we enforce that future vertices must neighbor all current simplex vertices


def local_contributions_vr_dfs( #enumerating all simplices in the epsilon graph up to some max_dim
        X: np.ndarray,
        epsilon: float,
        max_dim: int | None = None,
        ) -> list[tuple[float, int]]:
    """
    Using clean DFS complete subgraph enumeration in the epsilon graph with ordered neighbor sets.
    Emits (filtration_value, (-1)^dim) for each simplex found to save memory.
    outputs C which is a list of events at filtration value f add 1 or subtract 1
    """
    X = np.asarray(X, dtype = float)
    n = X.shape[0]
    D = pairwise_dist(X)
    N = Nplus(D, epsilon)

    C: list[tuple[float, int]] = [] # candidate set is a set of float,int tuples with (filtration_val, sign)
    
    def expand(simplex: list[int], candidates: np.ndarray, f_simplex: float):
        #simplex is a list of vertex indices in sigma. Candidates is C which is the allowed vertices to add next (the intersection set).
        # f_simplex is the current filtration value for sigma (max edge in sigma so far). 

        #emit contribution of the current simplex for memory
        dim = len(simplex) - 1
        sign = 1 if (dim % 2 == 0) else -1
        C.append((f_simplex, sign))

        #dimension cap (don't expand further than max_dim)
        if max_dim is not None and dim >= max_dim:
            return
        
        #now try expanding by each candidate vertex v
        for v in candidates:
            v = int(v)
            #update filtration with max of current filtration and longest edge to v
            longest_edge = 0.0
            for u in simplex:
                d = D[u, v]
                if d > longest_edge:
                    longest_edge = d
            f_new = max(f_simplex, longest_edge)
            # computes max(f_sigma, maxD[u,v] for u in sigma)

            #update candidate set (C intersect N+[v] = C(sigma union {v}))
            candidates_new = intersect_sorted(candidates, N[v])

            simplex.append(v) #add v
            expand(simplex, candidates_new, f_new)
            simplex.pop()
    
    #start from each vertex i as the smallest vertex index
    for i in range(n):
        expand([i], N[i], 0.0)
    
    C.sort(key = lambda t: t[0]) #sort by filtration value to to cumulatibe sums in order
    return C

def EC_at_eps(C: list[tuple[float, int]], r: float) -> int:
    #C must be sorted by filtration value (which it is)
    total = 0
    for f, s in C:
        if f <= r:
            total += s
        else:
            break
    return total


# What now?

Now we have C, our total candidate matrix containing tuples of (filtration value, sign) for all vertices. 

In [33]:
n = 100
X = np.random.rand(n,1)
eps_enum = 0.5

C = local_contributions_vr_dfs(X, eps_enum, max_dim=3)

for eps in [0.1, 0.25, 0.4]:
    print(eps, EC_at_eps(C, eps))

0.1 -8463
0.25 -159404
0.4 -611852


In [34]:
import matplotlib as plt
def simulate_EC_dist(
        eps_list,
        n = 200,
        d = 1,
        trials = 30,
        max_dim = 30,
        seed = 0
        ):
    """
    for each epsilon in eps_list sample X ~ unif([0,1]^d) trials times. 
    compute C using eps_enum = epsilon
    compute EC at r = epsilon
    returns dict eps -> np.array of EC values (length = trials)
    """
    rng = np.random.default_rng(seed)
    results = {}

    for eps in eps_list:
        ecs = []
        for _ in range(trials):
            X = rng.random((n,d)) #uniform([0,1]^d)

            C = local_contributions_vr_dfs(X, eps, max_dim=max_dim)
            ec = EC_at_eps(C, eps)
            ecs.append(ec)

        results[eps] = np.array(ecs, dtype= int)
    return results

In [38]:
eps_list = [0.1, 0.25, 0.4]
results = simulate_EC_dist(eps_list, n=50, d=1, trials=30, max_dim=2, seed=42)

for eps in eps_list:
    vals = results[eps]
    print(eps, "mean", vals.mean(), "std", vals.std(), "min", vals.min(), "max", vals.max())

0.1 mean 356.8666666666667 std 82.01046003745836 min 244 max 526
0.25 mean 2564.5333333333333 std 321.4607008986047 min 1988 max 3229
0.4 mean 6219.4 std 905.3609077784026 min 4693 max 8317


In [39]:
for r in np.linspace(0.01, 1.0, 20):
    ec = EC_at_eps(C, r)
    print(r, ec)

0.01 39
0.06210526315789474 -1694
0.11421052631578947 -14328
0.16631578947368422 -44357
0.21842105263157896 -99060
0.2705263157894737 -201135
0.32263157894736844 -334935
0.37473684210526315 -516044
0.4268421052631579 -698684
0.4789473684210527 -943645
0.5310526315789474 -1041976
0.5831578947368421 -1041976
0.6352631578947369 -1041976
0.6873684210526316 -1041976
0.7394736842105263 -1041976
0.791578947368421 -1041976
0.8436842105263158 -1041976
0.8957894736842106 -1041976
0.9478947368421053 -1041976
1.0 -1041976
