In [None]:
# default_exp eeemd
from nbdev.showdoc import *
import numpy as np
import matplotlib.pyplot as plt
import torch
import FRED
if torch.__version__[:4] == '1.13': # If using pytorch with MPS, use Apple silicon GPU acceleration
    device = torch.device("cuda" if torch.cuda.is_available() else 'mps' if torch.has_mps else "cpu")
else:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device", device)
%load_ext autoreload
%autoreload 2

ModuleNotFoundError: No module named 'nbdev'

# EEEMD: Edge-to-Edge Earth Mover's Distance
> EMD between graphs with same number of nodes

We want a distance between graphs with the same number of nodes, obeying:
1. The cost of moving an edge from one neighbor to another neighbor is lower than the cost of moving it from a neighbor to a non-neighbor.
2. The distance between a pair of edges having one point in common is smaller than the distance between a pair of edges having no points in common.
3. The above differences in magnitude correspond to the geodesic (manifold) distances on the graph.

Such a distance could be useful to FRED. We've previously used the KLD between adjacency matrices as a proxy for this graph-to-graph distance, but the KLD doesn't respect the graph geometry — all edges are equally far apart. As a result, gradients flowing from the KLD can't give as useful information to the model -- insofar as *which* way to tweak the edges or move the nodes -- as might the mythical EEMD.

## City-block EEEMD
Here's one very simple such metric satisfying the above. Let $e_{ij}$ denote the edge between $i$ and $j$. Then we define the distance as:

Neighboring edges have distance equal to the transport cost between nodes:
$$d(e_{ij}, e_{ik}) = d_{graph}(j,k)$$

By the triangle inequality, other edges have distance less than or equal to the distance obtained by moving one node to a neighboring node and then moving the other:
$$d(e_{ij},e_{kl}) \leq d(e_{ij}, e_{kj}) + d(e_{kj},e_{kl})$$

We'll go ahead and declare this a city-block style metric, in which this distance is equal
$$d(e_{ij},e_{kl}) = \text{min}(d(e_{ij}, e_{kj}) + d(e_{kj},e_{jl}), d(e_{ij}, e_{il}) + d(e_{il},e_{kl}))$$

This defines a distance metric between pairs of edges. Using this as a *ground* distance over the adjacency matrix, we define the City-block EEMD as the EMD between the adjacency matrices under the cityblock edge-to-edge ground distance.

In [None]:
from scipy.sparse import csr_array
def EEEMD_cityblock(A1, A2, Dgraph):
    # Calculate edge-to-edge ground distance
    # Support sparse matrices for more efficient computation: we only care 
    # First, calculate distances between neighbors.
    # Get nonzero combined indices; we only care about these distances
    nonzero_indices = np.vstack(csr_array(A1 + A2).nonzero())
    return nonzero_indices


In [None]:
# test case: branch graph
from FRED.graph_datasets import CycleGraph
A = CycleGraph(num_nodes=10)

ImportError: Numba needs NumPy 1.22 or less

In [None]:
from scipy.sparse import csr_array
U = np.random.rand(10,10)
Usparse = csr_array(U)

In [None]:
Usparse.nonzero()

(array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6,
        6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9], dtype=int32),
 array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
        2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3,
        4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5,
        6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7,
        8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32))

In [None]:
np.vstack(Usparse.nonzero())

array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6,
        6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
       [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
        2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3,
        4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5,
        6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7,
        8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], dtype=int32)