# CIA-Net: Contour-aware Information Aggregation Network

## 1. Environment Setup & Configuration
Initializes the runtime environment, installs dependencies (Albumentations), and defines global hyperparameters ensuring reproducibility.

In [None]:
import os
import sys
import random
import warnings
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, models, backend as K
import cv2
import matplotlib.pyplot as plt
import xml.etree.ElementTree as ET
from glob import glob
from tqdm import tqdm
from scipy import ndimage as ndi
from skimage.morphology import remove_small_objects
from skimage.measure import label

# --- 3rd Party Libraries ---
os.system('pip install -q albumentations')
import albumentations as A

# --- Configuration ---
warnings.filterwarnings('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
tf.keras.mixed_precision.set_global_policy('float32')

# --- Hyperparameters ---
SEED = 42
PATCH_SIZE = (224, 224)
IMG_SIZE = PATCH_SIZE
BATCH_SIZE = 8
EPOCHS = 100
LAMBDA_CONTOUR = 0.42  # Paper Eq. 4
GAMMA_STL = 0.2

# --- Phase Configs ---
WARMUP_EPOCHS = 20
RESCUE_EPOCHS = 30
SOTA_EPOCHS = 100

# --- Paths ---
DRIVE_MOUNT_PATH = "/content/drive"
DRIVE_DATA_PATH = "/content/drive/MyDrive/"
MONUSEG_TRAIN_ZIP = "MoNuSeg_Training_Data.zip"
MONUSEG_TEST_ZIP = "MoNuSeg_Test_Data.zip"
TRAIN_EXTRACT_DIR = "/content/MoNuSeg_Train"
TEST_EXTRACT_DIR = "/content/MoNuSeg_Test"

# --- Reproducibility ---
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

set_seed(SEED)

# --- UI Utilities ---
class Colors:
    HEADER = '\033[95m'
    INFO = '\033[94m'
    SUCCESS = '\033[92m'
    WARNING = '\033[93m'
    ERROR = '\033[91m'
    ENDC = '\033[0m'

print(f"{Colors.SUCCESS}[OK] Environment Initialized. IMG_SIZE: {IMG_SIZE}{Colors.ENDC}")

## 2. Data Acquisition & Ingestion
Mounts Google Drive and safely extracts the MoNuSeg dataset zips into the local runtime storage for high-speed I/O.

In [None]:
from google.colab import drive
import zipfile

if not os.path.exists(DRIVE_MOUNT_PATH):
    drive.mount(DRIVE_MOUNT_PATH)

def unzip_data(zip_path, extract_to):
    full_zip_path = os.path.join(DRIVE_DATA_PATH, zip_path)
    if not os.path.exists(full_zip_path):
        print(f"{Colors.ERROR}[ERROR] File not found: {full_zip_path}{Colors.ENDC}")
        return False
    if not os.path.exists(extract_to):
        os.makedirs(extract_to)
        with zipfile.ZipFile(full_zip_path, 'r') as zip_ref:
            zip_ref.extractall(extract_to)
    return True

status_train = unzip_data(MONUSEG_TRAIN_ZIP, TRAIN_EXTRACT_DIR)
status_test = unzip_data(MONUSEG_TEST_ZIP, TEST_EXTRACT_DIR)

if not (status_train and status_test):
    raise FileNotFoundError("Required Dataset Zip files not found in Drive.")
print(f"{Colors.SUCCESS}[OK] Data Extracted Successfully.{Colors.ENDC}")

## 3. Data Processing Pipeline
Handles high-resolution image loading, XML annotation parsing, staining normalization (optional), and builds the TensorFlow Data Pipeline with advanced augmentations (Albumentations).

In [None]:
# --- 3.1 Preprocessing Utilities ---

def normalize_staining(img, Io=240, alpha=1, beta=0.15):
    try:
        HER = np.array([[0.650, 0.704, 0.286], [0.072, 0.990, 0.105], [0.268, 0.570, 0.776]])
        h, w, c = img.shape
        img = img.reshape((-1, 3))
        OD = -np.log((img.astype(float) + 1) / Io)
        ODhat = OD[np.all(OD > beta, axis=1)]
        if len(ODhat) < 10: return img.reshape((h, w, c))
        eigvals, eigvecs = np.linalg.eigh(np.cov(ODhat.T))
        That = ODhat.dot(eigvecs[:, 1:3])
        phi = np.arctan2(That[:, 1], That[:, 0])
        minPhi, maxPhi = np.percentile(phi, alpha), np.percentile(phi, 100 - alpha)
        vMin = eigvecs[:, 1:3].dot(np.array([(np.cos(minPhi), np.sin(minPhi))]).T)
        vMax = eigvecs[:, 1:3].dot(np.array([(np.cos(maxPhi), np.sin(maxPhi))]).T)
        HE = np.array((vMin[:, 0], vMax[:, 0])).T if vMin[0] > vMax[0] else np.array((vMax[:, 0], vMin[:, 0])).T
        Y = np.reshape(OD, (-1, 3)).T
        C = np.linalg.lstsq(HE, Y, rcond=None)[0]
        maxC = np.array([1.9705, 1.0308])
        C = np.array([C[0] / maxC[0], C[1] / maxC[1]])
        Inorm = Io * np.exp(-np.dot(HER[:, 0:2], C * maxC[:, np.newaxis]))
        return np.clip(np.reshape(Inorm.T, (h, w, c)), 0, 255).astype(np.uint8)
    except: return img.reshape((h, w, c))

def process_xml_annotations(xml_path, image_shape):
    tree = ET.parse(xml_path); root = tree.getroot()
    nuclei_mask = np.zeros(image_shape[:2], dtype=np.uint8)
    for region in root.findall(".//Region"):
        points = [[float(v.get('X')), float(v.get('Y'))] for v in region.findall(".//Vertex")]
        if len(points) > 0:
            pts = np.array(points, np.int32).reshape((-1, 1, 2))
            cv2.fillPoly(nuclei_mask, [pts], 255)

    kernel = np.ones((3,3), np.uint8)
    dilated = cv2.dilate(nuclei_mask, kernel, iterations=1)
    eroded = cv2.erode(nuclei_mask, kernel, iterations=1)
    contour_mask = (dilated - eroded > 0).astype(np.uint8) * 255
    return nuclei_mask, contour_mask

def load_data_high_res(data_dir):
    image_paths = sorted(glob(os.path.join(data_dir, "**", "*.tif"), recursive=True))
    images, nuclei, contours = [], [], []
    print(f"{Colors.INFO}Loading High-Res Data from: {data_dir}{Colors.ENDC}")

    for p in tqdm(image_paths):
        base = os.path.splitext(os.path.basename(p))[0]
        xml = list(glob(os.path.join(data_dir, "**", f"{base}.xml"), recursive=True))
        if not xml: continue

        img = cv2.cvtColor(cv2.imread(p), cv2.COLOR_BGR2RGB)
        # img = normalize_staining(img) # Optional: Macenko norm

        n_mask, c_mask = process_xml_annotations(xml[0], img.shape)

        images.append(img)
        nuclei.append(n_mask)
        contours.append(c_mask)

    return images, nuclei, contours

# --- 3.2 Loading Raw Data ---
X_train_raw, y_nuc_train_raw, y_con_train_raw = load_data_high_res(TRAIN_EXTRACT_DIR)
X_test_raw, y_nuc_test_raw, y_con_test_raw = load_data_high_res(TEST_EXTRACT_DIR)

# --- 3.3 Augmentation Pipeline ---
train_transform = A.Compose([
    A.RandomCrop(width=PATCH_SIZE[1], height=PATCH_SIZE[0]),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=0.5),
    A.GridDistortion(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
], additional_targets={'nuclei': 'mask', 'contour': 'mask'})

test_transform = A.Compose([
    A.CenterCrop(width=PATCH_SIZE[1], height=PATCH_SIZE[0])
], additional_targets={'nuclei': 'mask', 'contour': 'mask'})

def process_data_train_flat(img, nuc, con):
    data = {"image": img, "nuclei": nuc, "contour": con}
    aug = train_transform(**data)

    img = aug["image"].astype(np.float32) / 255.0
    nuc = (aug["nuclei"] > 127).astype(np.float32)[..., np.newaxis]
    con = (aug["contour"] > 127).astype(np.float32)[..., np.newaxis]
    return img, nuc, con

def process_data_test_flat(img, nuc, con):
    data = {"image": img, "nuclei": nuc, "contour": con}
    aug = test_transform(**data)

    img = aug["image"].astype(np.float32) / 255.0
    nuc = (aug["nuclei"] > 127).astype(np.float32)[..., np.newaxis]
    con = (aug["contour"] > 127).astype(np.float32)[..., np.newaxis]
    return img, nuc, con

def tf_process_wrapper(img, nuc, con, is_train=True):
    [img_out, nuc_out, con_out] = tf.numpy_function(
        func=process_data_train_flat if is_train else process_data_test_flat,
        inp=[img, nuc, con],
        Tout=[tf.float32, tf.float32, tf.float32]
    )
    img_out.set_shape([224, 224, 3])
    nuc_out.set_shape([224, 224, 1])
    con_out.set_shape([224, 224, 1])
    return img_out, {'nuclei_output': nuc_out, 'contour_output': con_out}

# --- 3.4 Dataset Generator ---
def get_generator(images, nucs, cons, is_train=True):
    def gen():
        for i in range(len(images)):
            repeats = 10 if is_train else 1
            for _ in range(repeats):
                yield images[i], nucs[i], cons[i]
    return gen

output_signature = (
    tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8),
    tf.TensorSpec(shape=(None, None), dtype=tf.uint8),
    tf.TensorSpec(shape=(None, None), dtype=tf.uint8)
)

