# Image stiching and localization

## Extract the frames from a video

In [None]:
from pathlib import Path
import cv2
import numpy as np
import matplotlib.pyplot as plt
import random

In [None]:
video_path = Path("/home/cv_stu_03/Hierarchical-Localization/project/local_data/video")
images_path = Path("/home/cv_stu_03/Hierarchical-Localization/project/local_data/images_from_local")
output_path = Path("/home/cv_stu_03/Hierarchical-Localization/project/local_data/image_stitch")
# do extract

## Do stiching

### Set the parameters

- 间隔为，选取需要进行定位的那张照片的前后间隔多少帧的照片来进行 image stiching

In [None]:
interval = 24

### Get the images

- reference 是顺序的视频帧，我们需要使用

In [None]:
images_references = sorted([p.relative_to(images_path).as_posix() for p in (images_path).iterdir()])
# references
# for i in range(len(references)):
#     print(references[i])

In [None]:
def get_images(references, images_path, image_reference, interval):
    # references 是上述的得到的图像序列
    # images_path 即为上述的 images path
    # image_reference 是需要做定位的图片，需要找到前后的 interval 的图像
    ret_images = []
    index = references.index(image_reference)
    if (index-interval) < 0:
        return ret_images
    elif (index+interval) >= len(references):
        return ret_images
    
    ret_images.append(cv2.imread(str(images_path/references[index - interval])))
    ret_images.append(cv2.imread(str(images_path/references[index])))
    ret_images.append(cv2.imread(str(images_path/references[index + interval])))
    return np.array(ret_images)

- Do the test

In [None]:
ret_images = get_images(images_references, images_path, images_references[24], interval)
if len(ret_images)==0 :
    print("no images")
else:
    print(ret_images.shape)

### Do image stitching

- show images

In [None]:
img0_rgb = ret_images[0]
img1_rgb = ret_images[1]
img2_rgb = ret_images[2]

print(img0_rgb.shape)

plt.subplot(1,3,1)
plt.imshow(img0_rgb)
plt.subplot(1,3,2)
plt.imshow(img1_rgb)
plt.subplot(1,3,3)
plt.imshow(img2_rgb)
plt.show()

In [None]:
plt.imshow(img1_rgb)

- Use SIFT to compute the keypoints and desciptors

In [None]:
def get_good_matches(image0, image1):
    SIFT = cv2.SIFT_create()

    keys0, descb0 = SIFT.detectAndCompute(image0, None)
    keys1, descb1 = SIFT.detectAndCompute(image1, None)

    bf = cv2.BFMatcher()

    matches = bf.knnMatch(descb0, descb1, 2)

    # print("the lenth of matches is {}".format(len(matches)))

    good_matches = []

    for m, n in matches:
        if m.distance < 0.6*n.distance:
            good_matches.append([m])

    print("the lenth of good matches is {}".format(len(good_matches)))

    match_result1 = cv2.drawMatchesKnn(image0, keys0, image1, keys1, good_matches, None, flags=2)
    # plt.imshow(match_result1)   

    mkpts0 = np.zeros((len(good_matches), 2))
    mkpts1 = np.zeros((len(good_matches), 2))

    for i, match in enumerate(good_matches):
        img0_idx = match[0].queryIdx
        img1_idx = match[0].trainIdx
        
        mkpts0[i, :] = (keys0[img0_idx].pt) # pt is the coordinate of the matched points
        mkpts1[i, :] = (keys1[img1_idx].pt)

    # print("the match points in image 1 is {}".format(mkpts0))
    # print("the match points in image 2 is {}".format(mkpts1))
    
    return mkpts0, mkpts1


- RANSAC

In [None]:
# calculate the inliers
def cal_inliers(sample1, sample2, h, inlier_thr):
    sum_inliers = 0
    h_reshape = np.zeros((3, 3))
    h_reshape[2, 2] = 1
    sample1_reshape = np.ones((3, ))
    sample2_reshape = np.ones((3, ))
    for i in range(sample1.shape[0]):
        h_reshape[0:2, :] = h.reshape((2, 3))
        sample1_reshape[:2] = sample1[i]
        sample2_reshape[:2] = sample2[i]
        if np.sum(np.square(h_reshape @ sample1_reshape - sample2_reshape)) < inlier_thr:
            sum_inliers += 1

    return sum_inliers


# implement your own RANSAC
def ransac_to_estimate_H(sample1, sample2, K, inlier_thr, M): 
    H = None
    max_inliers = 0
    for i in range(M):
        rand_matches = [random.randint(0, sample1.shape[0]-1) for _ in range(K)]
        
        selected_mkpts0 = [sample1[m] for m in rand_matches]
        selected_mkpts1 = [sample2[m] for m in rand_matches]

        A = np.zeros((6, 6))
        b = np.zeros((6, ))
        
        for i in range(K):
            A[2*i,0 :2] = selected_mkpts0[i]
            A[2*i, 2] = 1 
            A[2*i+1, 3:5] = selected_mkpts0[i]
            A[2*i+1, 5] = 1
            b[2*i:2*i+2] = selected_mkpts1[i]

        h = np.linalg.lstsq(A, b, rcond=None)[0]
        if cal_inliers(sample1, sample2, h, inlier_thr) > max_inliers:
            max_inliers = cal_inliers(sample1, sample2, h, inlier_thr)
            H = h
    return H

