<a href="https://colab.research.google.com/github/MarvinEhab/medical-image-segmention/blob/main/medsam.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Enhanced Google Colab NIFTI Data Reader with MedSAM AI Segmentation
# Using Segment Anything Model adapted for Medical Imaging

# Install required packages
!pip install nibabel matplotlib seaborn plotly
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install segment-anything
!pip install git+https://github.com/facebookresearch/segment-anything.git
!pip install SimpleITK
!pip install scikit-image
!pip install opencv-python
!pip install timm

import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import pandas as pd
from google.colab import files
import io
import os
import torch
import torch.nn.functional as F
from torchvision import transforms
import SimpleITK as sitk
from skimage import measure, morphology
from skimage.segmentation import watershed
from scipy import ndimage
from scipy.ndimage import zoom, gaussian_filter
import cv2
import warnings
warnings.filterwarnings('ignore')

# Try to import SAM
try:
    from segment_anything import sam_model_registry, SamPredictor
    SAM_AVAILABLE = True
except:
    print("SAM not available, will use traditional methods")
    SAM_AVAILABLE = False

class MedSAMOrganSegmenter:
    """AI-powered organ segmentation using MedSAM (Medical Segment Anything Model)"""

    def __init__(self):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Using device: {self.device}")
        self.models = {}
        self.predictor = None
        self.organ_configs = {
            'liver': {
                'parts': ['parenchyma', 'vessels', 'lesions'],
                'colors': ['Reds', 'Blues', 'Oranges'],
                'thresholds': [0.3, 0.4, 0.5],
                'prompt_boxes': [  # Relative coordinates (x1, y1, x2, y2) from 0 to 1
                    (0.3, 0.3, 0.7, 0.7),
                    (0.35, 0.35, 0.65, 0.65),
                    (0.4, 0.4, 0.6, 0.6)
                ]
            },
            'kidney': {
                'parts': ['cortex', 'medulla', 'pelvis'],
                'colors': ['Greens', 'Oranges', 'Purples'],
                'thresholds': [0.35, 0.4, 0.45],
                'prompt_boxes': [
                    (0.25, 0.25, 0.75, 0.75),
                    (0.35, 0.35, 0.65, 0.65),
                    (0.4, 0.4, 0.6, 0.6)
                ]
            },
            'heart': {
                'parts': ['left_ventricle', 'right_ventricle', 'myocardium'],
                'colors': ['Reds', 'Blues', 'RdPu'],
                'thresholds': [0.25, 0.2, 0.55],
                'prompt_boxes': [
                    (0.3, 0.35, 0.55, 0.65),  # LV (left side)
                    (0.45, 0.35, 0.7, 0.65),  # RV (right side)
                    (0.25, 0.3, 0.75, 0.7)    # Myocardium (full heart)
                ]
            },
            'brain': {
                'parts': ['gray_matter', 'white_matter', 'csf'],
                'colors': ['Greys', 'bone', 'Blues'],
                'thresholds': [0.3, 0.3, 0.4],
                'prompt_boxes': [
                    (0.2, 0.2, 0.8, 0.8),
                    (0.3, 0.3, 0.7, 0.7),
                    (0.35, 0.35, 0.65, 0.65)
                ]
            }
        }

    def load_medsam_model(self, model_type='vit_b'):
        """
        Load MedSAM (Segment Anything Model for Medical Imaging)
        model_type: 'vit_b' (base), 'vit_l' (large), or 'vit_h' (huge)
        """
        if not SAM_AVAILABLE:
            print("SAM not available. Please install: pip install segment-anything")
            return False

        try:
            print(f"Loading MedSAM model: {model_type}")

            # Download SAM checkpoint
            checkpoint_urls = {
                'vit_b': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth',
                'vit_l': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
                'vit_h': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'
            }

            checkpoint_path = f'sam_{model_type}.pth'

            if not os.path.exists(checkpoint_path):
                print(f"Downloading {model_type} checkpoint...")
                import urllib.request
                urllib.request.urlretrieve(checkpoint_urls[model_type], checkpoint_path)
                print("Download complete!")

            # Load SAM model
            sam = sam_model_registry[model_type](checkpoint=checkpoint_path)
            sam = sam.to(self.device)
            sam.eval()

            # Create predictor
            self.predictor = SamPredictor(sam)
            self.models['medsam'] = sam

            print(f"MedSAM {model_type} loaded successfully!")
            return True

        except Exception as e:
            print(f"Error loading MedSAM: {str(e)}")
            print("Falling back to traditional methods")
            return False

    def segment_with_medsam(self, data, organ_type='liver'):
        """Segment using MedSAM with prompt engineering"""
        print(f"Using MedSAM AI segmentation for {organ_type}")

        if self.predictor is None:
            print("MedSAM not loaded. Loading now...")
            success = self.load_medsam_model('vit_b')
            if not success:
                return self.segment_with_traditional_methods(data, organ_type)

        if len(data.shape) == 4:
            data = data[:, :, :, 0]

        original_shape = data.shape
        segmentations = np.zeros((4,) + original_shape)

        config = self.organ_configs[organ_type]
        prompt_boxes = config['prompt_boxes']

        # Normalize data
        data_norm = (data - np.min(data)) / (np.max(data) - np.min(data) + 1e-8)

        # Process key slices with MedSAM
        print("   Processing slices with MedSAM...")
        mid_z = original_shape[2] // 2
        key_slices = [
            mid_z - 10, mid_z - 5, mid_z, mid_z + 5, mid_z + 10
        ]
        key_slices = [z for z in key_slices if 0 <= z < original_shape[2]]

        slice_predictions = {i: [] for i in range(3)}

        for z_idx in key_slices:
            try:
                # Extract slice
                slice_2d = data_norm[:, :, z_idx]

                # Convert to 3-channel RGB
                slice_rgb = np.stack([slice_2d] * 3, axis=-1)
                slice_rgb = (slice_rgb * 255).astype(np.uint8)

                # Resize to optimal size for SAM (1024x1024)
                h, w = slice_rgb.shape[:2]
                target_size = 1024
                scale = target_size / max(h, w)
                new_h, new_w = int(h * scale), int(w * scale)
                slice_resized = cv2.resize(slice_rgb, (new_w, new_h), interpolation=cv2.INTER_LINEAR)

                # Pad to square
                slice_padded = np.zeros((target_size, target_size, 3), dtype=np.uint8)
                y_offset = (target_size - new_h) // 2
                x_offset = (target_size - new_w) // 2
                slice_padded[y_offset:y_offset+new_h, x_offset:x_offset+new_w] = slice_resized

                # Set image for SAM
                self.predictor.set_image(slice_padded)

                # Segment each part using prompt boxes
                for part_idx, box_coords in enumerate(prompt_boxes[:3]):
                    # Convert relative coordinates to absolute
                    x1 = int(box_coords[0] * target_size)
                    y1 = int(box_coords[1] * target_size)
                    x2 = int(box_coords[2] * target_size)
                    y2 = int(box_coords[3] * target_size)

                    box = np.array([x1, y1, x2, y2])

                    # Predict with SAM
                    masks, scores, logits = self.predictor.predict(
                        box=box,
                        multimask_output=True
                    )

                    # Select best mask
                    best_mask_idx = np.argmax(scores)
                    mask = masks[best_mask_idx]

                    # Extract relevant region and resize back
                    mask_region = mask[y_offset:y_offset+new_h, x_offset:x_offset+new_w]
                    mask_original = cv2.resize(
                        mask_region.astype(np.uint8),
                        (w, h),
                        interpolation=cv2.INTER_LINEAR
                    ).astype(np.float32)

                    slice_predictions[part_idx].append((z_idx, mask_original))

            except Exception as e:
                print(f"   Error on slice {z_idx}: {e}")
                continue

        # Combine SAM predictions with traditional methods
        print("   Combining MedSAM with traditional segmentation...")
        traditional_seg = self.segment_with_traditional_methods(data, organ_type)

        if any(slice_predictions.values()):
            # Interpolate SAM predictions across volume
            for part_idx in range(3):
                predictions = slice_predictions[part_idx]
                if predictions:
                    # Fill in predicted slices
                    for z_idx, mask in predictions:
                        segmentations[part_idx+1, :, :, z_idx] = mask

                    # Interpolate between slices
                    predicted_slices = sorted([z for z, _ in predictions])
                    for z in range(original_shape[2]):
                        if segmentations[part_idx+1, :, :, z].sum() == 0:
                            # Find nearest predicted slices
                            distances = [abs(z - pz) for pz in predicted_slices]
                            if distances:
                                nearest_idx = np.argmin(distances)
                                nearest_z = predicted_slices[nearest_idx]

                                # Simple copy for now
                                decay = np.exp(-distances[nearest_idx] / 5.0)
                                segmentations[part_idx+1, :, :, z] = (
                                    segmentations[part_idx+1, :, :, nearest_z] * decay
                                )

            # Blend SAM (70%) with traditional (30%) for robustness
            alpha = 0.7
            for i in range(1, 4):
                segmentations[i] = alpha * segmentations[i] + (1 - alpha) * traditional_seg[i]
        else:
            # Use traditional if SAM failed
            segmentations = traditional_seg

        # Apply smoothing
        print("   Applying final smoothing...")
        for i in range(1, 4):
            segmentations[i] = gaussian_filter(segmentations[i], sigma=1.0)

        # Normalize
        segmentations[0] = 1.0 - np.sum(segmentations[1:], axis=0)
        segmentations = np.clip(segmentations, 0, 1)

        return segmentations

    def segment_with_traditional_methods(self, data, organ_type='liver'):
        """Optimized segmentation with smooth upsampling"""
        print(f"Using traditional segmentation methods for {organ_type}")

        if len(data.shape) == 4:
            data = data[:, :, :, 0]

        original_shape = data.shape
        if np.prod(data.shape) > 10**7:
            print("   Downsampling large volume for faster processing...")
            downsample_factor = int(np.ceil((np.prod(data.shape) / 5**6)**(1/3)))
            data_small = data[::downsample_factor, ::downsample_factor, ::downsample_factor]
            print(f"   Reduced from {original_shape} to {data_small.shape}")
        else:
            data_small = data
            downsample_factor = 1

        data_small = data_small.astype(np.float32)
        if np.max(data_small) > 0:
            data_small = (data_small - np.min(data_small)) / (np.max(data_small) - np.min(data_small))

        config = self.organ_configs.get(organ_type, self.organ_configs['liver'])
        segmentations = np.zeros((4,) + data_small.shape)

        print("   Applying intensity-based segmentation...")

        if organ_type == 'brain':
            brain_mask = data_small > 0.1
            gray_matter = (data_small > 0.4) & (data_small < 0.8) & brain_mask
            segmentations[1] = gray_matter.astype(np.float32)
            white_matter = (data_small > 0.6) & brain_mask
            segmentations[2] = white_matter.astype(np.float32)
            csf = (data_small > 0.1) & (data_small < 0.3) & brain_mask
            segmentations[3] = csf.astype(np.float32)

        elif organ_type == 'heart':
            print("   Using specialized cardiac segmentation...")
            cardiac_mask = data_small > 0.2
            blood_pool = (data_small > 0.15) & (data_small < 0.55) & cardiac_mask
            blood_pool = morphology.binary_opening(blood_pool, morphology.ball(2))
            blood_pool = morphology.binary_closing(blood_pool, morphology.ball(3))

            labeled = measure.label(blood_pool)
            regions = measure.regionprops(labeled)

            if len(regions) >= 2:
                regions_sorted = sorted(regions, key=lambda r: r.area, reverse=True)
                lv_label = regions_sorted[0].label
                lv_mask = (labeled == lv_label).astype(np.float32)
                rv_label = regions_sorted[1].label
                rv_mask = (labeled == rv_label).astype(np.float32)
                lv_mask = morphology.binary_dilation(lv_mask, morphology.ball(1)).astype(np.float32)
                rv_mask = morphology.binary_dilation(rv_mask, morphology.ball(1)).astype(np.float32)
                segmentations[1] = lv_mask
                segmentations[2] = rv_mask
                print(f"      - Found LV with {regions_sorted[0].area} voxels")
                print(f"      - Found RV with {regions_sorted[1].area} voxels")
            elif len(regions) == 1:
                print("      - Only one chamber detected, splitting spatially...")
                single_mask = blood_pool.astype(np.float32)
                com = ndimage.center_of_mass(single_mask)
                x_center = int(com[0])
                lv_mask = single_mask.copy()
                rv_mask = single_mask.copy()
                lv_mask[x_center:, :, :] *= 0.3
                rv_mask[:x_center, :, :] *= 0.3
                segmentations[1] = lv_mask
                segmentations[2] = rv_mask
            else:
                print("      - No distinct chambers found, using intensity thresholds...")
                segmentations[1] = ((data_small > 0.2) & (data_small < 0.5)).astype(np.float32)
                segmentations[2] = ((data_small > 0.15) & (data_small < 0.35)).astype(np.float32)

            myocardium = (data_small > 0.5) & cardiac_mask
            myocardium = morphology.binary_opening(myocardium, morphology.ball(1))
            myocardium = myocardium & ~(segmentations[1] > 0.3) & ~(segmentations[2] > 0.3)
            segmentations[3] = myocardium.astype(np.float32)
            print(f"      - Myocardium segmented")

        else:
            for i, threshold in enumerate(config['thresholds']):
                mask = data_small > threshold
                if mask.any():
                    mask = morphology.binary_opening(mask, morphology.ball(1))
                    mask = morphology.binary_closing(mask, morphology.ball(1))
                segmentations[i+1] = mask.astype(np.float32)

        if organ_type != 'brain' and organ_type != 'heart' and np.prod(data_small.shape) < 10**6:
            print("   Applying watershed refinement...")
            try:
                threshold = np.percentile(data_small[data_small > 0], 85)
                markers = measure.label(data_small > threshold)
                if np.max(markers) > 0 and np.max(markers) < 1000:
                    watershed_result = watershed(-data_small, markers, mask=data_small > np.mean(data_small))
                    unique_labels = np.unique(watershed_result)[1:]
                    if len(unique_labels) > 3:
                        for i in range(3):
                            label_subset = unique_labels[i::3]
                            for label in label_subset:
                                mask = (watershed_result == label).astype(np.float32)
                                segmentations[i+1] = np.maximum(segmentations[i+1], mask * 0.3)
            except Exception as e:
                print(f"   Watershed skipped due to: {e}")

        for i in range(1, 4):
            for j in range(i+1, 4):
                overlap = segmentations[i] * segmentations[j]
                segmentations[i] -= overlap * 0.5
                segmentations[j] -= overlap * 0.5

        segmentations[0] = 1.0 - np.sum(segmentations[1:], axis=0)
        segmentations[0] = np.clip(segmentations[0], 0, 1)

        if downsample_factor > 1:
            print("   Upsampling with enhanced smooth interpolation...")
            segmentations_full = np.zeros((4,) + original_shape)
            for i in range(4):
                upsampled = zoom(segmentations[i], downsample_factor, order=3)
                segmentations_full[i] = upsampled[:original_shape[0], :original_shape[1], :original_shape[2]]

            print("   Applying multi-stage smoothing...")
            for i in range(1, 4):
                segmentations_full[i] = gaussian_filter(segmentations_full[i], sigma=1.5)
                if np.any(segmentations_full[i] > 0.3):
                    binary_mask = segmentations_full[i] > 0.5
                    smoothed_mask = morphology.binary_closing(binary_mask, morphology.ball(2))
                    smoothed_mask = morphology.binary_opening(smoothed_mask, morphology.ball(1))
                    segmentations_full[i] = 0.7 * segmentations_full[i] + 0.3 * smoothed_mask.astype(np.float32)
                segmentations_full[i] = gaussian_filter(segmentations_full[i], sigma=0.5)

            for i in range(4):
                segmentations_full[i] = np.clip(segmentations_full[i], 0, 1)

            return segmentations_full

        return segmentations

    def segment_organ(self, data, organ_type='liver', method='ai'):
        """Main segmentation function"""
        print(f"\nStarting {organ_type} segmentation using {method} method...")

        if method == 'ai':
            segmentation = self.segment_with_medsam(data, organ_type)
        else:
            segmentation = self.segment_with_traditional_methods(data, organ_type)

        if segmentation is not None:
            print("Segmentation completed successfully!")
            return segmentation
        else:
            print("Segmentation failed")
            return None

    def visualize_segmentation(self, original_data, segmentation, organ_type='liver'):
        """Visualize the segmentation results"""
        if segmentation is None:
            print("No segmentation to visualize")
            return

        config = self.organ_configs[organ_type]
        parts = config['parts']
        colors = config['colors']

        if len(original_data.shape) == 4:
            original_data = original_data[:, :, :, 0]

        mid_z = original_data.shape[2] // 2

        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        fig.suptitle(f'{organ_type.title()} Segmentation Results (MedSAM)', fontsize=16, fontweight='bold')

        axes[0, 0].imshow(original_data[:, :, mid_z].T, cmap='gray', origin='lower')
        axes[0, 0].set_title('Original Image')
        axes[0, 0].axis('off')

        positions = [(0, 1), (0, 2), (1, 0)]

        for i, (part, color) in enumerate(zip(parts, colors)):
            if i < len(positions):
                row, col = positions[i]
                part_mask = segmentation[i+1, :, :, mid_z].T
                axes[row, col].imshow(part_mask, cmap=color, origin='lower', alpha=0.8)
                axes[row, col].imshow(original_data[:, :, mid_z].T, cmap='gray', origin='lower', alpha=0.3)
                axes[row, col].set_title(f'{part.replace("_", " ").title()}')
                axes[row, col].axis('off')

        overlay = np.zeros((*original_data[:, :, mid_z].T.shape, 3))
        base_img = original_data[:, :, mid_z].T
        base_img_norm = (base_img - np.min(base_img)) / (np.max(base_img) - np.min(base_img))

        rgb_colors = {
            'Reds': [1, 0, 0], 'Blues': [0, 0, 1], 'Greens': [0, 1, 0],
            'Oranges': [1, 0.5, 0], 'Purples': [0.5, 0, 1], 'RdPu': [1, 0, 0.5],
            'Greys': [0.5, 0.5, 0.5], 'bone': [0.8, 0.8, 0.6]
        }

        for j, (part_name, part_color) in enumerate(zip(parts, colors)):
            mask = segmentation[j+1, :, :, mid_z].T > 0.5
            if part_color in rgb_colors:
                color_rgb = rgb_colors[part_color]
                overlay[mask, 0] = color_rgb[0]
                overlay[mask, 1] = color_rgb[1]
                overlay[mask, 2] = color_rgb[2]

        alpha = 0.6
        for c in range(3):
            overlay[:, :, c] = alpha * overlay[:, :, c] + (1-alpha) * base_img_norm

        axes[1, 1].imshow(overlay, origin='lower')
        axes[1, 1].set_title('Combined Segmentation')
        axes[1, 1].axis('off')
        axes[1, 2].axis('off')

        plt.tight_layout()
        plt.show()

    def generate_segmentation_stats(self, segmentation, organ_type='liver', voxel_size=(1.0, 1.0, 1.0)):
        """Generate statistics for the segmentation"""
        if segmentation is None:
            return

        config = self.organ_configs[organ_type]
        parts = config['parts']

        print(f"\nSEGMENTATION STATISTICS - {organ_type.upper()}")
        print("="*50)

        voxel_volume = np.prod(voxel_size)
        total_volume = 0

        for i, part in enumerate(parts):
            mask = segmentation[i+1] > 0.5
            volume_voxels = np.sum(mask)
            volume_mm3 = volume_voxels * voxel_volume
            volume_ml = volume_mm3 / 1000

            print(f"   {part.replace('_', ' ').title()}:")
            print(f"     - Volume: {volume_ml:.2f} ml ({volume_voxels:,} voxels)")
            print(f"     - Percentage: {(volume_voxels / segmentation[0].size) * 100:.2f}%")
            total_volume += volume_ml

        print(f"\n   Total segmented volume: {total_volume:.2f} ml")
        print(f"   Segmentation coverage: {(np.sum(segmentation[1:]) / segmentation[0].size) * 100:.2f}%")

    def compute_dice_coefficient(self, pred_mask, gt_mask, smooth=1e-6):
        """Compute Dice Coefficient"""
        pred_flat = pred_mask.flatten()
        gt_flat = gt_mask.flatten()
        intersection = np.sum(pred_flat * gt_flat)
        union = np.sum(pred_flat) + np.sum(gt_flat)
        dice = (2.0 * intersection + smooth) / (union + smooth)
        return dice

    def compute_iou(self, pred_mask, gt_mask, smooth=1e-6):
        """Compute IoU (Jaccard Index)"""
        pred_flat = pred_mask.flatten()
        gt_flat = gt_mask.flatten()
        intersection = np.sum(pred_flat * gt_flat)
        union = np.sum(pred_flat) + np.sum(gt_flat) - intersection
        iou = (intersection + smooth) / (union + smooth)
        return iou

    def compute_hausdorff_distance(self, pred_mask, gt_mask):
        """Compute Hausdorff Distance"""
        from scipy.spatial.distance import directed_hausdorff
        pred_boundary = pred_mask.astype(bool)
        gt_boundary = gt_mask.astype(bool)
        pred_coords = np.argwhere(pred_boundary)
        gt_coords = np.argwhere(gt_boundary)

        if len(pred_coords) == 0 or len(gt_coords) == 0:
            return float('inf')

        forward_hd = directed_hausdorff(pred_coords, gt_coords)[0]
        backward_hd = directed_hausdorff(gt_coords, pred_coords)[0]
        hausdorff_dist = max(forward_hd, backward_hd)
        return hausdorff_dist

    def evaluate_segmentation(self, pred_segmentation, gt_segmentation, organ_type='liver', voxel_size=(1.0, 1.0, 1.0)):
        """Evaluate segmentation quality"""
        if pred_segmentation is None or gt_segmentation is None:
            print("Error: Both predicted and ground truth segmentations required")
            return None

        config = self.organ_configs[organ_type]
        parts = config['parts']

        print(f"\n{'='*60}")
        print(f"SEGMENTATION EVALUATION METRICS - {organ_type.upper()}")
        print(f"{'='*60}\n")

        results = {}

        for i, part in enumerate(parts):
            print(f"   {part.replace('_', ' ').title()}:")
            print(f"   {'-'*40}")

            pred_mask = (pred_segmentation[i+1] > 0.5).astype(np.float32)
            gt_mask = (gt_segmentation[i+1] > 0.5).astype(np.float32)

            dice = self.compute_dice_coefficient(pred_mask, gt_mask)
            iou = self.compute_iou(pred_mask, gt_mask)

            try:
                hd_voxels = self.compute_hausdorff_distance(pred_mask, gt_mask)
                hd_mm = hd_voxels * np.mean(voxel_size)
            except:
                hd_voxels = float('inf')
                hd_mm = float('inf')

            results[part] = {
                'dice': dice,
                'iou': iou,
                'hausdorff_distance_voxels': hd_voxels,
                'hausdorff_distance_mm': hd_mm
            }

            print(f"     • Dice Coefficient:        {dice:.4f}  {'✓ Excellent' if dice > 0.9 else '✓ Good' if dice > 0.7 else '⚠ Fair' if dice > 0.5 else '✗ Poor'}")
            print(f"     • IoU (Jaccard Index):     {iou:.4f}  {'✓ Excellent' if iou > 0.8 else '✓ Good' if iou > 0.6 else '⚠ Fair' if iou > 0.4 else '✗ Poor'}")

            if hd_voxels != float('inf'):
                print(f"     • Hausdorff Distance:      {hd_mm:.2f} mm ({hd_voxels:.2f} voxels)")
                print(f"                               {'✓ Excellent' if hd_mm < 2 else '✓ Good' if hd_mm < 5 else '⚠ Fair' if hd_mm < 10 else '✗ Poor'}")
            else:
                print(f"     • Hausdorff Distance:      N/A (no boundary detected)")
            print()

        print(f"   {'='*40}")
        print(f"   OVERALL PERFORMANCE:")
        print(f"   {'='*40}")

        avg_dice = np.mean([results[part]['dice'] for part in parts])
        avg_iou = np.mean([results[part]['iou'] for part in parts])
        valid_hd = [results[part]['hausdorff_distance_mm'] for part in parts
                    if results[part]['hausdorff_distance_mm'] != float('inf')]
        avg_hd = np.mean(valid_hd) if valid_hd else float('inf')

        print(f"     • Average Dice:            {avg_dice:.4f}")
        print(f"     • Average IoU:             {avg_iou:.4f}")
        if avg_hd != float('inf'):
            print(f"     • Average Hausdorff Dist:  {avg_hd:.2f} mm")

        print(f"\n   {'='*40}")
        print(f"   INTERPRETATION GUIDE:")
        print(f"   {'='*40}")
        print(f"     Dice & IoU:   >0.9 = Excellent | >0.7 = Good | >0.5 = Fair")
        print(f"     Hausdorff:    <2mm = Excellent | <5mm = Good | <10mm = Fair")
        print(f"   {'='*40}\n")

        return results