train_ds = tf.data.Dataset.from_generator(
    get_generator(X_train_raw, y_nuc_train_raw, y_con_train_raw, True),
    output_signature=output_signature
).map(
    lambda i, n, c: tf_process_wrapper(i, n, c, is_train=True),
    num_parallel_calls=tf.data.AUTOTUNE
).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

test_ds = tf.data.Dataset.from_generator(
    get_generator(X_test_raw, y_nuc_test_raw, y_con_test_raw, False),
    output_signature=output_signature
).map(
    lambda i, n, c: tf_process_wrapper(i, n, c, is_train=False),
    num_parallel_calls=tf.data.AUTOTUNE
).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

# --- Validation Check ---
print(f"\n{Colors.WARNING}Pipeline Integrity Check (Masks must be 0-1):{Colors.ENDC}")
for img, target in train_ds.take(1):
    nuc = target['nuclei_output'][0].numpy()
    print(f"Nuclei Max: {nuc.max()}, Min: {nuc.min()}")
    if nuc.max() > 1.0:
        print(f"{Colors.ERROR}CRITICAL: Mask values > 1.0!{Colors.ENDC}")
    else:
        print(f"{Colors.SUCCESS}Data Format Correct.{Colors.ENDC}")
    break

## 4. CIA-Net Architecture
Constructs the functional Keras model. Key components include:
1.  **Encoder:** DenseNet121 (pre-trained).
2.  **IAM (Information Aggregation Module):** Facilitates flow between Nucleus and Contour tasks.
3.  **Decoder:** Multi-scale feature fusion.

