In [1]:
"""
======================================================================
SCRIPT DE ENTRENAMIENTO: CLASIFICADOR DE DEFECTOS (PyTorch VGG16)
======================================================================
Versión Modular.
Estructura:
1. Configuración y Utils
2. Dataset y Modelo
3. Preparación de Datos
4. Bucle de Entrenamiento
5. Guardado
"""

import os
import cv2
import numpy as np
import json
import traceback
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import models, transforms
from sklearn.preprocessing import LabelEncoder
from PIL import Image
import matplotlib.pyplot as plt
import time
import sys

In [2]:
# --- BLOQUE 1: CONFIGURACIÓN ---

try:
    from dataset_paths import (
        DATASET_BASE_PATH,
        DatasetPaths,
        AVAILABLE_CATEGORIES,
        discover_categories,
        CLASSIFIER_MODEL_PATH
    )
except ImportError:
    print("Error: No se pudo importar 'dataset_paths.py'.")
    raise

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMG_WIDTH = 256
IMG_HEIGHT = 256
BATCH_SIZE = 32
EPOCHS = 100

def get_next_model_path_torch(base_dir, prefix, category):
    os.makedirs(base_dir, exist_ok=True)
    max_num = 0
    base_prefix = f"{prefix}_{category}"
    for f in os.listdir(base_dir):
        if f.startswith(base_prefix + "_") and f.endswith(".pth"):
            try:
                num = int(f[len(base_prefix)+1:-4])
                if num > max_num: max_num = num
            except ValueError: continue
    return os.path.join(base_dir, f"{base_prefix}_{str(max_num + 1).zfill(3)}")

def load_and_mask_image(img_path, mask_path, target_size=(IMG_HEIGHT, IMG_WIDTH)):
    img = cv2.imread(img_path)
    if img is None: return None
    img = cv2.resize(img, target_size, interpolation=cv2.INTER_AREA)
    
    if os.path.exists(mask_path):
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        if mask is None: mask = np.zeros(target_size, dtype=np.uint8)
        else:
            mask = cv2.resize(mask, target_size, interpolation=cv2.INTER_NEAREST)
            _, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
    else:
        return None 

    img_masked = cv2.bitwise_and(img, img, mask=mask)
    img_rgb = cv2.cvtColor(img_masked, cv2.COLOR_BGR2RGB)
    return Image.fromarray(img_rgb)

--- Configuracion de Rutas Cargada ---
Categoria: capsule
Ruta de Entrenamiento: datasets\capsule\train\good
Ruta de Prueba: datasets\capsule\test
Ruta de Mascaras: datasets\capsule\ground_truth
-------------------------------------


In [3]:
# --- BLOQUE 2: DATASET Y MODELO ---

class MaskedDefectDataset(Dataset):
    def __init__(self, category_paths, defect_folders, transform=None):
        self.samples = []
        self.transform = transform
        self.label_encoder = LabelEncoder()
        all_labels = []
        
        for defect_type in defect_folders:
            folder_path = os.path.join(category_paths.test_path, defect_type)
            if not os.path.isdir(folder_path): continue
            
            for f in os.listdir(folder_path):
                if not f.lower().endswith(('.png', '.jpg', '.jpeg')): continue
                img_path = os.path.join(folder_path, f)
                mask_path = category_paths.get_ground_truth_mask_path(defect_type, f)
                
                if os.path.exists(mask_path):
                    self.samples.append((img_path, mask_path))
                    all_labels.append(defect_type)
        
        if self.samples:
            self.encoded_labels = self.label_encoder.fit_transform(all_labels)
            self.classes = self.label_encoder.classes_

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, mask_path = self.samples[idx]
        label_idx = self.encoded_labels[idx]
        image = load_and_mask_image(img_path, mask_path)
        if image is None: image = Image.new('RGB', (IMG_WIDTH, IMG_HEIGHT))
        if self.transform: image = self.transform(image)
        return image, label_idx

class DefectClassifierVGG16(nn.Module):
    def __init__(self, num_classes):
        super(DefectClassifierVGG16, self).__init__()
        vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)
        self.features = vgg.features
        for param in self.features[:20].parameters(): param.requires_grad = False
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512 * 7 * 7, 512), nn.ReLU(True), nn.Dropout(0.5),
            nn.Linear(512, 128), nn.ReLU(True), nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        return self.classifier(x)

train_transforms = transforms.Compose([
    transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(0.1, 0.1),
    transforms.ToTensor(),
])

val_transforms = transforms.Compose([
    transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
    transforms.ToTensor(),
])

In [4]:
# --- BLOQUE 3: PREPARACIÓN DE DATOS ---

