In [None]:
!pip install tensorflow-addons
!pip install tfa-nightly
!pip install scikit-image

In [None]:
import os
import glob
import cv2
import random
import numpy as np
import matplotlib.pyplot as plt
from keras.utils import normalize
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from tensorflow.keras.utils import to_categorical
from keras.models import Model, load_model
from keras.layers import Input, Conv2DTranspose, Conv2D, MaxPooling2D, UpSampling2D, concatenate, Dropout, BatchNormalization, Activation, MaxPool2D, Multiply, GlobalAveragePooling2D, Reshape, Dense, Lambda
from tensorflow.keras.regularizers import l2
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import backend as K
from keras.losses import binary_crossentropy
from sklearn.metrics import confusion_matrix, precision_score, recall_score, roc_auc_score, roc_curve
import seaborn as sns
from skimage.restoration import denoise_nl_means, estimate_sigma
from skimage import img_as_ubyte, img_as_float, io
from scipy.ndimage import binary_opening

def load_and_preprocess_data(folder_path, imagesf, maskf):
    # Load images
    image_names = glob.glob(os.path.join(folder_path, imagesf, "*.png"))
    image_names.sort()
    images = []
    for image_path in image_names:
        image = cv2.imread(image_path, 1)
        if image is not None:
            images.append(image)

    # Load masks (as color images)
    mask_names = glob.glob(os.path.join(folder_path, maskf, "*.png"))
    mask_names.sort()
    masks = []
    for mask_path in mask_names:
        mask = cv2.imread(mask_path, cv2.IMREAD_COLOR)  # Load in BGR format
        if mask is not None:
            masks.append(mask)

    return images, masks  # No heatmaps loaded here


def create_multiscale_patches(image, large_patch_size, medium_patch_size, small_patch_size, stride):
    """
    Splits an image into multi-scale patches of given sizes with specified stride.

    :param image: Input image to split into patches.
    :param large_patch_size: Size of each large patch.
    :param medium_patch_size: Size of each medium patch.
    :param small_patch_size: Size of each small patch.
    :param stride: Number of pixels to move in both horizontal and vertical directions for the next patch.
    :return: Lists of large, medium, and small patches.
    """
    large_patches = []
    medium_patches = []
    small_patches = []

    h, w = image.shape[:2]

    for y in range(0, h - large_patch_size + 1, stride):
        for x in range(0, w - large_patch_size + 1, stride):
            large_patch = image[y:y + large_patch_size, x:x + large_patch_size]
            large_patches.append(large_patch)

            # Calculate center for medium patch within the large patch
            m_x_center = x + (large_patch_size - medium_patch_size) // 2
            m_y_center = y + (large_patch_size - medium_patch_size) // 2

            if m_x_center + medium_patch_size <= w and m_y_center + medium_patch_size <= h:
                medium_patch = image[m_y_center:m_y_center + medium_patch_size, m_x_center:m_x_center + medium_patch_size]
                medium_patches.append(medium_patch)

                # Calculate center for small patch within the medium patch
                l_x_center = m_x_center + (medium_patch_size - small_patch_size) // 2
                l_y_center = m_y_center + (medium_patch_size - small_patch_size) // 2

                if l_x_center + small_patch_size <= w and l_y_center + small_patch_size <= h:
                    small_patch = image[l_y_center:l_y_center + small_patch_size, l_x_center:l_x_center + small_patch_size]
                    small_patches.append(small_patch)

    return large_patches, medium_patches, small_patches


def resize_patches(patches, target_size, is_mask=False):
    """
    Resize patches to the target size. Uses nearest-neighbor interpolation for masks.

    :param patches: List of patches to resize.
    :param target_size: Target size for resizing (tuple of width, height).
    :param is_mask: Boolean indicating if the patches are masks. Defaults to False.
    :return: Resized patches as a numpy array.
    """
    interpolation_method = cv2.INTER_NEAREST if is_mask else cv2.INTER_LINEAR
    resized_patches = [cv2.resize(patch, (target_size, target_size), interpolation=interpolation_method) for patch in patches]
    return np.array(resized_patches)

def encode_mask(mask_patches):
    labelencoder = LabelEncoder()
    n, h, w = mask_patches.shape  
    mask_dataset_reshaped = mask_patches.reshape(-1, 1)
    mask_dataset_reshaped_encoded = labelencoder.fit_transform(mask_dataset_reshaped.ravel())
    mask_dataset_encoded = mask_dataset_reshaped_encoded.reshape(n, h, w)
    mask_dataset_encoded = np.expand_dims(mask_dataset_encoded, axis=3)
    return mask_dataset_encoded, labelencoder  # Return both encoded masks and encoder



def categorize_and_reshape_masks(y, n_classes):
    # Convert masks to one-hot encoded format
    y_masks_cat = to_categorical(y, num_classes=n_classes)
    # Reshape the one-hot encoded masks to the desired shape
    y_cat = y_masks_cat.reshape((y.shape[0], y.shape[1], y.shape[2], n_classes))

    return y_cat

In [None]:
# Load and preprocess training data
train_folder = "/kaggle/input/er-20x-corrected-balanced-centroids/20XCorrectedBalancedER-IHC"
train_images, train_masks = load_and_preprocess_data(train_folder, "images","masks")


# Define patch sizes and stride
large_patch_size = 512
medium_patch_size = 256
small_patch_size = 192
stride = 512

# Create multi-scale patches for training data
train_image_large_patches, train_image_medium_patches, train_image_small_patches = [], [], []
train_mask_large_patches, train_mask_medium_patches, train_mask_small_patches = [], [], []

for image, mask in zip(train_images, train_masks):
    large_patches, medium_patches, small_patches = create_multiscale_patches(image, large_patch_size, medium_patch_size, small_patch_size, stride)
    train_image_large_patches.extend(large_patches)
    train_image_medium_patches.extend(medium_patches)
    train_image_small_patches.extend(small_patches)

    large_patches, medium_patches, small_patches = create_multiscale_patches(mask, large_patch_size, medium_patch_size, small_patch_size, stride)
    train_mask_large_patches.extend(large_patches)
    train_mask_medium_patches.extend(medium_patches)
    train_mask_small_patches.extend(small_patches)


