In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Local Libraries 
import sys
sys.path.append("..")
from SO3.utils.curve_utils import vector_to_skew_matrix_single_rotation as hatmap, log_map 

In [None]:
def is_SO3(matrix):
    if matrix.shape[-2:] != (3, 3):
        return False
    for mat in np.reshape(matrix, (-1, 3, 3)):
        if not np.allclose(mat @ mat.T, np.eye(3)):
            return False
    if not np.all(np.isclose(np.linalg.det(matrix), 1)):
        return False
    return True

def plot_rotations(c, plot_sphere=True):
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.grid(False)
    ax.set_axis_off()

    # Original unit vector (z-axis)
    original_vector = np.array([0, 0, 1])

    for R in c:
        # Apply the rotation to the vector
        rotated_vector = np.dot(R, original_vector)

        # Plotting
        ax.scatter(*rotated_vector, color='red', alpha=0.5)

    # Create a sphere
    def sphere(ax):
        u = np.linspace(0, 2 * np.pi, 100)
        v = np.linspace(0, np.pi, 100)
        x = np.outer(np.cos(u), np.sin(v))
        y = np.outer(np.sin(u), np.sin(v))
        z = np.outer(np.ones(np.size(u)), np.cos(v))

        ax.plot_surface(x, y, z, color='y', alpha=.5)

    if plot_sphere:
        sphere(ax)

def create_g_dot(xi):
    def g_dot(t, g):
        g_matrix = np.reshape(g, (3, 3))
        xi_matrix = hatmap(xi(t))
        return np.ravel(g_matrix @ xi_matrix)
    return g_dot

In [None]:
g0 = np.eye(3).ravel()

xi_1 = lambda t: np.array([2 * np.sin(3*t) * np.exp(t), 4 * np.cos(3 * t), 3 * t* np.sin(t) * np.cos(t)])
xi_2 = lambda t: np.array([t * 3, 4 * np.sin(10 * t), 3 * t* np.sin(t) * np.cos(t)])
xi_3 = lambda t: np.array([4 * t ** 2, 5 * np.sin(4 * t) * np.sin(6 * t), 3 * t * np.cos(t)])

xi_4 = lambda t: np.array([3 * np.sin(2 * t) * np.exp(-t), 5 * np.cos(2 * t), 4 * t * np.sin(t) * np.cos(t)])
xi_5 = lambda t: np.array([t ** 3, 5 * np.sin(5 * t), 2 * t * np.sin(t) * np.sin(t)])
xi_6 = lambda t: np.array([2 * t ** 2, 3 * np.sin(6 * t) * np.cos(t), 5 * t * np.cos(t) * np.cos(t)])
xi_7 = lambda t: np.array([-1, 6 * np.sin(3 * t), 3 * t * np.sin(t)])
xi_8 = lambda t: np.array([t ** 2, 4 * np.cos(3 * t), t * np.sin(t)])
xi_9 = lambda t: np.array([3 * t, 5 * np.sin(2 * t), 2 * t * np.cos(t)])

xi_lst = [xi_1, xi_2, xi_3, xi_4, xi_5, xi_6, xi_7, xi_8, xi_9]


In [None]:
def rk4_step(func, t, y, dt):
    k1 = func(t, y)
    k2 = func(t + dt/2, y + dt/2 * k1)
    k3 = func(t + dt/2, y + dt/2 * k2)
    k4 = func(t + dt, y + dt * k3)
    return y + dt/6 * (k1 + 2*k2 + 2*k3 + k4)

def solve_ivp_rk4(func, y0, t):
    y = np.empty((len(t), len(y0)))
    y[0] = y0
    for i in range(1, len(t)):
        y[i] = rk4_step(func, t[i-1], y[i-1], t[i] - t[i-1])
    return y

In [None]:
t_eval = np.linspace(0, 1, 100)

g_dot_lst = []
g_t_lst = []

for i in range(9): 
    g_dot = create_g_dot(xi_lst[i])
    g_t = solve_ivp_rk4(g_dot, g0, t_eval)
    g_t = g_t.reshape(len(t_eval), 3, 3)

    g_dot_lst.append(g_dot)
    g_t_lst.append(g_t)

for i, g_t in enumerate(g_t_lst):
    assert is_SO3(g_t), f"Solution {i} is not in SO(3)"

In [None]:
import os

def plot_rotations(c, filename, plot_sphere=True):
    fig = plt.figure(figsize=(5, 5))
    ax = fig.add_subplot(111, projection='3d')
    ax.grid(False)  # Turn off the grid
    ax.set_axis_off()  # Turn off the axis

    # Original unit vector (z-axis)
    original_vector = np.array([0, 0, 1])

    for R in c:
        # Apply the rotation to the vector
        rotated_vector = np.dot(R, original_vector)

        # Plotting
        ax.scatter(*rotated_vector, color='red', alpha=0.5)

    ax.set_xlim(-0.65, 0.65)
    ax.set_ylim(-0.65, 0.65)
    ax.set_zlim(-0.5, 0.5)

    # Create a sphere
    def sphere(ax):
        u = np.linspace(0, 2 * np.pi, 100)
        v = np.linspace(0, np.pi, 100)
        x = np.outer(np.cos(u), np.sin(v))
        y = np.outer(np.sin(u), np.sin(v))
        z = np.outer(np.ones(np.size(u)), np.cos(v))

        ax.plot_surface(x, y, z, color='y', alpha=.5)

    if plot_sphere:
        sphere(ax)

    plt.tight_layout(pad=0)
    # plt.show()
    output_dir = "figures/syntetic_data/"
    # os.makedirs(output_dir, exist_ok=True)  # Create the directory if it does not exist

    fig_path = os.path.join(output_dir, f"{filename}.png")
    fig.savefig(fig_path, format='png', bbox_inches='tight', pad_inches=0) 

In [None]:
# for i, g_t in enumerate(g_t_lst):
#     plot_rotations(g_t, f"SO3_fig{i+1}")

# plot_rotations(g_t_1, "SO3_fig1")
# plot_rotations(g_t_2, "SO3_fig2")
# plot_rotations(g_t_3, "SO3_fig3")

### Create the reparametrization 

In [None]:
import numpy as np

def is_diff_plus(c): 
    if c[0] != 0: 
        print('c[0] != 0')
        return False
    if c[-1] != 1: 
        print(f'c[-1] != 1')
        print(c[-1])
        return False
    if np.diff(c).min() < 0: 
        print('np.diff(c).min() < 0')
        return False
    return True

def basis_function(n,x): 
    return np.sin(n*np.pi*x) / (n * np.pi)

def I(x): 
    return x 

def varphi_func(x, I, f, *args):
    return I(x) + f(x, *args)

def pi(w, epsilon):
    norm_w = np.linalg.norm(w, 1)
    # if norm_w > 1 - epsilon:
    #     print(f"norm_w: {norm_w}")
    scaling_factor = (1 - epsilon) / max(1 - epsilon, norm_w)
    return scaling_factor * w

