In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Dropout, Conv2DTranspose, concatenate
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.optimizers import Adam
from skimage import filters, morphology, measure
from tqdm import tqdm
import cv2

2025-04-27 13:39:56.245240: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1745761196.452576     257 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1745761196.514484     257 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
# Use Mixed Precision (save VRAM)
from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy("mixed_float16")
print("mixed precision enabled.")

mixed precision enabled.


In [3]:
# Set random seeds for reproducibility
SEED = 42
np.random.seed(SEED)
tf.random.set_seed(SEED)

In [4]:
# Parameters
IMG_SIZE = 256  # Input size for U-Net
BATCH_SIZE = 8
EPOCHS = 2
DATASET_ROOT = '/kaggle/input/preprocessed-mammo-splits'
OUTPUT_DIR = './cell_roi_extracted'
SEGMENTATION_MODEL_PATH = './cell_unet_model.keras'

In [5]:
# Create output directory if it doesn't exist
os.makedirs(OUTPUT_DIR, exist_ok=True)

In [6]:
# Define custom IoU metric
def iou_metric(y_true, y_pred):
    y_pred = tf.cast(y_pred > 0.5, tf.float32)
    intersection = tf.reduce_sum(y_true * y_pred)
    union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) - intersection
    return intersection / (union + tf.keras.backend.epsilon())

In [7]:
# Build U-Net model for cell segmentation
def build_unet_model(input_size=(IMG_SIZE, IMG_SIZE, 1)):
    inputs = Input(input_size)
    
    # Encoder (Downsampling)
    conv1 = Conv2D(64, 3, activation='relu', padding='same')(inputs)
    conv1 = Conv2D(64, 3, activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    
    conv2 = Conv2D(128, 3, activation='relu', padding='same')(pool1)
    conv2 = Conv2D(128, 3, activation='relu', padding='same')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    
    conv3 = Conv2D(256, 3, activation='relu', padding='same')(pool2)
    conv3 = Conv2D(256, 3, activation='relu', padding='same')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    
    conv4 = Conv2D(512, 3, activation='relu', padding='same')(pool3)
    conv4 = Conv2D(512, 3, activation='relu', padding='same')(conv4)
    drop4 = Dropout(0.5)(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)
    
    # Bridge
    conv5 = Conv2D(1024, 3, activation='relu', padding='same')(pool4)
    conv5 = Conv2D(1024, 3, activation='relu', padding='same')(conv5)
    drop5 = Dropout(0.5)(conv5)
    
    # Decoder (Upsampling)
    up6 = Conv2DTranspose(512, 2, strides=(2, 2), padding='same')(drop5)
    merge6 = concatenate([drop4, up6], axis=3)
    conv6 = Conv2D(512, 3, activation='relu', padding='same')(merge6)
    conv6 = Conv2D(512, 3, activation='relu', padding='same')(conv6)
    
    up7 = Conv2DTranspose(256, 2, strides=(2, 2), padding='same')(conv6)
    merge7 = concatenate([conv3, up7], axis=3)
    conv7 = Conv2D(256, 3, activation='relu', padding='same')(merge7)
    conv7 = Conv2D(256, 3, activation='relu', padding='same')(conv7)
    
    up8 = Conv2DTranspose(128, 2, strides=(2, 2), padding='same')(conv7)
    merge8 = concatenate([conv2, up8], axis=3)
    conv8 = Conv2D(128, 3, activation='relu', padding='same')(merge8)
    conv8 = Conv2D(128, 3, activation='relu', padding='same')(conv8)
    
    up9 = Conv2DTranspose(64, 2, strides=(2, 2), padding='same')(conv8)
    merge9 = concatenate([conv1, up9], axis=3)
    conv9 = Conv2D(64, 3, activation='relu', padding='same')(merge9)
    conv9 = Conv2D(64, 3, activation='relu', padding='same')(conv9)
    
    # Output layer - sigmoid for binary segmentation
    outputs = Conv2D(1, 1, activation='sigmoid')(conv9)
    
    model = Model(inputs, outputs)
    model.compile(optimizer=Adam(learning_rate=1e-4), 
                  loss='binary_crossentropy', 
                  metrics=['accuracy', iou_metric])
    
    return model

In [8]:
def load_preprocessed_data():
    train_data = np.load(os.path.join(DATASET_ROOT, 'train_data.npz'))
    val_data = np.load(os.path.join(DATASET_ROOT, 'val_data.npz'))
    test_data = np.load(os.path.join(DATASET_ROOT, 'test_data.npz'))
    
    X_train, y_train = train_data['X'], train_data['y']
    X_val, y_val = val_data['X'], val_data['y']
    X_test, y_test = test_data['X'], test_data['y']
    
    # Ensure images are in grayscale for segmentation
    if X_train.shape[-1] == 3:
        X_train = np.mean(X_train, axis=-1, keepdims=True)
        X_val = np.mean(X_val, axis=-1, keepdims=True)
        X_test = np.mean(X_test, axis=-1, keepdims=True)
    
    print(f"Loaded data shapes: Train {X_train.shape}, Val {X_val.shape}, Test {X_test.shape}")
    return X_train, y_train, X_val, y_val, X_test, y_test

In [9]:
# Generate masks for cell images using adaptive thresholding
def generate_cell_masks(images):
    masks = []
    for img in tqdm(images, desc="Generating cell masks"):
        # Convert to single channel and scale to 0-255
        img_2d = img.squeeze()
        img_2d = (img_2d * 255).astype(np.uint8)
        
        # Apply adaptive thresholding
        binary = cv2.adaptiveThreshold(
            img_2d, 
            255, 
            cv2.ADAPTIVE_THRESH_GAUSSIAN_C, 
            cv2.THRESH_BINARY_INV, 
            11, 
            2
        )
        
        # Clean up the mask with morphological operations
        kernel = np.ones((3, 3), np.uint8)
        opening = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel, iterations=2)
        
        # Remove small objects (noise)
        nb_components, output, stats, centroids = cv2.connectedComponentsWithStats(opening, connectivity=8)
        sizes = stats[1:, -1]
        min_size = 30  # Minimum size of cell objects
        
        # Clean mask
        clean_mask = np.zeros_like(output)
        for i in range(1, nb_components):
            if sizes[i - 1] >= min_size:
                clean_mask[output == i] = 1
        
        # Add channel dimension back
        masks.append(clean_mask.reshape(img.shape))
    
    return np.array(masks)

