In [1]:
import warnings
import torch
from matplotlib import pyplot as plt
from skimage.measure import find_contours
from torch.cuda import amp
from torch.nn import functional as F
from tqdm import tqdm
import sys
import os
import random
import cv2
import numpy as np
from skimage import filters, morphology

sys.path.append('./architectures/')
from architectures.lr_aspp import LiteRASPP

sys.path.append('./datasets/')
from datasets import phantom_dataset
from datasets.cvc_dataset import CVC_Dataset
from datasets.real_dataset import RealSegDataset
from datasets.broncho_dataset import Broncho_Dataset

import utils
import json

# Import the new model and utilities
sys.path.append('../')
from depth_anything_v2.dpt import DepthAnythingV2
from util import ZSegmentationExtractor

class DatasetEvaluator:
    def __init__(self, dataset_type, img_resolution=128, gray_scaled=True):
        """
        dataset_type: str - one of 'cvc', 'real', 'phantom', 'broncho'
        """
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.spatial_dim = img_resolution
        self.gray_scaled = gray_scaled
        self.model = self._setup_model()
        self.dataset = self._setup_dataset(dataset_type)
        self.dataset_type = dataset_type
        
    def _setup_model(self):
        model_configs = {
            'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
            'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
            'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
            'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
        }
        encoder = 'vitl'
        model = DepthAnythingV2(**model_configs[encoder])
        model.load_state_dict(torch.load(f'../depth_anything_v2_{encoder}.pth', map_location=self.device, weights_only=True))
        return model.to(self.device).eval()
    
    def _setup_dataset(self, dataset_type):
        dataset_map = {
            'cvc': lambda: CVC_Dataset(self.spatial_dim, normalize=False, gray_scale=self.gray_scaled),
            'real': lambda: RealSegDataset(self.spatial_dim, normalize=False, gray_scale=self.gray_scaled),
            'phantom': lambda: phantom_dataset.PhantomSegDataset('blob', 'test', self.spatial_dim, normalize=False, gray_scaled=self.gray_scaled),
            'broncho': lambda: Broncho_Dataset(self.spatial_dim, normalize=False, gray_scale=self.gray_scaled)
        }
        
        if dataset_type not in dataset_map:
            raise ValueError(f"Dataset type must be one of {list(dataset_map.keys())}")
            
        return dataset_map[dataset_type]()

    @staticmethod
    def show_points(coords, labels, ax, marker_size=100):
        cmap = plt.get_cmap('tab10')
        unique_labels = np.unique(labels)
        for i, label in enumerate(unique_labels):
            color = cmap(i % 10)
            points = coords[labels == label]
            ax.scatter(points[:, 1], points[:, 0], color=color, marker='o', s=marker_size, edgecolor='white', linewidth=1.25)

    @staticmethod
    def get_filtered(image, cutoffs, squared_butterworth=False, order=0.0, npad=32):
        highpass_filtered = []
        for cutoff in cutoffs:
            highpass_filtered.append(
                filters.butterworth(
                    image,
                    cutoff_frequency_ratio=cutoff,
                    order=order,
                    high_pass=True,
                    squared_butterworth=squared_butterworth,
                    npad=npad,
                )
            )
        return highpass_filtered

    def evaluate(self, sample_interval=1):
        # Set up sampling for different datasets
        if self.dataset_type == 'phantom':
            indices = list(range(0, len(self.dataset), 30))
        elif self.dataset_type == 'broncho':
            indices = list(range(0, len(self.dataset), 5))
        else:
            indices = list(range(0, len(self.dataset), sample_interval))

        num_diff = torch.empty(len(indices), dtype=torch.int32)
        dice = torch.empty_like(num_diff, dtype=torch.float)
        c_dist = torch.empty_like(dice)

        # Set up specific parameters for each dataset
        params = {
            'cvc': {'corner_margin': 25, 'edge_margin': 7, 'intensity_threshold': 0.5},
            'real': {'corner_margin': 25, 'edge_margin': 7, 'intensity_threshold': 0.5},
            'phantom': {'corner_margin': 25, 'edge_margin': 7, 'intensity_threshold': 0.5},
            'broncho': {'corner_margin': 25, 'edge_margin': 7, 'intensity_threshold': 0.5}
        }

        current_params = params[self.dataset_type]

        for idx, i in enumerate(tqdm(indices)):
            x, y = self.dataset[i]
            with amp.autocast(), torch.no_grad():
                # Preprocess the input image
                raw_img = x.numpy()
                resized_img = cv2.resize(raw_img.transpose(1, 2, 0), (128, 128))
                
                # Infer depth and perform segmentation
                depth = self.model.infer_image(resized_img)
                max_depth = depth.max()
                inverted_depth = max_depth - depth
                
                # Apply high-pass filter to inverted depth
                cutoffs = [0.0126, 0.05, 0.32]
                highpass_depths = self.get_filtered(inverted_depth, cutoffs, squared_butterworth=False, order=5, npad=211)
                high_pass_depth = highpass_depths[0]
                
                depth_tensor = torch.from_numpy(high_pass_depth)

                # Initialize segmentation extractor with dataset-specific parameters
                segmentation_extractor = ZSegmentationExtractor(
                    self.spatial_dim, 
                    watershed_compactness=3, 
                    avg_pool_kernel_size=1, 
                    **current_params
                )

                # Extract predesegmentation
                d = segmentation_extractor.extract_segmentation(depth_tensor, rgb_img=resized_img, return_plot_data=True)
                seg_mask = d['seg_mask']
                label_y_hat = d['z_labels']
                c_y_hat = d['airway_centroids']
                num_y_hat = len(d['z_labels'])

                n_points = len(c_y_hat)
                labels = np.arange(1, n_points + 1)

                # Convert instance segmentation to binary
                small_object_size = 100
                
                y_hat_cleaned = (seg_mask > 0).squeeze(0).cpu().numpy().astype(np.uint8)
                y_hat_cleaned = morphology.remove_small_objects(y_hat_cleaned.astype(bool), min_size=small_object_size)
                y_hat_cleaned = morphology.binary_opening(y_hat_cleaned, morphology.disk(3))
                y_hat = y_hat_cleaned

            y[y == -1] = 0
            y = y.squeeze(0)

            # Plot the results
            self.plot_results(resized_img, high_pass_depth, y, y_hat, c_y_hat, labels, i)

            assert torch.all(torch.isin(y.to(int), torch.tensor([0, 1])))

            # Extract airways using ZSegmentationExtractor
            label_y, c_y, num_y = utils.extract_airways(y.to(torch.bool).numpy())

            # Calculate metrics
            if num_y > 0:
                dice[idx] = utils.dice_coeff(torch.from_numpy(y_hat), y, max_label=1)
            else:
                dice[idx] = -1

            if num_y > 0 and num_y_hat > 0:
                if isinstance(self.dataset, phantom_dataset.PhantomSegDataset) and False:  # in_mm is always False in original code
                    # Handle phantom dataset specific calculations (keeping original logic)
                    z = phantom_dataset.PhantomSegDataset('z', 'test', self.spatial_dim).gt
                    z_y = torch.empty(num_y)
                    contours = find_contours(y.numpy(), fully_connected='high')

                    assert len(contours) == num_y
                    for c_idx in range(len(contours)):
                        z_contour = utils.sample_from_image(z[i], torch.from_numpy(contours[c_idx]).flip(-1))
                        z_y[c_idx] = z_contour.median()

                    c_y = phantom_dataset.pixel_to_normalized_world_coordinates(torch.as_tensor(c_y), [self.spatial_dim, self.spatial_dim])
                    c_y_hat = phantom_dataset.pixel_to_normalized_world_coordinates(torch.as_tensor(c_y_hat), [self.spatial_dim, self.spatial_dim])

                    d = c_y.unsqueeze(1) - c_y_hat.unsqueeze(0)
                    d = d.norm(p=2, dim=-1).min(1)[0]
                    c_dist[idx] = (d * z_y * phantom_dataset.Z_MAX_DEPTH).mean()
                else:
                    try:
                        c_y_tensor = torch.as_tensor(c_y)
                        c_y_hat_tensor = torch.as_tensor(c_y_hat)

                        # Check for empty tensors
                        if c_y_tensor.numel() == 0 or c_y_hat_tensor.numel() == 0:
                            c_dist[idx] = -1
                            continue

                        # Ensure we have 2D coordinates
                        if c_y_tensor.size(-1) != c_y_hat_tensor.size(-1):
                            c_y_tensor = c_y_tensor[:, :2]
                            c_y_hat_tensor = c_y_hat_tensor[:, :2]

                        # Additional shape check
                        if c_y_tensor.size(0) == 0 or c_y_hat_tensor.size(0) == 0:
                            c_dist[idx] = -1
                            continue

                        # Calculate distances
                        d = c_y_tensor.unsqueeze(1) - c_y_hat_tensor.unsqueeze(0)  # nxmx2
                        if d.size(0) > 0 and d.size(1) > 0:
                            d = d.norm(p=2, dim=-1)
                            if d.size(1) > 0:
                                d = d.min(1)[0]
                                c_dist[idx] = d.mean()
                            else:
                                c_dist[idx] = -1
                        else:
                            c_dist[idx] = -1
                    except Exception as e:
                        print(f"Warning: Error calculating centroid distance for index {idx}: {e}")
                        c_dist[idx] = -1
            else:
                c_dist[idx] = -1

            num_diff[idx] = num_y_hat - num_y

        # Print metrics
        self.print_metrics(dice, c_dist, num_diff)
        
        # Plot final statistics
        self.plot_statistics(dice, c_dist, num_diff)

    def plot_results(self, raw_img, depth_map, ground_truth, pred_mask, centroids, labels, index):
        fig, axes = plt.subplots(1, 4, figsize=(8, 8))

        axes[0].imshow(raw_img)
        axes[0].set_title(f'Raw Image {index}')
        axes[0].axis('off')

        axes[1].imshow(depth_map, cmap='viridis')
        axes[1].set_title(f'Inverted Depth {index}')
        self.show_points(centroids, labels, plt.gca())
        axes[1].axis('off')

        axes[2].imshow(ground_truth, cmap='gray')
        axes[2].set_title(f'Ground Truth {index}')
        axes[2].axis('off')

        axes[3].imshow(pred_mask, cmap='gray')
        axes[3].set_title(f'Predicted Mask {index}')
        self.show_points(centroids, labels, plt.gca())
        axes[3].axis('off')

        plt.tight_layout()
        plt.show()

    def print_metrics(self, dice, c_dist, num_diff):
        # Filter out invalid measurements
        dice = dice[dice != -1]
        c_dist = c_dist[c_dist != -1]

        print('Dataset:', self.dataset)
        print(f'num diff centroids: {num_diff.float().mean().item()} +- {num_diff.float().std().item()}')
        print(f'DSC: {dice.mean().item()} +- {dice.std().item()}, median: {dice.median().item()}')
        print(f'D_c[px]: {c_dist.mean().item()} +- {c_dist.std().item()}, median: {c_dist.median().item()}')

    def plot_statistics(self, dice, c_dist, num_diff):
        plt.hist(num_diff.numpy(), bins='auto')
        plt.xlabel('|C_predicted|-|C_ground truth|')
        plt.figure()
        
        plt.boxplot(dice.numpy(), showfliers=False, notch=True, showmeans=True)
        plt.title('Dice mean')
        plt.figure()
        
        plt.boxplot(c_dist.numpy(), showfliers=False, notch=True, showmeans=True)
        plt.title('Mean centroid distance')
        plt.ylabel('distance [px]')
        
        plt.show()


  check_for_updates()