def generate_and_transform_weights(random, epsilon, M, std):
    if random: weights = np.random.normal(0, std, M - 1)
    else: weights = np.ones(M - 1)
    weights = pi(weights, epsilon)
    return weights

def generate_delta_from_basis(x, M, random=True, epsilon=1e-8, std = 1):
    weights = generate_and_transform_weights(random, epsilon, M, std)
    delta = sum(weights[j - 1] * basis_function(j, x) for j in range(1, M))
    delta[np.abs(delta) < 1e-15] = 0
    return delta

In [None]:
import matplotlib.pyplot as plt

M = 4
random = True
epsilon = 1e-8
std = 2
x = np.linspace(0, 1, 100)

np.random.seed(2)
varphi_1 = varphi_func(x, I, generate_delta_from_basis, M, random, epsilon, std)

np.random.seed(3)
varphi_2 = varphi_func(x, I, generate_delta_from_basis, M, random, epsilon, std)

np.random.seed(5)
varphi_3 = varphi_func(x, I, generate_delta_from_basis, M, random, epsilon, std)

plt.plot(x, varphi_1, label=r'$\varphi_1(x)$')
plt.plot(x, varphi_2, label=r'$\varphi_2(x)$')
plt.plot(x, varphi_3, label=r'$\varphi_3(x)$')

# import pandas as pd

# # Create a DataFrame from your data
# df_varphi_1 = pd.DataFrame({'x': x, 'varphi_1': varphi_1})
# df_varphi_2 = pd.DataFrame({'x': x, 'varphi_2': varphi_2})
# df_varphi_3 = pd.DataFrame({'x': x, 'varphi_3': varphi_3})

# df_varphi_1.to_csv("figures/syntetic_data/parameterization/varphi_1.csv", index=False)
# df_varphi_2.to_csv("figures/syntetic_data/parameterization/varphi_2.csv", index=False)
# df_varphi_3.to_csv("figures/syntetic_data/parameterization/varphi_3.csv", index=False)

plt.legend()
plt.show()

In [None]:
delta = generate_delta_from_basis(x, M, random, epsilon, std)
plt.plot(x, delta)
plt.show()

### Find the reparametrization of the problem.

In [None]:
from SO3.utils.curve_utils import * 
from SO3.utils.reparameterization_utils import *

def find_reparameterization(c1, c2, depth = 10): 
    I = np.linspace(0, 1, len(c1))
    c1 = move_rotation_origin_to_identity(c1)
    c2 = move_rotation_origin_to_identity(c2)

    q1 = skew_matrix_to_vector_single_rotation(SRVT_single_rotation(I,c1))
    q2 = skew_matrix_to_vector_single_rotation(SRVT_single_rotation(I,c2))

    I_new = find_optimal_diffeomorphism(q1, q2, I , I, depth)
    return I_new

In [None]:
# import pandas as pd

# x = np.linspace(0, 1, 100)

# # Initialize empty dataframes
# df_varphi_1 = pd.DataFrame({'x' : x})
# df_varphi_2 = pd.DataFrame({'x' : x})
# df_varphi_3 = pd.DataFrame({'x' : x})

# for i, g_func in enumerate(g_dot_lst[:3]):
#     for j, varphi in enumerate([varphi_1, varphi_2, varphi_3]):
#         c2 = solve_ivp_rk4(g_func, g0, x)
#         c2 = c2.reshape(len(x), 3, 3)

#         c1 = solve_ivp_rk4(g_func, g0, varphi)
#         c1 = c1.reshape(len(varphi), 3, 3)

#         I_new = find_reparameterization(c1, c2)

#         # Store the result in the appropriate dataframe
#         if j == 0:
#             df_varphi_1[f"g{i}"] = I_new
#         elif j == 1:
#             df_varphi_2[f"g{i}"] = I_new
#         else:
#             df_varphi_3[f"g{i}"] = I_new

# df_varphi_1.to_csv('figures/syntetic_data/reparameterization_SO3/df_varphi_1.csv', index=False)
# df_varphi_2.to_csv('figures/syntetic_data/reparameterization_SO3/df_varphi_2.csv', index=False)
# df_varphi_3.to_csv('figures/syntetic_data/reparameterization_SO3/df_varphi_3.csv', index=False)

In [None]:
# import numpy as np
# import pandas as pd
# import time 

# x = np.linspace(0, 1, 100)
# time_data = {}

# for i, g_func in enumerate(g_dot_lst[:3]):
#     for j, varphi in enumerate([varphi_1, varphi_2, varphi_3]):
#         # Prepare an empty list to store individual rows as dictionaries
#         data = []

#         for depth in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]:
#             start_time = time.time()

#             c2 = solve_ivp_rk4(g_func, g0, x)
#             c2 = c2.reshape(len(x), 3, 3)

#             c1 = solve_ivp_rk4(g_func, g0, varphi)
#             c1 = c1.reshape(len(varphi), 3, 3)

#             I_new = find_reparameterization(c1, c2, depth=depth)
#             I = np.linspace(0, 1, len(c1))

#             c2 = reparameterize_rotation(I_new, I, c2)
#             c1 = move_rotation_origin_to_identity(c1)
#             c2 = move_rotation_origin_to_identity(c2)

#             q1 = skew_matrix_to_vector_single_rotation(SRVT_single_rotation(I, c1))
#             q2 = skew_matrix_to_vector_single_rotation(SRVT_single_rotation(I, c2))

#             L2_distance = L2_metric(q1, q2, I, I)

#             elapsed_time = time.time() - start_time 

#             # Append the result as a dictionary to the data list
#             data.append({'depth': depth, 'L2_distance': L2_distance})

#             if depth not in time_data:
#                 time_data[depth] = []

#             time_data[depth].append(elapsed_time)

#         df = pd.DataFrame(data)
#         dir_path = 'figures/syntetic_data/reparameterization_SO3/depth_error'
        # os.makedirs(dir_path, exist_ok=True)
        # df.to_csv(f'{dir_path}/g{i}_varphi{j}.csv', index=False)

# mean_time_data = {}

# for depth, times in time_data.items():
#     mean_time_data[depth] = np.mean(times)

# # If you want to convert it to a DataFrame
# mean_time_df = pd.DataFrame(list(mean_time_data.items()), columns=['depth', 'mean_time'])

# mean_time_df.to_csv(f'{dir_path}/mean_time.csv', index=False)



In [None]:
varphi_func_1 = lambda x: np.sin(np.pi / 4 * x) / np.sin(np.pi / 4)
varphi_func_2 = lambda x: np.cos(np.pi / 6 * x) * np.sin(np.pi / 7 * x) / (np.sin(np.pi / 7) * np.cos(np.pi / 6))

x = np.linspace(0, 1, 100)
plt.plot(x, varphi_func_1(x))
plt.plot(x, varphi_func_2(x))
plt.plot(x, x)
plt.show()

