# 8. Prueba de Modelos SAM 3 (REAL) - FIXED

**Objetivo:** Evaluar el nuevo modelo **SAM 3** (lanzado en Noviembre 2025) para determinar si sus capacidades avanzadas de segmentación mejoran el refinamiento de detecciones de corrosión.

## Novedades de SAM 3
- **Promptable Concept Segmentation:** Capacidad de entender conceptos visuales.
- **Mejor Segmentación Interactiva:** Supera a SAM 2 en precisión.
- **Arquitectura Unificada:** Mejor manejo de contexto.

## Metodología
1. **Modelo:** Uso de `sam3_b.pt` (Base) o `sam3_t.pt` (Tiny) para evitar errores de memoria.
2. **Correcciones:** Incluye todos los fixes previos (conf=0.25, caché, etc.)
3. **Comparación:** Se evaluará contra YOLO Base y SAM 2.1

In [1]:
# Instalación de dependencias (Asegurar última versión de ultralytics para soporte SAM 3)
!pip install -U -q ultralytics pandas matplotlib opencv-python seaborn scikit-learn pillow tqdm ftfy
!pip install -q 'git+https://github.com/facebookresearch/segment-anything-2.git'  # Dependencias base

ERROR: Exception:
Traceback (most recent call last):
  File "C:\Users\lbuln\anaconda3\Lib\site-packages\pip\_vendor\urllib3\response.py", line 438, in _error_catcher
    yield
  File "C:\Users\lbuln\anaconda3\Lib\site-packages\pip\_vendor\urllib3\response.py", line 561, in read
    data = self._fp_read(amt) if not fp_closed else b""
           ~~~~~~~~~~~~~^^^^^
  File "C:\Users\lbuln\anaconda3\Lib\site-packages\pip\_vendor\urllib3\response.py", line 527, in _fp_read
    return self._fp.read(amt) if amt is not None else self._fp.read()
           ~~~~~~~~~~~~~^^^^^
  File "C:\Users\lbuln\anaconda3\Lib\site-packages\pip\_vendor\cachecontrol\filewrapper.py", line 102, in read
    self.__buf.write(data)
    ~~~~~~~~~~~~~~~~^^^^^^
  File "C:\Users\lbuln\anaconda3\Lib\tempfile.py", line 499, in func_wrapper
    return func(*args, **kwargs)
OSError: [Errno 28] No space left on device

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  Fi

In [2]:
import os
import cv2
import torch
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from ultralytics import YOLO, SAM
from ultralytics.utils.metrics import box_iou
from pathlib import Path
from tqdm import tqdm
from scipy.ndimage import binary_fill_holes, binary_erosion, binary_dilation
from sklearn.metrics import precision_recall_curve, average_precision_score

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Dispositivo: {device}")
if device == 'cuda':
    torch.cuda.empty_cache()

# Configuración de visualización
sns.set_style('whitegrid')
plt.rcParams['figure.dpi'] = 100
plt.rcParams['font.size'] = 10

Dispositivo: cpu


## 1. Sistema de Caché (Independiente)

In [3]:
class PredictionCache:
    """Caché para almacenar predicciones YOLO y evitar recálculos."""
    
    def __init__(self, cache_file='yolo_predictions_cache_sam3.pkl'):
        self.cache_file = cache_file
        self.cache = self._load_cache()
    
    def _load_cache(self):
        if os.path.exists(self.cache_file):
            with open(self.cache_file, 'rb') as f:
                return pickle.load(f)
        return {}
    
    def save(self):
        with open(self.cache_file, 'wb') as f:
            pickle.dump(self.cache, f)
    
    def get(self, image_path):
        return self.cache.get(str(image_path))
    
    def set(self, image_path, predictions):
        self.cache[str(image_path)] = predictions
    
    def clear(self):
        self.cache = {}
        if os.path.exists(self.cache_file):
            os.remove(self.cache_file)

pred_cache = PredictionCache()

## 2. Carga de Modelos (SAM 3 REAL)

In [4]:
# Rutas
TEST_IMAGES_DIR = Path('./dataset_yolo/images/test')
TEST_LABELS_DIR = Path('./dataset_yolo/labels/test')
RESULTS_DIR = Path('./resultados_sam3_real')
RESULTS_DIR.mkdir(exist_ok=True)

