In [None]:
import sys
sys.path.append("..")

from SO3.utils.curve_utils import * 
from SO3.utils.reparameterization_utils import *
from scipy.spatial.transform import Rotation as R
import numpy as np
import matplotlib.pyplot as plt

In [None]:
def rotate(axis: np.ndarray, angle: float) -> np.ndarray:
    return R.from_rotvec(axis * angle).as_matrix()

def create_so3_curve(n: int, I: np.ndarray, angle_x: float, angle_y: float) -> np.ndarray:
    x_axis = np.array([1, 0, 0])
    y_axis = np.array([0, 1, 0])
    z_axis = np.array([0, 0, 1])
    rotations = np.zeros((n, 3, 3))
    for i in range(n):
        rotation_x = rotate(x_axis, angle_x * (I[i]))
        rotation_y = rotate(y_axis, angle_y * I[i] - 1)
        rotations[i] = rotation_x @ rotation_y 
    return rotations

def plot_rotations(c, plot_sphere=True):
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    # Original unit vector (z-axis)
    original_vector = np.array([0, 0, 1])

    for R in c:
        # Apply the rotation to the vector
        rotated_vector = np.dot(R, original_vector)

        # Plotting
        ax.scatter(*rotated_vector, color='red', alpha=0.5)

    # Create a sphere
    def sphere(ax):
        u = np.linspace(0, 2 * np.pi, 100)
        v = np.linspace(0, np.pi, 100)
        x = np.outer(np.cos(u), np.sin(v))
        y = np.outer(np.sin(u), np.sin(v))
        z = np.outer(np.ones(np.size(u)), np.cos(v))

        ax.plot_surface(x, y, z, color='y', alpha=.5)

    if plot_sphere:
        sphere(ax)

def is_SO3(matrix):
    if matrix.shape[-2:] != (3, 3):
        return False
    for mat in np.reshape(matrix, (-1, 3, 3)):
        if not np.allclose(mat @ mat.T, np.eye(3)):
            return False
    if not np.all(np.isclose(np.linalg.det(matrix), 1)):
        return False
    return True

def rotate_new(axis, theta):
    """
    Compute the rotation matrix for a rotation around `axis` by `theta` radians.
    Uses the Rodrigues' rotation formula.
    """
    axis = np.asarray(axis)
    axis = axis / np.sqrt(np.dot(axis, axis))
    a = np.cos(theta / 2)
    b, c, d = -axis * np.sin(theta / 2)
    aa, bb, cc, dd = a*a, b*b, c*c, d*d
    bc, ad, ac, ab, bd, cd = b*c, a*d, a*c, a*b, b*d, c*d
    return np.array([[aa+bb-cc-dd, 2*(bc+ad), 2*(bd-ac)],
                     [2*(bc-ad), aa+cc-bb-dd, 2*(cd+ab)],
                     [2*(bd+ac), 2*(cd-ab), aa+dd-bb-cc]])

def create_advanced_so3_curve(n: int, I: np.ndarray, angle_scale: float) -> np.ndarray:
    rotations = np.zeros((n, 3, 3))
    for i in range(n):
        # Dynamically calculate angles using a non-linear relationship
        angle_x = np.sin(angle_scale * I[i]) * np.pi
        angle_y = np.cos(angle_scale * I[i]) * np.pi
        angle_z = (np.sin(angle_scale * I[i]) + np.cos(angle_scale * I[i])) * np.pi

        # Rotation matrices for x, y, and z axes
        rotation_x = rotate_new([1, 0, 0], angle_x)
        rotation_y = rotate_new([0, 1, 0], angle_y)
        rotation_z = rotate_new([0, 0, 1], angle_z)

        # Combine rotations
        rotations[i] = rotation_z @ rotation_y @ rotation_x
    
    return rotations

In [None]:
phi = lambda x: x**3
n = 50
I = np.linspace(0,1,n)
c1 = create_so3_curve(n, phi(I), np.pi, np.pi)
c2 = create_so3_curve(n, I, np.pi, np.pi)

plot_rotations(c1)
plot_rotations(c2)
plt.show()

In [None]:
c1 = move_rotation_origin_to_identity(c1)
c2 = move_rotation_origin_to_identity(c2)

q1 = skew_matrix_to_vector_single_rotation(SRVT_single_rotation(I,c1))
q2 = skew_matrix_to_vector_single_rotation(SRVT_single_rotation(I,c2))

I_new = find_optimal_diffeomorphism(q1, q2, I , I, 10)

plt.plot(I, I, label='I')
plt.plot(I, phi(I), label='phi(I)')
plt.plot(I, I_new, label='psi(I)')
plt.legend()
plt.show()

