In [None]:
import numpy as np
import matplotlib.pyplot as plt
from skimage import io, color
from skimage.transform import resize
import os
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import sys
from stardist import models, data
from stardist.models import Config2D
from stardist.plot import render_label, render_label_pred
from stardist.plot import random_label_cmap, draw_polygons
from csbdeep.utils import normalize
from stardist.matching import matching
from skimage.measure import label as sklabel
from torch.utils.data import random_split

### Data Loading and Processing

In [None]:
class NucleiDatasetStardist(Dataset):
    def __init__(self, image_dir, mask_dir, image_size=(256, 256)):
        self.image_size = image_size
        self.image_files = sorted([os.path.join(image_dir, f) for f in os.listdir(image_dir)])
        self.mask_files = sorted([os.path.join(mask_dir, f) for f in os.listdir(mask_dir)])
        assert len(self.image_files) == len(self.mask_files), "Number of images and masks should be the same."

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        mask_path = self.mask_files[idx]

        img = io.imread(img_path)
        mask = io.imread(mask_path)

        if img.ndim == 3 and img.shape[2] == 4:
            img = img[:, :, :3]
        if img.ndim == 3:
            img = color.rgb2gray(img)
        img_resized = resize(img, self.image_size, anti_aliasing=True).astype(np.float32)
        img_expanded = np.expand_dims(img_resized, axis=-1) # Add channel dimension

        if mask.ndim == 3:
            mask = color.rgb2gray(mask)
        mask_resized = resize(mask, self.image_size, anti_aliasing=False, order=0, preserve_range=True)
        mask_binary = mask_resized > 0.5
        mask_instance = sklabel(mask_binary).astype(np.uint16)

        return img_expanded, mask_instance

### Training StarDist

In [None]:
def train_stardist(train_loader, val_loader, model_name='stardist_nuclei', n_rays=32, epochs=10, learning_rate=1e-4):
    config = Config2D(
        n_rays=n_rays,
        n_channel_in=1,  # Assuming grayscale input images
        train_epochs=epochs,
        train_learning_rate=learning_rate,
        train_batch_size=4, # Match your DataLoader batch size
        # Add other configuration parameters as needed
    )

    model = models.StarDist2D(config, name=model_name)

    X_train = np.array([item[0] for item in train_loader.dataset])
    Y_train = np.array([item[1] for item in train_loader.dataset])
    X_val = np.array([item[0] for item in val_loader.dataset])
    Y_val = np.array([item[1] for item in val_loader.dataset])

    # Normalize images
    X_train = [normalize(x, 0, 1) for x in X_train]
    X_val = [normalize(x, 0, 1) for x in X_val]

    model.train(X_train, Y_train, validation_data=(X_val, Y_val),
                epochs=epochs)

    return model

### Prediction with StarDist

In [None]:
def predict_stardist(model, test_loader, prob_threshold=0.5, nms_threshold=0.4):
    predictions = []
    for img, _ in tqdm(test_loader, desc="Predicting"):
        img_np = img.squeeze().numpy()
        img_norm = normalize(img_np, 0, 1)
        # Pass the custom thresholds to predict_instances
        labels, details = model.predict_instances(img_norm, prob_thresh=prob_threshold, nms_thresh=nms_threshold)
        predictions.append(labels)
    return predictions

### Evaluation (Instance Segmentation)

In [None]:
def evaluate_stardist(ground_truth_masks, predictions, iou_threshold=0.5):
    mean_iou = 0
    total_tp = 0
    total_fp = 0
    total_fn = 0
    num_samples = len(ground_truth_masks)

    for gt_inst, pred_inst in zip(ground_truth_masks, predictions):
        if gt_inst.max() > 0 or pred_inst.max() > 0:
            stats = matching(gt_inst, pred_inst, thresh=iou_threshold)
            if stats is not None:
                mean_iou += stats.mean_iou if hasattr(stats, 'mean_iou') else stats.iou.mean() if hasattr(stats, 'iou') and len(stats.iou) > 0 else 0

                tp, fp, fn = stats.tp, stats.fp, stats.fn
                total_tp += tp
                total_fp += fp
                total_fn += fn

    mean_iou /= num_samples if num_samples > 0 else 0

    precision = total_tp / (total_tp + total_fp + 1e-8)
    recall = total_tp / (total_tp + total_fn + 1e-8)
    f1_score = 2 * precision * recall / (precision + recall + 1e-8)
    # This pixel-wise accuracy is likely not meaningful for instance segmentation
    accuracy = (total_tp + np.sum((np.array(ground_truth_masks) == 0) & (np.array(predictions) == 0))) / (np.array(ground_truth_masks).size + 1e-8)

    return mean_iou, precision, recall, f1_score, accuracy

### Visualization