In [None]:
import numpy as np
import pandas as pd
import time 

num_elements_lst = np.array([11, 21, 41, 81, 161])
f = lambda x: int(x / 8)

matrix = np.zeros((3, 2, len(num_elements_lst)))

for i, g_func in enumerate(g_dot_lst[:3]):
    for j, varphi_func in enumerate([varphi_func_1, varphi_func_2]):
        if i != 0 or j != 0:
            continue
        # Prepare an empty list to store individual rows as dictionaries
        data = []

        for k, num_elements in enumerate(num_elements_lst):
            print(f"i: {i}, j: {j}, k: {k}")

            x = np.linspace(0, 1, num_elements)
            varphi = varphi_func(x)
            target_value = 1/10
            depth = np.argmin(np.abs(x - target_value))
            print(f"depth: {depth}")
            print(f"The depth element is: {x[depth]}")

            c2 = solve_ivp_rk4(g_func, g0, x)
            c2 = c2.reshape(len(x), 3, 3)

            c1 = solve_ivp_rk4(g_func, g0, varphi)
            c1 = c1.reshape(len(varphi), 3, 3)

            I_new = find_reparameterization(c1, c2, depth=depth)
            I = np.linspace(0, 1, len(c1))

            c2 = reparameterize_rotation(I_new, I, c2)
            c1 = move_rotation_origin_to_identity(c1)
            c2 = move_rotation_origin_to_identity(c2)

            q1 = skew_matrix_to_vector_single_rotation(SRVT_single_rotation(I, c1))
            q2 = skew_matrix_to_vector_single_rotation(SRVT_single_rotation(I, c2))

            L2_distance = L2_metric(q1, q2, I, I)
            data.append({'num_elements': num_elements, 'L2_distance': L2_distance})
            matrix[i, j, k] = L2_distance


In [None]:
# Plot loglog 
for i in range(1):
    for j in range(1):
        plt.loglog(num_elements_lst[::-1], matrix[i, j], label=f"g{i+1}, varphi{j+1}")

plt.loglog(num_elements_lst[::-1], num_elements_lst[::-1], label='Order 1')
plt.show()

In [None]:
import numpy as np
import pandas as pd
import time 

x = np.linspace(0, 1, 100)
time_data = {}
curves = {}

counter = 0
for i, g_func in enumerate(g_dot_lst[:3]):
    for j, varphi in enumerate([varphi_1, varphi_2, varphi_3, x]):
        curve = solve_ivp_rk4(g_func, g0, varphi)
        curve = curve.reshape(len(varphi), 3, 3)
        curves[f"c_{counter}"] = curve
        counter += 1


In [None]:
curves.keys()

In [None]:
# distances = np.zeros((len(curves), len(curves)))

# for i, c1 in enumerate(curves.values()):
#     for j, c2 in enumerate(curves.values()):
#         if i <= j:
#             continue
            
#         I_new = find_reparameterization(c1, c2, depth=4)
#         I = np.linspace(0, 1, len(c1))

#         c2 = reparameterize_rotation(I_new, I, c2)
#         c1 = move_rotation_origin_to_identity(c1)
#         c2 = move_rotation_origin_to_identity(c2)

#         q1 = skew_matrix_to_vector_single_rotation(SRVT_single_rotation(I, c1))
#         q2 = skew_matrix_to_vector_single_rotation(SRVT_single_rotation(I, c2))

#         L2_distance = L2_metric(q1, q2, I, I)
#         distances[i, j] = L2_distance
#         distances[j, i] = L2_distance


In [None]:
# import matplotlib.pyplot as plt

# # Plot the distance matrix
# plt.figure(figsize=(8, 5))
# plt.imshow(distances, cmap='hot', interpolation='nearest')
# plt.colorbar()
# # plt.show()

# plt.savefig(f"figures/syntetic_data/distance_matrix/SO3.png")

### Signature

In [None]:
# import iisignature
# import numpy as np
# from scipy.stats import wasserstein_distance

# import numpy.fft as fft

# for level in [7]:
#     preprocessed_data = iisignature.prepare(3, level)
#     signatures = {}

#     for name, c in curves.items():
#         I = np.linspace(0, 1, len(c))
#         c = move_rotation_origin_to_identity(c)
#         q = skew_matrix_to_vector_single_rotation(right_log_single_rotation(I, c))
#         sig = iisignature.logsig(q, preprocessed_data)


#         # Set a sparsity threshold
#         threshold = 0.01

#         # Zero out all elements below the threshold
#         sparse_sig = np.where(np.abs(sig) > threshold * np.abs(sig).max(), sig, 0)

#         signatures[name] = sparse_sig


#     signature_matrix = np.zeros((len(signatures), len(signatures)))
#     for i, s1 in enumerate(signatures.values()):
#         for j, s2 in enumerate(signatures.values()):
#             distance =  np.linalg.norm(s1 / np.linalg.norm(s1) - s2 / np.linalg.norm(s2))
#             # distance = np.linalg.norm(s1 - s2)
#             # distance = 1 - np.dot(s1, s2) / (np.linalg.norm(s1) * np.linalg.norm(s2))
#             # distance = wasserstein_distance(s1, s2)
#             signature_matrix[i, j] = distance

#     plt.figure(figsize=(8, 5))
#     plt.imshow(signature_matrix, cmap='hot', interpolation='nearest')
#     plt.colorbar()
#     # plt.savefig(f"figures/syntetic_data/distance_matrix/SO3_signature_{level}.png", bbox_inches='tight', pad_inches=0)
#     plt.show()

In [None]:
# import iisignature
# import numpy as np
# from scipy.stats import wasserstein_distance
# from sklearn.decomposition import PCA
# from sklearn.preprocessing import StandardScaler
# from scipy.spatial.distance import pdist, squareform


# level = 7
# preprocessed_data = iisignature.prepare(3, level)
# signatures = {}

# for name, c in curves.items():
#     I = np.linspace(0, 1, len(c))
#     c = move_rotation_origin_to_identity(c)
#     q = skew_matrix_to_vector_single_rotation(right_log_single_rotation(I, c))
#     sig = iisignature.logsig(q, preprocessed_data)
#     signatures[name] = sig

# signature_matrix = np.zeros((len(signatures), len(signatures)))
# for i, s1 in enumerate(signatures.values()):
#     for j, s2 in enumerate(signatures.values()):
#         distance =  np.linalg.norm(s1 / np.linalg.norm(s1) - s2 / np.linalg.norm(s2))
#         # distance = np.linalg.norm(s1 - s2)
#         # distance = 1 - np.dot(s1, s2) / (np.linalg.norm(s1) * np.linalg.norm(s2))
#         # distance = wasserstein_distance(s1, s2)
#         signature_matrix[i, j] = distance

