# Visualización de Resultados en Test Set

Este notebook permite visualizar las predicciones del último modelo entrenado en el dataset de test.

## Funcionalidades:
- Cargar el mejor modelo entrenado
- Visualizar predicciones con bounding boxes
- Comparar predicciones vs ground truth
- Analizar métricas de detección
- Visualizar casos de éxito y error

In [None]:
import sys
import json
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import torch
from PIL import Image
import mlflow

# Add src to path
sys.path.append(str(Path.cwd().parent.parent))

from src.data.taco_dataloader import create_dataloader
from src.models.detector import TrashDetector
from src.models.evaluate import ObjectDetectionEvaluator, calculate_iou

print("Libraries imported successfully!")

## 1. Configuración y Carga del Modelo

In [None]:
# Configuration
PROJECT_ROOT = Path.cwd().parent.parent
DATA_DIR = PROJECT_ROOT / 'data' / 'processed'
CHECKPOINT_PATH = PROJECT_ROOT / 'models' / 'checkpoints' / 'best_model.pth'
MLFLOW_DIR = PROJECT_ROOT / 'mlruns'

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Model parameters
NUM_CLASSES = 61  # 60 classes + background
SCORE_THRESHOLD = 0.5
IOU_THRESHOLD = 0.5

In [None]:
# Load the best model
print(f"Loading model from: {CHECKPOINT_PATH}")

checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)

model = TrashDetector(
    num_classes=NUM_CLASSES,
    backbone='resnet50'
).to(device)

model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"Model loaded successfully!")
print(f"Trained for {checkpoint['epoch']} epochs")
print(f"Best validation loss: {checkpoint['val_loss']:.4f}")

## 2. Cargar Dataset de Test

In [None]:
# Create test dataloader
test_loader = create_dataloader(
    processed_dir=str(DATA_DIR),
    split='test',
    batch_size=4,
    shuffle=False,
    num_workers=2
)

print(f"Test dataset loaded: {len(test_loader)} batches")

## 3. Cargar Métricas de Test

In [None]:
# Load test metrics if available
test_metrics_path = PROJECT_ROOT / 'models' / 'checkpoints' / 'test_metrics.json'

if test_metrics_path.exists():
    with open(test_metrics_path, 'r') as f:
        test_metrics = json.load(f)
    
    print("Test Metrics:")
    print(f"  mAP@0.5: {test_metrics['mAP']:.4f}")
    print(f"  Precision: {test_metrics['precision']:.4f}")
    print(f"  Recall: {test_metrics['recall']:.4f}")
    print(f"  F1-Score: {test_metrics['f1_score']:.4f}")
    print(f"  True Positives: {test_metrics['true_positives']}")
    print(f"  False Positives: {test_metrics['false_positives']}")
    print(f"  False Negatives: {test_metrics['false_negatives']}")
else:
    print("Test metrics not found. Run evaluation first.")
    test_metrics = None

## 4. Función de Visualización

In [None]:
def denormalize_image(image_tensor):
    """
    Denormalize image tensor for visualization
    """
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    
    image = image_tensor.cpu() * std + mean
    image = torch.clamp(image, 0, 1)
    
    return image.permute(1, 2, 0).numpy()


def visualize_predictions(
    image, 
    pred_boxes, 
    pred_labels, 
    pred_scores,
    gt_boxes=None, 
    gt_labels=None,
    score_threshold=0.5,
    figsize=(15, 10)
):
    """
    Visualize predictions and ground truth on an image
    
    Args:
        image: Image tensor [C, H, W]
        pred_boxes: Predicted boxes [N, 4]
        pred_labels: Predicted labels [N]
        pred_scores: Prediction scores [N]
        gt_boxes: Ground truth boxes [M, 4]
        gt_labels: Ground truth labels [M]
        score_threshold: Score threshold for filtering
        figsize: Figure size
    """
    # Denormalize image
    img_np = denormalize_image(image)
    
    # Create figure
    if gt_boxes is not None:
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
        
        # Plot ground truth
        ax1.imshow(img_np)
        ax1.set_title('Ground Truth', fontsize=14, fontweight='bold')
        ax1.axis('off')
        
        for box, label in zip(gt_boxes, gt_labels):
            x1, y1, x2, y2 = box
            width = x2 - x1
            height = y2 - y1
            
            rect = patches.Rectangle(
                (x1, y1), width, height,
                linewidth=2, edgecolor='green', facecolor='none'
            )
            ax1.add_patch(rect)
            
            ax1.text(
                x1, y1 - 5,
                f'Class {label.item()}',
                color='white',
                fontsize=10,
                bbox=dict(facecolor='green', alpha=0.7, edgecolor='none', pad=2)
            )
        
        # Plot predictions
        ax2.imshow(img_np)
        ax2.set_title('Predictions', fontsize=14, fontweight='bold')
        ax2.axis('off')
        
        ax_pred = ax2
    else:
        fig, ax_pred = plt.subplots(1, 1, figsize=figsize)
        ax_pred.imshow(img_np)
        ax_pred.set_title('Predictions', fontsize=14, fontweight='bold')
        ax_pred.axis('off')
    
    # Filter predictions by score
    mask = pred_scores >= score_threshold
    pred_boxes = pred_boxes[mask]
    pred_labels = pred_labels[mask]
    pred_scores = pred_scores[mask]
    
    # Plot predictions
    for box, label, score in zip(pred_boxes, pred_labels, pred_scores):
        x1, y1, x2, y2 = box
        width = x2 - x1
        height = y2 - y1
        
        rect = patches.Rectangle(
            (x1, y1), width, height,
            linewidth=2, edgecolor='red', facecolor='none'
        )
        ax_pred.add_patch(rect)
        
        ax_pred.text(
            x1, y1 - 5,
            f'Class {label.item()} ({score:.2f})',
            color='white',
            fontsize=10,
            bbox=dict(facecolor='red', alpha=0.7, edgecolor='none', pad=2)
        )
    
    plt.tight_layout()
    return fig

