In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import particle_builder as pb
import sam3d as s3d
import open3d as o3d
from psdframe import Frame
from psdstaticdataset import StaticDataset
from pathlib import Path
from matplotlib import pyplot as plt
from segment_anything import build_sam, SamAutomaticMaskGenerator
from util import Voxelize, num_to_natural
from mesh_to_gaussians import batch_triangles_to_splats, splats_to_oriented_discs
import numpy as np

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [3]:
DS_PATH = Path("/home/david/projects/embodied_gaussians/datasets/simulated/single_1/modelling/static")
d = StaticDataset(DS_PATH / "transforms.json")
frames = d.frames
mask_generator = SamAutomaticMaskGenerator(build_sam(checkpoint=pb.sam_checkpoint).to(device="cuda"))
voxelize = Voxelize(voxel_size=pb.VOXEL_SIZE, mode="train", keys=("coord", "color", "group", "normals"))
intermediate_outputs_path = Path("/home/david/projects/SegmentAnything3D/outputs/notebook")

pcd_dict = pb.seg_pcd(d, mask_generator, voxelize, intermediate_outputs_path)

pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(pcd_dict['coord'])
pcd.colors = o3d.utility.Vector3dVector(pcd_dict['color'])
o3d.visualization.draw_geometries([pcd])

merging 6 point clouds


  idx = torch.cuda.IntTensor(m, nsample).zero_()


merging 3 point clouds
merging 2 point clouds


In [4]:
meshes = pb.get_object_meshes(pcd_dict, d)
o3d.visualization.draw_geometries(meshes, mesh_show_back_face=True)


          Initialize
          Found bad data: 228




In [5]:
tblock = meshes[0]
tblock.compute_triangle_normals()
o3d.visualization.draw_geometries([tblock], mesh_show_back_face=True)

In [6]:
triangles = np.asarray(np.asarray(tblock.vertices)[tblock.triangles])
triangle_normals = np.asarray(tblock.triangle_normals)

triangles.shape, triangle_normals.shape

((15680, 3, 3), (15680, 3))

In [7]:
splats = batch_triangles_to_splats(triangles, triangle_normals)

Please either pass the dim explicitly or simply use torch.linalg.cross.
The default value of dim will change to agree with that of linalg.cross in a future release. (Triggered internally at ../aten/src/ATen/native/Cross.cpp:62.)
  t = 2.0 * torch.cross(q_xyz, v)


In [8]:
splats['centers'].shape, splats['scales'].shape, splats['rotations'].shape

((15680, 3), (15680, 3), (15680, 4))

In [10]:
discs = splats_to_oriented_discs(splats['centers'], splats['scales'], splats['rotations'])
o3d.visualization.draw_geometries(discs, mesh_show_back_face=True)

In [18]:
box = o3d.geometry.TriangleMesh.create_box(0.5,0.5,0.5)
box.compute_triangle_normals()
triangles = np.asarray(np.asarray(box.vertices)[box.triangles])
triangle_normals = np.asarray(box.triangle_normals)

#mask = triangle_normals[:, 2] == 1  # looks like y axis is borked
#triangle_normals = triangle_normals[mask]
#triangles  = triangles[mask]


splats = batch_triangles_to_splats(triangles, triangle_normals)
discs = splats_to_oriented_discs(splats['centers'], splats['scales'], splats['rotations'])

o3d.visualization.draw_geometries([box, *discs], mesh_show_back_face=True)

KeyboardInterrupt: 

In [55]:
splats_to_oriented_discs(splats['centers'], splats['scales'], splats['rotations'])

[TriangleMesh with 33 points and 32 triangles.,
 TriangleMesh with 33 points and 32 triangles.]

In [53]:
splats['scales']

array([[3.53553391e-01, 1.76776695e-01, 1.00000000e-04],
       [3.53553391e-01, 1.76776695e-01, 1.00000000e-04]])

In [50]:
triangle_normals

array([[0., 1., 0.],
       [0., 1., 0.]])

In [49]:
splats['rotations']

