In [None]:
import cv2
import os
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import minimize, least_squares
import csv
from sklearn.neighbors import NearestNeighbors
import networkx as nx
from scipy.spatial.transform import Rotation as R

In [None]:

#### READ THE IMU FILE IN ASCII seperated columns ####
def read_imu_data(file_path):
    imu_data = []
    with open(file_path, 'r') as file:
        csv_reader = csv.reader(file, delimiter=' ')
        next(csv_reader)  # Skip header
        for row in csv_reader:
            timestamp = float(row[0])
            heading = float(row[1])
            pitch = float(row[2])
            roll = float(row[3])
            imu_data.append([timestamp, heading, pitch, roll])
    return np.array(imu_data)

#### TIMESTAMPED IMAGES #### NAMING FORMAT - {SERIAL NUMBER}_{FALLING EDGE OF TRIGGER SIGNAL}_{RISING EDGE OF TRIGGER SIGNAL}.jpg ####
def extract_timestamps_from_filename(filename):
    parts = filename.split('_')
    falling_edge = float(parts[1])
    rising_edge = float(parts[2].split('.')[0])
    return falling_edge, rising_edge

#### FIND THE CLOSEST IMAGE AND IMU TIMESTAMP TO SYNC ####
def find_closest_imu_data(imu_data, falling_edge, rising_edge):
    closest_index = np.argmin(np.abs(imu_data[:, 0] - falling_edge))
    if imu_data[closest_index][0] >= falling_edge and imu_data[closest_index][0] <= rising_edge:
        return imu_data[closest_index]
    else:
        return None
        
#### USE OPENCV ORB KEYPOINT IDENTIFIER & DESCRIPTOR. SCALE FACTOR >1, higher value bad reso less time, near 1 good reso more time #### 
def orb_detector_descriptor(image, nfeatures=5000, scale_factor=1.2, nlevels=8):
    orb = cv2.ORB_create(nfeatures=nfeatures, scaleFactor=scale_factor, nlevels=nlevels)
    keypoints, descriptors = orb.detectAndCompute(image, None)
    return keypoints, descriptors

#### IDENTIFY IF THE CURRENT FRAME IS A KEYFRAME ####
def is_keyframe(current_kp, last_keyframe_kp, min_matches=50):
    if len(last_keyframe_kp) == 0:
        return True
    return len(current_kp) >= min_matches


#### REFINE THE ESSENTIAL MATRIX USING THE KEYPOINTS IN 2 FRAMES BY REDUCING THE ERROR ####
def refine_E(E, p1, p2, K):
    def objective(E_vec):
        E_mat = E_vec.reshape(3, 3)
        p1_norm = cv2.undistortPoints(p1.reshape(-1, 1, 2), K, None).reshape(-1, 2)
        p2_norm = cv2.undistortPoints(p2.reshape(-1, 1, 2), K, None).reshape(-1, 2)
        error = 0
        for i in range(len(p1_norm)):
            p1_homogeneous = np.append(p1_norm[i], 1)
            p2_homogeneous = np.append(p2_norm[i], 1)
            error += np.abs(np.dot(p2_homogeneous, np.dot(E_mat, p1_homogeneous)))
        return error

    result = minimize(objective, E.flatten(), method='Nelder-Mead')
    return result.x.reshape(3, 3)

#### MOTION BUNDLE ADJUSTMENT ####
def motion_only_bundle_adjustment(R, t, points_3d, points_2d, K):
    def project(points_3d, rvec, tvec, K):
        points_proj, _ = cv2.projectPoints(points_3d, rvec, tvec, K, None)
        return points_proj.reshape(-1, 2)

    def objective(params):
        rvec, tvec = params[:3], params[3:]
        points_proj = project(points_3d, rvec, tvec, K)
        errors = points_proj - points_2d
        return errors.ravel()

    rvec, _ = cv2.Rodrigues(R)
    params = np.hstack((rvec.ravel(), t.ravel()))
    
    result = least_squares(objective, params, loss='soft_l1', f_scale=1.0, verbose=0)

    R_opt, _ = cv2.Rodrigues(result.x[:3])
    t_opt = result.x[3:].reshape(3, 1)

    return R_opt, t_opt