# Convert lists to numpy arrays
train_image_large_patches = np.array(train_image_large_patches)
train_image_medium_patches = np.array(train_image_medium_patches)
train_image_small_patches = np.array(train_image_small_patches)
train_mask_large_patches = np.array(train_mask_large_patches)
train_mask_medium_patches = np.array(train_mask_medium_patches)
train_mask_small_patches = np.array(train_mask_small_patches)


In [None]:
#Resize patches
train_image_large_patches_resized = train_image_large_patches  # No need to resize
train_image_medium_patches_resized = resize_patches(train_image_medium_patches, large_patch_size)
train_image_small_patches_resized = resize_patches(train_image_small_patches, large_patch_size)
train_mask_large_patches_resized = train_mask_large_patches  # No need to resize
train_mask_medium_patches_resized = resize_patches(train_mask_medium_patches, large_patch_size, True)  
train_mask_small_patches_resized = resize_patches(train_mask_small_patches, large_patch_size, True)  

In [None]:
print("large patch shape:", train_image_large_patches_resized.shape)  
print("medium shape:", train_image_medium_patches_resized.shape)   
print("small shape:", train_image_small_patches_resized.shape)    
num_samples = train_image_large_patches_resized.shape[0]
print("Number of samples:", num_samples)

In [None]:
# After resizing the patches, concatenate them to form unified representations
unified_images = np.concatenate([
    train_image_large_patches_resized,
    train_image_medium_patches_resized,
    train_image_small_patches_resized
], axis=0)

unified_masks_color = np.concatenate([
    train_mask_large_patches_resized,
    train_mask_medium_patches_resized,
    train_mask_small_patches_resized
], axis=0)

# Verify the shapes
print("Unified Images shape:", unified_images.shape) 
print("Unified Masks shape:", unified_masks_color.shape)   

In [None]:
import numpy as np
import cv2
from scipy import ndimage as ndi

def centroids_dot_mask_from_rgb_mask(mask_rgb, dot_size=3, min_area=5):
    """
    One 3x3 white dot per nucleus (connected component), regardless of shape/class.
    0 = background, 1 = centroid.

    Args:
        mask_rgb: HxWx3 uint8 color mask (0,0,0 background; anything else = nucleus)
        dot_size: odd integer (3 -> 3x3)
        min_area: ignore tiny specks smaller than this many pixels
    """
    H, W, _ = mask_rgb.shape

    # 1) Binary foreground (any non-black pixel is nucleus)
    fg = (mask_rgb[..., 0] | mask_rgb[..., 1] | mask_rgb[..., 2]) > 0
    fg_u8 = fg.astype(np.uint8)

    if fg_u8.max() == 0:
        return np.zeros((H, W), dtype=np.uint8)

    # (Optional) clean up tiny specks that can create fake peaks
    if min_area > 1:
        labeled_tmp, n_tmp = ndi.label(fg_u8)
        sizes = np.bincount(labeled_tmp.ravel())
        keep = np.ones_like(sizes, dtype=bool)
        keep[0] = False
        keep[np.where(sizes < min_area)] = False
        fg_u8 = keep[labeled_tmp].astype(np.uint8)

    # 2) Distance transform inside foreground
    dt = cv2.distanceTransform(fg_u8, distanceType=cv2.DIST_L2, maskSize=3)

    # 3) Label connected components (each nucleus)
    labeled, n = ndi.label(fg_u8)

    # 4) For each component, take the pixel with the maximum distance value
    dots = np.zeros((H, W), dtype=np.uint8)
    if n > 0:
        for lab in range(1, n + 1):
            mask_lab = (labeled == lab)
            if not mask_lab.any():
                continue
            # argmax within the component
            flat_idx = np.argmax(dt[mask_lab])
            r_comp, c_comp = np.where(mask_lab)
            r = r_comp[flat_idx]
            c = c_comp[flat_idx]
            dots[r, c] = 1

    # 5) Expand to a dot_size × dot_size square (clip at borders)
    k = dot_size // 2
    if dot_size > 1:
        kernel = np.ones((dot_size, dot_size), np.uint8)
        dots = cv2.dilate(dots, kernel, iterations=1)

    return dots.astype(np.uint8)

In [None]:
# === Example usage on three patch scales ===
unified_heatmaps = []

print("Generating Centroids For Large Patches")
for idx, mask_patch in enumerate(train_mask_large_patches_resized, 1):
    dots = centroids_dot_mask_from_rgb_mask(mask_patch, 3, 5)
    unified_heatmaps.append(dots)
    print(f"\rProcessed {idx}/{len(train_mask_large_patches_resized)}", end='', flush=True)
print()

print("Generating Centroids For Medium Patches")
for idx, mask_patch in enumerate(train_mask_medium_patches_resized, 1):
    dots = centroids_dot_mask_from_rgb_mask(mask_patch, 3, 5)
    unified_heatmaps.append(dots)
    print(f"\rProcessed {idx}/{len(train_mask_medium_patches_resized)}", end='', flush=True)
print()

print("Generating Centroids For Small Patches")
for idx, mask_patch in enumerate(train_mask_small_patches_resized, 1):
    dots = centroids_dot_mask_from_rgb_mask(mask_patch, 3, 5)
    unified_heatmaps.append(dots)
    print(f"\rProcessed {idx}/{len(train_mask_small_patches_resized)}", end='', flush=True)
print()

unified_heatmaps = np.array(unified_heatmaps, dtype=np.uint8)
print("Done")

In [None]:
unified_masks = np.array([cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) for mask in unified_masks_color])

In [None]:

