In [100]:
import os
import json
import numpy as np
from typing import Dict, List, Tuple, Optional
import torch
import open3d as o3d
from collections import defaultdict

In [101]:
def format_scan_dict(unformated_dict: Dict, attribute: str) -> Dict:
    # Format raw dictionary of object nodes for all scenes
    scan_list = unformated_dict["scans"]
    formatted_dict = {}
    for scan in scan_list:
        formatted_dict[scan["scan"]] = scan[attribute]
    return formatted_dict


def format_sem_seg_dict(sem_seg_dict: Dict) -> Dict:
    seg_list = sem_seg_dict["segGroups"]
    return seg_list
    # return seg_list
    # formatted_dict = {}
    # for seg in seg_list:
    #     formatted_dict[seg["id"]] = seg
    # return formatted_dict


def get_dataset_files(scan_dir):
    object_data = json.load(open(os.path.join(scan_dir, "3DSSG", "objects.json")))
    relationship_data = json.load(open(os.path.join(scan_dir, "3DSSG", "relationships.json")))
    objects_dict = format_scan_dict(object_data, "objects")
    relationships_dict = format_scan_dict(relationship_data, "relationships")
    return objects_dict, relationships_dict

def get_semseg(scan_dir, scan_id):
    sem_seg_data = json.load(open(os.path.join(scan_dir, "semantic_segmentation_data", scan_id, "semseg.v2.json")))
    sem_seg_dict = format_sem_seg_dict(sem_seg_data)
    return sem_seg_dict

def get_annotated_ply(scan_dir, scan_id):
    ply_path = os.path.join(scan_dir, "semantic_segmentation_data", scan_id, "labels.instances.annotated.v2.ply")
    o3d_ply = o3d.io.read_point_cloud(ply_path)
    return o3d_ply


In [102]:
def read_ply_vertices(scan_dir, scan_id):
    filename = os.path.join(scan_dir, "semantic_segmentation_data", scan_id, "labels.instances.annotated.v2.ply")
    with open(filename, "r") as f:
        lines = f.readlines()

    # Parse header to find vertex start
    header_ended = False
    vertex_lines = []
    for i, line in enumerate(lines):
        if line.strip() == "end_header":
            header_ended = True
            vertex_start_idx = i + 1
            break

    if not header_ended:
        raise ValueError("Invalid PLY file: No 'end_header' found.")

    # Extract vertex data
    vertices = []
    for line in lines[vertex_start_idx:]:
        tokens = line.split()
        if len(tokens) != 11:  # Ensure correct format
            continue
        x, y, z = map(float, tokens[:3])
        r, g, b = map(int, tokens[3:6])
        objectId = int(tokens[6])
        globalId = int(tokens[7])
        NYU40 = int(tokens[8])
        Eigen13 = int(tokens[9])
        RIO27 = int(tokens[10])

        vertices.append([x, y, z, r, g, b, objectId, globalId, NYU40, Eigen13, RIO27])

    # Convert to NumPy array
    vertices = np.array(vertices, dtype=np.float32)

    # Group by objectId
    grouped_data = defaultdict(lambda: {"xyz": [], "color": [], "globalId": [], "NYU40": [], "Eigen13": [], "RIO27": []})
    
    for v in vertices:
        obj_id = int(v[6])
        grouped_data[obj_id]["xyz"].append(v[:3])
        grouped_data[obj_id]["color"].append(v[3:6])
        grouped_data[obj_id]["globalId"].append(v[7])
        grouped_data[obj_id]["NYU40"].append(v[8])
        grouped_data[obj_id]["Eigen13"].append(v[9])
        grouped_data[obj_id]["RIO27"].append(v[10])

    # Convert lists to NumPy arrays for efficient processing
    for obj_id in grouped_data:
        grouped_data[obj_id]["xyz"] = np.array(grouped_data[obj_id]["xyz"], dtype=np.float32)
        grouped_data[obj_id]["color"] = np.array(grouped_data[obj_id]["color"], dtype=np.uint8)
        grouped_data[obj_id]["globalId"] = np.array(grouped_data[obj_id]["globalId"], dtype=np.uint16)
        grouped_data[obj_id]["NYU40"] = np.array(grouped_data[obj_id]["NYU40"], dtype=np.uint8)
        grouped_data[obj_id]["Eigen13"] = np.array(grouped_data[obj_id]["Eigen13"], dtype=np.uint8)
        grouped_data[obj_id]["RIO27"] = np.array(grouped_data[obj_id]["RIO27"], dtype=np.uint8)
    
    # merge the array of color, globalId, NYU40, Eigen13, RIO27 to a single array if the values are the same
    for obj_id in grouped_data:
        color = grouped_data[obj_id]["color"]
        globalId = grouped_data[obj_id]["globalId"]
        NYU40 = grouped_data[obj_id]["NYU40"]
        Eigen13 = grouped_data[obj_id]["Eigen13"]
        RIO27 = grouped_data[obj_id]["RIO27"]
        if np.all(color == color[0]):
            color = grouped_data[obj_id]["color"][0]
            grouped_data[obj_id]["color_hex"] = "#{:02x}{:02x}{:02x}".format(color[0], color[1], color[2])
        if np.all(globalId == globalId[0]):
            grouped_data[obj_id]["globalId"] = globalId[0]
        if np.all(NYU40 == NYU40[0]):
            grouped_data[obj_id]["NYU40"] = NYU40[0]
        if np.all(Eigen13 == Eigen13[0]):
            grouped_data[obj_id]["Eigen13"] = Eigen13[0]
        if np.all(RIO27 == RIO27[0]):
            grouped_data[obj_id]["RIO27"] = RIO27[0]

    return grouped_data

