# üè• Segmentaci√≥n Interactiva del H√∫mero con SAM

## Universidad Nacional de Colombia - Geometr√≠a Computacional

**Estudiante:** Thomas Molina Molina  
**Profesor:** Johan Felipe Garcia Vargas

Este notebook implementa segmentaci√≥n interactiva de im√°genes m√©dicas usando **Segment Anything Model (SAM)** con puntos positivos y negativos en tiempo real.

---

## üì¶ Paso 1: Instalaci√≥n de Dependencias

In [None]:
!pip install opencv-python-headless scikit-image matplotlib pillow torch torchvision
!pip install git+https://github.com/facebookresearch/segment-anything.git

## üì• Paso 2: Descargar Checkpoints de SAM

In [None]:
import os
import urllib.request

# Crear directorio para checkpoints
os.makedirs('checkpoints', exist_ok=True)

# Descargar SAM ViT-H (632M par√°metros - mejor calidad)
checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
checkpoint_path = "checkpoints/sam_vit_h_4b8939.pth"

if not os.path.exists(checkpoint_path):
    print("‚è¨ Descargando SAM ViT-H checkpoint (~2.4 GB)...")
    urllib.request.urlretrieve(checkpoint_url, checkpoint_path)
    print("‚úÖ Descarga completa!")
else:
    print("‚úÖ Checkpoint ya existe")

## üì§ Paso 3: Cargar Imagen M√©dica

**Opci√≥n 1:** Descargar imagen de ejemplo desde GitHub (recomendado)

**Opci√≥n 2:** Subir tu propia imagen manualmente

In [None]:
import urllib.request

# Crear directorio para im√°genes
os.makedirs('images', exist_ok=True)

# OPCI√ìN 1: Descargar imagen de ejemplo desde GitHub
print("? Opci√≥n 1: Descargar imagen de ejemplo desde GitHub")
image_url = "https://raw.githubusercontent.com/ThomasMolina19/interactive-medsam/main/dicom_pngs/I11.png"
image_path = "images/I11.png"

try:
    print(f"‚è¨ Descargando imagen desde GitHub...")
    urllib.request.urlretrieve(image_url, image_path)
    print(f"‚úÖ Imagen descargada: {image_path}")
except Exception as e:
    print(f"‚ö†Ô∏è Error al descargar: {e}")
    print("\nüì§ Por favor, sube tu imagen manualmente:")
    from google.colab import files
    uploaded = files.upload()
    uploaded_filename = list(uploaded.keys())[0]
    image_path = f"images/{uploaded_filename}"
    import shutil
    shutil.move(uploaded_filename, image_path)
    print(f"‚úÖ Imagen cargada: {image_path}")

# OPCI√ìN 2: Si prefieres subir tu propia imagen, descomenta estas l√≠neas:
# from google.colab import files
# print("üì§ Sube tu imagen m√©dica (PNG o JPG):")
# uploaded = files.upload()
# uploaded_filename = list(uploaded.keys())[0]
# image_path = f"images/{uploaded_filename}"
# import shutil
# shutil.move(uploaded_filename, image_path)
# print(f"‚úÖ Imagen cargada: {image_path}")

## üîß Paso 4: Importar Librer√≠as y Configurar Modelo

In [None]:
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torch
from segment_anything import sam_model_registry, SamPredictor
from scipy import ndimage
from skimage import morphology
import cv2
from IPython.display import display, clear_output
import ipywidgets as widgets

# Detectar device (GPU si est√° disponible)
device = "cuda" if torch.cuda.is_available() else "cpu"

# Mostrar informaci√≥n detallada del hardware
print("="*60)
print("üñ•Ô∏è  INFORMACI√ìN DE HARDWARE")
print("="*60)
print(f"Device seleccionado: {device.upper()}")

