In [3]:
# If needed:
# pip install plotly scikit-image ipywidgets

import os
import numpy as np
import nibabel as nib
from skimage import measure
import plotly.graph_objects as go
from ipywidgets import IntSlider, ToggleButton, HBox, VBox, interactive_output

# --- Configurable paths & case --- #
CASE_ID      = "6"
SCAN_DIR     = r"C:\Users\giles\Github\vesselFM\data\d_real\ImageCAS-Raw\img"

PRED_DIR     = r"C:\Users\giles\Github\vesselFM\data\inference\1e-3LRDecay-RandCrop\postp"

GT_DIR       = r"C:\Users\giles\Github\vesselFM\data\d_real\ImageCAS-Raw\gt"
GT_MASK_PATH = os.path.join(GT_DIR, f"{CASE_ID}.label.nii.gz")

# 1. Load data once
scan_vol  = nib.load(os.path.join(SCAN_DIR, f"{CASE_ID}.img.nii.gz")).get_fdata()
pred_mask = nib.load(os.path.join(PRED_DIR,  f"{CASE_ID}_pred.nii.gz")).get_fdata() > 0.5
gt_mask   = nib.load(GT_MASK_PATH).get_fdata() > 0.5
intersect = np.logical_and(pred_mask, gt_mask)
missing   = np.logical_and(~pred_mask, gt_mask)

# 2. Precompute meshes
def mesh3d(mask, color, name, opacity):
    verts, faces, _, _ = measure.marching_cubes(mask, level=0.5)
    return go.Mesh3d(
        x=verts[:,0], y=verts[:,1], z=verts[:,2],
        i=faces[:,0], j=faces[:,1], k=faces[:,2],
        color=color, opacity=opacity, name=name
    )

mesh_pred = mesh3d(pred_mask , 'cyan',    'Prediction',   0.9)
mesh_gt   = mesh3d(gt_mask   , 'red',     'Ground Truth', 0.4)
mesh_int  = mesh3d(intersect , 'yellow',  'Intersection', 0.6)
mesh_miss = mesh3d(missing   , 'magenta', 'Missing',      0.6)

# 3. Prepare CT slice grid
nx, ny, nz = scan_vol.shape
xx, yy     = np.meshgrid(np.arange(nx), np.arange(ny), indexing='ij')

# 4. The plotting function
def update_plot(z_idx, show_pred, show_gt, show_int, show_miss, show_slice):
    data = []
    if show_pred: data.append(mesh_pred)
    if show_gt:   data.append(mesh_gt)
    if show_int:  data.append(mesh_int)
    if show_miss: data.append(mesh_miss)
    if show_slice:
        sl   = scan_vol[:, :, z_idx]
        norm = (sl - sl.min()) / (sl.ptp() + 1e-9)
        slice_surf = go.Surface(
            x=xx, y=yy, z=np.full_like(xx, z_idx),
            surfacecolor=norm.T, colorscale='Gray',
            opacity=0.7, showscale=False, name=f"Slice {z_idx}"
        )
        data.append(slice_surf)

    fig = go.Figure(data=data)
    fig.update_layout(
        width=1000, height=800,
        margin=dict(l=0, r=0, b=0, t=80),
        title=dict(text=f"Case {CASE_ID}: 3D + Slice z={z_idx}", y=0.9),
        scene=dict(
            xaxis_title='X', yaxis_title='Y', zaxis_title='Z',
            aspectmode='data'
        )
    )
    fig.show()

# 5. Widgets
slice_slider = IntSlider(
    value=nz//2, min=0, max=nz-1, step=1,
    description='Slice Z:', continuous_update=False,
    layout={'width':'60%'}
)
tb_pred = ToggleButton(value=True,  description='Prediction',   button_style='info')
tb_gt   = ToggleButton(value=False, description='Ground Truth', button_style='danger')
tb_int  = ToggleButton(value=False, description='Intersection', button_style='warning')
tb_miss = ToggleButton(value=False, description='Missing',      button_style='primary')
tb_slice= ToggleButton(value=True,  description='Show Slice',   button_style='')

controls_top = HBox([slice_slider])
controls_bottom = HBox([tb_slice, tb_pred, tb_gt, tb_int, tb_miss])
controls = VBox([controls_top, controls_bottom])

# 6. Link up via interactive_output (no duplicate controls)
out = interactive_output(
    update_plot,
    {
        'z_idx': slice_slider,
        'show_slice': tb_slice,
        'show_pred': tb_pred,
        'show_gt': tb_gt,
        'show_int': tb_int,
        'show_miss': tb_miss
    }
)