In [103]:
import numpy as np
import open3d as o3d
from collections import defaultdict

def read_ply_vertices_faces(scan_dir, scan_id):
    filename = os.path.join(scan_dir, "semantic_segmentation_data", scan_id, "labels.instances.annotated.v2.ply")
    with open(filename, "r") as f:
        lines = f.readlines()

    # Parse header
    header_ended = False
    vertex_lines = []
    face_lines = []
    vertex_count = 0
    face_count = 0
    reading_vertices = False
    reading_faces = False

    for i, line in enumerate(lines):
        tokens = line.split()
        
        if line.startswith("element vertex"):
            vertex_count = int(tokens[-1])
            reading_vertices = True
        
        elif line.startswith("element face"):
            face_count = int(tokens[-1])
            reading_faces = True
        
        elif line.strip() == "end_header":
            header_ended = True
            vertex_start_idx = i + 1
            break

    if not header_ended:
        raise ValueError("Invalid PLY file: No 'end_header' found.")

    print(vertex_count, face_count)

    # Extract vertex data
    vertices = []
    faces = []
    
    for line in lines[vertex_start_idx:]:
        tokens = line.split()
        
        # Reading vertex elements
        if reading_vertices and len(vertices) < vertex_count:
            if len(tokens) != 11:
                continue
            x, y, z = map(float, tokens[:3])
            r, g, b = map(int, tokens[3:6])
            objectId = int(tokens[6])
            globalId = int(tokens[7])
            NYU40 = int(tokens[8])
            Eigen13 = int(tokens[9])
            RIO27 = int(tokens[10])

            vertices.append([x, y, z, r, g, b, objectId, globalId, NYU40, Eigen13, RIO27])

        # Switch to face parsing after vertex count is reached
        elif reading_faces and len(faces) < face_count:
            if len(tokens) < 4:  # Ensure valid face format
                continue
            face_vertex_count = int(tokens[0])  # First number is the vertex count
            if face_vertex_count != 3:
                raise ValueError("Only triangular faces are supported.")
            faces.append([int(tokens[1]), int(tokens[2]), int(tokens[3])])

    # Convert vertex list to NumPy array
    vertices = np.array(vertices, dtype=np.float32)

    # Convert face list to NumPy array
    faces = np.array(faces, dtype=np.int32)

    # Grouping vertices by objectId for clustering
    grouped_data = defaultdict(lambda: {"xyz": [], "color": [], "globalId": [], "NYU40": [], "Eigen13": [], "RIO27": []})
    
    for v in vertices:
        obj_id = int(v[6])
        grouped_data[obj_id]["xyz"].append(v[:3])
        grouped_data[obj_id]["color"].append(v[3:6])
        grouped_data[obj_id]["globalId"].append(v[7])
        grouped_data[obj_id]["NYU40"].append(v[8])
        grouped_data[obj_id]["Eigen13"].append(v[9])
        grouped_data[obj_id]["RIO27"].append(v[10])

    # Convert lists to NumPy arrays
    for obj_id in grouped_data:
        grouped_data[obj_id]["xyz"] = np.array(grouped_data[obj_id]["xyz"], dtype=np.float32)
        grouped_data[obj_id]["color"] = np.array(grouped_data[obj_id]["color"], dtype=np.uint8)
        grouped_data[obj_id]["globalId"] = np.array(grouped_data[obj_id]["globalId"], dtype=np.uint16)
        grouped_data[obj_id]["NYU40"] = np.array(grouped_data[obj_id]["NYU40"], dtype=np.uint8)
        grouped_data[obj_id]["Eigen13"] = np.array(grouped_data[obj_id]["Eigen13"], dtype=np.uint8)
        grouped_data[obj_id]["RIO27"] = np.array(grouped_data[obj_id]["RIO27"], dtype=np.uint8)

    return grouped_data, vertices[:, :3], faces