class NiftiReader:
    """Base NIFTI reader class"""
    def __init__(self):
        self.nifti_img = None
        self.data = None
        self.header = None
        self.affine = None
        self.filename = None

    def upload_nifti_file(self):
        """Upload a NIFTI file"""
        print("Select a NIFTI file (.nii or .nii.gz) to upload:")
        uploaded = files.upload()
        if uploaded:
            filename = list(uploaded.keys())[0]
            self.filename = filename
            print(f"Uploaded: {filename}")
            self.load_nifti_file(filename)
        else:
            print("No file uploaded")

    def load_sample_data(self):
        """Load sample brain data"""
        print("Loading sample brain data...")
        try:
            import subprocess
            result = subprocess.run(['wget', '-O', 'sample_brain.nii.gz',
                                   "https://github.com/neurolabusc/niivue-images/raw/main/chris_t1.nii.gz"],
                                  capture_output=True, text=True)
            if result.returncode == 0:
                self.filename = 'sample_brain.nii.gz'
                self.load_nifti_file('sample_brain.nii.gz')
        except:
            print("Could not download sample data")

    def load_cardiac_sample_data(self):
        """Load synthetic cardiac data"""
        print("Loading sample cardiac MRI data...")
        try:
            print("   Generating synthetic cardiac MRI data...")
            synthetic_heart = self.generate_synthetic_heart_data()
            synthetic_img = nib.Nifti1Image(synthetic_heart, np.eye(4))
            nib.save(synthetic_img, 'synthetic_heart.nii.gz')
            self.filename = 'synthetic_heart.nii.gz'
            self.load_nifti_file('synthetic_heart.nii.gz')
            print("Synthetic cardiac data loaded!")
        except Exception as e:
            print(f"Could not load cardiac data: {e}")

    def generate_synthetic_heart_data(self):
        """Generate synthetic heart MRI data with better contrast"""
        shape = (128, 128, 80)
        data = np.zeros(shape, dtype=np.float32)
        center = (64, 64, 40)

        print("   Creating left ventricle...")
        lv_center = (center[0] - 12, center[1], center[2])
        for z in range(shape[2]):
            for y in range(shape[1]):
                for x in range(shape[0]):
                    dist_lv = np.sqrt((x - lv_center[0])**2 + (y - lv_center[1])**2 + ((z - lv_center[2])*1.5)**2)
                    if dist_lv < 12:
                        data[x, y, z] = 0.4
                    elif dist_lv < 18:
                        data[x, y, z] = 0.75

        print("   Creating right ventricle...")
        rv_center = (center[0] + 18, center[1] + 8, center[2])
        for z in range(shape[2]):
            for y in range(shape[1]):
                for x in range(shape[0]):
                    dist_rv = np.sqrt((x - rv_center[0])**2 + (y - rv_center[1])**2 + ((z - rv_center[2])*1.5)**2)
                    if dist_rv < 10:
                        data[x, y, z] = max(data[x, y, z], 0.35)
                    elif dist_rv < 15 and data[x, y, z] < 0.5:
                        data[x, y, z] = 0.7

        print("   Adding myocardial wall...")
        for z in range(20, 60):
            for y in range(35, 90):
                for x in range(40, 90):
                    if data[x, y, z] < 0.3:
                        dist_to_lv = np.sqrt((x - lv_center[0])**2 + (y - lv_center[1])**2)
                        dist_to_rv = np.sqrt((x - rv_center[0])**2 + (y - rv_center[1])**2)
                        if 18 < dist_to_lv < 28 or 15 < dist_to_rv < 22:
                            data[x, y, z] = 0.65

        noise = np.random.normal(0, 0.03, shape)
        data = np.clip(data + noise, 0, 1)

        for z in range(shape[2]):
            slice_factor = 1.0 - abs(z - center[2]) / center[2] * 0.3
            data[:, :, z] *= slice_factor

        print(f"   Synthetic heart data created: {shape}")
        print(f"   Intensity range: {data.min():.3f} to {data.max():.3f}")
        print(f"   Mean intensity: {data.mean():.3f}")

        return data

    def load_nifti_file(self, filename):
        """Load NIFTI file"""
        try:
            print(f"Loading NIFTI file: {filename}")
            self.nifti_img = nib.load(filename)
            self.data = self.nifti_img.get_fdata()
            self.header = self.nifti_img.header
            self.affine = self.nifti_img.affine
            print("NIFTI file loaded successfully!")
            self.analyze_nifti_data()
        except Exception as e:
            print(f"Error loading NIFTI file: {str(e)}")

    def analyze_nifti_data(self):
        """Analyze NIFTI data"""
        if self.data is None:
            return

        print("\n" + "="*60)
        print("NIFTI FILE ANALYSIS")
        print("="*60)
        print(f"\nFilename: {self.filename}")
        print(f"Data shape: {self.data.shape}")
        print(f"Data type: {self.data.dtype}")
        print(f"Voxel size: {self.header.get_zooms()[:3]} mm")
        print(f"Memory usage: {self.data.nbytes / (1024**2):.2f} MB")