xFormers not available
xFormers not available


### CVC DS

In [2]:
# evaluator = DatasetEvaluator('cvc')
# evaluator.evaluate()

### Hauser DS

In [3]:
# evaluator = DatasetEvaluator('real')
# evaluator.evaluate()

### Phantom DS

In [4]:
# evaluator = DatasetEvaluator('phantom')
# evaluator.evaluate()

### BronchoLC DS

In [5]:
# evaluator = DatasetEvaluator('broncho')
# evaluator.evaluate()

In [6]:
import warnings
warnings.filterwarnings("ignore")  # Suppress all warnings

# Temporarily modify both plotting methods to do nothing
def no_plot(self, *args, **kwargs):
    pass

def no_statistics_plot(self, dice, c_dist, num_diff):
    # Just print the statistics without plotting
    print("\nStatistics Summary:")
    print(f"Number of differences - Mean: {num_diff.float().mean().item():.2f}, Std: {num_diff.float().std().item():.2f}")
    print(f"Dice Score - Mean: {dice[dice != -1].mean().item():.2f}, Std: {dice[dice != -1].std().item():.2f}")
    print(f"Centroid Distance - Mean: {c_dist[c_dist != -1].mean().item():.2f}, Std: {c_dist[c_dist != -1].std().item():.2f}")

