In [16]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import cv2
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from sklearn.model_selection import train_test_split
import urllib.request
import tarfile
from pathlib import Path

sys.path.append('../src')
from gcn_models import GCN_Segmenter, GraphSAGE_Segmenter
from grabcut_ops import GrabCutRefiner
from graph_utils import build_superpixel_graph, iou_pixel
from superpixel import SuperpixelExtractor

In [18]:


DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

DATA_DIR = Path("data/oxford_pets")
def download_oxford_pets(data_dir):
    """Descarga el dataset Oxford-IIIT Pet"""
    images_url = "https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz"
    masks_url = "https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz"
    
    images_tar = data_dir / "images.tar.gz"
    masks_tar = data_dir / "annotations.tar.gz"
    
    if not (data_dir / "images").exists():
        print("Downloading images...")
        urllib.request.urlretrieve(images_url, images_tar)
        with tarfile.open(images_tar) as tar:
            tar.extractall(data_dir)
        print("Images downloaded!")
    
    if not (data_dir / "annotations").exists():
        print("Downloading annotations...")
        urllib.request.urlretrieve(masks_url, masks_tar)
        with tarfile.open(masks_tar) as tar:
            tar.extractall(data_dir)
        print("Annotations downloaded!")

download_oxford_pets(DATA_DIR)

def download_oxford_pets(data_dir):
    """Descarga el dataset Oxford-IIIT Pet"""
    # Asegurar que el directorio existe
    data_dir = Path(data_dir)
    data_dir.mkdir(parents=True, exist_ok=True)
    
    images_url = "https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz"
    masks_url = "https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz"
    
    images_tar = data_dir / "images.tar.gz"
    masks_tar = data_dir / "annotations.tar.gz"
    
    if not (data_dir / "images").exists():
        print("Downloading images...")
        try:
            urllib.request.urlretrieve(images_url, str(images_tar))
            print("Extracting images...")
            with tarfile.open(images_tar) as tar:
                tar.extractall(data_dir)
            images_tar.unlink()  # Eliminar tar después de extraer
            print("Images downloaded and extracted!")
        except Exception as e:
            print(f"Error downloading images: {e}")
    else:
        print("Images already downloaded!")
    
    if not (data_dir / "annotations").exists():
        print("Downloading annotations...")
        try:
            urllib.request.urlretrieve(masks_url, str(masks_tar))
            print("Extracting annotations...")
            with tarfile.open(masks_tar) as tar:
                tar.extractall(data_dir)
            masks_tar.unlink()  # Eliminar tar después de extraer
            print("Annotations downloaded and extracted!")
        except Exception as e:
            print(f"Error downloading annotations: {e}")
    else:
        print("Annotations already downloaded!")

# Descomentar para descargar
download_oxford_pets(DATA_DIR)



Using device: cuda
Downloading images...


FileNotFoundError: [Errno 2] No such file or directory: 'data/oxford_pets/images.tar.gz'

In [None]:
class PetSegmentationDataset:
    def __init__(self, data_dir, image_size=(256, 256)):
        self.data_dir = Path(data_dir)
        self.image_size = image_size
        self.images_dir = self.data_dir / "images"
        self.masks_dir = self.data_dir / "annotations" / "trimaps"
        
        if self.images_dir.exists():
            self.image_files = sorted(list(self.images_dir.glob("*.jpg")))[:50]  
        else:
            self.image_files = []
            print("Warning: Images directory not found!")
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        mask_path = self.masks_dir / (img_path.stem + ".png")
        
        img = cv2.imread(str(img_path))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, self.image_size)
        
        if mask_path.exists():
            mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
            mask = cv2.resize(mask, self.image_size, interpolation=cv2.INTER_NEAREST)
            mask = (mask == 1).astype(np.uint8) 
        else:
            mask = None
        
        return img, mask, img_path.name

def image_to_graph(img_rgb, sp_extractor, n_segments=200):
    """
    Convierte imagen en grafo de superpixels
    """
    segments = sp_extractor.compute(img_rgb)
    
    features = sp_extractor.features(img_rgb, segments)
    
    G, sp, node_color, centroids_norm, counts = build_superpixel_graph(
        img_rgb, n_segments=n_segments
    )
    
    edge_index = []
    edge_weight = []
    for u, v, data in G.edges(data=True):
        edge_index.append([u, v])
        edge_index.append([v, u])
        edge_weight.append(data['weight'])
        edge_weight.append(data['weight'])
    
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_weight = torch.tensor(edge_weight, dtype=torch.float)
    
    x = torch.cat([
        torch.tensor(node_color, dtype=torch.float),
        torch.tensor(centroids_norm, dtype=torch.float),
        torch.tensor(counts, dtype=torch.float).unsqueeze(1)
    ], dim=1)
    
    data = Data(x=x, edge_index=edge_index, edge_attr=edge_weight)
    
    return data, segments, G