- Define the stitching function

In [None]:
def image_stitch(image1, image2):
    mkpts0, mkpts1 = get_good_matches(image1, image2)
    
    # transform image1 to stitch on image2
    H = ransac_to_estimate_H(mkpts1, mkpts0, 3, 5, 100) 

    H_optimized = np.zeros((3, 3))
    H_optimized[2, 2] = 1
    H_optimized[:2, :] = H.reshape((2, 3))

    height, width, channels = image1.shape
    dsize = (width*2, height)

    panorama = cv2.warpPerspective(image2, H_optimized, dsize)

    return panorama

- Do the stitching

In [None]:
# plt.imshow(img0_rgb)
temp_stitch = image_stitch(img1_rgb, img0_rgb)
# print(temp_stitch.shape)
final_image = image_stitch(img2_rgb, temp_stitch)
plt.imshow(final_image)

- Shape the image

In [None]:
from PIL import Image

def image_shape(image):
    image_shape = image.shape
    for i in range(image.shape[1]-1, 0, -1):
        if np.sum(image[:, i, :]) != 0:
            break
    
    image_reshape = Image.fromarray(image, mode="RGB")

    image_reshape = image_reshape.crop((0, 0, i+1, image.shape[0]))

    return image_reshape

image_stitched = image_shape(final_image)

image_output_path = Path(output_path/images_references[24])
print(images_references[24])

image_stitched.save(image_output_path, quality=100)
plt.imshow(image_stitched)

- import the library

In [None]:
%load_ext autoreload
%autoreload 2
import tqdm, tqdm.notebook

tqdm.tqdm = tqdm.notebook.tqdm  # notebook-friendly progress bars
from pathlib import Path
import numpy as np

from hloc import (
    extract_features,
    match_features,
    reconstruction,
    visualization,
    pairs_from_exhaustive,
)
from hloc.visualization import plot_images, read_image
from hloc.utils import viz_3d
import torch
import pycolmap

- Initialize the path

In [None]:
torch.cuda.set_device(3)
print("torch.cuda.current_device()=", torch.cuda.get_device_name())

images = Path("/home/cv_stu_03/Hierarchical-Localization/project/images/MainLibrary/ReconstructionDataset")
sfm_dir= Path("/home/cv_stu_03/Hierarchical-Localization/project/Comparison/disk_disk+lightglue/sfm")
sfm_pairs= Path("/home/cv_stu_03/Hierarchical-Localization/project/Comparison/disk_disk+lightglue/pairs-sfm.txt")
features= Path("/home/cv_stu_03/Hierarchical-Localization/project/Comparison/disk_disk+lightglue/features.h5")
matches= Path("/home/cv_stu_03/Hierarchical-Localization/project/Comparison/disk_disk+lightglue/matches.h5")
references = [p.relative_to(images).as_posix() for p in (images ).iterdir()]

In [None]:
model = pycolmap.Reconstruction(str(sfm_dir))

fig = viz_3d.init_figure()
viz_3d.plot_reconstruction(
    fig, model, color="rgba(255,0,0,0.5)", name="mapping", points_rgb=True
)
# fig.show()

In [None]:
feature_conf = extract_features.confs["disk"]
matcher_conf = match_features.confs["disk+lightglue"]
loc_pairs = Path("/home/cv_stu_03/Hierarchical-Localization/project/pairs-from-loc.txt")
extract_features.main(
    feature_conf, images_path, image_list=images_references, feature_path=features, overwrite=True
)
pairs_from_exhaustive.main(loc_pairs, image_list=images_references, ref_list=references)
match_features.main(
    matcher_conf, loc_pairs, features=features, matches=matches, overwrite=True
)

In [None]:
stitch_references = sorted([p.relative_to(output_path).as_posix() for p in (output_path).iterdir()])

In [None]:
import pycolmap
from hloc.localize_sfm import QueryLocalizer, pose_from_cluster

cameras = []
rets = []
logs = []

for i in range(len(stitch_references)):
    # print(image_path/image_reference[i])

    camera = pycolmap.infer_camera_from_image(output_path / stitch_references[i])
    ref_ids = [model.find_image_with_name(r).image_id for r in references]
    conf = {
        "estimation": {"ransac": {"max_error": 12}},
        "refinement": {"refine_focal_length": True, "refine_extra_params": True},
    }
    localizer = QueryLocalizer(model, conf)
    # print(str(image_path/image_reference[i]))
    ret, log = pose_from_cluster(localizer, str(stitch_references[i]), camera, ref_ids, features, matches)
    cameras.append(camera)
    rets.append(ret)
    logs.append(log)

In [None]:
pose = pycolmap.Image(cam_from_world=rets[0]["cam_from_world"])
viz_3d.plot_camera_colmap(
    fig, pose, cameras[0], color="rgba(0,255,0,0.5)", name=stitch_references[0], fill=True
)
# visualize 2D-3D correspodences
inl_3d = np.array(
    [model.points3D[pid].xyz for pid in np.array(logs[0]["points3D_ids"])[rets[0]["inliers"]]]
)
viz_3d.plot_points(fig, inl_3d, color="lime", ps=1, name=stitch_references[0])
fig.show()