# **ENSEMBLE**

**Important Note on Patchify Installation:** Installing patchify may trigger a runtime restart prompt in Colab due to library dependencies or version conflicts. If prompted, click "Restart runtime".


In [None]:
!pip install patchify

Authenticate using your Google account credentials when prompted. This mounts your Drive at /content/gdrive/, enabling access to the /Published folder and input images.

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


Import Libraries

In [None]:
import time
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.models import load_model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.optimizers.schedules import ExponentialDecay
from tensorflow.keras.saving import register_keras_serializable
from patchify import patchify, unpatchify
import os
from tqdm import tqdm
import sys
sys.path.append("/content/gdrive/MyDrive/ENSEMBLE_model")
from customfunc import * # recall_m, focal_dice_boundary_loss, precision_m, f1_m, jaccard_score, custom_objects


In [None]:
# --- Timing Decorator ---
def timeit(func):
    def wrapper(*args, **kwargs):
        start_time = time.time()
        result = func(*args, **kwargs)
        end_time = time.time()
        elapsed = end_time - start_time
        print(f"{func.__name__} executed in {elapsed:.2f} seconds")
        return result
    return wrapper

# --- Model DenseNet ---
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose, Concatenate, BatchNormalization, Activation, Dropout, Lambda
from tensorflow.keras.models import Model
from tensorflow.keras.regularizers import l2
from tensorflow.keras.applications import DenseNet121


model = densenet_unet(
    input_shape=(256, 256, 3),
    num_classes=1,
    l2_reg=1e-4,
    dropout_rate=0.5
)