In [10]:
# Extract ROI using segmentation masks
def extract_roi(images, masks):
    roi_images = []
    
    for image, mask in zip(images, masks):
        # Apply mask to original image
        roi = image * mask
        roi_images.append(roi)
    
    return np.array(roi_images)

In [11]:
# Process individual cells for feature extraction
def extract_cell_features(image, mask):
    # Label individual cells
    labeled_mask = measure.label(mask.squeeze())
    props = measure.regionprops(labeled_mask)
    
    cell_features = []
    for prop in props:
        # Get bounding box
        min_row, min_col, max_row, max_col = prop.bbox
        
        # Extract cell
        cell = image.squeeze()[min_row:max_row, min_col:max_col]
        
        # Calculate features
        mean_intensity = prop.mean_intensity
        area = prop.area
        perimeter = prop.perimeter
        eccentricity = prop.eccentricity
        
        cell_features.append({
            'cell_image': cell,
            'mean_intensity': mean_intensity,
            'area': area,
            'perimeter': perimeter,
            'eccentricity': eccentricity
        })
    
    return cell_features

In [12]:
# Save some example segmentation results for visualization
def save_segmentation_examples(images, masks, rois, filename="cell_segmentation_examples.png"):
    fig, axes = plt.subplots(5, 3, figsize=(15, 25))
    
    for i in range(5):
        # Original image
        axes[i, 0].imshow(images[i].squeeze(), cmap='gray')
        axes[i, 0].set_title('Original Cell Image')
        axes[i, 0].axis('off')
        
        # Mask
        axes[i, 1].imshow(masks[i].squeeze(), cmap='gray')
        axes[i, 1].set_title('Cell Segmentation Mask')
        axes[i, 1].axis('off')
        
        # ROI
        axes[i, 2].imshow(rois[i].squeeze(), cmap='gray')
        axes[i, 2].set_title('Extracted Cell ROI')
        axes[i, 2].axis('off')
    
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, filename))
    plt.close()

