# Visualización de Resultados de Entrenamiento

Este notebook carga el modelo entrenado y visualiza las predicciones en el conjunto de validación.

In [None]:
import os
import sys
import torch
import matplotlib.pyplot as plt
import numpy as np
import cv2

# Add src to path
sys.path.append('../src')

from unet import UNet
from dataset import SegmentationDataset
from torch.utils.data import DataLoader

## Configuración

In [None]:
IMG_DIR = '../data/images'
MASK_DIR = '../data/masks'
CHECKPOINT_PATH = '../checkpoints/best_model.pth'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f'Using device: {DEVICE}')

## Cargar Modelo

In [None]:
model = UNet(n_channels=3, n_classes=1)
if os.path.exists(CHECKPOINT_PATH):
    model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=DEVICE))
    print('Model loaded successfully!')
else:
    print('Checkpoint not found! Please train the model first.')

model.to(DEVICE)
model.eval()

## Cargar Datos

In [None]:
dataset = SegmentationDataset(IMG_DIR, MASK_DIR)
loader = DataLoader(dataset, batch_size=1, shuffle=True)

## Visualizar Predicciones

In [None]:
def visualize_prediction(model, loader, num_samples=5):
    model.eval()
    
    with torch.no_grad():
        for i, (image, mask) in enumerate(loader):
            if i >= num_samples:
                break
            
            image = image.to(DEVICE)
            pred = model(image)
            
            # Convert to numpy
            img_np = image.cpu().squeeze(0).permute(1, 2, 0).numpy()
            mask_np = mask.cpu().squeeze(0).squeeze(0).numpy()
            pred_np = pred.cpu().squeeze(0).squeeze(0).numpy()
            
            # Binarize prediction for visualization
            pred_bin = (pred_np > 0.5).astype(np.float32)
            
            # Difference map
            diff = np.abs(mask_np - pred_bin)
            
            plt.figure(figsize=(15, 5))
            
            plt.subplot(1, 4, 1)
            plt.imshow(img_np)
            plt.title('Original Image')
            plt.axis('off')
            
            plt.subplot(1, 4, 2)
            plt.imshow(mask_np, cmap='gray')
            plt.title('Ground Truth')
            plt.axis('off')
            
            plt.subplot(1, 4, 3)
            plt.imshow(pred_np, cmap='gray')
            plt.title('Prediction (Prob)')
            plt.axis('off')
            
            plt.subplot(1, 4, 4)
            plt.imshow(diff, cmap='hot')
            plt.title('Difference')
            plt.axis('off')
            
            plt.tight_layout()
            plt.show()

if len(dataset) > 0:
    visualize_prediction(model, loader)
else:
    print("No images found in data directory. Please add images to visualize.")