In [None]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset
from imageio import imread
from torchvision. transforms. functional import to_pil_image
import matplotlib.pyplot as plt

class CCPDataset(Dataset):
    """
    Dataset para segmentación semántica multiclase. 
    
    Args:
        df: DataFrame con columnas 'image_path', 'mask_path', y opcionalmente 'coords'
        patch_size: Tamaño del patch a extraer
        transforms: Lista [joint_transform, image_transform] donde:
            - joint_transform: se aplica a imagen y máscara (ej: flips, rotaciones)
            - image_transform: se aplica solo a imagen (ej: normalización)
        mode: 'train' (random crop), 'eval' (crop con coords), 'full' (imagen completa)
        class_dict_path: Ruta al CSV con columnas 'r', 'g', 'b' para mapeo de colores
    """

    def __init__(self, df, patch_size=224, transforms=None, mode='train', class_dict_path='clothes/class_dict.csv'):
        super(CCPDataset, self).__init__()
        self.df = df. reset_index(drop=True)
        self.ps = patch_size
        self. transforms = transforms
        self.mode = mode

        # Cargar diccionario de clases
        class_df = pd.read_csv(class_dict_path)
        
        # Crear mapping RGB -> class_id usando bit packing
        r = class_df['r'].to_numpy(dtype=np.uint32)
        g = class_df['g'].to_numpy(dtype=np.uint32)
        b = class_df['b'].to_numpy(dtype=np.uint32)

        keys = (r << 16) | (g << 8) | b

        self.color_to_class = dict(zip(keys. tolist(), range(len(keys))))
        self.num_classes = len(self.color_to_class)
        
        # Guardar nombres de clases si están disponibles
        if 'class_name' in class_df.columns:
            self.class_names = class_df['class_name']. tolist()
            if self.class_names[0] == 'null':
                self.class_names[0] = 'background'
        else:
            self.class_names = [f'class_{i}' for i in range(self.num_classes)]

    def mask_rgb_to_ids(self, mask):
        """
        Convierte máscara RGB a índices de clase usando bit packing.
        Colores no encontrados se mapean a clase 0 (background).
        """
        packed = (mask[...,0].astype(np.uint32) << 16) | \
                 (mask[...,1].astype(np.uint32) << 8)  | \
                  mask[...,2].astype(np.uint32)
        
        # Verificar colores desconocidos (opcional, comentar si no necesitas warnings)
        unique_packed = np.unique(packed)
        unknown = [p for p in unique_packed if p not in self.color_to_class]
        if unknown and self.mode == 'train':  # Solo warning en train para no saturar logs
            unknown_colors = [(p >> 16, (p >> 8) & 0xFF, p & 0xFF) for p in unknown[:3]]
            print(f"Warning: Found {len(unknown)} unknown colors (showing first 3): {unknown_colors}")
        
        mapped = np.vectorize(self.color_to_class.get)(packed, 0)
        return mapped. astype(np.int64)

    def __random_crop__(self, img, mask):
        """Extrae un crop aleatorio de tamaño self.ps"""
        H, W = img. shape[:2]
        
        if H < self.ps or W < self. ps:
            # Padding si la imagen es más pequeña que el patch
            pad_h = max(0, self.ps - H)
            pad_w = max(0, self.ps - W)
            img = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), mode='reflect')
            mask = np.pad(mask, ((0, pad_h), (0, pad_w), (0, 0)) if mask. ndim == 3 else ((0, pad_h), (0, pad_w)), mode='reflect')
            H, W = img.shape[:2]
        
        # Random top-left corner
        top = np.random.randint(0, H - self.ps + 1)
        left = np.random.randint(0, W - self.ps + 1)
        
        # Extract patch
        img_patch = img[top:top+self.ps, left:left+self.ps]
        mask_patch = mask[top:top+self.ps, left:left+self.ps]
        
        return img_patch, mask_patch

    def __getitem__(self, idx):
        r = self.df. iloc[idx]

        # Leer imagen y máscara
        x = imread(r.image_path)
        y = imread(r.mask_path)

        # Aplicar crop según el modo
        if self.mode == 'train':
            # Random crop para entrenamiento
            x, y = self.__random_crop__(x, y)
            
        elif self.mode == 'eval':
            # Crop con coordenadas específicas para evaluación
            if 'coords' in r and r.coords is not None:
                x = x[r.coords[0]:r.coords[0]+self.ps, r.coords[1]:r.coords[1]+self.ps]
                y = y[r.coords[0]:r.coords[0]+self.ps, r.coords[1]:r.coords[1]+self.ps]
            else:
                # Si no hay coords, tomar centro
                H, W = x.shape[:2]
                top = max(0, (H - self.ps) // 2)
                left = max(0, (W - self.ps) // 2)
                x = x[top:top+self.ps, left:left+self.ps]
                y = y[top:top+self.ps, left:left+self.ps]
        
        # mode == 'full': no hacer crop, usar imagen completa

        # Convertir máscara RGB a IDs de clase
        if y.ndim == 3 and y.shape[2] == 3:
            y = self.mask_rgb_to_ids(y)
        else:
            y = y.astype(np.int64)

        # Convertir a PIL para aplicar transforms
        x = to_pil_image(x. astype('float32'))
        y = to_pil_image(y.astype('uint8'), mode='L')

        # Aplicar transformaciones
        if self. transforms:
            if isinstance(self.transforms, list):
                # Formato: [joint_transform, image_transform]
                if self.transforms[0] is not None:
                    x, y = self.transforms[0](x, y)
                if self.transforms[1] is not None:
                    x = self. transforms[1](x)
            else:
                # Solo image transform
                x = self.transforms(x)

        # Convertir a tensors
        if not isinstance(x, torch.Tensor):
            x = torch.from_numpy(np.array(x)). permute(2, 0, 1).float() / 255.0
        
        y = torch.from_numpy(np.array(y, dtype=np.int64))

        return x, y

    def __len__(self):
        return self.df.shape[0]

    def __show_item__(self, x, y, denormalize=None):
        """
        Visualiza una muestra del dataset.
        
        Args:
            x: Tensor de imagen (C, H, W)
            y: Tensor de máscara (H, W)
            denormalize: Función para desnormalizar la imagen (opcional)
        """
        f, ax = plt.subplots(1, 3, figsize=(15, 5))

        # Desnormalizar imagen si es necesario
        if denormalize is not None:
            x_vis = denormalize(x)
        else:
            x_vis = x
        
        # Asegurar que x esté en rango [0, 1]
        x_vis = x_vis.permute(1, 2, 0).cpu().numpy()
        if x_vis.max() > 1.0:
            x_vis = x_vis / 255.0
        x_vis = np.clip(x_vis, 0, 1)

        y_vis = y.cpu().numpy()

        ax[0].imshow(x_vis)
        ax[0].set_title('Image')
        ax[0].axis('off')

        ax[1].imshow(y_vis, cmap='tab20', vmin=0, vmax=self.num_classes-1)
        ax[1]. set_title('Mask')
        ax[1].axis('off')

        ax[2].imshow(x_vis)
        ax[2].imshow(y_vis, alpha=0.5, cmap='tab20', vmin=0, vmax=self.num_classes-1)
        ax[2].set_title('Overlay')
        ax[2].axis('off')

        plt.tight_layout()
        plt.show()

In [None]:
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torchvision import transforms as T
import torch

class DataModule(pl.LightningDataModule):
    """
    PyTorch Lightning DataModule para segmentación semántica. 
    
    Args:
        df: DataFrame con columnas 'image_path', 'mask_path', 'set'
        class_dict_path: Ruta al CSV con clases
        bs: Batch size
        ps: Patch size
        num_workers: Número de workers para DataLoader
    """
    
    def __init__(self, df, class_dict_path='clothes/class_dict.csv', bs=16, ps=256, num_workers=8):
        super().__init__()
        self.df = df
        self.class_dict_path = class_dict_path
        self.bs = bs
        self.ps = ps
        self. num_workers = num_workers
        
        # Cargar número de clases
        class_df = pd.read_csv(class_dict_path)
        self.num_classes = len(class_df)

    def setup(self, stage=None):
        """Preparar datasets con transforms apropiados"""
        
        # Transforms para normalización (solo imagen)
        normalize_transform = T.Compose([
            T.ToTensor(),
            T. Normalize(mean=[0.485, 0.456, 0. 406], std=[0.229, 0.224, 0. 225]),
        ])
        
        # Para train: sin joint transforms (podrías agregar data augmentation aquí)
        train_transforms = [None, normalize_transform]
        
        # Para val/test: solo normalización
        eval_transforms = [None, normalize_transform]
        
        # Crear datasets
        self.train_ds = CCPDataset(
            df=self.df. query('set == "train"'). reset_index(drop=True),
            patch_size=self. ps,
            transforms=train_transforms,
            mode='train',
            class_dict_path=self.class_dict_path
        )
        
        self. valid_ds = CCPDataset(
            df=self.df. query('set == "valid"').reset_index(drop=True),
            patch_size=self.ps,
            transforms=eval_transforms,
            mode='train',  # También random crop en validación
            class_dict_path=self.class_dict_path
        )
        
        self.test_ds = CCPDataset(
            df=self. df.query('set == "test"').reset_index(drop=True),
            patch_size=self.ps,
            transforms=eval_transforms,
            mode='full',  # Imagen completa para test
            class_dict_path=self.class_dict_path
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_ds,
            batch_size=self.bs,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=True if self.num_workers > 0 else False
        )

    def val_dataloader(self):
        return DataLoader(
            self.valid_ds,
            batch_size=self.bs,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=True if self.num_workers > 0 else False
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_ds,
            batch_size=1,  # Batch size 1 para imágenes completas
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=True if self.num_workers > 0 else False
        )
    
    @staticmethod
    def patch_origins(h, w, ps, overlap=2):
        """
        Genera coordenadas de origen para patches con overlap.
        
        Args:
            h, w: Alto y ancho de la imagen
            ps: Tamaño del patch
            overlap: Factor de overlap (overlap=2 significa 50% de overlap)
        
        Returns:
            Array de coordenadas (x, y) de origen de cada patch
        """
        stride = ps // overlap
        origins = []
        
        for x in range(0, h - ps + 1, stride):
            for y in range(0, w - ps + 1, stride):
                origins.append([x, y])
        
        # Agregar bordes si no están cubiertos
        if (h - ps) % stride != 0:
            for y in range(0, w - ps + 1, stride):
                origins.append([h - ps, y])
        
        if (w - ps) % stride != 0:
            for x in range(0, h - ps + 1, stride):
                origins.append([x, w - ps])
        
        # Esquina inferior derecha
        if (h - ps) % stride != 0 and (w - ps) % stride != 0:
            origins. append([h - ps, w - ps])
        
        return np.array(origins)

In [None]:
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch. nn.functional as F
import segmentation_models_pytorch as smp
from collections import defaultdict

class SegmentationModel(pl.LightningModule):
    """
    PyTorch Lightning Module para segmentación semántica multiclase.
    
    Args:
        num_classes: Número de clases de segmentación
        encoder_name: Nombre del encoder (ej: 'resnet18', 'resnet50', 'efficientnet-b0')
        encoder_weights: Pesos pre-entrenados (ej: 'imagenet')
        learning_rate: Learning rate para el optimizador
        architecture: Arquitectura del modelo ('unet', 'unet++', 'deeplabv3+', etc.)
    """
    
    def __init__(self, num_classes, encoder_name='resnet18', encoder_weights='imagenet', 
                 learning_rate=1e-3, architecture='unet'):
        super().__init__()
        self.save_hyperparameters()
        
        self.num_classes = num_classes
        self.learning_rate = learning_rate
        
        # Crear modelo según arquitectura
        if architecture == 'unet':
            self.model = smp.Unet(
                encoder_name=encoder_name,
                encoder_weights=encoder_weights,
                in_channels=3,
                classes=num_classes
            )
        elif architecture == 'unet++':
            self.model = smp.UnetPlusPlus(
                encoder_name=encoder_name,
                encoder_weights=encoder_weights,
                in_channels=3,
                classes=num_classes
            )
        elif architecture == 'deeplabv3+':
            self.model = smp.DeepLabV3Plus(
                encoder_name=encoder_name,
                encoder_weights=encoder_weights,
                in_channels=3,
                classes=num_classes
            )
        else:
            raise ValueError(f"Architecture {architecture} not supported")
        
        # Métricas acumuladas por época
        self.training_step_outputs = defaultdict(float)
        self.validation_step_outputs = defaultdict(float)
        
    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        
        # Learning rate scheduler
        scheduler = torch. optim.lr_scheduler. ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=0. 5,
            patience=5,
            verbose=True
        )
        
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss",
            },
        }

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        
        # Redimensionar logits si es necesario
        if logits.shape[-2:] != y.shape[-2:]:
            logits = F.interpolate(logits, size=y.shape[-2:], mode='bilinear', align_corners=False)
        
        # Calcular loss
        loss = F.cross_entropy(logits, y)
        
        # Calcular accuracy
        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()
        
        # Logging
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('train_acc', acc, on_step=False, on_epoch=True, prog_bar=True)
        
        self.training_step_outputs['loss'] += loss.detach().cpu()
        self.training_step_outputs['acc'] += acc.detach().cpu()
        self.training_step_outputs['steps'] += 1
        
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        
        # Redimensionar logits si es necesario
        if logits.shape[-2:] != y.shape[-2:]:
            logits = F.interpolate(logits, size=y.shape[-2:], mode='bilinear', align_corners=False)
        
        # Calcular loss
        loss = F.cross_entropy(logits, y)
        
        # Calcular accuracy
        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()
        
        # Logging
        self. log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('val_acc', acc, on_step=False, on_epoch=True, prog_bar=True)
        
        self.validation_step_outputs['loss'] += loss.detach().cpu()
        self.validation_step_outputs['acc'] += acc.detach(). cpu()
        self.validation_step_outputs['steps'] += 1
        
        return loss

    def on_train_epoch_end(self):
        if self.training_step_outputs['steps'] > 0:
            avg_loss = self.training_step_outputs['loss'] / self.training_step_outputs['steps']
            avg_acc = self.training_step_outputs['acc'] / self.training_step_outputs['steps']
            print(f"\nEpoch {self.current_epoch} - Train Loss: {avg_loss:. 4f}, Train Acc: {avg_acc:.4f}")
        self.training_step_outputs.clear()

    def on_validation_epoch_end(self):
        if self.validation_step_outputs['steps'] > 0:
            avg_loss = self.validation_step_outputs['loss'] / self. validation_step_outputs['steps']
            avg_acc = self. validation_step_outputs['acc'] / self.validation_step_outputs['steps']
            print(f"Epoch {self.current_epoch} - Val Loss: {avg_loss:.4f}, Val Acc: {avg_acc:.4f}")
        self.validation_step_outputs.clear()

