In [2]:
import os
import numpy as np
from PIL import Image
from typing import List, Tuple, Dict, Any
import pytorch_lightning as pl

# --- PyTorch Imports ---
# Import Dataset to inherit from it
from torch.utils.data import Dataset, DataLoader
# Import for the demonstration code
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
# --- End PyTorch Imports ---

class ImageDatasetWrapper(Dataset):
    """
    Un Dataset compatible con PyTorch que escanea subdirectorios de clases.
    Hereda de torch.utils.data.Dataset.
    Devuelve etiquetas como vectores one-hot (np.ndarray).
    
    ¬°NUEVO! Tambi√©n crea una lista 'self.targets' con etiquetas enteras
    (ej. 0, 1, 2) para ser usada por 'sklearn.model_selection.train_test_split'.
    """

    def __init__(self, root_dir: str, transform: Any = None):
        """
        Inicializa el dataset, escanea el directorio y crea el mapa de √≠ndices.
        """
        self.root_dir = root_dir
        self.transform = transform
        
        # data_index almacenar√° (filepath, one_hot_label)
        self.data_index: List[Tuple[str, np.ndarray]] = []
        
        # --- ¬°CORRECCI√ìN A√ëADIDA AQU√ç! ---
        # self.targets almacenar√° el √≠ndice entero (0, 1, 2...) para la estratificaci√≥n
        self.targets: List[int] = []
        # --- FIN DE LA CORRECCI√ìN ---
        
        self.class_names: List[str] = []
        self.class_to_label: Dict[str, np.ndarray] = {}
        self._build_index()

    def _build_index(self):
        """
        Escanea el directorio ra√≠z en busca de carpetas de clases y rellena 
        data_index (para los datos) y targets (para la divisi√≥n).
        """
        print(f"Escaneando directorio: {self.root_dir}")

        # 1. Descubrir nombres de clases (subdirectorios)
        subdirs = [d for d in os.listdir(self.root_dir)
                   if os.path.isdir(os.path.join(self.root_dir, d))]
        self.class_names = sorted(subdirs)
        num_classes = len(self.class_names)

        if num_classes == 0:
            raise ValueError(f"No se encontraron subdirectorios de clases en {self.root_dir}")

        # 2. Crear mapeo class_to_label (para arrays one-hot)
        for i, class_name in enumerate(self.class_names):
            one_hot = np.zeros(num_classes, dtype=np.float32)
            one_hot[i] = 1.0
            self.class_to_label[class_name] = one_hot

        print(f"Se encontraron {num_classes} clases: {self.class_names}")

        # 3. Rellenar la lista de √≠ndices maestros
        image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff')
        for class_index, class_name in enumerate(self.class_names):
            class_path = os.path.join(self.root_dir, class_name)
            one_hot_label = self.class_to_label[class_name]

            # Listar archivos en el directorio de la clase
            for filename in os.listdir(class_path):
                if filename.lower().endswith(image_extensions):
                    filepath = os.path.join(class_path, filename)
                    # Almacenar (filepath, one_hot_label)
                    self.data_index.append((filepath, one_hot_label))
                    
                    # --- ¬°CORRECCI√ìN A√ëADIDA AQU√ç! ---
                    # Almacenar el √≠ndice entero (0, 1, 2...)
                    self.targets.append(class_index)
                    # --- FIN DE LA CORRECCI√ìN ---

        print(f"Total de im√°genes indexadas: {len(self.data_index)}")

    def __len__(self) -> int:
        """Devuelve el n√∫mero total de items (im√°genes) en el dataset."""
        return len(self.data_index)

    def __getitem__(self, idx: int) -> Tuple[Any, np.ndarray]:
        """
        Recupera la imagen y su etiqueta one-hot correspondiente.
        Aplica transformaciones si se proporcionan.
        """
        if idx >= len(self.data_index) or idx < 0:
            raise IndexError("√çndice fuera de rango")

        filepath, label_vector = self.data_index[idx]

        # 1. Cargar la imagen con PIL
        try:
            image = Image.open(filepath).convert('RGB')
        except Exception as e:
            print(f"Error al cargar la imagen {filepath}: {e}")
            raise

        # 2. Aplicar transformaciones (ej. ToTensor, Normalize)
        if self.transform:
            image = self.transform(image)

        # Devuelve la imagen transformada y el vector one-hot
        return image, label_vector

In [3]:
from sklearn.model_selection import train_test_split

