# Probando Dominio Adversarial

chestmnist + pathmnist -> SSL
chestmnist(etiquetado) + breastmnist -> DANN
bloodmnist -> inferencia

# SSL

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torchvision import transforms
from torch.utils.data import DataLoader
from lightly.data import LightlyDataset
from lightly.transforms import SimCLRTransform
from lightly.models.modules import SimCLRProjectionHead
from lightly.loss import NTXentLoss
import pytorch_lightning as pl
from sklearn.model_selection import StratifiedShuffleSplit

In [7]:
# Cargar datos
torch.set_float32_matmul_precision("high")

color_jitter = transforms.ColorJitter(
    0.5 * 0.8,  # brillo
    0.5 * 0.8,  # contraste
    0.5 * 0.8,  # saturaci√≥n
    0.2 * 0.8,  # tono
)

transform = transforms.Compose([
    transforms.RandomResizedCrop(28, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomApply([color_jitter], p=0.8),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

dataset = LightlyDataset(
    input_dir='/lustre/proyectos/p032/datasets/images/tmp',
    transform=transform)

dataloader = DataLoader(
    dataset,
    batch_size=256,
    shuffle=True,
    drop_last=True,
    num_workers=1,
)

In [8]:
# Define Modelo

# --- 2. Backbone ---
resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
backbone = nn.Sequential(*list(resnet.children())[:-1])  # Quitar la capa final

from copy import deepcopy

class SimCLRProjectionHead(nn.Module):
    def __init__(self, input_dim, output_dim): # <-- Recibe 2048
        super().__init__()
        hidden_dim = input_dim // 4 # Ej: 2048 // 4 = 512
        
        # ¬°CORRECTO! Usa el 'input_dim' (2048)
        self.head = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    def forward(self, x):
        return self.head(x)

class MoCoLightning(pl.LightningModule):
    def __init__(self, backbone, 
                 lr=0.0003, 
                 temperature=0.1, 
                 momentum=0.999, 
                 queue_size=65536,
                 input_dim=512, 
                 output_dim=128):
        
        super().__init__()
        self.save_hyperparameters('lr', 'temperature', 'momentum', 'queue_size', 'input_dim', 'output_dim')

        # 1. Crear los encoders de Consulta (q) y Clave (k)
        # El encoder_q es el que se entrena con backprop
        self.encoder_q = nn.Sequential(
            backbone,
            nn.Flatten(start_dim=1), # <-- APLANA a (B, 2048)
            SimCLRProjectionHead(self.hparams.input_dim, self.hparams.output_dim)
        )
        
        # El encoder_k es el encoder de momentum
        self.encoder_k = deepcopy(self.encoder_q)

        # Congelar los par√°metros del encoder_k. No se entrenan con el optimizador.
        for param in self.encoder_k.parameters():
            param.requires_grad = False

        # 2. Crear la fila (queue)
        # 
        self.register_buffer("queue", torch.randn(self.hparams.output_dim, self.hparams.queue_size))
        self.queue = F.normalize(self.queue, dim=0)
        
        # Puntero para saber d√≥nde insertar en la fila
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        """ Actualizaci√≥n de momentum para el encoder_k """
        # 
        m = self.hparams.momentum
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data = param_k.data * m + param_q.data * (1. - m)

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        """ Saca el batch m√°s antiguo de la fila y a√±ade el nuevo batch de 'keys' """
        batch_size = keys.shape[0]
        ptr = int(self.queue_ptr)
        
        # Asegurarse de que el batch cabe
        assert self.hparams.queue_size % batch_size == 0 

        # Reemplazar las claves en la fila
        self.queue[:, ptr:ptr + batch_size] = keys.T
        ptr = (ptr + batch_size) % self.hparams.queue_size  # Mover el puntero
        self.queue_ptr[0] = ptr

    def forward(self, x):
        # El forward ahora solo se usa para inferencia (ej. clasificaci√≥n lineal)
        # Devuelve solo las caracter√≠sticas del backbone
        return self.encoder_q[0](x).flatten(start_dim=1)

    def training_step(self, batch, batch_idx):
        (im_q, im_k), _, _ = batch # (x0, x1) ahora son im_q (consulta) e im_k (clave)
        
        # 1. Computar features de consulta (q)
        q = self.encoder_q(im_q)
        q = F.normalize(q, dim=1)

        # 2. Computar features de clave (k)
        with torch.no_grad():
            # Actualizar el encoder de clave (momentum)
            self._momentum_update_key_encoder()
            
            # Obtener las claves (sin gradiente)
            k = self.encoder_k(im_k)
            k = F.normalize(k, dim=1)

        # 3. Calcular la p√©rdida
        loss = self.moco_loss(q, k)
        
        # 4. Actualizar la fila
        self._dequeue_and_enqueue(k)
        
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def moco_loss(self, q, k):
        # q: NxC (consultas)
        # k: NxC (claves positivas)
        # queue: CxK (claves negativas)

        # Logits positivos (N, 1)
        l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
        
        # Logits negativos (N, K)
        l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])

        # Logits totales (N, 1+K)
        logits = torch.cat([l_pos, l_neg], dim=1)
        
        # Aplicar temperatura
        logits /= self.hparams.temperature

        # Etiquetas (siempre es la primera columna, la positiva)
        labels = torch.zeros(logits.shape[0], dtype=torch.long, device=self.device)
        
        loss = F.cross_entropy(logits, labels)
        return loss

    def configure_optimizers(self):
        # IMPORTANTE: El optimizador SOLO debe entrenar el encoder_q
        # Los par√°metros del encoder_k se actualizan por momentum.
        
        # El paper us√≥ AdamW [cite: 735]
        optimizer = torch.optim.AdamW(
            self.encoder_q.parameters(),
            lr=self.hparams.lr,
            weight_decay=1e-5 # El paper prob√≥ 1e-5 [cite: 736]
        )
        return optimizer

In [9]:
# --- 4. Inicializar modelo Lightning ---
from pytorch_lightning.loggers import CSVLogger

logger = CSVLogger(save_dir="logs", name="mo_co_run")

model = MoCoLightning(
    backbone=backbone,
    lr=0.0003,          # El LR que ten√≠as
    temperature=0.1,    # La temperatura que ten√≠as
    queue_size=8192     # Un valor m√°s peque√±o si 65536 da OOM
)

# --- 5. Entrenador Lightning ---
trainer = pl.Trainer(
    max_epochs=10,
    accelerator="gpu",  # detecta GPU autom√°ticamente
    devices=2,           # cambia a 4 si quieres usar todas tus GPUs
    log_every_n_steps=10,
    logger=logger,
)

# --- 6. Entrenamiento ---
trainer.fit(model, dataloader)

# --- 7. Guardar backbone al final ---
torch.save(model.encoder_q[0].state_dict(), "MG_backbone_ssl.pth")
print(f"El log de p√©rdidas por √©poca se guard√≥ en: {logger.log_dir}/metrics.csv")

/lustre/proyectos/p032/env/lib64/python3.9/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /lustre/proyectos/p032/env/lib64/python3.9/site-pack ...
üí° Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
W1021 07:47:14.261246 316819 torch/multiprocessing/spawn.py:169] Terminating process 316912 via signal SIGTERM