# Predict and postprocess
def predict_mask(model, image):
    pred = model.predict(image)[0, :, :, 0]
    return (pred > 0.5).astype(np.uint8)  # threshold the sigmoid output


Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/densenet/densenet121_weights_tf_dim_ordering_tf_kernels_notop.h5
[1m29084464/29084464[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 0us/step


In [None]:
from tqdm.notebook import tqdm
import numpy as np
import os
import time
from PIL import Image
import tensorflow as tf
import subprocess

# --- Configure TensorFlow to limit memory growth ---
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print("Enabled memory growth for GPUs")
    except RuntimeError as e:
        print(f"Error setting memory growth: {e}")

Error setting memory growth: Physical devices cannot be modified after being initialized


In [None]:
# --- Model Preloading ---
def load_keras_models(keras_model_paths):
    standard_models = []
    z_models = []

    for path in keras_model_paths:
        try:
            if os.path.basename(path).startswith("z_"):
                print(f"Loading z_ model weights from: {path}")
                m = densenet_unet(input_shape=(256, 256, 3), num_classes=1, l2_reg=1e-4, dropout_rate=0.5)
                m.load_weights(path)
                z_models.append(m)
            else:
                print(f"Loading standard model weights from: {path}")
                model = load_model(path, compile=False)
                standard_models.append(model)
        except Exception as e:
            print(f"Failed to load model {path}: {e}")
    return standard_models, z_models

In [None]:
# --- Dynamic Batch Size Estimation ---
def get_available_gpu_memory():
    try:
        gpus = tf.config.list_physical_devices('GPU')
        if not gpus:
            print("No GPU detected, using default batch size.")
            return None
        try:
            mem_info = tf.config.experimental.get_memory_info('GPU:0')
            available_mem = mem_info['current'] / (1024 ** 2)  # Convert to MB
            return available_mem
        except:
            result = subprocess.check_output(
                ["nvidia-smi", "--query-gpu=memory.free", "--format=csv,nounits"],
                encoding='utf-8'
            )
            mem_free = float(result.split('\n')[1])
            return mem_free
    except Exception as e:
        print(f"Error checking GPU memory: {e}. Using default batch size.")
        return None

def estimate_batch_size(available_mem, patch_size=(256, 256, 3), model_memory_factor=2.0):
    if available_mem is None:
        return 8  # Conservative default batch size
    patch_memory = (patch_size[0] * patch_size[1] * patch_size[2] * 4) / (1024 ** 2)  # Float32
    total_patch_memory = patch_memory * model_memory_factor
    buffer_factor = 0.7  # More conservative buffer
    max_batch_size = int((available_mem * buffer_factor) // total_patch_memory)
    min_batch_size = 4
    max_batch_size_limit = 32  # Lowered max limit
    batch_size = max(min_batch_size, min(max_batch_size, max_batch_size_limit))
    print(f"Estimated batch size: {batch_size} (Available GPU memory: {available_mem:.2f} MB)")
    return batch_size

In [None]:
# --- Segmentation Function ---
def timeit(func):
    def wrapper(*args, **kwargs):
        start = time.time()
        result = func(*args, **kwargs)
        print(f"{func.__name__} took {time.time() - start:.2f} seconds")
        return result
    return wrapper

@timeit
def segment_with_keras_models(img_arr, original_shape, standard_models, z_models):
    try:
        total_start = time.time()
        img_arr, _ = pad_image_to_multiple(img_arr, patch_size=256)
        H, W, _ = img_arr.shape

        patches = patchify(img_arr, (256, 256, 3), step=256)
        p_rows, p_cols = patches.shape[:2]

        # Dynamically estimate batch size
        available_mem = get_available_gpu_memory()
        batch_size = estimate_batch_size(available_mem)

        def predict_mask_batch(model, patch_batch):
            preds = model(patch_batch, training=False)
            return (preds.numpy()[..., 0] > 0.5).astype(np.uint8)

        def process_models(models, batch_size):
            if not models:
                return None
            preds_group = []
            all_patches = []

            # Flatten all patches into a single list
            for i in range(p_rows):
                for j in range(p_cols):
                    patch = patches[i, j, 0].astype(np.float32) / 255.0
                    all_patches.append(patch)
            all_patches = np.array(all_patches)  # Shape: (p_rows * p_cols, 256, 256, 3)

            for model in models:
                preds = []
                current_batch_size = batch_size
                while True:
                    try:
                        for batch_start in range(0, len(all_patches), current_batch_size):
                            batch_end = min(batch_start + current_batch_size, len(all_patches))
                            batch_input = all_patches[batch_start:batch_end]
                            batch_preds = predict_mask_batch(model, batch_input)
                            preds.append(batch_preds)
                        break
                    except (tf.errors.ResourceExhaustedError, tf.errors.InternalError) as e:
                        print(f"Memory error with batch size {current_batch_size}: {e}. Reducing batch size.")
                        current_batch_size = max(current_batch_size // 2, 1)
                        if current_batch_size < 1:
                            raise RuntimeError("Batch size reduced to less than 1, cannot continue.")
                        tf.keras.backend.clear_session()
                preds = np.concatenate(preds, axis=0)
                preds = preds.reshape(p_rows, p_cols, 256, 256)
                preds_group.append(preds)
                tf.keras.backend.clear_session()
            return majority_vote(np.stack(preds_group, axis=0))

        std_vote = process_models(standard_models, batch_size)
        z_vote = process_models(z_models, batch_size)

        if std_vote is not None and z_vote is not None:
            patch_class_maps = np.maximum(std_vote, z_vote)
        elif std_vote is not None:
            patch_class_maps = std_vote
        elif z_vote is not None:
            patch_class_maps = z_vote
        else:
            raise RuntimeError("No valid predictions from any model.")

        padded_result = unpatchify(patch_class_maps, (H, W))
        original_H, original_W = original_shape
        result = padded_result[:original_H, :original_W]

        print(f"Total segmentation time: {time.time() - total_start:.2f} seconds")
        return result

    except Exception as e:
        print(f"\nSegmentation error: {str(e)}")
        return None

In [None]:
# --- Majority Voting for Ensemble ---
standard_models = None
z_models = None
models_loaded = False

def majority_vote(pred_stack):
    pred_stack = pred_stack.astype(np.uint8)
    majority = np.round(np.mean(pred_stack, axis=0)).astype(np.uint8)
    return majority

# --- Display and Save Segmentation ---
def display_and_save_segmentation(seg_result, save_path, original_path=None):
    seg_img = Image.fromarray((seg_result * 255).astype(np.uint8))
    seg_img.save(save_path)

def pad_image_to_multiple(img_arr, patch_size=256):
    H, W, C = img_arr.shape
    new_H = ((H + patch_size - 1) // patch_size) * patch_size
    new_W = ((W + patch_size - 1) // patch_size) * patch_size
    pad_h = new_H - H
    pad_w = new_W - W
    padded = np.pad(img_arr, ((0, pad_h), (0, pad_w), (0, 0)), mode='reflect')
    return padded, (H, W)

def overlay_mask_on_image(original_img_path, mask, alpha=0.4, mask_color=(255, 0, 0)):
    orig_img = Image.open(original_img_path).convert("RGB")
    orig_arr = np.array(orig_img)
    color_mask = np.zeros_like(orig_arr, dtype=np.uint8)
    color_mask[mask == 1] = mask_color
    blended = (orig_arr * (1 - alpha) + color_mask * alpha).astype(np.uint8)
    return Image.fromarray(blended)


In [None]:
import os
from google.colab import drive
import ipywidgets as widgets
from IPython.display import display
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="keras.src.models.functional")

# Mount Google Drive
try:
    drive.mount('/content/gdrive', force_remount=False)
except Exception as e:
    print(f"Error mounting Google Drive: {e}")
    print("Please ensure Google Drive is accessible.")

# Initialize variables to store directory paths (may already be set)
global input_dir, output_dir, overlay_dir
input_dir = globals().get('input_dir', "")
output_dir = globals().get('output_dir', "")
overlay_dir = globals().get('overlay_dir', "")

# Base paths
base_path_input = "/content/gdrive/MyDrive/ENSEMBLE_model/TEST_DATA"   # read-only shared data
base_path_output = "/content/gdrive/MyDrive/ENSEMBLE_model/ENSEMBLE_output"
base_path_ouptut2 = "/content/gdrive/MyDrive/"             # writable safe area

# Function to list subdirectories in the base path
def get_subdirs(path):
    try:
        if not os.path.exists(path):
            return ["No directories available"]
        return sorted([f.path for f in os.scandir(path) if f.is_dir()])
    except Exception as e:
        print(f"Error accessing {path}: {e}")
        return ["No directories available"]

# Get available directories
subdir_options_input = get_subdirs(base_path_input)
subdir_options_output = get_subdirs(base_path_output)
subdir_options_output2 = get_subdirs(base_path_ouptut2)

# Function to find the matching dropdown option for a given path
def get_dropdown_value(path, options):
    if path and path in options:
        return path
    return None   # instead of forcing first option

# Create dropdown for input directory (start with None)
input_dropdown = widgets.Dropdown(
    options=[None] + subdir_options_input,  # prepend None as "blank"
    value=None,                             # start with no selection
    description="Input Dir:",
    layout={'width': '500px'}
)

# Checkbox for auto-create
auto_create_checkbox = widgets.Checkbox(
    value=False,
    description="Auto-create output/overlay in RESULTS",
    indent=False,
    layout={'width': '500px'}
)

# Create dropdowns for output and overlay dirs
output_dropdown = widgets.Dropdown(
    options=subdir_options_output,
    value=get_dropdown_value(output_dir, subdir_options_output),
    description="Output Dir:",
    layout={'width': '500px'}
)

overlay_dropdown = widgets.Dropdown(
    options=subdir_options_output,
    value=get_dropdown_value(overlay_dir, subdir_options_output),
    description="Overlay Dir:",
    layout={'width': '500px'}
)

# Text widgets to show selected paths
input_dir_text = widgets.Text(
    value="",
    placeholder='Selected input directory',
    description='Input Path:',
    disabled=True,
    layout={'width': '500px'}
)

output_dir_text = widgets.Text(
    value=output_dir,
    placeholder='Selected output directory',
    description='Output Path:',
    disabled=True,
    layout={'width': '500px'}
)

overlay_dir_text = widgets.Text(
    value=overlay_dir,
    placeholder='Selected overlay directory',
    description='Overlay Path:',
    disabled=True,
    layout={'width': '500px'}
)

# Container for manual dropdowns
output_overlay_container = widgets.VBox([output_dropdown, overlay_dropdown])
output_overlay_container.layout.display = 'none' if auto_create_checkbox.value else 'block'

# Function to update auto-create paths
def update_dropdown_visibility(change):
    if change['new']:  # checked
        output_overlay_container.layout.display = 'none'
        if input_dropdown.value and input_dropdown.value != "No directories available":
            input_dir_name = os.path.basename(input_dropdown.value)
            result_base = os.path.join(base_path_output, f"{input_dir_name}_RESULTS")
            globals()['output_dir'] = os.path.join(result_base, "Outputs")
            globals()['overlay_dir'] = os.path.join(result_base, "Overlays")
            output_dir_text.value = globals()['output_dir']
            overlay_dir_text.value = globals()['overlay_dir']
    else:  # unchecked
        output_overlay_container.layout.display = 'block'
        if output_dropdown.value and output_dropdown.value != "No directories available":
            globals()['output_dir'] = output_dropdown.value
            output_dir_text.value = output_dropdown.value
        if overlay_dropdown.value and overlay_dropdown.value != "No directories available":
            globals()['overlay_dir'] = overlay_dropdown.value
            overlay_dir_text.value = overlay_dropdown.value

# Update input dir
def update_input_dir(change):
    if change['new'] and change['new'] != "No directories available":
        globals()['input_dir'] = change['new']
        input_dir_text.value = change['new']
        if auto_create_checkbox.value:
            input_dir_name = os.path.basename(change['new'])
            result_base = os.path.join(base_path_output, f"{input_dir_name}_RESULTS")
            globals()['output_dir'] = os.path.join(result_base, "Outputs")
            globals()['overlay_dir'] = os.path.join(result_base, "Overlays")
            output_dir_text.value = globals()['output_dir']
            overlay_dir_text.value = globals()['overlay_dir']

# Update output dir manually
def update_output_dir(change):
    if not auto_create_checkbox.value and change['new'] != "No directories available":
        globals()['output_dir'] = change['new']
        output_dir_text.value = change['new']

# Update overlay dir manually
def update_overlay_dir(change):
    if not auto_create_checkbox.value and change['new'] != "No directories available":
        globals()['overlay_dir'] = change['new']
        overlay_dir_text.value = change['new']

# Bind
input_dropdown.observe(update_input_dir, names='value')
output_dropdown.observe(update_output_dir, names='value')
overlay_dropdown.observe(update_overlay_dir, names='value')
auto_create_checkbox.observe(update_dropdown_visibility, names='value')

# Confirm button
confirm_button = widgets.Button(
    description="Confirm Selection",
    button_style='success',
    tooltip='Click to confirm directory selections',
)

# Output widget
output = widgets.Output()

def on_confirm_clicked(b):
    with output:
        output.clear_output()
        if input_dir and (output_dir or auto_create_checkbox.value) and (overlay_dir or auto_create_checkbox.value):
            if auto_create_checkbox.value:
                os.makedirs(output_dir, exist_ok=True)
                os.makedirs(overlay_dir, exist_ok=True)
            print("✅ Selected paths:")
            print(f"Input Dir: {input_dir}")
            print(f"Output Dir: {output_dir}")
            print(f"Overlay Dir: {overlay_dir}")
            print("You can now use input_dir, output_dir, and overlay_dir in your code.")
        else:
            print("❌ Please select all required directories.")

confirm_button.on_click(on_confirm_clicked)

# Initialize if auto-create is checked
if auto_create_checkbox.value and input_dir and input_dir != "No directories available":
    input_dir_name = os.path.basename(input_dir)
    result_base = os.path.join(base_path_output, f"{input_dir_name}_RESULTS")
    output_dir = os.path.join(result_base, "Outputs")
    overlay_dir = os.path.join(result_base, "Overlays")
    output_dir_text.value = output_dir
    overlay_dir_text.value = overlay_dir

# Display GUI
print(f"Select input directories from {base_path_input} (read-only).")
print(f"Outputs/overlays will be created under {base_path_output}/<InputFolderName>_RESULTS")
display(widgets.VBox([
    widgets.HBox([input_dropdown, input_dir_text]),
    auto_create_checkbox,
    output_overlay_container,
    widgets.HBox([output_dir_text, overlay_dir_text]),
    confirm_button,
    output
]))

In [None]:
import os
from PIL import Image
import numpy as np
from tqdm.notebook import tqdm
import time

if __name__ == "__main__":
    # input_dir, output_dir, overlay_dir are set by the GUI (from previous artifact)

    # Create output directories if they don't exist
    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(overlay_dir, exist_ok=True)

    keras_models = [
        "/content/gdrive/MyDrive/ENSEMBLE_model/models/z_008DensDICEOvrSmpHard_06.keras",
        "/content/gdrive/MyDrive/ENSEMBLE_model/models/z_008DensDICEOvrSmpHard_07.keras",
        "/content/gdrive/MyDrive/ENSEMBLE_model/models/z_008DensDICEOvrSmpHard_08.keras",
        "/content/gdrive/MyDrive/ENSEMBLE_model/models/z_008DensFOCALOvrSmpHard_15.keras",
        "/content/gdrive/MyDrive/ENSEMBLE_model/models/z_008DensFOCALOvrSmpHard_40.keras",
        "/content/gdrive/MyDrive/ENSEMBLE_model/models/_008vggDICEOvrSmpHard_06.keras",
        "/content/gdrive/MyDrive/ENSEMBLE_model/models/_008vggDICEOvrSmpHard_32.keras",
        "/content/gdrive/MyDrive/ENSEMBLE_model/models/Keras_VggSoft_34.keras",
        "/content/gdrive/MyDrive/ENSEMBLE_model/models/Keras_Vgg_hard13.keras",
        "/content/gdrive/MyDrive/ENSEMBLE_model/models/Keras_Vgg_hard14.keras",
        "/content/gdrive/MyDrive/ENSEMBLE_model/models/_attn_dice_21.keras",
        "/content/gdrive/MyDrive/ENSEMBLE_model/models/_attn_dice_22.keras",
        "/content/gdrive/MyDrive/ENSEMBLE_model/models/_008VggFOCAL_ADV_Atten_IA_8m_increased_42.keras",
        "/content/gdrive/MyDrive/ENSEMBLE_model/models/_008VggFOCAL_ADV_Atten_IA_8m_increased_45.keras",
    ]

    # Check if models are already loaded
    if not models_loaded:
        print("Loading all models once...")
        standard_models, z_models = load_keras_models(keras_models)
        models_loaded = True
        print("Models loaded successfully.")
    else:
        print("Models already loaded, skipping loading.")

    image_files = sorted(os.listdir(input_dir))

    for img_nm in tqdm(image_files, desc="Processing images"):
        image_path = os.path.join(input_dir, img_nm)
        img = Image.open(image_path).convert("RGB")
        img_arr = np.array(img)
        original_shape = img_arr.shape[:2]

        segmentation_result = segment_with_keras_models(img_arr, original_shape, standard_models, z_models)

        if segmentation_result is not None:
            output_path = os.path.join(output_dir, img_nm)
            overlay_save_path = os.path.join(overlay_dir, img_nm)

            save_start = time.time()
            display_and_save_segmentation(segmentation_result, output_path, image_path)
            overlay_img = overlay_mask_on_image(image_path, segmentation_result, alpha=0.4, mask_color=(255, 0, 0))
            overlay_img.save(overlay_save_path)
            print(f"Saved: {img_nm} in {time.time() - save_start:.2f}s")
        else:
            print(f"Failed on: {img_nm}")