# plt.figure(figsize=(8, 5))
# plt.imshow(signature_matrix, cmap='hot', interpolation='nearest')
# plt.colorbar()
# # plt.savefig(f"figures/syntetic_data/distance_matrix/SO3_signature_{level}.png", bbox_inches='tight', pad_inches=0)
# plt.show()


# # Standardize the data
# scaler = StandardScaler()
# X_standardized = scaler.fit_transform(signature_matrix)

# # Apply PCA
# num_components = 2
# pca = PCA(n_components=num_components)
# X_reduced = pca.fit_transform(X_standardized)

# # Compute the pairwise distance matrix
# distance_matrix = squareform(pdist(X_reduced, metric='cosine'))

# # Visualize the distance matrix
# plt.figure(figsize=(8, 5))
# plt.imshow(distance_matrix, cmap='hot', interpolation='nearest')
# plt.colorbar()
# plt.show()


In [None]:
# level = 5
# preprocessed_data = iisignature.prepare(3, level)
# signatures = {}

# for name, c in curves.items():
#     I = np.linspace(0, 1, len(c))
#     c = move_rotation_origin_to_identity(c)
#     q = skew_matrix_to_vector_single_rotation(right_log_single_rotation(I, c))
#     signatures[name] = iisignature.logsig(q, preprocessed_data)

# signature_matrix = np.zeros((len(signatures), len(signatures)))
# for i, s1 in enumerate(signatures.values()):
#     for j, s2 in enumerate(signatures.values()):
#         # distance =  np.linalg.norm(s1 / np.linalg.norm(s1) - s2 / np.linalg.norm(s2))
#         # distance = np.linalg.norm(s1 - s2)
#         distance = 1 - np.dot(s1, s2) / (np.linalg.norm(s1) * np.linalg.norm(s2))
#         # distance = wasserstein_distance(s1, s2)
#         signature_matrix[i, j] = distance

# plt.figure(figsize=(8, 5))
# plt.imshow(signature_matrix, cmap='viridis', interpolation='nearest')
# plt.colorbar()
# # plt.savefig(f"figures/syntetic_data/distance_matrix/SO3_signature_{level}.png", bbox_inches='tight', pad_inches=0)
# plt.show()


# # Standardize the data
# scaler = StandardScaler()
# X_standardized = scaler.fit_transform(signature_matrix)

# # Apply PCA
# num_components = 2
# pca = PCA(n_components=num_components)
# X_reduced = pca.fit_transform(X_standardized)

# # Compute the pairwise distance matrix
# distance_matrix = squareform(pdist(X_reduced, metric='cosine'))

# # Visualize the distance matrix
# plt.figure(figsize=(8, 5))
# plt.imshow(distance_matrix, cmap='viridis', interpolation='nearest')
# plt.colorbar()
# plt.show()


### SO3^n

In [None]:
# import numpy as np
# import pandas as pd
# import time 

# x = np.linspace(0, 1, 100)
# curves = {}

# for i, varphi in enumerate([varphi_1, varphi_2, varphi_3, x]):
#     curve = np.zeros((len(x), 3, 3, 3))
#     for j, g_func in enumerate(g_dot_lst[:3]):
#         c_j = solve_ivp_rk4(g_func, g0, varphi)
#         c_j = c_j.reshape(len(varphi), 3, 3)
#         curve[:, j] = c_j

#     # Reshape curve from (100, 3, 3, 3) to (3, 100, 3, 3)
#     curve = np.moveaxis(curve, 1, 0)
#     print(curve.shape)
#     curves[f"c_{i}"] = curve

In [None]:
# from SO3.utils.curve_utils import * 
# from SO3.utils.multiple_curves_utils import *

# def find_reparameterization_several(c1, c2, depth = 10): 
#     I = np.linspace(0, 1, len(c1))
#     c1 = move_several_rotations_origins_to_identity(c1)
#     c2 = move_several_rotations_origins_to_identity(c2)

#     q1 = skew_matrix_to_vector_several_rotations(SRVT_multiple_rotations(I,c1))
#     q2 = skew_matrix_to_vector_several_rotations(SRVT_multiple_rotations(I,c2))

#     I_new = find_optimal_diffeomorphism(q1, q2, I , I, depth)
#     return I_new

In [None]:
print(curves.keys())    

In [None]:
# plt.plot(x, I_0, label=r'$\hat \varphi_1$', color='blue')
# plt.plot(x, varphi_1, label=r'$\varphi_1$', color='blue', marker = '.')

# plt.plot(x, I_1, label=r'$\hat \varphi_2$', color='red')
# plt.plot(x, varphi_2, label=r'$\varphi_2$', color='red', marker = '.')

# plt.plot(x, I_2, label=r'$\hat \varphi_3$', color='green')
# plt.plot(x, varphi_3, label=r'$\varphi_3$', color='green', marker = '.')

# plt.legend()
# plt.show()

In [None]:
import numpy as np
import pandas as pd
import time 

x = np.linspace(0, 1, 100)
curves = {}

for i, varphi in enumerate([varphi_1, varphi_2, varphi_3, x]):
    curve = np.zeros((len(x), 9, 3, 3))
    for j, g_func in enumerate(g_dot_lst):
        c_j = solve_ivp_rk4(g_func, g0, varphi)
        c_j = c_j.reshape(len(varphi), 3, 3)
        curve[:, j] = c_j

    # Reshape curve from (100, 3, 3, 3) to (3, 100, 3, 3)
    curve = np.moveaxis(curve, 1, 0)
    curves[f"c_1_{i+1}"] = curve[0:3]
    curves[f"c_2_{i+1}"] = curve[3:6]
    curves[f"c_3_{i+1}"] = curve[6:9]

keys = sorted(curves.keys(), key=lambda x: (int(x.split('_')[1]), int(x.split('_')[2])))
sorted_curves = {key: curves[key] for key in keys}

In [None]:
print(sorted_curves.keys())

In [None]:
# from SO3.movement_data.calculate_reparameterized_distance import reparameterized_distance

# df_varphi_1 = pd.DataFrame({'x' : x})
# df_varphi_2 = pd.DataFrame({'x' : x})
# df_varphi_3 = pd.DataFrame({'x' : x})

# depth = 10

# from tqdm import tqdm
# pbar = tqdm(total=3 * 3)

# for i, key in enumerate(['c_1_1', 'c_1_2', 'c_1_3']):
#     c1 = sorted_curves[key]
#     c2 = sorted_curves['c_1_4']
#     I_new = reparameterized_distance(c1, c2, depth=depth).I_new
#     df_varphi_1[f"g{i}"] = I_new
#     pbar.update(1)

# for i, key in enumerate(['c_2_1', 'c_2_2', 'c_2_3']):
#     c1 = sorted_curves[key]
#     c2 = sorted_curves['c_2_4']
#     I_new = reparameterized_distance(c1, c2, depth=depth).I_new
#     df_varphi_2[f"g{i}"] = I_new
#     pbar.update(1)

