In [9]:
"""
======================================================================
SCRIPT DE ENTRENAMIENTO: DETECTOR DE ANOMALIAS (PyTorch VGG16)
======================================================================
Versión Modular para Notebooks.
Estructura:
1. Configuración
2. Dataset y Dataloaders
3. Modelo VGG16 Autoencoder
4. Funciones de Entrenamiento
5. Funciones de Evaluación (Cálculo de Umbral con Filtro Outliers)
6. Orquestador Principal
"""

import os
import cv2
import numpy as np
import json
import traceback
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from PIL import Image
import matplotlib.pyplot as plt
import time
import sys

# Configurar Matplotlib para evitar errores de GUI en servidores
plt.switch_backend('Agg')

In [10]:
# --- BLOQUE 1: CONFIGURACIÓN ---

try:
    from dataset_paths import (
        DATASET_BASE_PATH,
        DatasetPaths,
        AVAILABLE_CATEGORIES,
        discover_categories,
        DETECTOR_MODEL_PATH
    )
except ImportError:
    print("Error: No se pudo importar 'dataset_paths.py'.")
    raise

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMG_WIDTH = 256
IMG_HEIGHT = 256
BATCH_SIZE = 32
EPOCHS = 50
SENSE = 1.0 

def get_device():
    return DEVICE

def get_next_model_version_path_torch(base_dir, prefix, category):
    os.makedirs(base_dir, exist_ok=True)
    max_num = 0
    base_prefix = f"{prefix}_{category}"
    for f in os.listdir(base_dir):
        if f.startswith(base_prefix + "_") and f.endswith(".pth"):
            try:
                num = int(f[len(base_prefix)+1:-4])
                if num > max_num: max_num = num
            except ValueError: continue
    return os.path.join(base_dir, f"{base_prefix}_{str(max_num + 1).zfill(3)}")

In [11]:
# --- BLOQUE 2: DATASET Y DATALOADERS ---