In [None]:
import glob
import pandas as pd
import os

# ============================================
# 1.  PREPARAR DATOS
# ============================================

DATA_DIR = 'clothes'

# Cargar imágenes
images = glob.glob('./clothes/*/images/*')
df = pd.DataFrame(images, columns=['image_path'])
df['mask_path'] = df. image_path.apply(lambda x: x.replace('/images/', '/labels/'). replace('. jpg', '.png'))
df['set'] = df.image_path.apply(lambda x: x.split('/')[2])

print(f"Total images: {len(df)}")
print(df['set'].value_counts())
print(df. head())

# Cargar diccionario de clases
class_dict = pd.read_csv(os.path.join(DATA_DIR, 'class_dict. csv'))
class_names = class_dict['class_name'].tolist()
class_names[0] = 'background'  # Renombrar 'null' a 'background'
class_rgb_values = class_dict[['r', 'g', 'b']].values.tolist()

print('\n' + '='*50)
print('Dataset Classes:')
print('='*50)
for i, (name, rgb) in enumerate(zip(class_names, class_rgb_values)):
    print(f"{i:2d}. {name:20s} - RGB: {rgb}")
print('='*50 + '\n')

# ============================================
# 2. CREAR DATAMODULE
# ============================================

dm = DataModule(
    df=df,
    class_dict_path=os.path.join(DATA_DIR, 'class_dict.csv'),
    bs=16,
    ps=256,
    num_workers=4  # Ajusta según tu CPU
)

