# Probando Dominio Adversarial

In [20]:
import torch
from tqdm import tqdm
import torch.utils.data as data
import torchvision.transforms as transforms
import numpy as np

import medmnist
from medmnist import INFO

In [21]:
print(INFO)

{'pathmnist': {'python_class': 'PathMNIST', 'description': 'The PathMNIST is based on a prior study for predicting survival from colorectal cancer histology slides, providing a dataset (NCT-CRC-HE-100K) of 100,000 non-overlapping image patches from hematoxylin & eosin stained histological images, and a test dataset (CRC-VAL-HE-7K) of 7,180 image patches from a different clinical center. The dataset is comprised of 9 types of tissues, resulting in a multi-class classification task. We resize the source images of 3√ó224√ó224 into 3√ó28√ó28, and split NCT-CRC-HE-100K into training and validation set with a ratio of 9:1. The CRC-VAL-HE-7K is treated as the test set.', 'url': 'https://zenodo.org/records/10519652/files/pathmnist.npz?download=1', 'MD5': 'a8b06965200029087d5bd730944a56c1', 'url_64': 'https://zenodo.org/records/10519652/files/pathmnist_64.npz?download=1', 'MD5_64': '55aa9c1e0525abe5a6b9d8343a507616', 'url_128': 'https://zenodo.org/records/10519652/files/pathmnist_128.npz?downlo

In [25]:
data_flag = 'breastmnist' # dataset a usar
download = True # Checa si el dataset esta descargado, si no descargalo
root = '/lustre/proyectos/p032/datasets' # Ruta a datasets

dataset_names = ['pathmnist', 'chestmnist', 'breastmnist', 'bloodmnist']
datasets = {}

NUM_EPOCHS = 3
BATCH_SIZE = 1
lr = 0.001

data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5], std=[.5])
])

for dataset in dataset_names:
    # Obt√©n la clase como antes
    DataClass = getattr(medmnist, INFO[dataset]['python_class'])

    # 1. Crea los datasets primero y gu√°rdalos en variables
    train_dataset = DataClass(split='train', transform=data_transform, download=download, root=root)
    test_dataset = DataClass(split='test', transform=data_transform, download=download, root=root)
    val_dataset = DataClass(split='val', transform=data_transform, download=download, root=root)

    # 2. Ahora, usa esas variables para crear los DataLoaders
    train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=True)
    # (Nota: Tu c√≥digo original no creaba un 'val_loader', pero aqu√≠ podr√≠as hacerlo)
    # val_loader = data.DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    # 3. Finalmente, asigna todos los objetos al diccionario
    datasets[dataset] = {
        'train_dataset': train_dataset,
        'test_dataset': test_dataset,
        'val_dataset': val_dataset,
        'train_loader': train_loader,
        'test_loader': test_loader,
        # 'val_loader': val_loader # Descomenta si lo creaste arriba
    }

chestmnist + pathmnist -> SSL
chestmnist(etiquetado) + breastmnist -> DANN
bloodmnist -> inferencia

# PRUEBA

In [26]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, ConcatDataset
import torchvision.models as models
from torch.autograd import Function

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger

import medmnist
from medmnist import INFO
import torchvision.transforms as transforms

import itertools
import os

# --- 0. CONFIGURACI√ìN GLOBAL ---
# ==================================
print(f"PyTorch version: {torch.__version__}")
print(f"MedMNIST version: {medmnist.__version__}")

# Configuraciones del Experimento
ROOT_PATH = './medmnist_data' # Directorio para descargar los datasets
SSL_BACKBONE_PATH = 'ssl_backbone.pth'
DANN_MODEL_PATH = 'dann_model.pth'

# Hiperpar√°metros
BATCH_SIZE = 128
NUM_WORKERS = 24 # Aumenta si tu m√°quina tiene m√°s cores
# Reducimos √©pocas para que se ejecute m√°s r√°pido como ejemplo
NUM_EPOCHS_SSL = 5
NUM_EPOCHS_DANN = 10
LR_SSL = 0.01
LR_DANN = 0.001

# Aseg√∫rate de usar la GPU si est√° disponible
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_float32_matmul_precision("high")

PyTorch version: 2.8.0+rocm6.4
MedMNIST version: 3.0.2


