In [1]:
import os
import sys
import yaml
import itertools

from glob import glob
from tqdm.auto import tqdm

import torch
from torch.utils.data import Dataset, DataLoader

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Pytorch StarDist3D
sys.path.append('..')
from pytorch_stardist.data.utils import normalize
from pytorch_stardist.models.config import Config3D
from pytorch_stardist.models.stardist3d import StarDist3D
from utils import seed_all, prepare_conf

from stardist_tools.matching import matching_dataset

# Need this even when not using multiprocessing
os.environ["LOCAL_RANK"] = '0'
os.environ["RANK"] = '0'

  setattr(self, word, getattr(machar, word).flat[0])
  return self._float_to_str(self.smallest_subnormal)
  setattr(self, word, getattr(machar, word).flat[0])
  return self._float_to_str(self.smallest_subnormal)
__init__.py (48): Importing from timm.models.layers is deprecated, please import via timm.layers


### 1. Load Configuration and Pre-trained Model
Load the model configuration from the YAML file and instantiate the `StarDist3D` model with pre-trained weights.

In [2]:
config_file = '../confs/train_convnext_unet_base-3D.yaml'

with open(config_file) as yconf:
    opt = yaml.safe_load(yconf)

Config = Config3D
StarDist = StarDist3D

conf = Config(**opt, allow_new_params=True)

# Set random seed
seed_all(conf.random_seed)

# process the configuration variables
opt = prepare_conf(conf)

# Model instanciation
model = StarDist(opt)
model.net.load_state_dict(torch.load('../model_checkpoints/convnext_unet_base-3D.pth'))
model.net.to(model.device)

### 2. Define and Prepare the Dataset
We define the `BlastospimDataset` class to load images and masks. For a robust grid search, all available test sets are combined. We also load the ground truth masks into memory once to speed up the evaluation loop.

In [3]:
class BlastospimDataset(Dataset):
    def __init__(self, image_names, source_dir):
        self.image_paths = []
        self.mask_paths = []
        for name in image_names:
            self.image_paths.append(f'{source_dir}/{name}/{name}/images/{name}_image_0001.npy')
            self.mask_paths.append(f'{source_dir}/{name}/{name}/masks/{name}_masks_0001.npy')

    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]
        
        image = np.load(image_path)
        mask = np.load(mask_path)

        # Make image dimensions divisible by n:
        n = 32
        image = image[:image.shape[0]-image.shape[0]%n, :image.shape[1]-image.shape[1]%n, :image.shape[2]-image.shape[2]%n]
        mask = mask[:mask.shape[0]-mask.shape[0]%n, :mask.shape[1]-mask.shape[1]%n, :mask.shape[2]-mask.shape[2]%n]
        assert image.shape == mask.shape

        # Normalize image
        axis_norm = (0, 1, 2)  # normalize channels independently
        image = np.expand_dims(normalize(image, 1, 99.8, axis=axis_norm), 0) # Add channel for one color

        return {
            'image':image.astype(np.float32),
            'mask':mask.astype(np.int16)
        }

In [4]:
source_dir = '/mnt/ceph/users/alu10/datasets/GTSets/2023_Full_Iso-Trilinear_Image'

testset8 = ['F24_001', 'F24_002', 'F24_006', 'F25_002', 'F25_008', 'F27_010', 'F27_007', 'F27_009', 'F29_003','F29_004', 'F30_004', 'F30_008', 'F30_009', 'M6_021', 'M6_012']
testset16 = ['M7_004', 'M7_000', 'F42_063', 'F41_056', 'F34_073', 'F33_067', 'F26_008', 'F24_010']
testset32 = ['F8_072', 'F44_087', 'F44_089', 'F39_117']
testset64 = ['F40_136', 'F49_148']
testset128 = ['F55_185']

# Combine all test sets for a comprehensive evaluation
all_test_names = testset8 + testset16 + testset32 + testset64 + testset128

test_dataset = BlastospimDataset(all_test_names, source_dir)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)

# Load all GT masks into memory once
gt_masks = [data['mask'][0].numpy() for data in tqdm(test_dataloader, desc="Loading GT masks")]

print(f'Combined test set contains {len(test_dataset)} images.')

### 3. Perform Grid Search
Now we loop through different combinations of `prob_thresh` and `nms_thresh`. For each combination, we generate instance labels for the *entire dataset* and calculate the F1-score at IoU=0.7.

**⚠️ Warning:** This method is computationally expensive. It re-runs the full prediction pipeline for every image at every point in the grid. This may take a very long time to complete.

In [5]:
# Define the grid of thresholds to search
prob_thresholds = np.arange(0.1, 1.0, 0.1)
nms_thresholds = np.arange(0.1, 1.0, 0.1)
iou_thresh_fixed = 0.7

results = []
grid = list(itertools.product(prob_thresholds, nms_thresholds))

print(f"Starting grid search over {len(grid)} combinations...")
for prob_thresh, nms_thresh in tqdm(grid, desc="Grid Search Progress"):
    # Set the model's thresholds for this iteration
    model.thresholds['prob'] = prob_thresh
    model.thresholds['nms'] = nms_thresh
    
    predicted_labels = []
    # Loop through the dataset and predict for each image
    for batch in test_dataloader:
        image = batch['image'][0].numpy()
        labels, _ = model.predict_instance(image, patch_size=[256, 256, 256], context=[64, 64, 64])
        predicted_labels.append(labels)
        
    # Evaluate this set of predictions against all ground truth masks
    stats = matching_dataset(gt_masks, predicted_labels, thresh=iou_thresh_fixed, show_progress=False)
    
    results.append({
        'prob_thresh': prob_thresh,
        'nms_thresh': nms_thresh,
        'iou_thresh': iou_thresh_fixed,
        'f1': stats.f1,
        'precision': stats.precision,
        'recall': stats.recall
    })

results_df = pd.DataFrame(results)

### 4. Analyze and Visualize Results
First, we find and print the best combination of thresholds that maximized the F1-score. Then, we create a heatmap to visualize the performance across the entire grid.

In [6]:
# Find the best result
best_result = results_df.loc[results_df['f1'].idxmax()]

print("--- Grid Search Complete ---")
print(f"Best F1-Score: {best_result['f1']:.4f}")
print(f"Optimal prob_thresh: {best_result['prob_thresh']:.2f}")
print(f"Optimal nms_thresh: {best_result['nms_thresh']:.2f}")
print(f"Precision at best F1: {best_result['precision']:.4f}")
print(f"Recall at best F1: {best_result['recall']:.4f}")

# Pivot the data for the heatmap
f1_pivot = results_df.pivot(index='prob_thresh', columns='nms_thresh', values='f1')

# Plot the heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(f1_pivot, annot=True, fmt=".3f", cmap="viridis", cbar_kws={'label': 'F1-Score'})
plt.title(f'F1-Score Grid Search (IoU Threshold = {iou_thresh_fixed})')
plt.xlabel('NMS Threshold')
plt.ylabel('Probability Threshold')
plt.gca().invert_yaxis()
plt.show()