class HybridSegmenter:
    def __init__(self, model, sp_extractor, grabcut_refiner, device='cpu'):
        self.model = model.to(device)
        self.sp_extractor = sp_extractor
        self.grabcut = grabcut_refiner
        self.device = device
    
    def predict(self, img_rgb, threshold=0.5, use_grabcut=True):
        """
        Pipeline completo:
        1. GCN predice en superpixels
        2. GrabCut refina la máscara
        """
        self.model.eval()
        
        # Crear grafo
        graph_data, segments, G = image_to_graph(img_rgb, self.sp_extractor)
        graph_data = graph_data.to(self.device)
        
        # Predicción GCN
        with torch.no_grad():
            logits = self.model(graph_data)
            probs = F.softmax(logits, dim=1)[:, 1]  # Probabilidad foreground
        
        # Convertir a máscara de superpixels
        sp_mask = (probs.cpu().numpy() > threshold).astype(np.uint8)
        
        # Mapear a píxeles
        pixel_mask = sp_mask[segments]
        
        if use_grabcut:
            # Inicializar GrabCut desde superpixels
            fg_ids = np.where(sp_mask == 1)[0]
            bg_ids = np.where(sp_mask == 0)[0]
            
            init_mask = self.grabcut.init_from_superpixels(segments, fg_ids, bg_ids)
            
            # Aplicar GrabCut
            img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
            refined_mask = self.grabcut.run(img_bgr, init_mask)
            final_mask = self.grabcut.to_binary(refined_mask)
        else:
            final_mask = pixel_mask
        
        return {
            'gcn_probs': probs.cpu().numpy(),
            'sp_mask': sp_mask,
            'pixel_mask': pixel_mask,
            'final_mask': final_mask,
            'segments': segments
        }

# %% Función de entrenamiento
def train_gcn(model, train_data_list, epochs=20, lr=0.001, device='cpu'):
    """
    Entrena el modelo GCN
    """
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = torch.nn.CrossEntropyLoss()
    
    losses = []
    
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0
        
        for data, y in train_data_list:
            data = data.to(device)
            y = y.to(device)
            
            optimizer.zero_grad()
            out = model(data)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
        
        avg_loss = epoch_loss / len(train_data_list)
        losses.append(avg_loss)
        
        if (epoch + 1) % 5 == 0:
            print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
    
    return losses

# %% Preparar datos de entrenamiento (ejemplo sintético)
def prepare_training_data(dataset, sp_extractor, num_samples=20):
    """
    Prepara datos para entrenar GCN usando máscaras ground truth
    """
    train_data_list = []
    
    for idx in range(min(num_samples, len(dataset))):
        img, mask, name = dataset[idx]
        
        if mask is None:
            continue
        
        # Crear grafo
        graph_data, segments, G = image_to_graph(img, sp_extractor)
        
        # Etiquetar superpixels según ground truth
        n_sp = segments.max() + 1
        sp_labels = np.zeros(n_sp, dtype=np.long)
        
        for k in range(n_sp):
            sp_pixels = (segments == k)
            fg_ratio = mask[sp_pixels].mean()
            sp_labels[k] = 1 if fg_ratio > 0.5 else 0
        
        y = torch.tensor(sp_labels, dtype=torch.long)
        
        train_data_list.append((graph_data, y))
    
    return train_data_list

