In [None]:
import os
import glob
import cv2
import random
import numpy as np
import matplotlib.pyplot as plt
from keras.utils import normalize
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from tensorflow.keras.utils import to_categorical
from keras.models import Model, load_model
from keras.layers import Input, Conv2DTranspose, Conv2D, MaxPooling2D, UpSampling2D, concatenate, Dropout, BatchNormalization, Activation, MaxPool2D, Multiply, GlobalAveragePooling2D, Reshape, Dense, Lambda
from tensorflow.keras.regularizers import l2
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import backend as K
from keras.losses import binary_crossentropy
from tensorflow.keras.applications import VGG16
from sklearn.metrics import confusion_matrix, precision_score, recall_score, roc_auc_score, roc_curve
import seaborn as sns
from skimage.restoration import denoise_nl_means, estimate_sigma
from skimage import img_as_ubyte, img_as_float, io
from scipy.ndimage import binary_opening

def load_and_preprocess_data(folder_path,imagesf,maskf):
    # Load images
    image_names = glob.glob(os.path.join(folder_path, imagesf, "*.tif"))
    image_names.sort()
    images = []
    for image_path in image_names:
        image = cv2.imread(image_path, 1)
        if image is not None:
            images.append(image)

    # Load masks
    mask_names = glob.glob(os.path.join(folder_path, maskf, "*.png"))
    mask_names.sort()
    masks = []
    for mask_path in mask_names:
        mask = cv2.imread(mask_path, 1)
        if mask is not None:
            gray_mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
            masks.append(gray_mask)

    return images, masks


def create_multiscale_patches(image, large_patch_size, medium_patch_size, small_patch_size, stride):
    """
    Splits an image into multi-scale patches of given sizes with specified stride.

    :param image: Input image to split into patches.
    :param large_patch_size: Size of each large patch.
    :param medium_patch_size: Size of each medium patch.
    :param small_patch_size: Size of each small patch.
    :param stride: Number of pixels to move in both horizontal and vertical directions for the next patch.
    :return: Lists of large, medium, and small patches.
    """
    large_patches = []
    medium_patches = []
    small_patches = []

    h, w = image.shape[:2]

    for y in range(0, h - large_patch_size + 1, stride):
        for x in range(0, w - large_patch_size + 1, stride):
            large_patch = image[y:y + large_patch_size, x:x + large_patch_size]
            large_patches.append(large_patch)

            # Calculate center for medium patch within the large patch
            m_x_center = x + (large_patch_size - medium_patch_size) // 2
            m_y_center = y + (large_patch_size - medium_patch_size) // 2

            if m_x_center + medium_patch_size <= w and m_y_center + medium_patch_size <= h:
                medium_patch = image[m_y_center:m_y_center + medium_patch_size, m_x_center:m_x_center + medium_patch_size]
                medium_patches.append(medium_patch)

                # Calculate center for small patch within the medium patch
                l_x_center = m_x_center + (medium_patch_size - small_patch_size) // 2
                l_y_center = m_y_center + (medium_patch_size - small_patch_size) // 2

                if l_x_center + small_patch_size <= w and l_y_center + small_patch_size <= h:
                    small_patch = image[l_y_center:l_y_center + small_patch_size, l_x_center:l_x_center + small_patch_size]
                    small_patches.append(small_patch)

    return large_patches, medium_patches, small_patches


def resize_patches(patches, target_size, is_mask=False):
    """
    Resize patches to the target size. Uses nearest-neighbor interpolation for masks.

    :param patches: List of patches to resize.
    :param target_size: Target size for resizing (tuple of width, height).
    :param is_mask: Boolean indicating if the patches are masks. Defaults to False.
    :return: Resized patches as a numpy array.
    """
    interpolation_method = cv2.INTER_NEAREST if is_mask else cv2.INTER_LINEAR
    resized_patches = [cv2.resize(patch, (target_size, target_size), interpolation=interpolation_method) for patch in patches]
    return np.array(resized_patches)

