In [16]:
import os
import open3d as o3d
import numpy as np
import torch
from tqdm import tqdm

data_dir = "/data/Predict-Pneumoperitoneum_LaB-GATr/pointclouds_and_meshes"
disp_dir = "/data/Predict-Pneumoperitoneum_LaB-GATr/dataset/raw"


def orient_normals_negative_zaxis(normals):
    normals = np.asarray(normals)
    mask = normals[:, 2] > 0
    normals[mask] *= -1
    normals = o3d.utility.Vector3dVector(normals)
    return normals


patients = sorted([os.path.join(data_dir, p) for p in os.listdir(data_dir) if p.startswith("24_")])
for patient in tqdm(patients):
    # Keys: pointcloud_begin (N, 3), input_points (N, 3), displacements (N, 3), target_points (M, 3),
    # patient_features (5), annotations_start (13, 3), annotations_end (13, 3)
    data = torch.load(os.path.join(disp_dir, os.path.basename(patient) + ".pt"))
    begin_mesh = o3d.io.read_triangle_mesh(os.path.join(patient, "filtered_begin_mesh.ply"))
    end_mesh = o3d.io.read_triangle_mesh(os.path.join(patient, "filtered_end_mesh.ply"))

    # Direct all normals such that they face the camera direction (fix inverted normals)
    begin_mesh.vertex_normals = orient_normals_negative_zaxis(begin_mesh.vertex_normals)
    end_mesh.vertex_normals = orient_normals_negative_zaxis(end_mesh.vertex_normals)

    # Check if lengths are correct
    N = len(begin_mesh.vertex_normals)
    M = len(end_mesh.vertex_normals)
    if len(data["input_points"]) != N or len(data["displacements"]) != N or len(data["target_points"]) != M:
        print(f"{N=}, {len(data['input_points'])=}, {len(data['displacements'])}")
        print(f"{M=}, {len(data['target_points'])=}")

    # Write normals to data files
    data["input_normals"] = torch.as_tensor(np.asarray(begin_mesh.vertex_normals))
    data["target_normals"] = torch.as_tensor(np.asarray(end_mesh.vertex_normals))
    torch.save(data, os.path.join(disp_dir, os.path.basename(patient) + ".pt"))
    o3d.io.write_triangle_mesh(os.path.join(patient, "filtered_begin_mesh_normals.ply"), begin_mesh)
    o3d.io.write_triangle_mesh(os.path.join(patient, "filtered_end_mesh_normals.ply"), end_mesh)

 98%|█████████▊| 64/65 [00:08<00:00,  8.11it/s]

N=100875, len(data['input_points'])=100875, 100875
M=107261, len(data['target_points'])=106367


100%|██████████| 65/65 [00:08<00:00,  7.97it/s]