# %% Visualización
def visualize_results(img, results, gt_mask=None):
    """
    Visualiza resultados del pipeline híbrido
    """
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    # Imagen original
    axes[0, 0].imshow(img)
    axes[0, 0].set_title("Original Image")
    axes[0, 0].axis('off')
    
    # Superpixels con predicción GCN
    segments = results['segments']
    sp_viz = results['sp_mask'][segments]
    axes[0, 1].imshow(img)
    axes[0, 1].imshow(sp_viz, alpha=0.5, cmap='jet')
    axes[0, 1].set_title("GCN Superpixel Prediction")
    axes[0, 1].axis('off')
    
    # Máscara píxel (GCN)
    axes[0, 2].imshow(results['pixel_mask'], cmap='gray')
    axes[0, 2].set_title("GCN Pixel Mask")
    axes[0, 2].axis('off')
    
    # Máscara final (GrabCut)
    axes[1, 0].imshow(results['final_mask'], cmap='gray')
    axes[1, 0].set_title("Final Mask (with GrabCut)")
    axes[1, 0].axis('off')
    
    # Overlay final
    axes[1, 1].imshow(img)
    axes[1, 1].imshow(results['final_mask'], alpha=0.5, cmap='Reds')
    axes[1, 1].set_title("Final Overlay")
    axes[1, 1].axis('off')
    
    # Ground truth (si existe)
    if gt_mask is not None:
        axes[1, 2].imshow(gt_mask, cmap='gray')
        axes[1, 2].set_title("Ground Truth")
        iou = iou_pixel(results['final_mask'], gt_mask)
        axes[1, 2].text(10, 30, f"IoU: {iou:.3f}", 
                       bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    else:
        axes[1, 2].axis('off')
    
    axes[1, 2].axis('off')
    
    plt.tight_layout()
    plt.show()

# %% DEMO: Pipeline completo
def run_demo():
    """
    Ejecuta demo completo del pipeline híbrido
    """
    print("="*60)
    print("GCN-GrabCut Hybrid Segmentation Demo")
    print("="*60)
    
    # 1. Inicializar componentes
    sp_extractor = SuperpixelExtractor(num_segments=150, compactness=10)
    grabcut_refiner = GrabCutRefiner(iterations=5)
    
    # 2. Crear modelo
    model = GCN_Segmenter(
        in_channels=6,  # [L, a, b, y_norm, x_norm, size]
        hidden_channels=32,
        out_channels=2,
        dropout=0.2
    )
    
    print(f"\nModel: {model.__class__.__name__}")
    print(f"Parameters: {sum(p.numel() for p in model.parameters())}")
    
    # 3. Cargar dataset
    dataset = PetSegmentationDataset(DATA_DIR, image_size=(256, 256))
    print(f"\nDataset size: {len(dataset)} images")
    
    if len(dataset) == 0:
        print("\n⚠️  No images found! Please download the Oxford-IIIT Pet Dataset:")
        print("    Uncomment the 'download_oxford_pets(DATA_DIR)' line above")
        return
    
    # 4. Entrenar modelo (opcional, con datos sintéticos si no hay GT)
    print("\n--- Training GCN ---")
    train_data = prepare_training_data(dataset, sp_extractor, num_samples=10)
    
    if len(train_data) > 0:
        losses = train_gcn(model, train_data, epochs=20, lr=0.001, device=DEVICE)
        
        plt.figure(figsize=(8, 4))
        plt.plot(losses)
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Training Loss')
        plt.grid(True)
        plt.show()
    else:
        print("⚠️  No training data with masks available")
    
    # 5. Crear segmentador híbrido
    hybrid = HybridSegmenter(model, sp_extractor, grabcut_refiner, device=DEVICE)
    
    # 6. Probar en imágenes
    print("\n--- Testing on images ---")
    for idx in [0, 5, 10]:  # Probar 3 imágenes
        if idx >= len(dataset):
            break
        
        img, mask, name = dataset[idx]
        print(f"\nProcessing: {name}")
        
        # Predicción
        results = hybrid.predict(img, threshold=0.5, use_grabcut=True)
        
        # Visualizar
        visualize_results(img, results, gt_mask=mask)
        
        # Métricas
        if mask is not None:
            iou_gcn = iou_pixel(results['pixel_mask'], mask)
            iou_final = iou_pixel(results['final_mask'], mask)
            print(f"  IoU (GCN only): {iou_gcn:.3f}")
            print(f"  IoU (with GrabCut): {iou_final:.3f}")
            print(f"  Improvement: {(iou_final - iou_gcn):.3f}")

# %% Ejecutar demo
if __name__ == "__main__":
    run_demo()

# %% Experimento: Comparar GCN vs GraphSAGE
def compare_architectures():
    """
    Compara diferentes arquitecturas de GNN
    """
    sp_extractor = SuperpixelExtractor(num_segments=150, compactness=10)
    grabcut_refiner = GrabCutRefiner(iterations=5)
    dataset = PetSegmentationDataset(DATA_DIR, image_size=(256, 256))
    
    if len(dataset) == 0:
        print("No images available!")
        return
    
    models = {
        'GCN': GCN_Segmenter(in_channels=6, hidden_channels=32, out_channels=2),
        'GraphSAGE': GraphSAGE_Segmenter(in_channels=6, hidden_channels=32, out_channels=2)
    }
    
    results_comparison = {}
    
    for model_name, model in models.items():
        print(f"\n{'='*40}")
        print(f"Testing: {model_name}")
        print('='*40)
        
        # Entrenar
        train_data = prepare_training_data(dataset, sp_extractor, num_samples=10)
        if len(train_data) > 0:
            train_gcn(model, train_data, epochs=15, device=DEVICE)
        
        # Evaluar
        hybrid = HybridSegmenter(model, sp_extractor, grabcut_refiner, device=DEVICE)
        
        ious = []
        for idx in range(min(5, len(dataset))):
            img, mask, _ = dataset[idx]
            if mask is None:
                continue
            
            results = hybrid.predict(img, use_grabcut=True)
            iou = iou_pixel(results['final_mask'], mask)
            ious.append(iou)
        
        results_comparison[model_name] = {
            'mean_iou': np.mean(ious) if ious else 0,
            'std_iou': np.std(ious) if ious else 0
        }
        
        print(f"Mean IoU: {results_comparison[model_name]['mean_iou']:.3f} ± "
              f"{results_comparison[model_name]['std_iou']:.3f}")
    
    # Visualizar comparación
    plt.figure(figsize=(8, 5))
    names = list(results_comparison.keys())
    means = [results_comparison[n]['mean_iou'] for n in names]
    stds = [results_comparison[n]['std_iou'] for n in names]
    
    plt.bar(names, means, yerr=stds, capsize=5, alpha=0.7)
    plt.ylabel('Mean IoU')
    plt.title('Architecture Comparison')
    plt.ylim(0, 1)
    plt.grid(axis='y', alpha=0.3)
    plt.show()

# Descomentar para ejecutar comparación
# compare_architectures()