#### FEATURE DISTANCE MATCH TO SEE REPITION AND DECLARE A LOOP CLOSURE ####
def detect_loop_closures(trajectory, descriptors, distance_threshold=5.0, similarity_threshold=0.7):
    loop_closures = []
    nn = NearestNeighbors(n_neighbors=1, metric='euclidean')
    nn.fit(trajectory)
    
    for i in range(len(trajectory)):
        distances, indices = nn.kneighbors([trajectory[i]])
        for j, distance in zip(indices[0], distances[0]):
            if j > i + 10 and distance < distance_threshold:  # Avoid consecutive frames
                matches = flann.knnMatch(descriptors[i], descriptors[j], k=2)
                good_matches = [m for m, n in matches if m.distance < 0.7 * n.distance]
                if len(good_matches) / len(matches) > similarity_threshold:
                    loop_closures.append((i, j))
    
    return loop_closures

### CERATES POSE GRAPH FOR OPTIMIZATION ####
def create_pose_graph(trajectory, loop_closures):
    g = nx.Graph()
    for i in range(len(trajectory) - 1):
        g.add_edge(i, i+1, weight=1)
    
    for i, j in loop_closures:
        g.add_edge(i, j, weight=0.1)  # Lower weight for loop closures
    
    return g

#### POSE GRAPH OPTIMIZATION ####
def optimize_pose_graph(trajectory, pose_graph):
    n = len(trajectory)
    
    def objective(x):
        residuals = []
        for u, v, data in pose_graph.edges(data=True):
            p1 = x[u*3:u*3+3]
            p2 = x[v*3:v*3+3]
            weight = data['weight']
            residuals.append((p2 - p1) * weight)
        return np.concatenate(residuals)
    
    x0 = trajectory.flatten()
    res = least_squares(objective, x0)
    return res.x.reshape(-1, 3)


#### CLAHE IMPLEMENTATION TO IMPROVE CONTRAST IN IMAGE FOR BETTER KEYPOINT IDENTIFICATION ####
def enhance_contrast(image):
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    return clahe.apply(image)


In [None]:

# Kalman Filter initialization
x_kalman = np.zeros(9)
P_kalman = np.eye(9)
Q_kalman = np.eye(9) * 1e-4
R_kalman = np.eye(6) * 0.01  # 6x6 to include roll, pitch, yaw

dt = 2.0  # Time step, can be set based on your frame rate
F_kalman = np.eye(9)
F_kalman[0, 3] = F_kalman[1, 4] = F_kalman[2, 5] = dt

H_kalman = np.zeros((6, 9))
H_kalman[0, 0] = H_kalman[1, 1] = H_kalman[2, 2] = 1
H_kalman[3, 6] = H_kalman[4, 7] = H_kalman[5, 8] = 1

# Main execution
K = np.array([
    [4085.11, 0, 3000],
    [0, 4102.56, 2000],
    [0, 0, 1]
])

image_folder = input("Enter the path to the folder containing images: ")
imu_file = input("Enter the path to the IMU data file: ")

# Read IMU data
imu_data = read_imu_data(imu_file)

image_files = sorted([f for f in os.listdir(image_folder) if f.endswith('.jpg') or f.endswith('.png')])

images = []
gray_images = []
keypoints = []
descriptors = []
imu_matches = []

for image_file in image_files:
    img_path = os.path.join(image_folder, image_file)
    img = cv2.imread(img_path)
    gray_img = enhance_contrast(cv2.cvtColor(img, cv2.COLOR_BGR2GRAY))
    
    images.append(img)
    gray_images.append(gray_img)
    
    k, d = orb_detector_descriptor(gray_img)
    keypoints.append(k)
    descriptors.append(d)
    
    falling_edge, rising_edge = extract_timestamps_from_filename(image_file)
    imu_data_for_image = find_closest_imu_data(imu_data, falling_edge, rising_edge)
    
    if imu_data_for_image is not None:
        imu_matches.append(imu_data_for_image)
    else:
        imu_matches.append([np.nan]*4)  # Placeholder for no match