ProcessRaisedException: 

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/lustre/proyectos/p032/env/lib64/python3.9/site-packages/torch/multiprocessing/spawn.py", line 90, in _wrap
    fn(i, *args)
  File "/lustre/proyectos/p032/env/lib64/python3.9/site-packages/pytorch_lightning/strategies/launchers/multiprocessing.py", line 173, in _wrapping_function
    results = function(*args, **kwargs)
  File "/lustre/proyectos/p032/env/lib64/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 598, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/lustre/proyectos/p032/env/lib64/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 967, in _run
    self.strategy.setup_environment()
  File "/lustre/proyectos/p032/env/lib64/python3.9/site-packages/pytorch_lightning/strategies/ddp.py", line 153, in setup_environment
    super().setup_environment()
  File "/lustre/proyectos/p032/env/lib64/python3.9/site-packages/pytorch_lightning/strategies/strategy.py", line 129, in setup_environment
    self.accelerator.setup_device(self.root_device)
  File "/lustre/proyectos/p032/env/lib64/python3.9/site-packages/pytorch_lightning/accelerators/cuda.py", line 46, in setup_device
    _check_cuda_matmul_precision(device)
  File "/lustre/proyectos/p032/env/lib64/python3.9/site-packages/lightning_fabric/accelerators/cuda.py", line 161, in _check_cuda_matmul_precision
    if not torch.cuda.is_available() or not _is_ampere_or_later(device):
  File "/lustre/proyectos/p032/env/lib64/python3.9/site-packages/lightning_fabric/accelerators/cuda.py", line 155, in _is_ampere_or_later
    major, _ = torch.cuda.get_device_capability(device)
  File "/lustre/proyectos/p032/env/lib64/python3.9/site-packages/torch/cuda/__init__.py", line 600, in get_device_capability
    prop = get_device_properties(device)
  File "/lustre/proyectos/p032/env/lib64/python3.9/site-packages/torch/cuda/__init__.py", line 616, in get_device_properties
    _lazy_init()  # will define _get_device_properties
  File "/lustre/proyectos/p032/env/lib64/python3.9/site-packages/torch/cuda/__init__.py", line 398, in _lazy_init
    raise RuntimeError(
RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
