In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.models as models
from PIL import Image
import os
import random
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.callbacks import RichProgressBar

In [2]:

# --- 0. Configuraci√≥n ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DATA_PATH = "/lustre/home/atorres/compartido/datasets/all_medmnist_images" # <- Tu carpeta con todas las imgs
MODEL_PATH = "/lustre/home/atorres/MEDA_Challenge/models/221025MG_backbone.ssl.pth" # <- Tu modelo
IMG_SIZE = 28
BATCH_SIZE = 256
N_EPOCHS = 10
LR = 1e-4 # Learning rate para el fine-tuning
JIGSAW_N = 4 # Rejilla de 4x4 para Jigsaw

print(f"Usando dispositivo: {DEVICE}")


Usando dispositivo: cuda


In [3]:

# --- 1. Cargar el Backbone SSL Pre-entrenado ---

# Definimos una clase 'dummy' solo para cargar la estructura que guardaste
# (Asumiendo que guardaste solo el state_dict de 'encoder_q[0]')
class MoCoLightning(nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.encoder_q = nn.Sequential(backbone)

print("Cargando backbone SSL...")
# Cargar ResNet18 sin pesos (solo la arquitectura)
resnet = models.resnet18(weights=None) 
# Tu backbone (quitando la capa FC final)
backbone_structure = nn.Sequential(*list(resnet.children())[:-1])

# Cargar el estado
encoder_wrapper = MoCoLightning(backbone=backbone_structure)
state_dict = torch.load(MODEL_PATH, map_location=DEVICE)

# Cargar los pesos en la estructura
encoder_wrapper.encoder_q[0].load_state_dict(state_dict)

# --- Este es tu backbone listo para usar ---
ssl_backbone = encoder_wrapper.encoder_q[0].to(DEVICE)
print("Backbone cargado y movido a GPU.")



Cargando backbone SSL...
Backbone cargado y movido a GPU.


In [4]:

# --- 2. Dataset (Versi√≥n Carpeta Unificada) ---

class MedMNISTUnifiedFolder(Dataset):
    """
    Dataset que lee todas las im√°genes de una sola carpeta ra√≠z.
    """
    def __init__(self, root, transform=None):
        self.root = root
        self.files = [os.path.join(root, f) for f in os.listdir(root)
                      if f.endswith(('.png', '.jpg', '.jpeg'))]
        self.transform = transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img_path = self.files[idx]
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img


In [5]:

# Transformaci√≥n base (para el dataloader)
transform = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.ToTensor(),
    # Nota: No normalizamos aqu√≠ porque las tareas (color, patch)
    # esperan la imagen en rango [0, 1]
])

dataset = MedMNISTUnifiedFolder(DATA_PATH, transform)
loader = DataLoader(dataset, 
                    batch_size=BATCH_SIZE, 
                    shuffle=True, 
                    num_workers=2, 
                    pin_memory=True,
                    drop_last=True) # drop_last=True es importante para Jigsaw si el batch no es divisible
print(f"Dataset cargado con {len(dataset)} im√°genes.")


Dataset cargado con 600338 im√°genes.




In [6]:


# --- 3. Funciones de Pretext-Task (VERSI√ìN TENSOR) ---
# Estas funciones operan sobre BATCHES en la GPU

def colorization_pair_tensor(imgs):
    """
    Input: Batch de imgs RGB [B, 3, H, W]
    Output: (Input para modelo [B, 3, H, W] (gris), Target [B, 3, H, W] (color))
    """
    # 1. Convertir a escala de grises (usando la transformaci√≥n de torchvision)
    gray = T.Grayscale()(imgs) # Shape: [B, 1, H, W]
    
    # 2. REPETIR el canal 1 -> 3 (Esta es la correcci√≥n para el RuntimeError)
    gray_repeated = gray.repeat(1, 3, 1, 1) # Shape: [B, 3, H, W]
    
    # Input: gris repetido, Target: imagen original a color
    return gray_repeated, imgs


In [7]:

def patch_prediction_pair_tensor(imgs, mask_size_ratio=0.25):
    """
    Input: Batch de imgs RGB [B, 3, H, W]
    Output: (Input para modelo [B, 3, H, W] (con m√°scara), Target [B, 3, H, W] (original))
    """
    B, C, H, W = imgs.shape
    mask_size_h = int(H * mask_size_ratio)
    mask_size_w = int(W * mask_size_ratio)
    
    # Calcular centro
    x = (W - mask_size_w) // 2
    y = (H - mask_size_h) // 2
    
    masked_imgs = imgs.clone()
    # Enmascarar con 0.0 (negro)
    masked_imgs[:, :, y:y+mask_size_h, x:x+mask_size_w] = 0.0 
    
    return masked_imgs, imgs


In [8]:

def jigsaw_pair_tensor(imgs, n=JIGSAW_N):
    """
    Input: Batch de imgs RGB [B, 3, H, W]
    Output: (Input para modelo [B, 3, H, W] (desordenado), Target [B, n*n] (orden))
    """
    B, C, H, W = imgs.shape
    patch_h, patch_w = H // n, W // n
    
    if H % n != 0 or W % n != 0:
        # Esto no deber√≠a pasar si IMG_SIZE es 128 y n=3. Ajusta si es necesario.
        raise ValueError(f"El tama√±o de la imagen ({H}x{W}) no es divisible por n={n}")

    # 1. Cortar el batch en parches
    # imgs shape: [B, C, H, W] -> [B, C, n, patch_h, n, patch_w]
    patches = imgs.unfold(2, patch_h, patch_h).unfold(3, patch_w, patch_w)
    # -> [B, C, n, n, patch_h, patch_w]
    patches = patches.permute(0, 2, 3, 1, 4, 5).reshape(B, n*n, C, patch_h, patch_w)
    # -> [B, n*n, C, patch_h, patch_w] (Batch, NumPatches, C, H_patch, W_patch)

    # 2. Generar permutaci√≥n aleatoria para cada imagen en el batch
    # 'order' es el target: [B, n*n]. 
    # Cada fila es [0, 1, ..., 8] desordenado.
    order = torch.stack([torch.randperm(n*n, device=imgs.device) for _ in range(B)])
    
    # 3. Desordenar los parches usando el 'order'
    # 'order' [B, 9] -> expand a [B, 9, C, pH, pW] para gather
    order_expanded = order.view(B, n*n, 1, 1, 1).expand_as(patches)
    shuffled_patches = torch.gather(patches, 1, order_expanded)
    # -> [B, n*n, C, patch_h, patch_w]

    # 4. Reensamblar el batch desordenado
    shuffled_patches = shuffled_patches.view(B, n, n, C, patch_h, patch_w)
    shuffled_patches = shuffled_patches.permute(0, 3, 1, 4, 2, 5) 
    # -> [B, C, n, patch_h, n, patch_w]
    shuffled_imgs = shuffled_patches.reshape(B, C, H, W)
    
    # Input: imagen desordenada, Target: el orden (permutaci√≥n)
    return shuffled_imgs, order


In [9]:
# --- 4. Modelo Multi-Pretexto (Versi√≥n LightningModule) ---

