<a href="https://colab.research.google.com/github/S1ink/Colabs/blob/main/models/pc_transformation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install open3d

In [None]:
import os
import shutil
import json
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
from glob import glob
from scipy.spatial.transform import Rotation as R
from google.colab import files
from google.colab import drive

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import utils

import matplotlib.pyplot as plt
import open3d as o3d
import plotly.graph_objects as go
import plotly.express as px

In [None]:
content = files.upload()
pnames = list(content.keys())
for f in tqdm(pnames):
    shutil.unpack_archive(f)
    print("\nUnpacked {}!".format(f))

Saving 566215_266254_2023-09-18 11_53_31.878.tar.gz to 566215_266254_2023-09-18 11_53_31.878.tar.gz


100%|██████████| 1/1 [00:00<00:00, 65.05it/s]


Unpacked 566215_266254_2023-09-18 11_53_31.878.tar.gz!





In [None]:
drive.mount("/content/gdrive")

Mounted at /content/gdrive


In [None]:
!ls -l

total 8
drwx------ 6 root root 4096 Sep 24 03:27 gdrive
drwxr-xr-x 1 root root 4096 Sep 21 13:49 sample_data


In [None]:
DS_BASE = "./gdrive/MyDrive/CSM/shoes-grass-dataset/"  # << SET THIS BASED ON IMPORTED FILES/CONNECTED GDRIVE FILE STRUCTURE
KEY_MAP_JSON = DS_BASE + "key_id_map.json"
META_JSON = DS_BASE + "meta.json"

def _annotation(fname):
    return DS_BASE + "ds0/ann/{}".format(fname)
def _source(fname):
    return DS_BASE + "ds0/pointcloud/{}".format(fname)

def _basename(path):
    return path.split("/")[-1].split(".")[0]


with open(KEY_MAP_JSON) as json_file:
    keymap = json.load(json_file)

with open(META_JSON) as json_file:
    meta = json.load(json_file)   # have the metadata for all the classes and colors

DEFAULT_LABEL = "none"
class_colors = dict()
ALL_LABELS = set()
class_colors[DEFAULT_LABEL] = "gray"
ALL_LABELS.add(DEFAULT_LABEL)
for c in meta["classes"]:
    label = c["title"]
    class_colors[label] = c["color"]
    ALL_LABELS.add(label)

# print(class_colors)

# print(meta)
# print(map)

point_sets = []         # extracted points for each point cloud
point_set_labels = []   # labels for each point in each point cloud
point_set_1hot = []     # 1hot array for each point

print(class_colors, ALL_LABELS)

annotations = glob(_annotation("*.pcd.json"))
for idx, a in enumerate(tqdm(annotations, "Files: ")):
    source = _source(_basename(a) + ".pcd")
    # print("\n", source)
    if(os.path.exists(source)):
        with open(a) as json_file:
           a_data = json.load(json_file)    # have the configuration for the bounds

        pcd = o3d.io.read_point_cloud(source)
        pc_data = np.asarray(pcd.points)    # have all the points as numpy array

        # print(a_data)
        # print(pc_data.shape)

        objects = a_data["objects"]
        figures = a_data["figures"]     # list of annotation volume
        label_bounds = []
        for f in figures:
            if f["geometryType"] != "cuboid_3d":
                continue
            objkey = f["objectKey"]
            geo = f["geometry"]
            class_label = next(i for i in objects if i["key"] == objkey)["classTitle"]  # find the first dictionary with a matching "key":objkey pair --> use that dictionaries class label

            # print(objkey, "\n", geo)
            # print("\n", class_label)

            # extract cuboid bbox data
            origin = np.array(list(geo["position"].values()))
            scale = np.array(list(geo["dimensions"].values())) / 2  # divide by 2 to get distance from center along each direction
            rotation = np.array(list(geo["rotation"].values()))     # rotation about x, y, z
            r = R.from_rotvec(rotation)     # convert to rotation matrix
            units = np.array([
                r.apply([1, 0, 0]),     # rotate unit x, y, z vectors by the cuboid's rotation so we can compare them to each point's offset from the center
                r.apply([0, 1, 0]),
                r.apply([0, 0, 1])
            ])

            label_bounds.append((origin, scale, units, class_label))    # append the center, lengths, unit directions, and class label

        # print(label_bounds)

        set_labels = []
        LABELS = list(ALL_LABELS)
        for i, p in enumerate(tqdm(pc_data, "Point Cloud[{}]: ".format(idx))):
            valid = False
            for l in label_bounds:

                v = p - l[0];
                d = np.array([
                    np.abs(np.dot(l[2][0], v)),
                    np.abs(np.dot(l[2][1], v)),
                    np.abs(np.dot(l[2][2], v))
                ])

                if (d <= l[1]).all():
                    valid = True
                    set_labels.append(l[3])
                    break
            if not valid:
                set_labels.append(DEFAULT_LABEL)

        set_1hot = [LABELS.index(label) for label in set_labels]
        set_1hot = tf.keras.utils.to_categorical(set_1hot, num_classes=len(ALL_LABELS))
        # print(set_1hot)

        # print(set_labels)
        point_sets.append(pc_data)
        point_set_labels.append(set_labels)
        point_set_1hot.append(set_1hot)