# Cargar YOLO (Modelo Base)
import glob
yolo_candidates = [
    './modelos_entrenados/modelo-acuatico-m.pt',
    './modelos_entrenados/modelo-mixto-m.pt',
    *glob.glob('./modelos_entrenados/*-m.pt')
]
MODELO_YOLO_PATH = next((p for p in yolo_candidates if os.path.exists(p)), None)
if not MODELO_YOLO_PATH:
    raise FileNotFoundError("No se encontró modelo YOLO Medium")

print(f"Cargando YOLO: {MODELO_YOLO_PATH}...")
model_yolo = YOLO(MODELO_YOLO_PATH)
print("✓ YOLO cargado")

# Cargar SAM 3 (REAL) con manejo de memoria
print("\nIntentando cargar SAM 3...")
sam_models = {}
# Lista de prioridad para SAM 3 - Priorizamos modelos más ligeros para evitar MemoryError
sam_candidates = [
    'sam3_b.pt',   # SAM 3 Base (Recomendado por memoria)
    'sam3_t.pt',   # SAM 3 Tiny (Fallback ligero)
    'sam3_l.pt',   # SAM 3 Large (Solo si hay mucha RAM/VRAM)
]

model_sam = None
sam_name = ""

for weight in sam_candidates:
    try:
        print(f"Intentando descargar/cargar {weight}...")
        if device == 'cuda':
            torch.cuda.empty_cache()
        model_sam = SAM(weight)
        sam_name = weight.replace('.pt', '')
        print(f"✓ ÉXITO: {sam_name.upper()} cargado correctamente")
        break
    except MemoryError:
        print(f"  - Falló {weight}: Memoria insuficiente (MemoryError)")
    except Exception as e:
        print(f"  - Falló {weight}: {str(e)[:100]}...")

if model_sam is None:
    print("\n⚠️ ADVERTENCIA: No se pudo cargar SAM 3. Intentando SAM 2.1 como fallback...")
    try:
        if device == 'cuda':
            torch.cuda.empty_cache()
        model_sam = SAM('sam2.1_b.pt') # Usar Base para evitar MemoryError
        sam_name = 'sam2.1_b'
        print(f"✓ Fallback: {sam_name.upper()} cargado")
    except:
        raise RuntimeError("No se pudo cargar ningún modelo SAM. Intenta reiniciar el kernel para liberar memoria.")

print(f"\n→ Modelo activo para pruebas: {sam_name.upper()}")

Cargando YOLO: ./modelos_entrenados/modelo-acuatico-m.pt...
✓ YOLO cargado

Intentando cargar SAM 3...
Intentando descargar/cargar sam3_b.pt...
  - Falló sam3_b.pt: No module named 'ftfy'...
Intentando descargar/cargar sam3_t.pt...
  - Falló sam3_t.pt: No module named 'ftfy'...
Intentando descargar/cargar sam3_l.pt...
  - Falló sam3_l.pt: No module named 'ftfy'...

⚠️ ADVERTENCIA: No se pudo cargar SAM 3. Intentando SAM 2.1 como fallback...


RuntimeError: No se pudo cargar ningún modelo SAM. Intenta reiniciar el kernel para liberar memoria.

## 3. Funciones de Utilidad (Optimizadas)

In [None]:
def read_yolo_labels(label_path, img_width, img_height):
    """Lee etiquetas YOLO y retorna [class, x1, y1, x2, y2]."""
    boxes = []
    if not os.path.exists(label_path):
        return np.array([])
    
    with open(label_path, 'r') as f:
        for line in f:
            parts = list(map(float, line.strip().split()))
            cls, x_c, y_c, w, h = parts
            x1 = (x_c - w/2) * img_width
            y1 = (y_c - h/2) * img_height
            x2 = (x_c + w/2) * img_width
            y2 = (y_c + h/2) * img_height
            boxes.append([int(cls), x1, y1, x2, y2])
    
    return np.array(boxes) if boxes else np.array([])

def mask_to_box(mask):
    """Convierte máscara a box [x1, y1, x2, y2]."""
    y_indices, x_indices = np.where(mask > 0)
    if len(x_indices) == 0:
        return None
    return [np.min(x_indices), np.min(y_indices), 
            np.max(x_indices), np.max(y_indices)]

def postprocess_mask(mask, min_area=50):
    """Post-procesamiento suave (solo dilatación)."""
    filled = binary_fill_holes(mask)
    opened = binary_dilation(filled)
    if np.sum(opened) < min_area:
        return mask
    return opened.astype(np.uint8)

