In [1]:
# %% Cell 0: Downgrade NumPy in the notebook environment and force a restart

# Use the IPython “magic” so it installs into this kernel
#%pip install --upgrade "numpy<2.0"

# Then immediately exit so Jupyter will prompt you to restart
#import os
#os._exit(0)


In [2]:
# %% Cell 0: Install missing dependency
#%pip install albumentations


In [6]:
# %% Cell 1: Imports & Hyperparameters (TensorFlow Keras only)
import os
from glob import glob
import numpy as np
import time
import cv2
import tensorflow as tf
#from tensorflow.keras.callbacks import CSVLogger, LearningRateScheduler, ModelCheckpoint
#from tensorflow.keras.layers import *
#from tensorflow.keras.models import Model, load_model
#from tensorflow.keras.optimizers import Adam
from albumentations import *
#from tensorflow.keras import backend as K
from skimage.feature import peak_local_max
from scipy import ndimage as ndi
from skimage.segmentation import watershed
import skimage.morphology
from skimage.io import imsave
from skimage.morphology import remove_small_objects
import tqdm
from random import shuffle
import matplotlib.pyplot as plt

# Hyperparameters
opts = {
    'number_of_channel':       3,
    'treshold':                0.5,
    'epoch_num':               30,
    'quick_run':               1,
    'batch_size':              16,
    'random_seed_num':         19,
    'crop_size':               256,
    'init_LR':                 0.001,
    'LR_decay_factor':         0.5,
    'LR_drop_after_nth_epoch': 20,
    'result_save_path':        'prediction_image/',
    'model_save_path':         'output_model/'
}

# Ensure output dirs exist
os.makedirs(opts['model_save_path'], exist_ok=True)
os.makedirs(opts['result_save_path'] + 'validation/unet',         exist_ok=True)
os.makedirs(opts['result_save_path'] + 'validation/watershed_unet', exist_ok=True)


  data = fetch_version_info()


In [7]:
# %% Cell 2: Data splits → globs
# TRAIN
train_img   = sorted(glob('data_sr_x2/train/images/*.png'))
train_mask  = sorted(glob('data_sr_x2/train/masks/*.png'))
train_dist  = sorted(glob('data_sr_x2/train/distance_maps/*.png'))
train_label = sorted(glob('data_sr_x2/train/label_masks/*.tif'))
train_vague = sorted(glob('data_sr_x2/train/vague_masks/*.png'))

# VAL
val_img     = sorted(glob('data_sr_x2/val/images/*.png'))
val_mask    = sorted(glob('data_sr_x2/val/masks/*.png'))
val_dist    = sorted(glob('data_sr_x2/val/distance_maps/*.png'))
val_label   = sorted(glob('data_sr_x2/val/label_masks/*.tif'))
val_vague   = sorted(glob('data_sr_x2/val/vague_masks/*.png'))

# TEST
test_img    = sorted(glob('data_sr_x2/test/images/*.png'))
test_mask   = sorted(glob('data_sr_x2/test/masks/*.png'))
test_dist   = sorted(glob('data_sr_x2/test/distance_maps/*.png'))
test_label  = sorted(glob('data_sr_x2/test/label_masks/*.tif'))
test_vague  = sorted(glob('data_sr_x2/test/vague_masks/*.png'))


In [8]:
# %% Cell 3: Losses & Scheduler
def dice_coef(y_true, y_pred):
    smooth = 1.
    y_true_f = tf.keras.backend.flatten(y_true)
    y_pred_f = tf.keras.backend.flatten(y_pred)
    inter = tf.keras.backend.sum(y_true_f * y_pred_f)
    return (2. * inter + smooth) / (tf.keras.backend.sum(y_true_f) + tf.keras.backend.sum(y_pred_f) + smooth)

def dice_loss(y_true, y_pred):
    return 1 - dice_coef(y_true, y_pred)

def bce_dice_loss(y_true, y_pred):
    return 0.5 * tf.keras.losses.binary_crossentropy(y_true, y_pred) - dice_coef(y_true, y_pred)

def step_decay_schedule(initial_lr=1e-3, decay_factor=0.75, epochs_drop=1000):
    def schedule(epoch):
        return initial_lr * (decay_factor ** np.floor(epoch/epochs_drop))
    return tf.keras.callbacks.LearningRateScheduler(schedule, verbose=1)


