In [1]:
!pip install segmentation_models_pytorch

Collecting segmentation_models_pytorch
  Downloading segmentation_models_pytorch-0.5.0-py3-none-any.whl.metadata (17 kB)
Downloading segmentation_models_pytorch-0.5.0-py3-none-any.whl (154 kB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m154.8/154.8 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: segmentation_models_pytorch
Successfully installed segmentation_models_pytorch-0.5.0


In [3]:
# Install SAM
!pip install segment-anything

# Download SAM checkpoint (choose one):
# ViT-B (smallest, ~375MB)
!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth

# ViT-L (~1.2GB)
# !wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth

# ViT-H (largest, ~2.4GB)
# !wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

Collecting segment-anything
  Downloading segment_anything-1.0-py3-none-any.whl.metadata (487 bytes)
Downloading segment_anything-1.0-py3-none-any.whl (36 kB)
Installing collected packages: segment-anything
Successfully installed segment-anything-1.0
--2025-11-09 03:49:16--  https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 52.85.129.113, 52.85.129.86, 52.85.129.4, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|52.85.129.113|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 375042383 (358M) [binary/octet-stream]
Saving to: ‚Äòsam_vit_b_01ec64.pth‚Äô


2025-11-09 03:49:17 (257 MB/s) - ‚Äòsam_vit_b_01ec64.pth‚Äô saved [375042383/375042383]

--2025-11-09 03:49:18--  https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 52.85.129.113, 52.85.129.86, 52.85.129.4, ...
Connecting to dl.fbaipublicfi

In [4]:
"""
GRADIO WEB APPLICATION FOR BRAIN TUMOR CLASSIFICATION AND SEGMENTATION
Enhanced with SAM refinement for improved segmentation accuracy
"""

import gradio as gr
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import plotly.graph_objects as go
import os
from datetime import datetime
import cv2
import segmentation_models_pytorch as smp
import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib
matplotlib.use('Agg')  # Use non-interactive backend
import matplotlib.pyplot as plt
import io
import base64
import glob
import warnings
warnings.filterwarnings('ignore')

# Import SAM dependencies
try:
    from segment_anything import sam_model_registry, SamPredictor
    SAM_AVAILABLE = True
except ImportError:
    print("Warning: SAM not installed. Install with: pip install segment-anything")
    SAM_AVAILABLE = False

# Add safe globals for PyTorch 2.6+
import torch.serialization
import numpy.core.multiarray
torch.serialization.add_safe_globals([numpy.core.multiarray.scalar])

class BrainTumorSegmenter:
    """Enhanced segmentation model handler with SAM refinement"""

    def __init__(self, model_path, sam_checkpoint_path=None):
        """Initialize the segmenter with optional SAM refinement"""
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = None
        self.sam_predictor = None

        # Initialize U-Net
        if not os.path.exists(model_path):
            print(f"Segmentation model not found at: {model_path}")
            return

        try:
            print(f"Loading segmentation model from: {model_path}")

            # Load checkpoint with weights_only=False for compatibility
            checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)

            # Extract configuration
            self.img_size = 256  # Default
            encoder_name = 'efficientnet-b1'  # Default

            if isinstance(checkpoint, dict):
                if 'config' in checkpoint:
                    config = checkpoint['config']
                    self.img_size = config.get('img_size', 256)
                    encoder_name = config.get('encoder', 'efficientnet-b1')

                if 'preprocessing' in checkpoint:
                    prep = checkpoint['preprocessing']
                    self.img_size = prep.get('input_size', self.img_size)
                    self.mean = prep.get('mean', [0.485, 0.456, 0.406])
                    self.std = prep.get('std', [0.229, 0.224, 0.225])
                else:
                    self.mean = [0.485, 0.456, 0.406]
                    self.std = [0.229, 0.224, 0.225]

                if 'model_architecture' in checkpoint:
                    arch = checkpoint['model_architecture']
                    encoder_name = arch.get('encoder_name', encoder_name)
            else:
                self.mean = [0.485, 0.456, 0.406]
                self.std = [0.229, 0.224, 0.225]

            # Create model
            print(f"  Creating U-Net with {encoder_name} encoder...")
            self.model = smp.Unet(
                encoder_name=encoder_name,
                encoder_weights=None,
                in_channels=3,
                classes=1,
                activation=None,
            )

            # Load weights
            if isinstance(checkpoint, dict):
                if 'model_state_dict' in checkpoint:
                    state_dict = checkpoint['model_state_dict']
                elif 'state_dict' in checkpoint:
                    state_dict = checkpoint['state_dict']
                else:
                    state_dict = checkpoint
            else:
                state_dict = checkpoint

            # Remove 'module.' prefix if present
            new_state_dict = {}
            for k, v in state_dict.items():
                if k.startswith('module.'):
                    name = k[7:]
                else:
                    name = k
                new_state_dict[name] = v

            self.model.load_state_dict(new_state_dict, strict=False)
            self.model.to(self.device)
            self.model.eval()

            print(f"‚úì Segmentation model loaded successfully!")

            # Initialize SAM if available and checkpoint provided
            if SAM_AVAILABLE and sam_checkpoint_path:
                self._initialize_sam(sam_checkpoint_path)

        except Exception as e:
            print(f"Error loading segmentation model: {str(e)}")
            self.model = None

    def _initialize_sam(self, checkpoint_path):
        """Initialize SAM model for refinement"""
        try:
            if not os.path.exists(checkpoint_path):
                print(f"SAM checkpoint not found at: {checkpoint_path}")
                return

            print(f"Loading SAM model from: {checkpoint_path}")

            # Determine SAM model type based on file name
            if 'vit_h' in checkpoint_path:
                model_type = 'vit_h'
            elif 'vit_l' in checkpoint_path:
                model_type = 'vit_l'
            elif 'vit_b' in checkpoint_path:
                model_type = 'vit_b'
            else:
                model_type = 'vit_b'  # Default

            sam = sam_model_registry[model_type](checkpoint=checkpoint_path)
            sam.to(self.device)
            self.sam_predictor = SamPredictor(sam)

            print(f"‚úì SAM model ({model_type}) loaded successfully!")

        except Exception as e:
            print(f"Error loading SAM model: {str(e)}")
            self.sam_predictor = None

    def get_bounding_box_from_mask(self, mask):
        """Extract bounding box from binary mask"""
        # Find non-zero points
        points = np.where(mask > 0)

        if len(points[0]) == 0:  # No mask found
            return None

        # Get bounding box coordinates
        y_min = points[0].min()
        y_max = points[0].max()
        x_min = points[1].min()
        x_max = points[1].max()

        # Add small padding (5% of image size)
        h, w = mask.shape
        padding_y = int(h * 0.05)
        padding_x = int(w * 0.05)

        y_min = max(0, y_min - padding_y)
        y_max = min(h - 1, y_max + padding_y)
        x_min = max(0, x_min - padding_x)
        x_max = min(w - 1, x_max + padding_x)

        return np.array([x_min, y_min, x_max, y_max])

    def refine_with_sam(self, image_np, initial_mask):
        """Refine segmentation using SAM with bounding box prompt"""
        if self.sam_predictor is None:
            return initial_mask

        try:
            # Get bounding box from initial mask
            bbox = self.get_bounding_box_from_mask(initial_mask)

            if bbox is None:
                return initial_mask

            # Set image for SAM
            self.sam_predictor.set_image(image_np)

            # Get mask using box prompt
            masks, scores, _ = self.sam_predictor.predict(
                box=bbox,
                multimask_output=True  # Get multiple mask options
            )

            # Select best mask (highest score)
            best_idx = np.argmax(scores)
            refined_mask = masks[best_idx].astype(np.uint8)

            # Optionally combine with initial mask (intersection or union)
            # Here we'll use intersection to be more conservative
            combined_mask = np.logical_and(refined_mask, initial_mask).astype(np.uint8)

            # If combined mask is too small, use the refined mask alone
            if combined_mask.sum() < initial_mask.sum() * 0.3:
                return refined_mask

            return combined_mask

        except Exception as e:
            print(f"SAM refinement error: {str(e)}")
            return initial_mask

    def segment(self, image, use_sam_refinement=True):
        """Segment tumor in image with optional SAM refinement"""
        if self.model is None:
            return None, None, None

        try:
            if isinstance(image, Image.Image):
                image_np = np.array(image.convert('RGB'))
            else:
                image_np = image

            orig_h, orig_w = image_np.shape[:2]

            # U-Net segmentation
            transform = A.Compose([
                A.Resize(self.img_size, self.img_size),
                A.Normalize(mean=self.mean, std=self.std),
                ToTensorV2(),
            ])

            transformed = transform(image=image_np)
            image_tensor = transformed['image'].unsqueeze(0).to(self.device)

            with torch.no_grad():
                output = self.model(image_tensor)
                prob_map = torch.sigmoid(output).cpu().squeeze().numpy()

            # Initial U-Net mask
            binary_mask = (prob_map > 0.5).astype(np.uint8)

            # Resize to original dimensions
            binary_mask = cv2.resize(binary_mask, (orig_w, orig_h), interpolation=cv2.INTER_NEAREST)
            prob_map = cv2.resize(prob_map, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR)

            # Store U-Net mask for comparison
            unet_mask = binary_mask.copy()

            # Apply SAM refinement if available and requested
            refined_mask = binary_mask
            if use_sam_refinement and self.sam_predictor is not None and binary_mask.sum() > 0:
                print("Applying SAM refinement...")
                refined_mask = self.refine_with_sam(image_np, binary_mask)
                print(f"Mask refined - Original pixels: {binary_mask.sum()}, Refined pixels: {refined_mask.sum()}")

            return refined_mask, prob_map, unet_mask

        except Exception as e:
            print(f"Segmentation error: {str(e)}")
            return None, None, None

