In [None]:
import os
import glob
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import cv2
from skimage.transform import resize

SEGMENTATION_MODEL_PATH = "/home/furix/Desktop/DiscAI/notebooks/segmentation_model_best.keras"
IMAGE_DIR = "../data/JUH/train/images" 
OUTPUT_MASK_DIR = "../data/JUH_segmentation_output/train_masks" 
OUTPUT_PREPROCESS_DIR = "../data/JUH_segmentation_output/train_preprocessed" 

SAVE_MASKS = True
DISPLAY_IMAGES = False
SAVE_PREPROCESS = True

PRE_CROP_TARGET_SHAPE = (512, 512)
CROP_Y_START = 96
CROP_X_START = 48
CROP_X_END_OFFSET = -48
FINAL_INPUT_SHAPE = (
    PRE_CROP_TARGET_SHAPE[0] - CROP_Y_START,
    PRE_CROP_TARGET_SHAPE[1] - CROP_X_START - abs(CROP_X_END_OFFSET)
) 
PREDICTION_THRESHOLD = 0.5

print("Configuring GPU...")

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True)
        print(f"TensorFlow GPU memory growth enabled for {len(gpus)} GPU(s).")
    except RuntimeError as e: print(f"Error setting memory growth: {e}")
else: print("No GPU detected by TensorFlow.")


@tf.keras.utils.register_keras_serializable()
def dice_loss(y_true, y_pred, epsilon=1e-6): y_true = tf.cast(y_true, tf.float32); y_pred = tf.cast(y_pred, tf.float32); intersection = tf.reduce_sum(y_true * y_pred); union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred); return 1 - (2. * intersection + epsilon) / (union + epsilon)
@tf.keras.utils.register_keras_serializable()
def combo_loss(y_true, y_pred): y_true_f = tf.cast(y_true, tf.float32); y_pred_f = tf.cast(y_pred, tf.float32); bce = tf.keras.losses.binary_crossentropy(y_true_f, y_pred_f); dsc = dice_loss(y_true_f, y_pred_f); return bce + dsc
@tf.keras.utils.register_keras_serializable()
def dice_coefficient(y_true, y_pred, epsilon=1e-6): y_true = tf.cast(y_true, tf.float32); y_pred = tf.cast(y_pred, tf.float32); intersection = tf.reduce_sum(y_true * y_pred); union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred); dice = tf.where(union == 0, 1.0, (2. * intersection + epsilon) / (union + epsilon)); return dice
@tf.keras.utils.register_keras_serializable()
def iou_metric(y_true, y_pred, epsilon=1e-6): y_true = tf.cast(y_true, tf.float32); y_pred = tf.cast(y_pred, tf.float32); intersection = tf.reduce_sum(y_true * y_pred); union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) - intersection; iou = tf.where(union == 0, 1.0, (intersection + epsilon) / (union + epsilon)); return iou
@tf.keras.utils.register_keras_serializable()
def channel_avg_pool(x): return tf.reduce_mean(x, axis=-1, keepdims=True)
@tf.keras.utils.register_keras_serializable()
def channel_max_pool(x): return tf.reduce_max(x, axis=-1, keepdims=True)

def normalize_image(image_slice_np):
    slice_float = image_slice_np.astype(np.float32); min_val, max_val = np.min(slice_float), np.max(slice_float)
    if max_val > min_val: normalized_slice = (slice_float - min_val) / (max_val - min_val)
    else: normalized_slice = np.zeros_like(slice_float)
    return normalized_slice

def preprocess_and_crop_image_for_inference(image_path, pre_crop_shape=PRE_CROP_TARGET_SHAPE, final_shape=FINAL_INPUT_SHAPE):
    try:
        img_np = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
        if img_np is None: print(f"Warning: Could not read image {image_path}. Skipping."); return None
        resized_slice = resize(img_np, pre_crop_shape, order=3, preserve_range=True, anti_aliasing=True)
        normalized_slice = normalize_image(resized_slice)
        cropped_slice = normalized_slice[CROP_Y_START:, CROP_X_START:CROP_X_END_OFFSET]
        if cropped_slice.shape != final_shape: print(f"Warning: Cropped shape {cropped_slice.shape} for {image_path} != expected {final_shape}.")
        final_tensor = np.expand_dims(cropped_slice, axis=-1); final_tensor = np.expand_dims(final_tensor, axis=0)
        return final_tensor.astype(np.float32)
    except Exception as e: print(f"Error preprocessing {image_path}: {e}"); return None

def generate_segmentation_mask(input_tensor, model, threshold=PREDICTION_THRESHOLD):
    if model is None: raise ValueError("Segmentation model is not loaded.");
    if input_tensor is None: raise ValueError("Input tensor is None.")
    try:
        if input_tensor.shape[1:] != model.input_shape[1:]: print(f"Warning: Input tensor shape {input_tensor.shape[1:]} differs from model expected {model.input_shape[1:]}.")
        prediction = model.predict(input_tensor, verbose=0)
        if prediction.ndim == 4 and prediction.shape[0] == 1 and prediction.shape[-1] == 1: mask = (prediction[0, ..., 0] > threshold).astype(np.uint8); return mask
        else: print(f"Error: Unexpected prediction output shape: {prediction.shape}"); return None
    except Exception as e: print(f"Error during segmentation prediction: {e}"); return None