def encode_mask(mask_patches):
    labelencoder = LabelEncoder()
    n, h, w = mask_patches.shape  
    mask_dataset_reshaped = mask_patches.reshape(-1, 1)
    mask_dataset_reshaped_encoded = labelencoder.fit_transform(mask_dataset_reshaped.ravel())
    mask_dataset_encoded = mask_dataset_reshaped_encoded.reshape(n, h, w)
    mask_dataset_encoded = np.expand_dims(mask_dataset_encoded, axis=3)
    return mask_dataset_encoded


def categorize_and_reshape_masks(y, n_classes):
    # Convert masks to one-hot encoded format
    y_masks_cat = to_categorical(y, num_classes=n_classes)
    # Reshape the one-hot encoded masks to the desired shape
    y_cat = y_masks_cat.reshape((y.shape[0], y.shape[1], y.shape[2], n_classes))

    return y_cat

In [None]:
# Load and preprocess training data
train_folder = "/kaggle/input/monusac-public/MoNuSac"
train_images, train_masks = load_and_preprocess_data(train_folder,"images","binary_masks")



# Define patch sizes and stride
large_patch_size = 256
medium_patch_size = 192
small_patch_size = 128
stride = 256

# Create multi-scale patches for training data
train_image_large_patches, train_image_medium_patches, train_image_small_patches = [], [], []
train_mask_large_patches = []

for image, mask in zip(train_images, train_masks):
    large_patches, medium_patches, small_patches = create_multiscale_patches(image, large_patch_size, medium_patch_size, small_patch_size, stride)
    train_image_large_patches.extend(large_patches)
    train_image_medium_patches.extend(medium_patches)
    train_image_small_patches.extend(small_patches)

    large_patches, _, _ = create_multiscale_patches(mask, large_patch_size, medium_patch_size, small_patch_size, stride)
    train_mask_large_patches.extend(large_patches)



# Convert lists to numpy arrays
train_image_large_patches = np.array(train_image_large_patches)
train_image_medium_patches = np.array(train_image_medium_patches)
train_image_small_patches = np.array(train_image_small_patches)
train_mask_large_patches = np.array(train_mask_large_patches)



In [None]:
# Resize patches
train_image_large_patches_resized = train_image_large_patches  # No need to resize
train_image_medium_patches_resized = resize_patches(train_image_medium_patches, large_patch_size)
train_image_small_patches_resized = resize_patches(train_image_small_patches, large_patch_size)
train_mask_large_patches_resized = train_mask_large_patches  # No need to resize


In [None]:
print("Image small patch shape is: ", train_image_small_patches_resized.shape)
print("Image Medium patch shape is: ", train_image_medium_patches_resized.shape)
print("Image Large patch shape is: ", train_image_large_patches_resized.shape, "and Mask Large patch shape is: ", train_mask_large_patches_resized.shape)

print("Max pixel value in image is: ", train_image_large_patches_resized.max())
unique_labels = np.unique(train_mask_large_patches_resized)
print("Labels in the mask are : ", unique_labels)
num_classes = len(unique_labels)
print("Total Classes in the mask are : ", num_classes)

In [None]:
# Normalize images
train_image_large_patches_resized = train_image_large_patches_resized / 255.
train_image_medium_patches_resized = train_image_medium_patches_resized / 255.
train_image_small_patches_resized = train_image_small_patches_resized / 255.


In [None]:
# Encode masks
train_mask_large_dataset_encoded = encode_mask(train_mask_large_patches_resized)


# Get number of classes
unique_labels = np.unique(train_mask_large_dataset_encoded)
num_classes = len(unique_labels)

In [None]:
# Categorize and reshape masks
train_large_mask_cat = categorize_and_reshape_masks(train_mask_large_dataset_encoded, num_classes)


