Here is some sample code to show how to load/use the latents and poses.

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import h5py
import numpy as np
import trimesh

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load latents and poses
Latent vectors are stored as a $N\times P\times L$ tensor, with $N$ the number of shapes in the dataset, $P$ the number of primitives, and $L$ the latent space.

Similarly, poses are stored as a $N\times P\times 10$ tensor, with the last 10 dimensions corresponding to:
* rotation (as quaternion) $\in\mathbb{R}^4$
* translation $\in\mathbb{R}^3$
* scale $\in\mathbb{R}^3$

For smaller tests, don't hesitate to subsample less shapes.

In [None]:
# How to load latents as a tensor:
latents = torch.load("latents.pth", map_location=device)['weight']
print(latents.shape)

In [None]:
# How to load poses as a tensor:
poses = torch.load("poses.pth", map_location=device)['weight']
print(poses.shape)

In [None]:
scene = trimesh.Scene()

mesh = trimesh.primitives.Box(extents=(1, 1, 1))

# Create a 4x4 transformation matrix for rotation

rotation_quat = [0, 0, 0, 1]
translation = [1,0,0]
rottrans_matrix = np.eye(4)
rottrans_matrix[:3, :3] = R.from_quat(rotation_quat).as_matrix()
rottrans_matrix[:3, 3] = translation

# Apply rotation and translation
mesh.apply_transform(rottrans_matrix)


scene.add_geometry(mesh)

# Axis length for visualization
axis_length = 1.0

# Create the X, Y, and Z axis lines as Path3D objects with per-segment colors
x_axis = trimesh.load_path([[0, 0, 0], [axis_length, 0, 0]])
y_axis = trimesh.load_path([[0, 0, 0], [0, axis_length, 0]])
z_axis = trimesh.load_path([[0, 0, 0], [0, 0, axis_length]])

# Assign colors to the vertices of each axis
x_axis.entities[0].color = [255, 0, 0, 255]  # Red for X-axis
y_axis.entities[0].color = [0, 255, 0, 255]  # Green for Y-axis
z_axis.entities[0].color = [0, 0, 255, 255]  # Blue for Z-axis

# Add the axes to the scene
scene.add_geometry(x_axis)
scene.add_geometry(y_axis)
scene.add_geometry(z_axis)

# Show the scene
scene.show()


# Show the scene
scene.show()



### Visualize Poses

In [None]:
import numpy as np
import trimesh
from scipy.spatial.transform import Rotation as R

# Assuming `dataset` is your dataset (num_examples, num_parts, poses)
# Select an example and the first part for simplicity
example = poses[76]  # Choose the first example
print(example.shape)

# Create an empty scene to hold all parts
scene = trimesh.Scene()

# Assuming `example` is a tensor with shape (num_parts, poses)
num_parts = example.shape[0]

for i in range(0, num_parts):
    print(f"Processing part {i}")
    pose = example[i]
    translation = pose[4:7].cpu().numpy()          # Translation parameters (x, y, z)
    rotation_quat = pose[0:4].cpu().numpy()        # Quaternion (x, y, z, w)
    scale = pose[7:].cpu().numpy()               # Scale parameters (x, y, z)

    # Create a basic mesh (e.g., a box with the given scale)
    mesh = trimesh.primitives.Box(extents=scale)

    rottrans_matrix = np.eye(4)
    rottrans_matrix[:3, :3] = R.from_quat(rotation_quat).as_matrix()
    rottrans_matrix[:3, 3] = translation

    # Apply rotation and translation
    mesh.apply_transform(rottrans_matrix)

    # Add the transformed mesh to the scene
    scene.add_geometry(mesh)

# Axis length for visualization
axis_length = 1.0

# Create the X, Y, and Z axis lines as Path3D objects with per-segment colors
x_axis = trimesh.load_path([[0, 0, 0], [axis_length, 0, 0]])
y_axis = trimesh.load_path([[0, 0, 0], [0, axis_length, 0]])
z_axis = trimesh.load_path([[0, 0, 0], [0, 0, axis_length]])

# Assign colors to the vertices of each axis
x_axis.entities[0].color = [255, 0, 0, 255]  # Red for X-axis
y_axis.entities[0].color = [0, 255, 0, 255]  # Green for Y-axis
z_axis.entities[0].color = [0, 0, 255, 255]  # Blue for Z-axis

# Add the axes to the scene
scene.add_geometry(x_axis)
scene.add_geometry(y_axis)
scene.add_geometry(z_axis)