In [13]:
# Create a visualization showing cell segmentation outlines
def save_cell_contour_visualization(images, masks, filename="cell_contours.png"):
    fig, axes = plt.subplots(3, 3, figsize=(15, 15))
    
    for i in range(min(9, len(images))):
        row, col = i // 3, i % 3
        
        # Get original image
        img = images[i].squeeze()
        mask = masks[i].squeeze()
        
        # Create RGB visualization
        img_rgb = np.stack([img, img, img], axis=-1)
        
        # Find contours
        contours, _ = cv2.findContours(
            (mask * 255).astype(np.uint8),
            cv2.RETR_EXTERNAL,
            cv2.CHAIN_APPROX_SIMPLE
        )
        
        # Draw contours on RGB image
        cv2.drawContours(img_rgb, contours, -1, (0, 1, 0), 1)  # Green contours
        
        # Display
        axes[row, col].imshow(img_rgb)
        axes[row, col].set_title(f'Cell Sample {i+1}')
        axes[row, col].axis('off')
    
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, filename))
    plt.close()

In [14]:
# Main function for segmentation and ROI extraction
def main():
    # Load preprocessed data
    print("Loading preprocessed data...")
    X_train, y_train, X_val, y_val, X_test, y_test = load_preprocessed_data()
    
    # Create or load U-Net model
    if os.path.exists(SEGMENTATION_MODEL_PATH):
        print(f"Loading existing U-Net model from {SEGMENTATION_MODEL_PATH}")
        model = tf.keras.models.load_model(SEGMENTATION_MODEL_PATH)
    else:
        print("Building and training U-Net segmentation model...")
        model = build_unet_model(input_size=(IMG_SIZE, IMG_SIZE, 1))
        
        # Generate training masks
        print("Generating cell masks...")
        train_masks = generate_cell_masks(X_train)
        val_masks = generate_cell_masks(X_val)
        
        # Set up callbacks
        callbacks = [
            ModelCheckpoint(SEGMENTATION_MODEL_PATH, save_best_only=True, monitor='val_loss'),
            EarlyStopping(patience=10, monitor='val_loss'),
            ReduceLROnPlateau(factor=0.1, patience=5, min_lr=1e-6, monitor='val_loss')
        ]
        
        # Train the model
        history = model.fit(
            X_train, train_masks,
            batch_size=BATCH_SIZE,
            epochs=EPOCHS,
            validation_data=(X_val, val_masks),
            callbacks=callbacks
        )
        
        # Plot training history
        plt.figure(figsize=(12, 4))
        plt.subplot(1, 2, 1)
        plt.plot(history.history['loss'], label='Train Loss')
        plt.plot(history.history['val_loss'], label='Validation Loss')
        plt.title('Loss')
        plt.legend()
        
        plt.subplot(1, 2, 2)
        plt.plot(history.history['accuracy'], label='Train Accuracy')
        plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
        plt.title('Accuracy')
        plt.legend()
        
        plt.tight_layout()
        plt.savefig(os.path.join(OUTPUT_DIR, 'cell_training_history.png'))
        plt.close()
    
    # Predict segmentation masks
    print("Predicting cell segmentation masks...")
    train_pred_masks = model.predict(X_train, verbose=1, batch_size=BATCH_SIZE)
    val_pred_masks = model.predict(X_val, verbose=1, batch_size=BATCH_SIZE)
    test_pred_masks = model.predict(X_test, verbose=1, batch_size=BATCH_SIZE)
    
    # Convert predictions to binary masks
    threshold = 0.5
    train_binary_masks = (train_pred_masks > threshold).astype(np.float32)
    val_binary_masks = (val_pred_masks > threshold).astype(np.float32)
    test_binary_masks = (test_pred_masks > threshold).astype(np.float32)
    
    # Post-process the binary masks to clean them up
    for i in range(len(train_binary_masks)):
        # Remove small objects
        binary = morphology.remove_small_objects(train_binary_masks[i].squeeze().astype(bool), min_size=30)
        # Fill holes
        binary = morphology.binary_closing(binary, morphology.disk(3))
        binary = morphology.binary_fill_holes(binary)
        train_binary_masks[i] = binary.astype(np.float32).reshape(train_binary_masks[i].shape)
    
    for i in range(len(val_binary_masks)):
        binary = morphology.remove_small_objects(val_binary_masks[i].squeeze().astype(bool), min_size=30)
        binary = morphology.binary_closing(binary, morphology.disk(3))
        binary = morphology.binary_fill_holes(binary)
        val_binary_masks[i] = binary.astype(np.float32).reshape(val_binary_masks[i].shape)
    
    for i in range(len(test_binary_masks)):
        binary = morphology.remove_small_objects(test_binary_masks[i].squeeze().astype(bool), min_size=30)
        binary = morphology.binary_closing(binary, morphology.disk(3))
        binary = morphology.binary_fill_holes(binary)
        test_binary_masks[i] = binary.astype(np.float32).reshape(test_binary_masks[i].shape)
    
    # Extract ROIs
    print("Extracting cell ROIs...")
    train_roi = extract_roi(X_train, train_binary_masks)
    val_roi = extract_roi(X_val, val_binary_masks)
    test_roi = extract_roi(X_test, test_binary_masks)
    
    # Save ROI extracted datasets
    print("Saving cell ROI extracted datasets...")
    np.savez_compressed(os.path.join(OUTPUT_DIR, 'cell_roi_train_data.npz'), X=train_roi, y=y_train)
    np.savez_compressed(os.path.join(OUTPUT_DIR, 'cell_roi_val_data.npz'), X=val_roi, y=y_val)
    np.savez_compressed(os.path.join(OUTPUT_DIR, 'cell_roi_test_data.npz'), X=test_roi, y=y_test)
    
    # Save segmentation examples
    print("Saving cell segmentation examples...")
    save_segmentation_examples(X_train[:5], train_binary_masks[:5], train_roi[:5], "train_cell_segmentation.png")
    save_segmentation_examples(X_val[:5], val_binary_masks[:5], val_roi[:5], "val_cell_segmentation.png")
    save_segmentation_examples(X_test[:5], test_binary_masks[:5], test_roi[:5], "test_cell_segmentation.png")
    
    # Save contour visualization
    save_cell_contour_visualization(X_train[:9], train_binary_masks[:9], "cell_contours_train.png")
    
    # Calculate and save segmentation metrics
    print("Calculating segmentation metrics...")
    
    # For demonstration, we'll use a small sample of the validation masks
    sample_size = min(100, len(val_binary_masks))
    dice_scores = []
    iou_scores = []
    
    for i in range(sample_size):
        pred = val_binary_masks[i].squeeze().astype(np.bool_)
        true = generate_cell_masks(val_pred_masks[i:i+1])[0].squeeze().astype(np.bool_)
        
        intersection = np.logical_and(pred, true).sum()
        union = np.logical_or(pred, true).sum()
        
        # IoU
        iou = intersection / (union + 1e-7)
        iou_scores.append(iou)
        
        # Dice coefficient
        dice = (2 * intersection) / (pred.sum() + true.sum() + 1e-7)
        dice_scores.append(dice)
    
    # Save metrics
    with open(os.path.join(OUTPUT_DIR, 'segmentation_metrics.txt'), 'w') as f:
        f.write(f"Mean IoU: {np.mean(iou_scores):.4f}\n")
        f.write(f"Mean Dice: {np.mean(dice_scores):.4f}\n")
    
    print("Cell ROI extraction complete! Results saved to:", OUTPUT_DIR)
    print("You can now use the extracted ROIs for your transfer learning hybrid approach.")

In [None]:
if __name__ == "__main__":
    main()

Loading preprocessed data...
Loaded data shapes: Train (17755, 224, 224, 1), Val (3134, 224, 224, 1), Test (3687, 224, 224, 1)
Building and training U-Net segmentation model...


I0000 00:00:1745761239.734669     257 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 15513 MB memory:  -> device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:04.0, compute capability: 6.0


Generating cell masks...


Generating cell masks: 100%|██████████| 17755/17755 [00:25<00:00, 698.84it/s]
Generating cell masks: 100%|██████████| 3134/3134 [00:03<00:00, 815.09it/s]