In [None]:
def IAM_Module(nuc_feat, con_feat, filters, name="IAM"):
    concat = layers.Concatenate(name=f"{name}_concat")([nuc_feat, con_feat])
    smooth = layers.Conv2D(filters, 3, padding='same', activation=None, name=f"{name}_smooth")(concat)
    nuc_refine = layers.Conv2D(filters, 3, padding='same', activation='relu', name=f"{name}_nuc_refine")(smooth)
    con_refine = layers.Conv2D(filters, 3, padding='same', activation='relu', name=f"{name}_con_refine")(smooth)
    return nuc_refine, con_refine

def build_cia_net(input_shape):
    inputs = layers.Input(shape=input_shape)
    base_model = tf.keras.applications.DenseNet121(include_top=False, weights='imagenet', input_tensor=inputs)

    # Encoder Features
    enc_block1 = base_model.get_layer('conv1_relu').output               # 112x112
    enc_block2 = base_model.get_layer('conv2_block6_concat').output      # 56x56
    enc_block3 = base_model.get_layer('conv3_block12_concat').output     # 28x28
    enc_block4 = base_model.get_layer('conv4_block24_concat').output     # 14x14
    enc_bottleneck = base_model.get_layer('relu').output                 # 7x7

    # --- DECODER ---
    # Level 4
    x = layers.Conv2D(256, 3, padding='same', activation='relu')(enc_bottleneck)
    x = layers.UpSampling2D()(x)
    enc4_lat = layers.Conv2D(256, 1, padding='same')(enc_block4)
    nuc_m4 = layers.Add()([x, enc4_lat])
    con_m4 = layers.Add()([x, enc4_lat])
    nuc_d4, con_d4 = IAM_Module(nuc_m4, con_m4, 256, name="IAM_4")

    # Level 3
    nuc_up3 = layers.Conv2D(128, 1, padding='same')(layers.UpSampling2D()(nuc_d4))
    con_up3 = layers.Conv2D(128, 1, padding='same')(layers.UpSampling2D()(con_d4))
    enc3_lat = layers.Conv2D(128, 1, padding='same')(enc_block3)
    nuc_m3 = layers.Add()([nuc_up3, enc3_lat])
    con_m3 = layers.Add()([con_up3, enc3_lat])
    nuc_d3, con_d3 = IAM_Module(nuc_m3, con_m3, 128, name="IAM_3")

    # Level 2
    nuc_up2 = layers.Conv2D(64, 1, padding='same')(layers.UpSampling2D()(nuc_d3))
    con_up2 = layers.Conv2D(64, 1, padding='same')(layers.UpSampling2D()(con_d3))
    enc2_lat = layers.Conv2D(64, 1, padding='same')(enc_block2)
    nuc_m2 = layers.Add()([nuc_up2, enc2_lat])
    con_m2 = layers.Add()([con_up2, enc2_lat])
    nuc_d2, con_d2 = IAM_Module(nuc_m2, con_m2, 64, name="IAM_2")

    # Level 1
    nuc_up1 = layers.Conv2D(32, 1, padding='same')(layers.UpSampling2D()(nuc_d2))
    con_up1 = layers.Conv2D(32, 1, padding='same')(layers.UpSampling2D()(con_d2))
    enc1_lat = layers.Conv2D(32, 1, padding='same')(enc_block1)
    nuc_m1 = layers.Add()([nuc_up1, enc1_lat])
    con_m1 = layers.Add()([con_up1, enc1_lat])
    nuc_d1, con_d1 = IAM_Module(nuc_m1, con_m1, 32, name="IAM_1")

    # Output
    nuc_final = layers.UpSampling2D()(nuc_d1)
    con_final = layers.UpSampling2D()(con_d1)

    nuc_out = layers.Conv2D(1, 1, activation='sigmoid', name='nuclei_output', dtype='float32')(nuc_final)
    con_out = layers.Conv2D(1, 1, activation='sigmoid', name='contour_output', dtype='float32')(con_final)

    return models.Model(inputs=inputs, outputs=[nuc_out, con_out], name="CIA-Net")