# Split training data into training and validation sets
X_large_train, X_large_val, X_medium_train, X_medium_val, X_small_train, X_small_val, y_train, y_val = train_test_split(
    train_image_large_patches_resized, train_image_medium_patches_resized, train_image_small_patches_resized, train_large_mask_cat, test_size=0.2, random_state=42
)

# Print shapes
print("Training data shapes:")
print("X_large_train:", X_large_train.shape)
print("X_medium_train:", X_medium_train.shape)
print("X_small_train:", X_small_train.shape)
print("y_train:", y_train.shape)

print("\nValidation data shapes:")
print("X_large_val:", X_large_val.shape)
print("X_medium_val:", X_medium_val.shape)
print("X_small_val:", X_small_val.shape)
print("y_val:", y_val.shape)


In [None]:
import tensorflow as tf

gpus = tf.config.list_physical_devices('GPU')
print("GPUs found:", gpus)

# # Optional: enable mixed precision for speed on modern GPUs
# mixed_precision.set_global_policy('mixed_float16')

# Use all available GPUs on this machine
strategy = tf.distribute.MirroredStrategy()

print("Number of GPUs in sync:", strategy.num_replicas_in_sync)

In [None]:
# Define input shapes for the model
large_input_shape = (train_image_large_patches_resized.shape[1], train_image_large_patches_resized.shape[2], train_image_large_patches_resized.shape[3])
medium_input_shape = (train_image_medium_patches_resized.shape[1], train_image_medium_patches_resized.shape[2], train_image_medium_patches_resized.shape[3])
small_input_shape = (train_image_small_patches_resized.shape[1], train_image_small_patches_resized.shape[2], train_image_small_patches_resized.shape[3])

print(large_input_shape)
print(medium_input_shape)
print(small_input_shape)

In [None]:
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

tf.config.optimizer.set_experimental_options({"layout_optimizer": False, "model_pruner": False})

checkpoint = ModelCheckpoint(
    filepath='/kaggle/working/Binary_model_ER_IHC.keras',
    monitor='val_loss',     
    save_best_only=True,
    mode='min',             
    verbose=1
)

early_stopping = EarlyStopping(
    monitor='val_loss',     
    patience=25,
    restore_best_weights=True,
    verbose=1,
    mode='min'
)


In [None]:



# Define the VGG16 backbone function
def create_vgg16_backbone(input_shape):
    vgg16 = VGG16(weights='imagenet', include_top=False, input_shape=input_shape)
    
    # Make VGG16 layers non-trainable
#     for layer in vgg16.layers:
#         layer.trainable = False
    # Extract feature maps from different layers
    c1 = vgg16.get_layer('block1_conv2').output  
    c2 = vgg16.get_layer('block2_conv2').output 
    c3 = vgg16.get_layer('block3_conv3').output  

    return Model(inputs=vgg16.input, outputs=[c1, c2, c3])

