# Test

In [None]:
from pathlib import Path
from ultralytics import YOLO
from autoencoder import DNIAnomalyDetector, DNIDataset
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

def get_reconstruction_errors(detector, data_dir, batch_size=32):
    """Helper function to compute reconstruction errors for a directory."""
    dataset = DNIDataset(data_dir, detector.transform)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    
    reconstruction_errors = []
    
    with torch.no_grad():
        for batch in tqdm(loader, desc=f"Processing {data_dir}"):
            imgs = batch.to(detector.device)
            reconstructed = detector.decoder(detector.encoder(imgs))
            
            errors = torch.nn.functional.mse_loss(
                reconstructed, imgs, reduction='none'
            ).mean(dim=[1,2,3]).cpu().numpy()
            
            reconstruction_errors.extend(errors)
            
    return np.array(reconstruction_errors)

def preprocess_dataset_with_yolo(yolo_model, input_dir, output_dir, invalid_labels=None):
    """
    Preprocesa un directorio de imágenes usando YOLO para detectar y recortar DNIs.
    
    Args:
        yolo_model: Modelo YOLO cargado
        input_dir: Directorio con imágenes originales
        output_dir: Directorio donde guardar las imágenes recortadas
        valid_labels: Lista de etiquetas válidas para YOLO
    """
    if invalid_labels is None:
        invalid_labels = ['no_match']  # Ajusta según tus etiquetas
        
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    
    # Procesar todas las imágenes en el directorio
    image_paths = list(Path(input_dir).glob('*.jpg'))
    processed_count = 0
    skipped_count = 0
    
    for img_path in tqdm(image_paths, desc="Preprocessing images with YOLO"):
        try:
            # Detectar objetos con YOLO
            results = yolo_model(str(img_path), verbose=False)[0]
            
            # Si no hay detecciones, saltar imagen
            if len(results.boxes) == 0:
                print(f"\nNo detection in {img_path.name}")
                skipped_count += 1
                continue
            
            # Obtener la mejor detección
            confidences = results.boxes.conf.cpu().numpy()
            best_idx = confidences.argmax()
            box = results.boxes[best_idx]
            
            cls_id = box.cls.item()
            cls_name = results.names[int(cls_id)]
            
            # Verificar si la clase es válida
            if cls_name in invalid_labels:
                print(f"\nInvalid class {cls_name} in {img_path.name}")
                skipped_count += 1
                continue
            
            # Recortar imagen
            original_image = Image.open(img_path)
            bbox = box.xyxy[0].cpu().numpy()
            cropped_image = crop_image(original_image, bbox)
            
            # Guardar imagen recortada
            output_path = Path(output_dir) / img_path.name
            cropped_image.save(output_path)
            processed_count += 1
            
        except Exception as e:
            print(f"\nError processing {img_path.name}: {str(e)}")
            skipped_count += 1
            continue
    
    print(f"\nPreprocessing complete:")
    print(f"Processed: {processed_count} images")
    print(f"Skipped: {skipped_count} images")
    return output_dir

def crop_image(image, bbox):
    """Crop image using bounding box coordinates."""
    x1, y1, x2, y2 = [int(coord) for coord in bbox]
    return image.crop((x1, y1, x2, y2))

def plot_misclassifications(detector, data_dir, errors, is_valid=True, batch_size=32):
    """
    Plot images that were misclassified by the model.
    
    Args:
        detector: DNIAnomalyDetector instance
        data_dir: Directory with images
        errors: Array of reconstruction errors
        is_valid: Whether these are valid images (for FP) or invalid images (for FN)
        batch_size: Batch size for processing
    """
    dataset = DNIDataset(data_dir, detector.transform)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    
    # Encontrar índices de imágenes mal clasificadas
    if is_valid:
        # Falsos Positivos: imágenes válidas con error > threshold
        misclassified_idx = np.where(errors > detector.threshold)[0]
        title = "Falsos Positivos (Válidas clasificadas como Inválidas)"
    else:
        # Falsos Negativos: imágenes inválidas con error <= threshold
        misclassified_idx = np.where(errors <= detector.threshold)[0]
        title = "Falsos Negativos (Inválidas clasificadas como Válidas)"
    
    if len(misclassified_idx) == 0:
        print(f"No hay {title}")
        return
    
    # Tomar hasta 16 imágenes para visualizar
    n_images = min(16, len(misclassified_idx))
    misclassified_idx = misclassified_idx[:n_images]
    
    fig, axes = plt.subplots(4, 4, figsize=(15, 15))
    axes = axes.ravel()
    
    detector.encoder.eval()
    detector.decoder.eval()
    
    with torch.no_grad():
        current_idx = 0
        for batch_idx, batch in enumerate(loader):
            for img_idx in range(len(batch)):
                global_idx = batch_idx * batch_size + img_idx
                if global_idx in misclassified_idx:
                    img = batch[img_idx:img_idx+1].to(detector.device)
                    reconstructed = detector.decoder(detector.encoder(img))
                    
                    # Convertir a numpy para visualización
                    original = img.cpu().squeeze(0).permute(1, 2, 0).numpy()
                    reconstructed = reconstructed.cpu().squeeze(0).permute(1, 2, 0).numpy()
                    
                    # Plotear original y reconstrucción
                    axes[current_idx*2].imshow(original)
                    axes[current_idx*2].set_title(f'Original\nError: {errors[global_idx]:.6f}')
                    axes[current_idx*2].axis('off')
                    
                    axes[current_idx*2+1].imshow(reconstructed)
                    axes[current_idx*2+1].set_title('Reconstrucción')
                    axes[current_idx*2+1].axis('off')
                    
                    current_idx += 1
                    if current_idx >= n_images:
                        break
            if current_idx >= n_images:
                break
    
    plt.suptitle(title)
    plt.tight_layout()
    save_name = 'falsos_positivos.png' if is_valid else 'falsos_negativos.png'
    plt.savefig(save_name)
    plt.show()