In [9]:
# %% Cell 4: U-Net model definitions
def deep_unet(IMG_CHANNELS, LearnRate):
    inputs = tf.keras.layers.Input((None, None, IMG_CHANNELS))
    # down
    c1 = tf.keras.layers.Conv2D(16, 3, activation='relu', padding='same')(inputs)
    c1 = tf.keras.layers.Dropout(0.1)(c1)
    c1 = tf.keras.layers.Conv2D(16, 3, activation='relu', padding='same')(c1)
    p1 = tf.keras.layers.MaxPooling2D()(c1)
    c2 = tf.keras.layers.Conv2D(32, 3, activation='relu', padding='same')(p1)
    c2 = tf.keras.layers.Dropout(0.1)(c2)
    c2 = tf.keras.layers.Conv2D(32, 3, activation='relu', padding='same')(c2)
    p2 = tf.keras.layers.MaxPooling2D()(c2)
    c3 = tf.keras.layers.Conv2D(64, 3, activation='relu', padding='same')(p2)
    c3 = tf.keras.layers.Dropout(0.1)(c3)
    c3 = tf.keras.layers.Conv2D(64, 3, activation='relu', padding='same')(c3)
    p3 = tf.keras.layers.MaxPooling2D()(c3)
    c4 = tf.keras.layers.Conv2D(128,3, activation='relu', padding='same')(p3)
    c4 = tf.keras.layers.Dropout(0.1)(c4)
    c4 = tf.keras.layers.Conv2D(128,3, activation='relu', padding='same')(c4)
    p4 = tf.keras.layers.MaxPooling2D()(c4)
    # bottleneck
    c5 = tf.keras.layers.Conv2D(256,3, activation='relu', padding='same')(p4)
    c5 = tf.keras.layers.Dropout(0.1)(c5)
    c5 = tf.keras.layers.Conv2D(256,3, activation='relu', padding='same')(c5)
    # up
    u6 = tf.keras.layers.Conv2DTranspose(128,2, strides=2, padding='same')(c5)
    u6 = tf.keras.layers.concatenate([u6, c4])
    c6 = tf.keras.layers.Conv2D(128,3, activation='relu', padding='same')(u6)
    c6 = tf.keras.layers.Dropout(0.1)(c6)
    c6 = tf.keras.layers.Conv2D(128,3, activation='relu', padding='same')(c6)
    u7 = tf.keras.layers.Conv2DTranspose(64,2, strides=2, padding='same')(c6)
    u7 = tf.keras.layers.concatenate([u7, c3])
    c7 = tf.keras.layers.Conv2D(64,3, activation='relu', padding='same')(u7)
    c7 = tf.keras.layers.Dropout(0.1)(c7)
    c7 = tf.keras.layers.Conv2D(64,3, activation='relu', padding='same')(c7)
    u8 = tf.keras.layers.Conv2DTranspose(32,2, strides=2, padding='same')(c7)
    u8 = tf.keras.layers.concatenate([u8, c2])
    c8 = tf.keras.layers.Conv2D(32,3, activation='relu', padding='same')(u8)
    c8 = tf.keras.layers.Dropout(0.1)(c8)
    c8 = tf.keras.layers.Conv2D(32,3, activation='relu', padding='same')(c8)
    u9 = tf.keras.layers.Conv2DTranspose(16,2, strides=2, padding='same')(c8)
    u9 = tf.keras.layers.concatenate([u9, c1])
    c9 = tf.keras.layers.Conv2D(16,3, activation='relu', padding='same')(u9)
    c9 = tf.keras.layers.Dropout(0.1)(c9)
    c9 = tf.keras.layers.Conv2D(16,3, activation='relu', padding='same')(c9)
    outputs = tf.keras.layers.Conv2D(1,1, activation='sigmoid')(c9)
    model = tf.keras.models.Model(inputs, outputs)
    model.compile(
            optimizer=tf.keras.optimizers.Adam(learning_rate=LearnRate),    # ← use Adam, not adam_v2
            loss=bce_dice_loss,
            metrics=[dice_coef]
        )
    return model


In [10]:
from scipy.optimize   import linear_sum_assignment
import scipy           # for scipy.spatial.distance
# %% Cell 5 (update): Augmentation function without invalid always_apply
def albumentation_aug(p=1.0):
    return Compose([
        CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), p=0.5),
        RandomBrightnessContrast(brightness_limit=0.15,
                                 contrast_limit=0.15,
                                 brightness_by_max=True,
                                 p=0.4),
    ], p=p)