In [27]:
def prepare_datasets():
    """
    Descarga y prepara todos los datasets de MedMNIST necesarios.
    """
    print("\n--- Preparando Datasets ---")
    data_transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=3), # ResNet espera 3 canales
        transforms.ToTensor(),
        transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
    ])

    # Datasets para SSL (sin etiquetas)
    # Combinaremos chestmnist y pathmnist
    PathMNISTClass = getattr(medmnist, INFO['pathmnist']['python_class'])
    ChestMNISTClass = getattr(medmnist, INFO['chestmnist']['python_class'])
    
    path_train = PathMNISTClass(split='train', transform=data_transform, download=True, root=ROOT_PATH)
    chest_train_ssl = ChestMNISTClass(split='train', transform=data_transform, download=True, root=ROOT_PATH)
    
    ssl_dataset = ConcatDataset([path_train, chest_train_ssl])
    print(f"Tama√±o del dataset SSL (PathMNIST + ChestMNIST): {len(ssl_dataset)} im√°genes")

    # Datasets para DANN
    # Source: chestmnist (con etiquetas)
    # Target: breastmnist (sin etiquetas)
    BreastMNISTClass = getattr(medmnist, INFO['breastmnist']['python_class'])
    chest_train_dann = ChestMNISTClass(split='train', transform=data_transform, download=True, root=ROOT_PATH)
    breast_train_dann = BreastMNISTClass(split='train', transform=data_transform, download=True, root=ROOT_PATH)
    print(f"Tama√±o del dataset DANN Source (ChestMNIST): {len(chest_train_dann)} im√°genes")
    print(f"Tama√±o del dataset DANN Target (BreastMNIST): {len(breast_train_dann)} im√°genes")

    # Dataset para Inferencia Final
    BloodMNISTClass = getattr(medmnist, INFO['bloodmnist']['python_class'])
    blood_test = BloodMNISTClass(split='test', transform=data_transform, download=True, root=ROOT_PATH)
    print(f"Tama√±o del dataset de Inferencia (BloodMNIST): {len(blood_test)} im√°genes")
    
    datasets = {
        'ssl': ssl_dataset,
        'dann_source': chest_train_dann,
        'dann_target': breast_train_dann,
        'inference': blood_test
    }
    return datasets

# SSL

In [28]:
# --- 2. ETAPA SSL: PRE-ENTRENAMIENTO CON SimCLR ---
# ==================================
# Usaremos Lightly para simplificar el proceso de SSL
from lightly.data import LightlyDataset
from lightly.transforms import SimCLRTransform
from lightly.models.modules import SimCLRProjectionHead

class SimCLRLightning(pl.LightningModule):
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone
        self.projection_head = SimCLRProjectionHead(512, 512, 128) # ResNet18 tiene 512 features
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        h = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(h)
        return z

    def training_step(self, batch, batch_idx):
        (x0, x1), _, _ = batch
        z0 = self.forward(x0)
        z1 = self.forward(x1)
        
        # Implementaci√≥n de la p√©rdida NTXent
        z0 = F.normalize(z0, dim=1)
        z1 = F.normalize(z1, dim=1)

        sim_matrix = torch.matmul(z0, z1.T) / 0.1 # temp = 0.1
        
        batch_size = z0.shape[0]
        labels = torch.arange(batch_size, device=self.device)
        
        loss = (self.criterion(sim_matrix, labels) + self.criterion(sim_matrix.T, labels)) / 2
        
        self.log("train_loss_ssl", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=LR_SSL, momentum=0.9, weight_decay=5e-4)

def run_ssl_stage(ssl_dataset):
    """
    Ejecuta el entrenamiento auto-supervisado (SSL) con SimCLR.
    """
    print("\n--- ETAPA 1: Iniciando Pre-entrenamiento SSL con SimCLR ---")
    
    # ResNet18 como backbone
    resnet = models.resnet18()
    backbone = nn.Sequential(*list(resnet.children())[:-1])

    # Transformaciones espec√≠ficas para SimCLR
    transform = SimCLRTransform(input_size=28, vf_prob=0.5, rr_prob=0.5)
    dataset_lightly = LightlyDataset.from_torch_dataset(ssl_dataset, transform=transform)

    dataloader = DataLoader(
        dataset_lightly,
        batch_size=BATCH_SIZE,
        shuffle=True,
        drop_last=True,
        num_workers=NUM_WORKERS,
    )

    model = SimCLRLightning(backbone)
    trainer = pl.Trainer(
        max_epochs=NUM_EPOCHS_SSL,
        accelerator="auto",
        devices=1,
        logger=TensorBoardLogger("lightning_logs", name="SSL"),
        log_every_n_steps=10
    )
    trainer.fit(model, dataloader)

    print(f"‚úÖ Pre-entrenamiento SSL finalizado. Guardando backbone en '{SSL_BACKBONE_PATH}'")
    torch.save(model.backbone.state_dict(), SSL_BACKBONE_PATH)
# ==================================

# DANN

In [None]:
# --- 3. ETAPA DANN: ADAPTACI√ìN DE DOMINIO ---
# ==================================
# Componente clave de DANN: la capa de inversi√≥n de gradiente (GRL)
class GradientReversalFunction(Function):
    @staticmethod
    def forward(ctx, x, lambda_):
        ctx.lambda_ = lambda_
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg() * ctx.lambda_, None

