In [None]:
import vtk
import os
import matplotlib.pyplot as plt
import numpy as np
import pydicom
from stl import mesh
from pydicom.multival import MultiValue
from skimage import measure

In [None]:
dicom_dir = "data/dicom_files/"

In [None]:
def is_suitable_for_3d(slice):
    """Checks if a DICOM dataset is suitable for 3D reconstruction."""
    image_type = slice.get('ImageType', [])  # Handle missing ImageType
    if not isinstance(image_type, (list, MultiValue)):
        image_type = [image_type]  # Ensure it's a list
    if 'ORIGINAL' in image_type or 'PRIMARY' in image_type:
        if 'AXIAL' in image_type:
            return True  # Suitable
    return False  # Not suitable


def show_images_grid(pixel_arrays, images_per_row=5, cmap='gray'):
    """Displays a grid of images.

    Args:
        pixel_arrays: A list or numpy array of 2D pixel arrays (images).
        images_per_row: The number of images to display per row.
        cmap: The colormap to use (default is 'gray').
    """

    num_images = len(pixel_arrays)
    num_rows = (num_images + images_per_row - 1) // images_per_row  # Calculate rows needed

    fig, axes = plt.subplots(num_rows, images_per_row, figsize=(20, 5 * num_rows))  # Adjust figure size dynamically

    # Handle the case where there's only one row:
    if num_rows == 1:
        axes = np.array([axes])  # Makes axes iterable even if it's just one row


    for i, pixel_array in enumerate(pixel_arrays):
        row = i // images_per_row
        col = i % images_per_row

        axes[row, col].imshow(pixel_array, cmap=cmap)
        axes[row, col].axis('off')

    # Hide any unused subplots (if num_images is not a multiple of images_per_row):
    for j in range(i + 1, num_rows * images_per_row):  # Iterate from the image after the last one
        row = j // images_per_row
        col = j % images_per_row
        axes[row, col].axis('off')  # Turn off axis for unused subplots
        axes[row, col].set_visible(False) #Hide the entire subplot


    plt.tight_layout() # Adjust subplot params so that subplots fit in to the figure area.
    plt.show()

In [None]:
slices = []
for filename in os.listdir(dicom_dir):
    filepath = os.path.join(dicom_dir, filename)
    slice = pydicom.dcmread(filepath)
    if is_suitable_for_3d(slice):
        slices.append(slice)

# Sort slices by Instance Number to ensure correct order
slices.sort(key=lambda x: x.InstanceNumber)


In [None]:
print(f"Found {len(slices)} suitable slices for 3D reconstruction.")
print("Sample slice metadata:")
print(slices[0])

show_images_grid([s.pixel_array for s in slices][:40], images_per_row=5)

In [None]:
# get spacing between slices
x_spacings = [s.PixelSpacing[0] for s in slices]
y_spacings = [s.PixelSpacing[1] for s in slices]
z_spacings = [s.SliceThickness for s in slices]

assert np.unique(x_spacings).size == 1, "x_spacing is not consistent"
assert np.unique(y_spacings).size == 1, "y_spacing is not consistent"
assert np.unique(z_spacings).size == 1, "z_spacing is not consistent"

x_spacing = x_spacings[0]
y_spacing = y_spacings[0]
z_spacing = z_spacings[0]

rescaled_slope = slices[0].RescaleSlope
rescaled_intercept = slices[0].RescaleIntercept
window_center = slices[0].WindowCenter
window_width = slices[0].WindowWidth

print(f"Pixel Spacing: {x_spacing} x {y_spacing} mm")
print(f"Slice Thickness: {z_spacing} mm")

def process_pixel_data(pixel_data):
    # Apply rescaling
    hu_data = pixel_data

    # Apply windowing
    min_val = window_center - window_width / 2
    max_val = window_center + window_width / 2
    windowed_data = np.clip(hu_data, min_val, max_val)

    # Normalize
    normalized_data = (windowed_data - min_val) / (max_val - min_val) * 255
    normalized_data = normalized_data.astype(np.uint8)

    return normalized_data

# get pixel_array and ensure they are the same size
# pixel_arrays = [process_pixel_data(s.pixel_array) for s in slices]
pixel_arrays = [s.pixel_array for s in slices]

# ensure all pixel arrays are the same size
assert np.unique([pa.shape for pa in pixel_arrays]).size == 1, "pixel arrays are not the same size"

In [None]:
# plot histogram of pixel values
pixel_values = np.concatenate([pa.flatten() for pa in pixel_arrays])
plt.hist(pixel_values, bins=100, range=(0, 255), color='gray', alpha=0.6)
plt.xlabel('Pixel Value')
plt.ylabel('Frequency')
plt.title('Pixel Value Distribution')
plt.show()

In [None]:
# Create a 3D binary mask based on threshold

mask = np.stack(pixel_arrays).astype(np.uint8)
verts, faces, _, _ = measure.marching_cubes(mask)

# Apply scaling based on DICOM spacing (important for accurate dimensions)
verts[:, 0] *= x_spacing
verts[:, 1] *= y_spacing
verts[:, 2] *= z_spacing

# 4. Save the mesh to an STL file
mesh_obj = mesh.Mesh(np.zeros(faces.shape[0], dtype=mesh.Mesh.dtype))
mesh_obj.vectors = verts[faces]
mesh_obj.save("output.stl")

In [None]:
def visualize_stl_vtk(stl_file):
    """Visualizes an STL file using VTK."""
    reader = vtk.vtkSTLReader()
    reader.SetFileName(stl_file)
    reader.Update()

    mapper = vtk.vtkPolyDataMapper()
    mapper.SetInputConnection(reader.GetOutputPort())

    actor = vtk.vtkActor()
    actor.SetMapper(mapper)

    renderer = vtk.vtkRenderer()
    renderWindow = vtk.vtkRenderWindow()
    renderWindow.AddRenderer(renderer)
    iren = vtk.vtkRenderWindowInteractor()
    iren.SetRenderWindow(renderWindow)

    renderer.AddActor(actor)
    renderer.SetBackground(0.1, 0.2, 0.3)  # Set background color

    renderWindow.Render()
    iren.Start()

visualize_stl_vtk("output.stl")