# ---------------------------------------------------------------
# 2. Un NUEVO Dataset Wrapper (m√°s simple)
# ---------------------------------------------------------------
class PreSplitDataset(Dataset):
    """
    Un Dataset que acepta una lista de datos (filepath, label) 
    pre-dividida en su constructor.
    """
    def __init__(self, data_list: List[Tuple[str, np.ndarray]], transform: Any = None):
        self.data_list = data_list
        self.transform = transform

    def __len__(self) -> int:
        return len(self.data_list)

    def __getitem__(self, idx: int) -> Tuple[Any, np.ndarray]:
        from PIL import Image
        
        # Obtener el filepath y la etiqueta de la lista
        filepath, label_vector = self.data_list[idx]

        # Cargar la imagen
        try:
            image = Image.open(filepath).convert('RGB')
        except Exception as e:
            print(f"Error loading image {filepath}: {e}")
            raise
            
        # Aplicar transformaciones
        if self.transform:
            image = self.transform(image)
            
        return image, label_vector

# ---------------------------------------------------------------
# 3. Configuraci√≥n y Proceso de Divisi√≥n
# ---------------------------------------------------------------
# --- Configuraci√≥n ---
dataset_root = "/lustre/proyectos/p032/datasets/images/3kvasir"
BATCH_SIZE = 64
SEED = 42

# Definir los ratios
TRAIN_RATIO = 0.70
VAL_RATIO = 0.15
TEST_RATIO = 0.15 # (debe sumar 1.0)

# Transformaciones
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
])

# --- 1. Cargar el dataset COMPLETO ---
print("Cargando el dataset completo para indexar...")
# (Necesitamos la clase 'ImageDatasetWrapper' original para esto)
# (He a√±adido .targets a la clase para que esto funcione)
full_dataset = ImageDatasetWrapper(root_dir=dataset_root)

# Extraer los datos y las etiquetas para sklearn
# data_index es List[Tuple[str, np.ndarray]]
# targets es List[int] (ej. 0, 1, 2, 0, 1...)
all_data = full_dataset.data_index 
all_targets = full_dataset.targets 

if len(all_data) == 0:
    raise RuntimeError("Error: No se encontraron datos en el dataset.")

print(f"Total de im√°genes encontradas: {len(all_data)}")

# --- 2. Primera Divisi√≥n (Train+Val vs Test) ---
# Dividimos el 85% para (train+val) y el 15% para test
print("Realizando primera divisi√≥n (estratificada)...")
train_val_data, test_data, train_val_targets, test_targets = train_test_split(
    all_data,
    all_targets,
    test_size=TEST_RATIO,
    stratify=all_targets, # ¬°La clave es esta!
    random_state=SEED
)

# --- 3. Segunda Divisi√≥n (Train vs Val) ---
# Dividimos (train+val) en train y val
# El ratio debe recalcularse: VAL_RATIO / (TRAIN_RATIO + VAL_RATIO)
val_split_ratio = VAL_RATIO / (TRAIN_RATIO + VAL_RATIO)

print("Realizando segunda divisi√≥n (estratificada)...")
train_data, val_data, train_targets, val_targets = train_test_split(
    train_val_data,
    train_val_targets,
    test_size=val_split_ratio,
    stratify=train_val_targets, # Estratificar de nuevo
    random_state=SEED
)

print("\n--- ¬°Divisi√≥n completada! ---")
print(f"Total:      {len(all_data)}")
print(f"Set Train:  {len(train_data)}")
print(f"Set Val:    {len(val_data)}")
print(f"Set Test:   {len(test_data)}")

# --- 4. Crear los Datasets y DataLoaders ---

# Aplicar la transformaci√≥n a cada set
train_dataset = PreSplitDataset(train_data, transform=transform)
val_dataset = PreSplitDataset(val_data, transform=transform)
test_dataset = PreSplitDataset(test_data, transform=transform)

# Crear los DataLoaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

print("\nDataLoaders estratificados (train, val, test) creados.")

# --- 5. (Opcional) Verificar la distribuci√≥n de clases ---
print("\nVerificando distribuci√≥n (ejemplo):")

def get_class_counts(targets_list):
    counts = np.bincount(targets_list)
    return [f"{count/len(targets_list)*100:.2f}%" for count in counts]
    
print(f"  Train: {get_class_counts(train_targets)}")
print(f"  Val:   {get_class_counts(val_targets)}")
print(f"  Test:  {get_class_counts(test_targets)}")


Cargando el dataset completo para indexar...
Escaneando directorio: /lustre/proyectos/p032/datasets/images/3kvasir
Se encontraron 3 clases: ['normal-cecum', 'normal-pylorus', 'normal-z-line']
Total de im√°genes indexadas: 1500
Total de im√°genes encontradas: 1500
Realizando primera divisi√≥n (estratificada)...
Realizando segunda divisi√≥n (estratificada)...

--- ¬°Divisi√≥n completada! ---
Total:      1500
Set Train:  1049
Set Val:    226
Set Test:   225

