In [1]:
import tensorflow as tf
import os
import numpy as np
import random
from tqdm import tqdm
from skimage.io import imread, imshow
from skimage.transform import resize
import matplotlib.pyplot as plt

In [None]:


# Set random seed for reproducibility
seed = 42
np.random.seed(seed)
random.seed(seed)

# Image parameters
IMG_WIDTH = 128
IMG_HEIGHT = 128
IMG_CHANNELS = 1

# Folders containing the data
folders = ['benign', 'malignant', 'normal']

images = []
masks = []

print("Processing images and corresponding masks...")
for folder in folders:
    file_list = os.listdir(folder)
    # Get only the image files (exclude mask files)
    image_files = [f for f in file_list if f.endswith('.png') and '_mask' not in f]
    for img_file in tqdm(image_files, desc=f"Processing {folder}"):
        # Load and resize image
        img_path = os.path.join(folder, img_file)
        img = imread(img_path)[:, :, :IMG_CHANNELS]
        img = resize(img, (IMG_HEIGHT, IMG_WIDTH), mode='constant', preserve_range=True)
        images.append(img)
        
        # Initialize a blank mask (single channel)
        mask = np.zeros((IMG_HEIGHT, IMG_WIDTH, 1), dtype=bool)
        # The base name (e.g., "benign (25)") used to match its masks
        base_name = img_file[:-4]  # remove '.png'
        # Find all mask files that start with the base name and contain '_mask'
        mask_files = [f for f in file_list if f.startswith(base_name) and '_mask' in f]
        for m_file in mask_files:
            m_path = os.path.join(folder, m_file)
            m = imread(m_path)
            # Resize mask; note that masks can be single channel images
            m = resize(m, (IMG_HEIGHT, IMG_WIDTH), mode='constant', preserve_range=True)
            m = np.expand_dims(m, axis=-1)
            # Combine masks: each pixel becomes the maximum value across masks
            mask = np.maximum(mask, m)
        masks.append(mask)

# Convert lists to numpy arrays
X = np.array(images, dtype=np.uint8)
Y = np.array(masks, dtype=bool)

print("Total images:", X.shape[0])

# (Optional) Shuffle the data
indices = np.arange(X.shape[0])
np.random.shuffle(indices)
X = X[indices]
Y = Y[indices]

# Split data into training and validation sets (e.g., 90% training, 10% validation)
split_index = int(0.9 * X.shape[0])
X_train, Y_train = X[:split_index], Y[:split_index]
X_val, Y_val = X[split_index:], Y[split_index:]

print("Training samples:", X_train.shape[0], "Validation samples:", X_val.shape[0])

# Build U-Net model
inputs = tf.keras.layers.Input((IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
s = tf.keras.layers.Lambda(lambda x: x / 255)(inputs)

# Contraction path
c1 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(s)
c1 = tf.keras.layers.Dropout(0.1)(c1)
c1 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c1)
p1 = tf.keras.layers.MaxPooling2D((2, 2))(c1)

c2 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p1)
c2 = tf.keras.layers.Dropout(0.1)(c2)
c2 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c2)
p2 = tf.keras.layers.MaxPooling2D((2, 2))(c2)

c3 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p2)
c3 = tf.keras.layers.Dropout(0.2)(c3)
c3 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c3)
p3 = tf.keras.layers.MaxPooling2D((2, 2))(c3)

c4 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p3)
c4 = tf.keras.layers.Dropout(0.2)(c4)
c4 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c4)
p4 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(c4)

c5 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p4)
c5 = tf.keras.layers.Dropout(0.3)(c5)
c5 = tf.keras.layers.Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c5)

# Expansive path 
u6 = tf.keras.layers.Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(c5)
u6 = tf.keras.layers.concatenate([u6, c4])
c6 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u6)
c6 = tf.keras.layers.Dropout(0.2)(c6)
c6 = tf.keras.layers.Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c6)

u7 = tf.keras.layers.Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(c6)
u7 = tf.keras.layers.concatenate([u7, c3])
c7 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u7)
c7 = tf.keras.layers.Dropout(0.2)(c7)
c7 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c7)

u8 = tf.keras.layers.Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(c7)
u8 = tf.keras.layers.concatenate([u8, c2])
c8 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u8)
c8 = tf.keras.layers.Dropout(0.1)(c8)
c8 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c8)

u9 = tf.keras.layers.Conv2DTranspose(16, (2, 2), strides=(2, 2), padding='same')(c8)
u9 = tf.keras.layers.concatenate([u9, c1])
c9 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u9)
c9 = tf.keras.layers.Dropout(0.1)(c9)
c9 = tf.keras.layers.Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c9)

outputs = tf.keras.layers.Conv2D(1, (1, 1), activation='sigmoid')(c9)

model = tf.keras.Model(inputs=[inputs], outputs=[outputs])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
model.summary()

# Set up callbacks
checkpointer = tf.keras.callbacks.ModelCheckpoint('model_ultrasound_unet.h5', verbose=1, save_best_only=True)
callbacks = [
    tf.keras.callbacks.EarlyStopping(patience=2, monitor='val_loss'),
    tf.keras.callbacks.TensorBoard(log_dir='logs')
]

# Train the model
results = model.fit(X_train, Y_train,
                    validation_data=(X_val, Y_val),
                    batch_size=16,
                    epochs=25,
                    callbacks=callbacks)

# Perform a sanity check on a random training sample
ix = random.randint(0, X_train.shape[0] - 1)
imshow(X_train[ix])
plt.title("Input Image")
plt.show()

imshow(np.squeeze(Y_train[ix]))
plt.title("Ground Truth Mask")
plt.show()

# Predict on the same image
pred = model.predict(np.expand_dims(X_train[ix], axis=0), verbose=1)
pred_mask = (pred[0] > 0.5).astype(np.uint8)
imshow(np.squeeze(pred_mask))
plt.title("Predicted Mask")
plt.show()
