In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR

import numpy as np
import gudhi
from gudhi.wasserstein import wasserstein_distance

import matplotlib.pyplot as plt
from IPython.display import clear_output
import ot
import time

In [None]:
# Define constants
num_train_samples = 1   # always 1 because we are using iterativ learning
cloud = 500             # points in point cloud
pts = 3                 # xyz coordinates
channels = 1            # no color channels

output_classes = cloud*pts

# load input and target models which have already been normalized
loaded_array = np.load("npy/input_norm.npy")
loaded_array2 = np.load("npy/target_norm.npy")

# subsampled_arr = loaded_array[:500,:]
# subsampled_arr2 = loaded_array2[:500,:]

# Convert data to PyTorch tensors
X_train_tensor = torch.FloatTensor(loaded_array)
targets_train_tensor = torch.FloatTensor(loaded_array2)

In [None]:
class SharedMLP(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(SharedMLP, self).__init__()
        self.fc1 = nn.Linear(in_channels, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, out_channels)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class PointNet(nn.Module):
    def __init__(self):
        super(PointNet, self).__init__()
        self.shared_mlp = SharedMLP(3, 3)

    def forward(self, x):
        # Reshape to (n, 3) to apply shared MLP
        x = x.view(-1, 3)

        # Apply shared MLP
        x = self.shared_mlp(x)

        # Reshape back to (n, 3)
        x = x.view(-1, 3)

        # normalize output points
        # max_vals, _ = torch.abs(x).max(dim=0)
        # x = 2*x / max_vals

        return x

model = PointNet()
optimizer = optim.Adam(model.parameters(), lr=.01)

In [None]:
def preventCollapse(inputs,points, targets, epoch, delta=0.2, epsilon=1e-6, alpha=0.3, beta=0.7):
    pairwise_distances = torch.cdist(points, points)

    # Set diagonal elements of the mask to False
    mask = torch.ones_like(pairwise_distances, dtype=torch.bool)
    mask[torch.arange(points.size(0)), torch.arange(points.size(0))] = False

    # Mask where distances are lesser than delta
    mask &= pairwise_distances < delta
    num_terms = mask.sum()

    # Calculate sum of distances where distance is lesser than delta
    sum_distances = torch.sum(1/pairwise_distances[mask])

    # Calculate loss
    loss = sum_distances / (num_terms + 1e-6)  # Avoid division by zero

    if epoch % 10 == 0:
        print('COLLAPSE LOSS\t:', loss)

    return loss

In [None]:
def geoloss(inpts, pts, pts2, idx):
    distances_1to2 = torch.cdist(pts, inpts, p=2)  # Shape: (n, m)
    distances_2to1 = torch.cdist(inpts, pts, p=2)  # Shape: (m, n)

    # Find minimum distances for each point in both directions
    min_distances_1to2 = torch.min(distances_1to2, dim=1).values  # Shape: (n,)
    min_distances_2to1 = torch.min(distances_2to1, dim=1).values  # Shape: (m,)

    # Compute Chamfer distance (mean of minimum distances in both directions)
    chamfer_dist = torch.mean(min_distances_1to2) + torch.mean(min_distances_2to1)

    if idx % 10 == 0:
        print('GEOMETRIC LOSS\t:', chamfer_dist)
    return chamfer_dist


In [None]:
def topoloss(inpts,pts,pts2,idx):
    # Increasing minimum persistence after a few iterations so that initially diagram contains points
    max_length_rips = 0.3
    min_per = 0.01
    if(idx > 40):
        max_length_rips = 0.3
        min_per = 0.04

    # Create rips complex using point cloud with 0.3 max edge length which is suitable for the normalized point cloud
    rips = gudhi.RipsComplex(points=pts, max_edge_length=max_length_rips)
    rips2 = gudhi.RipsComplex(points=pts2, max_edge_length=0.3)

    # Create simplex tree from rips complex and compute persistence
    st = rips.create_simplex_tree(max_dimension=2)
    st.compute_persistence(min_persistence=min_per)
    st2 = rips2.create_simplex_tree(max_dimension=2)
    st2.compute_persistence(min_persistence=.04)

    # find reverse mapping of points in the persistence diagram
    i = st.flag_persistence_generators()
    j = st2.flag_persistence_generators()

    # It is better to avoid this condition completely
    if ( len(i[1]) == 0 or len(j[1])==0 ):
        i1 = torch.tensor(i[0])  # pytorch sometimes interprets it as a tuple otherwise
        j1 = torch.tensor(j[0])  # pytorch sometimes interprets it as a tuple otherwise
        diag1 = torch.norm(pts[i1[:, (0, 0)]] - pts[i1[:, (1, 2)]], dim=-1)
        diag2 = torch.norm(pts2[j1[:, (0, 0)]] - pts2[j1[:, (1, 2)]], dim=-1)
        return wasserstein_distance(diag1, diag2, order=2, keep_essential_parts=False, enable_autodiff=True)

    i1 = torch.tensor(i[1][0])
    j1 = torch.tensor(j[1][0])

    # Same as the finite part of st.persistence_intervals_in_dimension(1), but differentiable
    diag1 = torch.norm(pts[i1[:, (0, 2)]] - pts[i1[:, (1, 3)]], dim=-1)
    diag2 = torch.norm(pts2[j1[:, (0, 2)]] - pts2[j1[:, (1, 3)]], dim=-1)

    perstot1 = wasserstein_distance(diag1, diag2, order=2, keep_essential_parts=False, enable_autodiff=True)

    # For visualization purposes
    b_points = pts[i1[:, (0, 2)]] # birth points
    d_points = pts[i1[:, (1, 3)]] # death points
    b_critical = b_points.view(b_points.shape[0]*b_points.shape[1], 3)
    d_critical = d_points.view(d_points.shape[0]*d_points.shape[1], 3)

    if(idx%10 == 0):

        print('TOPO LOSS\t:', perstot1)
        P = pts.detach().numpy()

        # Create a 3D scatter plot
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')

        b_critical = b_critical.detach().numpy()
        d_critical = d_critical.detach().numpy()

        ax.scatter(P[:, 0], P[:, 1], P[:, 2])
        ax.scatter(d_critical[:, 0], d_critical[:, 1], d_critical[:, 2], color='yellow', s=20)
        ax.scatter(b_critical[:, 0], b_critical[:, 1], b_critical[:, 2], color='red', s=20)

        # Set labels and title for the plot
        ax.set_xlabel('X Label')
        ax.set_ylabel('Y Label')
        ax.set_zlabel('Z Label')
        ax.set_title('3D Scatter Plot of 3D Points', color="black")
        plt.show()

    return perstot1

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
inputs = X_train_tensor.to(device)
targets = targets_train_tensor.to(device)

# # For Threading
# total_loss = 0
# def execute_function_and_update_sum(func, inputs, outputs, targets,epoch):
#     global total_loss
#     result = func(inputs, outputs, targets,0)
#     total_loss += result
# duration_in_seconds = 3600
# start_time = time.time()
# epoch = 0


num_epochs = 1000
for epoch in range(num_epochs):
    if(epoch%10 == 0):
        clear_output(wait=True)
        print("EPOCH:", epoch)

    # Forward pass
    outputs = model(inputs)
    optimizer.zero_grad()

    lr = 0.1 * (epoch)/100

    lambda_geo = 0.6
    lambda_topo = 0.3
    lambda_collapse = 0.1

    loss =  lambda_topo*topoloss(inputs, outputs, targets,epoch) + lambda_geo*geoloss(inputs, outputs, targets,epoch) + lambda_collapse*preventCollapse(inputs, outputs, targets,epoch, 0.05)

    # # Threading
    # total_loss = 0
    # thread1 = threading.Thread(target=execute_function_and_update_sum, args=(topoloss,inputs, outputs, targets,epoch))
    # thread2 = threading.Thread(target=execute_function_and_update_sum, args=(geoloss,inputs, outputs, targets,epoch))
    # thread3 = threading.Thread(target=execute_function_and_update_sum, args=(preventCollapse,inputs, outputs, targets,epoch))
    # thread1.start()
    # thread2.start()
    # thread3.start()
    # thread1.join()
    # thread2.join()
    # thread3.join()
    # print("Total Loss\t:", total_loss)

    if (epoch%10 == 0):
        print("Total Loss\t:", loss)

    loss.backward()
    optimizer.step()

    # # Threading
    # epoch += 1
    # if time.time() - start_time >= duration_in_seconds:
    #     break

print("Training Complete")


In [None]:
# Final Output Model
outputs = model(inputs)
P = outputs.detach().numpy()
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(P[:, 0], P[:, 1], P[:, 2])
ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
ax.set_zlabel('Z Label')
ax.set_title('3D Scatter Plot of 3D Points')
plt.show()

In [None]:
# Save npy and xyz file for visualization
np.save('output_3d.npy', outputs.detach().numpy())
xyz_file = 'output_3d.xyz'
with open(xyz_file, 'w') as f:
        # Write each point to the .xyz file
        for point in outputs:
            x, y, z = point
            f.write(f"{x} {y} {z}\n")

In [None]:
# Input model
P = inputs.detach().numpy()
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(P[:, 0], P[:, 1], P[:, 2])
ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
ax.set_zlabel('Z Label')
ax.set_title('3D Scatter Plot of 3D Points')
plt.show()

In [None]:
# Target topology model
P = targets.detach().numpy()
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(P[:, 0], P[:, 1], P[:, 2])
ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
ax.set_zlabel('Z Label')
ax.set_title('3D Scatter Plot of 3D Points')
plt.show()