# FLANN matcher
FLANN_INDEX_LSH = 6
index_params = dict(algorithm=FLANN_INDEX_LSH, table_number=6, key_size=12, multi_probe_level=1)
search_params = dict(checks=50)
flann = cv2.FlannBasedMatcher(index_params, search_params)

R_total = np.eye(3)
t_total = np.zeros((3, 1))
trajectory = [np.zeros(3)]  # Start at origin
all_descriptors = [descriptors[0]]

last_keyframe_kp = keypoints[0]
last_keyframe_desc = descriptors[0]
last_keyframe_index = 0

R_total = np.eye(3)
t_total = np.zeros((3, 1))

for i in range(1, len(images)):
    
    if is_keyframe(keypoints[i], last_keyframe_kp):
        matches = flann.knnMatch(last_keyframe_desc, descriptors[i], k=2) ## TRY CHANGING K=2 for more matches
        good_matches = []
        for match1, match2 in matches:
            if match1.distance < 0.75 * match2.distance:
                good_matches.append(match1)
        
        if len(good_matches) >= 8:
            src_pts = np.float32([last_keyframe_kp[m.queryIdx].pt for m in good_matches]).reshape(-1, 1, 2)
            dst_pts = np.float32([keypoints[i][m.trainIdx].pt for m in good_matches]).reshape(-1, 1, 2)
            
            E, mask = cv2.findEssentialMat(src_pts, dst_pts, K, method=cv2.RANSAC, prob=0.999, threshold=1.0)
            E_refined = refine_E(E, src_pts, dst_pts, K)
            
            _, R, t, mask = cv2.recoverPose(E_refined, src_pts, dst_pts, K)
            R_total = R @ R_total
            t_total = t_total + (R_total @ t)
            
            last_keyframe_kp = keypoints[i]
            last_keyframe_desc = descriptors[i]
            last_keyframe_index = i

            trajectory.append(t_total.flatten())
            all_descriptors.append(descriptors[i])
            
            # Kalman Filter update
            z = np.hstack([t_total.flatten(), imu_matches[i][1:4]])  # Position and IMU angles
            x_kalman_pred = F_kalman @ x_kalman
            P_kalman_pred = F_kalman @ P_kalman @ F_kalman.T + Q_kalman
            y_kalman = z - H_kalman @ x_kalman_pred
            S_kalman = H_kalman @ P_kalman_pred @ H_kalman.T + R_kalman
            K_kalman = P_kalman_pred @ H_kalman.T @ np.linalg.inv(S_kalman)
            x_kalman = x_kalman_pred + K_kalman @ y_kalman
            P_kalman = (np.eye(len(x_kalman)) - K_kalman @ H_kalman) @ P_kalman_pred
            
            trajectory[-1] = x_kalman[:3]  # Update the trajectory with the filtered position

# Detect loop closures
trajectory_np = np.array(trajectory)
loop_closures = detect_loop_closures(trajectory_np, all_descriptors)

# Optimize the trajectory using pose graph optimization
pose_graph = create_pose_graph(trajectory_np, loop_closures)
optimized_trajectory = optimize_pose_graph(trajectory_np, pose_graph)

# Plot the trajectory
plt.figure(figsize=(10, 10))
plt.plot(trajectory_np[:, 0], trajectory_np[:, 2], label="Original Trajectory")
plt.plot(optimized_trajectory[:, 0], optimized_trajectory[:, 2], label="Optimized Trajectory", linestyle='dashed')
plt.xlabel("X")
plt.ylabel("Z")
plt.title("Trajectory")
plt.legend()
plt.show()


In [None]:
plt.figure()
plt.plot(trajectory[:, 0], trajectory[:, 1],'ro-')
plt.title('2D Trajectory')
plt.show()

nx.draw(pose_graph)
plt.show()