In [None]:
# ==========================================
# 1. INSTALACIÓN DE DEPENDENCIAS
# ==========================================
!pip install tensorflow keras scikit-learn matplotlib pandas numpy opencv-python-headless
!pip install torch torchvision fastprogress kaggle

import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.autograd as autograd
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from collections import Counter
from google.colab import drive
import random

# Semilla para reproducibilidad
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_everything()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Librerías listas. Usando dispositivo: {device}")

In [None]:
# ==========================================
# 2. CONEXIÓN CON GOOGLE DRIVE
# ==========================================
drive.mount('/content/drive')

# Ruta Base del Proyecto (AJUSTA ESTO SI ES NECESARIO)
BASE_PATH = '/content/drive/MyDrive/proyecto_completo/'

# Carpetas de salida
DATA_PATH = os.path.join(BASE_PATH, 'preprocesamiento')
MODELS_PATH = os.path.join(BASE_PATH, 'models')
CHECKPOINT_PATH = os.path.join(BASE_PATH, 'checkpoints')

os.makedirs(DATA_PATH, exist_ok=True)
os.makedirs(MODELS_PATH, exist_ok=True)
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

print(f"Directorio de trabajo configurado en: {BASE_PATH}")

In [None]:
# ==========================================
# 3. BIBLIOTECA DE PREPROCESAMIENTO (UTILS)
# ==========================================

IMG_WIDTH, IMG_HEIGHT = 256, 256  # Tamaño intermedio
RESNET_SIZE = 224                 # Tamaño final

def center_crop(img, crop_size=None):
    """Recorta el centro de la imagen para eliminar bordes irrelevantes."""
    if crop_size is None:
        crop_size = int(min(img.shape[0], img.shape[1]) * 0.95)

    y, x = img.shape[:2]
    startx = x // 2 - crop_size // 2
    starty = y // 2 - crop_size // 2
    startx = max(0, startx)
    starty = max(0, starty)

    return img[starty:starty+crop_size, startx:startx+crop_size]

def apply_clahe(img):
    """Aplica CLAHE. Si es RGB, convierte a LAB, aplica a L y reconvierte."""
    if len(img.shape) == 3 and img.shape[2] == 3:
        lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
        l, a, b = cv2.split(lab)
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        cl = clahe.apply(l)
        limg = cv2.merge((cl, a, b))
        return cv2.cvtColor(limg, cv2.COLOR_LAB2RGB)
    else:
        # Escala de grises
        if len(img.shape) == 3: img = img[:,:,0]
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        return clahe.apply(img)

def apply_gaussian_filter(img, kernel_size=(3, 3)):
    """Suavizado para reducir ruido."""
    return cv2.GaussianBlur(img, kernel_size, 0)

def preprocess_image(img, target_size=(224, 224)):
    """Pipeline completo de preprocesamiento."""
    try:
        # 1. Center Crop
        img = center_crop(img)
        # 2. Resize intermedio
        img = cv2.resize(img, (256, 256))
        # 3. CLAHE
        img = apply_clahe(img)
        # 4. Gaussian
        img = apply_gaussian_filter(img)
        # 5. Resize final
        img = cv2.resize(img, target_size)

        # Normalización [0, 1]
        img = img.astype(np.float32) / 255.0
        return img
    except Exception as e:
        print(f"Error preprocesando: {e}")
        return None

def analyze_dataset(labels):
    """Calcula el Desbalance (IR)."""
    counter = Counter(labels)
    majority = max(counter.values())
    minority = min(counter.values())
    ir = majority / minority
    print(f"Distribución: {counter}")
    print(f"Imbalance Ratio (IR): {ir:.2f}")
    return ir

In [None]:
# ==========================================
# 4. PREPARACIÓN DE DATOS PARA WGAN
# ==========================================

# Cargar datos originales
print("Cargando arrays .npy originales...")
try:
    X_train = np.load(os.path.join(DATA_PATH, 'X_train_unbalanced.npy'))
    y_train = np.load(os.path.join(DATA_PATH, 'y_train_unbalanced.npy'))
    print(f"Datos cargados. Shape: {X_train.shape}")