# Store the original methods
original_plot_results = DatasetEvaluator.plot_results
original_plot_statistics = DatasetEvaluator.plot_statistics

# Replace the methods
DatasetEvaluator.plot_results = no_plot
DatasetEvaluator.plot_statistics = no_statistics_plot

# Run evaluations
datasets = ['cvc', 'real', 'phantom', 'broncho']

for dataset_type in datasets:
    print(f"\n{'='*50}")
    print(f"Evaluating {dataset_type.upper()} Dataset")
    print('='*50)
    
    try:
        evaluator = DatasetEvaluator(dataset_type)
        evaluator.evaluate()
    except Exception as e:
        print(f"Error evaluating {dataset_type} dataset: {str(e)}")
        continue
    
    print()  # Add blank line between datasets

# Restore the original methods
DatasetEvaluator.plot_results = original_plot_results
DatasetEvaluator.plot_statistics = original_plot_statistics


Evaluating CVC Dataset


100%|██████████| 125/125 [00:18<00:00,  6.71it/s]


Dataset: <datasets.cvc_dataset.CVC_Dataset object at 0x7f16fc5629d0>
num diff centroids: 12801.8876953125 +- 0.6119245886802673
DSC: 0.6239727139472961 +- 0.2181776463985443, median: 0.6757317185401917
D_c[px]: 13.935718536376953 +- 7.137856483459473, median: 12.693500518798828