# for i, key in enumerate(['c_3_1', 'c_3_2', 'c_3_3']):
#     c1 = sorted_curves[key]
#     c2 = sorted_curves['c_3_4']
#     I_new = reparameterized_distance(c1, c2, depth=depth).I_new
#     df_varphi_3[f"g{i}"] = I_new
#     pbar.update(1)

# pbar.close()

# df_varphi_1.to_csv('figures/syntetic_data/reparameterization_SO3_3/df_varphi_1.csv', index=False)
# df_varphi_2.to_csv('figures/syntetic_data/reparameterization_SO3_3/df_varphi_2.csv', index=False)
# df_varphi_3.to_csv('figures/syntetic_data/reparameterization_SO3_3/df_varphi_3.csv', index=False)

In [None]:
# from tqdm import tqdm

# total = (len(sorted_curves) * len(sorted_curves) - len(sorted_curves)) // 2
# pbar = tqdm(total=total)

# depth = 10
# distance_matrix = np.zeros((len(sorted_curves), len(sorted_curves)))
# for i, c1 in enumerate(sorted_curves.values()):
#     for j, c2 in enumerate(sorted_curves.values()):
#         if i <= j:
#             continue

#         # start = time.time()
#         distance = reparameterized_distance(c1, c2, depth).distance
#         # end = time.time()
#         # print(f"Time: {end - start}")

#         distance_matrix[i, j] = distance
#         distance_matrix[j, i] = distance

#         pbar.update()

# pbar.close()

# plt.figure(figsize=(8, 5))
# plt.imshow(distance_matrix, cmap='hot', interpolation='nearest')
# plt.colorbar()
# plt.savefig(f"figures/syntetic_data/distance_matrix/SO3_3.png", bbox_inches='tight', pad_inches=0)
# # plt.show()

In [None]:
# import numpy as np
# import pandas as pd
# import time 

# x = np.linspace(0, 1, 100)
# time_data = {}

# for key in sorted_curves.keys():
#     _, i, j = key.split('_')

#     # If equidistant parameterization, skip
#     if j == '4':
#         continue
    
#     c1 = curves[key]
#     c2 = curves[f'c_{i}_4']

#     # Prepare an empty list to store individual rows as dictionaries
#     data = []

#     for depth in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]:
#         start_time = time.time()

#         distance = reparameterized_distance(c1, c2, depth).distance

#         elapsed_time = time.time() - start_time 

#         # Append the result as a dictionary to the data list
#         data.append({'depth': depth, 'L2_distance': distance})

#         if depth not in time_data:
#             time_data[depth] = []

#         time_data[depth].append(elapsed_time)

#         df = pd.DataFrame(data)
#         dir_path = 'figures/syntetic_data/reparameterization_SO3_3/depth_error'
#         os.makedirs(dir_path, exist_ok=True)
#         df.to_csv(f'{dir_path}/g{int(i)-1}_varphi{int(j)-1}.csv', index=False)
    
#     print(f"Done with {key}")

# mean_time_data = {}

# for depth, times in time_data.items():
#     mean_time_data[depth] = np.mean(times)

# # If you want to convert it to a DataFrame
# mean_time_df = pd.DataFrame(list(mean_time_data.items()), columns=['depth', 'mean_time'])

# mean_time_df.to_csv(f'{dir_path}/mean_time.csv', index=False)

In [None]:
# for key, c in sorted_curves.items():
#     _, i, j = key.split('_')
#     print(f"Key: {key}, Indices: ({i}, {j})")
#     print(c.shape)

In [None]:
# import iisignature


# for level in [1, 2, 3, 4, 5, 6, 7]:
#     preprocessed_data = iisignature.prepare(9, level)
#     signatures = {}

#     for key, c in sorted_curves.items():
#         c = move_several_rotations_origins_to_identity(c)
#         I = create_parameterization_several_rotations(c)
#         right_log = right_log_several_rotations(I, c)
#         vector = skew_matrix_to_vector_several_rotations(right_log)
#         signatures[key] = iisignature.logsig(vector, preprocessed_data)

#     signature_matrix = np.zeros((len(signatures), len(signatures)))
#     for i, s1 in enumerate(signatures.values()):
#         for j, s2 in enumerate(signatures.values()):
#             distance = 1 - np.dot(s1, s2) / (np.linalg.norm(s1) * np.linalg.norm(s2))
#             signature_matrix[i, j] = distance

#     plt.figure(figsize=(8, 5))
#     plt.imshow(signature_matrix, cmap='hot', interpolation='nearest')
#     plt.colorbar()
#     plt.savefig(f"figures/syntetic_data/distance_matrix/SO3_3_signature_{level}.png", bbox_inches='tight', pad_inches=0)
#     # plt.show()

### Perturbation 

In [None]:
t_eq = np.linspace(0, 1, 100)

def pertubate_parameterization(t, max_pertubation):
    t = t.copy()
    perturbation = np.random.normal(0, max_pertubation, len(t) - 2)
    # perturbation = np.random.uniform(-max_pertubation, max_pertubation, len(t) - 2)
    perturbation = np.clip(perturbation, -max_pertubation, max_pertubation)
    t[1:-1] += perturbation
    return t

In [None]:
# from tqdm import tqdm


# n = 10
# num_g = 3

# pbar = tqdm(total=num_g)
# distances = np.zeros((num_g, n))


# for i, g_dot in enumerate(g_dot_lst[:num_g]):

#     perturbations = []
#     max_pertubation = (t_eq[1] - t_eq[0]) / 2 - 1e-12

#     c_ref = solve_ivp_rk4(g_dot_lst[0], g0, t_eq)
#     c_ref = c_ref.reshape(len(varphi), 3, 3)

#     c1 = move_rotation_origin_to_identity(c_ref)
#     q1 = skew_matrix_to_vector_single_rotation(SRVT_single_rotation(t_eq,c1))



#     for j in range(n):
#         t_pertubated = pertubate_parameterization(t_eq, max_pertubation)
#         c_pertubated = solve_ivp_rk4(g_dot_lst[0], g0, t_pertubated)
#         c_pertubated = c_pertubated.reshape(len(varphi), 3, 3)

#         c2 = move_rotation_origin_to_identity(c_pertubated)
#         q2 = skew_matrix_to_vector_single_rotation(SRVT_single_rotation(t_eq,c2))

#         distances[i, j] = L2_metric(q1, q2, t_eq, t_eq)
#         perturbations.append(max_pertubation)
        
#         max_pertubation = max_pertubation / 2
        
#     pbar.update(1)
    
# pbar.close()


In [None]:
# df = pd.DataFrame(distances)
# df = df.T
# df.columns = ['c1', 'c2', 'c3']
# df['perturbation'] = perturbations

# # df.to_csv('figures/syntetic_data/perturbation-analysis/SO3.csv', index=False)

In [None]:
# # Plot the distances with labels
# for i, distance in enumerate(distances):
#     plt.loglog(perturbations, distance, label=f'$\omega_{i}$')

