In [1]:
import torch
import gudhi
import numpy as np

In [2]:
class SFTDLossGudhi:
    def __init__(self, dims = 1, card = 100, p = 1):
        self.dims = dims
        self.card = card
        self.p = p
        self.min_barcode = 0.

    def __call__(self, F1, G1):
        D_size = torch.Size([3] + list(F1.shape))
        D = torch.zeros(D_size)

        D[0] = torch.min(F1, G1)
        D[1] = F1
        D[2] = torch.min(torch.min(F1, G1)).expand(F1.shape)

        cubical_complex = gudhi.CubicalComplex(vertices = D.detach().numpy())
        cubical_complex.compute_persistence(homology_coeff_field = 2, min_persistence = 0.0)

        v = cubical_complex.vertices_of_persistence_pairs()
        self.cubical_complex = cubical_complex

        D_fortran = D.permute(*torch.arange(D.ndim - 1, -1, -1))
        D_flat = D_fortran.reshape(D.numel())
        loss = 0.

        self.barcodes = {}

        for dim in self.dims:

            self.barcodes[dim] = []

            if v[0] and len(v[0]) >= dim + 1:

                r = []

                for elem in v[0][dim]:
                    i, j = elem
                    if D_flat[j] - D_flat[i] > self.min_barcode:
                        r.append((D_flat[i], D_flat[j]))


                self.barcodes[dim] = r

                r_sorted = sorted(r, key = lambda x : x[1].item() - x[0].item(), reverse = True)
                part_loss = sum(map(lambda x : (x[1] - x[0]) ** self.p, r_sorted[:self.card]))

                loss += part_loss
        return loss

In [3]:
def cmp(a, b):
    return abs(a - b) > 1e-10

In [4]:
for trial in range(100):
    F1 = torch.rand((64, 64))
    G1 = torch.rand((64, 64))

    loss = SFTDLossGudhi(dims = [0, 1, 2], card = 100, p = 2)
    loss(F1, G1)

    for dim in [0, 1, 2]:
        correct_persistence = loss.cubical_complex.persistence_intervals_in_dimension(dim)
        correct_persistence_python = [(x[0], x[1]) for x in correct_persistence]
        correct_persistence_python = list(filter(lambda x : x[1] < np.inf, correct_persistence_python))

        test_persistence_python = [(x[0].item(), x[1].item()) for x in loss.barcodes[dim]]

        #print('dim', dim, len(correct_persistence_python), len(test_persistence_python))

        for elem1, elem2 in zip(sorted(correct_persistence_python), sorted(test_persistence_python)):    
            if cmp(elem1[0], elem2[0]) or cmp(elem1[1], elem2[1]):
                print('ERROR')
                
print('OK!')

OK!