class EnhancedNiftiReader(NiftiReader):
    """Enhanced NIFTI reader with MedSAM segmentation"""

    def __init__(self):
        super().__init__()
        self.segmenter = MedSAMOrganSegmenter()
        self.current_segmentation = None
        self.current_organ_type = None
        self.ground_truth_segmentation = None

    def load_ground_truth(self, gt_filename):
        """Load ground truth segmentation for evaluation"""
        try:
            print(f"Loading ground truth: {gt_filename}")
            gt_img = nib.load(gt_filename)
            gt_data = gt_img.get_fdata()

            if len(gt_data.shape) == 3:
                segmentation = np.zeros((4,) + gt_data.shape)
                for i in range(4):
                    segmentation[i] = (gt_data == i).astype(np.float32)
                self.ground_truth_segmentation = segmentation
            else:
                self.ground_truth_segmentation = gt_data

            print("Ground truth loaded successfully!")
            return True
        except Exception as e:
            print(f"Error loading ground truth: {e}")
            return False

    def evaluate_current_segmentation(self):
        """Evaluate current segmentation against ground truth"""
        if self.current_segmentation is None:
            print("No current segmentation to evaluate")
            return None

        if self.ground_truth_segmentation is None:
            print("No ground truth loaded. Use reader.load_ground_truth('filename.nii.gz')")
            return None

        voxel_size = self.header.get_zooms()[:3] if self.header else (1.0, 1.0, 1.0)

        return self.segmenter.evaluate_segmentation(
            self.current_segmentation,
            self.ground_truth_segmentation,
            self.current_organ_type,
            voxel_size
        )

    def compare_segmentations(self, seg1, seg2, organ_type='liver', labels=('Method 1', 'Method 2')):
        """Compare two different segmentation methods"""
        voxel_size = self.header.get_zooms()[:3] if self.header else (1.0, 1.0, 1.0)

        print(f"\n{'='*60}")
        print(f"COMPARING SEGMENTATION METHODS - {organ_type.upper()}")
        print(f"{'='*60}\n")

        print(f"Method 1: {labels[0]}")
        print(f"Method 2: {labels[1]}\n")

        results = self.segmenter.evaluate_segmentation(seg1, seg2, organ_type, voxel_size)

        return results

    def segment_organ_parts(self, organ_type='liver', method='ai', sam_model='vit_b'):
        """Segment organ into 3 parts using MedSAM AI"""
        if self.data is None:
            print("No data loaded")
            return None

        if method == 'ai':
            self.segmenter.load_medsam_model(sam_model)

        segmentation = self.segmenter.segment_organ(self.data, organ_type, method)

        if segmentation is not None:
            self.current_segmentation = segmentation
            self.current_organ_type = organ_type
            self.segmenter.visualize_segmentation(self.data, segmentation, organ_type)
            voxel_size = self.header.get_zooms()[:3] if self.header else (1.0, 1.0, 1.0)
            self.segmenter.generate_segmentation_stats(segmentation, organ_type, voxel_size)
            return segmentation

        return None

    def save_segmentation(self, output_prefix='segmentation', export_format='both'):
        """Save segmentation as NIFTI files"""
        if self.current_segmentation is None:
            print("No segmentation to save")
            return []

        print(f"Saving segmentation results...")
        affine = self.affine if self.affine is not None else np.eye(4)
        config = self.segmenter.organ_configs[self.current_organ_type]
        parts = config['parts']
        saved_files = []

        if export_format in ['masks', 'both']:
            print("\nSaving individual binary masks...")
            for i, part in enumerate(parts):
                part_data = (self.current_segmentation[i+1] > 0.5).astype(np.uint8) * 255
                part_data_oriented = np.transpose(part_data, (2, 1, 0))
                part_img = nib.Nifti1Image(part_data_oriented, affine)
                part_img.header.set_data_dtype(np.uint8)
                filename = f"{output_prefix}_{self.current_organ_type}_{part}_mask.nii.gz"
                nib.save(part_img, filename)
                saved_files.append(filename)
                print(f"   Saved: {filename}")

        if export_format in ['labels', 'both']:
            print("\nSaving combined label map...")
            combined = np.argmax(self.current_segmentation, axis=0).astype(np.uint8)
            combined_oriented = np.transpose(combined, (2, 1, 0))
            combined_img = nib.Nifti1Image(combined_oriented, affine)
            combined_img.header.set_data_dtype(np.uint8)
            combined_filename = f"{output_prefix}_{self.current_organ_type}_labels.nii.gz"
            nib.save(combined_img, combined_filename)
            saved_files.append(combined_filename)
            print(f"   Saved: {combined_filename}")

        print("\nSaving original image...")
        original_data = self.data[:, :, :, 0] if len(self.data.shape) == 4 else self.data
        original_normalized = original_data.astype(np.float32)
        if np.max(original_normalized) > 0:
            original_normalized = (original_normalized - np.min(original_normalized)) / (np.max(original_normalized) - np.min(original_normalized))
        original_uint16 = (original_normalized * 65535).astype(np.uint16)
        original_oriented = np.transpose(original_uint16, (2, 1, 0))
        original_img = nib.Nifti1Image(original_oriented, affine)
        original_img.header.set_data_dtype(np.uint16)
        original_filename = f"{output_prefix}_{self.current_organ_type}_original.nii.gz"
        nib.save(original_img, original_filename)
        saved_files.append(original_filename)
        print(f"   Saved: {original_filename}")

        print(f"\n3D VIEWER READY FILES:")
        for filename in saved_files:
            print(f"   {filename}")

        return saved_files