def get_fast_aji(true, pred):
    """AJI version distributed by MoNuSeg, has no permutation problem but suffered from 
    over-penalisation similar to DICE2.
    Fast computation requires instance IDs are in contiguous orderding i.e [1, 2, 3, 4] 
    not [2, 3, 6, 10]. Please call `remap_label` before hand and `by_size` flag has no 
    effect on the result.
    """
    true = np.copy(true)  # ? do we need this
    pred = np.copy(pred)
    true_id_list = list(np.unique(true))
    pred_id_list = list(np.unique(pred))
    #print(len(pred_id_list))
    if len(pred_id_list) == 1:
        return 0

    true_masks = [None,]
    for t in true_id_list[1:]:
        t_mask = np.array(true == t, np.uint8)
        true_masks.append(t_mask)

    pred_masks = [None,]
    for p in pred_id_list[1:]:
        p_mask = np.array(pred == p, np.uint8)
        pred_masks.append(p_mask)

    # prefill with value
    pairwise_inter = np.zeros(
        [len(true_id_list) - 1, len(pred_id_list) - 1], dtype=np.float64
    )
    pairwise_union = np.zeros(
        [len(true_id_list) - 1, len(pred_id_list) - 1], dtype=np.float64
    )

    # caching pairwise
    for true_id in true_id_list[1:]:  # 0-th is background
        t_mask = true_masks[true_id]
        pred_true_overlap = pred[t_mask > 0]
        pred_true_overlap_id = np.unique(pred_true_overlap)
        pred_true_overlap_id = list(pred_true_overlap_id)
        for pred_id in pred_true_overlap_id:
            if pred_id == 0:  # ignore
                continue  # overlaping background
            p_mask = pred_masks[pred_id]
            total = (t_mask + p_mask).sum()
            inter = (t_mask * p_mask).sum()
            pairwise_inter[true_id - 1, pred_id - 1] = inter
            pairwise_union[true_id - 1, pred_id - 1] = total - inter

    pairwise_iou = pairwise_inter / (pairwise_union + 1.0e-6)
    # pair of pred that give highest iou for each true, dont care
    # about reusing pred instance multiple times
    paired_pred = np.argmax(pairwise_iou, axis=1)
    pairwise_iou = np.max(pairwise_iou, axis=1)
    # exlude those dont have intersection
    paired_true = np.nonzero(pairwise_iou > 0.0)[0]
    paired_pred = paired_pred[paired_true]
    # print(paired_true.shape, paired_pred.shape)
    overall_inter = (pairwise_inter[paired_true, paired_pred]).sum()
    overall_union = (pairwise_union[paired_true, paired_pred]).sum()

    paired_true = list(paired_true + 1)  # index to instance ID
    paired_pred = list(paired_pred + 1)
    # add all unpaired GT and Prediction into the union
    unpaired_true = np.array(
        [idx for idx in true_id_list[1:] if idx not in paired_true]
    )
    unpaired_pred = np.array(
        [idx for idx in pred_id_list[1:] if idx not in paired_pred]
    )
    for true_id in unpaired_true:
        overall_union += true_masks[true_id].sum()
    for pred_id in unpaired_pred:
        overall_union += pred_masks[pred_id].sum()

    aji_score = overall_inter / overall_union
    #print(aji_score)
    return aji_score