def prepare_classifier_data(category):
    """Prepara DataLoaders para el clasificador."""
    print(f"--- Preparando datos Clasificador: {category} ---")
    paths = DatasetPaths(DATASET_BASE_PATH, category)
    defect_folders = paths.get_test_defect_folders()
    
    if len(defect_folders) < 2:
        print(f"  Info: Menos de 2 tipos de defectos. No se puede entrenar clasificador.")
        return None, None, None

    full_dataset = MaskedDefectDataset(paths, defect_folders, transform=train_transforms)
    if len(full_dataset) == 0:
        print("  Error: No se encontraron datos enmascarados válidos.")
        return None, None, None

    # Split 80/20
    train_size = int(0.8 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    train_ds, val_ds = random_split(full_dataset, [train_size, val_size])
    
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)
    
    print(f"  Clases: {full_dataset.classes}")
    print(f"  Train Samples: {len(train_ds)} | Val Samples: {len(val_ds)}")
    
    return train_loader, val_loader, full_dataset.classes

In [5]:
# --- BLOQUE 4: ENTRENAMIENTO ---

def run_classifier_training(model, train_loader, val_loader, epochs=EPOCHS):
    """Bucle de entrenamiento y validación."""
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
    model.to(DEVICE)
    
    print(f"--- Iniciando Entrenamiento ({epochs} épocas) ---")
    best_acc = 0.0
    best_wts = None
    
    total_batches = len(train_loader)
    start_time = time.time()

    for epoch in range(epochs):
        # Train
        model.train()
        run_loss, correct, total = 0.0, 0, 0
        
        for i, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            run_loss += loss.item() * inputs.size(0)
            _, pred = torch.max(outputs, 1)
            correct += (pred == labels).sum().item()
            total += labels.size(0)
            
            # Barra
            percent = (i + 1) / total_batches
            bar = '=' * int(20 * percent) + '-' * (20 - int(20 * percent))
            sys.stdout.write(f"\r    Epoch {epoch+1:03d}/{epochs} [{bar}]")
            sys.stdout.flush()
            
        train_acc = correct / total
        
        # Validate
        model.eval()
        val_correct, val_total = 0, 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
                outputs = model(inputs)
                _, pred = torch.max(outputs, 1)
                val_correct += (pred == labels).sum().item()
                val_total += labels.size(0)
                
        val_acc = val_correct / val_total if val_total > 0 else 0
        
        sys.stdout.write(f"\r    Epoch {epoch+1:03d}/{epochs} [{'='*20}] Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}\n")
        
        if val_acc >= best_acc:
            best_acc = val_acc
            best_wts = model.state_dict()

    print(f"  Entrenamiento completado. Mejor Val Acc: {best_acc:.4f}")
    return best_wts if best_wts else model.state_dict(), best_acc

In [6]:
# --- BLOQUE 5: GUARDADO ---

def save_classifier_model(model_weights, category, classes, accuracy):
    base_path = get_next_model_path_torch(CLASSIFIER_MODEL_PATH, "classifier", category)
    
    # Guardar Pesos
    torch.save(model_weights, base_path + ".pth")
    
    # Guardar Etiquetas
    label_map = {int(i): str(c) for i, c in enumerate(classes)}
    meta = {
        "category": category,
        "labels": label_map,
        "accuracy": float(accuracy)
    }
    
    with open(base_path + "_labels.json", "w") as f:
        json.dump(meta, f, indent=4)
        
    print(f"  Modelo guardado en: {base_path}.pth")


In [7]:
# --- BLOQUE 6: EJECUCIÓN PRINCIPAL ---

def main():
    print(f"Dispositivo: {DEVICE}")
    categories = discover_categories(DATASET_BASE_PATH)
    
    for category in categories:
        try:
            # 1. Datos
            train_dl, val_dl, classes = prepare_classifier_data(category)
            if not train_dl: continue
            
            # 2. Modelo
            model = DefectClassifierVGG16(len(classes))
            
            # 3. Entrenar
            best_weights, best_acc = run_classifier_training(model, train_dl, val_dl)
            
            # 4. Guardar
            save_classifier_model(best_weights, category, classes, best_acc)
            
        except Exception as e:
            print(f"Error crítico en {category}: {e}")
            traceback.print_exc()

if __name__ == "__main__":
    main()

Dispositivo: cuda
--- Preparando datos Clasificador: capsule ---
--- Configuracion de Rutas Cargada ---
Categoria: capsule
Ruta de Entrenamiento: datasets\capsule\train\good
Ruta de Prueba: datasets\capsule\test
Ruta de Mascaras: datasets\capsule\ground_truth
-------------------------------------
  Clases: ['crack' 'faulty_imprint' 'poke' 'scratch' 'squeeze']
  Train Samples: 87 | Val Samples: 22
--- Iniciando Entrenamiento (100 épocas) ---
  Entrenamiento completado. Mejor Val Acc: 0.7273
  Modelo guardado en: out-models\m-models\classifier_capsule_002.pth
--- Preparando datos Clasificador: metal_nut ---
--- Configuracion de Rutas Cargada ---
Categoria: metal_nut
Ruta de Entrenamiento: datasets\metal_nut\train\good
Ruta de Prueba: datasets\metal_nut\test
Ruta de Mascaras: datasets\metal_nut\ground_truth
-------------------------------------
  Clases: ['bent' 'color' 'flip' 'scratch']
  Train Samples: 74 | Val Samples: 19
--- Iniciando Entrenamiento (100 épocas) ---
  Entrenamiento com