In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as functional
import torch.optim as optim
from homology import rips
from math import prod
import numpy as np
from util import create_distance_matrix


class TopologicalAutoEncoder(nn.Module):
    def __init__(self, **kwargs):
        # Input_shape is tuple with: (Amount of simplices, input dimensions)
        # Latent_space_size is: latent dimensions

        assert kwargs["input_shape"] is not None, "parameter input_shape required"
        assert kwargs["latent_space_size"] is not None, "parameter latent_space_size required"

        super().__init__()
        self.act = nn.ReLU()

        intermediate_layer_dimensions = (kwargs["input_shape"][1] + kwargs["latent_space_size"]) // 2

        self.lin_1 = nn.Linear(
            in_features=prod(kwargs["input_shape"]),
            out_features=kwargs["input_shape"][0] * intermediate_layer_dimensions
        )

        self.lin_2 = nn.Linear(
            in_features=kwargs["input_shape"][0] * intermediate_layer_dimensions,
            out_features=kwargs["input_shape"][0] * kwargs["latent_space_size"]
        )

        self.lin_3 = nn.Linear(
            in_features=kwargs["input_shape"][0] * kwargs["latent_space_size"],
            out_features=kwargs["input_shape"][0] * intermediate_layer_dimensions
        )

        self.lin_4 = nn.Linear(
            in_features=kwargs["input_shape"][0] * intermediate_layer_dimensions,
            out_features=prod(kwargs["input_shape"])
        )

    def forward(self, x):
        out = self.act(self.lin_1(x))
        latent = self.act(self.lin_2(out))
        out = self.act(self.lin_3(latent))
        out = self.act(self.lin_4(out))
        return out, latent


class TopAELoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.relevant_input = None
        self.relevant_latent = None
        self.A_input = None
        self.A_latent = None

    def forward(self, input, latent, output, point_count):
        # assert(dimensions == 1, "This implementation only supports 1 dimensional homology")

        loss = functional.mse_loss(input, output)

        # flattened layers into dimensional vectors
        # Select "important" simplices
        # print(input.reshape((point_count, input.shape[0] // point_count)))
        A_input = create_distance_matrix(input.reshape((point_count, input.shape[0] // point_count)))
        A_latent = create_distance_matrix(latent.reshape((point_count, latent.shape[0] // point_count)))

        relevant_input = self.__relevant_points(rips(A_input, dimensions=1)[0])
        relevant_latent = self.__relevant_points(rips(A_latent, dimensions=1)[0])

        loss_x = .5 * torch.norm(self.__relevant_distances(relevant_input, A_input) - self.__relevant_distances(relevant_input, A_latent))
        loss_z = .5 * torch.norm(self.__relevant_distances(relevant_latent, A_latent) - self.__relevant_distances(relevant_latent, A_input))

        print(loss_z)
        print(loss_x)
        print(loss)
        loss += loss_x + loss_z

        return loss

    def __relevant_distances(self, relevent_indices, distance_matrix):
        return torch.tensor([distance_matrix[i[0], i[1]] for i in relevent_indices])

    def __relevant_points(self, diagram):
        diagram = torch.from_numpy(diagram)
        non_zeros = np.count_nonzero(diagram, axis=0)
        simplex_indices = torch.where(diagram.T[np.where(non_zeros == 2)] == 1)[1]
        index_tuples = simplex_indices.reshape((len(simplex_indices) // 2, 2))
        return index_tuples


In [16]:
data = torch.flatten(torch.rand(5, 3))

torch.manual_seed(42)
ae = TopologicalAutoEncoder(input_shape=(5, 3), latent_space_size=2)
ae.train()

optimizer = optim.SGD(ae.parameters(), lr=0.01, momentum=0.9)
out, latent = ae(data)

# print(out.shape)
# print(data.shape)

loss_fn = TopAELoss()
loss = loss_fn(data, latent, out, 5)
print(loss.grad)
print(loss.data)
loss.backward()

print(data)
print()




tensor(1.9477)
tensor(1.9477)
tensor(0.3396, grad_fn=<MseLossBackward0>)
None
tensor(4.2351)
tensor([0.5475, 0.7896, 0.8881, 0.9037, 0.3273, 0.3882, 0.7410, 0.3636, 0.7341,
        0.3908, 0.1609, 0.7035, 0.5767, 0.7229, 0.9967])