except FileNotFoundError:
    print("No se encontraron los archivos .npy en la ruta especificada.")


def ensure_grayscale_clahe_and_norm(X, target_channels=1):
    """
    Asegura formato (N, H, W, 1), aplica CLAHE y normaliza a [-1, 1] para la GAN.
    """
    print("Aplicando preprocesamiento específico para WGAN (CLAHE + Norm [-1,1])...")
    X_processed = []

    for i in range(len(X)):
        img = X[i]
        # Asegurar uint8 para CLAHE
        if img.dtype != np.uint8:
            if img.max() <= 1.0: img = (img * 255).astype(np.uint8)
            else: img = img.astype(np.uint8)

        # Si es RGB, pasar a Gris
        if len(img.shape) == 3 and img.shape[2] == 3:
            img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        elif len(img.shape) == 3 and img.shape[2] == 1:
            img = img[:, :, 0]

        # Aplicar CLAHE
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        img = clahe.apply(img)
        X_processed.append(img)

    X_processed = np.array(X_processed)

    # Expandir dims (N, H, W, 1)
    if X_processed.ndim == 3:
        X_processed = np.expand_dims(X_processed, axis=-1)

    # Normalizar a [-1, 1] (Requisito de Generator con Tanh)
    X_processed = (X_processed.astype('float32') - 127.5) / 127.5

    return X_processed

# Procesar
X_train_wgan = ensure_grayscale_clahe_and_norm(X_train)

# Filtrar Clase Minoritaria (Normal=0, Pneumonia=1)
counts = np.bincount(y_train.astype(int))
minority_class = np.argmin(counts)
print(f"Clase minoritaria detectada: {minority_class} (Cantidad: {counts[minority_class]})")

X_minority = X_train_wgan[y_train == minority_class]
print(f"Dataset para WGAN listo: {X_minority.shape}, Rango: [{X_minority.min():.1f}, {X_minority.max():.1f}]")

In [None]:
# ==========================================
# 5. ARQUITECTURA WGAN-GP (PyTorch)
# ==========================================
class Generator(nn.Module):
    def __init__(self, latent_dim=128, channels=1, img_size=224):
        super().__init__()
        self.init_size = img_size // 32
        self.inp = nn.Sequential(
            nn.Linear(latent_dim, 512 * self.init_size ** 2),
            nn.BatchNorm1d(512 * self.init_size ** 2),
            nn.ReLU(True)
        )
        def sn_conv_t(in_c, out_c):
            return nn.utils.spectral_norm(
                nn.ConvTranspose2d(in_c, out_c, 4, stride=2, padding=1, bias=False))

        self.main = nn.Sequential(
            sn_conv_t(512, 256), nn.BatchNorm2d(256), nn.ReLU(True),
            sn_conv_t(256, 128), nn.BatchNorm2d(128), nn.ReLU(True),
            sn_conv_t(128, 64),  nn.BatchNorm2d(64),  nn.ReLU(True),
            sn_conv_t(64, 32),   nn.BatchNorm2d(32),  nn.ReLU(True),
            sn_conv_t(32, channels),
            nn.Tanh()
        )
    def forward(self, z):
        x = self.inp(z)
        x = x.view(x.size(0), 512, self.init_size, self.init_size)
        return self.main(x)

class Critic(nn.Module):
    def __init__(self, channels=1, img_size=224):
        super().__init__()
        def sn_conv(in_c, out_c):
            return nn.utils.spectral_norm(
                nn.Conv2d(in_c, out_c, 4, stride=2, padding=1, bias=False))

        self.main = nn.Sequential(
            sn_conv(channels, 32), nn.LeakyReLU(0.2, inplace=True),
            sn_conv(32, 64),   nn.InstanceNorm2d(64),  nn.LeakyReLU(0.2, inplace=True),
            sn_conv(64, 128),  nn.InstanceNorm2d(128), nn.LeakyReLU(0.2, inplace=True),
            sn_conv(128, 256), nn.InstanceNorm2d(256), nn.LeakyReLU(0.2, inplace=True),
            sn_conv(256, 512), nn.InstanceNorm2d(512), nn.LeakyReLU(0.2, inplace=True)
        )
        self.out = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512 * (img_size // 32) ** 2, 1)
        )
    def forward(self, x):
        return self.out(self.main(x))