# # Plot the reference line for order 1
# x = np.linspace(min(perturbations), max(perturbations), 100)
# plt.loglog(x, x**1, label=r'$\mathcal{O}(\epsilon^{1})$', linestyle='--', color='black')


# # Improve plot aesthetics
# plt.xlabel(r'$\epsilon$')
# plt.ylabel(r'$\|q - q \circ \varphi_{\epsilon}\|_{L^2}$')
# plt.legend(fontsize='small')
# plt.title('Perturbation Analysis')
# # plt.grid(True, which="both", ls="--")
# plt.tight_layout()

# # Show the plot
# plt.show()


In [None]:
# from tqdm import tqdm


# n = 10
# num_g = 3

# pbar = tqdm(total=num_g)
# distances = np.zeros((num_g, n))


# for i, g_dot in enumerate(g_dot_lst[:num_g]):

#     perturbations = []
#     max_pertubation = (t_eq[1] - t_eq[0]) / 2 - 1e-12

#     c_ref = solve_ivp_rk4(g_dot_lst[0], g0, t_eq)
#     c_ref = c_ref.reshape(len(varphi), 3, 3)

#     c1 = move_rotation_origin_to_identity(c_ref)
#     q1 = skew_matrix_to_vector_single_rotation(SRVT_single_rotation(t_eq,c1))



#     for j in range(n):
#         t_pertubated = pertubate_parameterization(t_eq, max_pertubation)
#         c_pertubated = solve_ivp_rk4(g_dot_lst[0], g0, t_pertubated)
#         c_pertubated = c_pertubated.reshape(len(varphi), 3, 3)
#         c2 = move_rotation_origin_to_identity(c_pertubated)


#         I_new = find_reparameterization(c1, c2, depth=10)
#         # I = np.linspace(0, 1, len(c1))
#         c2 = reparameterize_rotation(I_new, t_eq, c2)

#         c2 = move_rotation_origin_to_identity(c2)
#         q2 = skew_matrix_to_vector_single_rotation(SRVT_single_rotation(t_eq,c2))

#         distances[i, j] = L2_metric(q1, q2, t_eq, t_eq)
#         perturbations.append(max_pertubation)
        
#         max_pertubation = max_pertubation / 2
        
#     pbar.update(1)
    
# pbar.close()


In [None]:
# # Plot the distances with labels
# for i, distance in enumerate(distances):
#     plt.loglog(perturbations, distance, label=f'$\omega_{i}$')

# # Plot the reference line for order 1
# x = np.linspace(min(perturbations), max(perturbations), 100)
# plt.loglog(x, x**1, label=r'$\mathcal{O}(\epsilon^{1})$', linestyle='--', color='black')


# # Improve plot aesthetics
# plt.xlabel(r'$\epsilon$')
# plt.ylabel(r'$\|q - q \circ \varphi_{\epsilon}\|_{L^2}$')
# plt.legend(fontsize='small')
# plt.title('Perturbation Analysis')
# # plt.grid(True, which="both", ls="--")
# plt.tight_layout()

# # Show the plot
# plt.show()


In [None]:
# df = pd.DataFrame(distances)
# df = df.T
# df.columns = ['c1', 'c2', 'c3']
# df['perturbation'] = perturbations

# # df.to_csv('figures/syntetic_data/perturbation-analysis/reparameterized-SO3.csv', index=False)

### Signature

In [None]:
from tqdm import tqdm
import iisignature

level = 8
preprocessed_data = iisignature.prepare(3, level)

n = 10
num_g = 3

pbar = tqdm(total=num_g)
distances = np.zeros((num_g, n))


for i, g_dot in enumerate(g_dot_lst[:num_g]):

    perturbations = []
    max_pertubation = (t_eq[1] - t_eq[0]) / 2 - 1e-12

    c_ref = solve_ivp_rk4(g_dot, g0, t_eq)
    c_ref = c_ref.reshape(len(varphi), 3, 3)

    c1 = move_rotation_origin_to_identity(c_ref)
    q1 = skew_matrix_to_vector_single_rotation(right_log_single_rotation(t_eq, c1))

    for j in range(n):

        t_pertubated = pertubate_parameterization(t_eq, max_pertubation)
        c_pertubated = solve_ivp_rk4(g_dot, g0, t_pertubated)
        c_pertubated = c_pertubated.reshape(len(varphi), 3, 3)

        c2 = move_rotation_origin_to_identity(c_pertubated)
        q2 = skew_matrix_to_vector_single_rotation(right_log_single_rotation(t_eq, c2))

        sig1 = iisignature.logsig(q1, preprocessed_data)
        sig2 = iisignature.logsig(q2, preprocessed_data)

        distances[i, j] = np.linalg.norm(sig1 / np.linalg.norm(sig1) - sig2 / np.linalg.norm(sig2))
        perturbations.append(max_pertubation)
        
        max_pertubation = max_pertubation / 2
        
    pbar.update(1)
    
pbar.close()

In [None]:
# Plot the distances with labels
for i, distance in enumerate(distances):
    plt.loglog(perturbations, distance, label=f'$\omega_{i}$')

# Plot the reference line for order 1
x = np.linspace(min(perturbations), max(perturbations), 100)
plt.loglog(x, x**1, label=r'$\mathcal{O}(\epsilon^{1})$', linestyle='--', color='black')


# Improve plot aesthetics
plt.xlabel(r'$\epsilon$')
plt.ylabel(r'$\|q - q \circ \varphi_{\epsilon}\|_{L^2}$')
plt.legend(fontsize='small')
plt.title('Perturbation Analysis')
# plt.grid(True, which="both", ls="--")
plt.tight_layout()

# Show the plot
plt.show()


In [None]:
df = pd.DataFrame(distances)
df = df.T
df.columns = ['c1', 'c2', 'c3']
df['perturbation'] = perturbations
df
# df.to_csv('figures/syntetic_data/perturbation-analysis/signature-SO3.csv', index=False)

In [None]:
from tqdm import tqdm
import iisignature

max_pertubation = (t_eq[1] - t_eq[0]) / 2 - 1e-12

c_ref = solve_ivp_rk4(g_dot_lst[0], g0, t_eq)
c_ref = c_ref.reshape(len(varphi), 3, 3)

c1 = move_rotation_origin_to_identity(c_ref)
q1 = skew_matrix_to_vector_single_rotation(right_log_single_rotation(t_eq, c1))

t_pertubated = pertubate_parameterization(t_eq, max_pertubation)
c_pertubated = solve_ivp_rk4(g_dot_lst[0], g0, t_pertubated)
c_pertubated = c_pertubated.reshape(len(varphi), 3, 3)

c2 = move_rotation_origin_to_identity(c_pertubated)
q2 = skew_matrix_to_vector_single_rotation(right_log_single_rotation(t_eq, c2))