if __name__ == "__main__":
    print(f"Loading segmentation model from: {SEGMENTATION_MODEL_PATH}")
    if not os.path.exists(SEGMENTATION_MODEL_PATH): print(f"FATAL: Model file not found. Exiting."); exit(1)
    try:
        custom_objects = { 'dice_loss': dice_loss, 'combo_loss': combo_loss, 'dice_coefficient': dice_coefficient, 'iou_metric': iou_metric, 'channel_avg_pool': channel_avg_pool, 'channel_max_pool': channel_max_pool }
        segmentation_model = tf.keras.models.load_model(SEGMENTATION_MODEL_PATH, custom_objects=custom_objects)
        print("Segmentation model loaded successfully.")
        print(f"Model expects input shape (HxWxC): {segmentation_model.input_shape[1:]}")
        if segmentation_model.input_shape[1:3] != FINAL_INPUT_SHAPE: print(f"WARNING: Model expects {segmentation_model.input_shape[1:3]} but preprocessing generates {FINAL_INPUT_SHAPE}.")
    except Exception as e: print(f"FATAL: Error loading segmentation model: {e}"); exit(1)

    if SAVE_MASKS:
        os.makedirs(OUTPUT_MASK_DIR, exist_ok=True)
        print(f"Output masks will be saved to: {OUTPUT_MASK_DIR}")

    print(f"Searching for images in: {IMAGE_DIR}")
    image_files = glob.glob(os.path.join(IMAGE_DIR, '*.jpg')) + \
                  glob.glob(os.path.join(IMAGE_DIR, '*.jpeg')) + \
                  glob.glob(os.path.join(IMAGE_DIR, '*.png'))
    if not image_files: print(f"FATAL: No JPG or PNG images found in {IMAGE_DIR}. Exiting."); exit(1)
    print(f"Found {len(image_files)} images to process.")

    processed_count = 0
    skipped_count = 0
    for img_path in image_files:
        base_filename = os.path.basename(img_path)

        if base_filename.lower().startswith('x'):
            print(f"Skipping file: {base_filename} (starts with 'x')")
            skipped_count += 1
            continue 

        print(f"\nProcessing: {base_filename}")

        input_tensor = preprocess_and_crop_image_for_inference(img_path)
        if input_tensor is None:
            continue

        print(f"  Preprocessed input tensor shape: {input_tensor.shape}")

        mask = generate_segmentation_mask(input_tensor, segmentation_model)
        if mask is None:
            print(f"  Segmentation failed for {base_filename}")
            continue

        print(f"  Generated mask shape: {mask.shape}")
        processed_count += 1

        if SAVE_MASKS:
            try:
                mask_filename = f"mask_{os.path.splitext(base_filename)[0]}.png"
                output_path = os.path.join(OUTPUT_MASK_DIR, mask_filename)
                cv2.imwrite(output_path, mask * 255)
                print(f"  Saved mask to: {output_path}")
            except Exception as e:
                print(f"  Error saving mask for {base_filename}: {e}")
                
        if SAVE_PREPROCESS and input_tensor is not None: 
            try:
                os.makedirs(OUTPUT_PREPROCESS_DIR, exist_ok=True)

                preprocess_filename = f"preprocess_{os.path.splitext(base_filename)[0]}.png"
                output_path = os.path.join(OUTPUT_PREPROCESS_DIR, preprocess_filename)

                image_to_save = np.squeeze(input_tensor)

                image_to_save_uint8 = (image_to_save * 255.0).astype(np.uint8)

                cv2.imwrite(output_path, image_to_save_uint8)
                print(f"  Saved preprocessed image to: {output_path}")

            except Exception as e:
                print(f"  Error saving preprocessed image for {base_filename}: {e}")

        if DISPLAY_IMAGES:
            try:
                original_img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
                plt.figure(figsize=(10, 5)); plt.subplot(1, 2, 1); plt.imshow(original_img, cmap='gray'); plt.title(f"Original: {base_filename}"); plt.axis("off");
                plt.subplot(1, 2, 2); plt.imshow(mask, cmap='gray'); plt.title("Generated Mask"); plt.axis("off"); plt.show()
            except Exception as e: print(f"  Error displaying image/mask for {base_filename}: {e}")

    print(f"\nFinished processing. Processed: {processed_count}, Skipped: {skipped_count}")

Configuring GPU...
TensorFlow GPU memory growth enabled for 1 GPU(s).
Loading segmentation model from: /home/furix/Desktop/DiscAI/notebooks/segmentation_model_best.keras
Segmentation model loaded successfully.
Model expects input shape (HxWxC): (416, 416, 1)
Output masks will be saved to: ../data/JUH_segmentation_output/train_masks
Searching for images in: ../data/JUH/train/images
Found 2918 images to process.

Processing: S273588_jpg.rf.ca89534b63a7755dddf3bbad66d3ba10.jpg
  Preprocessed input tensor shape: (1, 416, 416, 1)
  Generated mask shape: (416, 416)
  Saved mask to: ../data/JUH_segmentation_output/train_masks/mask_S273588_jpg.rf.ca89534b63a7755dddf3bbad66d3ba10.png
  Saved preprocessed image to: ../data/JUH_segmentation_output/train_preprocessed/preprocess_S273588_jpg.rf.ca89534b63a7755dddf3bbad66d3ba10.png
Skipping file: X1355114_L2_jpg.rf.38bfbd9b3279773d5c48d10b82600cc7.jpg (starts with 'x')

Processing: S863745_jpg.rf.68672ffa4072173a96f85a6856f29f3b.jpg
  Preprocessed in