def compare_reconstruction_errors(model_path, valid_dir, invalid_dir, batch_size=32):
    """
    Compare reconstruction errors between valid and invalid images.
    
    Args:
        model_path: Path to the saved model
        valid_dir: Directory with valid images
        invalid_dir: Directory with invalid images
        batch_size: Batch size for testing
    """
    # Cargar modelo
    detector = DNIAnomalyDetector(device='cuda' if torch.cuda.is_available() else 'cpu')
    detector.load_model(model_path)
    
    detector.encoder.eval()
    detector.decoder.eval()
    
    yolo_model = YOLO("api/best.pt")
    
    print("Preprocesando imágenes con YOLO...")
    # Crear directorios temporales para las imágenes procesadas
    valid_processed_dir = "temp_valid_processed"
    invalid_processed_dir = "temp_invalid_processed"
    
    try:
        valid_dir_processed = preprocess_dataset_with_yolo(
            yolo_model, valid_dir, valid_processed_dir
        )
        invalid_dir_processed = preprocess_dataset_with_yolo(
            yolo_model, invalid_dir, invalid_processed_dir
        )
        # Obtener errores para ambos conjuntos
        valid_errors = get_reconstruction_errors(detector, valid_dir, batch_size)
        invalid_errors = get_reconstruction_errors(detector, invalid_dir, batch_size)
        
        valid_above_threshold = np.sum(valid_errors > detector.threshold)
        valid_proportion = valid_above_threshold / len(valid_errors) * 100
        
        invalid_below_threshold = np.sum(invalid_errors <= detector.threshold)
        invalid_proportion = invalid_below_threshold / len(invalid_errors) * 100
        
        print("\nMétricas de clasificación:")
        print(f"Imágenes válidas por encima del threshold: {valid_above_threshold}/{len(valid_errors)} ({valid_proportion:.2f}%)")
        print(f"Imágenes inválidas por debajo del threshold: {invalid_below_threshold}/{len(invalid_errors)} ({invalid_proportion:.2f}%)")
        
        # Para completitud, también mostramos el complemento
        print(f"\nImágenes válidas correctamente clasificadas: {len(valid_errors) - valid_above_threshold}/{len(valid_errors)} ({100 - valid_proportion:.2f}%)")
        print(f"Imágenes inválidas correctamente clasificadas: {len(invalid_errors) - invalid_below_threshold}/{len(invalid_errors)} ({100 - invalid_proportion:.2f}%)")
        
        # Plotear histograma comparativo
        plt.figure(figsize=(12, 7))
        
        # Calcular bins comunes para ambas distribuciones
        all_errors = np.concatenate([valid_errors, invalid_errors])
        bins = np.linspace(min(all_errors), max(all_errors), 50)
        
        # Plotear ambas distribuciones
        plt.hist(valid_errors, bins=bins, alpha=0.5, color='green', 
                 label=f'Valid Images (n={len(valid_errors)})')
        plt.hist(invalid_errors, bins=bins, alpha=0.5, color='red',
                 label=f'Invalid Images (n={len(invalid_errors)})')
        
        plt.axvline(x=detector.threshold, color='black', linestyle='--', 
                    label=f'Model Threshold: {detector.threshold:.6f}')
        
        plt.title('Distribution of Reconstruction Errors: Valid vs Invalid Images')
        plt.xlabel('MSE Reconstruction Error')
        plt.ylabel('Number of Images')
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.savefig('reconstruction_errors_comparison.png')
        plt.show()
        
        # Imprimir estadísticas comparativas
        print("\nEstadísticas de errores de reconstrucción:")
        print("\nImágenes Válidas:")
        print(f"Media: {valid_errors.mean():.6f}")
        print(f"Mediana: {np.median(valid_errors):.6f}")
        print(f"Desviación estándar: {valid_errors.std():.6f}")
        print(f"Mínimo: {valid_errors.min():.6f}")
        print(f"Máximo: {valid_errors.max():.6f}")
        
        print("\nImágenes Inválidas:")
        print(f"Media: {invalid_errors.mean():.6f}")
        print(f"Mediana: {np.median(invalid_errors):.6f}")
        print(f"Desviación estándar: {invalid_errors.std():.6f}")
        print(f"Mínimo: {invalid_errors.min():.6f}")
        print(f"Máximo: {invalid_errors.max():.6f}")
        
        # Calcular solapamiento entre distribuciones
        min_overlap = max(min(valid_errors), min(invalid_errors))
        max_overlap = min(max(valid_errors), max(invalid_errors))
        overlap_range = max_overlap - min_overlap
        total_range = max(max(valid_errors), max(invalid_errors)) - min(min(valid_errors), min(invalid_errors))
        overlap_percentage = (overlap_range / total_range) * 100
        
        print(f"\nSolapamiento entre distribuciones: {overlap_percentage:.2f}%")
        
        # Después de todas las métricas y el histograma, agregar:
        print("\nGenerando visualizaciones de casos mal clasificados...")
        plot_misclassifications(detector, valid_dir, valid_errors, is_valid=True, batch_size=batch_size)
        plot_misclassifications(detector, invalid_dir, invalid_errors, is_valid=False, batch_size=batch_size)
    
    finally:
        pass
