In [11]:
import numpy as np
from pyquaternion import Quaternion
import math

In [75]:
def NewtonRaphson(proposed_eigen, a, b, c, d, sigma):
    return proposed_eigen - (proposed_eigen**4 - (a+b)*(proposed_eigen**2) - c*proposed_eigen + (a*b + c*sigma - d)) \
    / (4*(proposed_eigen**3) - 2*(a+b)*proposed_eigen - c)

def QUEST(observation_vectors, reference_vectors):
    iters = 10
    #all vectors should be unit vectors of shape (3, )
    vector_count = len(observation_vectors)
    weights = np.random.random(vector_count)
    weights /= weights.sum()

    B = np.zeros((3, 3))
    Z = np.zeros(3)
    for i in range(vector_count):
        B += weights[i] * np.outer(observation_vectors[i], reference_vectors[i])
        Z += weights[i] * np.cross(observation_vectors[i],reference_vectors[i])
    S = B + B.T

    delta = np.linalg.det(S)
    kappa = np.trace(np.linalg.inv(S).T * delta)
    sigma = 1/2*np.trace(S)
    d = Z@S@S@Z.T
    c = delta + Z@S@Z.T
    b = sigma**2 + Z@Z.T
    a = sigma**2 - kappa

    proposed_eigen = 1
    for i in range(iters):
        print(proposed_eigen)
        proposed_eigen = NewtonRaphson(proposed_eigen, a, b, c, d, sigma)

    alpha = proposed_eigen**2 - sigma**2 + kappa
    beta = proposed_eigen - sigma
    gamma = (proposed_eigen + sigma)*alpha - delta
    X = -(alpha*np.eye(3) + beta*S + S@S)@Z#look idfk but my vectors were always off by a negative sign

    return 1/math.sqrt(gamma**2 + (np.linalg.norm(X))**2) * Quaternion(scalar = gamma, vector = X)
    
    

In [76]:
real_1 = np.array([1, 0, 0])
real_2 = np.array([0, 1, 0])
real_3 = np.array([0, 0, 1])

random_4 = []
while(True):
    random_3 = np.random.random(3)
    if(np.linalg.norm(random_3) < 1):
        random_4 = np.array([random_3[0], random_3[1], random_3[2], math.sqrt(1 - np.linalg.norm(random_3)**2)])
        break

quat = Quaternion(random_4)
observed_v1 = quat.rotate(real_1)
observed_v2 = quat.rotate(real_2)
observed_v3 = quat.rotate(real_3)
observed = np.array([observed_v1, observed_v2, observed_v3])
real = np.array([real_1, real_2, real_3])
quat_received = QUEST(observed, real)
print(quat_received.rotate(real_1))
print(observed_v1)

1
0.9999999999999999
0.9999999999999999
0.9999999999999999
0.9999999999999999
0.9999999999999999
0.9999999999999999
0.9999999999999999
0.9999999999999999
0.9999999999999999
[-0.09509865  0.81092509 -0.57737055]
[-0.09509865  0.81092509 -0.57737055]
