Skip to content

Commit

Permalink
Merge pull request #51 from kaitj/enh/3d-viz
Browse files Browse the repository at this point in the history
Plotting in 3D
  • Loading branch information
kaitj committed Dec 12, 2023
2 parents 801f263 + 2ae1a60 commit caed712
Show file tree
Hide file tree
Showing 2 changed files with 258 additions and 2 deletions.
175 changes: 174 additions & 1 deletion afids_utils/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from numpy.typing import NDArray
from plotly.graph_objs._figure import Figure as goFigure

from afids_utils.afids import AfidPosition, AfidSet, AfidVoxel
from afids_utils.afids import AfidDistance, AfidPosition, AfidSet, AfidVoxel
from afids_utils.transforms import world_to_voxel

# Matplotlib colormap object with 32-discrete colors
Expand Down Expand Up @@ -54,6 +54,12 @@
name="afids_cmap", colors=COLORS, N=len(COLORS)
)

SCATTER_DICT = {
"size": 4,
"color": "rgba(0,0,0,0.9)",
"line": {"width": 1.5, "color": "rgba(50,50,50,1.0)"},
}


def _create_afid_nii(
afid_voxels: list[AfidVoxel], afid_nii: nib.nifti1.Nifti1Image
Expand Down Expand Up @@ -268,6 +274,53 @@ def _create_scatter_plot(
return view # pyright: ignore


def _gen_distance_lines(
distances: list[AfidDistance],
) -> dict[str, list[float | int | None]]:
"""Internal function to generate lines between corresponding AFIDs from
two collections
Parameters
----------
distances
Collection of AfidDistances
Returns
-------
dict[str, list[float | int | None]]
Dictionary of parameters to draw connecting lines between corresponding
AFIDs
"""
lines: dict[str, list[float | int | None]] = {
"x": [],
"y": [],
"z": [],
"magnitude": [],
}

for distance in distances:
# Compute midpoints
x_mid = (distance.afid_position1.x + distance.afid_position2.x) / 2.0
y_mid = (distance.afid_position1.y + distance.afid_position2.y) / 2.0
z_mid = (distance.afid_position1.z + distance.afid_position2.z) / 2.0

# Compute lines
lines["x"].extend(
[distance.afid_position1.x, x_mid, distance.afid_position2.x, None]
)
lines["y"].extend(
[distance.afid_position1.y, y_mid, distance.afid_position2.y, None]
)
lines["z"].extend(
[distance.afid_position1.z, z_mid, distance.afid_position2.z, None]
)
lines["magnitude"].extend(
[distance.distance, distance.distance, distance.distance, 0]
)

return lines # pyright: ignore


def plot_ortho(
afids: AfidVoxel | AfidPosition | list[AfidVoxel | AfidPosition],
afid_nii: nib.nifti1.Nifti1Image,
Expand Down Expand Up @@ -380,3 +433,123 @@ def plot_distance_summary(
)

return view # pyright: ignore


def plot_3d(
afids: list[AfidPosition],
template_afids: list[AfidPosition] | None = None,
afids_scatter_dict: dict[
str, int | str | dict[str, float | str]
] = SCATTER_DICT,
title: str = "",
show_distance: bool = False,
) -> goFigure:
"""Generate 3D plot of AFIDs. Optionally visualize distance against
template and/or overlay with surface mesh.
Parameters
----------
afids
Collection of AfidPositions to visualize.
template_afids
Optional collection of template AfidPositions to visualize and compare
AFIDs against
afids_scatter_dict
Dictionary containing parameters for modifying visualization of afid
scatter points
title
Main title of figure
show_distance
Flag to indicate drawing of line between afids and template_afids
(if provided)
"""
go_afids = go.Scatter3d( # pyright: ignore
x=[afid.x for afid in afids],
y=[afid.y for afid in afids],
z=[afid.z for afid in afids],
showlegend=True,
mode="markers",
marker=afids_scatter_dict,
hovertemplate=("%{text}<br>x: %{x:.4f}<br>y: %{y:.4f}<br>z: %{z:.4f}"),
text=[f"<b>{afid.desc} ({afid.label})</b>" for afid in afids],
name="Subject AFIDs",
)

view = go.Figure() # pyright: ignore
view.add_trace(go_afids) # pyright: ignore

# Plot templates
if template_afids:
# Check if same number of fiducials provided
if len(afids) != len(template_afids):
raise ValueError("Mismatched number of fiducials")

go_template_afids = go.Scatter3d( # pyright: ignore
x=[template_afid.x for template_afid in template_afids],
y=[template_afid.y for template_afid in template_afids],
z=[template_afid.z for template_afid in template_afids],
showlegend=True,
mode="markers",
marker=afids_scatter_dict.update(
{"color": "rgba(255,191,31,0.9)"}
),
hovertemplate=(
"%{text}<br>x: %{x:.4f}<br>y: %{y:.4f}<br>z: %{z:.4f}"
),
text=[
f"<b>{template_afid.desc} ({template_afid.label})</b>"
for template_afid in template_afids
],
name="Template AFIDs",
)

view.add_trace(go_template_afids) # pyright: ignore

if show_distance:
distances = [
AfidDistance(afid, template_afid)
for afid, template_afid in zip(afids, template_afids)
]
lines = _gen_distance_lines(distances=distances)

# Draw line between afids
go_lines = go.Scatter3d( # pyright: ignore
x=lines["x"],
y=lines["y"],
z=lines["z"],
showlegend=False,
mode="lines",
hovertemplate="%{text}",
text=[
(
f"<b>{template_afids[idx // 4].desc} "
f"({template_afids[idx // 4].label})</b>"
f"<br>{lines['magnitude'][idx]:.3f} mm"
)
for idx in range(len(lines["magnitude"]))
],
line={
"color": lines["magnitude"],
"colorscale": "Bluered",
"width": 8,
"showscale": True,
"colorbar": {"title": {"text": "Euclidean distance"}},
},
name="Euclidean Distance",
)

view.add_trace(go_lines) # pyright: ignore

view.update_layout( # pyright: ignore
title_text=title,
autosize=True,
barmode="stack",
coloraxis={"colorscale": "Bluered"},
legend_orientation="h",
)

return view # pyright: ignore
85 changes: 84 additions & 1 deletion afids_utils/tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import afids_utils.plotting as af_plot
import afids_utils.tests.helpers as af_helpers
import afids_utils.tests.strategies as af_st
from afids_utils.afids import AfidPosition, AfidVoxel
from afids_utils.afids import AfidDistance, AfidPosition, AfidVoxel


@pytest.fixture
Expand Down Expand Up @@ -250,3 +250,86 @@ def test_plot_summary(
view.close() # pyright: ignore
else:
del view


class TestPlot3D:
@given(afid_positions=af_st.position_lists())
def test_plot_scatter3d(self, afid_positions: list[AfidPosition]):
view = af_plot.plot_3d(afids=afid_positions)
assert isinstance(view, goFigure)

del view

@given(afid_positions=af_st.position_lists())
def test_plot_scatter3d_dict(self, afid_positions: list[AfidPosition]):
# Static dict for testing
scatter_dict = {
"size": 8,
"color": "rgba(125, 125, 125, 0.5)",
"line": {"width": 3.0, "color": "rgba(0,0,0,1.)"},
}
view = af_plot.plot_3d(
afids=afid_positions, afids_scatter_dict=scatter_dict
)
assert isinstance(view, goFigure)

del view

@given(afid_positions=af_st.position_lists())
def test_plot_scatter3d_title(self, afid_positions: list[AfidPosition]):
# Static title for testing
title = "Test Title"
view = af_plot.plot_3d(afids=afid_positions, title=title)
assert isinstance(view, goFigure)

del view

@given(afid_positions=af_st.position_lists())
def test_plot_scatter3d_template(
self,
afid_positions: list[AfidPosition],
):
view = af_plot.plot_3d(
afids=afid_positions, template_afids=afid_positions
)
assert isinstance(view, goFigure)

del view

@given(afid_positions=af_st.position_lists())
def test_plot_scatter3d_invalid_template(
self,
afid_positions: list[AfidPosition],
):
with pytest.raises(ValueError, match="Mismatched number"):
af_plot.plot_3d(
afids=afid_positions, template_afids=afid_positions[1:]
)

@given(afid_positions=af_st.position_lists())
def test_gen_distance_lines(self, afid_positions: list[AfidPosition]):
distances = [
AfidDistance(
afid_position1=afid_position, afid_position2=afid_position
)
for afid_position in afid_positions
]
assert all(isinstance(dist, AfidDistance) for dist in distances)

lines = af_plot._gen_distance_lines(distances=distances)
assert isinstance(lines, dict)
for key in lines.keys():
assert len(lines[key]) == len(distances) * 4

@given(afid_positions=af_st.position_lists())
def test_plot_scatter3d_show_distance(
self, afid_positions: list[AfidPosition]
):
view = af_plot.plot_3d(
afids=afid_positions,
template_afids=afid_positions,
show_distance=True,
)
assert isinstance(view, goFigure)

del view

0 comments on commit caed712

Please sign in to comment.