In [105]:
def get_semseg_with_pointcloud(scan_dir, scan_id):
    sem_seg_dict = get_semseg(scan_dir, scan_id)
    ply_data = read_ply_vertices(scan_dir, scan_id)

    # add xyz field in ply_data to sem_seg_dict
    for key in ply_data:
        object_id = key
        matched_objects = [seg for seg in sem_seg_dict if seg["id"] == object_id]
        if len(matched_objects) == 0:
            print(f"Object {object_id} not found in semseg dict")
            
            continue
        object = matched_objects[0]
        object["xyz"] = ply_data[key]["xyz"]

    return sem_seg_dict

In [99]:
root = "/mnt/Backup/Dataset/3d_vsg/data"
scan_dir = os.path.join(root, "raw")
scans = json.load(open(os.path.join(scan_dir, "3RScan.json")))
objects_dict, relationships_dict = get_dataset_files(scan_dir)


In [88]:
scan_id = "ddc737b3-765b-241a-9c35-6b7662c04fc9"
sem_seg_dict = get_semseg(scan_dir, scan_id)
o3d_ply = get_annotated_ply(scan_dir, scan_id)
ply_data = read_ply_vertices(scan_dir, scan_id)

object_dict = objects_dict[scan_id]
relationship_dict = relationships_dict[scan_id]

# # among object_dict, find the element that has the same "id" as the key in ply_data
# # print the attributes of the object
# for key in ply_data:
#     object_id = key
#     matched_obj = [obj for obj in object_dict if int(obj["id"]) == object_id]
#     if len(matched_obj) > 0:
#         matched_obj = matched_obj[0]

#         # print the attributes of the object
#         print(f"Object Attributes: {matched_obj}")

#         # visualize the object with open3d using the xyz and color data
#         pcd = o3d.geometry.PointCloud()
#         pcd.points = o3d.utility.Vector3dVector(ply_data[key]["xyz"])
#         pcd.colors = o3d.utility.Vector3dVector(ply_data[key]["color"] / 255.0)
#         o3d.visualization.draw_geometries([pcd])


In [None]:
sem_seg_dict_with_pc = get_semseg_with_pointcloud(scan_dir, scan_id)
for seg in sem_seg_dict_with_pc:
    print(f"Segment Attributes: {seg.keys()}")

