In [None]:
import torch
import numpy as np
from torch.utils.data import DataLoader
from vae import DNIVAEDetector, DNIDataset
from tqdm import tqdm
import matplotlib.pyplot as plt
from pathlib import Path

def test_autoencoder(model_path, test_dir, batch_size=32, threshold_option='95', num_images_to_plot=5):
    """
    Test autoencoder model on a directory of images.

    Args:
        model_path: Path to the autoencoder model
        test_dir: Directory containing test images
        batch_size: Batch size for testing
        threshold_option: Threshold option ('90', '95', or '99')
        num_images_to_plot: Number of random images to plot
    """
    # Cargar modelo
    detector = DNIVAEDetector(device='cuda' if torch.cuda.is_available() else 'cpu')
    detector.load_model(model_path)

    # Obtener lista de imágenes
    test_dir = Path(test_dir)
    image_paths = list(test_dir.glob('*.jpg'))

    # Calcular errores y confidence scores
    print("\nProcesando imágenes...")
    confidences = []
    errors = []
    images_to_plot = []

    for img_path in tqdm(image_paths):
        confidence, error, original, reconstructed = detector.predict(
            str(img_path),
            threshold_option=threshold_option,
            return_images=True
        )
        confidences.append(confidence)
        errors.append(error)

        # Guardar algunas imágenes para plotear
        if len(images_to_plot) < num_images_to_plot and np.random.random() < 0.1:
            images_to_plot.append((original, reconstructed, confidence, error))

    # Convertir a arrays
    confidences = np.array(confidences)
    errors = np.array(errors)
    threshold = detector.thresholds[threshold_option]

    # Plotear histograma de errores
    plt.figure(figsize=(12, 6))
    plt.hist(errors, bins=50, alpha=0.75, color='blue', label=f'Images (n={len(errors)})')
    plt.axvline(x=threshold, color='red', linestyle='--',
                label=f'Threshold ({threshold_option}%): {threshold:.6f}')
    plt.title('Distribution of Reconstruction Errors')
    plt.xlabel('Reconstruction Error')
    plt.ylabel('Number of Images')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig('reconstruction_errors.png')
    plt.show()

    # Plotear imágenes originales vs reconstruidas
    if images_to_plot:
        fig, axes = plt.subplots(len(images_to_plot), 2, figsize=(12, 4*len(images_to_plot)))
        for idx, (original, reconstructed, confidence, error) in enumerate(images_to_plot):
            # Convertir tensores a imágenes
            original = original.cpu().permute(1, 2, 0).numpy()
            reconstructed = reconstructed.cpu().permute(1, 2, 0).numpy()

            # Normalizar si es necesario
            original = np.clip(original, 0, 1)
            reconstructed = np.clip(reconstructed, 0, 1)

            # Plotear
            axes[idx, 0].imshow(original)
            axes[idx, 0].set_title('Original')
            axes[idx, 0].axis('off')

            axes[idx, 1].imshow(reconstructed)
            axes[idx, 1].set_title(f'Reconstructed\nConf: {confidence:.2f}, Error: {error:.4f}')
            axes[idx, 1].axis('off')

        plt.tight_layout()
        plt.savefig('reconstructions.png')
        plt.show()

    # Imprimir estadísticas
    print("\nEstadísticas:")
    print(f"Threshold {threshold_option}%: {threshold:.6f}")
    above_threshold = np.sum(errors > threshold)
    below_threshold = len(errors) - above_threshold

    print(f"\nImágenes anómalas (error > threshold): {above_threshold}/{len(errors)} ({100*above_threshold/len(errors):.2f}%)")
    print(f"Imágenes normales (error ≤ threshold): {below_threshold}/{len(errors)} ({100*below_threshold/len(errors):.2f}%)")

    print(f"\nEstadísticas de errores de reconstrucción:")
    print(f"Media: {errors.mean():.6f}")
    print(f"Mediana: {np.median(errors):.6f}")
    print(f"Desviación estándar: {errors.std():.6f}")
    print(f"Mínimo: {errors.min():.6f}")
    print(f"Máximo: {errors.max():.6f}")

    print(f"\nEstadísticas de confidence scores:")
    print(f"Media: {confidences.mean():.6f}")
    print(f"Mediana: {np.median(confidences):.6f}")
    print(f"Desviación estándar: {confidences.std():.6f}")
    print(f"Mínimo: {confidences.min():.6f}")
    print(f"Máximo: {confidences.max():.6f}")

# Ejemplo de uso
if __name__ == "__main__":
    test_autoencoder(
        model_path='vae_model.pt',
        test_dir='archive/test_set/test_set/cats',
        batch_size=32,
        threshold_option='95',
        num_images_to_plot=5
    )

In [None]:
detector.decoder