# Verify the shapes
print("Unified Images shape:", unified_images.shape) 
print("Unified Masks shape:", unified_masks.shape)   
print("Unified heatmat shape:", unified_heatmaps.shape) 
print("____________________________________________________") 
print("Max pixel value in image is: ", unified_images.max())
print("Max pixel value in heatmap is: ", unified_heatmaps.max())
print("Min pixel value in heatmap is: ", unified_heatmaps.min())
print("____________________________________________________")
unique_labels = np.unique(unified_masks)
print("Labels in the mask are : ", unique_labels)
num_classes = len(unique_labels)
print("Total Classes in the mask are : ", num_classes)
print("____________________________________________________")
unique_labels = np.unique(unified_heatmaps)
print("Labels in the heatmap are : ", unique_labels)
num_classesheatmap = len(unique_labels)
print("Total Classes in the heatmap are : ", num_classesheatmap)

In [None]:
small_sample = num_samples+num_samples
def display_multiscale_pairs(images, masks, heatmaps, base_index):
    """Display large, medium, and small patches for a given base index."""
    fig, axes = plt.subplots(3, 3, figsize=(15, 8))
    
    # Large patches (index)
    axes[0, 0].imshow(images[base_index])
    axes[0, 0].set_title(f"Large Image {base_index}")
    axes[0, 0].axis('off')
    
    axes[1, 0].imshow(masks[base_index], cmap='gray')
    axes[1, 0].set_title(f"Large Mask {base_index}")
    axes[1, 0].axis('off')

    axes[2, 0].imshow(heatmaps[base_index], cmap='gray')
    axes[2, 0].set_title(f"Large heatmaps {base_index}")
    axes[2, 0].axis('off')

    # Medium patches (index + 945)
    axes[0, 1].imshow(images[base_index + num_samples])
    axes[0, 1].set_title(f"Medium Image {base_index + num_samples}")
    axes[0, 1].axis('off')
    
    axes[1, 1].imshow(masks[base_index + num_samples], cmap='gray')
    axes[1, 1].set_title(f"Medium Mask {base_index + num_samples}")
    axes[1, 1].axis('off')

    axes[2, 1].imshow(heatmaps[base_index + num_samples], cmap='gray')
    axes[2, 1].set_title(f"Medium heatmap {base_index + num_samples}")
    axes[2, 1].axis('off')
    
    # Small patches (index + 1890)
    axes[0, 2].imshow(images[base_index + small_sample])
    axes[0, 2].set_title(f"Small Image {base_index + small_sample}")
    axes[0, 2].axis('off')
    
    axes[1, 2].imshow(masks[base_index + small_sample], cmap='gray')
    axes[1, 2].set_title(f"Small Mask {base_index + small_sample}")
    axes[1, 2].axis('off')

    axes[2, 2].imshow(heatmaps[base_index + small_sample], cmap='gray')
    axes[2, 2].set_title(f"Small heatmap {base_index + small_sample}")
    axes[2, 2].axis('off')
    
    plt.tight_layout()
    plt.show()

# Example: Display patches for the first region (index 0)
display_multiscale_pairs(unified_images, unified_masks, unified_heatmaps, 5)

In [None]:
# Normalize images
unified_images = unified_images / 255.



In [None]:
# Encode masks
unified_masks_encoded, mask_label_encoder = encode_mask(unified_masks)

# After encoding masks
unique_labels = np.unique(unified_masks_encoded)
numclasses = len(unique_labels)
print("Actual number of classes:", numclasses)


# Categorize and reshape masks
train_mask_cat = categorize_and_reshape_masks(unified_masks_encoded, numclasses)


unified_heatmaps = np.expand_dims(unified_heatmaps, axis=-1)  

In [None]:
# compute class weights from your labels (after encode & to_categorical)
class_weights = np.ones((numclasses,), dtype=np.float32)
# increase background weight (assumes index of background is wherever '0' was mapped)
bg_index = int(np.where(mask_label_encoder.classes_ == 0)[0][0])
class_weights[bg_index] = 1.5  # try 1.5–2.0

In [None]:
# Define input shapes for the model
input_shape = (unified_images.shape[1], unified_images.shape[2], unified_images.shape[3])


In [None]:
import tensorflow as tf
tf.config.optimizer.set_experimental_options({"layout_optimizer": False, "model_pruner": False})
from tensorflow.keras.callbacks import (
    ReduceLROnPlateau, EarlyStopping, ModelCheckpoint
)

# 8) Callbacks: LR scheduler, early stopping, checkpoint
reduce_lr = ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.5,
    patience=10,
    min_lr=1e-6,
    verbose=1
)
early_stop = EarlyStopping(
    monitor='val_loss',
    patience=15,
    restore_best_weights=True,
    verbose=1,
    mode='min'
)
checkpoint = ModelCheckpoint(
    '/kaggle/working/Binary_model_ER_IHC.keras',
    monitor='val_loss',
    save_best_only=True,
    verbose=1
)

In [None]:
import numpy as np
from collections import Counter
from sklearn.model_selection import train_test_split


# 5. Perform stratified split on the new labels
X_train, X_val, y_train_mask, y_val_mask, y_train_heatmap, y_val_heatmap = train_test_split(
    unified_images,
    train_mask_cat,
    unified_heatmaps,
    test_size=0.2,
    random_state=42
)

# 6. Verify shapes
print("X_train shape:", X_train.shape)
print("X_val   shape:", X_val.shape)
print("y_train mask shape:", y_train_mask.shape)
print("y_val   mask shape:", y_val_mask.shape)
print("y_train heatmap shape:", y_train_heatmap.shape)
print("y_val   heatmap shape:", y_val_heatmap.shape)



In [None]:
import tensorflow as tf

try:
    # Detect and initialize the TPU
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect(tpu='local')  # Kaggle-specific TPU address
    strategy = tf.distribute.TPUStrategy(tpu)  # TPU strategy
except ValueError as e:
    # If TPU is not available, fall back to the default strategy (GPU or CPU)
    print("TPU not found, using default strategy (CPU/GPU)")
    print("Error message:", str(e))  # Print the error message
    strategy = tf.distribute.get_strategy()