DataLoaders estratificados (train, val, test) creados.

Verificando distribuci√≥n (ejemplo):
  Train: ['33.37%', '33.37%', '33.27%']
  Val:   ['33.19%', '33.19%', '33.63%']
  Test:  ['33.33%', '33.33%', '33.33%']




In [4]:
# --- 1. CONFIGURACI√ìN INICIAL ---
# ==========================================================
# PATH_MODELO_SSL = "/lustre/proyectos/p032/models/multi_pretext_model2.ckpt" # No se usa
# MODEL_PATH = "/lustre/home/opacheco/MEDA_Challenge/models/221025MG_backbone.ssl.pth" # No se usa

# ¬øCu√°ntas clases tiene tu dataset de PRUEBA?
NUM_CLASES = 3 # Esto sigue siendo correcto para tu 3kvasir

# Par√°metros (¬°Importante usar los mismos!)
BATCH_SIZE = 64
EPOCHS_DE_PRUEBA = 10
LEARNING_RATE = 0.001 # Este LR se us√≥ para el Linear Probe, ¬°mantenerlo!
# JIGSAW_N = 2 # No aplica aqu√≠
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Usando dispositivo: {DEVICE}")
print(f"N√∫mero de clases: {NUM_CLASES}")

Usando dispositivo: cuda
N√∫mero de clases: 3


In [5]:
# --- Cargar el Backbone Baseline (ResNet-18 ImageNet) ---

print("Cargando ResNet-18 pre-entrenado en ImageNet...")
try:
    # Volvemos a resnet18
    resnet_imagenet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
except AttributeError:
    print("...usando fallback 'pretrained=True' por versi√≥n de torchvision.")
    resnet_imagenet = models.resnet18(pretrained=True)

baseline_backbone = nn.Sequential(*list(resnet_imagenet.children())[:-1])
print("¬°Backbone ResNet-18 (ImageNet) cargado!")

# Mover a GPU
baseline_backbone = baseline_backbone.to(DEVICE)

Cargando ResNet-18 pre-entrenado en ImageNet...
¬°Backbone ResNet-18 (ImageNet) cargado!


In [8]:
# --- Crear el Modelo para Linear Probing (ResNet-18 Baseline) ---

# Congelar todo el backbone baseline (el ResNet-18 de ImageNet)
for param in baseline_backbone.parameters():
    param.requires_grad = False

# --- ¬°CAMBIO IMPORTANTE! ---
# La salida de ResNet-18 es 512
in_features = 512
# Crear la cabeza lineal CORRECTA
linear_head = nn.Linear(in_features, NUM_CLASES) # <-- Debe ser 512

# Clase para el modelo combinado (Backbone + Cabeza)
class LinearProbingModel(nn.Module):
    def __init__(self, backbone, linear_head):
        super().__init__()
        self.backbone = backbone
        self.linear_head = linear_head # <-- Ahora s√≠ recibe la cabeza correcta (512 -> 3)

    def forward(self, x):
        # Asegurar que el backbone est√© en modo eval
        self.backbone.eval()
        with torch.no_grad(): # No calcular gradientes para el backbone
            feats = self.backbone(x)          # [B, 512, 1, 1]

        feats = feats.view(feats.size(0), -1)  # Flatten -> [B, 512]
        out = self.linear_head(feats)        # [B, NUM_CLASES] - ¬°Ahora s√≠ funciona!
        return out

# Crear la instancia del modelo final
# (baseline_backbone debe ser tu ResNet-18 cargado en la celda anterior)
model = LinearProbingModel(baseline_backbone, linear_head).to(DEVICE)

# Configurar Loss y Optimizador
# Para 3kvasir (NUM_CLASES=3), CrossEntropyLoss es correcto
criterion = nn.CrossEntropyLoss()
print("Usando CrossEntropyLoss para multi-clase.")

# Optimizador SOLO para la cabeza lineal, usando el LEARNING_RATE definido (0.001)
optimizer = optim.Adam(model.linear_head.parameters(), lr=LEARNING_RATE)

print("Modelo Linear Probing (ResNet-18 Baseline) creado CORRECTAMENTE.")

Usando CrossEntropyLoss para multi-clase.
Modelo Linear Probing (ResNet-18 Baseline) creado CORRECTAMENTE.


In [11]:
# --- ENTRENAR LA CABEZA LINEAL CON VALIDACI√ìN ---

print("Iniciando entrenamiento de la cabeza lineal (Linear Probing)...")