display(VBox([controls, out]))


VBox(children=(VBox(children=(HBox(children=(IntSlider(value=137, continuous_update=False, description='Slice …

Save meshes to file

In [None]:
from stl import mesh # Add this import

MESH_DIR = r'C:\Users\giles\Github\vesselFM\data\inference\basefoundation\meshes' # Adjust as needed

output_mesh_dir = os.path.join(PRED_DIR, "..", "meshes_stl")
os.makedirs(output_mesh_dir, exist_ok=True)

def save_plotly_mesh_to_stl(plotly_mesh_obj, filename_base, case_id_str, base_output_dir):
    """
    Saves a Plotly Mesh3d object to an STL file.
    """
    # Reconstruct vertices (Nx3) and faces (Mx3) from Plotly mesh object
    verts_np = np.array([plotly_mesh_obj.x, plotly_mesh_obj.y, plotly_mesh_obj.z]).T
    faces_np = np.array([plotly_mesh_obj.i, plotly_mesh_obj.j, plotly_mesh_obj.k]).T

    if verts_np.size == 0 or faces_np.size == 0:
        print(f"Mesh for {filename_base} (Case: {case_id_str}) is empty. Skipping STL export.")
        return

    # Create the mesh structure for numpy-stl
    # It requires a list of triangles, where each triangle is defined by its 3 vertices.
    triangles = verts_np[faces_np]

    # Create the STL mesh object
    stl_data = mesh.Mesh(np.zeros(triangles.shape[0], dtype=mesh.Mesh.dtype))
    stl_data.vectors = triangles

    # Define the output filepath
    filepath = os.path.join(base_output_dir, f"{case_id_str}_{filename_base}.stl")
    
    # Save the STL file
    stl_data.save(filepath)
    print(f"Mesh saved to {filepath}")

# Save the prediction mesh
save_plotly_mesh_to_stl(mesh_pred, "prediction_mesh", CASE_ID, output_mesh_dir)

# Save the ground truth mesh
save_plotly_mesh_to_stl(mesh_gt, "ground_truth_mesh", CASE_ID, output_mesh_dir)


Mesh saved to C:\Users\giles\Github\vesselFM\data\inference\1e-3LRDecay-RandCrop\postp\..\meshes_stl\6_prediction_mesh.stl
Mesh saved to C:\Users\giles\Github\vesselFM\data\inference\1e-3LRDecay-RandCrop\postp\..\meshes_stl\6_ground_truth_mesh.stl


Load meshes from file

In [3]:
import time
import os
import numpy as np
import plotly.graph_objects as go
from stl import mesh

print("Initializing STL loading cell...")

def load_stl_to_plotly_mesh(stl_filepath, mesh_name, color='cyan', opacity=0.8):
    """
    Loads an STL file and converts it to a Plotly Mesh3d object, with progress prints.
    """
    t0 = time.time()
    print(f"Checking file: {stl_filepath}")
    if not os.path.exists(stl_filepath):
        print(f"File not found: {stl_filepath}")
        return None
    print(f"File found. Reading STL...")

    stl_mesh = mesh.Mesh.from_file(stl_filepath)
    print(f"STL file loaded. Time taken: {time.time() - t0:.2f} s")

    t1 = time.time()
    all_vertices = stl_mesh.vectors.reshape(-1, 3)
    unique_vertices, faces_idx = np.unique(all_vertices, axis=0, return_inverse=True)
    faces = faces_idx.reshape(-1, 3)
    print(f"Processed vertices/faces in {time.time() - t1:.2f} s")

    return go.Mesh3d(
        x=unique_vertices[:,0],
        y=unique_vertices[:,1],
        z=unique_vertices[:,2],
        i=faces[:,0],
        j=faces[:,1],
        k=faces[:,2],
        color=color,
        opacity=opacity,
        name=mesh_name
    )

CASE_ID         = "6"
print("Starting STL visualization setup...")
PRED_DIR    = r"C:\Users\giles\Github\vesselFM\data\inference\1e-3LRDecay-RandCrop\postp" #Change this to your prediction directory
output_mesh_dir = os.path.join(PRED_DIR, "..", "meshes_stl") 
stl_pred        = os.path.join(output_mesh_dir, f"{CASE_ID}_prediction_mesh.stl")
stl_gt          = os.path.join(output_mesh_dir, f"{CASE_ID}_ground_truth_mesh.stl")

print(f"Loading prediction mesh: {stl_pred}")
pred_mesh = load_stl_to_plotly_mesh(stl_pred, "Prediction Mesh",  color='cyan',  opacity=0.9)

print(f"Loading ground truth mesh: {stl_gt}")
gt_mesh   = load_stl_to_plotly_mesh(stl_gt,   "Ground Truth Mesh", color='red',   opacity=0.4)

print("Preparing data for visualization...")
data = []
if pred_mesh: data.append(pred_mesh)
if gt_mesh:   data.append(gt_mesh)

print("Creating final figure...")
fig = go.Figure(data=data)
fig.update_layout(
    width=800, height=700,
    title=f"Loaded STL Meshes for Case {CASE_ID}",
    scene=dict(aspectmode='data')
)
print("Displaying figure...")
fig.show()

Initializing STL loading cell...
Starting STL visualization setup...
Loading prediction mesh: C:\Users\giles\Github\vesselFM\data\inference\1e-3LRDecay-RandCrop\postp\..\meshes_stl\6_prediction_mesh.stl
Checking file: C:\Users\giles\Github\vesselFM\data\inference\1e-3LRDecay-RandCrop\postp\..\meshes_stl\6_prediction_mesh.stl
File found. Reading STL...
STL file loaded. Time taken: 0.02 s
Processed vertices/faces in 0.24 s
Loading ground truth mesh: C:\Users\giles\Github\vesselFM\data\inference\1e-3LRDecay-RandCrop\postp\..\meshes_stl\6_ground_truth_mesh.stl
Checking file: C:\Users\giles\Github\vesselFM\data\inference\1e-3LRDecay-RandCrop\postp\..\meshes_stl\6_ground_truth_mesh.stl
File found. Reading STL...
STL file loaded. Time taken: 0.02 s
Processed vertices/faces in 0.23 s
Preparing data for visualization...
Creating final figure...
Displaying figure...


Previous old attempt plotting voxels (very slow)

In [None]:
# import os
# import numpy as np
# import nibabel as nib
# import matplotlib.pyplot as plt
# from mpl_toolkits.mplot3d import Axes3D
# from ipywidgets import interact, IntSlider, VBox, Output
# from IPython.display import display

# # Define base paths
# MASK_DIR = r"C:\Users\giles\Github\vesselFM\data\inference"
# SCAN_DIR = r"C:\Users\giles\Github\vesselFM\data\d_real\ImageCAS-Raw"

# # Global variables
# scan_data_global = None
# mask_data_global = None
# fig_global = None
# ax_global = None
# out_plot_global = None # This will hold the plot output
# current_case_id = None 

# def load_data(case_id_str):
#     """Loads scan and mask data for a given case ID and sets global variables."""
#     global current_case_id, mask_data_global, scan_data_global
#     current_case_id = case_id_str 

#     mask_filename = f"{case_id_str}_pred.nii.gz"
#     scan_filename = f"{case_id_str}.img.nii.gz"
#     mask_path = os.path.join(MASK_DIR, mask_filename)
#     scan_path = os.path.join(SCAN_DIR, scan_filename)

#     if not os.path.exists(mask_path) or not os.path.exists(scan_path):
#         print(f"Error: File not found. Mask: {mask_path}, Scan: {scan_path}")
#         scan_data_global, mask_data_global = None, None
#         return None, None
#     try:
#         print(f"Loading scan: {scan_path}")
#         scan_nii = nib.load(scan_path)
#         scan_data_global = scan_nii.get_fdata()
#         print(f"Scan data loaded. Shape: {scan_data_global.shape}")

#         print(f"Loading mask: {mask_path}")
#         mask_nii = nib.load(mask_path)
#         mask_data_raw = mask_nii.get_fdata()
#         mask_data_global = mask_data_raw > 0.5 
#         print(f"Mask data loaded and thresholded. Shape: {mask_data_global.shape}, Sum True: {np.sum(mask_data_global)}")
        
#         if np.sum(mask_data_global) > 0:
#             true_voxel_indices = np.argwhere(mask_data_global)
#             print(f"Min/Max coords of True voxels: {true_voxel_indices.min(axis=0)} / {true_voxel_indices.max(axis=0)}")
        
#         return scan_data_global, mask_data_global
#     except Exception as e:
#         print(f"Error loading NIfTI files: {e}")
#         scan_data_global, mask_data_global = None, None
#         return None, None

# def normalize_slice(slice_data):
#     """Normalize slice data to [0, 1] for colormapping."""
#     min_val = np.min(slice_data)
#     max_val = np.max(slice_data)
#     if max_val - min_val > 1e-9: 
#         return (slice_data - min_val) / (max_val - min_val)
#     return np.zeros_like(slice_data)

# def update_plot(slice_idx):
#     """Updates the 3D plot with a sub-mask and the selected scan slice."""
#     global scan_data_global, mask_data_global, ax_global, fig_global, out_plot_global, current_case_id

#     if mask_data_global is None or scan_data_global is None or ax_global is None or fig_global is None or out_plot_global is None:
#         print("Global variables not fully initialized. Cannot update plot.")
#         return
    
#     with out_plot_global: # Crucial: all plotting actions for the output area happen here
#         out_plot_global.clear_output(wait=True) # Clear previous content of the output widget
#         ax_global.clear() # Clear the axes for redrawing

#         # 1. Plot the FULL mask
#         print(f"Plotting full mask_data_global (shape {mask_data_global.shape}, sum {np.sum(mask_data_global)}). This might take a moment...")
#         if np.sum(mask_data_global) > 0:
#             ax_global.voxels(mask_data_global, facecolors='cyan', edgecolor='k', alpha=0.3, linewidth=0.1) # Reduced linewidth for dense plots
#             print("Full mask plotting call completed.")
#         else:
#             print("Full mask is empty. Not plotting voxels.")

#         # --- Commented out sub-mask plotting ---
#         # x_start, x_end = 0, 20 
#         # y_start, y_end = 0, 20
#         # z_start, z_end = 155, 175 
#         # x_end = min(x_end, mask_data_global.shape[0])
#         # y_end = min(y_end, mask_data_global.shape[1])
#         # z_end = min(z_end, mask_data_global.shape[2])
#         # x_start = min(x_start, x_end -1 if x_end > 0 else 0)
#         # y_start = min(y_start, y_end -1 if y_end > 0 else 0)
#         # z_start = min(z_start, z_end -1 if z_end > 0 else 0)
#         # sub_mask_to_plot = mask_data_global[x_start:x_end, y_start:y_end, z_start:z_end]
#         # if np.sum(sub_mask_to_plot) > 0:
#         #     print(f"Plotting sub-mask (shape {sub_mask_to_plot.shape}) from original region [{x_start}:{x_end}, {y_start}:{y_end}, {z_start}:{z_end}] at plot origin.")
#         #     ax_global.voxels(sub_mask_to_plot, facecolors='green', edgecolor='k', alpha=0.7, linewidth=0.5)
#         # else:
#         #     print(f"Sub-mask from region [{x_start}:{x_end}, {y_start}:{y_end}, {z_start}:{z_end}] is empty. Not plotting sub-mask.")

#         # 2. Plot the 2D scan slice
#         scan_slice = scan_data_global[:, :, slice_idx]
#         scan_slice_normalized = normalize_slice(scan_slice)
#         xx, yy = np.meshgrid(np.arange(scan_data_global.shape[0]), 
#                              np.arange(scan_data_global.shape[1]))
#         cmap = plt.cm.gray
#         facecolors = cmap(scan_slice_normalized.T)
#         # Set rstride and cstride to 1 for a solid surface, removing checkerboard
#         ax_global.plot_surface(xx, yy, np.full(xx.shape, slice_idx),
#                                facecolors=facecolors, rstride=1, cstride=1, shade=False, alpha=0.7) # Adjusted alpha for better visibility with voxels
#         print(f"Plotting scan slice at Dim2 = {slice_idx}")

#         # Set labels and limits for the entire volume
#         ax_global.set_xlabel(f'Dim 0 (up to {scan_data_global.shape[0]})')
#         ax_global.set_ylabel(f'Dim 1 (up to {scan_data_global.shape[1]})')
#         ax_global.set_zlabel(f'Dim 2 (up to {scan_data_global.shape[2]})')
#         ax_global.set_title(f'Case: {current_case_id}, Scan Slice: {slice_idx}')
        
#         ax_global.set_xlim([0, scan_data_global.shape[0]])
#         ax_global.set_ylim([0, scan_data_global.shape[1]])
#         ax_global.set_zlim([0, scan_data_global.shape[2]])

#         display(fig_global) # This re-renders the fig_global into the out_plot_global

#     print(f"--- update_plot finished for slice_idx: {slice_idx} ---")


# def run_visualization_notebook():
#     """Main function to set up and run the visualization in a notebook."""
#     global scan_data_global, mask_data_global, fig_global, ax_global, out_plot_global, current_case_id

#     print("--- Vessel FM 3D Mask Visualization (Notebook Mode) ---")
#     case_id_str_input = input("Enter Case ID (e.g., 6): ")
    
#     load_data(case_id_str_input) 
#     if scan_data_global is None or mask_data_global is None:
#         print("Failed to load data. Exiting.")
#         return
    
#     # # Initialize Matplotlib figure and Output widget ONCE, or reuse if cell is re-run
#     # if fig_global is None or not plt.fignum_exists(fig_global.number): # If no fig or fig was closed
#     #     print("Creating new figure, axes, and output widget.")
#     #     fig_global = plt.figure(figsize=(10, 8))
#     #     ax_global = fig_global.add_subplot(111, projection='3d')
#     #     out_plot_global = Output() # Create the output widget to hold the plot
#     # else:
#     #     print("Reusing existing figure and axes. Clearing output widget for new plot.")
#     #     if ax_global is None or ax_global not in fig_global.axes : # If axes were somehow removed
#     #          ax_global = fig_global.add_subplot(111, projection='3d')
#     #     out_plot_global.clear_output(wait=True) # Clear previous content if any

#     # max_slice_idx = scan_data_global.shape[2] - 1
#     # slice_slider = IntSlider(min=0, max=max_slice_idx, step=1, value=max_slice_idx // 2,
#     #                          description='Scan Slice (Dim2):', continuous_update=False,
#     #                          layout={'width': '80%'}) 

#     # # Display the interactive components: slider and the output area for the plot
#     # # The VBox will be displayed below the cell. Updates to the plot will happen inside out_plot_global.
#     # display(VBox([slice_slider, out_plot_global]))
#     # print("Displayed VBox with slider and output area.")

#     # # Connect slider to update_plot. interact will call update_plot initially.
#     # interact(update_plot, slice_idx=slice_slider)
#     # print("`interact` has been set up. Initial plot should be generated by `update_plot`.")
    
#     #----------


#     import numpy as np
#     from skimage import measure
#     import plotly.graph_objects as go
#     from ipywidgets import interact, IntSlider

#     # Precompute the mesh once
#     verts, faces, _, _ = measure.marching_cubes(mask_data_global, level=0.5)

#     # Extract voxel dimensions
#     nx, ny, nz = scan_data_global.shape

#     # Make the base mesh trace
#     mesh3d = go.Mesh3d(
#         x=verts[:, 0], y=verts[:, 1], z=verts[:, 2],
#         i=faces[:, 0], j=faces[:, 1], k=faces[:, 2],
#         opacity=0.2, color='black', flatshading=True,
#         name='Vessel Surface'
#     )

#     def plot_with_slice(z_idx):
#         # Extract the 2D slice and normalize
#         slice_img = scan_data_global[:, :, z_idx]
#         norm = (slice_img - slice_img.min()) / (slice_img.ptp() + 1e-9)
        
#         # Build a textured surface at z = z_idx
#         xx, yy = np.meshgrid(np.arange(nx), np.arange(ny), indexing='ij')
#         surface = go.Surface(
#             x=xx, y=yy, z=np.full_like(xx, z_idx),
#             surfacecolor=norm.T,  # transpose so axes align
#             colorscale='Gray',
#             cmin=0, cmax=1,
#             opacity=0.7,
#             showscale=False,
#             name=f'Slice {z_idx}'
#         )

#         fig = go.Figure([mesh3d, surface])
#         fig.update_layout(
#             title=f'Case {current_case_id}: 3D Mask + Slice at z={z_idx}',
#             scene=dict(
#                 xaxis=dict(title='X'),
#                 yaxis=dict(title='Y'),
#                 zaxis=dict(title='Z'),
#                 aspectmode='data'
#             ),
#             margin=dict(l=0, r=0, b=0, t=30)
#         )
#         fig.show()

#     # Slider to choose the slice
#     max_z = nz - 1
#     slider = IntSlider(value=max_z//2, min=0, max=max_z, step=1,
#                     description='Slice Z:', continuous_update=False)

#     interact(plot_with_slice, z_idx=slider)


    
# run_visualization_notebook()

--- Vessel FM 3D Mask Visualization (Notebook Mode) ---
Loading scan: C:\Users\giles\Github\vesselFM\data\d_real\ImageCAS-Raw\6.img.nii.gz
Scan data loaded. Shape: (512, 512, 275)
Loading mask: C:\Users\giles\Github\vesselFM\data\inference\6_pred.nii.gz
Mask data loaded and thresholded. Shape: (512, 512, 275), Sum True: 319739
Min/Max coords of True voxels: [0 0 0] / [511 511 274]


interactive(children=(IntSlider(value=137, continuous_update=False, description='Slice Z:', max=274), Output()…