# Toy topology and signals dataset

In [1]:
import numpy as np 
from itertools import combinations
from tqdm import tqdm

In [2]:
# Let's generate a toy topology for our example

nodes = [i for i in range(7)]
edges = [
    (0,1),
    (0,3),
    (0,6),
    (1,2),
    (1,5),
    (2,4),
    (4,6),
    (5,6)
]

V = 7
E = len(edges)

d = 100                                          # Node and edges stalks dimension

F = {
    e:{
        e[0]:np.random.randn(d,d),
        e[1]:np.random.randn(d,d)
        } 
        for e in edges
    }                                           # Incidency linear maps

# Sheaf representation 

# Coboundary map

B = np.zeros((d*E, d*V))

for i in range(len(edges)):
    edge = edges[i]

    u = edge[0] 
    v = edge[1] 

    B_u = F[edge][u]
    B_v = F[edge][v]

    B[i*d:(i+1)*d, u*d:(u+1)*d] = B_u
    B[i*d:(i+1)*d, v*d:(v+1)*d] = - B_v

# Sheaf Laplacian

L_f = B.T @ B

In [3]:
N = 100
X = np.random.randn(V*d,N)

In [4]:
Lambda, U = np.linalg.eig(L_f)
H = 1/(1 + 10*Lambda)

In [5]:
Y = U @ np.diag(H) @ U.T @ X

Y += np.random.normal(0, 10e-2, size=Y.shape)

# Imposed dictionary for learned sparsity

In [6]:
def proj(D):

    return np.apply_along_axis(lambda d: d/np.max([1, np.linalg.norm(d)]), axis=0, arr=D)

In [7]:
M = 100                      # Number of atoms
D = np.random.randn(d, M)    # Dictionary of random atoms

D = proj(D)

In [8]:
# Subroutine for learned sparsity - OTM

def ortho_match_pursuit(y, A, max_iter = 10, eps = 1e-2):
    S = []

    x = np.zeros(A.shape[1])
    iters = 0
    rate = 0
    R = y

    while iters < max_iter or rate > eps:

        # Retrieve the maximum correlation between atoms and residuals of the previous iteration
        S.append(np.argmax(np.abs(A.T @ R)))
        # Expand the dictionary for the representation
        dic = A[:,S]

        # Solve subproblems and update x
        x[S] = np.linalg.inv(dic.T @ dic) @ dic.T @ y
        
        # Update the residuals
        R = y - A @ x
        iters += 1

    return x

In [9]:
S = {node:None for node in nodes}

In [12]:
for node in nodes:
    S_ = np.zeros((M,N))
    y = Y[node*d:(node+1)*d,:]
    for n in tqdm(range(N)):
        S_[:,n] = ortho_match_pursuit(y[:,n],D,M)
    S[node] = S_

100%|██████████| 100/100 [00:18<00:00,  5.41it/s]
100%|██████████| 100/100 [00:18<00:00,  5.45it/s]
100%|██████████| 100/100 [00:18<00:00,  5.34it/s]
100%|██████████| 100/100 [00:18<00:00,  5.41it/s]
100%|██████████| 100/100 [00:18<00:00,  5.40it/s]
100%|██████████| 100/100 [00:18<00:00,  5.47it/s]
100%|██████████| 100/100 [00:18<00:00,  5.44it/s]


In [13]:
def premultiplier(Xu, Xv):
    uu = np.linalg.pinv(Xu @ Xu.T)
    uv = Xu @ Xv.T
    vv = np.linalg.pinv(Xv @ Xv.T)
    vu = Xv @ Xu.T

    return (uu, uv, vv, vu)

def chi_u(uu, uv, vv, vu):

    return ((uu @ uv - np.eye(uu.shape[0])) @ vv @ np.linalg.pinv(vu @ uu @ uv @ vv - np.eye(uu.shape[0])) @ vu - np.eye(uu.shape[0])) @ uu

def chi_v(uu, uv, vv, vu):

    return (uu @ uv - np.eye(uu.shape[0])) @ vv @ np.linalg.pinv(vu @ uu @ uv @ vv - np.eye(uu.shape[0]))

In [14]:
T = 0

H = {
    edge : {
        edge[0] : None,
        edge[1] : None
    }
for edge in combinations(nodes, 2)
}

In [15]:
for e in tqdm(combinations(nodes,2)):
    u = e[0]
    v = e[1]

    X_u = S[u]
    X_v = S[v]
    uu, uv, vv, vu = premultiplier(X_u, X_v)

    H[e][u] = chi_u(uu, uv, vv, vu)
    H[e][v] = chi_u(uu, uv, vv, vu)
    
    T += np.trace(H[e][u]) + np.trace(H[e][v])

21it [00:01, 20.45it/s]


In [16]:
T

5548209873579350.0

In [17]:
mu = 50

In [18]:
H = {
    edge : {
        edge[0] : mu/T * (H[edge][edge[0]]),
        edge[1] : mu/T * (H[edge][edge[1]])
    }
for edge in combinations(nodes, 2)
}

In [19]:
all_edges = list(combinations(range(V), 2))

energies = {
    e : 0
    for e in all_edges
    }

for e in (all_edges):
    u = e[0]
    v = e[1]
    
    '''
    X_ = np.zeros_like(X)
    X_[u*d:(u+1)*d,:] = X[u*d:(u+1)*d,:]
    X_[v*d:(v+1)*d,:] = X[v*d:(v+1)*d,:]
    energies[e] = np.linalg.norm(BB @ X_)
    '''

    energies[e] = np.linalg.norm(H[e][e[0]] @ S[e[0]] - H[e][e[1]] @ S[e[1]])

In [20]:
retrieved = sorted(energies.items(), key=lambda x:x[1])[:E]

In [21]:
len(set(list(map(lambda x: x[0], retrieved))).intersection(set(edges))) / E

0.375

In [22]:
retrieved

[((3, 4), 1.4243393477821986),
 ((3, 6), 1.868983475544535),
 ((5, 6), 2.519860649506638),
 ((0, 6), 2.7372507561512043),
 ((2, 6), 4.482251222636327),
 ((0, 1), 4.872186070771289),
 ((0, 5), 5.98541236642465),
 ((0, 2), 7.197714531956308)]