for epoch in range(EPOCHS_DE_PRUEBA):
    model.train() 
    running_loss = 0.0
    
    # ---- Train ----
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * inputs.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)

    # ---- Validaci√≥n ----
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels_one_hot in val_loader:
            inputs = inputs.to(DEVICE)
            labels_one_hot = labels_one_hot.to(DEVICE)

            labels_indices = torch.argmax(labels_one_hot, dim=1)

            outputs = model(inputs)
            loss = criterion(outputs, labels_indices)
            val_loss += loss.item() * inputs.size(0)

            _, predicted = torch.max(outputs, 1)
            total += labels_indices.size(0)
            correct += (predicted == labels_indices).sum().item()

    val_loss /= len(val_loader.dataset)
    val_acc = 100 * correct / total

    print(f"Epoch {epoch+1}/{EPOCHS_DE_PRUEBA} - "
          f"Train Loss: {epoch_loss:.4f} - "
          f"Val Loss: {val_loss:.4f} - "
          f"Val Acc: {val_acc:.2f}%")

print("Entrenamiento de la cabeza finalizado.")


Iniciando entrenamiento de la cabeza lineal (Linear Probing)...
Epoch 1/10 - Train Loss: 0.0742 - Val Loss: 0.0716 - Val Acc: 99.12%
Epoch 2/10 - Train Loss: 0.0715 - Val Loss: 0.0721 - Val Acc: 98.23%
Epoch 3/10 - Train Loss: 0.0680 - Val Loss: 0.0700 - Val Acc: 98.67%
Epoch 4/10 - Train Loss: 0.0649 - Val Loss: 0.0671 - Val Acc: 99.12%
Epoch 5/10 - Train Loss: 0.0620 - Val Loss: 0.0651 - Val Acc: 98.67%
Epoch 6/10 - Train Loss: 0.0609 - Val Loss: 0.0637 - Val Acc: 98.67%
Epoch 7/10 - Train Loss: 0.0582 - Val Loss: 0.0644 - Val Acc: 97.79%
Epoch 8/10 - Train Loss: 0.0570 - Val Loss: 0.0641 - Val Acc: 97.79%
Epoch 9/10 - Train Loss: 0.0544 - Val Loss: 0.0613 - Val Acc: 98.67%
Epoch 10/10 - Train Loss: 0.0530 - Val Loss: 0.0600 - Val Acc: 98.67%
Entrenamiento de la cabeza finalizado.


In [None]:
from sklearn.metrics import f1_score

# --- 6. EVALUAR EL RENDIMIENTO CON F1 SCORE ---

print("Evaluando en el set de validaci√≥n...")

# Lista para almacenar todas las etiquetas verdaderas y predichas
all_labels = []
all_predicted = []

model.eval()
correct = 0
total = 0

with torch.no_grad():
    for inputs, labels_one_hot in test_loader:
        # Mover datos al dispositivo (CPU/GPU)
        inputs = inputs.to(DEVICE)
        labels_one_hot = labels_one_hot.to(DEVICE)

        labels_indices = torch.argmax(labels_one_hot, dim=1)
        
        # 1. Pase adelante (Forward Pass)
        outputs = model(inputs)
        
        # 2. Obtener la predicci√≥n de clase
        _, predicted_indices = torch.max(outputs.data, 1)
        
        # 3. Almacenar para el c√°lculo de F1 Score
        # Mover a CPU para Scikit-learn y convertir a numpy
        all_labels.extend(labels_indices.cpu().numpy())
        all_predicted.extend(predicted_indices.cpu().numpy())
        
        # 4. Actualizar contadores de Accuracy
        total += labels_one_hot.size(0)
        correct += (predicted_indices == labels_indices).sum().item()


# --- C√ÅLCULO DE M√âTRICAS ---

# 1. Calcular Accuracy
accuracy = 100 * correct / total

# 2. Calcular F1 Score
# 'average="macro"' se usa com√∫nmente en problemas multi-clase para dar
# igual peso a cada clase, independientemente del desequilibrio.
# Cambiar a 'average="weighted"' si se necesita considerar el desequilibrio de clases.
f1 = f1_score(all_labels, all_predicted, average='macro') 
f1_percentage = f1 * 100

# --- RESULTADO FINAL ---
print("\\n==========================================================") 
print(f"üéâ ¬°Prueba de Evaluaci√≥n Lineal (Linear Probing - ResNet50 Baseline) completa! üéâ")
print(f"   Accuracy en el set de test: {accuracy:.2f} %") 
print(f"   F1 Score (Macro) en el set de test: {f1_percentage:.2f} %") 
print("==========================================================")


Evaluando en el set de validaci√≥n...
üéâ ¬°Prueba de Evaluaci√≥n Lineal (Linear Probing - ResNet50 Baseline) completa! üéâ
   Accuracy en el set de test: 97.78 %
   F1 Score (Macro) en el set de test: 97.78 %
