Skip to content

Commit

Permalink
Add plotting in 3D for afids
Browse files Browse the repository at this point in the history
This commit allows lone afids to be plotted in 3d, as well as plot
comparisons against "template" afids, optionally drawing lines between
corresponding points.
  • Loading branch information
kaitj committed Dec 7, 2023
1 parent 22c4ae6 commit 877e095
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 1 deletion.
121 changes: 120 additions & 1 deletion afids_utils/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from numpy.typing import NDArray
from plotly.graph_objs._figure import Figure as goFigure

from afids_utils.afids import AfidPosition, AfidSet, AfidVoxel
from afids_utils.afids import AfidDistance, AfidPosition, AfidSet, AfidVoxel
from afids_utils.transforms import world_to_voxel

# Matplotlib colormap object with 32-discrete colors
Expand Down Expand Up @@ -274,6 +274,53 @@ def _create_scatter_plot(
return view # pyright: ignore


def _gen_distance_lines(
distances: list[AfidDistance],
) -> dict[str, list[float | int | None]]:
"""Internal function to generate lines between corresponding AFIDs from
two collections
Parameters
----------
distances
Collection of AfidDistances
Returns
-------
dict[str, list[float | int | None]]
Dictionary of parameters to draw connecting lines between corresponding
AFIDs
"""
lines: dict[str, list[float | int | None]] = {
"x": [],
"y": [],
"z": [],
"magnitude": [],
}

for distance in distances:
# Compute midpoints
x_mid = (distance.afid_position1.x + distance.afid_position2.x) / 2.0
y_mid = (distance.afid_position1.y + distance.afid_position2.y) / 2.0
z_mid = (distance.afid_position1.z + distance.afid_position2.z) / 2.0

# Compute lines
lines["x"].extend(
[distance.afid_position1.x, x_mid, distance.afid_position2.x, None]
)
lines["y"].extend(
[distance.afid_position1.y, y_mid, distance.afid_position2.y, None]
)
lines["z"].extend(
[distance.afid_position1.z, z_mid, distance.afid_position2.z, None]
)
lines["magnitude"].extend(
[distance.distance, distance.distance, distance.distance, 0]
)

return lines # pyright: ignore


def plot_ortho(
afids: AfidVoxel | AfidPosition | list[AfidVoxel | AfidPosition],
afid_nii: nib.nifti1.Nifti1Image,
Expand Down Expand Up @@ -390,10 +437,12 @@ def plot_distance_summary(

def plot_3d(
afids: list[AfidPosition],
template_afids: list[AfidPosition] | None = None,
afids_scatter_dict: dict[
str, int | str | dict[str, float | str]
] = SCATTER_DICT,
title: str = "",
show_distance: bool = False,
) -> goFigure:
"""Generate 3D plot of AFIDs. Optionally visualize distance against
template and/or overlay with surface mesh.
Expand All @@ -403,12 +452,20 @@ def plot_3d(
afids
Collection of AfidPositions to visualize.
template_afids
Optional collection of template AfidPositions to visualize and compare
AFIDs against
afids_scatter_dict
Dictionary containing parameters for modifying visualization of afid
scatter points
title
Main title of figure
show_distance
Flag to indicate drawing of line between afids and template_afids
(if provided)
"""
go_afids = go.Scatter3d( # pyright: ignore
x=[afid.x for afid in afids],
Expand All @@ -425,6 +482,68 @@ def plot_3d(
view = go.Figure() # pyright: ignore
view.add_trace(go_afids) # pyright: ignore

# Plot templates
if template_afids:
# Check if same number of fiducials provided
if len(afids) != len(template_afids):
raise ValueError("Mismatched number of fiducials")

go_template_afids = go.Scatter3d( # pyright: ignore
x=[template_afid.x for template_afid in template_afids],
y=[template_afid.y for template_afid in template_afids],
z=[template_afid.z for template_afid in template_afids],
showlegend=True,
mode="markers",
marker=afids_scatter_dict.update(
{"color": "rgba(255,191,31,0.9)"}
),
hovertemplate=(
"%{text}<br>x: %{x:.4f}<br>y: %{y:.4f}<br>z: %{z:.4f}"
),
text=[
f"<b>{template_afid.desc} ({template_afid.label})</b>"
for template_afid in template_afids
],
name="Template AFIDs",
)

view.add_trace(go_template_afids) # pyright: ignore

if show_distance:
distances = [
AfidDistance(afid, template_afid)
for afid, template_afid in zip(afids, template_afids)
]
lines = _gen_distance_lines(distances=distances)

# Draw line between afids
go_lines = go.Scatter3d( # pyright: ignore
x=lines["x"],
y=lines["y"],
z=lines["z"],
showlegend=False,
mode="lines",
hovertemplate="%{text}",
text=[
(
f"<b>{template_afids[idx // 4].desc} "
f"({template_afids[idx // 4].label})</b>"
f"<br>{lines['magnitude'][idx // 4]:.3f} mm"
)
for idx in range(len(lines["magnitude"]))
],
line={
"color": lines["magnitude"],
"colorscale": "Bluered",
"width": 8,
"showscale": True,
"colorbar": {"title": {"text": "Euclidean distance"}},
},
name="Euclidean Distance",
)

view.add_trace(go_lines) # pyright: ignore

view.update_layout( # pyright: ignore
title_text=title,
autosize=True,
Expand Down
35 changes: 35 additions & 0 deletions afids_utils/tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,38 @@ def test_plot_scatter3d_title(self, afid_positions: list[AfidPosition]):
assert isinstance(view, goFigure)

del view

@given(afid_positions=af_st.position_lists())
def test_plot_scatter3d_template(
self,
afid_positions: list[AfidPosition],
):
view = af_plot.plot_3d(
afids=afid_positions, template_afids=afid_positions
)
assert isinstance(view, goFigure)

del view

@given(afid_positions=af_st.position_lists())
def test_plot_scatter3d_invalid_template(
self,
afid_positions: list[AfidPosition],
):
with pytest.raises(ValueError, match="Mismatched number"):
af_plot.plot_3d(
afids=afid_positions, template_afids=afid_positions[1:]
)

@given(afid_positions=af_st.position_lists())
def test_plot_scatter3d_show_distance(
self, afid_positions: list[AfidPosition]
):
view = af_plot.plot_3d(
afids=afid_positions,
template_afids=afid_positions,
show_distance=True,
)
assert isinstance(view, goFigure)

del view

0 comments on commit 877e095

Please sign in to comment.