class GradientReversalLayer(nn.Module):
    def __init__(self, lambda_=1.0):
        super(GradientReversalLayer, self).__init__()
        self.lambda_ = lambda_

    def forward(self, x):
        return GradientReversalFunction.apply(x, self.lambda_)

# Modelo DANN completo
class DANN(pl.LightningModule):
    def __init__(self, backbone_path, source_num_classes):
        super().__init__()
        self.save_hyperparameters()

        # 1. Extractor de caracter√≠sticas (backbone pre-entrenado)
        resnet = models.resnet18()
        self.feature_extractor = nn.Sequential(*list(resnet.children())[:-1])
        # Cargamos los pesos de la etapa SSL
        self.feature_extractor.load_state_dict(torch.load(backbone_path))
        print(f"Backbone cargado desde '{backbone_path}' para DANN.")

        # 2. Clasificador de etiquetas (para el dominio fuente)
        self.label_classifier = nn.Sequential(
            nn.Linear(512, 100),
            nn.ReLU(),
            nn.Linear(100, source_num_classes)
        )

        # 3. Clasificador de dominio (distingue fuente de objetivo)
        self.domain_classifier = nn.Sequential(
            nn.Linear(512, 100),
            nn.ReLU(),
            nn.Linear(100, 1) # Salida binaria: 0=fuente, 1=objetivo
        )
        
        # 4. Capa de inversi√≥n de gradiente
        self.grl = GradientReversalLayer()

        self.class_criterion = nn.CrossEntropyLoss()
        self.domain_criterion = nn.BCEWithLogitsLoss()

    def forward(self, x):
        features = self.feature_extractor(x).flatten(1)
        # La GRL se aplica ANTES del clasificador de dominio
        reversed_features = self.grl(features)
        
        label_preds = self.label_classifier(features)
        domain_preds = self.domain_classifier(reversed_features)
        
        return label_preds, domain_preds.squeeze()

    def training_step(self, batch, batch_idx):
        source_batch, target_batch = batch
        s_imgs, s_labels = source_batch
        s_labels = s_labels.squeeze().long()
        t_imgs, _ = target_batch # No usamos las etiquetas del objetivo
        
        # Etiquetas de dominio: 0 para fuente, 1 para objetivo
        s_domain_labels = torch.zeros(s_imgs.size(0), device=self.device)
        t_domain_labels = torch.ones(t_imgs.size(0), device=self.device)

        # Predicciones para el lote de fuente
        s_label_preds, s_domain_preds = self(s_imgs)
        
        # Predicciones para el lote de objetivo (solo nos importa el dominio)
        _, t_domain_preds = self(t_imgs)
        
        # --- C√°lculo de las p√©rdidas ---
        # 1. P√©rdida de clasificaci√≥n de etiquetas (solo en fuente)
        loss_class = self.class_criterion(s_label_preds, s_labels)
        
        # 2. P√©rdida de clasificaci√≥n de dominio (en ambos)
        loss_domain_s = self.domain_criterion(s_domain_preds, s_domain_labels)
        loss_domain_t = self.domain_criterion(t_domain_preds, t_domain_labels)
        loss_domain = (loss_domain_s + loss_domain_t) / 2
        
        # P√©rdida total
        total_loss = loss_class + loss_domain
        
        self.log_dict({
            'train_loss_dann': total_loss,
            'train_loss_class': loss_class,
            'train_loss_domain': loss_domain
        }, prog_bar=True)
        
        return total_loss

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=LR_DANN)

def run_dann_stage(dann_source_dataset, dann_target_dataset):
    """
    Ejecuta la adaptaci√≥n de dominio con DANN.
    """
    print("\n--- ETAPA 2: Iniciando Adaptaci√≥n de Dominio con DANN ---")
    
    source_loader = DataLoader(
        dann_source_dataset, batch_size=BATCH_SIZE, shuffle=True, 
        num_workers=NUM_WORKERS, drop_last=True
    )
    target_loader = DataLoader(
        dann_target_dataset, batch_size=BATCH_SIZE, shuffle=True, 
        num_workers=NUM_WORKERS, drop_last=True
    )

    # Combinamos los dataloaders para iterar en paralelo
    # Usamos itertools.cycle si un dataset es m√°s grande que el otro
    if len(source_loader) > len(target_loader):
        combined_loader = zip(source_loader, itertools.cycle(target_loader))
        steps_per_epoch = len(target_loader)
    else:
        combined_loader = zip(itertools.cycle(source_loader), target_loader)
        steps_per_epoch = len(source_loader)

    # El n√∫mero de clases del clasificador se basa en el dataset fuente (ChestMNIST)
    n_classes_source = INFO['chestmnist']['n_classes']

    model = DANN(backbone_path=SSL_BACKBONE_PATH, source_num_classes=n_classes_source)
    
    trainer = pl.Trainer(
        max_epochs=NUM_EPOCHS_DANN,
        accelerator="auto",
        devices=1,
        logger=TensorBoardLogger("lightning_logs", name="DANN"),
        limit_train_batches=steps_per_epoch # Para que una √©poca termine cuando el loader m√°s corto lo haga
    )

    trainer.fit(model, train_dataloaders=combined_loader)
    
    print(f"‚úÖ Adaptaci√≥n de Dominio DANN finalizada. Guardando modelo en '{DANN_MODEL_PATH}'")
    # Guardamos el extractor de caracter√≠sticas y el clasificador de etiquetas
    final_model_state = {
        'feature_extractor': model.feature_extractor.state_dict(),
        'label_classifier': model.label_classifier.state_dict()
    }
    torch.save(final_model_state, DANN_MODEL_PATH)