model = build_cia_net((224, 224, 3))
print(f"{Colors.SUCCESS}[OK] CIA-Net Model Built.{Colors.ENDC}")

## 5. Training Components
Definitions for custom Loss Functions, Metrics, and the specific AJI Monitor Callback used for model checkpointing.

In [None]:
# --- Loss Functions ---
class SmoothTruncatedLoss(tf.keras.losses.Loss):
    def __init__(self, gamma=0.2, name="smooth_truncated_loss"):
        super().__init__(name=name)
        self.gamma = gamma
    def call(self, y_true, y_pred):
        y_true = tf.cast(y_true, tf.float32)
        y_pred = tf.cast(y_pred, tf.float32)
        y_pred = tf.clip_by_value(y_pred, 1e-7, 1.0 - 1e-7)
        pt = tf.where(tf.equal(y_true, 1), y_pred, 1 - y_pred)
        loss_outlier = -tf.math.log(self.gamma) + 0.5 * (1 - (pt**2)/(self.gamma**2))
        loss_inlier = -tf.math.log(pt)
        return tf.reduce_mean(tf.where(pt < self.gamma, loss_outlier, loss_inlier))

def soft_dice_loss(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32); y_pred = tf.cast(y_pred, tf.float32)
    numerator = 2 * tf.reduce_sum(y_true * y_pred) + 1e-5
    denominator = tf.reduce_sum(y_true**2) + tf.reduce_sum(y_pred**2) + 1e-5
    return 1 - (numerator / denominator)