In [None]:
# ==========================================
# 6. ENTRENAMIENTO PROGRESIVO (3 ETAPAS)
# ==========================================
import torch.autograd as autograd
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import torch.nn.functional as F

# --- Configuración Progresiva ---
STAGES = [64, 128, 224]  # Las 3 resoluciones
EPOCHS_PER_STAGE = [500, 500, 4000] # Según la Tabla 12
BATCH_SIZES = [64, 32, 32] # Ajustar según memoria (64 para pequeña, 32 para grande)

# Hiperparámetros WGAN-GP
LR_G = 1e-4
LR_C = 5e-5 # (En la última etapa ajustaremos esto dinámicamente)
N_CRITIC = 5
LAMBDA_GP = 10.0
LATENT_DIM = 128

# Tensor de referencia
Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor

# --- Dataset que redimensiona dinámicamente ---
class ProgressiveDataset(Dataset):
    def __init__(self, x_data, current_size):
        self.x_data = x_data
        self.current_size = current_size

    def __len__(self):
        return len(self.x_data)

    def __getitem__(self, idx):
        # Convertir (H, W, 1) -> (1, H, W) -> Tensor
        img = self.x_data[idx]
        img_tensor = torch.from_numpy(img.transpose(2, 0, 1))
        # Redimensionar a la etapa actual (ej. 64x64)
        img_resized = F.interpolate(img_tensor.unsqueeze(0), size=(self.current_size, self.current_size), mode='bilinear', align_corners=False)
        return img_resized.squeeze(0)

