In [None]:
%load_ext autoreload
%autoreload 2
import tqdm, tqdm.notebook

tqdm.tqdm = tqdm.notebook.tqdm  # notebook-friendly progress bars
from pathlib import Path
import numpy as np

from hloc import (
    extract_features,
    match_features,
    reconstruction,
    visualization,
    pairs_from_exhaustive,
)
from hloc.visualization import plot_images, read_image
from hloc.utils import viz_3d
import torch
import pycolmap

## run sfm and display 3D model (disk   disk_lightglue)

In [None]:
torch.cuda.set_device(3)
print("torch.cuda.current_device()=", torch.cuda.get_device_name())

images = Path("/home/cv_stu_03/Hierarchical-Localization/project/images/MainLibrary/ReconstructionDataset")
sfm_dir= Path("/home/cv_stu_03/Hierarchical-Localization/project/Comparison/disk_disk+lightglue/sfm")
sfm_pairs= Path("/home/cv_stu_03/Hierarchical-Localization/project/Comparison/disk_disk+lightglue/pairs-sfm.txt")
features= Path("/home/cv_stu_03/Hierarchical-Localization/project/Comparison/disk_disk+lightglue/features.h5")
matches= Path("/home/cv_stu_03/Hierarchical-Localization/project/Comparison/disk_disk+lightglue/matches.h5")
references = [p.relative_to(images).as_posix() for p in (images ).iterdir()]

In [None]:

model = pycolmap.Reconstruction(str(sfm_dir))

fig = viz_3d.init_figure()
viz_3d.plot_reconstruction(
    fig, model, color="rgba(255,0,0,0.5)", name="mapping", points_rgb=True
)
fig.show()

### visiualize keypoints

In [None]:
visualization.visualize_sfm_2d(model, images, color_by="visibility", n=2)

# localization

In [None]:
image_path = Path("/home/cv_stu_03/Hierarchical-Localization/project/local_data/images_from_local")

image_reference =sorted([p.relative_to(image_path).as_posix() for p in (image_path).iterdir()])

print(len(image_reference), "localization images")

plot_images([read_image(image_path/r) for r in image_reference], dpi=25)

## extract features

In [None]:
feature_conf = extract_features.confs["disk"]
matcher_conf = match_features.confs["disk+lightglue"]
loc_pairs = Path("/home/cv_stu_03/Hierarchical-Localization/project/pairs-from-loc.txt")
extract_features.main(
    feature_conf, image_path, image_list=image_reference, feature_path=features, overwrite=True
)
pairs_from_exhaustive.main(loc_pairs, image_list=image_reference, ref_list=references)
match_features.main(
    matcher_conf, loc_pairs, features=features, matches=matches, overwrite=True
)

In [None]:
import pycolmap
from hloc.localize_sfm import QueryLocalizer, pose_from_cluster

cameras = []
rets = []
logs = []

for i in range(len(image_reference)):
    # print(image_path/image_reference[i])

    camera = pycolmap.infer_camera_from_image(image_path / image_reference[i])
    ref_ids = [model.find_image_with_name(r).image_id for r in references]
    conf = {
        "estimation": {"ransac": {"max_error": 12}},
        "refinement": {"refine_focal_length": True, "refine_extra_params": True},
    }
    localizer = QueryLocalizer(model, conf)
    # print(str(image_path/image_reference[i]))
    ret, log = pose_from_cluster(localizer, str(image_reference[i]), camera, ref_ids, features, matches)
    cameras.append(camera)
    rets.append(ret)
    logs.append(log)

# print(f'found {ret["num_inliers"]}/{len(ret["inliers"])} inlier correspondences.')
# visualization.visualize_loc_from_log(images, image_reference, log, model)

In [None]:
for i in range(len(image_reference)):
    print("matrix[",i,"]:   ",rets[i]["cam_from_world"])

# visualize

In [None]:
# for i in range(23,26):
#     pose = pycolmap.Image(cam_from_world=rets[i]["cam_from_world"])
#     viz_3d.plot_camera_colmap(
#         fig, pose, cameras[i], color="rgba(0,255,0,0.5)", name=image_reference[i], fill=True
#     )
#     # visualize 2D-3D correspodences
#     inl_3d = np.array(
#         [model.points3D[pid].xyz for pid in np.array(logs[i]["points3D_ids"])[rets[i]["inliers"]]]
#     )
#     viz_3d.plot_points(fig, inl_3d, color="lime", ps=1, name=image_reference[i])

# fig.show()

## print the camera matrix

In [None]:
# print(rets[2]["cam_from_world"])
# print(rets[3]["cam_from_world"])
# print(rets[4]["cam_from_world"])

