In [18]:
import sys
import os
from pathlib import Path

# Add project root to path and change working directory
current_dir = os.getcwd()
project_root = os.path.dirname(current_dir)
sys.path.insert(0, project_root)
os.chdir(project_root)

import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import ot
import scipy as sp
import matplotlib.pyplot as plt
from src import LoadCloudPoint, DistanceProfile, compute_W_matrix_distance_matrix_input, plot_3d_points_and_connections, accuracy

In [23]:
import random

random.seed(10)

lcp = LoadCloudPoint(filepath="datasets/csv_files/0005_Jogging001.csv")
source_pc, target_pc = lcp.get_two_random_point_cloud()

dp = DistanceProfile(source_pc, target_pc)
distance_matrix = dp.compute_L2_matrix()

Loaded point cloud data from datasets/csv_files/0005_Jogging001.csv, number of frames: 1377


In [24]:
distance_matrix[0].shape

(26, 26)

In [25]:
distance_matrix[1].shape

(26, 26)

In [7]:
# %load_ext autoreload

# Dist prof W1 with dist L1 norm

In [26]:
distance_matrix = dp.compute_L1_matrix()
W, map_matrix= compute_W_matrix_distance_matrix_input(distance_matrix[0], distance_matrix[1])
plot_3d_points_and_connections(source_pc, target_pc, map_matrix)

In [27]:
accuracy(map_matrix)

0.38461538461538464

# Dist prof with W1 loss OT

In [28]:
plot_3d_points_and_connections(source_pc, target_pc, map_matrix)

# Vanilla OT with position Coords

In [30]:
M = ot.dist(source_pc, target_pc)

N = source_pc.shape[0]
a = np.ones(N) / N
b = np.ones(N) / N
G = ot.solve(M, a, b).plan

In [31]:
plot_3d_points_and_connections(source_pc, target_pc, G)

In [32]:
accuracy(G)

0.038461538461538464

# Using GW on Dist loss L2

In [33]:
distance_matrix = dp.compute_L2_matrix()
T, logs = ot.gromov_wasserstein(distance_matrix[0], distance_matrix[1], a, b, 'square_loss', log=True)
plot_3d_points_and_connections(source_pc, target_pc, T)

In [16]:
accuracy(T)

0.38461538461538464

# GW with L1 dist

In [34]:
distance_matrix = dp.compute_L1_matrix()
T, logs = ot.gromov_wasserstein(distance_matrix[0], distance_matrix[1], a, b, 'square_loss', log=True)
fig = plot_3d_points_and_connections(source_pc, target_pc, T)
fig.show()
accuracy(T)

0.3076923076923077