if torch.cuda.is_available():
    print(f"‚úÖ GPU detectada: {torch.cuda.get_device_name(0)}")
    print(f"üìä Memoria GPU total: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
    print(f"üìà Memoria GPU disponible: {(torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)) / 1024**3:.2f} GB")
    print(f"üî¢ N√∫mero de GPUs: {torch.cuda.device_count()}")
    print(f"‚ö° CUDA version: {torch.version.cuda}")
else:
    print("‚ö†Ô∏è  No se detect√≥ GPU - usando CPU")
    print("üí° Para habilitar GPU en Colab:")
    print("   Runtime ‚Üí Change runtime type ‚Üí Hardware accelerator: T4 GPU")

print(f"üêç PyTorch version: {torch.__version__}")
print("="*60)

# Cargar modelo SAM
print("\n‚è≥ Cargando SAM model...")
sam = sam_model_registry["vit_h"](checkpoint=checkpoint_path)
sam = sam.to(device)
predictor = SamPredictor(sam)

# Mostrar uso de memoria despu√©s de cargar el modelo
if torch.cuda.is_available():
    print(f"üìä Memoria GPU usada por el modelo: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
    print(f"üìà Memoria GPU disponible: {(torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)) / 1024**3:.2f} GB")

print("‚úÖ Modelo SAM cargado exitosamente!")

## üñºÔ∏è Paso 5: Cargar y Preprocesar Imagen

In [None]:
# Cargar imagen
img = np.array(Image.open(image_path).convert("RGB"))

# Mejorar contraste para im√°genes m√©dicas
img_enhanced = cv2.convertScaleAbs(img, alpha=1.2, beta=10)

# Configurar predictor con la imagen
predictor.set_image(img_enhanced)

H, W = img.shape[:2]
print(f"üìè Dimensiones de imagen: {W} x {H}")

# Mostrar imagen original
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
ax1.imshow(img)
ax1.set_title("Imagen Original")
ax1.axis('off')

ax2.imshow(img_enhanced)
ax2.set_title("Imagen con Contraste Mejorado")
ax2.axis('off')

plt.tight_layout()
plt.show()

## üéØ Paso 6: Segmentaci√≥n Interactiva

### Instrucciones:
- **Click DERECHO**: Agregar punto POSITIVO (verde ‚≠ê) - marca el objeto de inter√©s
- **Click IZQUIERDO**: Agregar punto NEGATIVO (rojo ‚úñÔ∏è) - marca regiones a excluir
- **Tecla 'z'**: Deshacer √∫ltimo punto
- **Tecla 'c'**: Limpiar todos los puntos
- **Cierra la ventana o presiona ENTER**: Finalizar selecci√≥n

In [None]:
# Instalar widget interactivo para matplotlib en Colab
!pip install ipympl -q
from google.colab import output
output.enable_custom_widget_manager()

In [None]:
%matplotlib widget

class PointSelector:
    def __init__(self, ax_img, ax_mask):
        self.positive_points = []
        self.negative_points = []
        self.ax_img = ax_img
        self.ax_mask = ax_mask
        self.point_markers = []
        self.mask_display = None
        
    def update_segmentation(self):
        """Update segmentation in real-time"""
        # Clear previous mask
        if self.mask_display is not None:
            self.mask_display.remove()
            self.mask_display = None
        
        # If no points, return
        if len(self.positive_points) == 0 and len(self.negative_points) == 0:
            self.ax_mask.clear()
            self.ax_mask.imshow(img)
            self.ax_mask.set_title("M√°scara (agrega puntos para ver)")
            self.ax_mask.axis('off')
            fig.canvas.draw()
            return
        
        # Prepare points and labels
        input_points = []
        input_labels = []
        
        for point in self.positive_points:
            input_points.append(point)
            input_labels.append(1)
        
        for point in self.negative_points:
            input_points.append(point)
            input_labels.append(0)
        
        input_points = np.array(input_points)
        input_labels = np.array(input_labels)
        
        # Generate mask
        try:
            masks, scores, _ = predictor.predict(
                point_coords=input_points,
                point_labels=input_labels,
                multimask_output=True
            )
            
            best_mask = masks[np.argmax(scores)]
            
            # Display mask on right subplot
            self.ax_mask.clear()
            self.ax_mask.imshow(img)
            self.mask_display = self.ax_mask.imshow(best_mask, alpha=0.6, cmap='Blues')
            
            # Show points on mask view too
            for point in self.positive_points:
                self.ax_mask.plot(point[0], point[1], 'g*', markersize=15, markeredgewidth=2)
            for point in self.negative_points:
                self.ax_mask.plot(point[0], point[1], 'rx', markersize=12, markeredgewidth=3)
            
            score = scores[np.argmax(scores)]
            area = np.sum(best_mask)
            self.ax_mask.set_title(f"Segmentaci√≥n | Score: {score:.3f} | √Årea: {area} px")
            self.ax_mask.axis('off')
            
        except Exception as e:
            print(f"‚ö†Ô∏è Error en segmentaci√≥n: {e}")
        
        fig.canvas.draw()
        
    def onclick(self, event):
        if event.inaxes != self.ax_img:
            return
        if event.xdata is None or event.ydata is None:
            return
            
        x, y = event.xdata, event.ydata
        
        # Bot√≥n izquierdo (1) = Punto NEGATIVO (rojo)
        if event.button == 1:
            self.negative_points.append([x, y])
            marker = self.ax_img.plot(x, y, 'rx', markersize=15, markeredgewidth=3)[0]
            self.point_markers.append(('neg', marker))
            print(f"‚ùå Punto NEGATIVO agregado: ({x:.0f}, {y:.0f})")
            
        # Bot√≥n derecho (3) = Punto POSITIVO (verde)
        elif event.button == 3:
            self.positive_points.append([x, y])
            marker = self.ax_img.plot(x, y, 'g*', markersize=20, markeredgewidth=2)[0]
            self.point_markers.append(('pos', marker))
            print(f"‚úÖ Punto POSITIVO agregado: ({x:.0f}, {y:.0f})")
        
        # Update title with counts
        self.ax_img.set_title(f"‚úÖ Positivos: {len(self.positive_points)} | ‚ùå Negativos: {len(self.negative_points)} | 'z': deshacer | 'c': limpiar")
        
        # Update segmentation in real-time
        self.update_segmentation()
        
    def onkey(self, event):
        """Handle keyboard events"""
        # Z = Undo last point
        if event.key == 'z':
            if len(self.point_markers) > 0:
                point_type, marker = self.point_markers.pop()
                marker.remove()
                
                if point_type == 'pos' and len(self.positive_points) > 0:
                    removed = self.positive_points.pop()
                    print(f"‚Ü©Ô∏è  Deshecho punto POSITIVO: ({removed[0]:.0f}, {removed[1]:.0f})")
                elif point_type == 'neg' and len(self.negative_points) > 0:
                    removed = self.negative_points.pop()
                    print(f"‚Ü©Ô∏è  Deshecho punto NEGATIVO: ({removed[0]:.0f}, {removed[1]:.0f})")
                
                self.ax_img.set_title(f"‚úÖ Positivos: {len(self.positive_points)} | ‚ùå Negativos: {len(self.negative_points)} | 'z': deshacer | 'c': limpiar")
                self.update_segmentation()
        
        # C = Clear all points
        elif event.key == 'c':
            for _, marker in self.point_markers:
                marker.remove()
            self.point_markers.clear()
            self.positive_points.clear()
            self.negative_points.clear()
            print("üßπ Todos los puntos limpiados")
            self.ax_img.set_title(f"‚úÖ Positivos: 0 | ‚ùå Negativos: 0 | 'z': deshacer | 'c': limpiar")
            self.update_segmentation()

# Create the selector object with 2 subplots
fig, (ax_img, ax_mask) = plt.subplots(1, 2, figsize=(20, 8))

# Left: Image with points
ax_img.imshow(img)
ax_img.set_title("üéØ Imagen Original | Click derecho = POSITIVO | Click izquierdo = NEGATIVO")
ax_img.axis('off')

# Right: Real-time mask
ax_mask.imshow(img)
ax_mask.set_title("Segmentaci√≥n (agrega puntos para ver)")
ax_mask.axis('off')

selector_obj = PointSelector(ax_img, ax_mask)

# Connect events
fig.canvas.mpl_connect('button_press_event', selector_obj.onclick)
fig.canvas.mpl_connect('key_press_event', selector_obj.onkey)

plt.tight_layout()
plt.show()

print("üéØ Selecci√≥n de puntos iniciando...")
print("   - Click DERECHO: Marca puntos POSITIVOS (objeto de inter√©s)")
print("   - Click IZQUIERDO: Marca puntos NEGATIVOS (para omitir contornos)")
print("   - Tecla 'z': Deshacer √∫ltimo punto")
print("   - Tecla 'c': Limpiar todos los puntos")

## üî¨ Paso 7: Generar Segmentaci√≥n Final y Post-procesamiento

In [None]:
# Prepare points and labels for SAM
input_points = []
input_labels = []

# Add positive points (label = 1)
for point in selector_obj.positive_points:
    input_points.append(point)
    input_labels.append(1)

# Add negative points (label = 0)
for point in selector_obj.negative_points:
    input_points.append(point)
    input_labels.append(0)

if len(input_points) == 0:
    print("‚ö†Ô∏è No se seleccionaron puntos. Por favor ejecuta la celda anterior y selecciona puntos.")
else:
    input_points = np.array(input_points)
    input_labels = np.array(input_labels)

    print(f"‚úÖ Total de puntos: {len(input_points)}")
    print(f"   - Positivos: {len(selector_obj.positive_points)}")
    print(f"   - Negativos: {len(selector_obj.negative_points)}")

    # Generate masks using the selected points
    masks, scores, _ = predictor.predict(
        point_coords=input_points,
        point_labels=input_labels,
        multimask_output=True
    )

    # Select best mask
    best_mask = masks[np.argmax(scores)]

    # Post-process mask
    def refine_medical_mask(mask):
        """Clean up the segmentation mask for medical images"""
        # Remove small objects
        mask_clean = morphology.remove_small_objects(mask, min_size=500)
        
        # Fill holes
        mask_filled = ndimage.binary_fill_holes(mask_clean)
        
        # Smooth with morphological operations
        kernel = morphology.disk(2)
        mask_smooth = morphology.binary_opening(mask_filled, kernel)
        mask_smooth = morphology.binary_closing(mask_smooth, kernel)
        
        return mask_smooth

    refined_mask = refine_medical_mask(best_mask)
    
    print("‚úÖ Segmentaci√≥n completada y refinada!")

## üìä Paso 8: Visualizaci√≥n de Resultados

In [None]:
%matplotlib inline

# Enhanced visualization
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

# Row 1: Original results
axes[0,0].imshow(img)
axes[0,0].set_title("Imagen Original")
axes[0,0].axis('off')

axes[0,1].imshow(img)
axes[0,1].imshow(best_mask, alpha=0.5, cmap='Reds')
# Show selected points
for point in selector_obj.positive_points:
    axes[0,1].plot(point[0], point[1], 'g*', markersize=15, markeredgewidth=2)
for point in selector_obj.negative_points:
    axes[0,1].plot(point[0], point[1], 'rx', markersize=12, markeredgewidth=3)
axes[0,1].set_title("Resultado SAM con Puntos")
axes[0,1].axis('off')

axes[0,2].imshow(best_mask, cmap='gray')
axes[0,2].set_title("M√°scara Original")
axes[0,2].axis('off')

# Row 2: Enhanced results
axes[1,0].imshow(img_enhanced)
axes[1,0].set_title("Imagen con Contraste Mejorado")
axes[1,0].axis('off')

axes[1,1].imshow(img)
axes[1,1].imshow(refined_mask, alpha=0.5, cmap='Blues')
# Show selected points on refined view too
for point in selector_obj.positive_points:
    axes[1,1].plot(point[0], point[1], 'g*', markersize=15, markeredgewidth=2)
for point in selector_obj.negative_points:
    axes[1,1].plot(point[0], point[1], 'rx', markersize=12, markeredgewidth=3)
axes[1,1].set_title("Segmentaci√≥n Refinada")
axes[1,1].axis('off')

axes[1,2].imshow(refined_mask, cmap='gray')
axes[1,2].set_title("M√°scara Refinada")
axes[1,2].axis('off')

plt.tight_layout()
plt.show()

# Results summary
print(f"\n{'='*50}")
print(f"üéØ Segmentaci√≥n completada en {device}")
print(f"üü¢ Puntos positivos: {len(selector_obj.positive_points)}")
print(f"üî¥ Puntos negativos: {len(selector_obj.negative_points)}")
print(f"üìè √Årea de m√°scara: {np.sum(refined_mask)} p√≠xeles")
print(f"‚≠ê Score de mejor m√°scara: {scores[np.argmax(scores)]:.4f}")
print(f"üé≠ Total de m√°scaras generadas: {len(masks)}")
print(f"{'='*50}")

## üíæ Paso 9: Guardar Resultados

In [None]:
# Create output directory
os.makedirs('results', exist_ok=True)

# Save refined mask
refined_mask_pil = Image.fromarray((refined_mask * 255).astype(np.uint8))
refined_mask_pil.save("results/segmentation_mask.png")

# Save overlay
overlay = img.copy()
overlay[refined_mask] = (overlay[refined_mask] * 0.5 + np.array([0, 0, 255]) * 0.5).astype(np.uint8)
overlay_pil = Image.fromarray(overlay)
overlay_pil.save("results/segmentation_overlay.png")

print("üíæ Resultados guardados:")
print("   - results/segmentation_mask.png")
print("   - results/segmentation_overlay.png")

# Download results
print("\nüì• Descargar resultados:")
files.download("results/segmentation_mask.png")
files.download("results/segmentation_overlay.png")

---

## üìù Notas Finales

### M√©tricas del Modelo:
- **Modelo usado**: SAM ViT-H (632M par√°metros)
- **Dataset de entrenamiento**: SA-1B (11M im√°genes naturales)
- **Arquitectura**: Vision Transformer + Prompt Encoder + Mask Decoder

### Ventajas de este enfoque:
1. ‚úÖ **M√≠nima interacci√≥n**: 1-2 puntos suficientes
2. ‚úÖ **Tiempo real**: Actualizaci√≥n instant√°nea
3. ‚úÖ **Robustez**: Funciona en diferentes orientaciones
4. ‚úÖ **Reproducible**: C√≥digo completamente documentado

### Limitaciones:
1. ‚ö†Ô∏è Requiere preprocesamiento manual de im√°genes DICOM
2. ‚ö†Ô∏è Solo trabaja en 2D (no explota naturaleza volum√©trica)
3. ‚ö†Ô∏è Sin ground truth para validaci√≥n cuantitativa

---

**Desarrollado por:** Thomas Molina Molina  
**Curso:** Geometr√≠a Computacional - Universidad Nacional de Colombia  
**Repositorio:** [https://github.com/ThomasMolina19/interactive-medsam](https://github.com/ThomasMolina19/interactive-medsam)