In [212]:
# This project focuses on brain CT image segmentation and denoising using Deep Learning.
# Al Smith and Will Newman
# c 2024

# Importing the necessary libraries:
import pydicom
from collections import defaultdict
import torch
import numpy as np
import zipfile
import os
import sys
import nibabel as nib

import matplotlib.pyplot as plt
from ipywidgets import widgets
from scipy.ndimage import label
from mayavi import mlab
from mpl_toolkits.mplot3d import Axes3D

## Tasks:
1) Load data from CQ500
    * Download from: http://headctstudy.qure.ai/dataset
    * Explore the DICOM Header for voxel size and imaging information (ideally the CT machine model)
    * Resolution, dose, parameters, etc.
2) Preprocess the data (noise addition, downsampling, agumentation, etc)
    * 2a) Add noise similar to low-resolution CT
    * 2b) Downsample the images to lower-resolution scale
    * 2c) Split data into train/test sets
3) Build a 3D U-Net model for segmentation
    * 3a) start with training the model on CQ500 and normal masks
    * 3b) train the model on CQ500 with noise added
    * 3c) train the model on CQ500 with images processed by denoising model
4) Denoising Model for noisy CT images
5) Train the model
6) Evaluate the model

In [26]:
# Extract zip files from CQ500
data_dir = './CQ500'
extract_dir = '/media/hal9000/Database/CQ500_extracted'

# Create the extraction directory if it doesn't exist
os.makedirs(extract_dir, exist_ok=True)

for i in range(121, 491): # CQ500 scan 120 corrupted...
    zip_file = data_dir + "/CQ500-CT-{}.zip".format(i)
    with zipfile.ZipFile(zip_file, 'r') as zip_ref:
        # Extract all the contents of the zip file in the extraction directory
        zip_ref.extractall(path=extract_dir)
    print("completed extraction of CT", i)

print("Extraction completed.")

completed extraction of CT 121
completed extraction of CT 122
completed extraction of CT 123
completed extraction of CT 124
completed extraction of CT 125
completed extraction of CT 126
completed extraction of CT 127
completed extraction of CT 128
completed extraction of CT 129
completed extraction of CT 130
completed extraction of CT 131
completed extraction of CT 132
completed extraction of CT 133
completed extraction of CT 134
completed extraction of CT 135
completed extraction of CT 136
completed extraction of CT 137
completed extraction of CT 138
completed extraction of CT 139
completed extraction of CT 140
completed extraction of CT 141
completed extraction of CT 142
completed extraction of CT 143
completed extraction of CT 144
completed extraction of CT 145
completed extraction of CT 146
completed extraction of CT 147
completed extraction of CT 148
completed extraction of CT 149
completed extraction of CT 150
completed extraction of CT 151
completed extraction of CT 152
complete

In [215]:
# Global Variables
lb = 1040
ub = 1080

In [214]:
def save_nifti(mask, output_dir, series_uid):
    """Save the mask as a NIfTI file."""
    nifti_image = nib.Nifti1Image(mask.astype(np.int16), affine=np.eye(4))
    nib.save(nifti_image, os.path.join(output_dir, f'{series_uid}_mask.nii'))

In [216]:
# Apply window and level to a 2D numpy array of DICOM data
def apply_window_level(data, lb, ub):
    windowed_data = np.clip(data, lb, ub)
    normalized_data = (windowed_data - lb) / (ub - lb)  # Normalize between 0 and 1
    return normalized_data

# Remove unnecessary CSF spaces from the mask
def remove_mask_below_slice(volume_mask, slice_index):
    """
    Set the mask to False for all slices below the specified index.

    Parameters:
    - volume_mask: 3D numpy array (boolean) of the mask.
    - slice_index: Integer, the index below which the mask should be removed.
    """
    # Set all slices below the specified index to False
    volume_mask[:slice_index, :, :] = False
    return volume_mask


In [217]:
# Largest connected component segmentation
def create_spherical_mask(shape, center, radius):
    z, y, x = np.ogrid[:shape[0], :shape[1], :shape[2]]
    dist_from_center = (z - center[0])**2 + (y - center[1])**2 + (x - center[2])**2
    return dist_from_center <= radius**2