In [107]:
def get_scene_list(scene: Dict):
    # Returns a list of scan IDs and relative transformation matrices for an entire scene
    scan_id_set = [scene["reference"]]
    scan_tf_set = [np.eye(4)]
    scan_changes = [[]]
    changes = []
    for follow_scan in scene["scans"]:
        scan_id_set.append(follow_scan["reference"])
        if "transform" in follow_scan.keys():
            scan_tf_set.append(np.array(follow_scan["transform"]).reshape((4, 4)).T)
        else:
            scan_tf_set.append(np.eye(4))
        for change in follow_scan["rigid"]:
            if isinstance(change, int):
                changes.append(change)
            else:
                changes.append(change["instance_reference"])
        scan_changes.append(changes.copy())

    return scan_id_set, scan_tf_set, scan_changes

In [None]:
scan_id_set, scan_tf_set, scan_changes = get_scene_list(scans[1])

print(scan_id_set)
print(scan_tf_set)
print(scan_changes)

In [117]:
import os
import json
import torch
import networkx as nx
from typing import Dict, List, Tuple, Optional


def format_sem_seg_dict_orig(sem_seg_dict: Dict) -> Dict:
    object_dict = {}
    for object in sem_seg_dict["segGroups"]:
        object_dict[object["id"]] = object["obb"]["centroid"]

    return object_dict


def build_scene_graph(nodes_dict: Dict, edges_dict: Dict, scan_id: str, scan_dir: str, graph_out=False) -> Tuple:
    # Returns a scene graph from raw data, including:
    #   - Nodes: objects with relevant attributes
    #   = Edges: relationships between objects

    # Extract objects in scan
    if scan_id not in nodes_dict.keys() or scan_id not in edges_dict.keys():
        return None, None, None

    # Extract position information from Semantic Segmentation results
    scan_sem_seg_file = os.path.join(scan_dir, "semantic_segmentation_data", scan_id, "semseg.v2.json")
    if os.path.isfile(scan_sem_seg_file):
        semantic_seg = json.load(open(scan_sem_seg_file))
        object_pos_list = format_sem_seg_dict_orig(semantic_seg)
    else:
        print(f"No Semantic Segmentation File Available for {scan_id}")
        return None, None, None

    # Reformat node dictionary, include only relevant attributes, and add location
    nodes = nodes_dict[scan_id]
    input_node_list = []
    for node in nodes:
        node_copy = node.copy()
        id = int(node["id"])
        att_dict = {"label": node_copy.pop("label", None), "affordances": node_copy.pop("affordances", None),
                    "attributes": node_copy.pop("attributes", None), "global_id": node_copy.pop("global_id", None),
                    "color": node_copy.pop("ply_color", None)}

        if object_pos_list is not None:
            att_dict["attributes"]["location"] = torch.tensor(np.clip(object_pos_list[id], -100, 100)).to(torch.float32)

        att_dict["attributes"].pop("lexical", None)
        input_node_list.append((id, att_dict))

    # Extract edges from raw data
    edges = edges_dict[scan_id]

    # Can output a networkx Graph object for visualization purposes
    if graph_out:
        graph = nx.Graph()
        graph.add_nodes_from(input_node_list)
        for edge in edges:
            graph.add_edge(edge[0], edge[1])
    else:
        graph = None

    return graph, input_node_list, edges

In [None]:
graph, nodes, edges = build_scene_graph(objects_dict, relationships_dict, scan_id, scan_dir, graph_out=True)

# visualize the scene graph
import matplotlib.pyplot as plt
nx.draw(graph, with_labels=True, pos=nx.spring_layout(graph))
plt.show()

print(nodes[0][1].keys())
print(nodes[0][1].values())

In [None]:
for scan in scans:
    for k, v in scan.items():
        print(k, v)
    print()

In [None]:
for key, value in objects_dict.items():
    print(key, "total objects:", len(value))
    for obj in value:
        new_value = {}
        for k, v in obj.items():
            if k in {"ply_color", "label", "affordances", "id", "global_id", "attributes"}:
                new_value[k] = v
        print(new_value)
    break

In [None]:
for key, value in relationships_dict.items():
    print(key)
    for rel in value:
        print(rel)
    break

In [None]:
scan = scans[0]
ref_scan_id = scan["reference"]
rescan_ids = scan["scans"]