print("Number of accelerators: ", strategy.num_replicas_in_sync)

In [None]:
import tensorflow as tf

GLOBAL_BS = 64   # 8 TPU cores → per-core 4
def ds_from_numpy(X, y_mask, y_heat):
    ds = tf.data.Dataset.from_tensor_slices((X, {'NS': y_mask, 'CM': y_heat}))
    ds = ds.shuffle(2048, reshuffle_each_iteration=True)
    ds = ds.batch(GLOBAL_BS, drop_remainder=True)  # <- important on TPU
    ds = ds.prefetch(tf.data.AUTOTUNE)
    return ds

train_ds = ds_from_numpy(X_train, y_train_mask, y_train_heatmap)
val_ds   = ds_from_numpy(X_val,   y_val_mask,   y_val_heatmap)


In [None]:
# After you define class_weights (NumPy)  shape: [C]
class_weights_tf = tf.constant(class_weights, dtype=tf.float32)

@tf.keras.utils.register_keras_serializable()
def weighted_cce(y_true, y_pred):
    # y_true, y_pred: [B, H, W, C], one-hot y_true
    # reshape to [1, 1, 1, C] so it broadcasts over [B, H, W, C]
    w = tf.reshape(class_weights_tf, (1, 1, 1, -1))  # [1,1,1,C]

    # per-pixel weight = weight of the true class
    pix_w = tf.reduce_sum(y_true * w, axis=-1)      # [B, H, W]

    # standard per-pixel CE (no reduction)
    ce = tf.keras.losses.categorical_crossentropy(y_true, y_pred)  # [B, H, W]

    # weighted average:
    # sum(ce * weight) / sum(weight) instead of plain mean
    loss = tf.reduce_sum(ce * pix_w) / tf.reduce_sum(pix_w)
    return loss


In [None]:
from tensorflow.keras.layers import Resizing, Conv2D, BatchNormalization, Activation, MaxPooling2D, UpSampling2D, Concatenate, Add, Layer, Input, Dense, LayerNormalization, SpatialDropout2D
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
import tensorflow as tf
# import tensorflow_addons as tfa
import math
from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy('float32')
# Custom rotation layer
class RotationLayer(Layer):
    def __init__(self, k=1, **kwargs):  # Default to 90° instead of 0°
        super(RotationLayer, self).__init__(**kwargs)
        self.k = k % 4  # ensure valid rotation value

    def call(self, inputs):
        # Skip rotation if k == 0 (no rotation)
        if self.k == 0:
            return inputs
        return tf.image.rot90(inputs, k=self.k)



# Spiral fusion block:
# - Always uses 90°, 180°, 270° rotations
# - Rotates ONLY the current (last) feature in level_features
# - Skips 0°/360° (original orientation already represented by x)
def rotation_block(x, level_features, filters):
    # Keep the base feature (unrotated path)
    spiral_connections = [x]

    # Use only the most recent / current feature map
    current_feat = level_features[-1]

    # Generate 90°, 180°, 270° rotated versions of the current feature
    for k in [1, 2, 3]:  # 1→90°, 2→180°, 3→270°
        rotated = RotationLayer(k=k)(current_feat)

        # Resize to match x if necessary
        if rotated.shape[1:3] != x.shape[1:3]:
            rotated = Resizing(x.shape[1], x.shape[2])(rotated)

        # Project to desired number of filters
        rotated = Conv2D(filters, kernel_size=3, padding='same')(rotated)
        spiral_connections.append(rotated)

    # Fuse original + rotated features
    return Concatenate()(spiral_connections)




# Standard conv block
def conv_block(x, filters, kernel_size=3, activation='relu', drop_rate=0.0, spatial=True):
    x = Conv2D(filters, kernel_size, padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation(activation)(x)

    if drop_rate and drop_rate > 0:
        x = SpatialDropout2D(drop_rate)(x) if spatial else Dropout(drop_rate)(x)

    x = Conv2D(filters, kernel_size, padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation(activation)(x)
    return x
# Transformer-based integration module
class TransformerModule(Layer):
    def __init__(self, num_heads, key_dim, ff_dim, dropout=0.1):
        super(TransformerModule, self).__init__()
        self.mha = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=key_dim)
        self.ffn = tf.keras.Sequential([
            tf.keras.layers.Dense(ff_dim, activation='relu'),
            tf.keras.layers.Dense(key_dim * num_heads),
            tf.keras.layers.Dropout(dropout)
        ])
        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.dropout = tf.keras.layers.Dropout(dropout)
    
    def call(self, inputs):
        batch_size, H, W, C = tf.shape(inputs)[0], tf.shape(inputs)[1], tf.shape(inputs)[2], inputs.shape[-1]
        x = tf.reshape(inputs, [batch_size, H * W, C])
        attn = self.mha(x, x)
        attn = self.dropout(attn)
        out1 = self.layernorm1(x + attn)
        ffn = self.ffn(out1)
        ffn = self.dropout(ffn)
        out2 = self.layernorm2(out1 + ffn)
        return tf.reshape(out2, [batch_size, H, W, C])

# Fuse multi-scale features and apply transformer
def multi_scale_integration_hub(features, filters):
    target_h, target_w = features[-1].shape[1], features[-1].shape[2]
    resize = Resizing(target_h, target_w)
    fused = []
    for feat in features:
        r = resize(feat)
        r = Conv2D(filters, kernel_size=1, padding='same')(r)
        fused.append(r)
    combined = Add()(fused)
    combined = BatchNormalization()(combined)

    num_heads = 8
    key_dim = filters // num_heads
    ff_dim = filters * 8
    x = TransformerModule(num_heads, key_dim, ff_dim)(combined)
    return conv_block(x, filters * 3)




#RoMT-Net: A Rotation-aware Multi-Scale Transformer Network for Histopathological Segmentation and Nuclei Localization
def CRAFTNet(input_shape=input_shape, num_classes=numclasses):
    
    inputs = Input(input_shape)
    level_features = []

    # Encoder
    x = conv_block(inputs, 32,  drop_rate=0.1)
    level_features.append(x)

    x = MaxPooling2D(2)(x)
    x = rotation_block(x, [level_features[0]], 32)
    x = conv_block(x, 64,  drop_rate=0.2)
    level_features.append(x)

    x = MaxPooling2D(2)(x)
    x = rotation_block(x, level_features[:2], 64)
    x = conv_block(x, 128,  drop_rate=0.3)
    level_features.append(x)

    x = MaxPooling2D(2)(x)
    x = rotation_block(x, level_features[:3], 128)
    x = conv_block(x, 256,  drop_rate=0.4)
    level_features.append(x)

    x = MaxPooling2D(2)(x)
    x = rotation_block(x, level_features[:4], 256)
    x = conv_block(x, 512,  drop_rate=0.5)

    #Multi-scale integration
    x = multi_scale_integration_hub(level_features + [x], 512)

    # Decoder
    x = UpSampling2D(2)(x)
    x = rotation_block(x, level_features[:4], 256)
    x = conv_block(x, 256,  drop_rate=0.5)

    x = UpSampling2D(2)(x)
    x = rotation_block(x, level_features[:3], 128)
    x = conv_block(x, 128,  drop_rate=0.1)

    x = UpSampling2D(2)(x)
    x = rotation_block(x, level_features[:2], 64)
    x = conv_block(x, 64,  drop_rate=0.2)

    x = UpSampling2D(2)(x)
    x = rotation_block(x, level_features[:1], 32)
    x = conv_block(x, 32,  drop_rate=0.3)
    
    # Outputs
    NS_out = Conv2D(num_classes, 1, activation='softmax', name='NS')(x)
    CM_out = Conv2D(1, 1, activation='sigmoid', name='CM')(x)

    return Model(inputs=inputs, outputs=[NS_out, CM_out])

# Build & compile under TPU strategy
with strategy.scope():
    model = CRAFTNet(input_shape=input_shape, num_classes=numclasses)
    model.compile(
        optimizer=Adam(), 
        loss={'NS': weighted_cce,'CM': 'binary_crossentropy'},
        loss_weights={'NS': 1.0, 'CM': 1.0},
        metrics={'NS': 'accuracy','CM': ['mse']}
    )

model.summary()


In [None]:
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=350,
    callbacks=[reduce_lr, early_stop, checkpoint]
)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