dm.setup()

print(f"Train samples: {len(dm.train_ds)}")
print(f"Valid samples: {len(dm. valid_ds)}")
print(f"Test samples: {len(dm.test_ds)}")
print(f"Number of classes: {dm.num_classes}")

# ============================================
# 3.  VISUALIZAR MUESTRAS
# ============================================

# Visualizar algunas muestras de entrenamiento
for i in range(3):
    x, y = dm.train_ds[i]
    print(f"\nSample {i}:")
    print(f"Image shape: {x.shape}")
    print(f"Mask shape: {y.shape}")
    print(f"Unique classes in mask: {torch.unique(y). tolist()}")
    dm.train_ds.__show_item__(x, y)

# ============================================
# 4. CREAR MODELO
# ============================================

model = SegmentationModel(
    num_classes=dm.num_classes,
    encoder_name='resnet18',
    encoder_weights='imagenet',
    learning_rate=1e-3,
    architecture='unet'
)

print(f"\nModel created with {dm.num_classes} classes")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

# ============================================
# 5. ENTRENAR
# ============================================

from pytorch_lightning. callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

# Callbacks
checkpoint_callback = ModelCheckpoint(
    dirpath='checkpoints',
    filename='segmentation-{epoch:02d}-{val_loss:.4f}',
    save_top_k=3,
    monitor='val_loss',
    mode='min'
)