# rm1 = rets[2]["cam_from_world"].rotation.matrix()
# rm2 = rets[3]["cam_from_world"].rotation.matrix()
# rm3 = rets[4]["cam_from_world"].rotation.matrix()

# print("rm1: ", rm1)
# print("rm2: ", rm2)
# print("rm3: ", rm3)

# tm1 = rets[2]["cam_from_world"].translation
# tm2 = rets[3]["cam_from_world"].translation
# tm3 = rets[4]["cam_from_world"].translation

# print("tm1: ", tm1)
# print("tm2: ", tm2)
# print("tm3: ", tm3)

# rotation_matrix = ret["cam_from_world"].rotation.matrix()
# print("rotaion: " ,rotation_matrix)

# translation_matrix = ret["cam_from_world"].translation
# print("translation: ",translation_matrix)


# check frame consistency

In [None]:
from scipy.spatial.transform import Rotation as R

def check_pose_consistency(prev_rotation, curr_rotation, prev_translation, curr_translation, threshold_angle=10, threshold_translation=0.5):
    """
    Check if the relative pose difference between two frames is within a reasonable range,
    considering both rotation and translation. The inputs are the rotation and translation matrices
    for both frames.
    
    Args:
        prev_rotation: Previous frame rotation matrix (3x3).
        curr_rotation: Current frame rotation matrix (3x3).
        prev_translation: Previous frame translation vector (3x1).
        curr_translation: Current frame translation vector (3x1).
        threshold_angle: Maximum allowed rotation angle difference (in degrees).
        threshold_translation: Maximum allowed translation difference (in meters).

    Returns:
        True if the poses are consistent, False otherwise.
    """
    # Compute the relative rotation matrix (from prev to curr)
    relative_rotation = curr_rotation.T @ prev_rotation  # curr to prev, so we use the transpose of prev_rotation
    # Convert relative rotation matrix to rotation vector and compute the angle
    rotation_vector = R.from_matrix(relative_rotation).as_rotvec()
    rotation_angle = np.linalg.norm(rotation_vector) * (180 / np.pi)  # Convert to degrees
    
    # Compute the relative translation vector (from prev to curr)
    relative_translation = curr_translation - prev_translation
    translation_distance = np.linalg.norm(relative_translation)  # Euclidean distance between the translation vectors
    
    # Check if the rotation angle and translation distance are within the thresholds
    is_rotation_consistent = rotation_angle < threshold_angle
    is_translation_consistent = translation_distance < threshold_translation
    
    # Both rotation and translation must be consistent
    return is_rotation_consistent and is_translation_consistent



In [None]:
rm1 = rets[23]["cam_from_world"].rotation.matrix()
rm2 = rets[24]["cam_from_world"].rotation.matrix()
rm3 = rets[25]["cam_from_world"].rotation.matrix()

tm1 = rets[23]["cam_from_world"].translation
tm2 = rets[24]["cam_from_world"].translation
tm3 = rets[25]["cam_from_world"].translation

if check_pose_consistency(rm1, rm2, tm1, tm2):
    print("pose between 23 and 24 is consistent")
else:
    print("pose between 23 and 24 is not consistent")

if check_pose_consistency(rm1, rm3, tm1, tm3):
    print("pose between 23 and 25 is consistent")
else:
    print("pose between 23 and 25 is not consistent")


## first initialize the camera pose roughly   
## and fix the wrong poses in the first 10 frames

In [None]:
camera_pose =np.copy(rets[0]["cam_from_world"])
temp_camera_pose = np.copy(rets[0]["cam_from_world"])
temp_rm = np.copy(rets[0]["cam_from_world"].rotation.matrix())
temp_tm = np.copy(rets[0]["cam_from_world"].translation)


for i in range(1,10):
    temp_rm+=rets[i]["cam_from_world"].rotation.matrix()
    temp_tm+=rets[i]["cam_from_world"].translation
temp_rm = temp_rm/10
temp_tm = temp_tm/10


correct_camera_pose =[]
wrong_camera_pose = []
for i in range(10):
    if check_pose_consistency(temp_rm, rets[i]["cam_from_world"].rotation.matrix(), temp_tm, rets[i]["cam_from_world"].translation,threshold_angle=30, threshold_translation=10):
        print("pose[",i,"] is consistent")
        correct_camera_pose.append(rets[i]["cam_from_world"])
    else:
        print("pose[",i,"] is not consistent")
        wrong_camera_pose.append(rets[i]["cam_from_world"])



