In [1]:
import matplotlib.pyplot as plt
from pathlib import Path
import scipy.io as spio
from tqdm import tqdm
from PIL import Image
import numpy as np
import cv2
import os

%load_ext autoreload
%autoreload 2

In [2]:
MIN_CONF = 0.7

# 1. Pre-process data

In [3]:
data_path = Path(os.getcwd()) / 'office'

output_path = data_path / 'output'
output_path.mkdir(parents=True, exist_ok=True)

imgs_fpaths = sorted((data_path / 'rgb').glob('*.png'))

cams_info_fpath = data_path / 'cams_info_no_extr.mat'
kp_fpath = data_path / 'kp.mat'
world_info_fpath = data_path / 'wrld_info.mat'

In [4]:
# Loading cams info

cams_info_raw = spio.loadmat(cams_info_fpath)['cams_info']
cams_info = []
for datapoint in cams_info_raw:
    img = datapoint[0][0][0][0]
    depth_map = datapoint[0][0][0][1]
    conf_map = datapoint[0][0][0][2]
    focal_length = datapoint[0][0][0][3][0][0]
    cams_info.append({'focal_length':focal_length, 'rgb':img, 'depth':depth_map, 'conf': conf_map})

In [5]:
# Loading sift keypoints and descriptions

kp= spio.loadmat(kp_fpath)
keys = list(set(kp.keys()) - {'__header__', '__version__', '__globals__'})
keys = sorted(keys, key=lambda x: int(x.split('_')[1][3:]))

sift = []
for key in keys:
    sift.append({'kp': kp[key][0][0][0], 'desc': kp[key][0][0][1]})

In [6]:
from utils.ImageNode import ImageNode
image_nodes:ImageNode = []

for i in range(len(cams_info)):
    image_nodes.append(ImageNode(
        idx=i,
        rgb=cams_info[i]['rgb'],
        keypoints=sift[i]['kp'],
        descriptors=sift[i]['desc'],
        depth_map=cams_info[i]['depth'],
        conf_map=cams_info[i]['conf'],
        focal_length=cams_info[i]['focal_length']
    ))

In [None]:
from utils.plot import draw_image_keypoints, draw_matches_points, plot_point_cloud_with_keypoints

node = image_nodes[2]
plt.imshow(draw_image_keypoints(node.rgb, node.keypoints))

In [8]:
# node1 = image_nodes[1]
# node2 = image_nodes[2]

# canvas = draw_matches_points(img1=node1.rgb, img2=node2.rgb, src_points=node1.keypoints, dst_points=node2.keypoints, max_points=50)
# plt.imshow(canvas)

# 2. Visualize point cloud

In [9]:
# from utils.plot import plot_point_cloud_with_keypoints

# plot_point_cloud_with_keypoints(node1.point_cloud, node1.keypoints_3d)

# 3. Compute 2d and 3d match between images

In [10]:
from utils.ImageNodeMatch import ImageNodesMatch


nodes_match = {}
for node1 in image_nodes:
    for node2 in image_nodes:
        matching_nodes = ImageNodesMatch(node1, node2)
        nodes_match.update({matching_nodes.pair_idx: matching_nodes})

# 4. Visualize errors

In [11]:
pairs = []
errors = []
for k,v in list(nodes_match.items()):
    if k[0] != k[1]:
        pairs.append(k)
        errors.append(float(v.get_error(transform_type='icp').round(4)))

In [None]:
sorted_pairs_errors = sorted(zip(errors, pairs))
sorted_errors, sorted_pairs = zip(*sorted_pairs_errors)

# Convert back to lists
sorted_errors = list(sorted_errors)
sorted_pairs = list(sorted_pairs)

print(sorted_errors)
print(sorted_pairs)

In [13]:
from utils.plot import register_point_cloud, write_point_cloud
# Registering point clouds based on error

for pair in sorted_pairs[:10]:
    m = nodes_match[pair]
    output_dir = output_path / f'{pair[0]}_{pair[1]}'
    output_dir.mkdir(parents=True, exist_ok=True)

    src_node_pcloud_t = m.apply_transform(m.src_node.point_cloud, m.transform_icp)

    pcd = register_point_cloud(m.dst_node.point_cloud, m.dst_node.cloud_colors)
    write_point_cloud(pcd, str(output_dir / 'dst.ply'))
    
    pcd = register_point_cloud(src_node_pcloud_t, m.src_node.cloud_colors)
    write_point_cloud(pcd, str(output_dir / 'src_t.ply'))

In [14]:
from utils.plot import register_point_cloud, write_point_cloud

m = nodes_match[(0,3)]
src_transformed = m.apply_transform(m.src_node.point_cloud, m.transform_procrustes)
pcd = register_point_cloud(src_transformed, m.src_node.cloud_colors)
write_point_cloud(pcd, 'src_procrustes.ply')


# 4. Manually transforming office dataset

In [15]:
from utils.plot import register_point_cloud, write_point_cloud

# Some direct transforms
direct_matches = [(3,0), (4,0), (7,0)]
for match_id in direct_matches:
    m = nodes_match[match_id]
    src_transformed = m.apply_transform(m.src_node.point_cloud, m.transform_icp)
    pcd = register_point_cloud(src_transformed, m.src_node.cloud_colors)
    write_point_cloud(pcd, f'src_icp_{match_id}.ply')

In [None]:
from utils.PoseGraph import PoseGraph

# 1) Build the pose graph with your existing objects
optimizer = PoseGraph([image_nodes[i] for i in [0, 3, 4, 7]], nodes_match)

# 2) Incrementally register everyone, picking node 0 as the base
optimizer.incremental_registration(base_node_idx=3)

# At this point, 'optimizer.global_poses[node_idx]' should have a
# (refined) transform for each ImageNode in the global coordinate frame.


In [17]:
for idx, transform in optimizer.global_poses.items():
    pcloud_transf = m.apply_transform(image_nodes[idx].point_cloud, transform)
    pcd = register_point_cloud(pcloud_transf, image_nodes[idx].cloud_colors)
    write_point_cloud(pcd, f'{idx}_transformed.ply')