def get_fast_pq(true, pred, match_iou=0.5):
    """`match_iou` is the IoU threshold level to determine the pairing between
    GT instances `p` and prediction instances `g`. `p` and `g` is a pair
    if IoU > `match_iou`. However, pair of `p` and `g` must be unique 
    (1 prediction instance to 1 GT instance mapping).
    If `match_iou` < 0.5, Munkres assignment (solving minimum weight matching
    in bipartite graphs) is caculated to find the maximal amount of unique pairing. 
    If `match_iou` >= 0.5, all IoU(p,g) > 0.5 pairing is proven to be unique and
    the number of pairs is also maximal.    
    
    Fast computation requires instance IDs are in contiguous orderding 
    i.e [1, 2, 3, 4] not [2, 3, 6, 10]. Please call `remap_label` beforehand 
    and `by_size` flag has no effect on the result.
    Returns:
        [dq, sq, pq]: measurement statistic
        [paired_true, paired_pred, unpaired_true, unpaired_pred]: 
                      pairing information to perform measurement
                    
    """
    assert match_iou >= 0.0, "Cant' be negative"

    true = np.copy(true)
    pred = np.copy(pred)
    true_id_list = list(np.unique(true))
    pred_id_list = list(np.unique(pred))
    
    if len(pred_id_list) == 1:
        return [0, 0, 0], [0,0, 0, 0]

    true_masks = [
        None,
    ]
    for t in true_id_list[1:]:
        t_mask = np.array(true == t, np.uint8)
        true_masks.append(t_mask)

    pred_masks = [
        None,
    ]
    for p in pred_id_list[1:]:
        p_mask = np.array(pred == p, np.uint8)
        pred_masks.append(p_mask)

    # prefill with value
    pairwise_iou = np.zeros(
        [len(true_id_list) - 1, len(pred_id_list) - 1], dtype=np.float64
    )

    # caching pairwise iou
    for true_id in true_id_list[1:]:  # 0-th is background
        t_mask = true_masks[true_id]
        pred_true_overlap = pred[t_mask > 0]
        pred_true_overlap_id = np.unique(pred_true_overlap)
        pred_true_overlap_id = list(pred_true_overlap_id)
        for pred_id in pred_true_overlap_id:
            if pred_id == 0:  # ignore
                continue  # overlaping background
            p_mask = pred_masks[pred_id]
            total = (t_mask + p_mask).sum()
            inter = (t_mask * p_mask).sum()
            iou = inter / (total - inter)
            pairwise_iou[true_id - 1, pred_id - 1] = iou
    #
    if match_iou >= 0.5:
        paired_iou = pairwise_iou[pairwise_iou > match_iou]
        pairwise_iou[pairwise_iou <= match_iou] = 0.0
        paired_true, paired_pred = np.nonzero(pairwise_iou)
        paired_iou = pairwise_iou[paired_true, paired_pred]
        paired_true += 1  # index is instance id - 1
        paired_pred += 1  # hence return back to original
    else:  # * Exhaustive maximal unique pairing
        #### Munkres pairing with scipy library
        # the algorithm return (row indices, matched column indices)
        # if there is multiple same cost in a row, index of first occurence
        # is return, thus the unique pairing is ensure
        # inverse pair to get high IoU as minimum
        paired_true, paired_pred = linear_sum_assignment(-pairwise_iou)
        ### extract the paired cost and remove invalid pair
        paired_iou = pairwise_iou[paired_true, paired_pred]

        # now select those above threshold level
        # paired with iou = 0.0 i.e no intersection => FP or FN
        paired_true = list(paired_true[paired_iou > match_iou] + 1)
        paired_pred = list(paired_pred[paired_iou > match_iou] + 1)
        paired_iou = paired_iou[paired_iou > match_iou]

    # get the actual FP and FN
    unpaired_true = [idx for idx in true_id_list[1:] if idx not in paired_true]
    unpaired_pred = [idx for idx in pred_id_list[1:] if idx not in paired_pred]
    # print(paired_iou.shape, paired_true.shape, len(unpaired_true), len(unpaired_pred))

    #
    tp = len(paired_true)
    fp = len(unpaired_pred)
    fn = len(unpaired_true)
    # get the F1-score i.e DQ
    dq = tp / (tp + 0.5 * fp + 0.5 * fn)
    # get the SQ, no paired has 0 iou so not impact
    sq = paired_iou.sum() / (tp + 1.0e-6)

    return [dq, sq, dq * sq], [paired_true, paired_pred, unpaired_true, unpaired_pred]


#############################################################################################################
def get_dice_1(true, pred):
    """Traditional dice."""
    # cast to binary 1st
    true = np.copy(true)
    pred = np.copy(pred)
    true[true > 0] = 1
    pred[pred > 0] = 1
    inter = true * pred
    denom = true + pred
    dice_score = 2.0 * np.sum(inter) / (np.sum(denom) + 0.0001)
    if np.sum(inter)==0 and np.sum(denom)==0:
        dice_score = 1 # to handel cases without any nuclei
    #print(dice_score)
    return dice_score

#############################################################################################################
def remap_label(pred, by_size=False):
    """Rename all instance id so that the id is contiguous i.e [0, 1, 2, 3] 
    not [0, 2, 4, 6]. The ordering of instances (which one comes first) 
    is preserved unless by_size=True, then the instances will be reordered
    so that bigger nucler has smaller ID.
    Args:
        pred    : the 2d array contain instances where each instances is marked
                  by non-zero integer
        by_size : renaming with larger nuclei has smaller id (on-top)
    """
    pred_id = list(np.unique(pred))
    pred_id.remove(0)
    if len(pred_id) == 0:
        return pred  # no label
    if by_size:
        pred_size = []
        for inst_id in pred_id:
            size = (pred == inst_id).sum()
            pred_size.append(size)
        # sort the id by size in descending order
        pair_list = zip(pred_id, pred_size)
        pair_list = sorted(pair_list, key=lambda x: x[1], reverse=True)
        pred_id, pred_size = zip(*pair_list)

    new_pred = np.zeros(pred.shape, np.int32)
    for idx, inst_id in enumerate(pred_id):
        new_pred[pred == inst_id] = idx + 1
    return new_pred

