In [1]:
from scipy.io import loadmat
import matplotlib.pyplot as plt
import cv2
import numpy as np
import math
import random
from scipy.linalg import null_space, inv, svd, det

In [2]:
def ComputeReprojectionError(P_1, P_2, X_j, x_1j, x_2j):
    
    r_1 = np.array([
        x_1j[0] - (P_1[0, :] @ X_j) / (P_1[2, :] @ X_j),
        x_1j[1] - (P_1[1, :] @ X_j) / (P_1[2, :] @ X_j)
    ])

    r_2 = np.array([
        x_2j[0] - (P_2[0, :] @ X_j) / (P_2[2, :] @ X_j),
        x_2j[1] - (P_2[1, :] @ X_j) / (P_2[2, :] @ X_j)
    ])

    rs    = np.hstack((r_1, r_2))
    error = np.linalg.norm(r_1)**2 + np.linalg.norm(r_2)**2

    return error, rs

def LinearizeReprojErr(P_1, P_2, X_j, x_1j, x_2j):
    
    _, rs = ComputeReprojectionError(P_1, P_2, X_j, x_1j, x_2j)

    J_1 = np.zeros((2, 4))
    J_2 = np.zeros((2, 4))

    J_1[0, :] = (((P_1[0, :] @ X_j) / (P_1[2, :] @ X_j)**2) * P_1[2, :]) - ((1 / (P_1[2, :] @ X_j)) * P_1[0, :])
    J_1[1, :] = (((P_1[1, :] @ X_j) / (P_1[2, :] @ X_j)**2) * P_1[2, :]) - ((1 / (P_1[2, :] @ X_j)) * P_1[1, :])

    J_2[0, :] = (((P_2[0, :] @ X_j) / (P_2[2, :] @ X_j)**2) * P_2[2, :]) - ((1 / (P_2[2, :] @ X_j)) * P_2[0, :])
    J_2[1, :] = (((P_2[1, :] @ X_j) / (P_2[2, :] @ X_j)**2) * P_2[2, :]) - ((1 / (P_2[2, :] @ X_j)) * P_2[1, :])

    J = np.vstack((J_1, J_2))

    return rs, J

def ComputeUpdate(r, J, mu):
    
    JTJ = J.T @ J
    JTr = J.T @ r
    delta_X_j = -(np.linalg.inv(JTJ + mu * np.eye(JTJ.shape[0])) @ JTr)
    
    return delta_X_j

def Refine3DPoints(P_1, P_2, X, x_1, x_2, max_iterations=50, initial_mu=0.1):

    n = X.shape[0]
    m = X.shape[1]
    refined_X = np.zeros((n, m))
    mu = initial_mu

    for j in range(m):
        X_j = X[:, j]
        x_1j = x_1[:, j]
        x_2j = x_2[:, j]

        prev_error, _ = ComputeReprojectionError(P_1, P_2, X_j, x_1j, x_2j)

        for i in range(max_iterations):
            residuals, J = LinearizeReprojErr(P_1, P_2, X_j, x_1j, x_2j)

            delta_X_j = ComputeUpdate(residuals, J, mu)

            updated_X_j = X_j + delta_X_j

            updated_error, _ = ComputeReprojectionError(P_1, P_2, updated_X_j, x_1j, x_2j)

            if updated_error < prev_error:
                X_j = updated_X_j
                prev_error = updated_error
                mu /= 10
            else:
                mu *= 10
                
            mu = max(min(mu, 1e10), 1e-10)

        refined_X[:, j] = X_j

    return refined_X


def ComputeAllReprojectionErrors(P_1, P_2, X, x_1, x_2):
    errors = []
    for j in range(X.shape[1]):
        X_j = X[:, j]
        x_1j = x_1[:, j]
        x_2j = x_2[:, j]
        error, _ = ComputeReprojectionError(P_1, P_2, X_j, x_1j, x_2j)
        errors.append(error)
    return errors


In [3]:
data = loadmat('data/compEx3data')
P = data["P"][0]
P_1 = P[0]
P_2 = P[1]

X = data["X"]
x = data["x"][0]
x_1 = x[0]
x_2 = x[1]

In [4]:
r_X = Refine3DPoints(P_1, P_2, X, x_1, x_2)

In [5]:
error_before_LM, _ = ComputeReprojectionError(P_1, P_2, X, x_1, x_2)
error_after_LM, _  = ComputeReprojectionError(P_1, P_2, r_X, x_1, x_2)

before = ComputeAllReprojectionErrors(P_1, P_2, X, x_1, x_2)
after = ComputeAllReprojectionErrors(P_1, P_2, r_X, x_1, x_2)

median_error_before_LM = np.median(before)
median_error_after_LM = np.median(after)

print(f"Total reprojection error before refining the 3D points: {error_before_LM:.4f}")
print(f"Total reprojection error after refining the 3D points: {error_after_LM:.4f}")
print(f"Median reprojection error before refining the 3D points: {median_error_before_LM:.4f}")
print(f"Median reprojection error after refining the 3D points: {median_error_after_LM:.4f}")

Total reprojection error before refining the 3D points: 22354.2450
Total reprojection error after refining the 3D points: 21566.4200
Median reprojection error before refining the 3D points: 11.6599
Median reprojection error after refining the 3D points: 11.1979


In [6]:
%matplotlib qt
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(X[0, :], X[1, :], X[2, :], s=0.5, c="b", label="Original Points", edgecolors="b", facecolors="b")
ax.scatter(r_X[0, :], r_X[1, :], r_X[2, :], s=10, edgecolors="r", facecolors="none", label="Refined Points", linewidth=0.8)
ax.set_title("3D Points Before and After Refinement")
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Z")
ax.legend()
plt.show()