Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plotting in 3D #51

Merged
merged 4 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 174 additions & 1 deletion afids_utils/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from numpy.typing import NDArray
from plotly.graph_objs._figure import Figure as goFigure

from afids_utils.afids import AfidPosition, AfidSet, AfidVoxel
from afids_utils.afids import AfidDistance, AfidPosition, AfidSet, AfidVoxel
from afids_utils.transforms import world_to_voxel

# Matplotlib colormap object with 32-discrete colors
Expand Down Expand Up @@ -54,6 +54,12 @@
name="afids_cmap", colors=COLORS, N=len(COLORS)
)

SCATTER_DICT = {
"size": 4,
"color": "rgba(0,0,0,0.9)",
"line": {"width": 1.5, "color": "rgba(50,50,50,1.0)"},
}


def _create_afid_nii(
afid_voxels: list[AfidVoxel], afid_nii: nib.nifti1.Nifti1Image
Expand Down Expand Up @@ -268,6 +274,53 @@ def _create_scatter_plot(
return view # pyright: ignore


def _gen_distance_lines(
distances: list[AfidDistance],
) -> dict[str, list[float | int | None]]:
"""Internal function to generate lines between corresponding AFIDs from
two collections

Parameters
----------
distances
Collection of AfidDistances

Returns
-------
dict[str, list[float | int | None]]
Dictionary of parameters to draw connecting lines between corresponding
AFIDs
"""
lines: dict[str, list[float | int | None]] = {
"x": [],
"y": [],
"z": [],
"magnitude": [],
}

for distance in distances:
# Compute midpoints
x_mid = (distance.afid_position1.x + distance.afid_position2.x) / 2.0
y_mid = (distance.afid_position1.y + distance.afid_position2.y) / 2.0
z_mid = (distance.afid_position1.z + distance.afid_position2.z) / 2.0

# Compute lines
lines["x"].extend(
[distance.afid_position1.x, x_mid, distance.afid_position2.x, None]
)
lines["y"].extend(
[distance.afid_position1.y, y_mid, distance.afid_position2.y, None]
)
lines["z"].extend(
[distance.afid_position1.z, z_mid, distance.afid_position2.z, None]
)
lines["magnitude"].extend(
[distance.distance, distance.distance, distance.distance, 0]
)

return lines # pyright: ignore


def plot_ortho(
afids: AfidVoxel | AfidPosition | list[AfidVoxel | AfidPosition],
afid_nii: nib.nifti1.Nifti1Image,
Expand Down Expand Up @@ -380,3 +433,123 @@ def plot_distance_summary(
)

return view # pyright: ignore


def plot_3d(
afids: list[AfidPosition],
template_afids: list[AfidPosition] | None = None,
afids_scatter_dict: dict[
str, int | str | dict[str, float | str]
] = SCATTER_DICT,
title: str = "",
show_distance: bool = False,
) -> goFigure:
"""Generate 3D plot of AFIDs. Optionally visualize distance against
template and/or overlay with surface mesh.

Parameters
----------
afids
Collection of AfidPositions to visualize.

template_afids
Optional collection of template AfidPositions to visualize and compare
AFIDs against

afids_scatter_dict
Dictionary containing parameters for modifying visualization of afid
scatter points

title
Main title of figure

show_distance
Flag to indicate drawing of line between afids and template_afids
(if provided)
"""
go_afids = go.Scatter3d( # pyright: ignore
x=[afid.x for afid in afids],
y=[afid.y for afid in afids],
z=[afid.z for afid in afids],
showlegend=True,
mode="markers",
marker=afids_scatter_dict,
hovertemplate=("%{text}<br>x: %{x:.4f}<br>y: %{y:.4f}<br>z: %{z:.4f}"),
text=[f"<b>{afid.desc} ({afid.label})</b>" for afid in afids],
name="Subject AFIDs",
)

view = go.Figure() # pyright: ignore
view.add_trace(go_afids) # pyright: ignore

# Plot templates
if template_afids:
# Check if same number of fiducials provided
if len(afids) != len(template_afids):
raise ValueError("Mismatched number of fiducials")

go_template_afids = go.Scatter3d( # pyright: ignore
x=[template_afid.x for template_afid in template_afids],
y=[template_afid.y for template_afid in template_afids],
z=[template_afid.z for template_afid in template_afids],
showlegend=True,
mode="markers",
marker=afids_scatter_dict.update(
{"color": "rgba(255,191,31,0.9)"}
),
hovertemplate=(
"%{text}<br>x: %{x:.4f}<br>y: %{y:.4f}<br>z: %{z:.4f}"
),
text=[
f"<b>{template_afid.desc} ({template_afid.label})</b>"
for template_afid in template_afids
],
name="Template AFIDs",
)

view.add_trace(go_template_afids) # pyright: ignore

if show_distance:
distances = [
AfidDistance(afid, template_afid)
for afid, template_afid in zip(afids, template_afids)
]
lines = _gen_distance_lines(distances=distances)

# Draw line between afids
go_lines = go.Scatter3d( # pyright: ignore
x=lines["x"],
y=lines["y"],
z=lines["z"],
showlegend=False,
mode="lines",
hovertemplate="%{text}",
text=[
(
f"<b>{template_afids[idx // 4].desc} "
f"({template_afids[idx // 4].label})</b>"
f"<br>{lines['magnitude'][idx]:.3f} mm"
)
for idx in range(len(lines["magnitude"]))
],
line={
"color": lines["magnitude"],
"colorscale": "Bluered",
"width": 8,
"showscale": True,
"colorbar": {"title": {"text": "Euclidean distance"}},
},
name="Euclidean Distance",
)

view.add_trace(go_lines) # pyright: ignore

view.update_layout( # pyright: ignore
title_text=title,
autosize=True,
barmode="stack",
coloraxis={"colorscale": "Bluered"},
legend_orientation="h",
)

return view # pyright: ignore
85 changes: 84 additions & 1 deletion afids_utils/tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import afids_utils.plotting as af_plot
import afids_utils.tests.helpers as af_helpers
import afids_utils.tests.strategies as af_st
from afids_utils.afids import AfidPosition, AfidVoxel
from afids_utils.afids import AfidDistance, AfidPosition, AfidVoxel


@pytest.fixture
Expand Down Expand Up @@ -250,3 +250,86 @@ def test_plot_summary(
view.close() # pyright: ignore
else:
del view


class TestPlot3D:
@given(afid_positions=af_st.position_lists())
def test_plot_scatter3d(self, afid_positions: list[AfidPosition]):
view = af_plot.plot_3d(afids=afid_positions)
assert isinstance(view, goFigure)

del view

@given(afid_positions=af_st.position_lists())
def test_plot_scatter3d_dict(self, afid_positions: list[AfidPosition]):
# Static dict for testing
scatter_dict = {
"size": 8,
"color": "rgba(125, 125, 125, 0.5)",
"line": {"width": 3.0, "color": "rgba(0,0,0,1.)"},
}
view = af_plot.plot_3d(
afids=afid_positions, afids_scatter_dict=scatter_dict
)
assert isinstance(view, goFigure)

del view

@given(afid_positions=af_st.position_lists())
def test_plot_scatter3d_title(self, afid_positions: list[AfidPosition]):
# Static title for testing
title = "Test Title"
view = af_plot.plot_3d(afids=afid_positions, title=title)
assert isinstance(view, goFigure)

del view

@given(afid_positions=af_st.position_lists())
def test_plot_scatter3d_template(
self,
afid_positions: list[AfidPosition],
):
view = af_plot.plot_3d(
afids=afid_positions, template_afids=afid_positions
)
assert isinstance(view, goFigure)

del view

@given(afid_positions=af_st.position_lists())
def test_plot_scatter3d_invalid_template(
self,
afid_positions: list[AfidPosition],
):
with pytest.raises(ValueError, match="Mismatched number"):
af_plot.plot_3d(
afids=afid_positions, template_afids=afid_positions[1:]
)

@given(afid_positions=af_st.position_lists())
def test_gen_distance_lines(self, afid_positions: list[AfidPosition]):
distances = [
AfidDistance(
afid_position1=afid_position, afid_position2=afid_position
)
for afid_position in afid_positions
]
assert all(isinstance(dist, AfidDistance) for dist in distances)

lines = af_plot._gen_distance_lines(distances=distances)
assert isinstance(lines, dict)
for key in lines.keys():
assert len(lines[key]) == len(distances) * 4

@given(afid_positions=af_st.position_lists())
def test_plot_scatter3d_show_distance(
self, afid_positions: list[AfidPosition]
):
view = af_plot.plot_3d(
afids=afid_positions,
template_afids=afid_positions,
show_distance=True,
)
assert isinstance(view, goFigure)

del view