def bce_dice_loss(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    bce = tf.keras.losses.binary_crossentropy(y_true, y_pred)
    smooth = 1e-5
    intersection = tf.reduce_sum(y_true * y_pred)
    union = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred)
    dice_loss = 1 - ((2. * intersection + smooth) / (union + smooth))
    return bce + dice_loss

# --- Metrics ---
def dice_coef(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32); y_pred = tf.cast(y_pred, tf.float32)
    y_pred_thresh = tf.cast(y_pred > 0.5, tf.float32)
    intersection = K.sum(y_true * y_pred_thresh)
    return (2. * intersection + 1e-5) / (K.sum(y_true) + K.sum(y_pred_thresh) + 1e-5)

# --- Callbacks ---
class AJIMonitor(tf.keras.callbacks.Callback):
    def __init__(self, val_dataset, freq=2):
        super().__init__()
        self.val_dataset = val_dataset
        self.freq = freq
        self.best_aji = 0.0

    def calculate_aji_vectorized(self, gt_mask, pred_mask):
        # Fast approximation: IoU over global batch
        intersection = np.logical_and(gt_mask, pred_mask).sum()
        union = np.logical_or(gt_mask, pred_mask).sum()
        return intersection / (union + 1e-7)

    def on_epoch_end(self, epoch, logs=None):
        if (epoch + 1) % self.freq != 0: return

        aji_scores = []
        # Sample limited batch for speed (20 batches)
        for images, targets in self.val_dataset.take(20):
            preds = self.model.predict(images, verbose=0)
            nucs = preds[0]

            for i in range(images.shape[0]):
                pred_mask = (nucs[i,:,:,0] > 0.5).astype(np.uint8)
                gt_mask = targets['nuclei_output'][i,:,:,0].numpy().astype(np.uint8)
                score = self.calculate_aji_vectorized(gt_mask, pred_mask)
                aji_scores.append(score)

        mean_aji = np.mean(aji_scores)
        print(f"\n{Colors.INFO}[AJI Monitor] Val AJI (Approx): {mean_aji:.4f} (Best: {self.best_aji:.4f}){Colors.ENDC}")
        logs['val_aji'] = mean_aji

        if mean_aji > self.best_aji:
            self.best_aji = mean_aji
            self.model.save_weights(os.path.join(DRIVE_DATA_PATH, 'cia_net_best_aji.weights.h5'))

print(f"{Colors.SUCCESS}[OK] Loss Functions & Monitors Ready.{Colors.ENDC}")

## 6. Execution: 3-Stage Training Strategy
Implements a production-grade automated training loop:
1.  **Warm-up:** Frozen encoder, high LR.
2.  **Stabilization:** Full network, moderate LR, standard loss.
3.  **SOTA Fine-tuning:** Advanced Optimizer (AdamW + Cosine Decay) & Robust Loss.

In [None]:
# --- 0. SAFETY CHECKS ---
print(f"{Colors.INFO}Config: Warmup={WARMUP_EPOCHS}, Rescue={RESCUE_EPOCHS}, SOTA={SOTA_EPOCHS}{Colors.ENDC}")

if 'model' in globals() and 'history_3' in globals():
    print(f"{Colors.SUCCESS}Training appears complete. Saving backup weights.{Colors.ENDC}")
    model.save_weights(os.path.join(DRIVE_DATA_PATH, 'cia_net_final_sota.weights.h5'))