# Visualize the entire scene
scene.show()


In [None]:
scale

In [None]:
from primitives import mesh_cuboid, mesh_cylinder

In [None]:
# Get the pose of the first car's part
idx = 211
pose = poses[idx].detach().cpu().numpy()

# Extract the rotation, translation, and scale parameters
quaternions = pose[:, :4]
translations = pose[:, 4:7]
scales = pose[:, 7:10]

# Mesh cuboids or cylinders
# NOTE: the factors of 2 are needed for some parameters when converting to trimesh primitives
#   because my scale factor are half-lengths. 
cuboids = [mesh_cuboid(scales[i] * 2, translations[i], quaternions[i]) for i in range(0, 1)]
cylinders = [mesh_cylinder(scales[i, 0], scales[i, 2] * 2, translations[i], quaternions[i]) for i in range(1, 5)]

# Visualize them
trimesh.Scene([trimesh.creation.axis()] + cylinders + cuboids).show()

I guess you might need to flatten them into a vector for the diffusion:

In [None]:
# print(latents.flatten(1).shape)
# print(poses.flatten(1).shape)

# # And if you want to concatenate them:
# print("\nWhen combined per parts:")
# combined = torch.cat([latents, poses], dim=-1)
# print(combined.shape)
# print(combined.flatten(1).shape)

### Dataset Creation

In [None]:
# Convert tensors to numpy arrays if necessary
g_js_affine_np = poses.cpu().numpy()
s_j_affine_np = latents.cpu().numpy()

# Create an HDF5 file and write the datasets
output_hdf5_path = '/auto/k2/ademirtas/codes/diffusion/salad/salad/data/arda_cars.hdf5'
with h5py.File(output_hdf5_path, 'w') as hdf5_file:
    hdf5_file.create_dataset('g_js_affine', data=g_js_affine_np, dtype=np.float64)
    hdf5_file.create_dataset('s_j_affine', data=s_j_affine_np, dtype=np.float64)

print(f"HDF5 file created at {output_hdf5_path} with datasets 'g_js_affine', 's_j_affine'")

In [None]:
with h5py.File("/auto/k2/ademirtas/codes/diffusion/salad/salad/data/arda_cars.hdf5", 'r') as f:
    # Print all root level keys (groups, datasets)
    print("Keys:", list(f.keys()))
    
    # Example: Access and read a specific dataset
    # dataset_name = 'name_of_your_dataset'  # Replace with the actual dataset name if known
    # data = f[dataset_name][:]
    # print(data)
    
    # Recursively print all groups and datasets in the file
    def print_hdf5_item(name, item):
        print(name, ":", item)
        
    f.visititems(print_hdf5_item)

In [None]:
output_hdf5_path = "/auto/k2/ademirtas/codes/diffusion/salad/salad/data/arda_cars_mean_std.hdf5"

# Calculate means and standard deviations
mean = np.mean(g_js_affine_np, axis=(0, 1))    # Mean over the first two dimensions, shape: (16,)
std = np.std(g_js_affine_np, axis=(0, 1))      # Std over the first two dimensions, shape: (16,)

sj_mean = np.mean(s_j_affine_np, axis=(0, 1))  # Mean over the first two dimensions, shape: (512,)
sj_std = np.std(s_j_affine_np, axis=(0, 1))    # Std over the first two dimensions, shape: (512,)

# Create an HDF5 file for the calculated statistics
with h5py.File(output_hdf5_path, 'w') as output_file:
    output_file.create_dataset('mean', data=mean, dtype=np.float64)
    output_file.create_dataset('std', data=std, dtype=np.float64)
    output_file.create_dataset('sj_mean', data=sj_mean, dtype=np.float64)
    output_file.create_dataset('sj_std', data=sj_std, dtype=np.float64)

print(f"Mean and std dataset created at {output_hdf5_path} with keys ['mean', 'std', 'sj_mean', 'sj_std'].")


In [None]:
with h5py.File("/auto/k2/ademirtas/codes/diffusion/salad/salad/data/arda_cars_mean_std.hdf5", 'r') as f:
    # Print all root level keys (groups, datasets)
    print("Keys:", list(f.keys()))
    
    # Example: Access and read a specific dataset
    # dataset_name = 'name_of_your_dataset'  # Replace with the actual dataset name if known
    # data = f[dataset_name][:]
    # print(data)
    
    # Recursively print all groups and datasets in the file
    def print_hdf5_item(name, item):
        print(name, ":", item)
        
    f.visititems(print_hdf5_item)