class BrainTumorClassifier:
    """Classification model handler"""

    def __init__(self, model_path):
        """Initialize the classifier"""
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Loading classification model from: {model_path}")

        self.checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)

        self.class_names = self.checkpoint['class_names']
        self.num_classes = self.checkpoint['num_classes']
        self.img_size = self.checkpoint['config']['img_size']
        self.model_name = self.checkpoint['model_name']
        self.test_accuracy = self.checkpoint['metrics']['test_accuracy']

        self.normalize_mean = self.checkpoint['normalize_mean']
        self.normalize_std = self.checkpoint['normalize_std']

        self.model = self._create_model()
        self.model.load_state_dict(self.checkpoint['model_state_dict'])
        self.model.eval()

        self.transform = self._create_transform()

        print(f"‚úì Classification model loaded successfully!")
        print(f"  Classes: {self.class_names}")

    def _create_model(self):
        """Recreate model architecture"""
        model_classes = {
            'resnet18': models.resnet18,
            'resnet34': models.resnet34,
            'resnet50': models.resnet50,
            'resnet101': models.resnet101
        }

        model = model_classes[self.model_name](weights=None)
        num_features = model.fc.in_features
        model.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(num_features, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, self.num_classes)
        )

        return model.to(self.device)

    def _create_transform(self):
        """Create preprocessing transform"""
        return transforms.Compose([
            transforms.Resize((self.img_size, self.img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=self.normalize_mean, std=self.normalize_std)
        ])

    def predict(self, image):
        """Make prediction on image"""
        if image is None:
            return None, None, None

        if isinstance(image, np.ndarray):
            image = Image.fromarray(image).convert('RGB')
        elif not isinstance(image, Image.Image):
            image = Image.open(image).convert('RGB')

        image_tensor = self.transform(image).unsqueeze(0).to(self.device)

        with torch.no_grad():
            outputs = self.model(image_tensor)
            probabilities = torch.nn.functional.softmax(outputs, dim=1)
            confidence, predicted_idx = torch.max(probabilities, 1)

        predicted_class = self.class_names[predicted_idx.item()]
        confidence_percent = confidence.item() * 100
        all_probs = {
            self.class_names[i]: prob.item() * 100
            for i, prob in enumerate(probabilities[0])
        }

        return predicted_class, confidence_percent, all_probs

# Global variables for models
classifier = None
segmenter = None

def load_sample_images():
    """Load sample images from the specified directory"""
    sample_dir = "/content/drive/MyDrive/IP Project/sample_images"
    sample_images = []

    if os.path.exists(sample_dir):
        print(f"Loading sample images from: {sample_dir}")

        # Get all image files
        image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.tif', '*.tiff', '*.bmp']
        image_files = []

        for ext in image_extensions:
            image_files.extend(glob.glob(os.path.join(sample_dir, ext)))
            image_files.extend(glob.glob(os.path.join(sample_dir, ext.upper())))

        # Sort for consistent ordering
        image_files = sorted(image_files)

        # Limit to reasonable number of samples
        max_samples = 10
        for img_path in image_files[:max_samples]:
            try:
                # Verify the image can be opened
                img = Image.open(img_path)
                img.verify()

                # Get filename for label
                filename = os.path.basename(img_path)

                # Try to determine tumor type from filename
                label = filename
                if 'glioma' in filename.lower():
                    label = f"üìå Glioma - {filename}"
                elif 'meningioma' in filename.lower():
                    label = f"üìå Meningioma - {filename}"
                elif 'pituitary' in filename.lower():
                    label = f"üìå Pituitary - {filename}"
                elif 'notumor' in filename.lower() or 'no_tumor' in filename.lower() or 'normal' in filename.lower():
                    label = f"üìå Normal - {filename}"
                else:
                    label = f"üìå {filename}"

                sample_images.append([img_path, label])
                print(f"  ‚úì Loaded: {filename}")

            except Exception as e:
                print(f"  ‚úó Failed to load {os.path.basename(img_path)}: {str(e)}")

        print(f"Successfully loaded {len(sample_images)} sample images")
    else:
        print(f"Sample directory not found: {sample_dir}")

    if not sample_images:
        print("No sample images found - examples will not be available")
        return None

    return [[img_path] for img_path, _ in sample_images], [label for _, label in sample_images]

def initialize_models():
    """Initialize both models with better path detection"""
    global classifier, segmenter

    search_dirs = [
        "/content/drive/MyDrive/IP Project",
        "/kaggle/working",
        ".",
        os.getcwd()
    ]

    # SAM checkpoint paths to try
    sam_checkpoints = [
        "/content/drive/MyDrive/IP Project/sam_vit_b_01ec64.pth",
        "/content/sam_vit_b_01ec64.pth",
        "/kaggle/working/sam_vit_b_01ec64.pth",
        None  # Will proceed without SAM if not found
    ]

    # Find and load classification model
    classification_found = False
    for base_dir in search_dirs:
        if not os.path.exists(base_dir):
            continue

        for file in os.listdir(base_dir) if os.path.isdir(base_dir) else []:
            if 'resnet' in file.lower() and file.endswith('.pth'):
                try:
                    model_path = os.path.join(base_dir, file)
                    classifier = BrainTumorClassifier(model_path)
                    print(f"‚úì Classification model loaded from: {model_path}")
                    classification_found = True
                    break
                except Exception as e:
                    print(f"Failed to load {file}: {e}")

        if classification_found:
            break

    # Specific paths for classification model
    if not classification_found:
        specific_paths = [
            "/content/drive/MyDrive/IP Project/resnet101_brain_tumor_20251105_163019.pth",
            "/kaggle/working/best_model.pth"
        ]
        for path in specific_paths:
            if os.path.exists(path):
                try:
                    classifier = BrainTumorClassifier(path)
                    print(f"‚úì Classification model loaded from: {path}")
                    break
                except Exception as e:
                    print(f"Failed to load classification model: {e}")

    # Find and load segmentation model
    segmentation_found = False
    sam_checkpoint = None

    # Find SAM checkpoint
    for sam_path in sam_checkpoints:
        if sam_path and os.path.exists(sam_path):
            sam_checkpoint = sam_path
            print(f"‚úì SAM checkpoint found at: {sam_checkpoint}")
            break

    if sam_checkpoint is None and SAM_AVAILABLE:
        print("‚ö†Ô∏è SAM checkpoint not found - will proceed without SAM refinement")

    for base_dir in search_dirs:
        if not os.path.exists(base_dir):
            continue

        for file in os.listdir(base_dir) if os.path.isdir(base_dir) else []:
            if ('segment' in file.lower() or 'unet' in file.lower()) and file.endswith('.pth'):
                try:
                    model_path = os.path.join(base_dir, file)
                    segmenter = BrainTumorSegmenter(model_path, sam_checkpoint)
                    print(f"‚úì Segmentation model loaded from: {model_path}")
                    segmentation_found = True
                    break
                except Exception as e:
                    print(f"Failed to load {file}: {e}")

        if segmentation_found:
            break

    # Specific paths for segmentation model
    if not segmentation_found:
        specific_paths = [
            "/content/drive/MyDrive/IP Project/brain_tumor_segmentation_model.pth",
            "/kaggle/working/brain_tumor_segmentation_model.pth"
        ]
        for path in specific_paths:
            if os.path.exists(path):
                try:
                    segmenter = BrainTumorSegmenter(path, sam_checkpoint)
                    print(f"‚úì Segmentation model loaded from: {path}")
                    break
                except Exception as e:
                    print(f"Failed to load segmentation model: {e}")

    return classifier is not None, segmenter is not None

# Initialize models on startup
print("="*60)
print("üöÄ Initializing Brain Tumor Analysis System...")
print("="*60)
class_loaded, seg_loaded = initialize_models()
print("="*60)
print(f"System Status:")
print(f"  Classification Model: {'‚úÖ Ready' if class_loaded else '‚ùå Not Found'}")
print(f"  Segmentation Model: {'‚úÖ Ready' if seg_loaded else '‚ùå Not Found'}")
print(f"  SAM Refinement: {'‚úÖ Available' if (segmenter and segmenter.sam_predictor) else '‚ùå Not Available'}")
print("="*60)

# Load sample images
print("\nüìÅ Loading sample images...")
sample_data = load_sample_images()
if sample_data:
    sample_images, sample_labels = sample_data
    print(f"‚úÖ {len(sample_images)} sample images ready for testing")
else:
    sample_images, sample_labels = None, None
    print("‚ö†Ô∏è No sample images available")
print("="*60)

def create_segmentation_figure(original_image, mask, prob_map, tumor_class, unet_mask=None):
    """Create enhanced segmentation visualization with SAM refinement comparison"""

    if isinstance(original_image, Image.Image):
        img_array = np.array(original_image.convert('RGB'))
    else:
        img_array = original_image.copy()

    # Determine if SAM was used
    sam_used = unet_mask is not None and mask is not None and not np.array_equal(mask, unet_mask)

    if tumor_class.lower() != 'notumor' and mask is not None:
        # Calculate tumor statistics
        tumor_pixels = mask.sum()
        total_pixels = mask.size
        tumor_percentage = (tumor_pixels / total_pixels) * 100

        if sam_used:
            # Create figure with 3 subplots for comparison
            fig = plt.figure(figsize=(18, 6))

            # 1. U-Net ONLY (Left)
            plt.subplot(1, 3, 1)
            overlay_unet = img_array.copy()
            red_overlay = np.zeros_like(img_array)
            red_overlay[:, :, 0] = 255

            for c in range(3):
                overlay_unet[:, :, c] = np.where(unet_mask == 1,
                                           img_array[:, :, c] * 0.6 + red_overlay[:, :, c] * 0.4,
                                           img_array[:, :, c])

            alpha = 0.4
            blended_unet = cv2.addWeighted(img_array, 1-alpha, overlay_unet.astype(np.uint8), alpha, 0)
            plt.imshow(blended_unet)
            plt.title('U-Net Segmentation', fontsize=14, fontweight='bold')
            plt.axis('off')

            # 2. SAM REFINED (Middle)
            plt.subplot(1, 3, 2)
            overlay_sam = img_array.copy()
            green_overlay = np.zeros_like(img_array)
            green_overlay[:, :, 1] = 255

            for c in range(3):
                overlay_sam[:, :, c] = np.where(mask == 1,
                                           img_array[:, :, c] * 0.6 + green_overlay[:, :, c] * 0.4,
                                           img_array[:, :, c])

            blended_sam = cv2.addWeighted(img_array, 1-alpha, overlay_sam.astype(np.uint8), alpha, 0)
            plt.imshow(blended_sam)
            plt.title('SAM Refined Segmentation', fontsize=14, fontweight='bold', color='green')
            plt.axis('off')

            # 3. COMPARISON (Right)
            plt.subplot(1, 3, 3)
            comparison_img = img_array.copy()

            # Show U-Net contours in red
            contours_unet, _ = cv2.findContours(unet_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            cv2.drawContours(comparison_img, contours_unet, -1, (255, 0, 0), 2)

            # Show SAM contours in green
            contours_sam, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            cv2.drawContours(comparison_img, contours_sam, -1, (0, 255, 0), 2)

            plt.imshow(comparison_img)
            plt.title('Comparison (Red: U-Net, Green: SAM)', fontsize=14, fontweight='bold')
            plt.axis('off')

            # Calculate improvement
            unet_pixels = unet_mask.sum()
            improvement = ((tumor_pixels - unet_pixels) / unet_pixels * 100) if unet_pixels > 0 else 0

            plt.suptitle(
                f'{tumor_class.upper()} TUMOR - SAM REFINED\n' +
                f'Coverage: {tumor_percentage:.1f}% | Refinement: {improvement:+.1f}%',
                fontsize=16,
                fontweight='bold',
                color='darkgreen'
            )

        else:
            # Original 2-subplot layout when SAM not used
            fig = plt.figure(figsize=(12, 6))

            # 1. BLENDED VIEW (Left)
            plt.subplot(1, 2, 1)
            overlay = img_array.copy()
            red_overlay = np.zeros_like(img_array)
            red_overlay[:, :, 0] = 255

            for c in range(3):
                overlay[:, :, c] = np.where(mask == 1,
                                           img_array[:, :, c] * 0.6 + red_overlay[:, :, c] * 0.4,
                                           img_array[:, :, c])

            alpha = 0.4
            blended = cv2.addWeighted(img_array, 1-alpha, overlay.astype(np.uint8), alpha, 0)
            plt.imshow(blended)
            plt.title('Tumor Region Overlay', fontsize=14, fontweight='bold')
            plt.axis('off')

            # 2. TUMOR BOUNDARIES (Right)
            plt.subplot(1, 2, 2)
            contour_img = img_array.copy()
            contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            cv2.drawContours(contour_img, contours, -1, (0, 255, 0), 3)

            overlay_contour = img_array.copy()
            cv2.fillPoly(overlay_contour, contours, (0, 255, 0))
            contour_img = cv2.addWeighted(contour_img, 0.8, overlay_contour, 0.2, 0)

            plt.imshow(contour_img)
            plt.title('Tumor Boundaries', fontsize=14, fontweight='bold')
            plt.axis('off')

            plt.suptitle(
                f'{tumor_class.upper()} TUMOR DETECTED\nEstimated Coverage: {tumor_percentage:.1f}% of Brain Area',
                fontsize=16,
                fontweight='bold',
                color='darkred'
            )
    else:
        # No tumor case
        fig = plt.figure(figsize=(12, 6))
        plt.subplot(1, 1, 1)
        plt.imshow(img_array)
        plt.title('NO TUMOR DETECTED', fontsize=18, fontweight='bold', color='green')
        plt.text(0.5, -0.05, 'Brain tissue appears healthy',
                ha='center', transform=plt.gca().transAxes,
                fontsize=12, style='italic', color='green')
        plt.axis('off')

    plt.tight_layout()

    # Convert to PIL Image
    buf = io.BytesIO()
    plt.savefig(buf, format='png', dpi=120, bbox_inches='tight', facecolor='white', edgecolor='none')
    plt.close('all')
    buf.seek(0)

    return Image.open(buf)

def create_probability_plot(probabilities):
    """Create probability bar chart"""
    if probabilities is None:
        return None

    classes = list(probabilities.keys())
    probs = list(probabilities.values())

    # Color coding
    colors = []
    for i, p in enumerate(probs):
        if p == max(probs):
            if classes[i].lower() == 'notumor':
                colors.append('#2ecc71')  # Green for no tumor
            else:
                colors.append('#e74c3c')  # Red for tumor
        else:
            colors.append('#95a5a6')  # Gray for other classes

    fig = go.Figure(data=[
        go.Bar(
            x=probs,
            y=[c.upper() for c in classes],
            orientation='h',
            marker=dict(color=colors),
            text=[f'{p:.1f}%' for p in probs],
            textposition='outside',
            hovertemplate='<b>%{y}</b><br>Probability: %{x:.2f}%<extra></extra>'
        )
    ])

    fig.update_layout(
        title={
            'text': "Classification Confidence Scores",
            'font': {'size': 16, 'color': '#2c3e50'}
        },
        xaxis_title="Probability (%)",
        yaxis_title="Tumor Type",
        xaxis=dict(range=[0, 105]),
        height=350,
        template="plotly_white",
        showlegend=False,
        margin=dict(l=100, r=20, t=50, b=50)
    )

    return fig

def analyze_brain_mri(image, use_sam):
    """Main analysis function with SAM refinement option"""

    if classifier is None:
        return (
            "‚ùå Model Not Loaded",
            "",
            None,
            None,
            "### ‚ö†Ô∏è Classification Model Not Found\nPlease ensure the model file exists."
        )

    if image is None:
        return (
            "Awaiting Upload",
            "",
            None,
            None,
            "### üì§ Upload Required\nPlease upload a brain MRI scan or select a sample image."
        )

    try:
        # Perform classification
        predicted_class, confidence, probabilities = classifier.predict(image)

        # Format diagnosis
        if predicted_class.lower() == 'notumor':
            diagnosis = f"‚úÖ HEALTHY"
            emoji = "‚úÖ"
            status_color = "green"
        else:
            diagnosis = f"‚ö†Ô∏è {predicted_class.upper()} TUMOR"
            emoji = "‚ö†Ô∏è"
            status_color = "red"

        confidence_text = f"{confidence:.1f}%"

        # Create probability chart
        prob_chart = create_probability_plot(probabilities)

        # Perform segmentation
        seg_image = None
        seg_status = "Not Performed"
        refinement_status = ""

        if segmenter and segmenter.model is not None:
            if predicted_class.lower() != 'notumor':
                # Use SAM refinement based on checkbox
                result = segmenter.segment(image, use_sam_refinement=use_sam)

                if len(result) == 3:
                    mask, prob_map, unet_mask = result
                else:
                    mask, prob_map = result
                    unet_mask = None

                if mask is not None:
                    seg_image = create_segmentation_figure(image, mask, prob_map, predicted_class, unet_mask)

                    if use_sam and segmenter.sam_predictor and unet_mask is not None:
                        seg_status = "‚úÖ Tumor Localized with SAM Refinement"
                        refinement_status = " (SAM-Enhanced)"
                    else:
                        seg_status = "‚úÖ Tumor Localized with U-Net"
                        refinement_status = " (U-Net Only)"
                else:
                    seg_image = image
                    seg_status = "‚ö†Ô∏è Segmentation Failed"
            else:
                seg_image = create_segmentation_figure(image, None, None, predicted_class)
                seg_status = "‚úÖ No Segmentation Needed (Healthy)"
        else:
            seg_image = image if isinstance(image, Image.Image) else Image.fromarray(image)
            seg_status = "‚ùå Segmentation Model Unavailable"

        # Generate interpretation
        interpretation = f"""
## {emoji} Analysis Results

### üîç Primary Diagnosis
- **Classification:** {predicted_class.capitalize()}
- **Confidence:** {confidence:.1f}%
- **Segmentation:** {seg_status}{refinement_status}

### üìä Class Probabilities
"""

        # Add probability breakdown
        for cls, prob in sorted(probabilities.items(), key=lambda x: x[1], reverse=True):
            if cls == predicted_class:
                interpretation += f"\n**‚Üí {cls.capitalize()}: {prob:.1f}%** *(Detected)*"
            else:
                interpretation += f"\n   {cls.capitalize()}: {prob:.1f}%"

        # Add SAM refinement info if used
        if use_sam and segmenter and segmenter.sam_predictor:
            interpretation += "\n\n### üéØ SAM Refinement Applied\n"
            interpretation += "- Using Segment Anything Model for enhanced boundary detection\n"
            interpretation += "- Bounding box extracted from U-Net prediction\n"
            interpretation += "- SAM refined segmentation within tumor region"

        # Add clinical notes
        clinical_info = {
            'glioma': """
### üè• Clinical Notes: GLIOMA
- **Type:** Primary brain tumor from glial cells
- **Action:** Urgent neurology referral recommended
- **Treatment:** May include surgery, radiation, chemotherapy""",
            'meningioma': """
### üè• Clinical Notes: MENINGIOMA
- **Type:** Usually benign, slow-growing tumor
- **Action:** Neurosurgical consultation advised
- **Treatment:** Monitoring or surgical removal""",
            'pituitary': """
### üè• Clinical Notes: PITUITARY TUMOR
- **Type:** Adenoma affecting hormone production
- **Action:** Endocrinology evaluation needed
- **Treatment:** Medication, surgery, or radiation""",
            'notumor': """
### ‚úÖ Clinical Notes: HEALTHY BRAIN
- **Finding:** No abnormalities detected
- **Action:** Continue routine health monitoring
- **Note:** Regular checkups still recommended"""
        }

        if predicted_class.lower() in clinical_info:
            interpretation += f"\n{clinical_info[predicted_class.lower()]}"

        interpretation += f"""

---
**‚ö†Ô∏è Medical Disclaimer:** This AI analysis is for screening purposes only. Always consult qualified healthcare professionals for diagnosis and treatment.

*Analysis performed: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}*
*Model: {classifier.model_name.upper()} | Accuracy: {classifier.test_accuracy:.1f}%*
"""

        return diagnosis, confidence_text, prob_chart, seg_image, interpretation

    except Exception as e:
        print(f"Error during analysis: {str(e)}")
        import traceback
        traceback.print_exc()

        return (
            "‚ùå Error",
            "",
            None,
            None,
            f"### Analysis Error\n```\n{str(e)}\n```"
        )

# Create Gradio Interface
with gr.Blocks(
    title="üß† Brain Tumor AI Analysis",
    theme=gr.themes.Soft(),
    css="""
    .gradio-container {font-family: 'Inter', 'Arial', sans-serif;}
    .output-class {font-size: 22px !important; font-weight: bold !important; color: #2c3e50;}
    .sample-gallery {margin-top: 15px; border-radius: 8px;}
    h1 {color: #2c3e50 !important;}
    """
) as demo:

    gr.Markdown("""
    # üß† AI Brain Tumor Detection & Localization System

    **Advanced Deep Learning Analysis** - Instant classification and visual tumor localization with SAM refinement for brain MRI scans.
    """)

    with gr.Row():
        with gr.Column(scale=1):
            input_image = gr.Image(
                label="üì§ Upload Brain MRI",
                type="pil",
                elem_id="input-image"
            )

            # SAM refinement checkbox
            use_sam_refinement = gr.Checkbox(
                label="üéØ Enable SAM Refinement",
                value=True if (segmenter and segmenter.sam_predictor) else False,
                interactive=True if (segmenter and segmenter.sam_predictor) else False,
                info="Use Segment Anything Model to refine tumor boundaries"
            )

            with gr.Row():
                analyze_btn = gr.Button(
                    "üîç Analyze MRI",
                    variant="primary",
                    size="lg",
                    elem_id="analyze-button"
                )
                clear_btn = gr.Button("üóëÔ∏è Clear", variant="secondary")

            # Sample images section
            if sample_images and sample_labels:
                gr.Markdown("### üñºÔ∏è Quick Test Samples")
                gr.Markdown("*Click any sample to load it*")

                examples = gr.Examples(
                    examples=sample_images,
                    inputs=input_image,
                    label="Available Samples",
                    examples_per_page=5,
                )

            # System status
            gr.Markdown(f"""
            ### üíª System Status
            Classification: {'üü¢ Online' if classifier else 'üî¥ Offline'}
            Segmentation: {'üü¢ Online' if (segmenter and segmenter.model) else 'üî¥ Offline'}
            SAM Refinement: {'üü¢ Ready' if (segmenter and segmenter.sam_predictor) else 'üî¥ Not Available'}
            Samples: {'‚úÖ Loaded' if sample_images else '‚ùå None'}
            Device: {'üöÄ GPU' if torch.cuda.is_available() else 'üíª CPU'}
            """)

        with gr.Column(scale=2):
            with gr.Row():
                diagnosis_output = gr.Textbox(
                    label="üìã Diagnosis",
                    elem_classes="output-class"
                )
                confidence_output = gr.Textbox(
                    label="üéØ Confidence",
                )

            prob_chart = gr.Plot(label="üìä Classification Analysis")

            seg_output = gr.Image(
                label="üîç Tumor Visualization (Enhanced with SAM)",
                type="pil"
            )

            interpretation = gr.Markdown(label="üìù Detailed Report")

    # Event handlers
    analyze_btn.click(
        fn=analyze_brain_mri,
        inputs=[input_image, use_sam_refinement],
        outputs=[diagnosis_output, confidence_output, prob_chart, seg_output, interpretation]
    )

    clear_btn.click(
        fn=lambda: (None, True if (segmenter and segmenter.sam_predictor) else False, "", "", None, None, ""),
        inputs=[],
        outputs=[input_image, use_sam_refinement, diagnosis_output, confidence_output, prob_chart, seg_output, interpretation]
    )

    gr.Markdown("""
    ---
    ### üìö System Information

    **Capabilities:**
    - üéØ 4-Class tumor classification (Glioma, Meningioma, Pituitary, No Tumor)
    - üîç Automatic tumor region visualization with U-Net
    - üéØ SAM refinement for enhanced boundary detection
    - üìä Confidence scoring across all classes

    **Visualization Modes:**
    - **U-Net Only:** Fast segmentation using trained U-Net model
    - **SAM Refined:** Enhanced boundaries using Segment Anything Model
    - **Comparison View:** Side-by-side comparison when SAM is enabled

    *For medical professionals and research purposes only.*
    """)

if __name__ == "__main__":
    print("\n" + "="*60)
    print("üöÄ Launching Brain Tumor Analysis System...")
    print("="*60)

    if sample_images:
        print(f"‚úÖ {len(sample_images)} sample images ready")
    else:
        print("‚ö†Ô∏è No sample images available")

    if segmenter and segmenter.sam_predictor:
        print("‚úÖ SAM refinement capability enabled")
    else:
        print("‚ö†Ô∏è SAM refinement not available")

    print("\nüì± Starting web interface...")
    demo.launch(share=True, debug=False)

üöÄ Initializing Brain Tumor Analysis System...
Loading classification model from: /content/drive/MyDrive/IP Project/resnet101_brain_tumor_20251105_163019.pth
‚úì Classification model loaded successfully!
  Classes: ['glioma', 'meningioma', 'notumor', 'pituitary']
‚úì Classification model loaded from: /content/drive/MyDrive/IP Project/resnet101_brain_tumor_20251105_163019.pth
‚úì SAM checkpoint found at: /content/sam_vit_b_01ec64.pth
Loading segmentation model from: /content/drive/MyDrive/IP Project/brain_tumor_segmentation_model.pth
  Creating U-Net with efficientnet-b3 encoder...
‚úì Segmentation model loaded successfully!
Loading SAM model from: /content/sam_vit_b_01ec64.pth
‚úì SAM model (vit_b) loaded successfully!
‚úì Segmentation model loaded from: /content/drive/MyDrive/IP Project/brain_tumor_segmentation_model.pth
System Status:
  Classification Model: ‚úÖ Ready
  Segmentation Model: ‚úÖ Ready
  SAM Refinement: ‚úÖ Available

üìÅ Loading sample images...
Loading sample image