else:
    K.clear_session()
    try:
        model = build_cia_net((IMG_SIZE[0], IMG_SIZE[1], 3))
    except NameError:
        raise RuntimeError("Build function not found. Run previous cells.")

    print(f"{Colors.HEADER}Starting Fresh Model Training...{Colors.ENDC}")

    # ---------------------------------------------------------
    # PHASE 1: WARM-UP (Frozen Encoder)
    # ---------------------------------------------------------
    print(f"\n{Colors.INFO}>>> PHASE 1: WARM-UP (Frozen Encoder)...{Colors.ENDC}")
    for layer in model.layers:
        if 'densenet121' in layer.name or 'input' in layer.name:
            layer.trainable = False

    model.compile(
        optimizer=tf.keras.optimizers.Adam(1e-3),
        loss={'nuclei_output': bce_dice_loss, 'contour_output': bce_dice_loss},
        loss_weights={'nuclei_output': 1.0, 'contour_output': 0.5},
        metrics={'nuclei_output': ['accuracy', dice_coef], 'contour_output': [dice_coef]}
    )

    history_1 = model.fit(
        train_ds,
        validation_data=test_ds,
        epochs=WARMUP_EPOCHS,
        callbacks=[tf.keras.callbacks.CSVLogger(os.path.join(DRIVE_DATA_PATH, 'log_phase1.csv'))]
    )
    model.save_weights(os.path.join(DRIVE_DATA_PATH, 'weights_phase1.weights.h5'))

    # ---------------------------------------------------------
    # PHASE 2: STABILIZATION (Full Network)
    # ---------------------------------------------------------
    print(f"\n{Colors.INFO}>>> PHASE 2: STABILIZATION (Full Network)...{Colors.ENDC}")
    for layer in model.layers:
        layer.trainable = True

    model.compile(
        optimizer=tf.keras.optimizers.Adam(1e-4),
        loss={'nuclei_output': bce_dice_loss, 'contour_output': bce_dice_loss},
        loss_weights={'nuclei_output': 1.0, 'contour_output': 1.0},
        metrics={'nuclei_output': ['accuracy', dice_coef], 'contour_output': [dice_coef]}
    )

    history_2 = model.fit(
        train_ds,
        validation_data=test_ds,
        epochs=RESCUE_EPOCHS,
        callbacks=[
            tf.keras.callbacks.ModelCheckpoint(
                os.path.join(DRIVE_DATA_PATH, 'weights_phase2_best.keras'),
                save_best_only=True, monitor='val_nuclei_output_dice_coef', mode='max'
            ),
            tf.keras.callbacks.CSVLogger(os.path.join(DRIVE_DATA_PATH, 'log_phase2.csv'))
        ]
    )

    # ---------------------------------------------------------
    # PHASE 3: SOTA FINE-TUNING (Advanced)
    # ---------------------------------------------------------
    print(f"\n{Colors.INFO}>>> PHASE 3: SOTA FINE-TUNING...{Colors.ENDC}")
    try:
        model.load_weights(os.path.join(DRIVE_DATA_PATH, 'weights_phase2_best.keras'))
        print(f"{Colors.SUCCESS}Loaded Best Phase 2 Weights.{Colors.ENDC}")
    except:
        print(f"{Colors.WARNING}Warning: Phase 2 weights not found, continuing...{Colors.ENDC}")

    steps_per_epoch = len(X_train_raw) * 10 // BATCH_SIZE
    lr_schedule = tf.keras.optimizers.schedules.CosineDecayRestarts(
        initial_learning_rate=1e-4, first_decay_steps=20 * steps_per_epoch,
        t_mul=2.0, m_mul=0.9, alpha=1e-6
    )
    optimizer = tf.keras.optimizers.AdamW(
        learning_rate=lr_schedule, weight_decay=1e-4, global_clipnorm=1.0
    )

    model.compile(
        optimizer=optimizer,
        loss={'nuclei_output': SmoothTruncatedLoss(gamma=GAMMA_STL), 'contour_output': soft_dice_loss},
        loss_weights={'nuclei_output': 1.0, 'contour_output': LAMBDA_CONTOUR},
        metrics={
            'nuclei_output': [
                'accuracy', dice_coef,
                tf.keras.metrics.Precision(name='precision'),
                tf.keras.metrics.Recall(name='recall')
            ],
            'contour_output': [dice_coef]
        }
    )

    history_3 = model.fit(
        train_ds,
        validation_data=test_ds,
        epochs=SOTA_EPOCHS,
        callbacks=[
            tf.keras.callbacks.ModelCheckpoint(
                os.path.join(DRIVE_DATA_PATH, 'cia_net_final_sota.keras'),
                save_best_only=True, monitor='val_nuclei_output_dice_coef', mode='max', verbose=1
            ),
            tf.keras.callbacks.EarlyStopping(
                monitor='val_nuclei_output_dice_coef', patience=25, mode='max', restore_best_weights=True
            ),
            tf.keras.callbacks.CSVLogger(os.path.join(DRIVE_DATA_PATH, 'log_phase3.csv')),
            AJIMonitor(test_ds, freq=5)
        ]
    )

    print(f"{Colors.SUCCESS}All Training Phases Completed Successfully!{Colors.ENDC}")