# Uso
compare_reconstruction_errors(
    model_path='dni_anomaly_detector.pt',
    valid_dir='test/valid',
    invalid_dir='test/invalid',
    batch_size=32
)

In [None]:
import torch
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from torchvision import transforms

def test_autoencoder(model, image_path, save_path=None):
    """
    Test the autoencoder by reconstructing an image and visualizing the results.
    
    Args:
        model: DNIAnomalyDetector instance
        image_path: Path to the test image
        save_path: Optional path to save the visualization
    """
    # Prepare image
    image = Image.open(image_path).convert('RGB')
    image_tensor = model.transform(image).unsqueeze(0).to(model.device)
    
    # Get reconstruction
    model.encoder.eval()
    model.decoder.eval()
    with torch.no_grad():
        latent = model.encoder(image_tensor)
        reconstructed = model.decoder(latent)
    
    # Calculate reconstruction error
    mse_loss = torch.nn.functional.mse_loss(reconstructed, image_tensor).item()
    
    # Convert tensors to images for plotting
    original_img = image_tensor.squeeze(0).cpu().numpy().transpose(1, 2, 0)
    reconstructed_img = reconstructed.squeeze(0).cpu().numpy().transpose(1, 2, 0)
    
    # Create figure
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Plot original
    axes[0].imshow(original_img)
    axes[0].set_title('Original')
    axes[0].axis('off')
    
    # Plot reconstruction
    axes[1].imshow(reconstructed_img)
    axes[1].set_title('Reconstructed')
    axes[1].axis('off')
    
    # Plot error map
    error_map = np.abs(original_img - reconstructed_img).mean(axis=2)
    im = axes[2].imshow(error_map, cmap='hot')
    axes[2].set_title(f'Error Map\nMSE: {mse_loss:.6f}')
    axes[2].axis('off')
    plt.colorbar(im, ax=axes[2])
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path)
    plt.show()
    
    return mse_loss

def batch_test(model, test_dir, n_samples=5):
    """
    Test multiple images and show their reconstructions
    
    Args:
        model: DNIAnomalyDetector instance
        test_dir: Directory containing test images
        n_samples: Number of random samples to test
    """
    from pathlib import Path
    import random
    
    # Get list of images
    image_paths = list(Path(test_dir).glob('*.jpg'))
    
    # Select random samples
    if n_samples > len(image_paths):
        n_samples = len(image_paths)
    
    test_images = random.sample(image_paths, n_samples)
    
    # Test each image
    results = []
    for img_path in test_images:
        print(f"\nTesting {img_path.name}")
        mse = test_autoencoder(model, img_path)
        confidence = model.predict(img_path)
        results.append({
            'image': img_path.name,
            'mse': mse,
            'confidence': confidence
        })
        print(f"MSE: {mse:.6f}")
        print(f"Confidence Score: {confidence:.6f}")
    
    return results

# Cargar el modelo entrenado
detector = DNIAnomalyDetector(device='cuda' if torch.cuda.is_available() else 'cpu')
detector.load_model('dni_anomaly_detector_old.pt')

# Test de una sola imagen
test_autoencoder(detector, 'test/valid/0a1e89f6-0ae2-4cab-854e-61c897cbbe13.jpg', save_path='resultado.png')

# test_autoencoder(detector, 'test/invalid/DniFrente 1.jpg', save_path='resultado.png')

# # Test de múltiples imágenes
# results = batch_test(detector, 'carpeta/con/imagenes', n_samples=5)

In [None]:
detector = DNIAnomalyDetector(device='cuda' if torch.cuda.is_available() else 'cpu')
detector.load_model('dni_anomaly_detector_old.pt')

detector.threshold