h = history.history
epochs = range(1, len(h['loss']) + 1)

# --- Pull the right series ---
ns_tr_loss   = h.get('NS_loss', [])
ns_va_loss   = h.get('val_NS_loss', [])
ns_tr_acc    = h.get('NS_accuracy', [])
ns_va_acc    = h.get('val_NS_accuracy', [])

cm_tr_loss   = h.get('CM_loss', [])
cm_va_loss   = h.get('val_CM_loss', [])
cm_tr_mse    = h.get('CM_mse', [])
cm_va_mse    = h.get('val_CM_mse', [])

tot_tr_loss  = h['loss']
tot_va_loss  = h['val_loss']

# --- Best epochs ---
best_tot_epoch = int(np.argmin(tot_va_loss)) + 1 if len(tot_va_loss) else None
best_ns_acc_ep = int(np.argmax(ns_va_acc)) + 1  if len(ns_va_acc)  else None
best_ns_loss_ep= int(np.argmin(ns_va_loss)) + 1 if len(ns_va_loss) else None
best_cm_loss_ep= int(np.argmin(cm_va_loss)) + 1 if len(cm_va_loss) else None
best_cm_mse_ep = int(np.argmin(cm_va_mse)) + 1 if len(cm_va_mse) else None

def maybe_vline(ep, label):
    if ep is not None:
        plt.axvline(ep, color='g', linestyle='--', label=f'{label} (Epoch {ep})')

# --- NS (segmentation) Accuracy ---
if len(ns_tr_acc) and len(ns_va_acc):
    plt.figure(figsize=(9,6))
    plt.plot(epochs, ns_tr_acc, label='NS Training Accuracy')
    plt.plot(epochs, ns_va_acc, label='NS Validation Accuracy')
    maybe_vline(best_ns_acc_ep, 'Best NS Val Acc')
    plt.title('NS (Segmentation) Accuracy')
    plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.legend(); plt.grid(True); plt.tight_layout()
    plt.show()

# --- NS (segmentation) Loss ---
if len(ns_tr_loss) and len(ns_va_loss):
    plt.figure(figsize=(9,6))
    plt.plot(epochs, ns_tr_loss, label='NS Training Loss')
    plt.plot(epochs, ns_va_loss, label='NS Validation Loss')
    maybe_vline(best_ns_loss_ep, 'Lowest NS Val Loss')
    plt.title('NS (Segmentation) Loss (categorical_crossentropy)')
    plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend(); plt.grid(True); plt.tight_layout()
    plt.show()

# --- CM (centroid/heatmap) Loss ---
if len(cm_tr_loss) and len(cm_va_loss):
    plt.figure(figsize=(9,6))
    plt.plot(epochs, cm_tr_loss, label='CM Training Loss')
    plt.plot(epochs, cm_va_loss, label='CM Validation Loss')
    maybe_vline(best_cm_loss_ep, 'Lowest CM Val Loss')
    plt.title('CM (Centroid Map) Loss (binary_crossentropy)')
    plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend(); plt.grid(True); plt.tight_layout()
    plt.show()

# --- CM (centroid/heatmap) MSE ---
if len(cm_tr_mse) and len(cm_va_mse):
    plt.figure(figsize=(9,6))
    plt.plot(epochs, cm_tr_mse, label='CM Training MSE')
    plt.plot(epochs, cm_va_mse, label='CM Validation MSE')
    maybe_vline(best_cm_mse_ep, 'Lowest CM Val MSE')
    plt.title('CM (Centroid Map) Mean Squared Error')
    plt.xlabel('Epoch'); plt.ylabel('MSE'); plt.legend(); plt.grid(True); plt.tight_layout()
    plt.show()