class MultiPretextSSL_Lightning(pl.LightningModule):
    def __init__(self, backbone, learning_rate=1e-4):
        super().__init__()
        self.save_hyperparameters('learning_rate') # Guarda lr
        self.backbone = backbone
        self.lr = learning_rate
        
        num_features = 512 # Salida de ResNet18
        
        # --- DECODER CORREGIDO PARA 28x28 ---
        decoder_layers_28x28 = [
            nn.ConvTranspose2d(num_features, 256, kernel_size=4, stride=1, padding=0), # 1x1 -> 4x4
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1), # 4x4 -> 7x7
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # 7x7 -> 14x14
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),    # 14x14 -> 28x28
            nn.Sigmoid()
        ]
        
        self.color_head = nn.Sequential(*decoder_layers_28x28)
        self.patch_head = nn.Sequential(*decoder_layers_28x28)
        
        # --- JIGSAW HEAD CORREGIDO PARA N=4 (16 patches) ---
        self.n_patches = JIGSAW_N * JIGSAW_N # 16
        self.jigsaw_head = nn.Sequential(
            nn.Flatten(),
            nn.Linear(num_features, 512),
            nn.ReLU(),
            nn.Linear(512, self.n_patches * self.n_patches) # 16*16 = 256
        )

    def forward(self, x, task="color"):
        feats = self.backbone(x)
        if task == "color":
            return self.color_head(feats)
        elif task == "patch":
            return self.patch_head(feats)
        elif task == "jigsaw":
            return self.jigsaw_head(feats)

    def training_step(self, batch, batch_idx):
        # batch es lo que retorna __getitem__, en este caso, 'imgs'
        imgs = batch
        
        # Elegir una tarea al azar
        task = random.choice(["color", "patch", "jigsaw"])
        loss = 0.0

        if task == "color":
            inp, target = colorization_pair_tensor(imgs)
            pred = self(inp, "color")
            loss = F.mse_loss(pred, target)
        
        elif task == "patch":
            inp, target = patch_prediction_pair_tensor(imgs)
            pred = self(inp, "patch")
            loss = F.mse_loss(pred, target)
        
        elif task == "jigsaw":
            inp, target = jigsaw_pair_tensor(imgs, n=JIGSAW_N) # n=4
            pred = self(inp, "jigsaw") # Shape: [B, 256]
            
            # [B, 256] -> [B, 16, 16]
            pred_reshaped = pred.view(-1, self.n_patches, self.n_patches) 
            target_reshaped = target.view(-1) # [B*16]
            
            loss = F.cross_entropy(pred_reshaped.view(-1, self.n_patches), target_reshaped)

        # Loggear la p√©rdida. 'prog_bar=True' la muestra en la barra de progreso
        self.log(f'loss_{task}', loss, prog_bar=True)
        self.log('train_loss', loss)
        
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

In [11]:
# --- 5. Bucle de Entrenamiento (¬°Ahora con el Trainer!) ---

print("--- Iniciando entrenamiento con Lightning ---")

# Instanciar el modelo de Lightning
model = MultiPretextSSL_Lightning(ssl_backbone, learning_rate=LR)

# Instanciar el Trainer
trainer = pl.Trainer(
    max_epochs=N_EPOCHS,
    accelerator='gpu',  # Usa 'gpu' (Lightning 2.0+)
    devices=-1,          # Usa 1 GPU
    callbacks=[RichProgressBar()], # Una barra de progreso m√°s bonita
    logger=None # Puedes a√±adir un logger si quieres (ej. TensorBoardLogger)
)

# ¬°A entrenar!
# Aqu√≠ es donde ver√°s la barra de progreso con el ETA (tiempo restante)
trainer.fit(model, loader)

print("--- Entrenamiento finalizado ---")

üí° Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


--- Iniciando entrenamiento con Lightning ---


RuntimeError: Lightning can't create new processes if CUDA is already initialized. Did you manually call `torch.cuda.*` functions, have moved the model to the device, or allocated memory on the GPU any other way? Please remove any such calls, or change the selected strategy. You will have to restart the Python kernel.

In [None]:


# # --- 6. Exportaci√≥n a ONNX ---
# # Exportamos solo el backbone adaptado, que es lo que usar√°s 
# # para la inferencia downstream (clustering, clasificaci√≥n, etc.)

# print("Exportando backbone adaptado a ONNX...")
# model.eval()
# dummy_input = torch.randn(1, 3, IMG_SIZE, IMG_SIZE).to(DEVICE)

# # Exportamos solo 'model.backbone'
# torch.onnx.export(
#     model.backbone, 
#     dummy_input, 
#     "ssl_adapted_backbone.onnx", # Nombre del archivo
#     input_names=['input'], 
#     output_names=['features'],
#     opset_version=17,
#     dynamic_axes={'input': {0: 'batch_size'}, 'features': {0: 'batch_size'}}
# )
                      
# print("Backbone adaptado guardado en 'ssl_adapted_backbone.onnx'")