## 4. Estrategias de Refinamiento (SAM 3)

In [None]:
def refine_basic(img, yolo_boxes, sam_model):
    """Estrategia 1: SAM con box prompts."""
    if len(yolo_boxes) == 0:
        return yolo_boxes
    
    prompts = yolo_boxes[:, :4]
    # SAM 3 soporta batch inference nativo
    results = sam_model(img, bboxes=prompts, verbose=False)[0]
    
    refined_boxes = []
    if results.masks is not None:
        for i, mask in enumerate(results.masks.data.cpu().numpy()):
            box = mask_to_box(mask)
            if box:
                refined_boxes.append(box + [yolo_boxes[i, 4], yolo_boxes[i, 5]])
            else:
                refined_boxes.append(yolo_boxes[i].tolist())
    else:
        refined_boxes = yolo_boxes.tolist()
    
    return np.array(refined_boxes)

def refine_adaptive(img, yolo_boxes, sam_model, conf_threshold=0.25):
    """Estrategia 2: SAM solo en detecciones de confianza media-alta."""
    if len(yolo_boxes) == 0:
        return yolo_boxes
    
    refined_boxes = []
    high_conf_boxes = yolo_boxes[yolo_boxes[:, 4] >= conf_threshold]
    
    if len(high_conf_boxes) > 0:
        prompts = high_conf_boxes[:, :4]
        results = sam_model(img, bboxes=prompts, verbose=False)[0]
        
        if results.masks is not None:
            for i, mask in enumerate(results.masks.data.cpu().numpy()):
                box = mask_to_box(mask)
                if box:
                    refined_boxes.append(box + [high_conf_boxes[i, 4], high_conf_boxes[i, 5]])
                else:
                    refined_boxes.append(high_conf_boxes[i].tolist())
    
    low_conf_boxes = yolo_boxes[yolo_boxes[:, 4] < conf_threshold]
    if len(low_conf_boxes) > 0:
        refined_boxes.extend(low_conf_boxes.tolist())
    
    return np.array(refined_boxes) if refined_boxes else yolo_boxes

def refine_postprocess(img, yolo_boxes, sam_model):
    """Estrategia 3: SAM + Post-procesamiento."""
    if len(yolo_boxes) == 0:
        return yolo_boxes
    
    prompts = yolo_boxes[:, :4]
    results = sam_model(img, bboxes=prompts, verbose=False)[0]
    
    refined_boxes = []
    if results.masks is not None:
        for i, mask in enumerate(results.masks.data.cpu().numpy()):
            processed_mask = postprocess_mask(mask)
            box = mask_to_box(processed_mask)
            if box:
                refined_boxes.append(box + [yolo_boxes[i, 4], yolo_boxes[i, 5]])
            else:
                refined_boxes.append(yolo_boxes[i].tolist())
    else:
        refined_boxes = yolo_boxes.tolist()
    
    return np.array(refined_boxes)

## 5. Evaluación

In [None]:
def calculate_batch_stats(predictions, targets, iou_threshold=0.5):
    """Calcula estadísticas TP/FP."""
    if len(predictions) == 0:
        return []
    if len(targets) == 0:
        return [[0, float(pred[4]), float(pred[5]), 0.0] for pred in predictions]
    
    pred_boxes = torch.tensor(predictions[:, :4], dtype=torch.float32)
    target_boxes = torch.tensor(targets[:, 1:], dtype=torch.float32)
    ious = box_iou(pred_boxes, target_boxes)
    
    stats = []
    detected_targets = set()
    sorted_indices = np.argsort(-predictions[:, 4])
    
    for idx in sorted_indices:
        pred = predictions[idx]
        iou_row = ious[idx]
        best_iou, best_target_idx = 0, -1
        
        for t_idx, iou in enumerate(iou_row):
            if t_idx not in detected_targets and targets[t_idx, 0] == pred[5]:
                if iou > best_iou:
                    best_iou, best_target_idx = iou, t_idx
        
        if best_iou >= iou_threshold:
            detected_targets.add(best_target_idx)
            stats.append([1, pred[4], pred[5], float(best_iou)])
        else:
            stats.append([0, pred[4], pred[5], 0.0])
    
    return stats