# --- Overall (weighted) model loss ---
plt.figure(figsize=(9,6))
plt.plot(epochs, tot_tr_loss, label='Total Training Loss')
plt.plot(epochs, tot_va_loss, label='Total Validation Loss')
maybe_vline(best_tot_epoch, 'Lowest Total Val Loss')
plt.title('Overall Model Loss (NS + 0.8×CM)')
plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend(); plt.grid(True); plt.tight_layout()
plt.show()


In [None]:
# import tensorflow as tf
# from tensorflow.keras.layers import Layer
# @tf.keras.utils.register_keras_serializable()
# def weighted_cce(y_true, y_pred):
#     w = tf.reshape(class_weights_tf, (1, 1, 1, -1))
#     pix_w = tf.reduce_sum(y_true * w, axis=-1)
#     ce = tf.keras.losses.categorical_crossentropy(y_true, y_pred)
#     loss = tf.reduce_sum(ce * pix_w) / tf.reduce_sum(pix_w)
#     return loss


# @tf.keras.utils.register_keras_serializable()
# class RotationLayer(Layer):
#     def __init__(self, k=0, **kwargs):
#         super(RotationLayer, self).__init__(**kwargs)
#         self.k = k

#     def call(self, inputs):
#         return tf.image.rot90(inputs, k=self.k)

#     def get_config(self):
#         config = super().get_config()
#         config.update({'k': self.k})
#         return config

# @tf.keras.utils.register_keras_serializable()
# class TransformerModule(Layer):
#     def __init__(self, num_heads, key_dim, ff_dim, dropout=0.1, **kwargs):
#         super(TransformerModule, self).__init__(**kwargs)
#         self.num_heads = num_heads
#         self.key_dim = key_dim
#         self.ff_dim = ff_dim
#         self.dropout_rate = dropout

#         self.mha = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=key_dim)
#         self.ffn = tf.keras.Sequential([
#             tf.keras.layers.Dense(ff_dim, activation='relu'),
#             tf.keras.layers.Dense(key_dim * num_heads),
#             tf.keras.layers.Dropout(dropout)
#         ])
#         self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
#         self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
#         self.dropout = tf.keras.layers.Dropout(dropout)

#     def call(self, inputs):
#         batch_size, H, W, C = tf.shape(inputs)[0], tf.shape(inputs)[1], tf.shape(inputs)[2], inputs.shape[-1]
#         x = tf.reshape(inputs, [batch_size, H * W, C])
#         attn = self.mha(x, x)
#         attn = self.dropout(attn)
#         out1 = self.layernorm1(x + attn)
#         ffn = self.ffn(out1)
#         ffn = self.dropout(ffn)
#         out2 = self.layernorm2(out1 + ffn)
#         return tf.reshape(out2, [batch_size, H, W, C])

#     def get_config(self):
#         config = super().get_config()
#         config.update({
#             'num_heads': self.num_heads,
#             'key_dim': self.key_dim,
#             'ff_dim': self.ff_dim,
#             'dropout': self.dropout_rate
#         })
#         return config


# model = load_model(
#     '/kaggle/input/deletethismodel/other/default/1/Binary_model_ER_IHC (2).keras',
#     custom_objects={
#         'RotationLayer': RotationLayer,
#         'TransformerModule': TransformerModule,
#         'weighted_cce': weighted_cce,   # optional if registered, but safe
#     },
#     compile=True
# )

# # # Export to a TF SavedModel for broad TF compatibility
# # model.export("MOSAIC_savedmodel", format="tf_saved_model")
# # # (or also create an H5 if your layers are compatible)
# # model.save("MOSAIC.h5")

In [None]:
from skimage.feature import peak_local_max
from skimage.segmentation import watershed
from scipy.ndimage import distance_transform_edt
import numpy as np
import matplotlib.pyplot as plt
import cv2
import random
from skimage.measure import label, regionprops
from scipy.stats import mode
#Color map for ER
COLOR_MAPPING = {
    0: [0, 0, 0],        # Black (Background)
    122: [0, 159, 255],  # Blue (Normal)
    150: [0, 255, 0],    # Green (Weak)
    203: [255, 216, 0],  # Yellow (Moderate)
    76: [255, 0, 0]      # Red (Strong)
}

# #For MonuSac
# COLOR_MAPPING = {
#     0:   [  0,   0,   0],  # Black   (Background)
#     29:  [  0,   0, 255],  # Blue    (Neutrophils)
#     76:  [255,   0,   0],  # Red     (Epithelial)
#     150: [  0, 255,   0],  # Green   (Macrophages)
#     179: [  0, 255, 255],  # Cyan    (Moderate)
# }

# # === Final 4-Class Color Mapping ===
# COLOR_MAPPING = {
#     0:   [  0,   0,   0],     # Background - Black
#     76:  [255,   0,   0],     # Epithelial - Red
#     104: [  0, 128, 255],     # Inflammatory - Blue
#     151: [255, 128,   0],     # Spindle-shaped - Orange
# }

def separate_nuclei(seg_mask, heatmap_pred, min_distance=5):
    """Post-processing using predicted centroids (single channel heatmap)"""
    # Find peaks in single channel heatmap
    peaks = peak_local_max(heatmap_pred, 
                          min_distance=min_distance,
                          threshold_abs=0.3)
    
    # Create markers for watershed
    all_markers = np.zeros_like(seg_mask)
    for i, (y, x) in enumerate(peaks):
        all_markers[y, x] = i + 1  # Unique labels for each nucleus
        
    # Apply watershed
    distance_map = distance_transform_edt(seg_mask)
    labels = watershed(-distance_map, all_markers, mask=seg_mask)
    
    return labels


def apply_fixed_colormap(mask):
    """Map grayscale mask values to RGB colors using fixed colormap."""
    colored = np.zeros((*mask.shape, 3), dtype=np.uint8)

    for gray_value, rgb_color in COLOR_MAPPING.items():
        colored[mask == gray_value] = rgb_color

    return colored
