# 3D Visualization of Binary Masks

This notebook loads two NIfTI files, applies some transformations, and creates a 3D visualization of the binary masks.

In [None]:
import nibabel as nib
import numpy as np
import plotly.graph_objects as go
from scipy import ndimage
from monai.transforms import Compose, LoadImaged, Orientationd, Spacingd, CropForegroundd, SpatialPadd

In [None]:
transforms = Compose([
    LoadImaged(["A", "B"], image_only=True, ensure_channel_first=True),
    Orientationd(["A", "B"], axcodes="RAS"),
    Spacingd(["A", "B"], [2.0, 2.0, 2.0], mode="nearest"),
    CropForegroundd("A", source_key="A"),
    CropForegroundd("B", source_key="B"),
    SpatialPadd(["A", "B"], (128, 128, 128), method="symmetric", mode="minimum")
])

In [None]:
# Load NIfTI files
img_A = '/mnt/data/Experiment/nnUNet/nnUNet_raw/Dataset020_SCOTHEART/labelsTr/110021_CE-ED.nii.gz'
img_B = '/mnt/data/Experiment/nnUNet/nnUNet_raw/Dataset021_ACDC/labelsTs/patient002_frame01.nii.gz'

# Get data from NIfTI files
data = transforms({"A": img_A, "B": img_B})
data_A = data["A"].get_array()[0]
data_B = data["B"].get_array()[0]

In [None]:
# Convert segmentation to binary mask (values 2 or 4)
mask_A = np.isin(data_A, [2, 4])
mask_B = np.isin(data_B, [2, 4])

# Get coordinates of non-zero voxels
coords_A = np.array(np.where(mask_A)).T
coords_B = np.array(np.where(mask_B)).T

# Randomly select 5000 points for visualization
np.random.shuffle(coords_A)
np.random.shuffle(coords_B)
coords_A = coords_A[:5000]
coords_B = coords_B[:5000]

In [None]:
# Create 3D scatter plots
trace_A = go.Scatter3d(
    x=coords_A[:, 1],
    y=coords_A[:, 0],
    z=coords_A[:, 2],
    mode='markers',
    marker=dict(size=2, color='red', opacity=0.8),
    name='Mask A'
)

trace_B = go.Scatter3d(
    x=coords_B[:, 1],
    y=coords_B[:, 0],
    z=coords_B[:, 2],
    mode='markers',
    marker=dict(size=2, color='blue', opacity=0.8),
    name='Mask B'
)

# Create the 3D plot
fig = go.Figure(data=[trace_A, trace_B])

# Update layout for better visualization
fig.update_layout(
    scene=dict(
        xaxis_title='X',
        yaxis_title='Y',
        zaxis_title='Z',
        aspectmode='data'
    ),
    width=800,
    height=800,
    title='3D Visualization of Binary Masks'
)

# Display the figure
fig.show()

# Save the figure as an HTML file
fig.write_html("binary_masks_3d.html")