def compute_pr_metrics(stats, total_gt):
    """Calcula métricas finales."""
    if len(stats) == 0:
        return {'precision': 0, 'recall': 0, 'f1': 0, 'mAP': 0, 'avg_iou': 0}
    
    stats = np.array(stats)
    stats = stats[np.argsort(-stats[:, 1])]
    
    tp = stats[:, 0]
    fp = 1 - tp
    tp_cumsum = np.cumsum(tp)
    fp_cumsum = np.cumsum(fp)
    
    precision = tp_cumsum / (tp_cumsum + fp_cumsum + 1e-16)
    recall = tp_cumsum / (total_gt + 1e-16)
    f1 = 2 * (precision * recall) / (precision + recall + 1e-16)
    
    # mAP
    recall_thresholds = np.linspace(0, 1, 11)
    precision_interpolated = np.zeros(11)
    for i, r in enumerate(recall_thresholds):
        precisions_at_recall = precision[recall >= r]
        precision_interpolated[i] = np.max(precisions_at_recall) if len(precisions_at_recall) > 0 else 0
    mAP = np.mean(precision_interpolated)
    
    best_idx = np.argmax(f1) if len(f1) > 0 else 0
    tp_indices = stats[:, 0] == 1
    avg_iou = np.mean(stats[tp_indices, 3]) if np.any(tp_indices) else 0
    
    return {
        'precision': precision[best_idx],
        'recall': recall[best_idx],
        'f1': f1[best_idx],
        'mAP': mAP,
        'avg_iou': avg_iou
    }

In [None]:
# Ejecutar Evaluación
image_files = sorted(list(TEST_IMAGES_DIR.glob('*.jpg')) + list(TEST_IMAGES_DIR.glob('*.png')))
stats_collections = {'yolo': [], 'sam_basic': [], 'sam_adaptive': [], 'sam_postprocess': []}
total_gt = 0

print(f"Evaluando {len(image_files)} imágenes con {sam_name.upper()}...")

for img_path in tqdm(image_files):
    img = cv2.imread(str(img_path))
    if img is None: continue
    
    h, w = img.shape[:2]
    label_path = TEST_LABELS_DIR / (img_path.stem + '.txt')
    gt_boxes = read_yolo_labels(label_path, w, h)
    if len(gt_boxes) > 0: total_gt += len(gt_boxes)
    
    # YOLO (con conf=0.25)
    cached_pred = pred_cache.get(img_path)
    if cached_pred is not None:
        yolo_boxes = cached_pred
    else:
        results = model_yolo(img, conf=0.25, iou=0.45, verbose=False)[0]
        yolo_boxes = results.boxes.data.cpu().numpy()
        pred_cache.set(img_path, yolo_boxes)
    
    stats_collections['yolo'].extend(calculate_batch_stats(yolo_boxes, gt_boxes))
    
    # SAM Strategies
    stats_collections['sam_basic'].extend(calculate_batch_stats(refine_basic(img, yolo_boxes, model_sam), gt_boxes))
    stats_collections['sam_adaptive'].extend(calculate_batch_stats(refine_adaptive(img, yolo_boxes, model_sam), gt_boxes))
    stats_collections['sam_postprocess'].extend(calculate_batch_stats(refine_postprocess(img, yolo_boxes, model_sam), gt_boxes))

pred_cache.save()
print("✓ Evaluación completada")

In [None]:
# Resultados
results = {}
for name, stats in stats_collections.items():
    results[name] = compute_pr_metrics(stats, total_gt)

df_results = pd.DataFrame({
    'Estrategia': ['YOLO Base', 'SAM 3 Básico', 'SAM 3 Adaptativo', 'SAM 3 Post-Proc.'],
    'F1-Score': [results['yolo']['f1'], results['sam_basic']['f1'], results['sam_adaptive']['f1'], results['sam_postprocess']['f1']],
    'mAP@0.5': [results['yolo']['mAP'], results['sam_basic']['mAP'], results['sam_adaptive']['mAP'], results['sam_postprocess']['mAP']],
    'IoU Prom.': [results['yolo']['avg_iou'], results['sam_basic']['avg_iou'], results['sam_adaptive']['avg_iou'], results['sam_postprocess']['avg_iou']]
})

print("\n" + "="*80)
print(f"RESULTADOS SAM 3 REAL ({sam_name.upper()})")
print("="*80)
print(df_results.to_string(index=False, float_format=lambda x: f'{x:.4f}'))
print("="*80)

df_results.to_csv(RESULTS_DIR / 'metricas_sam3.csv', index=False)