def refine_segmentation(pred_mask):
    """Assign one class per nucleus using majority voting."""
    refined_mask = np.zeros_like(pred_mask)
    binary_mask = pred_mask > 0  # Nuclei are non-zero classes
    labeled_mask = label(binary_mask)

    for region in regionprops(labeled_mask):
        coords = region.coords
        values = pred_mask[coords[:, 0], coords[:, 1]]
        majority_class = mode(values, axis=None).mode.item()
        refined_mask[coords[:, 0], coords[:, 1]] = majority_class

    return refined_mask
    


# Select a random test image from validation set
test_img_number = random.randint(0, len(unified_images) - 1)
scale="scale"
if 0 <= test_img_number < num_samples:
    scale="20X scale"
elif num_samples <= test_img_number <= (num_samples+num_samples):
    scale="40X scale"
elif (num_samples+num_samples) <= test_img_number <= (num_samples+num_samples+num_samples):
    scale="80X scale"
test_img = unified_images[test_img_number]
ground_truth_seg = train_mask_cat[test_img_number]
ground_truth_heatmap = unified_heatmaps[test_img_number]



# Prepare input for prediction
test_img_input = np.expand_dims(test_img, axis=0)

# Get predictions
prediction = model.predict(test_img_input)
seg_pred = prediction[0][0]  # Segmentation output
heatmap_pred = prediction[1][0]  # Heatmap output (now single channel)


# Process segmentation outputs
ground_truth_encoded = np.argmax(ground_truth_seg, axis=-1)
ground_truth_class = mask_label_encoder.inverse_transform(ground_truth_encoded.ravel()).reshape(ground_truth_encoded.shape)
predicted_encoded = np.argmax(seg_pred, axis=-1)
predicted_class = mask_label_encoder.inverse_transform(predicted_encoded.ravel()).reshape(predicted_encoded.shape)
predicted_class = refine_segmentation(predicted_class)
# Prepare test image for display
test_img_display = (test_img * 255).astype(np.uint8)
test_img_display = cv2.cvtColor(test_img_display, cv2.COLOR_BGR2RGB)

# Create contour visualization
contoured_image = test_img_display.copy()
contour_mask = predicted_class.astype(np.uint8)
contours, _ = cv2.findContours(contour_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)



for contour in contours:
    M = cv2.moments(contour)
    if M["m00"] == 0:
        continue
    cx = int(M["m10"] / M["m00"])
    cy = int(M["m01"] / M["m00"])
    if 0 <= cy < predicted_class.shape[0] and 0 <= cx < predicted_class.shape[1]:
        class_label = predicted_class[cy, cx]
        color = tuple(COLOR_MAPPING.get(class_label, (0, 0, 0)))
        cv2.drawContours(contoured_image, [contour], -1, color, 2)



# Post-process predictions
separated_instances = separate_nuclei(
    seg_mask=np.argmax(seg_pred, axis=-1),
    heatmap_pred=np.squeeze(heatmap_pred),
    min_distance=5
)

# Create colored instances visualization that maintains class colors
colored_instances = apply_fixed_colormap(predicted_class).copy()

# Draw instance boundaries on top of class-colored image
instance_boundaries = np.zeros_like(predicted_class)
for label in np.unique(separated_instances):
    if label == 0:  # Skip background
        continue
    mask = np.uint8(separated_instances == label)
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    cv2.drawContours(colored_instances, contours, -1, (0, 0, 0), 1)  # White boundaries


##################################
# Create contour visualization from colored_instances
contoured_from_colored = test_img_display.copy()

