In [None]:
import torch
import numpy as np
from torch.utils.data import DataLoader
from autoencoder import DNIAnomalyDetector, DNIDataset
from tqdm import tqdm
import matplotlib.pyplot as plt

def test_autoencoder(model_path, test_dir, batch_size=32):
    """
    Test autoencoder model on a directory of images.
    
    Args:
        model_path: Path to the autoencoder model
        test_dir: Directory containing test images
        batch_size: Batch size for testing
    """
    # Cargar modelo
    detector = DNIAnomalyDetector(device='cuda' if torch.cuda.is_available() else 'cpu')
    detector.load_model(model_path)
    detector.encoder.eval()
    detector.decoder.eval()

    # Calcular errores
    print("\nCalculando errores de reconstrucción...")
    errors = compute_reconstruction_errors(detector, test_dir, batch_size)

    # Visualizar histograma
    plt.figure(figsize=(10, 6))
    plt.hist(errors, bins=50, alpha=0.75, color='blue', label=f'Images (n={len(errors)})')
    plt.axvline(x=detector.threshold, color='red', linestyle='--', 
                label=f'Model Threshold: {detector.threshold:.6f}')
    plt.title('Distribution of Reconstruction Errors')
    plt.xlabel('MSE Reconstruction Error')
    plt.ylabel('Number of Images')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig('reconstruction_errors.png')
    plt.show()

    # Imprimir estadísticas
    above_threshold = np.sum(errors > detector.threshold)
    below_threshold = len(errors) - above_threshold
    
    print("\nEstadísticas de errores de reconstrucción:")
    print(f"Imágenes por encima del threshold: {above_threshold}/{len(errors)} ({100*above_threshold/len(errors):.2f}%)")
    print(f"Imágenes por debajo del threshold: {below_threshold}/{len(errors)} ({100*below_threshold/len(errors):.2f}%)")
    print(f"\nMedia: {errors.mean():.6f}")
    print(f"Mediana: {np.median(errors):.6f}")
    print(f"Desviación estándar: {errors.std():.6f}")
    print(f"Mínimo: {errors.min():.6f}")
    print(f"Máximo: {errors.max():.6f}")

def compute_reconstruction_errors(detector, data_dir, batch_size=32):
    """Calcular errores de reconstrucción para todas las imágenes."""
    dataset = DNIDataset(data_dir, detector.transform)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    
    errors = []
    first_batch = True
    with torch.no_grad():
        for batch in tqdm(loader, desc="Procesando imágenes"):
            imgs = batch.to(detector.device)
            
            if first_batch:
                # Imprimir dimensiones de debug
                x = detector.encoder.conv(imgs)
                flattened = x.view(x.size(0), -1)
                print(f"\nDimensiones:")
                print(f"Input images: {imgs.shape}")
                print(f"After conv layers: {x.shape}")
                print(f"Flattened: {flattened.shape}")
                print(f"FC layer expects input: {detector.encoder.fc.in_features}")
                first_batch = False
            
            try:
                latent = detector.encoder(imgs)
                reconstructed = detector.decoder(latent)
                batch_errors = torch.nn.functional.mse_loss(
                    reconstructed, imgs, reduction='none'
                ).mean(dim=[1,2,3]).cpu().numpy()
                errors.extend(batch_errors)
            except RuntimeError as e:
                print("\nError de dimensiones detectado.")
                print(f"Último tensor shape: {x.shape}")
                print(f"FC layer input size: {detector.encoder.fc.in_features}")
                print(f"FC layer output size: {detector.encoder.fc.out_features}")
                raise e
            
    return np.array(errors)

# Uso
test_autoencoder(
    model_path='models/dni_anomaly_detector_without_yolo.pt',
    test_dir='autoencoder_data/train_images_without_YOLO',
    batch_size=32
)

  checkpoint = torch.load(path)



Calculando errores de reconstrucción...


Procesando imágenes:   2%|▏         | 1/63 [00:02<02:27,  2.38s/it]


Dimensiones:
Input images: torch.Size([32, 3, 224, 224])
After conv layers: torch.Size([32, 128, 28, 28])
Flattened: torch.Size([32, 100352])
FC layer expects input: 100352


Procesando imágenes:  65%|██████▌   | 41/63 [01:25<00:44,  2.03s/it]