In [1]:
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms

import medmnist
from medmnist import INFO, Evaluator

In [2]:
torch.cuda.is_available()

True

In [3]:
print(f"MedMNIST v{medmnist.__version__} @ {medmnist.HOMEPAGE}")

MedMNIST v3.0.2 @ https://github.com/MedMNIST/MedMNIST/


In [4]:
data_flag = 'pathmnist'
# data_flag = 'breastmnist'
download = True

NUM_EPOCHS = 3
BATCH_SIZE = 128
lr = 0.001

info = INFO[data_flag]
task = info['task']
n_channels = info['n_channels']
n_classes = len(info['label'])

DataClass = getattr(medmnist, info['python_class'])

In [5]:
import torch
from torch.utils.data import DataLoader
from lightly.data import LightlyDataset
from lightly.transforms import SimCLRTransform
from lightly.models.modules import SimCLRProjectionHead
from lightly.loss import NTXentLoss
import torchvision.models as models

# --- 1. Dataset y DataLoader ---
INPUT_DIR = "/lustre/proyectos/p032/datasets/all_medmnist_images"

# NOTA IMPORTANTE: MedMNIST son 28x28.
# Los aumentos por defecto de Lightly pueden ser muy agresivos.
# ¡Asegúrate de que input_size=28!
transform = SimCLRTransform(input_size=28) 

dataset = LightlyDataset(input_dir=INPUT_DIR, transform=transform)
dataloader = DataLoader(
    dataset,
    batch_size=256,
    shuffle=True,
    drop_last=True,
    num_workers=1
)

# --- 2. Modelo y Loss ---
# Usemos un ResNet18 (más ligero, bueno para 28x28)
# Modificamos la primera capa para aceptar 1 o 3 canales si es necesario
# (Aunque SimCLRTransform convierte todo a 3 canales por defecto)
resnet = models.resnet18()
backbone = torch.nn.Sequential(*list(resnet.children())[:-1]) # Quitar la capa final

# Modelo SimCLR (backbone + cabeza de proyección)
class SimCLRModel(torch.nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone
        # Ajusta num_ftrs si usas otro backbone
        self.projection_head = SimCLRProjectionHead(input_dim=512, output_dim=128) 
    
    def forward(self, x):
        x = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(x)
        return z

model = SimCLRModel(backbone)
criterion = NTXentLoss(temperature=0.1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

#device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cuda"
model.to(device)

# --- 3. Bucle de Entrenamiento ---
print("Iniciando entrenamiento SSL con pipeline de Python...")
for epoch in range(10): # Ajusta las épocas
    total_loss = 0
    for (x0, x1), _, _ in tqdm(dataloader, desc=f"Epoch {epoch+1}"):
        x0 = x0.to(device)
        x1 = x1.to(device)
        
        z0 = model(x0)
        z1 = model(x1)
        
        loss = criterion(z0, z1)
        total_loss += loss.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1}/{10} - Loss: {avg_loss:.4f}")

# Al final, guarda tu backbone
torch.save(model.backbone.state_dict(), "mi_backbone_ssl_xd.pth")

Iniciando entrenamiento SSL con pipeline de Python...


Epoch 1:  71%|█████████████████████████████████████▍               | 1655/2345 [09:31<03:58,  2.89it/s]


KeyboardInterrupt: 