# --- Función de Penalización de Gradiente ---
def compute_gradient_penalty(D, real_samples, fake_samples):
    alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1)))
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates = D(interpolates)
    fake = Tensor(real_samples.shape[0], 1).fill_(1.0)
    gradients = autograd.grad(
        outputs=d_interpolates, inputs=interpolates, grad_outputs=fake,
        create_graph=True, retain_graph=True, only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    return ((gradients.norm(2, dim=1) - 1) ** 2).mean()

# --- BUCLE PRINCIPAL DE ETAPAS ---
for i, img_size in enumerate(STAGES):
    print(f"\n" + "="*40)
    print(f" INICIANDO ETAPA {i+1}: Resolución {img_size}x{img_size}")
    print("="*40)

    # 1. Configurar Datos para esta etapa
    dataset = ProgressiveDataset(X_target, current_size=img_size)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZES[i], shuffle=True, drop_last=True)

    # 2. Instanciar Modelos para esta resolución
    generator = Generator(img_size=img_size).to(device)
    critic = Critic(img_size=img_size).to(device)

    # 3. Cargar pesos de la etapa anterior (Transfer Learning Progresivo)
    if i > 0:
        prev_gen_path = os.path.join(CHECKPOINT_PATH, f'generator_stage{i}_best.pth')
        prev_crit_path = os.path.join(CHECKPOINT_PATH, f'critic_stage{i}_best.pth')

        if os.path.exists(prev_gen_path):
            print(f"Cargando pesos de etapa previa: {prev_gen_path}")
            # Carga flexible (strict=False) para adaptar capas de distinto tamaño
            pretrained_dict = torch.load(prev_gen_path)
            model_dict = generator.state_dict()
            # Filtrar solo pesos que coincidan en forma
            pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and v.shape == model_dict[k].shape}
            model_dict.update(pretrained_dict)
            generator.load_state_dict(model_dict)

            # Lo mismo para el crítico
            critic.load_state_dict(torch.load(prev_crit_path), strict=False)

    # 4. Ajustar Learning Rates (Escalado en etapa final)
    current_lr_g = LR_G * 0.5 if i == 2 else LR_G
    current_lr_c = LR_C * 1.5 if i == 2 else LR_C # Ajuste según tu código (7.5e-5 vs 5e-5)

    optimizer_G = optim.Adam(generator.parameters(), lr=current_lr_g, betas=(0.0, 0.9))
    optimizer_C = optim.Adam(critic.parameters(), lr=current_lr_c, betas=(0.0, 0.9))

    # 5. Entrenamiento de la Etapa
    epochs = EPOCHS_PER_STAGE[i]

    for epoch in range(1, epochs + 1):
        for batch_idx, imgs in enumerate(dataloader):

            real_imgs = imgs.type(Tensor).to(device)

            # --- Entrenar Crítico ---
            optimizer_C.zero_grad()
            z = Tensor(np.random.normal(0, 1, (imgs.shape[0], LATENT_DIM)))
            fake_imgs = generator(z)

            real_validity = critic(real_imgs)
            fake_validity = critic(fake_imgs)
            gp = compute_gradient_penalty(critic, real_imgs.data, fake_imgs.data)

            d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + LAMBDA_GP * gp
            d_loss.backward()
            optimizer_C.step()

            # --- Entrenar Generador (cada n_critic pasos) ---
            if batch_idx % N_CRITIC == 0:
                optimizer_G.zero_grad()
                fake_imgs = generator(z) # Regenerar para grafo computacional
                fake_validity = critic(fake_imgs)
                g_loss = -torch.mean(fake_validity)
                g_loss.backward()
                optimizer_G.step()

        # Log simple
        if epoch % 100 == 0:
            print(f"[Etapa {img_size}x{img_size}] Epoch {epoch}/{epochs} | D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}")

    # 6. Guardar Modelo de la Etapa
    torch.save(generator.state_dict(), os.path.join(CHECKPOINT_PATH, f'generator_stage{i+1}_best.pth'))
    torch.save(critic.state_dict(), os.path.join(CHECKPOINT_PATH, f'critic_stage{i+1}_best.pth'))
    print(f"Modelo de Etapa {img_size}x{img_size} guardado.")

# Guardar el modelo final definitivo (Stage 3) como 'best_generator_stage3.pth' para la generación
torch.save(generator.state_dict(), os.path.join(CHECKPOINT_PATH, 'best_generator_stage3.pth'))
print("Entrenamiento progresivo completo.")

In [None]:
# ==========================================
# 7. GENERACIÓN DE IMÁGENES SINTÉTICAS
# ==========================================
LATENT_DIM = 128
IMG_SIZE = 224
NUM_IMGS_TO_GEN = 2179

best_gen_path = os.path.join(CHECKPOINT_PATH, 'best_generator_stage3.pth')

if os.path.exists(best_gen_path):
    generator = Generator(img_size=IMG_SIZE).to(device)
    generator.load_state_dict(torch.load(best_gen_path, map_location=device))
    generator.eval()

    print(f"Generando {NUM_IMGS_TO_GEN} imágenes sintéticas...")
    synthetic_imgs = []

    with torch.no_grad():
        for _ in range(0, NUM_IMGS_TO_GEN, 32):
            curr = min(32, NUM_IMGS_TO_GEN - len(synthetic_imgs))
            if curr <= 0: break
            noise = torch.randn(curr, LATENT_DIM, device=device)
            gen = generator(noise).cpu().numpy()
            # Desnormalizar de [-1, 1] -> [0, 255]
            gen = (gen + 1) / 2.0 * 255.0
            synthetic_imgs.append(gen)

    synthetic_data = np.concatenate(synthetic_imgs, axis=0).astype(np.uint8).squeeze()

    save_path = os.path.join(CHECKPOINT_PATH, f'generated_data_{NUM_IMGS_TO_GEN}.npy')
    np.save(save_path, synthetic_data)
    print(f"Guardado: {save_path}")