# ==================================


# INFERENCIA

In [None]:
# --- 4. ETAPA DE INFERENCIA ---
# ==================================
class InferenceModel(nn.Module):
    def __init__(self, model_path, source_num_classes):
        super().__init__()
        state = torch.load(model_path)
        
        resnet = models.resnet18()
        self.feature_extractor = nn.Sequential(*list(resnet.children())[:-1])
        self.feature_extractor.load_state_dict(state['feature_extractor'])
        
        self.label_classifier = nn.Sequential(
            nn.Linear(512, 100),
            nn.ReLU(),
            nn.Linear(100, source_num_classes)
        )
        self.label_classifier.load_state_dict(state['label_classifier'])
    
    def forward(self, x):
        features = self.feature_extractor(x).flatten(1)
        return self.label_classifier(features)

def run_inference_stage(inference_dataset):
    """
    Eval√∫a el modelo final en el dataset de inferencia (BloodMNIST).
    """
    print("\n--- ETAPA 3: Iniciando Inferencia en BloodMNIST ---")
    
    dataloader = DataLoader(inference_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
    
    # El modelo fue entrenado para clasificar ChestMNIST (14 clases)
    # BloodMNIST tiene 8 clases. Esto es un desaf√≠o de transferencia directa.
    # El modelo predecir√° una de las 14 clases. Mapearemos la clase de mayor probabilidad
    # a la clase de BloodMNIST, aunque esto no es ideal, demuestra la transferencia.
    n_classes_source = INFO['chestmnist']['n_classes']
    model = InferenceModel(DANN_MODEL_PATH, source_num_classes=n_classes_source)
    model.to(DEVICE)
    model.eval()
    
    correct = 0
    total = 0
    
    with torch.no_grad():
        for imgs, labels in dataloader:
            imgs, labels = imgs.to(DEVICE), labels.squeeze().long().to(DEVICE)
            outputs = model(imgs)
            # Como el n√∫mero de clases es diferente, esta es una forma simple de medir
            # el rendimiento, pero en un caso real, se necesitar√≠a una capa de adaptaci√≥n final.
            # Aqu√≠, solo vemos si la predicci√≥n coincide si el n√∫mero de clases fuera el mismo.
            # Esto funcionar√° si las etiquetas de bloodmnist (0-7) se solapan con las de chestmnist (0-13)
            _, predicted = torch.max(outputs.data, 1)
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"üéØ Exactitud final en el dataset de inferencia (BloodMNIST): {accuracy:.2f}%")
    print("\nNota: La exactitud puede ser baja debido a la diferencia de clases entre el dominio fuente y el de inferencia.")
    print("El objetivo principal era demostrar el pipeline completo.")
# ==================================

In [None]:
# --- SCRIPT PRINCIPAL ---
# ==================================
if __name__ == '__main__':
    # 1. Preparar todos los datos necesarios
    all_datasets = prepare_datasets()

    # 2. Ejecutar la etapa de pre-entrenamiento SSL
    if not os.path.exists(SSL_BACKBONE_PATH):
        run_ssl_stage(all_datasets['ssl'])
    else:
        print(f"\nSaltando Etapa 1 (SSL), ya existe el archivo '{SSL_BACKBONE_PATH}'")

    # 3. Ejecutar la etapa de adaptaci√≥n de dominio DANN
    if not os.path.exists(DANN_MODEL_PATH):
        run_dann_stage(all_datasets['dann_source'], all_datasets['dann_target'])
    else:
        print(f"\nSaltando Etapa 2 (DANN), ya existe el archivo '{DANN_MODEL_PATH}'")

    # 4. Ejecutar la inferencia final
    run_inference_stage(all_datasets['inference'])

    print("\nüéâ Pipeline completado.")
# ==================================