for level in [1,2,3,4,5,6,7,8,9,10]:
    print(level)
    # level = 3
    preprocessed_data = iisignature.prepare(3, level)
    sig1_L = iisignature.logsig(q1, preprocessed_data)
    sig2_L = iisignature.logsig(q2, preprocessed_data)

    # level = 3
    preprocessed_data = iisignature.prepare(3, level, 'DH')
    sig1_H = iisignature.logsig(q1, preprocessed_data)
    sig2_H = iisignature.logsig(q2, preprocessed_data)

    print(np.linalg.norm(sig1_L / np.linalg.norm(sig1_L) - sig2_L / np.linalg.norm(sig2_L)))
    print(np.linalg.norm(sig1_H / np.linalg.norm(sig1_H) - sig2_H / np.linalg.norm(sig2_H)))

### Using gradient descent to find the reparametrization of the problem.

In [None]:
import numpy as np
from scipy.optimize import minimize
from scipy.spatial.transform import Rotation
from joblib import Parallel, delayed



def find_optimal_nodes(x_init, cost, niter=10, n_jobs=1):
    n = len(x_init)
    bounds = [(0, 0)] + [(None, None)] * (n - 2) + [(1, 1)]

    def increasing_constraint(x, i):
        return x[i + 1] - x[i]

    constraints = [{'type': 'ineq', 'fun': increasing_constraint, 'args': (i,)} for i in range(n - 1)]

    options = {
        'disp': False,
        'ftol': 1e-9,
        'maxiter': 30,
        # 'maxiter': 1, 
        # 'eps': 1.4901161193847656e-08
    }

    def run_optimization(i):
        x_init = np.sort(np.random.uniform(0, 1, 100))
        x_init[0], x_init[-1] = 0, 1
        
        result = minimize(cost,
                          x_init,
                          method='SLSQP',
                          bounds=bounds,
                          constraints=constraints,
                          options=options)
        return result

    results = Parallel(n_jobs=n_jobs)(delayed(run_optimization)(i) for i in range(niter))

    return min(results, key=lambda result: result.fun).x

def rotation_difference(y1, y2):
    assert y1.shape == y2.shape, 'y1 and y2 must have the same shape'
    diffs = np.array([log_map(np.dot(y1[i].T, y2[i])) for i in range(y1.shape[0])])
    norm_sq = np.mean(np.linalg.norm(diffs, axis=(1, 2))**2)
    return norm_sq

def create_slerp_f(x, g):
    xi_hat = [log_map(g[i-1].T @ g[i]) for i in range(1, len(g))]

    def interpSO3(x_eval):
        result = []
        for x_e in x_eval:
            index = np.searchsorted(x, x_e, side='right')
            index = np.clip(index, 1, len(x) - 1)
            alpha = (x_e - x[index-1]) / (x[index] - x[index-1])
            interpolated_matrix = g[index-1] @ exp_map(alpha * xi_hat[index-1])
            result.append(interpolated_matrix)
        
        return np.array(result)
    
    return interpSO3

def C_factory(f1, y2):
    def C(x):
        x_eq = np.linspace(0, 1, len(x))
        x = np.clip(x, 0, 1)
        y1 = create_slerp_f(x_eq, f1(x))(x_eq)
        return rotation_difference(y1, y2)
    return C

curves = {}
counter = 0
for i, g_func in enumerate(g_dot_lst[:3]):
    for j, varphi in enumerate([varphi_1, varphi_2, varphi_3, x]):
        curve = solve_ivp_rk4(g_func, g0, varphi)
        curve = curve.reshape(len(varphi), 3, 3)
        curves[f"c_{counter}"] = curve
        counter += 1


In [None]:
c_fit = solve_ivp_rk4(g_dot_lst[2], g0, x)
c_fit = c_fit.reshape(len(x), 3, 3)

c_target = solve_ivp_rk4(g_func, g0, varphi_1)
c_target = c_target.reshape(len(varphi), 3, 3)
x = np.linspace(0, 1, len(c_target))

In [None]:
import os
num_cores = os.cpu_count()
print(f'Number of CPU cores available: {num_cores}')

def find_reparameterization_SLERP(x, c_fit, c_target, restart, num_cores):
    f_fit = create_slerp_f(x, c_fit)
    f_target = create_slerp_f(x, c_target)
    y_target = f_target(x)
    cost = C_factory(f_fit, y_target)
    x_opt = find_optimal_nodes(x, cost, restart, num_cores)
    return x_opt

# x_opt = find_reparameterization_SLERP(x, c_fit = c_fit, c_target = c_target, restart = 5, num_cores = num_cores)

In [None]:
plt.plot(x, x_opt, marker='.')
plt.plot(x, varphi_1)
plt.show()

print(x_opt)

In [None]:
# import pandas as pd
# df = pd.read_csv('figures/syntetic_data/reparameterization_SO3_SLERP/df_varphi_2.csv')
# df['g2'] = x_opt

# df.to_csv('figures/syntetic_data/reparameterization_SO3_SLERP/df_varphi_2.csv', index=False)


In [None]:
# n = 100
# # x_init_equidistant = np.linspace(0, 1, n)
# # x_init_equidistant[1:-1] += np.random.normal(0, 0.05, size=n-2)
# # x_init_equidistant = np.sort(x_init_equidistant)
# for i in range(5):
#     random_values = np.sort(np.random.uniform(0, 1, 100))
#     random_values[0], random_values[-1] = 0, 1
#     plt.plot(x, random_values, marker='.')
# # plt.plot(x, x_init_equidistant, marker='.')
# plt.show()

In [None]:
import pandas as pd
x = np.linspace(0, 1, 100)

# Initialize empty dataframes
df_varphi_1 = pd.DataFrame({'x' : x})
df_varphi_2 = pd.DataFrame({'x' : x})
df_varphi_3 = pd.DataFrame({'x' : x})

for i, g_func in enumerate(g_dot_lst[:3]):
    for j, varphi in enumerate([varphi_1, varphi_2, varphi_3]):

        c_fit = solve_ivp_rk4(g_func, g0, x)
        c_fit = c_fit.reshape(len(x), 3, 3)

        c_target = solve_ivp_rk4(g_func, g0, varphi)
        c_target = c_target.reshape(len(varphi), 3, 3)

        x_new = find_reparameterization_SLERP(x, c_fit = c_fit, c_target = c_target, restart = 6, num_cores = num_cores)

        plt.plot(x, x_new, label='Optimal reparameterization')
        plt.plot(x, varphi, label='Actual reparameterization', marker='.')
        plt.legend()
        plt.show()

        # Store the result in the appropriate dataframe
        if j == 0:
            df_varphi_1[f"g{i}"] = x_new
        elif j == 1:
            df_varphi_2[f"g{i}"] = x_new
        elif j == 2:
            df_varphi_3[f"g{i}"] = x_new
        else:
            raise ValueError('Invalid j value')
        
        print(f"Done with g{i} and varphi{j}")