## 5. Visualizar Ejemplos del Test Set

In [None]:
# Get a batch of test data
batch = next(iter(test_loader))

images = batch['images'].to(device)
gt_boxes = batch['bboxes']
gt_labels = batch['labels']

# Get predictions
with torch.no_grad():
    predictions = model(images)

print(f"Batch size: {len(images)}")
print(f"Predictions obtained!")

In [None]:
# Visualize first 4 images from the batch
num_images = min(4, len(images))

for idx in range(num_images):
    fig = visualize_predictions(
        image=images[idx],
        pred_boxes=predictions[idx]['boxes'].cpu(),
        pred_labels=predictions[idx]['labels'].cpu(),
        pred_scores=predictions[idx]['scores'].cpu(),
        gt_boxes=gt_boxes[idx],
        gt_labels=gt_labels[idx],
        score_threshold=SCORE_THRESHOLD
    )
    plt.show()

## 6. Análisis de Casos Específicos

### 6.1 Casos con Alta Confianza (> 0.8)

In [None]:
# Find images with high confidence predictions
high_confidence_images = []

for idx in range(num_images):
    max_score = predictions[idx]['scores'].max().item() if len(predictions[idx]['scores']) > 0 else 0
    if max_score > 0.8:
        high_confidence_images.append(idx)

print(f"Found {len(high_confidence_images)} images with high confidence predictions")

# Visualize high confidence cases
for idx in high_confidence_images[:2]:  # Show first 2
    fig = visualize_predictions(
        image=images[idx],
        pred_boxes=predictions[idx]['boxes'].cpu(),
        pred_labels=predictions[idx]['labels'].cpu(),
        pred_scores=predictions[idx]['scores'].cpu(),
        gt_boxes=gt_boxes[idx],
        gt_labels=gt_labels[idx],
        score_threshold=0.8
    )
    plt.show()

### 6.2 Análisis de IoU por Imagen

In [None]:
# Calculate average IoU for each image
def calculate_image_metrics(pred_boxes, pred_labels, pred_scores, gt_boxes, gt_labels, threshold=0.5):
    """
    Calculate metrics for a single image
    """
    # Filter predictions
    mask = pred_scores >= threshold
    pred_boxes = pred_boxes[mask]
    pred_labels = pred_labels[mask]
    pred_scores = pred_scores[mask]
    
    if len(pred_boxes) == 0 or len(gt_boxes) == 0:
        return {
            'avg_iou': 0.0,
            'num_predictions': len(pred_boxes),
            'num_ground_truths': len(gt_boxes),
            'matched': 0
        }
    
    # Calculate IoU for all pairs
    ious = []
    matched = 0
    gt_matched = torch.zeros(len(gt_boxes), dtype=torch.bool)
    
    for pred_box, pred_label in zip(pred_boxes, pred_labels):
        max_iou = 0.0
        max_idx = -1
        
        for gt_idx, (gt_box, gt_label) in enumerate(zip(gt_boxes, gt_labels)):
            if pred_label != gt_label:
                continue
            
            if gt_matched[gt_idx]:
                continue
            
            iou = calculate_iou(pred_box, gt_box)
            if iou > max_iou:
                max_iou = iou
                max_idx = gt_idx
        
        if max_iou > 0:
            ious.append(max_iou)
            if max_iou >= 0.5 and max_idx >= 0:
                matched += 1
                gt_matched[max_idx] = True
    
    return {
        'avg_iou': np.mean(ious) if ious else 0.0,
        'num_predictions': len(pred_boxes),
        'num_ground_truths': len(gt_boxes),
        'matched': matched
    }

# Analyze all images in batch
for idx in range(num_images):
    metrics = calculate_image_metrics(
        predictions[idx]['boxes'].cpu(),
        predictions[idx]['labels'].cpu(),
        predictions[idx]['scores'].cpu(),
        gt_boxes[idx],
        gt_labels[idx],
        threshold=SCORE_THRESHOLD
    )
    
    print(f"\nImage {idx + 1}:")
    print(f"  Predictions: {metrics['num_predictions']}, Ground Truths: {metrics['num_ground_truths']}")
    print(f"  Matched: {metrics['matched']}, Average IoU: {metrics['avg_iou']:.3f}")