# Initialize reader
print("Enhanced NIFTI Data Reader with MedSAM AI Segmentation")
print("="*60)

reader = EnhancedNiftiReader()

print("\nHOW TO USE:")
print("\n1. Load Data:")
print("   reader.upload_nifti_file()")
print("   reader.load_sample_data()")
print("   reader.load_cardiac_sample_data()")
print("\n2. Segment Organs with MedSAM AI:")
print("   reader.segment_organ_parts('brain', method='ai', sam_model='vit_b')")
print("   reader.segment_organ_parts('heart', method='ai')")
print("   reader.segment_organ_parts('liver', method='ai')")
print("   reader.segment_organ_parts('kidney', method='traditional')  # fallback")
print("\n3. Available SAM Models:")
print("   - 'vit_b': ViT-B (Base) - Faster, ~375MB")
print("   - 'vit_l': ViT-L (Large) - Better accuracy, ~1.2GB")
print("   - 'vit_h': ViT-H (Huge) - Best accuracy, ~2.4GB")
print("\n4. Evaluate Segmentation (requires ground truth):")
print("   reader.load_ground_truth('ground_truth.nii.gz')")
print("   results = reader.evaluate_current_segmentation()")
print("\n5. Compare Two Methods:")
print("   seg1 = reader.segment_organ_parts('heart', method='ai')")
print("   seg2 = reader.segment_organ_parts('heart', method='traditional')")
print("   reader.compare_segmentations(seg1, seg2, 'heart', ('MedSAM', 'Traditional'))")
print("\n6. Save Results:")
print("   reader.save_segmentation('my_segmentation')")