def largest_connected_component_3d(volume_data, lb, center, radius, threshold=0.4):
    """
    Find the largest connected component in a 3D volume that intersects with a specified spherical region.

    Parameters:
    - volume_data: 3D numpy array of DICOM data.
    - lb: Lower bound to create a binary mask for values of interest.
    - center: Tuple, the center coordinates (z, y, x) of the volume.
    - radius: Integer, the radius used to define the spherical region.

    Returns:
    - 3D mask (boolean array) of the same shape as volume_data for the largest component intersecting the sphere.
    """
    # Generate the binary mask
    slice_index = int(volume_data.shape[0] * 0.4)
    binary_mask = volume_data < lb
    binary_mask = remove_mask_below_slice(binary_mask, slice_index)

    # Label all components
    labeled_volume, num_features = label(binary_mask)
    if num_features == 0:
        return np.zeros_like(volume_data, dtype=bool)  # No components found

    # Generate the spherical mask
    spherical_mask = create_spherical_mask(volume_data.shape, center, radius)

    # Find labels intersecting the spherical mask
    intersecting_labels = np.unique(labeled_volume[spherical_mask])

    # Calculate the size of each intersecting component and select the largest
    largest_label = None
    max_size = 0
    for label_idx in intersecting_labels:
        if label_idx == 0:
            continue  # Skip background
        component_mask = labeled_volume == label_idx
        component_size = np.sum(component_mask)
        if component_size > max_size:
            max_size = component_size
            largest_label = label_idx

    return labeled_volume == largest_label if largest_label is not None else np.zeros_like(volume_data, dtype=bool)