early_stop_callback = EarlyStopping(
    monitor='val_loss',
    patience=10,
    mode='min',
    verbose=True
)

# Logger
logger = TensorBoardLogger('logs', name='segmentation')

# Trainer
trainer = pl.Trainer(
    max_epochs=50,
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1,
    callbacks=[checkpoint_callback, early_stop_callback],
    logger=logger,
    log_every_n_steps=10,
    precision='16-mixed' if torch.cuda.is_available() else 32  # Mixed precision para GPU
)

# Entrenar
print("\n" + "="*50)
print("Starting training...")
print("="*50 + "\n")

trainer.fit(model, dm)

print("\n" + "="*50)
print("Training completed!")
print(f"Best model: {checkpoint_callback.best_model_path}")
print("="*50 + "\n")

In [None]:
from torchvision import transforms as T
from torchvision.transforms. functional import to_pil_image
from sklearn.metrics import accuracy_score, jaccard_score, f1_score
import torch. nn as nn
from tqdm import tqdm

def rgb_to_class(mask, class_rgb_values):
    """Convertir máscara RGB a índices de clase"""
    h, w, _ = mask.shape
    class_mask = np.zeros((h, w), dtype=np.int64)
    
    for idx, rgb in enumerate(class_rgb_values):
        matches = np.all(mask == np.array(rgb), axis=-1)
        class_mask[matches] = idx
    
    return class_mask