In [None]:
pert_lst = []
err_lst = []
for const in [4, 8, 16, 32, 64]:
    n = 50
    I = np.linspace(0, 1, n)
    I_p = I.copy()
    pert = (I[1] - I[0]) * 0.50001 / const
    I_p[1:-1] = I_p[1:-1] + np.random.normal(0, pert, n-2) 
    # I_p[1:-1] = I_p[1:-1] + np.random.uniform(-pert, pert, n-2)
    assert np.diff(I_p).min() > 0, "I_p is not strictly increasing"
    angle_x = np.pi / 2
    angle_y = np.pi * 2

    c = create_so3_curve(n, I, angle_x, angle_y)
    c = move_rotation_origin_to_identity(c)

    c_p = create_so3_curve(n, I_p, angle_x, angle_y)
    c_p = move_rotation_origin_to_identity(c_p)

    q = skew_matrix_to_vector_single_rotation(SRVT_single_rotation(I, c))
    q_p = skew_matrix_to_vector_single_rotation(SRVT_single_rotation(I, c_p))
    print(q.shape)

    pert_lst.append(pert)
    err_lst.append(L2_metric(q, q_p, I, I))
    print(L2_metric(q, q_p, I, I))

plt.loglog(pert_lst, err_lst, 'o-')

x = np.linspace(pert_lst[0], pert_lst[-1], 100)
y = x
y = y * (err_lst[0] + y[0] * 1) / y[0]
plt.plot(x, y, 'r--')

plt.xlabel("Perturbation")
plt.ylabel("Error")
plt.title("Error vs Perturbation")
plt.show()

# plot_rotations(c)
# plot_rotations(c_p)


In [None]:
# np.random.seed(1)

pert_lst = []
err_lst = []
for a in [1, 1/2, 1/4, 1/8]:
    n = 20
    I = np.linspace(0, 1, n)
    # I_p = I.copy()
    # pert = (I[1] - I[0]) * 0.5 / const
    # I_p[1:-1] = I_p[1:-1] + np.random.uniform(-pert, pert, n-2)
    # I_p = I**3
    I_p = I + a * I * (I - 1)
    # plt.plot(I, I_p, 'o-')
    # plt.show()
    pert = a

    assert np.diff(I_p).min() > 0, "I_p is not strictly increasing"
    angle_x = np.pi / 2
    angle_y = np.pi * 2

    angle_scale = 1.4


    c = create_advanced_so3_curve(n, I, angle_scale)
    c = move_rotation_origin_to_identity(c)

    c_p = create_advanced_so3_curve(n, I_p, angle_scale)
    c_p = move_rotation_origin_to_identity(c_p)

    print(c.shape)
    print(c_p.shape)

    q = skew_matrix_to_vector_single_rotation(SRVT_single_rotation(I, c))
    q_p = skew_matrix_to_vector_single_rotation(SRVT_single_rotation(I, c_p))

    I_new = find_optimal_diffeomorphism(q, q_p, I, I, depth = int(n / 4))
    c_new = reparameterize_rotation(I_new, I, c_p)
    q_new = skew_matrix_to_vector_single_rotation(SRVT_single_rotation(I, c_new))

    print(q.shape)
    print(q_p.shape)

    print(f"Error: {L2_metric(q, q_p, I, I)} -> {L2_metric(q, q_new, I, I)}")
    print(f"Pert: {pert}")

    pert_lst.append(pert)
    err_lst.append(L2_metric(q, q_new, I, I))

plt.loglog(pert_lst, err_lst, 'o-')
x = np.linspace(pert_lst[0], pert_lst[-1], 100)
y = x
y = y * (err_lst[0] + y[0] * 1) / y[0]
plt.loglog(x, y, 'r--')
plt.xlabel("Perturbation")
plt.ylabel("Error")
plt.title("Error vs Perturbation")
plt.show()

plot_rotations(c)
plot_rotations(c_p)


In [None]:
def perturb_rotation(R, eps):
    # Generate a small random rotation axis
    axis = np.random.normal(0, 1, 3)
    axis = axis / np.linalg.norm(axis)  # Normalize axis

    # Create a small rotation matrix using Rodrigues' rotation formula
    theta = np.random.uniform(-eps, eps)  # Small rotation angle
    K = np.array([[0, -axis[2], axis[1]],
                  [axis[2], 0, -axis[0]],
                  [-axis[1], axis[0], 0]])
    I = np.eye(3)
    R_small = I + np.sin(theta) * K + (1 - np.cos(theta)) * np.dot(K, K)

    # Apply the small rotation to R
    R_perturbed = np.dot(R_small, R)
    return R_perturbed

In [None]:
np.random.seed(1)