In [431]:
def load_dicom_series_volumes(base_directory, idx):
    series_volumes = defaultdict(list)
    # Track the number of DICOM files processed
    dicom_file_count = 0
    
    # Construct the directory path
    case_dir = os.path.join(base_directory, f"CQ500CT{idx} CQ500CT{idx}", "Unknown Study")
    
    if os.path.exists(case_dir):
        # Walk through all files in the series directories within the "Unknown Study" directory
        for root, dirs, files in os.walk(case_dir):
            for dir in dirs:
                series_path = os.path.join(root, dir)
                slices = []
                # Collect all DICOM slices in the series directory
                for slice_file in os.listdir(series_path):
                    if slice_file.lower().endswith('.dcm'):
                        full_path = os.path.join(series_path, slice_file)
                        try:
                            # Read the DICOM file
                            dicom_slice = pydicom.dcmread(full_path)
                            # Append the slice and its position for later sorting
                            slices.append((dicom_slice, dicom_slice.ImagePositionPatient[2]))
                            dicom_file_count += 1
                        except Exception as e:
                            print(f"Failed to read {slice_file} as DICOM: {e}")

                # Sort slices based on the z-coordinate (ImagePositionPatient[2])
                slices.sort(key=lambda x: x[1])
                # Stack the pixel data from sorted slices to form a 3D volume
                if slices:
                    series_uid = slices[0][0].SeriesInstanceUID
                    volume = np.stack([s[0].pixel_array for s in slices])
                    series_volumes[series_uid].append(volume)

                    center = (volume.shape[0] // 2, volume.shape[1] // 2, volume.shape[2] // 2)
                    radius = int(volume.shape[0] // 6)  # Define the radius as desired

                    # Generate mask and save
                    mask = remove_mask_below_slice(largest_connected_component_3d(volume, lb, center, radius), volume.shape[0] * 2 // 5)
                    save_nifti(mask, case_dir, series_uid)
                    print("Saved mask to ", case_dir)
    else:
        print(f"Directory does not exist: {case_dir}")

    print(f"Processed {dicom_file_count} DICOM files.")
    return series_volumes

# Usage example
base_directory = '/media/hal9000/Database/CQ500_extracted'  # Adjust this path
dicom_volumes = load_dicom_series_volumes(base_directory, 50)
print(f"Number of series with loaded volumes: {len(dicom_volumes)}")

# DICOM Directories that are useful
* Series Number    0, 3, 4, 5, 7, 9, 14, 15, 19, 23, 26, 28, 33, 36, 37, 39, 40, 41, 43, 45, 46,
* Number of Masks: 5, 2, 1, 1, 1, 2,  2,  1,  4,  2,  3,  2,  2,  3,  2,  2,  2,  1,  1,  1,  1,
# DICOM Directories with surgical pathology (less useful)
* Series Number    2, 6, 10, 11, 17, 18, 20, 22, 34, 48
* Number of Masks: 2, 3,  1,  2,  2,  2,  2,  2,  3,  4

In [430]:
# Visualize Loaded DICOM Data

# Assume dicom_volumes is already loaded as per the previous part of our discussion
series_selection = 0
dicom = 50
series_uid = list(dicom_volumes.keys())[series_selection]
volume = dicom_volumes[series_uid][0]  # Get the first volume of the first series

center = (volume.shape[0] // 2, volume.shape[1] // 2, volume.shape[2] // 2)
radius = int(volume.shape[0] // 6)  # Define the radius as desired
#volume_mask = remove_mask_below_slice(largest_connected_component_3d(volume, lb, center, radius), volume.shape[0] * 2 // 5)
volume_mask = nib.load(f'/media/hal9000/Database/CQ500_extracted/CQ500CT{dicom} CQ500CT{dicom}/Unknown Study/{series_uid}_mask.nii').get_fdata() > 0

# Function to display a single slice
def view_slice(slice_index):
    plt.figure(figsize=(4, 4))
    processed_image = apply_window_level(volume[slice_index], lb, ub)
    
    # Overlay the 3D mask on the corresponding slice
    overlay = np.zeros(processed_image.shape + (4,))  # RGBA
    overlay[..., 0] = 1.0  # Red channel
    overlay[..., 3] = volume_mask[slice_index] * 0.5  # Semi-transparent where the mask is True

    plt.imshow(processed_image, cmap='gray')
    plt.imshow(overlay)
    plt.axis('off')
    plt.title(f'Slice {slice_index + 1}')
    plt.show()

# Slider to select the slice index
slice_slider = widgets.IntSlider(
    value=0,
    min=0,
    max=volume.shape[0] - 1,  # max slice index
    step=1,
    description='Slice Index:',
    continuous_update=True
)

# Use ipywidgets' interactive functionality to bind the slider and the display function
widgets.interactive(view_slice, slice_index=slice_slider)

interactive(children=(IntSlider(value=0, description='Slice Index:', max=231), Output()), _dom_classes=('widge…

In [356]:
# Visualize the CSF Space Segmentation Mask in 3D dynamic viewer
def visualize_3d_mask(volume_mask):
    """Visualize a 3D mask using mayavi's volume rendering capabilities."""
    # Create a figure
    fig = mlab.figure(bgcolor=(0, 0, 0), size=(800, 800))
    
    # Visualize the volume mask: 1s are turned to True, 0s to False
    src = mlab.pipeline.scalar_field(volume_mask.astype(int))
    # Threshold to visualize only the 1s
    mlab.pipeline.iso_surface(src, contours=[volume_mask.min()+0.5, volume_mask.max()], opacity=0.4, color=(1, 0, 0))
    
    # Enhance the view
    mlab.view(azimuth=180, elevation=180, distance=400)
    mlab.roll(180)
    
    # Add axes and outline for better visual orientation
    mlab.outline(src, color=(1, 1, 1))
    mlab.axes(src, color=(1, 1, 1), xlabel='X', ylabel='Y', zlabel='Z')

    # Show the plot
    mlab.show()

# Assuming volume_mask is the mask calculated earlier
visualize_3d_mask(volume_mask)

In [None]:
# Task #1 Convert loaded data into torch training and validation datasets

In [None]:
# Task #2 Preprocess the data
# Task #2b Add noise similar to low-resolution CT

# Task #2c Downsample the images to lower-resolution scale

# Task #2d Split the data into training and testing sets

In [None]:
# Task #3 Build a 3D U-Net model for segmentation
# Task #3a Start with training the model on CQ500 and normal masks

# Task #3b Train the model on CQ500 with noise added

# Task #3c Train the model on CQ500 with images processed by denoising model