camera_pose = np.copy(correct_camera_pose[0])
temp_camera_pose = np.copy(correct_camera_pose[(len(correct_camera_pose)-1)])
temp_rm = np.copy(correct_camera_pose[0].rotation.matrix())
temp_tm = np.copy(correct_camera_pose[0].translation)
for i in range(len(correct_camera_pose)):
    if i == 0:
        continue
    else:
        temp_rm+=correct_camera_pose[i].rotation.matrix()
        temp_tm+=correct_camera_pose[i].translation
temp_rm = temp_rm/len(correct_camera_pose)
temp_tm = temp_tm/len(correct_camera_pose)

camera_pose = pycolmap.Rigid3d(temp_rm, temp_tm)


print(camera_pose)
for i in range(len(wrong_camera_pose)):
    rets[i]["cam_from_world"] = camera_pose




## print out other images to check if they are reasonable

In [None]:
for i in range(10,len(image_reference)):
    if check_pose_consistency(camera_pose.rotation.matrix(), rets[i]["cam_from_world"].rotation.matrix(), camera_pose.translation, rets[i]["cam_from_world"].translation,threshold_angle=30, threshold_translation=10):
        print("pose[",i,"] is consistent")
    else:
        print("pose[",i,"] is not consistent")

# fix wrong frame

In [None]:
def fix(i,flag):
    if(flag==0):
        rm1 = rets[i-1]["cam_from_world"].rotation.matrix()
        rm2 = rets[i+1]["cam_from_world"].rotation.matrix()

        tm1 = rets[i-1]["cam_from_world"].translation
        tm2 = rets[i+1]["cam_from_world"].translation

        rm_avg = (rm1 + rm2) / 2
        tm_avg = (tm1 + tm2) / 2

        rets[i]["cam_from_world"] = pycolmap.Rigid3d(rm_avg, tm_avg)
    else:
        rm1 = rets[i-2]["cam_from_world"].rotation.matrix()
        rm2 = rets[i-1]["cam_from_world"].rotation.matrix()

        tm1 = rets[i-2]["cam_from_world"].translation
        tm2 = rets[i-1]["cam_from_world"].translation

        rm_avg = rm2 * 2 - rm1
        tm_avg = tm2 * 2 - tm1

        rets[i]["cam_from_world"] = pycolmap.Rigid3d(rm_avg, tm_avg)



# fix the wrong images

In [None]:
for i in range(10,len(image_reference)-1):
    if check_pose_consistency(camera_pose.rotation.matrix(), rets[i]["cam_from_world"].rotation.matrix(), camera_pose.translation, rets[i]["cam_from_world"].translation,threshold_angle=30, threshold_translation=10):
        # renew the camera pose = [ old camera pose * (i-1) + new camera pose ] / i
        print("pose[",i,"] is consistent")
        

    else:
        if check_pose_consistency(camera_pose.rotation.matrix(), rets[i+1]["cam_from_world"].rotation.matrix(), camera_pose.translation, rets[i+1]["cam_from_world"].translation,threshold_angle=30, threshold_translation=10):
            print("pose[",i,"] is not consistent but pose[",i+1,"] is consistent")
            fix(i,0)
        else:
            print("pose[",i,"] is not consistent and pose[",i+1,"] is not consistent")
            fix(i,1)


    camera_pose = pycolmap.Rigid3d(rets[i]["cam_from_world"].rotation.matrix()/i+(i-1)*camera_pose.rotation.matrix()/i, rets[i]["cam_from_world"].translation/i+(i-1)*camera_pose.translation/i)

if check_pose_consistency(camera_pose.rotation.matrix(), rets[len(image_reference)-1]["cam_from_world"].rotation.matrix(), camera_pose.translation, rets[len(image_reference)-1]["cam_from_world"].translation,threshold_angle=30, threshold_translation=10):
    print("pose[",len(image_reference)-1,"] is consistent")
else:
    print("pose[",len(image_reference)-1,"] is not consistent")
    rets[len(image_reference)-1]["cam_from_world"] = camera_pose

    
print(camera_pose)


In [None]:
for i in range(len(image_reference)):
    print("matrix[",i,"]:   ",rets[i]["cam_from_world"])

# visualize

In [None]:
for i in range(30,35):
    pose = pycolmap.Image(cam_from_world=rets[i]["cam_from_world"])
    viz_3d.plot_camera_colmap(
        fig, pose, cameras[i], color="rgba(0,255,0,0.5)", name=image_reference[i], fill=True
    )
    # visualize 2D-3D correspodences
    inl_3d = np.array(
        [model.points3D[pid].xyz for pid in np.array(logs[i]["points3D_ids"])[rets[i]["inliers"]]]
    )
    viz_3d.plot_points(fig, inl_3d, color="lime", ps=1, name=image_reference[i])

fig.show()