Statistics Summary:
Number of differences - Mean: 12801.89, Std: 0.61
Dice Score - Mean: 0.62, Std: 0.22
Centroid Distance - Mean: 13.94, Std: 7.14


Evaluating REAL Dataset


100%|██████████| 107/107 [00:17<00:00,  6.09it/s]


Dataset: <datasets.real_dataset.RealSegDataset object at 0x7f16fc560e10>
num diff centroids: 12681.6357421875 +- 1237.546142578125
DSC: 0.4214887320995331 +- 0.2711307108402252, median: 0.43695518374443054
D_c[px]: 19.391002655029297 +- 14.375557899475098, median: 14.708398818969727

Statistics Summary:
Number of differences - Mean: 12681.64, Std: 1237.55
Dice Score - Mean: 0.42, Std: 0.27
Centroid Distance - Mean: 19.39, Std: 14.38


Evaluating PHANTOM Dataset


load data: 100%|██████████| 2/2 [00:00<00:00, 21.63it/s]
100%|██████████| 173/173 [00:28<00:00,  6.08it/s]


Dataset: <datasets.phantom_dataset.PhantomSegDataset object at 0x7f170437a790>
num diff centroids: 12526.3125 +- 1951.8133544921875
DSC: 0.5992451906204224 +- 0.19701744616031647, median: 0.6424276828765869
D_c[px]: 16.537973403930664 +- 10.855913162231445, median: 13.857728004455566

Statistics Summary:
Number of differences - Mean: 12526.31, Std: 1951.81
Dice Score - Mean: 0.60, Std: 0.20
Centroid Distance - Mean: 16.54, Std: 10.86


Evaluating BRONCHO Dataset


100%|██████████| 239/239 [00:37<00:00,  6.43it/s]

Dataset: <datasets.broncho_dataset.Broncho_Dataset object at 0x7f15f73437d0>
num diff centroids: 12748.0791015625 +- 828.069580078125
DSC: 0.5151475071907043 +- 0.17124058306217194, median: 0.5144803524017334
D_c[px]: 19.54058074951172 +- 11.895088195800781, median: 16.930543899536133

Statistics Summary:
Number of differences - Mean: 12748.08, Std: 828.07
Dice Score - Mean: 0.52, Std: 0.17
Centroid Distance - Mean: 19.54, Std: 11.90