# ============================================
# CONFIGURACIÓN
# ============================================

ps = 256

# Cargar mejor modelo
best_model_path = checkpoint_callback.best_model_path
model = SegmentationModel. load_from_checkpoint(best_model_path)
model.eval()
model = model.cuda() if torch.cuda.is_available() else model

# Transforms para inferencia
inf_transforms = T.Compose([
    T.ToTensor(),
    T. Normalize(mean=[0.485, 0.456, 0. 406], std=[0.229, 0.224, 0. 225]),
])

# ============================================
# INFERENCIA EN TEST SET
# ============================================

metric_dict = []
test_df = df.query('set == "test"').reset_index(drop=True)

print(f"\nRunning inference on {len(test_df)} test images.. .\n")

for i in tqdm(range(len(test_df)), desc="Processing images"):
    r = test_df.iloc[i]

    # Leer imagen y máscara
    x = imread(r.image_path)
    y = imread(r.mask_path)
    
    # Convertir máscara RGB a índices de clase
    if y.ndim == 3 and y.shape[2] == 3:
        # Usar el método eficiente con bit packing
        packed = (y[...,0]. astype(np.uint32) << 16) | \
                 (y[...,1].astype(np.uint32) << 8) | \
                 y[...,2].astype(np.uint32)
        
        color_to_class = dm.train_ds.color_to_class
        y = np.vectorize(color_to_class. get)(packed, 0). astype(np.int64)
    
    # Obtener dimensiones
    h, w = y.shape
    
    # Generar coordenadas de patches
    coords = DataModule.patch_origins(h, w, ps=ps, overlap=2)

    # Inicializar predicciones
    y_hat = torch.zeros((dm.num_classes, h, w), dtype=torch.float32)
    y_wei = torch.zeros((1, h, w), dtype=torch.uint8)
    
    # Realizar inferencia por patches
    for (coord_x, coord_y) in coords:
        
        # Extraer patch
        xi = x[coord_x:coord_x+ps, coord_y:coord_y+ps]
        
        # Convertir a PIL
        xi = to_pil_image(xi. astype('float32'))
        
        # Aplicar transforms
        xi = inf_transforms(xi)
        
        # Inferencia
        with torch.no_grad():
            if torch.cuda.is_available():
                logits = model(xi.cuda(). unsqueeze(0))
            else:
                logits = model(xi.unsqueeze(0))
            
            logits = nn. Softmax(dim=1)(logits). detach().cpu()[0]
        
        # Acumular predicciones
        y_hat[:, coord_x:coord_x+ps, coord_y:coord_y+ps] += logits
        y_wei[:, coord_x:coord_x+ps, coord_y:coord_y+ps] += 1
    
    # Normalizar predicciones
    y_hat /= y_wei
    y_hat_pred = y_hat.argmax(0). numpy()
    
    # Calcular métricas
    metric_dict.append({
        'image': r.image_path,
        'pixel_acc': accuracy_score(y. ravel(), y_hat_pred. ravel()),
        'iou_macro': jaccard_score(y. ravel(), y_hat_pred.ravel(), average='macro', zero_division=0),
        'iou_weighted': jaccard_score(y. ravel(), y_hat_pred.ravel(), average='weighted', zero_division=0),
        'dice_macro': f1_score(y.ravel(), y_hat_pred.ravel(), average='macro', zero_division=0),
        'dice_weighted': f1_score(y.ravel(), y_hat_pred.ravel(), average='weighted', zero_division=0)
    })

