In [1]:
import numpy as np
import os
import sys
import cv2
import matplotlib.pyplot as plt
import open3d as o3d
%matplotlib inline 

src_path = os.path.abspath("../..")
if src_path not in sys.path:
    sys.path.append(src_path)
%load_ext autoreload
from dataset.kitti_odometry_dataset import KittiOdometryDataset, KittiOdometryDatasetConfig
from dataset.filters.filter_list import FilterList
from dataset.filters.kitti_gt_mo_filter import KittiGTMovingObjectFilter
from dataset.filters.range_filter import RangeFilter
from dataset.filters.apply_pose import ApplyPose

import networkx as nx
from scipy.spatial.distance import cdist
from sklearn.cluster import Birch, KMeans, MeanShift, DBSCAN
from sklearn.decomposition import TruncatedSVD

from point_cloud_utils import get_pcd
from aggregate_pointcloud import aggregate_pointcloud
from reproject_merged_pointcloud import reproject_points_to_label, merge_associations

Here we define the dataset depending on kitti sequence!

In [2]:
DATASET_PATH = os.path.join('/Users/laurenzheidrich/Downloads/','fused_dataset')
SEQUENCE_NUM = 7

config_filtered = KittiOdometryDatasetConfig(
    cache=True,
    dataset_path=DATASET_PATH,
    correct_scan_calibration=True,
    filters=FilterList(
        [
            KittiGTMovingObjectFilter(
                os.path.join(
                    DATASET_PATH,
                    "sequences",
                    "%.2d" % SEQUENCE_NUM,
                    "labels",
                )
            ),
            RangeFilter(2.5, 120),
            ApplyPose(),
        ]
    ),
)

dataset = KittiOdometryDataset(config_filtered, SEQUENCE_NUM)

Now we read in the point cloud and the left and right image of the stereo camera. If labels for those images are available they can be read in, too!

In [3]:
ind_start = 58
ind_end = 59

pcd, T_pcd = aggregate_pointcloud(dataset, ind_start, ind_end, clip_to_imageframe=True)

In [44]:
cams = ["cam2", "cam3"]
cam_ind = 0
point_to_label_reprojections = []

for points_index in range(ind_start+5, ind_end+5):
	label_PIL = dataset.get_sam_label(cams[cam_ind], points_index)
	label = cv2.cvtColor(np.array(label_PIL), cv2.COLOR_RGB2BGR)

	T_world2lidar = np.linalg.inv(dataset.get_pose(points_index))
	T_lidar2cam, K = dataset.get_calibration_matrices(cams[cam_ind])
	T_world2cam = T_lidar2cam @ T_world2lidar
	
	point_to_label_reprojections.append(reproject_points_to_label(np.array(pcd.points), T_pcd, label, T_world2cam, K, hidden_point_removal=False))

In [45]:
association_matrix = merge_associations(point_to_label_reprojections, len(pcd.points))

In [34]:
counter = 0
for i in range (num_points):
    if not np.any(association_matrix[i].toarray()):
        counter += 1

In [36]:
print(num_points)

227357


In [24]:
num_points = len(pcd.points)
proximity_threshold = 0.5
alpha = 1.0
beta = 1.0
points = np.asarray(pcd.points)

In [7]:
dist_matrix = cdist(points, points)

In [8]:
G = nx.Graph()
for i in range(num_points):
    G.add_node(i, pos=points[i], associations=association_matrix[i])

In [9]:
for i in range(num_points):
    for j in range (i+1, num_points):
        if dist_matrix[i, j] <= proximity_threshold:
            dist_weight = np.exp(-alpha * dist_matrix[i, j])
            feature_weight = np.exp(-beta * np.sum(G.nodes[i]['associations'] != G.nodes[j]['associations']))
            G.add_edge(i, j, weight=(dist_weight * feature_weight))

In [15]:
isolated_nodes = [node for node, degree in G.degree() if degree == 0]
print(len(isolated_nodes))
G.remove_nodes_from(isolated_nodes)

53


In [16]:
connected_components = list(nx.connected_components(G))
print(len(connected_components))

139


In [28]:
#pcd_test= filter_points_from_dict(np.array(pcd.points), point_to_label_reprojections[5])
#pcd_test_o3d = color_pcd_with_labels(pcd_test, point_to_label_reprojections[5])
#o3d.visualization.draw_geometries([pcd_test_o3d])

In [46]:
n_components = 75
tSVD = TruncatedSVD(n_components=68)
transformed_data = tSVD.fit_transform(association_matrix)

In [47]:
# Birch is fastest, results are also quite nice
birch_model = Birch(threshold=0.1, n_clusters=67)
birch_model.fit(transformed_data)
labels = birch_model.predict(transformed_data)

# kmeans is quite fast, results look okay, some noise
#kmeans = KMeans(n_clusters=80)
#labels = kmeans.fit_predict(transformed_data)

# DBSCAN doesn't really produce great results and takes really long
#dbscan = DBSCAN(eps=0.5, min_samples=100)
#labels = dbscan.fit_predict(transformed_data)

# Meanshift takes way too long
#meanshift = MeanShift()
#labels = meanshift.fit_predict(transformed_data)

In [48]:
print("Cluster Assignments:")
print(len(set(labels)))

Cluster Assignments:
67


In [49]:
colorspace = plt.cm.rainbow(np.linspace(0, 1, len(set(labels))))[:, :3]
colors = [colorspace[i] for i in labels]
pcd.colors = o3d.utility.Vector3dVector(np.array(colors))

In [50]:
o3d.io.write_point_cloud("pcd_merge_clustered.pcd", pcd, write_ascii=False, compressed=False, print_progress=False)

True