def channel_attention(x):
    """
    Channel Attention Mechanism.
    """
    # Global Average Pooling
    avg_pool = GlobalAveragePooling2D()(x)
    avg_pool = Reshape((1, 1, x.shape[-1]))(avg_pool)
    
    # Fully Connected Layer
    fc = Dense(x.shape[-1] // 8, activation='relu')(avg_pool)
    fc = Dense(x.shape[-1], activation='sigmoid')(fc)
    fc = Reshape((1, 1, x.shape[-1]))(fc)
    
    # Multiply with the input
    attention = Multiply()([x, fc])
    return attention




In [None]:
target_shape= large_input_shape

from tensorflow.keras.optimizers import Adam
from tensorflow.keras import mixed_precision
with strategy.scope():
    input_shape = target_shape

    large_input = Input(shape=input_shape)
    medium_input = Input(shape=input_shape)
    small_input = Input(shape=input_shape)

    vgg16_backbone = create_vgg16_backbone(input_shape)

    # Extract features for each patch size using VGG16
    large_features = vgg16_backbone(large_input)
    medium_features = vgg16_backbone(medium_input)
    small_features = vgg16_backbone(small_input)

    # Concatenate the features from different scales
    c1 = concatenate([large_features[0], medium_features[0], small_features[0]])
    c2 = concatenate([large_features[1], medium_features[1], small_features[1]])
    c3 = concatenate([large_features[2], medium_features[2], small_features[2]])
  
   
    # Apply attention mechanism

    c3 = channel_attention(c3)
    c2 = channel_attention(c2)
    c1 = channel_attention(c1)


    # Decoder with skip connections and increased filters with dropout
    x = Conv2DTranspose(1024, (3, 3), strides=(2, 2), padding='same', kernel_regularizer=l2(0.01))(c3)
    x = concatenate([x, c2])
    x = Conv2D(512, (3, 3), activation='relu', padding='same')(x)
    x = BatchNormalization()(x)
    x = Dropout(0.5)(x)

    x = Conv2DTranspose(512, (3, 3), strides=(2, 2), padding='same', kernel_regularizer=l2(0.01))(x)
    x = concatenate([x, c1])
    x = Conv2D(256, (3, 3), activation='relu', padding='same')(x)
    x = BatchNormalization()(x)
    x = Dropout(0.5)(x)

    shared = Dropout(0.5)(x)


    # Output layer
    Edge_output = Conv2D(1, kernel_size=1, activation='sigmoid', name='Binary_Edge_Branch')(shared)
    Segmentation_output = Conv2D(num_classes, 1, activation='sigmoid', name='Segmentation_branch')(shared)

    # Create the final model
    model = Model(inputs=[large_input, medium_input, small_input], outputs=Segmentation_output)
    
    model.compile(
        optimizer=Adam(),
        loss='binary_crossentropy',
        metrics=['accuracy']   
    )
    

    # Check the model summary
    model.summary()
    total_params = model.count_params()
    print(f"Total Trainable Parameters: {total_params / 1e6:.2f}M")

    # Train the model
    history = model.fit(
        [X_large_train, X_medium_train, X_small_train],
        y_train,
        validation_data=(
            [X_large_val, X_medium_val, X_small_val],
            y_val
        ),
        epochs=200,
        batch_size=16,
        callbacks=[early_stopping, checkpoint]
    )


In [None]:
import matplotlib.pyplot as plt

# ---- 1. Extract history ----
history_dict = history.history

train_loss = history_dict['loss']
val_loss   = history_dict['val_loss']
train_acc  = history_dict['accuracy']
val_acc    = history_dict['val_accuracy']

epochs = range(1, len(train_loss) + 1)

# ---- 2. Plot Loss (this is the most important for segmentation) ----
plt.figure(figsize=(8, 6))
plt.plot(epochs, train_loss, label='Training Loss')
plt.plot(epochs, val_loss, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Binary Cross-Entropy Loss')
plt.title('Training vs Validation Loss')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

# ---- 3. Plot Accuracy (secondary, can be misleading for segmentation) ----
plt.figure(figsize=(8, 6))
plt.plot(epochs, train_acc, label='Training Accuracy')
plt.plot(epochs, val_acc, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training vs Validation Accuracy')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import cv2
import random
import time

# ----------------------------------------------------------
# 1. Select 5 random test images
# ----------------------------------------------------------
X_test_large  = train_image_large_patches_resized
X_test_medium = train_image_medium_patches_resized
X_test_small  = train_image_small_patches_resized
y_test        = train_large_mask_cat

num_samples = len(X_test_large)
random_indices = random.sample(range(num_samples), 5)

print("Selected random test indices:", random_indices)

# ----------------------------------------------------------
# Helper Functions
# ----------------------------------------------------------

def prepare_rgb(img):
    """Convert to uint8 RGB."""
    img = img.copy()
    if img.dtype != np.uint8:
        img = (img * 255).astype(np.uint8)
    return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

def colormap_mask(mask):
    """Simple 2-class color mapping."""
    result = np.zeros((*mask.shape, 3), dtype=np.uint8)
    result[mask == 0] = [40, 0, 60]       # background
    result[mask == 1] = [255, 235, 50]    # yellow nuclei
    return result

def instance_segmentation_watershed(img_rgb, semantic_mask):
    """Return instance-separated contour image using watershed."""
    
    # 1. Binary mask
    binary_mask = (semantic_mask > 0).astype(np.uint8)

    # 2. Morphological cleaning
    kernel = np.ones((3, 3), np.uint8)
    binary_clean = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, kernel)

    # 3. Distance transform
    dist = cv2.distanceTransform(binary_clean, cv2.DIST_L2, 5)
    dist_norm = dist / (dist.max() + 1e-7)

    # 4. Foreground estimation
    _, sure_fg = cv2.threshold(dist_norm, 0.4, 1.0, cv2.THRESH_BINARY)
    sure_fg = np.uint8(sure_fg)
    unknown = cv2.subtract(binary_clean, sure_fg)

    # 5. Connected components â†’ markers
    num_markers, markers = cv2.connectedComponents(sure_fg)
    markers = markers + 1
    markers[unknown == 1] = 0

    # 6. Watershed
    ws_img = img_rgb.copy()
    markers = cv2.watershed(ws_img, markers)

    # 7. Draw contours
    contour_img = img_rgb.copy()

    unique_ids = np.unique(markers)
    for label in unique_ids:
        if label <= 1:
            continue
        
        inst_mask = np.uint8(markers == label)
        cnts, _ = cv2.findContours(inst_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        cv2.drawContours(contour_img, cnts, -1, (255, 0, 0), 2)

    return contour_img


# ----------------------------------------------------------
# 2. Plot 5 Samples
# ----------------------------------------------------------
plt.figure(figsize=(22, 20))

for i, idx in enumerate(random_indices):
    
    # ------------------------------
    # Load image + GT
    # ------------------------------
    img_large = X_test_large[idx]
    img_medium = X_test_medium[idx]
    img_small = X_test_small[idx]
    gt_mask = y_test[idx]

    # ------------------------------
    # Prepare input batch
    # ------------------------------
    test_input = [
        np.expand_dims(img_large, 0),
        np.expand_dims(img_medium, 0),
        np.expand_dims(img_small, 0)
    ]

    # ------------------------------
    # Predict
    # ------------------------------
    pred = model.predict(test_input, verbose=0)
    pred_probs = pred[0]
    pred_mask = np.argmax(pred_probs, axis=-1)

    gt_labels = np.argmax(gt_mask, axis=-1)

    # ------------------------------
    # Prepare images
    # ------------------------------
    rgb_img = prepare_rgb(img_large)
    gt_colored = colormap_mask(gt_labels)
    pred_colored = colormap_mask(pred_mask)

    # ------------------------------
    # Instance separation (watershed)
    # ------------------------------
    instance_contours_img = instance_segmentation_watershed(rgb_img, pred_mask)

    # ------------------------------
    # Display row (4 images)
    # ------------------------------
    row = i * 4 + 1

    plt.subplot(5, 4, row)
    plt.imshow(rgb_img)
    plt.title(f"Original (idx={idx})")
    plt.axis("off")

    plt.subplot(5, 4, row + 1)
    plt.imshow(gt_colored)
    plt.title("Ground Truth Mask")
    plt.axis("off")

    plt.subplot(5, 4, row + 2)
    plt.imshow(pred_colored)
    plt.title("Predicted Mask")
    plt.axis("off")

    plt.subplot(5, 4, row + 3)
    plt.imshow(instance_contours_img)
    plt.title("Instance Contours")
    plt.axis("off")

plt.tight_layout()
plt.show()