# print(point_sets, point_set_labels)
for i in tqdm(range(len(point_sets))):
    print(np.array(point_sets[i]).shape, np.array(point_set_1hot[i]).shape, np.array(point_set_labels[i]).shape)


{'none': 'gray', 'Shoe': '#50E3C2'} {'Shoe', 'none'}


Files:   0%|          | 0/13 [00:00<?, ?it/s]
Point Cloud[0]:   0%|          | 0/26005 [00:00<?, ?it/s][A
Point Cloud[0]:   5%|▍         | 1282/26005 [00:00<00:01, 12815.40it/s][A
Point Cloud[0]:  10%|▉         | 2564/26005 [00:00<00:02, 11002.82it/s][A
Point Cloud[0]:  15%|█▍        | 3828/26005 [00:00<00:01, 11686.94it/s][A
Point Cloud[0]:  19%|█▉        | 5012/26005 [00:00<00:01, 11404.90it/s][A
Point Cloud[0]:  24%|██▎       | 6162/26005 [00:00<00:01, 10674.80it/s][A
Point Cloud[0]:  29%|██▉       | 7530/26005 [00:00<00:01, 11615.37it/s][A
Point Cloud[0]:  33%|███▎      | 8707/26005 [00:00<00:01, 10402.84it/s][A
Point Cloud[0]:  38%|███▊      | 9811/26005 [00:00<00:01, 10581.31it/s][A
Point Cloud[0]:  42%|████▏     | 10891/26005 [00:01<00:01, 10146.63it/s][A
Point Cloud[0]:  46%|████▌     | 11922/26005 [00:01<00:01, 9726.76it/s] [A
Point Cloud[0]:  50%|████▉     | 12907/26005 [00:01<00:01, 9748.09it/s][A
Point Cloud[0]:  55%|█████▍    | 14209/26005 [00:01<00:01, 10672.8

(26005, 3) (26005, 2) (26005,)
(22696, 3) (22696, 2) (22696,)
(14028, 3) (14028, 2) (14028,)
(30738, 3) (30738, 2) (30738,)
(37981, 3) (37981, 2) (37981,)
(26588, 3) (26588, 2) (26588,)
(27960, 3) (27960, 2) (27960,)
(53990, 3) (53990, 2) (53990,)
(28928, 3) (28928, 2) (28928,)
(14749, 3) (14749, 2) (14749,)
(27801, 3) (27801, 2) (27801,)
(37981, 3) (37981, 2) (37981,)
(22273, 3) (22273, 2) (22273,)





In [None]:
def visualize_data(pc, labels):
    xm = np.min(pc[:,0])
    xM = np.max(pc[:,0])
    ym = np.min(pc[:,1])
    yM = np.max(pc[:,1])
    zm = np.min(pc[:,2])
    zM = np.max(pc[:,2])
    r = np.max([xM - xm, yM - ym, zM - zm])
    df = pd.DataFrame(
        data={
            "x": pc[:, 0],
            "y": pc[:, 1],
            "z": pc[:, 2],
            "label": labels
        }
    )
    fig = plt.figure(figsize=(15, 10))
    ax = plt.axes(projection="3d")
    ax.set_xlim3d(left=xm, right=(xm + r))
    ax.set_ylim3d(bottom=ym, top=(ym + r))
    ax.set_zlim3d(bottom=zm, top=(zm + r))
    for label in ALL_LABELS:
        c_df = df[df["label"] == label]
        try:
            ax.scatter(
                c_df["x"], c_df["y"], c_df["z"], label=label, alpha=0.5, c=class_colors[label]
            )
        except IndexError:
            pass
    ax.legend()
    plt.show()

for i in tqdm(range(len(point_sets))):
    visualize_data(point_sets[i], point_set_labels[i])

In [None]:
def visualize_rotate(points, labels):
    colors = [class_colors[label] for label in labels]
    x, y, z = np.array(points).T
    x_eye, y_eye, z_eye = 1.25, 1.25, 0.8
    frames=[]

    def rotate_z(x, y, z, theta):
        w = x+1j*y
        return np.real(np.exp(1j*theta)*w), np.imag(np.exp(1j*theta)*w), z

    for t in np.arange(0, 10.26, 0.1):
        xe, ye, ze = rotate_z(x_eye, y_eye, z_eye, -t)
        frames.append(dict(layout=dict(scene=dict(camera=dict(eye=dict(x=xe, y=ye, z=ze))))))
    fig = go.Figure(data=go.Scatter3d(
                        x=x, y=y, z=z,
                        mode='markers',
                        opacity=0.6,
                        marker=dict(
                            size=5,
                            color=colors
                        )
                    ),
                    layout=go.Layout(
                        updatemenus=[dict(type='buttons',
                                    showactive=False,
                                    y=1,
                                    x=0.8,
                                    xanchor='left',
                                    yanchor='bottom',
                                    pad=dict(t=45, r=10),
                                    buttons=[dict(label='Play',
                                                    method='animate',
                                                    args=[None, dict(frame=dict(duration=50, redraw=True),
                                                                    transition=dict(duration=0),
                                                                    fromcurrent=True,
                                                                    mode='immediate'
                                                                    )]
                                                    )
                                            ]
                                    )
                                ],
                        scene_aspectmode='data'
                    ),
                    frames=frames
            )
    return fig

# for i, point_set in enumerate(tqdm(point_sets)):
#     visualize_rotate(point_set, point_set_labels[i]).show()
visualize_rotate(point_sets[10], point_set_labels[10]).show()


In [None]:
LABEL_POINT_PATH = _annotation("all_labels/")

for i, ls in enumerate(tqdm(point_set_labels)):
    f = LABEL_POINT_PATH + _basename(annotations[i]) + ".txt"
    np.array(ls).tofile(f, sep=',', format='%s')


100%|██████████| 13/13 [00:00<00:00, 15.39it/s]


In [None]:
STATIC_COUNT = 25000

for index in tqdm(range(len(point_sets))):
    current_point_cloud = point_sets[index]
    current_label_cloud = point_set_1hot[index]
    current_labels = point_set_labels[index]
    num_points = len(current_point_cloud)
    # Randomly sampling respective indices.
    sampled_indices = None
    if(num_points < STATIC_COUNT):
        n = STATIC_COUNT // num_points
        sampled_indices = [i for i in range(num_points)]*n + random.sample(list(range(num_points)), STATIC_COUNT % num_points)
    else:
        sampled_indices = random.sample(list(range(num_points)), STATIC_COUNT)
    # Sampling points corresponding to sampled indices.
    sampled_point_cloud = np.array([current_point_cloud[i] for i in sampled_indices])
    # Sampling corresponding one-hot encoded labels.
    sampled_label_cloud = np.array([current_label_cloud[i] for i in sampled_indices])
    # Sampling corresponding labels for visualization.
    sampled_labels = np.array([current_labels[i] for i in sampled_indices])
    # Normalizing sampled point cloud.
    norm_point_cloud = sampled_point_cloud - np.mean(sampled_point_cloud, axis=0)
    norm_point_cloud /= np.max(np.linalg.norm(norm_point_cloud, axis=1))
    point_sets[index] = norm_point_cloud
    point_set_1hot[index] = sampled_label_cloud
    point_set_labels[index] = sampled_labels

100%|██████████| 13/13 [00:01<00:00,  7.25it/s]


In [None]:
for i in tqdm(range(len(point_sets))):
    print(point_sets[i].shape, point_set_1hot[i].shape, point_set_labels[i].shape)

visualize_rotate(point_sets[10], point_set_labels[10])

100%|██████████| 13/13 [00:00<00:00, 17894.96it/s]


(25000, 3) (25000, 2) (25000,)
(25000, 3) (25000, 2) (25000,)
(25000, 3) (25000, 2) (25000,)
(25000, 3) (25000, 2) (25000,)
(25000, 3) (25000, 2) (25000,)
(25000, 3) (25000, 2) (25000,)
(25000, 3) (25000, 2) (25000,)
(25000, 3) (25000, 2) (25000,)
(25000, 3) (25000, 2) (25000,)
(25000, 3) (25000, 2) (25000,)
(25000, 3) (25000, 2) (25000,)
(25000, 3) (25000, 2) (25000,)
(25000, 3) (25000, 2) (25000,)