## 7. Evaluación Completa en Todo el Test Set

In [None]:
# Evaluate on entire test set
print("Evaluating on entire test set...")

evaluator = ObjectDetectionEvaluator(
    model=model,
    device=device,
    num_classes=NUM_CLASSES,
    score_threshold=SCORE_THRESHOLD,
    iou_threshold=IOU_THRESHOLD
)

full_metrics = evaluator.evaluate(test_loader)

print("\nFull Test Set Metrics:")
print(f"  mAP@0.5: {full_metrics['mAP']:.4f}")
print(f"  Precision: {full_metrics['precision']:.4f}")
print(f"  Recall: {full_metrics['recall']:.4f}")
print(f"  F1-Score: {full_metrics['f1_score']:.4f}")
print(f"  True Positives: {full_metrics['true_positives']}")
print(f"  False Positives: {full_metrics['false_positives']}")
print(f"  False Negatives: {full_metrics['false_negatives']}")

## 8. Visualización de Métricas por Clase (Top 10 Clases)

In [None]:
# Extract per-class AP metrics
per_class_ap = {k: v for k, v in full_metrics.items() if k.startswith('AP_class_')}

# Sort by AP value
sorted_ap = sorted(per_class_ap.items(), key=lambda x: x[1], reverse=True)

# Plot top 10 classes
top_10 = sorted_ap[:10]
class_names = [item[0].replace('AP_class_', 'Class ') for item in top_10]
ap_values = [item[1] for item in top_10]

plt.figure(figsize=(12, 6))
plt.barh(class_names, ap_values, color='skyblue')
plt.xlabel('Average Precision (AP)', fontsize=12)
plt.title('Top 10 Classes by Average Precision', fontsize=14, fontweight='bold')
plt.xlim(0, 1)
plt.grid(axis='x', alpha=0.3)
plt.tight_layout()
plt.show()

# Plot bottom 10 classes
bottom_10 = sorted_ap[-10:]
class_names = [item[0].replace('AP_class_', 'Class ') for item in bottom_10]
ap_values = [item[1] for item in bottom_10]

plt.figure(figsize=(12, 6))
plt.barh(class_names, ap_values, color='coral')
plt.xlabel('Average Precision (AP)', fontsize=12)
plt.title('Bottom 10 Classes by Average Precision', fontsize=14, fontweight='bold')
plt.xlim(0, 1)
plt.grid(axis='x', alpha=0.3)
plt.tight_layout()
plt.show()

## 9. Visualizar Más Ejemplos del Test Set

In [None]:
# Get another batch
batch_iter = iter(test_loader)
next(batch_iter)  # Skip first batch
batch2 = next(batch_iter)

images2 = batch2['images'].to(device)
gt_boxes2 = batch2['bboxes']
gt_labels2 = batch2['labels']

# Get predictions
with torch.no_grad():
    predictions2 = model(images2)

# Visualize
num_images2 = min(4, len(images2))
for idx in range(num_images2):
    fig = visualize_predictions(
        image=images2[idx],
        pred_boxes=predictions2[idx]['boxes'].cpu(),
        pred_labels=predictions2[idx]['labels'].cpu(),
        pred_scores=predictions2[idx]['scores'].cpu(),
        gt_boxes=gt_boxes2[idx],
        gt_labels=gt_labels2[idx],
        score_threshold=SCORE_THRESHOLD
    )
    plt.show()

## 10. Resumen y Conclusiones

In [None]:
print("="*60)
print("RESUMEN DE EVALUACIÓN")
print("="*60)
print(f"\nModelo: Faster R-CNN con ResNet-50")
print(f"Épocas entrenadas: {checkpoint['epoch']}")
print(f"Mejor pérdida de validación: {checkpoint['val_loss']:.4f}")
print(f"\nMétricas en Test Set:")
print(f"  • mAP@0.5: {full_metrics['mAP']:.4f}")
print(f"  • Precision: {full_metrics['precision']:.4f}")
print(f"  • Recall: {full_metrics['recall']:.4f}")
print(f"  • F1-Score: {full_metrics['f1_score']:.4f}")
print(f"\nDetecciones:")
print(f"  • True Positives: {full_metrics['true_positives']}")
print(f"  • False Positives: {full_metrics['false_positives']}")
print(f"  • False Negatives: {full_metrics['false_negatives']}")
print(f"\nClase con mejor AP: {sorted_ap[0][0].replace('AP_class_', 'Class ')} ({sorted_ap[0][1]:.4f})")
print(f"Clase con peor AP: {sorted_ap[-1][0].replace('AP_class_', 'Class ')} ({sorted_ap[-1][1]:.4f})")
print("="*60)