class DefectDataset(Dataset):
    def __init__(self, folder_path, transform=None):
        self.folder_path = folder_path
        self.transform = transform
        self.image_files = []
        if os.path.exists(folder_path):
            self.image_files = [f for f in os.listdir(folder_path) 
                                if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp'))]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = os.path.join(self.folder_path, img_name)
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        # Autoencoder: input == target
        return image, image

data_transforms = transforms.Compose([
    transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
    transforms.ToTensor(),
])

def prepare_dataloaders(category, batch_size=BATCH_SIZE):
    """
    Prepara y devuelve (train_loader, anomaly_loader).
    Si no hay anomalías para test, anomaly_loader será None.
    """
    print(f"--- Preparando datos para: {category} ---")
    paths = DatasetPaths(DATASET_BASE_PATH, category)
    
    # 1. Train (Buenas)
    train_dataset = DefectDataset(paths.train_path, transform=data_transforms)
    if len(train_dataset) == 0:
        print(f"Error: No hay imágenes en {paths.train_path}")
        return None, None

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
    # 2. Anomalías (Test)
    defect_folders = paths.get_test_defect_folders()
    anomaly_files = []
    for df in defect_folders:
        ds = DefectDataset(os.path.join(paths.test_path, df), transform=data_transforms)
        if len(ds) > 0: anomaly_files.append(ds)
    
    if anomaly_files:
        anomaly_dataset = torch.utils.data.ConcatDataset(anomaly_files)
        anomaly_loader = DataLoader(anomaly_dataset, batch_size=batch_size, shuffle=False)
    else:
        anomaly_loader = None
    
    print(f"  Train Images: {len(train_dataset)}")
    print(f"  Anomaly Images: {len(anomaly_dataset) if anomaly_files else 0}")
    return train_loader, anomaly_loader

In [12]:
# --- BLOQUE 3: MODELO ---

class VGG16Autoencoder(nn.Module):
    def __init__(self):
        super(VGG16Autoencoder, self).__init__()
        # Encoder (Congelado)
        vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)
        self.encoder = nn.Sequential(*list(vgg.features.children())[:24])
        for param in self.encoder.parameters():
            param.requires_grad = False
            
        # Decoder (Entrenable)
        self.decoder = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.ReLU(True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(512, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.ReLU(True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(256, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(True),
            nn.Conv2d(64, 3, kernel_size=3, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        features = self.encoder(x)
        return self.decoder(features)

def get_model(weights_path=None):
    """Instancia el modelo y opcionalmente carga pesos."""
    model = VGG16Autoencoder().to(DEVICE)
    if weights_path:
        if os.path.exists(weights_path):
            print(f"Cargando pesos desde: {weights_path}")
            model.load_state_dict(torch.load(weights_path, map_location=DEVICE))
        else:
            print(f"Advertencia: No se encontró {weights_path}")
    return model

In [13]:
# --- BLOQUE 4: ENTRENAMIENTO ---

def train_model(model, train_loader, epochs=EPOCHS):
    criterion = nn.MSELoss()
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)
    model.train()
    
    print(f"--- Iniciando Entrenamiento ({epochs} épocas) ---")
    start_time = time.time()
    total_batches = len(train_loader)

    for epoch in range(epochs):
        running_loss = 0.0
        epoch_start = time.time()
        
        for i, (inputs, _) in enumerate(train_loader):
            inputs = inputs.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, inputs)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
            
            # Barra de Progreso
            percent = (i + 1) / total_batches
            bar = '=' * int(20 * percent) + '-' * (20 - int(20 * percent))
            sys.stdout.write(f"\r    Epoch {epoch+1:03d}/{epochs} [{bar}] Loss: {loss.item():.6f}")
            sys.stdout.flush()
            
        epoch_loss = running_loss / len(train_loader.dataset)
        duration = time.time() - epoch_start
        sys.stdout.write(f"\r    Epoch {epoch+1:03d}/{epochs} [{'='*20}] Avg Loss: {epoch_loss:.6f} | Time: {duration:.1f}s   \n")

    print(f"  Entrenamiento completado en {(time.time() - start_time)/60:.1f} min.")
    return model

In [7]:
# --- BLOQUE 5: EVALUACIÓN Y UMBRAL (Con Filtro de Outliers) ---

def calculate_errors(model, dataloader):
    """Calcula MSE por imagen."""
    model.eval()
    mse = nn.MSELoss(reduction='none')
    errors = []
    print("  Calculando errores...", end="", flush=True)
    with torch.no_grad():
        for inputs, _ in dataloader:
            inputs = inputs.to(DEVICE)
            outputs = model(inputs)
            loss = mse(inputs, outputs).mean(dim=[1, 2, 3])
            errors.extend(loss.cpu().numpy())
    print(" Listo.")
    return np.array(errors)

def remove_outliers(data):
    """
    Filtra outliers usando el Rango Intercuartílico (IQR).
    Retorna los datos limpios (sin valores extremos superiores).
    """
    if len(data) == 0: return data
    
    Q1 = np.percentile(data, 25)
    Q3 = np.percentile(data, 75)
    IQR = Q3 - Q1
    
    # Límite superior para considerar un error como "normal"
    upper_bound = Q3 + 1.5 * IQR
    
    clean_data = data[data <= upper_bound]
    
    removed = len(data) - len(clean_data)
    if removed > 0:
        print(f"  [FILTRO] Se eliminaron {removed} outliers (> {upper_bound:.6f})")
        
    return clean_data

def compute_threshold(good_errors, sensitivity=SENSE):
    """Calcula el umbral usando estadísticas de datos limpios."""
    # 1. Limpiar
    clean_errors = remove_outliers(good_errors)
    
    if len(clean_errors) == 0:
        print("  [ADVERTENCIA] Se filtraron todos los datos. Usando originales.")
        clean_errors = good_errors

    # 2. Estadísticas
    mean = np.mean(clean_errors)
    std = np.std(clean_errors)
    
    # 3. Umbral
    threshold = mean + sensitivity * std
    
    # 4. Validación de seguridad
    max_clean = np.max(clean_errors)
    if threshold < max_clean:
        print(f"  [AJUSTE] Umbral ({threshold:.6f}) < Max Limpio ({max_clean:.6f}). Ajustando.")
        threshold = max_clean * 1.05
        
    return threshold, mean, std

def save_results(model, category, threshold, mean, std):
    base_path = get_next_model_version_path_torch(DETECTOR_MODEL_PATH, "detector", category)
    torch.save(model.state_dict(), base_path + ".pth")
    
    meta = {
        "category": category,
        "threshold": float(threshold),
        "stats": {"mean": float(mean), "std": float(std)}
    }
    with open(base_path + "_threshold.json", "w") as f:
        json.dump(meta, f, indent=4)
    print(f"  Modelo guardado en: {base_path}.pth")

def plot_results(good_errors, anomaly_errors, threshold, category):
    try:
        plt.figure(figsize=(10, 6))
        
        # Histograma datos originales (gris)
        plt.hist(good_errors, bins=50, alpha=0.3, label='Buenas (Original)', color='gray', density=True)
        
        # Histograma datos limpios (azul)
        clean_errors = remove_outliers(good_errors)
        plt.hist(clean_errors, bins=50, alpha=0.7, label='Buenas (Limpias)', color='blue', density=True)

        if len(anomaly_errors) > 0:
            plt.hist(anomaly_errors, bins=50, alpha=0.7, label='Anomalias', color='red', density=True)
            
        plt.axvline(threshold, color='black', linestyle='dashed', label=f'Umbral: {threshold:.5f}')
        plt.title(f"Distribución de Errores - {category}")
        plt.legend()
        
        plot_path = os.path.join(DETECTOR_MODEL_PATH, f"distribucion_errores_{category}.png")
        # Asegurar que el directorio existe antes de guardar
        os.makedirs(os.path.dirname(plot_path), exist_ok=True)
        plt.savefig(plot_path)
        plt.close()
        print(f"  Gráfico guardado en: {plot_path}")
    except Exception as e:
        print(f"  Error al graficar: {e}")

In [None]:
# --- BLOQUE 6: ORQUESTADOR ---

def run_full_pipeline():
    print(f"Dispositivo: {DEVICE}")
    categories = discover_categories(DATASET_BASE_PATH)
    
    for category in categories:
        try:
            # 1. Preparar
            train_dl, anomaly_dl = prepare_dataloaders(category)
            if not train_dl: continue
            
            # 2. Entrenar
            model = get_model()
            model = train_model(model, train_dl, epochs=EPOCHS)
            
            # 3. Evaluar (Usamos train_dl sin shuffle para consistencia)
            eval_dl = DataLoader(train_dl.dataset, batch_size=BATCH_SIZE, shuffle=False)
            good_errors = calculate_errors(model, eval_dl)
            
            anomaly_errors = np.array([])
            if anomaly_dl:
                anomaly_errors = calculate_errors(model, anomaly_dl)
            
            # 4. Calcular Umbral con Outliers Filtrados
            threshold, mean, std = compute_threshold(good_errors, SENSE)
            print(f"  RESULTADO: Umbral={threshold:.6f} (SENSE={SENSE})")
            
            # 5. Guardar
            plot_results(good_errors, anomaly_errors, threshold, category)
            save_results(model, category, threshold, mean, std)
            
        except Exception as e:
            print(f"Error en categoría {category}: {e}")
            traceback.print_exc()

if __name__ == "__main__":
    run_full_pipeline()