# Convert colored_instances to grayscale to find contours
colored_gray = cv2.cvtColor(colored_instances, cv2.COLOR_BGR2GRAY)
contours, _ = cv2.findContours(colored_gray, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

# Draw contours with class colors
for contour in contours:
    M = cv2.moments(contour)
    if M["m00"] == 0:
        continue
    cx = int(M["m10"] / M["m00"])
    cy = int(M["m01"] / M["m00"])
    if 0 <= cy < predicted_class.shape[0] and 0 <= cx < predicted_class.shape[1]:
        class_label = predicted_class[cy, cx]
        color = tuple(COLOR_MAPPING.get(class_label, (0, 0, 0)))
        cv2.drawContours(contoured_from_colored, [contour], -1, color, 2)

##################################

# Create figure with modified layout
plt.figure(figsize=(24, 18))

# First row: Input and segmentation results
plt.subplot(3, 3, 1)
plt.title('Testing Image ' + scale)
plt.imshow(test_img_display)
plt.axis('off')

plt.subplot(3, 3, 2)
plt.title('Segmentation Ground Truth ' + scale)
plt.imshow(apply_fixed_colormap(ground_truth_class))
plt.axis('off')

plt.subplot(3, 3, 3)
plt.title('Predicted Segmentation ' + scale)
plt.imshow(apply_fixed_colormap(predicted_class))
plt.axis('off')

# Second row: Heatmaps and contours
plt.subplot(3, 3, 4)
plt.title('Coarse Contours on Image ' + scale)
plt.imshow(contoured_image)
plt.axis('off')

plt.subplot(3, 3, 5)
plt.title('Centroid Ground Truth ' + scale)
plt.imshow(np.squeeze(ground_truth_heatmap), cmap='gray')
plt.axis('off')

plt.subplot(3, 3, 6)
plt.title('Predicted centroids ' + scale)
plt.imshow(np.squeeze(heatmap_pred), cmap='gray')
plt.axis('off')

# Third row: Instance separation results
plt.subplot(3, 3, 7)
plt.title('Watershed Labels (For Debug) ' + scale)
plt.imshow(separated_instances, cmap='nipy_spectral')
plt.axis('off')

plt.subplot(3, 3, 8)
plt.title('Separated Instances with Class Colors ' + scale)
plt.imshow(colored_instances)
plt.axis('off')


plt.subplot(3, 3, 9)
plt.title('Instance Contours on Image ' + scale)
plt.imshow(contoured_from_colored)
plt.axis('off')


plt.tight_layout()
plt.show()

# Diagnostic information
print(f"Displaying results for sample {test_img_number}")
print("Unique labels in ground truth:", np.unique(ground_truth_class))
print("Unique labels in prediction:", np.unique(predicted_class))
print("Number of detected instances:", len(np.unique(separated_instances))-1)

**Code to save only final segmentation**

In [None]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from skimage.measure import label as sk_label, regionprops
from scipy.stats import mode   # make sure this is imported

def refine_segmentation(pred_mask):
    refined_mask = np.zeros_like(pred_mask)
    binary_mask = pred_mask > 0
    labeled_mask = sk_label(binary_mask)   # use the alias
    for region in regionprops(labeled_mask):
        coords = region.coords
        values = pred_mask[coords[:, 0], coords[:, 1]]
        majority_class = mode(values, axis=None).mode.item()
        refined_mask[coords[:, 0], coords[:, 1]] = majority_class
    return refined_mask

# === where to save on Kaggle ===
out_dir = "/kaggle/working/romtnet_ER_Segmentation"
os.makedirs(out_dir, exist_ok=True)

def determine_scale(idx, num_samples):
    if 0 <= idx < num_samples:
        return "20X scale"
    elif num_samples <= idx < 2 * num_samples:
        return "40X scale"
    else:
        return "80X scale"

# total number of images to process
total_images = len(unified_images)

for idx, (test_img, ground_truth_seg, ground_truth_heatmap) in enumerate(
    zip(unified_images, train_mask_cat, unified_heatmaps)
):
    scale = determine_scale(idx, num_samples)

    # Prepare input and predict
    test_img_input = np.expand_dims(test_img, axis=0)
    prediction = model.predict(test_img_input, verbose=0)
    seg_pred = prediction[0][0]        # Segmentation output [H,W,C]
    heatmap_pred = prediction[1][0]    # Heatmap output [H,W,1] or [H,W]

    # Convert one-hot (or logits) to class map via label encoder
    gt_encoded = np.argmax(ground_truth_seg, axis=-1)
    ground_truth_class = mask_label_encoder.inverse_transform(
        gt_encoded.ravel()
    ).reshape(gt_encoded.shape)

    pred_encoded = np.argmax(seg_pred, axis=-1)
    predicted_class = mask_label_encoder.inverse_transform(
        pred_encoded.ravel()
    ).reshape(pred_encoded.shape)

    # Majority vote refinement per instance
    predicted_class = refine_segmentation(predicted_class)

    # Prepare RGB image for overlays/plots
    test_img_display = (test_img * 255).astype(np.uint8)
    # If 'test_img' was originally RGB scaled [0,1], this conversion is correct:
    # matplotlib expects RGB, so keep it as RGB for plotting/drawing
    test_img_display = cv2.cvtColor(test_img_display, cv2.COLOR_BGR2RGB)

    # Coarse contour visualization (by class colors at centroid)
    contoured_image = test_img_display.copy()
    contour_mask = predicted_class.astype(np.uint8)
    contours, _ = cv2.findContours(contour_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    for contour in contours:
        M = cv2.moments(contour)
        if M["m00"] == 0:
            continue
        cx = int(M["m10"] / M["m00"])
        cy = int(M["m01"] / M["m00"])
        if 0 <= cy < predicted_class.shape[0] and 0 <= cx < predicted_class.shape[1]:
            class_label = predicted_class[cy, cx]
            color = tuple(COLOR_MAPPING.get(class_label, (0, 0, 0)))
            cv2.drawContours(contoured_image, [contour], -1, color, 2)

    # Instance separation from seg + centroid heatmap
    separated_instances = separate_nuclei(
        seg_mask=np.argmax(seg_pred, axis=-1),
        heatmap_pred=np.squeeze(heatmap_pred),
        min_distance=5
    )

    # Colored instances with class colors + black boundaries
    colored_instances = apply_fixed_colormap(predicted_class).copy()
    for inst_label in np.unique(separated_instances):
        if inst_label == 0:
            continue
        mask = np.uint8(separated_instances == inst_label)
        inst_contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        cv2.drawContours(colored_instances, inst_contours, -1, (0, 0, 0), 1)

    # Contours drawn directly from class-colored instances
    contoured_from_colored = test_img_display.copy()
    colored_gray = cv2.cvtColor(colored_instances, cv2.COLOR_BGR2GRAY)
    contours2, _ = cv2.findContours(colored_gray, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    for contour in contours2:
        M = cv2.moments(contour)
        if M["m00"] == 0:
            continue
        cx = int(M["m10"] / M["m00"])
        cy = int(M["m01"] / M["m00"])
        if 0 <= cy < predicted_class.shape[0] and 0 <= cx < predicted_class.shape[1]:
            class_label = predicted_class[cy, cx]
            color = tuple(COLOR_MAPPING.get(class_label, (0, 0, 0)))
            cv2.drawContours(contoured_from_colored, [contour], -1, color, 2)

    # ===== Save ONLY the contoured_from_colored as 512x512x3 PNG =====
    # contoured_from_colored is RGB now (for matplotlib). Convert to BGR for cv2.imwrite.
    cfc_resized = cv2.resize(contoured_from_colored, (512, 512), interpolation=cv2.INTER_LINEAR)
    cfc_bgr = cv2.cvtColor(cfc_resized, cv2.COLOR_RGB2BGR)
    out_img_path = os.path.join(out_dir, f"contoured_{idx:05d}_{scale.replace(' ', '')}.png")
    cv2.imwrite(out_img_path, cfc_bgr)



    # === Print progress ===
    print(f"Processed {idx+1}/{total_images} images", end='\r', flush=True)

print(f"\nSaved PNGs to: {out_dir}")

**Zip and Download**

In [None]:
!apt-get install -y zip
!zip -r ERReults.zip /kaggle/working/romtnet_ER_Segmentation