array([[ 0.7071068, -0.7071068,  0.       ,  0.       ],
       [ 0.7071068, -0.7071068,  0.       ,  0.       ]], dtype=float32)

In [39]:
mask = triangle_normals[:, 0] == 0
triangle_normals[mask]

array([[ 0.,  1.,  0.],
       [ 0.,  1.,  0.],
       [ 0., -1.,  0.],
       [ 0., -1.,  0.],
       [ 0.,  0.,  1.],
       [ 0.,  0.,  1.],
       [ 0.,  0., -1.],
       [ 0.,  0., -1.]])

In [44]:
from mesh_to_gaussians import create_aligned_ellipsoid
create_aligned_ellipsoid(np.array([[1,0,0]]))


array([[0.7071068, 0.       , 0.7071068, 0.       ]], dtype=float32)

In [45]:
create_aligned_ellipsoid(np.array([[0,1,0]]))

array([[ 0.7071068, -0.7071068,  0.       ,  0.       ]], dtype=float32)

In [46]:
create_aligned_ellipsoid(np.array([[0,0,1]]))

array([[1., 0., 0., 0.]], dtype=float32)

In [47]:
create_aligned_ellipsoid(np.array([[0,-1,0]]))

array([[ 0.7071068,  0.7071068,  0.       , -0.       ]], dtype=float32)

In [51]:

def normal_to_quaternion(normal):
    """
    Convert a normal vector to a rotation quaternion that would rotate [0,0,1] to align with the normal.
    
    Args:
        normal: A numpy array or list containing [x,y,z] coordinates of the normal vector
        
    Returns:
        quaternion: A numpy array [w,x,y,z] representing the rotation quaternion
    """
    # Convert input to numpy array and normalize
    normal = np.array(normal, dtype=float)
    normal = normal / np.linalg.norm(normal)
    
    # Default up vector we're rotating from
    up = np.array([0, 0, 1])
    
    # Get rotation axis and angle
    axis = np.cross(up, normal)
    
    # If normal is parallel to up vector, handle specially
    if np.allclose(axis, 0):
        if np.allclose(normal, up):
            return np.array([1, 0, 0, 0])  # Identity quaternion
        else:
            return np.array([0, 1, 0, 0])  # 180° rotation around X
    
    axis = axis / np.linalg.norm(axis)
    angle = np.arccos(np.dot(up, normal))
    
    # Convert axis-angle to quaternion
    w = np.cos(angle / 2)
    xyz = axis * np.sin(angle / 2)
    
    return np.array([w, xyz[0], xyz[1], xyz[2]])

# Example usage for [0,1,0]
normal = [0, 1, 0]
quaternion = normal_to_quaternion(normal)
print(f"Normal {normal} converted to quaternion [w,x,y,z]: {quaternion}")

Normal [0, 1, 0] converted to quaternion [w,x,y,z]: [ 0.70710678 -0.70710678  0.          0.        ]


In [22]:
result = pb.initialize_scene(d, None, intermediate_outputs_path)

merging 6 point clouds
merging 3 point clouds
merging 2 point clouds


          Initialize
          Found bad data: 228
Mesh 1 is not clean and watertight, ignoring


Dims: [15 15  3]
Dims: [21 17 44]
Dims: [13  5  2]


In [24]:
len(result.objects)

3

In [26]:
gaussians = result.objects[0].gaussians
particles = result.objects[0].particles

In [27]:
# create spheres from particles, and move to particles position
spheres = [o3d.geometry.TriangleMesh.create_sphere(radius=particle.radius) for particle in particles]
for sphere, particle in zip(spheres, particles):
    sphere.translate(particle.xyz)
o3d.visualization.draw_geometries([*spheres], mesh_show_back_face=True)

In [30]:
# create discs from gaussians, and move to gaussian position
discs = splats_to_oriented_discs(gaussians.xyz, gaussians.scaling, gaussians.rotations, gaussians.colors, resolution=12)
o3d.visualization.draw_geometries(discs, mesh_show_back_face=True)