In [None]:
def visualize_stardist_results(images, ground_truth_masks, predictions, num_samples=5, save_dir='results/stardist'):
    os.makedirs(save_dir, exist_ok=True)
    for i in range(min(num_samples, len(images))):
        plt.figure(figsize=(15, 5))

        plt.subplot(1, 3, 1)
        plt.imshow(images[i].squeeze(), cmap='gray')
        plt.title('Original Image')
        plt.axis('off')

        plt.subplot(1, 3, 2)
        plt.imshow(ground_truth_masks[i], cmap='nipy_spectral')
        plt.title('Ground Truth (Instance)')
        plt.axis('off')

        plt.subplot(1, 3, 3)
        plt.imshow(predictions[i], cmap='nipy_spectral')
        plt.title('StarDist Prediction (Instance)')
        plt.axis('off')

        plt.savefig(os.path.join(save_dir, f'stardist_result_{i}.png'))
        plt.close()

### Main Execution

In [None]:
if __name__ == '__main__':
    train_image_dir = 'NucleiSegmentationDataset/all_images'
    train_mask_dir = 'NucleiSegmentationDataset/merged_masks'
    test_image_dir = 'TestDataset/images'
    test_mask_dir = 'TestDataset/masks'

    # Create the full training dataset
    full_train_dataset = NucleiDatasetStardist(train_image_dir, train_mask_dir, image_size=(256, 256))

    # Calculate sizes for training and validation sets
    train_size = int(0.8 * len(full_train_dataset))
    val_size = len(full_train_dataset) - train_size

    # Split the training dataset into training and validation sets
    train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size])

    # Create the test dataset
    test_dataset = NucleiDatasetStardist(test_image_dir, test_mask_dir, image_size=(256, 256))

    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

    # --- INSPECTION ---
    print("\n--- Data Inspection ---")
    print(f"Training dataset size: {len(train_dataset)}")
    print(f"Validation dataset size: {len(val_dataset)}")
    print(f"Testing dataset size: {len(test_dataset)}")

    if len(train_dataset) > 0:
        first_train_img, first_train_mask = train_dataset[0]
        print(f"First training image shape: {first_train_img.shape}, min: {first_train_img.min():.4f}, max: {first_train_img.max():.4f}")
        print(f"First training mask shape: {first_train_mask.shape}, unique values: {np.unique(first_train_mask)}")

    if len(test_dataset) > 0:
        first_test_img, first_test_mask = test_dataset[0]
        print(f"First testing image shape: {first_test_img.shape}, min: {first_test_img.min():.4f}, max: {first_test_img.max():.4f}")
        print(f"First testing mask shape: {first_test_mask.shape}, unique values: {np.unique(first_test_mask)}")
    print("--- End Data Inspection ---\n")
    # --- END INSPECTION ---

    # Train StarDist model - Increased epochs, consider adjusting learning_rate if needed
    stardist_model = train_stardist(train_loader, val_loader, epochs=15, learning_rate=1e-4) # Increased epochs to 50

    # Predict on the test set - Adjusted prob_threshold and nms_threshold
    test_images = np.array([item[0] for item in test_dataset])
    test_ground_truth_masks = np.array([item[1] for item in test_dataset])
    # Try different thresholds if predictions are still black
    stardist_predictions = predict_stardist(stardist_model, test_loader, prob_threshold=0.1, nms_threshold=0.3)

    # --- INSPECTION OF PREDICTIONS ---
    print("\n--- Prediction Inspection ---")
    if len(stardist_predictions) > 0:
        print(f"Shape of first prediction: {stardist_predictions[0].shape}")
        print(f"Unique values in first prediction: {np.unique(stardist_predictions[0])}")
        if np.max(stardist_predictions[0]) == 0:
            print("WARNING: First prediction is entirely black (no objects detected).")
    else:
        print("No predictions generated.")
    print("--- End Prediction Inspection ---\n")
    # --- END INSPECTION ---


    # Evaluate (using mean IoU for instance segmentation)
    mean_iou, precision, recall, f1, accuracy = evaluate_stardist(test_ground_truth_masks, stardist_predictions, iou_threshold=0.5)
    print(f'Mean IoU on test set: {mean_iou:.4f}')
    print(f'Precision: {precision:.4f}, Recall: {recall:.4f}, F1-Score: {f1:.4f}, Accuracy: {accuracy:.4f}')

    # Visualize results
    visualize_stardist_results(test_images, test_ground_truth_masks, stardist_predictions, num_samples=5)

    # Optionally save the trained model - Corrected file extension
    model_save_path = 'stardist_nuclei_model.keras' # Changed extension to .keras (recommended by Keras)
    stardist_model.keras_model.save(model_save_path)
    print(f'Trained StarDist model saved to: {model_save_path}')