pert_lst = []
err_lst_org = []
err_lst_rep = []
for eps in [.5, .25, .125, .0625]:
    n = 40
    I = np.linspace(0, 1, n)
    angle_scale = 1.4

    c = create_advanced_so3_curve(n, I, angle_scale)
    c_p = np.zeros(c.shape)
    for i, R in enumerate(c): 
        c_p[i] = perturb_rotation(R, eps)

    print(is_SO3(c))
    print(is_SO3(c_p))


    c = move_rotation_origin_to_identity(c)
    c_p = move_rotation_origin_to_identity(c_p)

    print(is_SO3(c))
    print(is_SO3(c_p))

    q = skew_matrix_to_vector_single_rotation(SRVT_single_rotation(I, c))
    q_p = skew_matrix_to_vector_single_rotation(SRVT_single_rotation(I, c_p))

    I_new = find_optimal_diffeomorphism(q, q_p, I, I, depth = int(n / 4))
    c_new = reparameterize_rotation(I_new, I, c_p)
    q_new = skew_matrix_to_vector_single_rotation(SRVT_single_rotation(I, c_new))

    print(f"Error: {L2_metric(q, q_p, I, I)} -> {L2_metric(q, q_new, I, I)}")
    print(f"Pert: {eps}")

    pert_lst.append(eps)
    err_lst_org.append(L2_metric(q, q_p, I, I))
    err_lst_rep.append(L2_metric(q, q_new, I, I))

plt.loglog(pert_lst, err_lst_org, 'o-')
plt.loglog(pert_lst, err_lst_rep, 'o-')
x = np.linspace(pert_lst[0], pert_lst[-1], 100)
y = x
y = y * (err_lst[0] + y[0] * 1) / y[0]
plt.loglog(x, y, 'r--')
plt.xlabel("Perturbation")
plt.ylabel("Error")
plt.title("Error vs Perturbation")
plt.show()

plot_rotations(c)
plot_rotations(c_p)
plot_rotations(c_new)


In [None]:
def local_cost_regulated(k, l, i, j, q0, q1, I, lambda_reg, gamma_reg = 1):
    cost_l2 = local_cost(k, l, i, j, q0, q1, I)
    new_q1 = np.sqrt((I[j]-I[l])/(I[i]-I[k]))*q1[l:j] 

    # This regulates how much the curve can change
    def warp(k, l, i, j, t): 
        return l + (t-k) * (j-l) / (i-k)
    s_m = np.arange(k+1, i+1) # k < s_m <= i 
    reg_cost_1 = lambda_reg * np.mean(np.abs(warp(k, l, i, j, s_m) - s_m)**gamma_reg) 

    # This tries to smooth the curve
    reg_cost_2 = np.mean(np.abs(new_q1) ** gamma_reg)

    return cost_l2 + lambda_reg * reg_cost_2

def find_optimal_diffeomorphism_reg(q0, q1, I0, I1, depth, lambda_reg):
    I, q0_new, q1_new = create_shared_parameterization(q0, q1, I0, I1)
    M = I.shape[0]
    local_cost_partial = functools.partial(local_cost_regulated, q0 = q0_new, q1 = q1_new, I = I, lambda_reg = lambda_reg)

    pointers, A = dynamic(local_cost_partial, M, depth)
    path = reconstruct(pointers, M-1, M-1)

    #Construct reparametrization
    x = np.array([p[0] for p in path])/float(M-1)
    y = np.array([p[1] for p in path])/float(M-1)

    I_new = np.interp(I1, x, y)
    return I_new

In [None]:
np.random.seed(1)

pert_lst = []
err_lst_org = []
err_lst_rep = []

eps = 0.1
n = 20
I = np.linspace(0, 1, n)
angle_scale = 1.4

c = create_advanced_so3_curve(n, I, angle_scale)
c_p = np.zeros(c.shape)
for i, R in enumerate(c): 
    c_p[i] = perturb_rotation(R, eps)

print(is_SO3(c))
print(is_SO3(c_p))


c = move_rotation_origin_to_identity(c)
c_p = move_rotation_origin_to_identity(c_p)

print(is_SO3(c))
print(is_SO3(c_p))

q = skew_matrix_to_vector_single_rotation(SRVT_single_rotation(I, c))
q_p = skew_matrix_to_vector_single_rotation(SRVT_single_rotation(I, c_p))

I_new = find_optimal_diffeomorphism(q, q_p, I, I, depth = 5)
c_new = reparameterize_rotation(I_new, I, c_p)
q_new = skew_matrix_to_vector_single_rotation(SRVT_single_rotation(I, c_new))

lam_lst = np.logspace(-3, 3, 50)
error_lst = []
for lambda_reg in lam_lst:
    I_new_reg = find_optimal_diffeomorphism_reg(q, q_p, I, I, depth = int(n / 4), lambda_reg = lambda_reg)
    c_new_reg = reparameterize_rotation(I_new_reg, I, c_p)
    q_new_reg = skew_matrix_to_vector_single_rotation(SRVT_single_rotation(I, c_new_reg))
    error_lst.append(L2_metric(q, q_new_reg, I, I))
    print(f"Error: {L2_metric(q, q_p, I, I)} -> {L2_metric(q, q_new, I, I)} -> {L2_metric(q, q_new_reg, I, I)}")
print(f"Pert: {eps}")

plt.loglog(lam_lst, error_lst, 'o-')
plt.loglog(np.linspace(lam_lst[0], lam_lst[-1], 100), np.ones(100) * L2_metric(q, q_new, I, I))
plt.xlabel("Lambda")
plt.ylabel("Error")
plt.show()