df_varphi_1.to_csv('figures/syntetic_data/reparameterization_SO3_SLERP/df_varphi_1.csv', index=False)
df_varphi_2.to_csv('figures/syntetic_data/reparameterization_SO3_SLERP/df_varphi_2.csv', index=False)
df_varphi_3.to_csv('figures/syntetic_data/reparameterization_SO3_SLERP/df_varphi_3.csv', index=False)

In [None]:
x = np.linspace(0, 1, 100)
curves = {}

for i, varphi in enumerate([varphi_1, varphi_2, varphi_3, x]):
    curve = np.zeros((len(x), 9, 3, 3))
    for j, g_func in enumerate(g_dot_lst):
        c_j = solve_ivp_rk4(g_func, g0, varphi)
        c_j = c_j.reshape(len(varphi), 3, 3)
        curve[:, j] = c_j

    # Reshape curve from (100, 3, 3, 3) to (3, 100, 3, 3)
    curve = np.moveaxis(curve, 1, 0)
    curves[f"c_1_{i+1}"] = curve[0:3]
    curves[f"c_2_{i+1}"] = curve[3:6]
    curves[f"c_3_{i+1}"] = curve[6:9]

keys = sorted(curves.keys(), key=lambda x: (int(x.split('_')[1]), int(x.split('_')[2])))
sorted_curves = {key: curves[key] for key in keys}

In [None]:
import numpy as np
from scipy.optimize import minimize
from scipy.spatial.transform import Rotation
from joblib import Parallel, delayed

def find_optimal_nodes(x_init, cost, niter=10, n_jobs=1):
    n = len(x_init)
    bounds = [(0, 0)] + [(None, None)] * (n - 2) + [(1, 1)]

    def increasing_constraint(x, i):
        return x[i + 1] - x[i]

    constraints = [{'type': 'ineq', 'fun': increasing_constraint, 'args': (i,)} for i in range(n - 1)]

    options = {
        'disp': False,
        'ftol': 1e-9,
        'maxiter': 10,
        # 'eps': 1.4901161193847656e-08
    }

    def run_optimization(i):
        x_init = np.sort(np.random.uniform(0, 1, 100))
        x_init[0], x_init[-1] = 0, 1
        
        result = minimize(cost,
                          x_init,
                          method='SLSQP',
                          bounds=bounds,
                          constraints=constraints,
                          options=options)
        return result

    results = Parallel(n_jobs=n_jobs)(delayed(run_optimization)(i) for i in range(niter))

    return min(results, key=lambda result: result.fun).x

def rotation_difference_multiple(y1, y2):
    tot_diff = 0
    for i in range(len(y1)):
        diffs = np.array([log_map(np.dot(y1[i][j].T, y2[i][j])) for j in range(len(y1[i]))])
        tot_diff += np.mean(np.linalg.norm(diffs, axis=(1, 2))**2)
    return tot_diff

def C_factory_multiple(f1, y2):
    def C(x):
        x_eq = np.linspace(0, 1, len(x))
        x = np.clip(x, 0, 1)
        y1 = [create_slerp_f(x_eq, f1_i(x))(x_eq) for f1_i in f1]
        return rotation_difference_multiple(y1, y2)
    return C

def find_reparameterization_SLERP_multiple(x, c_fit, c_target, restart, num_cores):
    f_fit = [create_slerp_f(x, c) for c in c_fit]
    f_target = [create_slerp_f(x, c) for c in c_target]
    y_target = [f(x) for f in f_target]
    cost = C_factory_multiple(f_fit, y_target)
    x_opt = find_optimal_nodes(x, cost, restart, num_cores)
    return x_opt

In [None]:
from SO3.movement_data.calculate_reparameterized_distance import reparameterized_distance

df_varphi_1 = pd.DataFrame({'x' : x})
df_varphi_2 = pd.DataFrame({'x' : x})
df_varphi_3 = pd.DataFrame({'x' : x})

for i, key in enumerate(['c_1_1', 'c_1_2', 'c_1_3']):
    c_target = sorted_curves[key]
    c_fit = sorted_curves['c_1_4']
    I_new = find_reparameterization_SLERP_multiple(x, c_fit, c_target, restart=5, num_cores=num_cores)
    plt.plot(x, I_new)
    plt.plot(x,varphi_1)
    plt.show()
    df_varphi_1[f"g{i}"] = I_new

df_varphi_1.to_csv('figures/syntetic_data/reparameterization_SO3_3_SLERP/df_varphi_1.csv', index=False)

for i, key in enumerate(['c_2_1', 'c_2_2', 'c_2_3']):
    c_target = sorted_curves[key]
    c_fit = sorted_curves['c_2_4']
    I_new = find_reparameterization_SLERP_multiple(x, c_fit, c_target, restart=5, num_cores=num_cores)
    df_varphi_2[f"g{i}"] = I_new

df_varphi_2.to_csv('figures/syntetic_data/reparameterization_SO3_3_SLERP/df_varphi_2.csv', index=False)


for i, key in enumerate(['c_3_1', 'c_3_2', 'c_3_3']):
    c_target = sorted_curves[key]
    c_fit = sorted_curves['c_3_4']
    I_new = find_reparameterization_SLERP_multiple(x, c_fit, c_target, restart=5, num_cores=num_cores)
    df_varphi_3[f"g{i}"] = I_new

df_varphi_3.to_csv('figures/syntetic_data/reparameterization_SO3_3_SLERP/df_varphi_3.csv', index=False)

### Distance matrix

In [None]:
curves = {}

counter = 0
for i, g_func in enumerate(g_dot_lst[:3]):
    for j, varphi in enumerate([varphi_1, varphi_2, varphi_3, x]):
        curve = solve_ivp_rk4(g_func, g0, varphi)
        curve = curve.reshape(len(varphi), 3, 3)
        curves[f"c_{counter}"] = curve
        counter += 1

In [None]:
total = (len(curves) * len(curves) - len(curves)) // 2

distance_matrix = np.zeros((len(curves), len(curves)))
counter = 0
x = np.linspace(0, 1, 100)
for i, c1 in enumerate(curves.values()):
    for j, c2 in enumerate(curves.values()):
        if i <= j:
            continue

        x_opt = find_reparameterization_SLERP(x, c1, c2, restart=5, num_cores=num_cores)
        f1 = create_slerp_f(x, c1)
        x_opt = np.clip(x_opt, 0, 1)
        c1_new = create_slerp_f(x, f1(x_opt))(x)
        distance = rotation_difference(c1_new, c2)
        distance_matrix[i, j] = distance
        distance_matrix[j, i] = distance

        counter += 1
        print(f"Done with {counter}/{total}")


In [None]:
plt.figure(figsize=(8, 5))
plt.imshow(distance_matrix, cmap='hot', interpolation='nearest')
plt.colorbar()
# plt.show()
# plt.savefig(f"figures/syntetic_data/distance_matrix/SO3_interpolation.png", bbox_inches='tight', pad_inches=0)