## 7. Inference & Explainability (XAI)
Performs advanced post-processing (Watershed) and generates an XAI visualization panel showing Uncertainty, Error Maps, and Instance Segmentation results.

In [None]:
from skimage.segmentation import watershed
from skimage.feature import peak_local_max

# --- Post-Processing Logic ---
def sophisticated_post_processing(pred_nuc, pred_con):
    nuc_mask = (pred_nuc > 0.5).astype(np.uint8)
    con_mask = (pred_con > 0.3).astype(np.uint8)

    markers_raw = np.clip(nuc_mask - con_mask, 0, 1)
    kernel = np.ones((3,3), np.uint8)
    markers_clean = cv2.morphologyEx(markers_raw, cv2.MORPH_OPEN, kernel, iterations=1)

    distance = ndi.distance_transform_edt(markers_clean)
    coords = peak_local_max(distance, footprint=np.ones((3, 3)), labels=markers_clean)
    mask = np.zeros(distance.shape, dtype=bool)
    mask[tuple(coords.T)] = True
    markers, _ = ndi.label(mask)

    final_labels = watershed(-distance, markers, mask=nuc_mask)
    return (final_labels > 0).astype(float)

def paper_post_processing(pred_nuc, pred_con):
    diff = pred_nuc - pred_con
    binary = (diff > 0.3).astype(np.uint8)
    cleaned = cv2.morphologyEx(binary, cv2.MORPH_OPEN, np.ones((3,3), np.uint8))
    return cleaned.astype(float)

def compute_entropy(probs):
    probs = np.clip(probs, 1e-7, 1 - 1e-7)
    entropy = - (probs * np.log(probs) + (1 - probs) * np.log(1 - probs))
    return entropy

# --- Visualization Engine ---
def visualize_xai_quality(model, dataset, num_samples=3):
    images, targets = next(iter(dataset.take(1)))
    imgs = images[:num_samples]
    gt_nuc = targets['nuclei_output'][:num_samples]

    print(f"{Colors.INFO}Generating Predictions...{Colors.ENDC}")
    preds = model.predict(imgs, verbose=0)
    pred_nuc = preds[0]
    pred_con = preds[1]

    plt.figure(figsize=(24, 6 * num_samples))

    for i in range(num_samples):
        final_mask = sophisticated_post_processing(pred_nuc[i,:,:,0], pred_con[i,:,:,0])
        paper_mask = paper_post_processing(pred_nuc[i,:,:,0], pred_con[i,:,:,0])
        entropy = compute_entropy(pred_nuc[i,:,:,0])

        gt_mask = gt_nuc[i,:,:,0].numpy()
        error_map = np.zeros((224, 224, 3))
        error_map[(gt_mask == 1) & (final_mask == 1)] = [1, 1, 0] # Yellow: TP
        error_map[(gt_mask == 1) & (final_mask == 0)] = [1, 0, 0] # Red: FN
        error_map[(gt_mask == 0) & (final_mask == 1)] = [0, 0, 1] # Blue: FP

        titles = ["Original", "Nuclei Prob", "Contour Prob", "Uncertainty", "Paper Method", "SOTA Watershed", "Error Map (Y:TP)"]
        contents = [imgs[i], pred_nuc[i,:,:,0], pred_con[i,:,:,0], entropy, paper_mask, final_mask, error_map]
        cmaps = [None, 'jet', 'magma', 'inferno', 'gray', 'gray', None]

        for j, (title, content, cmap) in enumerate(zip(titles, contents, cmaps)):
            plt.subplot(num_samples, 7, i*7 + j + 1)
            plt.imshow(content, cmap=cmap)
            plt.title(title, fontsize=10)
            plt.axis('off')

    plt.tight_layout()
    save_path = os.path.join(DRIVE_DATA_PATH, 'advanced_xai_analysis_sota.png')
    plt.savefig(save_path, dpi=300)
    print(f"{Colors.SUCCESS}Visualization Saved: {save_path}{Colors.ENDC}")
    plt.show()

visualize_xai_quality(model, test_ds)

## 8. Quantitative Analysis (Training History)
Consolidates logs from all 3 training phases into a single, comprehensive visualization of performance convergence and stability.

In [None]:
import seaborn as sns

