# Tooth Segmentation with 3D U-Net

In [1]:

# Import necessary libraries
import os
import numpy as np
import pydicom
import tensorflow as tf
from tensorflow.keras import layers, models
from sklearn.model_selection import train_test_split
import time
    

### Step 1: Load and Normalize DICOM Images

In [2]:

# Function to load DICOM images and stack them into a 3D volume
def load_dicom_images(dicom_folder_path):
    dicom_files = [os.path.join(dicom_folder_path, f) for f in os.listdir(dicom_folder_path) if f.endswith('.dcm')]
    dicom_files = sorted(dicom_files)  # Ensure consistent order
    images = [pydicom.dcmread(file).pixel_array for file in dicom_files]
    volume = np.stack(images, axis=-1)
    volume = np.array(volume, dtype=np.float32)
    return volume

# Function to normalize the DICOM volumes
def normalize(volume, min_bound=-1000, max_bound=400):
    volume = (volume - min_bound) / (max_bound - min_bound)
    volume = np.clip(volume, 0, 1)
    return volume
    

### Step 2: Load and Prepare Data

In [6]:
import os
import numpy as np

# Paths to input and label data
input_folder_path = "D:/DIcom gans/Data/raw"  # Path to folder containing raw DICOM input volumes
label_folder_path = "D:/DIcom gans/Data/labels"  # Path to folder containing ground truth labels

# Initialize lists to store input and label volumes
input_volumes = []
target_volumes = []

# Load and normalize input volumes
def load_and_append_dicom(folder_path, volume_list, is_normalized=True):
    """
    Loads DICOM images from a folder or single directory and appends normalized volume to list.
    """
    if any(f.endswith('.dcm') for f in os.listdir(folder_path)):
        # If .dcm files are directly inside the folder
        dicom_volume = load_dicom_images(folder_path)
        if is_normalized:
            dicom_volume = normalize(dicom_volume)
        volume_list.append(dicom_volume)
    else:
        # If subfolders are present, process each subfolder separately
        for subfolder in os.listdir(folder_path):
            subfolder_path = os.path.join(folder_path, subfolder)
            if os.path.isdir(subfolder_path):
                dicom_volume = load_dicom_images(subfolder_path)
                if is_normalized:
                    dicom_volume = normalize(dicom_volume)
                volume_list.append(dicom_volume)

# Load input volumes
load_and_append_dicom(input_folder_path, input_volumes, is_normalized=True)

# Load label volumes (assumed not to be normalized)
load_and_append_dicom(label_folder_path, target_volumes, is_normalized=False)

# Convert lists to numpy arrays for model compatibility
input_volumes = np.array(input_volumes)
target_volumes = np.array(target_volumes)

# Expand dimensions to add a channel axis (required for U-Net input)
input_volumes = np.expand_dims(input_volumes, axis=-1)
target_volumes = np.expand_dims(target_volumes, axis=-1)


FileNotFoundError: [WinError 3] The system cannot find the path specified: 'D:/DIcom gans/Data/labels'

### Step 3: Split Data into Training and Validation Sets

In [None]:

# Split data into training and validation sets
train_X, val_X, train_y, val_y = train_test_split(input_volumes, target_volumes, test_size=0.2, random_state=42)
    

### Step 4: Define U-Net Model

In [None]:

# Define U-Net model for 3D segmentation
def unet_model(input_shape=(128, 128, 64, 1)):
    inputs = layers.Input(input_shape)
    
    # Encoder
    c1 = layers.Conv3D(32, 3, activation="relu", padding="same")(inputs)
    c1 = layers.BatchNormalization()(c1)
    c1 = layers.Conv3D(32, 3, activation="relu", padding="same")(c1)
    c1 = layers.BatchNormalization()(c1)
    p1 = layers.MaxPooling3D((2, 2, 2))(c1)

    c2 = layers.Conv3D(64, 3, activation="relu", padding="same")(p1)
    c2 = layers.BatchNormalization()(c2)
    c2 = layers.Conv3D(64, 3, activation="relu", padding="same")(c2)
    c2 = layers.BatchNormalization()(c2)
    p2 = layers.MaxPooling3D((2, 2, 2))(c2)
    p2 = layers.Dropout(0.3)(p2)

    c3 = layers.Conv3D(128, 3, activation="relu", padding="same")(p2)
    c3 = layers.BatchNormalization()(c3)
    c3 = layers.Conv3D(128, 3, activation="relu", padding="same")(c3)
    c3 = layers.BatchNormalization()(c3)

    # Decoder
    u1 = layers.Conv3DTranspose(64, 3, strides=(2, 2, 2), padding="same")(c3)
    u1 = layers.concatenate([u1, c2])
    c4 = layers.Conv3D(64, 3, activation="relu", padding="same")(u1)
    c4 = layers.BatchNormalization()(c4)
    c4 = layers.Conv3D(64, 3, activation="relu", padding="same")(c4)
    c4 = layers.BatchNormalization()(c4)

    u2 = layers.Conv3DTranspose(32, 3, strides=(2, 2, 2), padding="same")(c4)
    u2 = layers.concatenate([u2, c1])
    c5 = layers.Conv3D(32, 3, activation="relu", padding="same")(u2)
    c5 = layers.BatchNormalization()(c5)
    c5 = layers.Conv3D(32, 3, activation="relu", padding="same")(c5)
    c5 = layers.BatchNormalization()(c5)

    # Output Layer
    outputs = layers.Conv3D(1, 1, activation="sigmoid")(c5)

    model = models.Model(inputs, outputs)
    return model

# Instantiate the model
model = unet_model(input_shape=train_X.shape[1:])
    

### Step 5: Compile the Model

In [None]:

# Compile the model
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
              loss='binary_crossentropy',
              metrics=['accuracy'])
    

### Step 6: Train the Model with Time Estimation

In [None]:

# Callbacks
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint("best_model.h5", save_best_only=True)
early_stopping_cb = tf.keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)

# Measure time for the first epoch to estimate total time
start_time = time.time()

# Train the model
history = model.fit(train_X, train_y,
                    validation_data=(val_X, val_y),
                    epochs=50,
                    batch_size=1,
                    callbacks=[checkpoint_cb, early_stopping_cb],
                    verbose=1)

# Calculate time estimation for total training
elapsed_time = time.time() - start_time
print(f"Time for one epoch: {elapsed_time:.2f} seconds")
estimated_total_time = elapsed_time * 50  # Adjust based on number of epochs
print(f"Estimated total training time: {estimated_total_time / 60:.2f} minutes")
    