Skip to content

Commit

Permalink
Add construct mesh tasks (#26)
Browse files Browse the repository at this point in the history
* Add construct mesh from coeffs task

* Move construct mesh from points task

* Add construct mesh from array task

* Add new construct tasks to init

* Update mypy ignore for vtk
  • Loading branch information
jessicasyu committed May 19, 2023
1 parent f4a1759 commit 3113464
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 25 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ module = [
"skimage.*",
"sklearn.*",
"trimesh.*",
"vtkmodules.*",
"vtk.*",
]
ignore_missing_imports = true

Expand Down
6 changes: 6 additions & 0 deletions src/abm_shape_collection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from .calculate_shape_stats import calculate_shape_stats
from .calculate_size_stats import calculate_size_stats
from .compile_shape_modes import compile_shape_modes
from .construct_mesh_from_array import construct_mesh_from_array
from .construct_mesh_from_coeffs import construct_mesh_from_coeffs
from .construct_mesh_from_points import construct_mesh_from_points
from .extract_shape_modes import extract_shape_modes
from .fit_pca_model import fit_pca_model
from .get_shape_coefficients import get_shape_coefficients
Expand All @@ -13,6 +16,9 @@
calculate_shape_stats = task(calculate_shape_stats)
calculate_size_stats = task(calculate_size_stats)
compile_shape_modes = task(compile_shape_modes)
construct_mesh_from_array = task(construct_mesh_from_array)
construct_mesh_from_coeffs = task(construct_mesh_from_coeffs)
construct_mesh_from_points = task(construct_mesh_from_points)
extract_shape_modes = task(extract_shape_modes)
fit_pca_model = task(fit_pca_model)
get_shape_coefficients = task(get_shape_coefficients)
Expand Down
10 changes: 10 additions & 0 deletions src/abm_shape_collection/construct_mesh_from_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import numpy as np
from aicsshparam import shtools
from vtk import vtkPolyData


def construct_mesh_from_array(array: np.ndarray, reference: np.ndarray) -> vtkPolyData:
_, angle = shtools.align_image_2d(image=reference)
aligned_array = shtools.apply_image_alignment_2d(array, angle).squeeze()
mesh, _, _ = shtools.get_mesh_from_image(image=aligned_array)
return mesh
34 changes: 34 additions & 0 deletions src/abm_shape_collection/construct_mesh_from_coeffs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import numpy as np
import pandas as pd
from aicsshparam import shtools
from vtk import vtkPolyData, vtkTransform, vtkTransformPolyDataFilter


def construct_mesh_from_coeffs(
coeffs: pd.DataFrame,
order: int,
prefix: str = "",
suffix: str = "",
scale: float = 1.0,
) -> vtkPolyData:
coeffs_map = np.zeros((2, order + 1, order + 1), dtype=np.float32)

for l in range(order + 1):
for m in range(order + 1):
coeffs_map[0, l, m] = coeffs[f"{prefix}shcoeffs_L{l}M{m}C{suffix}"]
coeffs_map[1, l, m] = coeffs[f"{prefix}shcoeffs_L{l}M{m}S{suffix}"]

mesh, _ = shtools.get_reconstruction_from_coeffs(coeffs_map)

if scale != 1:
transform = vtkTransform()
transform.Scale((scale, scale, scale))

transform_filter = vtkTransformPolyDataFilter()
transform_filter.SetInputData(mesh)
transform_filter.SetTransform(transform)
transform_filter.Update()

mesh = transform_filter.GetOutput(0)

return mesh
19 changes: 19 additions & 0 deletions src/abm_shape_collection/construct_mesh_from_points.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import numpy as np
import pandas as pd
from sklearn.decomposition import PCA
from vtk import vtkPolyData

from abm_shape_collection.construct_mesh_from_coeffs import construct_mesh_from_coeffs


def construct_mesh_from_points(
pca: PCA,
points: np.ndarray,
feature_names: list[str],
order: int,
prefix: str = "",
suffix: str = "",
) -> vtkPolyData:
"""Constructs mesh given PCA transformation points."""
coeffs = pd.Series(pca.inverse_transform(points), index=feature_names)
return construct_mesh_from_coeffs(coeffs, order, prefix, suffix)
27 changes: 3 additions & 24 deletions src/abm_shape_collection/extract_shape_modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import numpy as np
import pandas as pd
import trimesh
from aicsshparam import shtools
from sklearn.decomposition import PCA
from vtkmodules.vtkCommonDataModel import vtkPolyData
from vtkmodules.vtkIOPLY import vtkPLYWriter
from vtk import vtkPLYWriter, vtkPolyData

from abm_shape_collection.construct_mesh_from_points import construct_mesh_from_points


def extract_shape_modes(
Expand Down Expand Up @@ -50,27 +50,6 @@ def extract_shape_mode_slices(
return slices


def construct_mesh_from_points(
pca: PCA,
points: np.ndarray,
feature_names: list[str],
order: int,
prefix: str = "",
suffix: str = "",
) -> vtkPolyData:
"""Constructs mesh given PCA transformation points."""
coeffs = pd.Series(pca.inverse_transform(points), index=feature_names)
coeffs_map = np.zeros((2, order + 1, order + 1), dtype=np.float32)

for l in range(order + 1):
for m in range(order + 1):
coeffs_map[0, l, m] = coeffs[f"{prefix}shcoeffs_L{l}M{m}C{suffix}"]
coeffs_map[1, l, m] = coeffs[f"{prefix}shcoeffs_L{l}M{m}S{suffix}"]

mesh, _ = shtools.get_reconstruction_from_coeffs(coeffs_map)
return mesh


def convert_vtk_to_trimesh(mesh: vtkPolyData) -> trimesh.Trimesh:
with tempfile.NamedTemporaryFile() as temp:
writer = vtkPLYWriter()
Expand Down

0 comments on commit 3113464

Please sign in to comment.