# ============================================
# MOSTRAR RESULTADOS
# ============================================

results_df = pd.DataFrame(metric_dict)

print("\n" + "="*80)
print("INFERENCE RESULTS")
print("="*80)
print(results_df.describe())
print("\n" + "="*80)
print("MEAN METRICS:")
print("="*80)
for col in ['pixel_acc', 'iou_macro', 'iou_weighted', 'dice_macro', 'dice_weighted']:
    print(f"{col:20s}: {results_df[col].mean():.4f} ± {results_df[col].std():.4f}")
print("="*80 + "\n")

# Guardar resultados
results_df.to_csv('test_results.csv', index=False)
print("Results saved to 'test_results.csv'")

In [None]:
def visualize_prediction(image_path, model, transforms, class_names, ps=256, overlap=2):
    """
    Visualiza la predicción de segmentación para una imagen. 
    """
    # Leer imagen
    x = imread(image_path)
    h, w = x.shape[:2]
    
    # Generar patches
    coords = DataModule.patch_origins(h, w, ps=ps, overlap=overlap)
    
    # Inicializar predicciones
    num_classes = len(class_names)
    y_hat = torch.zeros((num_classes, h, w), dtype=torch.float32)
    y_wei = torch.zeros((1, h, w), dtype=torch.uint8)
    
    # Inferencia
    model.eval()
    for (coord_x, coord_y) in tqdm(coords, desc="Predicting"):
        xi = x[coord_x:coord_x+ps, coord_y:coord_y+ps]
        xi = to_pil_image(xi. astype('float32'))
        xi = transforms(xi)
        
        with torch.no_grad():
            if torch.cuda.is_available():
                logits = model(xi.cuda().unsqueeze(0))
            else:
                logits = model(xi.unsqueeze(0))
            logits = nn.Softmax(dim=1)(logits).detach().cpu()[0]
        
        y_hat[:, coord_x:coord_x+ps, coord_y:coord_y+ps] += logits
        y_wei[:, coord_x:coord_x+ps, coord_y:coord_y+ps] += 1
    
    # Normalizar y obtener predicción
    y_hat /= y_wei
    y_pred = y_hat.argmax(0).numpy()
    
    # Visualizar
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    axes[0].imshow(x)
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    
    axes[1].imshow(y_pred, cmap='tab20', vmin=0, vmax=num_classes-1)
    axes[1].set_title('Prediction')
    axes[1].axis('off')
    
    axes[2].imshow(x)
    axes[2].imshow(y_pred, alpha=0.5, cmap='tab20', vmin=0, vmax=num_classes-1)
    axes[2].set_title('Overlay')
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Mostrar distribución de clases
    unique, counts = np.unique(y_pred, return_counts=True)
    print("\nClass distribution:")
    for cls, count in zip(unique, counts):
        percentage = 100 * count / y_pred.size
        print(f"  {class_names[cls]:20s}: {percentage:6.2f}%")

# Uso:
# visualize_prediction(
#     test_df. iloc[0]. image_path, 
#     model, 
#     inf_transforms, 
#     class_names,
#     ps=256,
#     overlap=2
# )