Skip to content

Commit

Permalink
change botteneck and update exmaple
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanL12 committed Jan 14, 2023
1 parent b005ec2 commit e10d55b
Show file tree
Hide file tree
Showing 7 changed files with 344 additions and 184 deletions.
320 changes: 195 additions & 125 deletions docs/examples/PD_opt.ipynb

Large diffs are not rendered by default.

5 changes: 2 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,17 @@
# would reflect directly in your environment.

from setuptools import setup
import setuptools

with open('README.md') as f:
long_description = f.read()

setup(name='torch-tda',
version='0.0.1',
version='0.0.2',
description='Automatic differentiation for topological data analysis',
long_description=long_description,
long_description_content_type="text/markdown",
author='Brad Nelson, Yuan Luo',
author_email='bradnelson@uchicago.edu, yuanluo@uchicago.edu',
author_email='bradnelson@uchicago.edu, luoyuan9809@gmail.com',
url='https://github.com/CompTop/torch-tda',
project_urls={
"Documentation": "https://torch-tda.readthedocs.io/en/latest/",
Expand Down
36 changes: 31 additions & 5 deletions torch_tda/nn/diagram_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,46 @@
import torch.nn as nn
from .functional import BottleneckDistance, WassersteinDistance, BottleneckDistanceHera
from .poly_feat import remove_zero_bars
import numpy as np
# import hera_tda as hera
import hera
import torch

# `from . import` is called Intra-package References
# see in https://docs.python.org/3/tutorial/modules.html#intra-package-references

class BottleneckLayerHera(nn.Module):
def __init__(self):
super(BottleneckLayerHera, self).__init__()
self.D = BottleneckDistanceHera()
# self.D = BottleneckDistanceHera()

def forward(self, dgm0, dgm1):
dgm0 = remove_zero_bars(dgm0)
dgm1 = remove_zero_bars(dgm1)
def forward(self, dgm1, dgm2, zero_out = True):
print("new hera bottleneck layer")
if not zero_out:
dgm1 = remove_zero_bars(dgm1)
dgm2 = remove_zero_bars(dgm2)

# return self.D.apply(dgm0, dgm1)

d1 = dgm1.detach().numpy()
d2 = dgm2.detach().numpy()
# find the bottleneck distance and the maixmum edge (two points in R^2)
dist, edge = hera.bottleneck_dist(d1, d2, return_bottleneck_edge=True)

# change the data type
b = [edge[1].get_birth(), edge[1].get_death()]
a = [edge[0].get_birth(), edge[0].get_death()]
# find the index of persistence pair in origin input diagrams
idx1, idx2 = np.where((d1 == np.array(a)).all(axis=1)), np.where((d2 == np.array(b)).all(axis=1))


if dgm1[idx1].shape[0] == 0:
return (dgm2[idx2][0][1] - dgm2[idx2][0][0])/2
elif dgm2[idx2].shape[0] == 0:
return (dgm1[idx1][0][1] - dgm1[idx1][0][0])/2
else:
return torch.max(torch.abs(dgm1[idx1] - dgm2[idx2]))

return self.D.apply(dgm0, dgm1)

class BottleneckLayer(nn.Module):
def __init__(self):
Expand Down
127 changes: 90 additions & 37 deletions torch_tda/nn/functional/bottleneck.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,113 @@
# Bottleneck distance

"""
Bottleneck distance
This is still under development as an open question:
Given two diagrams, it is possible that in the max edge (p1,p2) of the inf matching
where the bottleck distance is obtained, one point p1 happens to be a point on the diagonal and it is not a
'real' persistent pair (have a correspondence to the two birth and death simplices in the complex).
How can we deal with situation?
My tentative solution is to find the closet point on the diagonal to p2
to approximate the true bottleneck distance.
"""

import torch
from torch.autograd import Function
from persim import bottleneck
import numpy as np
import hera_tda as hera
# import hera_tda as hera
import hera

class BottleneckDistanceHera(Function):
"""
Compute bottleneck distance between two persistence diagrams

forward inputs:
dgm0 - N x 2 torch.float tensor of birth-death pairs
dgm1 - M x 2 torch.float tensor of birth-death pairs
def find_index_of_nearest(arr, pt):
distance = np.linalg.norm(arr - pt, ord=2, axis = 1)
return np.argmin(distance)

def seperate_zero_bars(dgm):
"""
@staticmethod
def forward(ctx, dgm0, dgm1):
ctx.dtype = dgm0.dtype
d0 = dgm0.detach().numpy()
remove zero bars from diagram
"""
inds = dgm[:,0] != dgm[:,1]
return dgm[inds,:], dgm[~inds,:]

def bott_dist_torch(in_dgm1, in_dgm2, zero_out = False):
# print("new hera bottleneck layer")
if not zero_out:
dgm1, zero_dgm1 = seperate_zero_bars(in_dgm1)
dgm2, zero_dgm2 = seperate_zero_bars(in_dgm2)


d1 = dgm1.detach().numpy()
n0 = len(dgm0)
ctx.n0 = n0
n1 = len(dgm1)
ctx.n1 = n1
d2 = dgm2.detach().numpy()
# find the bottleneck distance and the maixmum edge (two points in R^2)
dist, edge = hera.bottleneck_dist(d1, d2, return_bottleneck_edge=True)

# change the data type
b = [edge[1].get_birth(), edge[1].get_death()]
a = [edge[0].get_birth(), edge[0].get_death()]
# find the index of persistence pair in origin input diagrams
idx1 = np.where((d1 == np.array(a)).all(axis=1))
idx2 = np.where((d2 == np.array(b)).all(axis=1))

# Assume at least one point is off-diagonal as both diagonal situtation is rare
if dgm1[idx1].shape[0] == 0:
# one point on the diagonal dgm1 is matched to a point in dgm2
if dgm2.requires_grad: # do not bother to modify dgm1
return (dgm2[idx2][0][1] - dgm2[idx2][0][0])/2
else: # now we do not have
# print("undefined matching")
closet_idx = find_index_of_nearest(zero_dgm1.detach().numpy(), d2[idx2][0])
# need to find the closet pt in Diag of dgm1
return torch.max(torch.abs(dgm2[idx2] - zero_dgm1[closet_idx]))
elif dgm2[idx2].shape[0] == 0:
if dgm1.requires_grad:
return (dgm1[idx1][1] - dgm1[idx1][0])/2
else:
# print("undefined matching")
closet_idx = find_index_of_nearest(zero_dgm2.detach().numpy(), d1[idx1][0])
# need to find the closet pt in Diag of dgm1
return torch.max(torch.abs(dgm1[idx1] - zero_dgm2[closet_idx])), dist
else:
return torch.max(torch.abs(dgm1[idx1] - dgm2[idx2]))

dist, match = hera.bottleneck.BottleneckDistance(d0, d1)
i0, i1 = match

# TODO check for -1 as index

ctx.i0 = i0
ctx.i1 = i1
class BottleneckDistanceHera(Function):
"""
Compute bottleneck distance between two persistence diagrams
d01 = torch.tensor(d0[i0] - d1[i1], dtype=ctx.dtype)
ctx.d01 = d01
dist01 = np.linalg.norm(d0[i0] - d1[i1], np.inf)
ctx.indmax = np.argmax(np.abs(d0[i0] - d1[i1]))
TODO: This torch function is problematic, should use the above bott_dist_torch() function
return torch.tensor(dist01, dtype=ctx.dtype)
forward inputs:
dgm1 - N x 2 torch.float tensor of birth-death pairs
dgm2 - M x 2 torch.float tensor of birth-death pairs
"""

@staticmethod
def backward(ctx, grad_dist):
n0 = ctx.n0
n1 = ctx.n1
i0 = ctx.i0
i1 = ctx.i1
d01 = ctx.d01
def forward(ctx, dgm1, dgm2):
d1 = dgm1.detach().numpy()
d2 = dgm2.detach().numpy()
# find the bottleneck distance and the maixmum edge (two points in R^2)
dist, edge = hera.bottleneck_dist(d1, d2, return_bottleneck_edge=True)

# change the data type
b = [edge[1].get_birth(), edge[1].get_death()]
a = [edge[0].get_birth(), edge[0].get_death()]
# find the index of persistence pair in origin input diagrams
idx1, idx2 = np.where((d1 == np.array(a)).all(axis=1)), np.where((d2 == np.array(b)).all(axis=1))

gd0 = torch.zeros(n0, 2, dtype=ctx.dtype)
gd1 = torch.zeros(n1, 2, dtype=ctx.dtype)

if dgm1[idx1].shape[0] == 0:
return (dgm2[idx2][0][1] - dgm2[idx2][0][0])/2
elif dgm2[idx2].shape[0] == 0:
return (dgm1[idx1][0][1] - dgm1[idx1][0][0])/2
else:
return torch.max(torch.abs(dgm1[idx1] - dgm2[idx2]))

gd0[i0, ctx.indmax] = np.sign(d01[ctx.indmax]) * grad_dist
gd1[i1, ctx.indmax] = -np.sign(d01[ctx.indmax]) * grad_dist

return gd0, gd1


class BottleneckDistance(Function):
Expand Down
2 changes: 1 addition & 1 deletion torch_tda/nn/functional/rips.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def sparse_pairwise_dist(D, eps = 0.15, dense_output=False):

class RipsDiagram(Function):
"""
This can be uncessary because we can do auto-diff by computing Diagram direcely from input point set matrix
(Outdated) we can do auto-diff by computing Diagram direcely from input point set matrix
Compute Rips complex persistence using point coordinates
forward inputs:
Expand Down
11 changes: 5 additions & 6 deletions torch_tda/nn/poly_feat.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,21 +131,20 @@ class BarcodePolyFeature(nn.Module):
sum length^p * mean^q
over lengths and means of barcode
parameters:
dim - homology dimension to work over
p - exponent for lengths
q - exponent for means
remove_zero = Flag to remove zero-length bars (default=True)
"""
def __init__(self, dim, p, q, remove_zero=True):
def __init__(self, p, q, remove_zero=True):
super(BarcodePolyFeature, self).__init__()
self.dim = dim
# self.dim = dim
self.p = p
self.q = q
self.remove_zero = remove_zero

def forward(self, dgms):
issublevel = True
dgm = dgms[self.dim]
def forward(self, dgm, issublevel = True):
# dgm: torch tensor of shape (n,2)
# dgm = dgms[self.dim]
if self.remove_zero:
dgm = remove_zero_bars(dgm)
lengths, means = get_barcode_lengths_means(dgm, issublevel)
Expand Down
27 changes: 20 additions & 7 deletions torch_tda/nn/rips.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,17 @@ def RCC_to_persistence_vertex_indices(R, scplex, imap):
imap (list of length complex's highest dimension):
inverse map that is able to map the index of simplex to its edge index
Output:
- First the regular persistence pairs of dimension 0, with one vertex for birth and two for death;
- then the other regular persistence pairs, grouped by dimension, with 2 vertices per extremity;
- then the connected components, with one vertex each;
- finally the other essential features, grouped by dimension, with 2 vertices for birth.
Return type: Tuple[
- numpy.array[int] of shape (n,3),
- List[numpy.array[int] of shape (m,4)], where the i-th element is the array of persistence pairs at dimension i+1
- numpy.array[int] of shape (l,),
- List[numpy.array[int] of shape (k,2)]
'''
def find_vet_idx(a):
"""Find indices vertices of a given edge index, a of shape (1,) """
Expand Down Expand Up @@ -62,7 +73,7 @@ def find_vet_idx(a):
class RipsLayer(nn.Module):
"""
Define a Rips persistence layer that will use the Rips Diagram function.
Here we return the all essential and regular persistence pairs
Here we return the all essential(infinite death) and regular(finite death) persistence pairs
we leave users to decide if they want to use essential pairs or zero-length bars
in practice
Input:
Expand All @@ -76,23 +87,25 @@ class RipsLayer(nn.Module):
https://bats-tda.readthedocs.io/en/latest/tutorials/Rips.html#Algorithm-optimization
Output:
dgms : list of length `maxdim`, where each element is an numpy array of shape (n,2)
note: infinite death == float('inf')
bdinds : list of length `maxdim`, where each element is an numpy array of shape (n,2)
note: infinite death index == -1
persistence_dgs: list (length 4) of persistence diagrams in each dimension. We separate them by dimension
and regular/essential, which are 0-dim regular pairs, 1-dim regular pairs,
0-dim essential pair(only one in the case Rips) and
Return type: List[
- tensor[float] of shape (n-1,2), where n is the number of points
- List[tensor[float] of shape (m,2)]
- tensor[float] of shape (1,1)
- List[tensor[float] of shape (k,)]
"""
def __init__(self, maxdim = 0, degree = -1, metric = 'euclidean', sparse = False, eps=0.5, reduction_flags=()):
super(RipsLayer, self).__init__()
self.maxdim = maxdim
self.degree = degree
self.sparse = sparse
self.eps = eps
# self.PD = RipsDiagram()
self.metric = metric
self.reduction_flags = reduction_flags

def forward(self, X):
# dgms = self.PD.apply(x, self.maxdim, self.degree, self.metric , self.sparse, self.eps, *self.reduction_flags)
# change dgms to make it able auto-diff
Xnp = X.cpu().detach().numpy() # convert to numpy array
# Xnp.astype('double')
Expand Down

0 comments on commit e10d55b

Please sign in to comment.