def get_id_from_file_path(fp, indicator):
    return os.path.basename(fp).replace(indicator, '')

def chunker(seq, seq2, size):
    return ([seq[i:i+size], seq2[i:i+size]] for i in range(0, len(seq), size))


In [11]:
# %% Cell 6: Data Generator (fixed to yield float32 masks)
def data_gen(list_files, list_masks, batch_size, p,
            augment=False):
    aug = albumentation_aug(p)
    while True:
        for batch_imgs, batch_msks in chunker(list_files, list_masks, batch_size):
            X, Y = [], []
            for img_p, m_p in zip(batch_imgs, batch_msks):
                x = cv2.imread(img_p)
                x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)
                m = cv2.imread(m_p, cv2.IMREAD_GRAYSCALE)
                # binarize mask (0 or 1)
                m_bin = (m == 255).astype(np.uint8)

                if augment:
                    augd = aug(image=x, mask=m_bin)
                    x, m_bin = augd['image'], augd['mask']

                X.append(x / 255.0)
                Y.append(m_bin)

            # convert to float32 so loss/metrics see float32 * float32
            X = np.array(X, dtype=np.float32)
            Y = np.expand_dims(np.array(Y, dtype=np.float32), -1)

            yield X, Y


In [12]:

# %% Cell 7: Train → Validation → Test Evaluation
# callbacks
logger     = tf.keras.callbacks.CSVLogger(opts['model_save_path'] + 'unet.log')
lr_drop    = step_decay_schedule(
                initial_lr   = opts['init_LR'],
                decay_factor = opts['LR_decay_factor'],
                epochs_drop  = opts['LR_drop_after_nth_epoch']
            )
checkpoint = tf.keras.callbacks.ModelCheckpoint(
    opts['model_save_path'] + 'unet.weights.h5',   # ← note the new extension
    monitor='val_dice_coef',
    verbose=1,
    save_best_only=True,
    mode='max',
    save_weights_only=True
)

# build model
model = deep_unet(opts['number_of_channel'], opts['init_LR'])

# train
history = model.fit(
    data_gen(train_img, train_mask, opts['batch_size'], 1,
              augment=True),
    validation_data = data_gen(val_img, val_mask, opts['batch_size'], 1,
                            augment=False),
    steps_per_epoch   = len(train_img)//opts['batch_size'],
    validation_steps  = max(1, len(val_img)//opts['batch_size']),
    epochs            = opts['epoch_num'],
    callbacks         = [checkpoint, logger, lr_drop],
    verbose           = 1
)

# load best
model.load_weights(opts['model_save_path'] + 'unet.weights.h5')

# evaluate on test
preds = model.predict(
    data_gen(test_img, test_mask, batch_size=1, p=1, augment=False),
    steps=len(test_img)
)
preds_bin = (preds>opts['treshold']).astype('uint8')

dice_scores = []
for i, gt_p in enumerate(test_mask):
    gt = cv2.imread(gt_p, cv2.IMREAD_GRAYSCALE)
    pr = cv2.resize(preds_bin[i,...,0], (gt.shape[1], gt.shape[0]),
                    interpolation=cv2.INTER_NEAREST)
    dice_scores.append(get_dice_1(gt, pr))

print(f"Test Dice: {np.mean(dice_scores):.4f} ± {np.std(dice_scores):.4f}")



Epoch 1: LearningRateScheduler setting learning rate to 0.001.
Epoch 1/30
[1m33/33[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14s/step - dice_coef: 0.2239 - loss: 0.0370 
Epoch 1: val_dice_coef improved from -inf to 0.25334, saving model to output_model/unet.weights.h5
[1m33/33[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m495s[0m 15s/step - dice_coef: 0.2240 - loss: 0.0354 - val_dice_coef: 0.2533 - val_loss: -0.0844 - learning_rate: 0.0010

Epoch 2: LearningRateScheduler setting learning rate to 0.001.
Epoch 2/30
[1m33/33[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14s/step - dice_coef: 0.3226 - loss: -0.1040 
Epoch 2: val_dice_coef improved from 0.25334 to 0.37702, saving model to output_model/unet.weights.h5
[1m33/33[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m460s[0m 14s/step - dice_coef: 0.3231 - loss: -0.1031 - val_dice_coef: 0.3770 - val_loss: -0.2134 - learning_rate: 0.0010

Epoch 3: LearningRateScheduler setting learning rate to 0.001.
Epoch 3/3