def plot_robust_training_history():
    log_files = [
        ('log_phase1.csv', 'Phase 1: Warm-up (Encoder Frozen)'),
        ('log_phase2.csv', 'Phase 2: Rescue/Stabilization (Full Net)'),
        ('log_phase3.csv', 'Phase 3: SOTA Fine-tuning (AdamW + Cosine)')
    ]

    dfs = []
    cumulative_epoch = 0
    phase_markers = []

    print(f"{Colors.INFO}Analyzing Training Logs...{Colors.ENDC}")

    for filename, label in log_files:
        path = os.path.join(DRIVE_DATA_PATH, filename)
        if os.path.exists(path):
            try:
                df = pd.read_csv(path)
                if len(df) > 0:
                    df['epoch'] = df['epoch'] + cumulative_epoch
                    dfs.append(df)
                    cumulative_epoch += len(df)
                    phase_markers.append((cumulative_epoch, label))
                    print(f"{filename} loaded ({len(df)} epochs)")
            except Exception as e:
                print(f"{filename} error: {e}")
        else:
            print(f"{filename} not found.")

    if not dfs:
        print(f"{Colors.ERROR}No logs found.{Colors.ENDC}")
        return

    full_df = pd.concat(dfs, ignore_index=True)
    sns.set_theme(style="whitegrid")
    fig, axes = plt.subplots(1, 2, figsize=(24, 8))

    color_train = '#2ecc71'
    color_val_nuc = '#3498db'
    color_val_con = '#e67e22'

    # --- Dice Score ---
    ax1 = axes[0]
    sns.lineplot(data=full_df, x='epoch', y='nuclei_output_dice_coef', ax=ax1, label='Train Nuclei Dice', color=color_train, linewidth=2, alpha=0.7)
    sns.lineplot(data=full_df, x='epoch', y='val_nuclei_output_dice_coef', ax=ax1, label='Val Nuclei Dice', color=color_val_nuc, linewidth=3)

    if 'val_contour_output_dice_coef' in full_df.columns:
        sns.lineplot(data=full_df, x='epoch', y='val_contour_output_dice_coef', ax=ax1, label='Val Contour Dice', color=color_val_con, linestyle='--', linewidth=2)

    best_idx = full_df['val_nuclei_output_dice_coef'].idxmax()
    best_epoch = full_df.loc[best_idx, 'epoch']
    best_score = full_df.loc[best_idx, 'val_nuclei_output_dice_coef']

    ax1.scatter(best_epoch, best_score, color='red', s=100, zorder=5, label=f'Best: {best_score:.4f}')
    ax1.set_title('Segmentation Performance (Dice Score)', fontsize=16, fontweight='bold')
    ax1.set_xlabel('Epochs', fontsize=14)
    ax1.set_ylabel('Dice Coefficient', fontsize=14)
    ax1.set_ylim(0, 1.0)
    ax1.legend(loc='lower right', fontsize=12, frameon=True)

    for epoch, label in phase_markers[:-1]:
        ax1.axvline(x=epoch, color='gray', linestyle=':', linewidth=2)
        ax1.text(epoch + 1, 0.05, label.split(':')[0], rotation=90, color='gray', fontweight='bold')

    # --- Loss ---
    ax2 = axes[1]
    sns.lineplot(data=full_df, x='epoch', y='loss', ax=ax2, label='Total Train Loss', color=color_train, linewidth=2, alpha=0.7)
    sns.lineplot(data=full_df, x='epoch', y='val_loss', ax=ax2, label='Total Val Loss', color='green', linewidth=3)

    ax2.set_title('Loss Convergence', fontsize=16, fontweight='bold')
    ax2.set_xlabel('Epochs', fontsize=14)
    ax2.set_ylabel('Loss Value', fontsize=14)
    ax2.legend(loc='upper right', fontsize=12, frameon=True)

    for epoch, label in phase_markers[:-1]:
        ax2.axvline(x=epoch, color='red', linestyle='--', linewidth=1.5)
        ax2.text(epoch + 1, full_df['loss'].max()*0.8, label.split(':')[0], rotation=90, color='red', fontweight='bold')

    plt.tight_layout()
    save_path = os.path.join(DRIVE_DATA_PATH, 'robust_training_history.png')
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"{Colors.SUCCESS}History Plot Saved: {save_path}{Colors.ENDC}")
    plt.show()

plot_robust_training_history()