print("\n" + "="*60)
print("FEATURES:")
print("  ✓ MedSAM (Segment Anything for Medical Imaging)")
print("  ✓ Prompt-based segmentation with bounding boxes")
print("  ✓ Multiple SAM model sizes (vit_b, vit_l, vit_h)")
print("  ✓ Enhanced cardiac segmentation with ventricle detection")
print("  ✓ Automatic fallback to traditional methods")
print("  ✓ Smooth upsampling and interpolation")
print("  ✓ GPU acceleration when available")
print("  ✓ Export to NIFTI format for 3D viewing")
print("  ✓ EVALUATION METRICS:")
print("    • Dice Coefficient (overlap accuracy)")
print("    • IoU/Jaccard Index (intersection over union)")
print("    • Hausdorff Distance (boundary accuracy)")
print("="*60)

print("\n🎯 ABOUT MedSAM:")
print("MedSAM uses Meta's Segment Anything Model (SAM) adapted for medical")
print("imaging. It uses prompt engineering (bounding boxes) to segment specific")
print("anatomical structures with high precision. The model excels at finding")
print("exact boundaries and works well with various medical imaging modalities.")
print("="*60)

print("\nReady for NIFTI analysis with MedSAM AI!")

[31mERROR: Operation cancelled by user[0m[31m
[0mCollecting segment-anything
  Downloading segment_anything-1.0-py3-none-any.whl.metadata (487 bytes)
Downloading segment_anything-1.0-py3-none-any.whl (36 kB)
Installing collected packages: segment-anything
Successfully installed segment-anything-1.0
Collecting git+https://github.com/facebookresearch/segment-anything.git
  Cloning https://github.com/facebookresearch/segment-anything.git to /tmp/pip-req-build-hjcncy3i
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/segment-anything.git /tmp/pip-req-build-hjcncy3i
  Resolved https://github.com/facebookresearch/segment-anything.git to commit dca509fe793f601edb92606367a655c15ac00fdf
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting SimpleITK
  Downloading simpleitk-2.5.2-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (7.2 kB)
Downloading simpleitk-2.5.2-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (