# Prueba del modelo preentrenado con SSL

In [35]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import DataLoader
import torch.optim as optim

# Importa las librerías de MedMNIST
import medmnist
from medmnist import INFO  
import torchvision.transforms as transforms 

# --- 1. Configuración ---
SSL_WEIGHTS_PATH = "models/221025MG_backbone.ssl.pth"
DATASET_NAME = 'breastmnist'
BATCH_SIZE = 128                          
NUM_EPOCHS = 5                            
LEARNING_RATE = 1e-4                      # LR bajo es bueno para finetuning
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Usando dispositivo: {DEVICE}")

# --- 2. Definir el Modelo de Clasificación ---
# Este modelo combina el backbone + una cabeza para clasificar

class SslClassifier(nn.Module):
    def __init__(self, backbone, num_classes):
        super().__init__()
        self.backbone = backbone
        # La salida de ResNet18 (antes de la última capa) es 512
        self.fc = nn.Linear(512, num_classes) # Esta es la nueva cabeza

    def forward(self, x):
        # Pasa por el backbone (que ya sabe de SSL)
        features = self.backbone(x).flatten(start_dim=1)
        # Pasa por la cabeza de clasificación
        output = self.fc(features)
        return output

# --- 3. Cargar Datos (BreastMNIST) ---
info = INFO[DATASET_NAME]
num_classes = len(info['label'])
DataClass = getattr(medmnist, info['python_class'])


mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]

data_transform = transforms.Compose([
    # Convierte la imagen (de 1 canal) a 3 canales duplicando el canal.
    transforms.Grayscale(num_output_channels=3),
    
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std) 
])

# Usamos el mismo transform para todos
train_transform = data_transform
val_transform = data_transform
test_transform = data_transform

# Crear los datasets
train_dataset = DataClass(split='train', transform=train_transform, download=True)
val_dataset = DataClass(split='val', transform=val_transform, download=True)
test_dataset = DataClass(split='test', transform=test_transform, download=True)

# Crear los DataLoaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"Datos de {DATASET_NAME} cargados. Clases: {num_classes}")

# --- 4. Cargar Modelo y Pesos SSL ---

# 1. Crea la arquitectura del backbone (la misma que en SSL)
resnet = models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])

# 2. Carga los pesos que entrenaste con SSL
# Esto carga los pesos en el objeto 'backbone'
backbone.load_state_dict(torch.load(SSL_WEIGHTS_PATH, map_location=DEVICE))
print("¡Pesos del backbone SSL cargados exitosamente!")

# 3. Crea el modelo de clasificación completo
model = SslClassifier(backbone, num_classes).to(DEVICE)

# --- 5. Configurar Entrenamiento (Finetuning) ---
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

print(f"--- Iniciando Finetuning (5 épocas) en {DATASET_NAME} ---")

for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        labels = labels.squeeze().long()
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    
    print(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Train Loss: {running_loss / len(train_loader):.4f}")

    # Validación después de cada época
    model.eval()
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            labels = labels.squeeze().long()
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()
    
    val_acc = 100 * val_correct / val_total
    print(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Val Accuracy: {val_acc:.2f}%")

print("--- Finetuning completado ---")

# --- 6. Prueba Final (Clasificación) ---
model.eval()
test_correct = 0
test_total = 0
with torch.no_grad():
    for images, labels in test_loader:
        labels = labels.squeeze().long()
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        test_total += labels.size(0)
        test_correct += (predicted == labels).sum().item()

test_acc = 100 * test_correct / test_total
print(f"==================================================")
print(f"✅ Resultado Final - Test Accuracy: {test_acc:.2f}%")
print(f"==================================================")

Usando dispositivo: cuda
Datos de breastmnist cargados. Clases: 2
¡Pesos del backbone SSL cargados exitosamente!
--- Iniciando Finetuning (5 épocas) en breastmnist ---
Epoch [1/5], Train Loss: 0.6521
Epoch [1/5], Val Accuracy: 26.92%
Epoch [2/5], Train Loss: 0.6246
Epoch [2/5], Val Accuracy: 73.08%
Epoch [3/5], Train Loss: 0.5765
Epoch [3/5], Val Accuracy: 70.51%
Epoch [4/5], Train Loss: 0.5633
Epoch [4/5], Val Accuracy: 67.95%
Epoch [5/5], Train Loss: 0.5379
Epoch [5/5], Val Accuracy: 73.08%
--- Finetuning completado ---
✅ Resultado Final - Test Accuracy: 71.15%
