In [None]:
# Standard
import os, sys, json, glob, re, math, random, pickle, time, datetime, subprocess, logging, argparse
from typing import Type, Any, Callable, Union, List, Optional
from pathlib import Path

# Scientific
import numpy as np
import pandas as pd
from scipy import linalg
from scipy.spatial.distance import jensenshannon
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr

# Torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torchvision.utils as vutils
from torch import Tensor
from torchvision.utils import save_image
from torchvision.models import inception_v3, Inception_V3_Weights
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.kid import KernelInceptionDistance

# === Librer√≠as para m√©tricas ===
import lpips
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
from scipy.spatial.distance import jensenshannon


from src.medgan.dataset import dstget
from src.medgan.dcgan import DCGAN_G, DCGAN_D
from src.medgan.mlp import MLP_G, MLP_D

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)

In [None]:
PROJECT_ROOT = Path.cwd().parent

DATA_DIR = PROJECT_ROOT / "data" / "processed" / "medgan"
EXPERIMENTS_DIR = PROJECT_ROOT / "experiments" / "medgan"
RESULTS_DIR = PROJECT_ROOT / "results" / "medgan"

CHECKPOINT_ROOT = EXPERIMENTS_DIR / "checkpoints"
GENERATED_ROOT = RESULTS_DIR / "generated"
METRICS_ROOT = RESULTS_DIR / "metrics"

## Utils

In [None]:
#toolsf.py
def sg_fk_img_gnrt(model, noise, dir, begin_idx=0):
    """Genera im√°genes falsas individuales y las guarda en un directorio."""
    if not os.path.exists(dir):
        os.makedirs(dir)

    with torch.no_grad():
        fake = model(noise)
    fake = fake.mul(0.5).add(0.5)  # de [-1,1] a [0,1]

    image_num = fake.shape[0]
    for i in range(image_num):
        image_path = os.path.join(dir, '{}.png'.format(begin_idx + i + 1))
        vutils.save_image(fake[i, :, :, :], image_path, normalize=True)

def bc_rl_img_gnrt(real_data, dir, iter):
    """Guarda un batch de im√°genes reales."""
    if not os.path.exists(dir):
        os.makedirs(dir)

    real_cpu = real_data.mul(0.5).add(0.5)
    image_path = os.path.join(dir, '{}_real_samples.png'.format(iter))
    vutils.save_image(real_cpu, image_path, normalize=True)

def bc_fk_img_gnrt(model, noise, dir, iter):
    """Genera un batch de im√°genes falsas y las guarda."""
    if not os.path.exists(dir):
        os.makedirs(dir)

    with torch.no_grad():
        fake = model(noise)
    fake = fake.mul(0.5).add(0.5)
    fake_image_path = os.path.join(dir, '{}_fake_samples.png'.format(iter))
    vutils.save_image(fake, fake_image_path, normalize=True)

def timage_gnrt(model, real_data, noise, root, iter):
    """Genera un set de im√°genes reales y falsas en un mismo paso."""
    sg_fk_dir = os.path.join(root, 'sg_fk_img', str(iter))
    sg_fk_img_gnrt(model, noise, sg_fk_dir)

    bc_rl_dir = os.path.join(root, 'bc_img')
    bc_rl_img_gnrt(real_data, bc_rl_dir, iter)

    bc_fk_dir = os.path.join(root, 'bc_img')
    bc_fk_img_gnrt(model, noise, bc_fk_dir, iter)

def model_save(netG, netD, iter, dir):
    """Guarda el estado de generador y discriminador."""
    model_save_dir = os.path.join(dir, 'save_model')
    if not os.path.exists(model_save_dir):
        os.makedirs(model_save_dir)

    G_save_path = os.path.join(model_save_dir, 'netG_iter{}.pth'.format(iter))
    D_save_path = os.path.join(model_save_dir, 'netD_iter{}.pth'.format(iter))
    torch.save(netG.state_dict(), G_save_path)
    torch.save(netD.state_dict(), D_save_path)

def isok(sub_list):
    for item in sub_list:
        if item.poll() is None:
            return False
    return True

def execute_command(cmdstring_list, cwd=None, timeout=None, shell=True):
    """Ejecuta comandos de shell de forma secuencial."""
    if timeout:
        end_time = datetime.datetime.now() + datetime.timedelta(seconds=timeout)

    sub_list = []
    for i, item in enumerate(cmdstring_list):
        sub = subprocess.Popen(item, cwd=cwd, stdin=subprocess.PIPE,
                               shell=shell, bufsize=4096)
        time.sleep(1.1)
        sub_list.append(sub)

    print('Comenzando ejecuci√≥n...')
    while True:
        if isok(sub_list) is True:
            break
        time.sleep(0.5)
    print('Ejecuci√≥n finalizada.')

class AverageMeter(object):
    """Calcula y almacena el promedio y el valor actual."""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

def accuracy(output, target, topk=(1,)):
    """Calcula accuracy en top-k."""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

def get_logger(file_path):
    logger = logging.getLogger('ecogan')
    log_format = '%(asctime)s | %(message)s'
    formatter = logging.Formatter(log_format, datefmt='%m/%d %I:%M:%S %p')
    file_handler = logging.FileHandler(file_path)
    file_handler.setFormatter(formatter)
    stream_handler = logging.StreamHandler()
    stream_handler.setFormatter(formatter)

    logger.addHandler(file_handler)
    logger.addHandler(stream_handler)
    logger.setLevel(logging.INFO)

    return logger

def lgwt_construct(logpath):
    """Construye un logger simple."""
    logger = get_logger(logpath)
    return logger

## Modelos de Evaluacion

In [None]:
# resnet.py
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
           'resnet152']


def conv3x3(in_planes: int, out_planes: int, stride: int = 1,
            groups: int = 1, dilation: int = 1) -> nn.Conv2d:
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


###############################################################################
# BasicBlock y Bottleneck (sin cambios)
###############################################################################
class BasicBlock(nn.Module):
    expansion: int = 1
    def __init__(self, inplanes: int, planes: int, stride: int = 1,
                 downsample: Optional[nn.Module] = None, groups: int = 1,
                 base_width: int = 64, dilation: int = 1,
                 norm_layer: Optional[Callable[..., nn.Module]] = None) -> None:
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x: Tensor) -> Tensor:
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion: int = 4
    def __init__(self, inplanes: int, planes: int, stride: int = 1,
                 downsample: Optional[nn.Module] = None, groups: int = 1,
                 base_width: int = 64, dilation: int = 1,
                 norm_layer: Optional[Callable[..., nn.Module]] = None) -> None:
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.)) * groups
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x: Tensor) -> Tensor:
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out


###############################################################################
# ResNet adaptado para escala de grises y GAN Discriminator
###############################################################################
class ResNet(nn.Module):
    def __init__(self, block: Type[Union[BasicBlock, Bottleneck]],
                 layers: List[int],
                 num_classes: int = 1,  # GAN discriminator: real/fake
                 zero_init_residual: bool = False,
                 groups: int = 1,
                 width_per_group: int = 64,
                 replace_stride_with_dilation: Optional[List[bool]] = None,
                 norm_layer: Optional[Callable[..., nn.Module]] = None) -> None:
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer
        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None or a 3-element tuple")
        self.groups = groups
        self.base_width = width_per_group
        # Cambiado a 1 canal en vez de 3
        self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        # salida: 1 score (real/fake)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )
        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes,
                                groups=self.groups, base_width=self.base_width,
                                dilation=self.dilation, norm_layer=norm_layer))
        return nn.Sequential(*layers)

    def forward(self, x: Tensor) -> Tensor:
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x  # (batch, 1)


###############################################################################
# Helpers: versiones de ResNet para GAN
###############################################################################
def resnet18(**kwargs: Any) -> ResNet:
    return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)

def resnet34(**kwargs: Any) -> ResNet:
    return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)

def resnet50(**kwargs: Any) -> ResNet:
    return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)


## ADA


In [None]:
# Ruta base de checkpoints
checkpoint_base = "/content/drive/MyDrive/Proyecto_Grado/MedGAN/ADA/CheckpointsE1"

# Buscar √∫ltima carpeta
checkpoint_dirs = sorted(glob.glob(os.path.join(checkpoint_base, "*_checkpoint")))
if not checkpoint_dirs:
    print("‚ö†Ô∏è No se encontr√≥ ning√∫n checkpoint. El entrenamiento empezar√° desde cero.")
    last_epoch = 0
    last_checkpoint_dir = None
else:
    last_checkpoint_dir = checkpoint_dirs[-1]
    print(f"‚úÖ √öltimo checkpoint encontrado: {last_checkpoint_dir}")

    # Buscar archivos de netG y netD
    netG_files = glob.glob(os.path.join(last_checkpoint_dir, "netG_epoch*.pth"))
    netD_files = glob.glob(os.path.join(last_checkpoint_dir, "netD_epoch*.pth"))

    # Extraer n√∫meros de √©poca
    def extract_epochs(files, prefix):
        epochs = []
        for f in files:
            m = re.search(rf"{prefix}_epoch(\d+)\.pth", os.path.basename(f))
            if m:
                epochs.append(int(m.group(1)))
        return sorted(epochs)

    G_epochs = extract_epochs(netG_files, "netG")
    D_epochs = extract_epochs(netD_files, "netD")

    if not G_epochs or not D_epochs:
        print("‚ö†Ô∏è No se encontraron checkpoints completos de G y D.")
        last_epoch = 0
    else:
        # Buscar la √∫ltima √©poca que est√© en ambos
        common_epochs = sorted(set(G_epochs).intersection(D_epochs))
        if not common_epochs:
            print("‚ö†Ô∏è No hay √©pocas coincidentes entre G y D.")
            last_epoch = 0
        else:
            last_epoch = common_epochs[-1]
            print(f"üîÑ √öltimo checkpoint coincidente: epoch {last_epoch}")
            print(f"   - netG: netG_epoch{last_epoch}.pth")
            print(f"   - netD: netD_epoch{last_epoch}.pth")

# Ahora `last_checkpoint_dir` y `last_epoch` te sirven para cargar en el training loop.


‚úÖ √öltimo checkpoint encontrado: /content/drive/MyDrive/Proyecto_Grado/MedGAN/ADA/CheckpointsE1/00000_checkpoint
üîÑ √öltimo checkpoint coincidente: epoch 70
   - netG: netG_epoch70.pth
   - netD: netD_epoch70.pth


In [None]:
# Continuaci√≥n con manejo de checkpoints y logging mejorado)
one = torch.FloatTensor([1]).to(device)
mone = (one * -1).to(device)

# LPIPS (usa inputs en rango [-1,1], 3 canales)
lpips_fn = lpips.LPIPS(net='alex').to(device)


def argsget():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', default="other")
    parser.add_argument('--dataroot', default="/content/drive/MyDrive/Proyecto_Grado/Data/frames_extraidos_MedGAN")
    parser.add_argument('--workers', type=int, default=4)
    parser.add_argument('--batchSize', type=int, default=20)
    parser.add_argument('--img_size', type=int, default=128)
    parser.add_argument('--nc', type=int, default=1)
    parser.add_argument('--nz', type=int, default=100)
    parser.add_argument('--ngf', type=int, default=64)
    parser.add_argument('--ndf', type=int, default=64)
    parser.add_argument('--niter', type=int, default=2000)
    parser.add_argument('--lrD', type=float, default=0.00005)
    parser.add_argument('--lrG', type=float, default=0.00005)
    parser.add_argument('--beta1', type=float, default=0.5)
    parser.add_argument('--ngpu', type=int, default=1)
    parser.add_argument('--netG', default='')
    parser.add_argument('--netD', default='')
    parser.add_argument('--clamp_lower', type=float, default=-0.01)
    parser.add_argument('--clamp_upper', type=float, default=0.01)
    parser.add_argument('--Diters', type=int, default=5)
    parser.add_argument('--noBN', action='store_true')
    parser.add_argument('--mlp_G', action='store_true')
    parser.add_argument('--mlp_D', action='store_true')
    parser.add_argument('--n_extra_layers', type=int, default=0)
    parser.add_argument('--experiment', default='/content/drive/MyDrive/Proyecto_Grado/MedGAN/ADA')
    parser.add_argument('--adam', action='store_true')
    parser.add_argument('--class_name', type=str, default='ecocardio')
    parser.add_argument('--T1', type=float, default=0.3)
    parser.add_argument('--T2', type=float, default=0.5)
    opt = parser.parse_args(args=[])
    return opt


def mkdir(path):
    if not os.path.exists(path):
        os.makedirs(path)


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)


def discriminator_train(netD, netG, data, noise, optimizerD, opt):
    for p in netD.parameters():
        p.requires_grad = True
        p.data.clamp_(opt.clamp_lower, opt.clamp_upper)

    netD.zero_grad()

    real = data.to(device)
    errD_real = netD(real)
    errD_real.backward(one)

    noise.resize_(opt.batchSize, opt.nz, 1, 1).normal_(0, 1)
    fake = netG(noise).detach()
    errD_fake = netD(fake)
    errD_fake.backward(mone)

    errD = errD_real - errD_fake
    optimizerD.step()
    return errD_real, errD_fake, errD


def discriminator_infer(netD, netG, data, noise, opt):
    with torch.no_grad():
        real = data.to(device)
        c_errD_real = netD(real)

        noise.resize_(opt.batchSize, opt.nz, 1, 1).normal_(0, 1)
        fake = netG(noise)
        c_errD_fake = netD(fake)
        c_errD = c_errD_real - c_errD_fake

    return c_errD_real, c_errD_fake, c_errD


def generator_train(netG, netD, noise, optimizerG, opt):
    netG.zero_grad()
    noise.resize_(opt.batchSize, opt.nz, 1, 1).normal_(0, 1)
    fake = netG(noise)
    errG = netD(fake)
    errG.backward(one)
    optimizerG.step()
    return errG


# === Funciones de m√©tricas ===
class InceptionV3_FID(nn.Module):
    def __init__(self):
        super().__init__()
        inception = inception_v3(weights=Inception_V3_Weights.IMAGENET1K_V1, transform_input=False)
        inception.fc = nn.Identity()   # quitamos la clasificaci√≥n
        inception.eval()
        for p in inception.parameters():
            p.requires_grad = False

        # Nos quedamos solo con las capas hasta el pen√∫ltimo bloque (pool3 ‚Üí 2048D)
        self.features = nn.Sequential(
            inception.Conv2d_1a_3x3,
            inception.Conv2d_2a_3x3,
            inception.Conv2d_2b_3x3,
            nn.MaxPool2d(3, stride=2),
            inception.Conv2d_3b_1x1,
            inception.Conv2d_4a_3x3,
            nn.MaxPool2d(3, stride=2),
            inception.Mixed_5b,
            inception.Mixed_5c,
            inception.Mixed_5d,
            inception.Mixed_6a,
            inception.Mixed_6b,
            inception.Mixed_6c,
            inception.Mixed_6d,
            inception.Mixed_6e,
            inception.Mixed_7a,
            inception.Mixed_7b,
            inception.Mixed_7c,
            inception.avgpool,   # AdaptiveAvgPool2d ‚Üí [N,2048,1,1]
        )

    def forward(self, x):
        # Resize a 299x299 (requisito InceptionV3)
        x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False)
        if x.shape[1] == 1:  # grayscale ‚Üí RGB
            x = x.repeat(1, 3, 1, 1)
        with torch.no_grad():
            feats = self.features(x)       # [N, 2048, 1, 1]
            feats = torch.flatten(feats, 1)  # [N, 2048]
        return feats

# === C√°lculo del FID ===
fid_inception = InceptionV3_FID().to(device)

def calculate_fid(real, fake):
    real = (real + 1) / 2  # [-1,1] ‚Üí [0,1]
    fake = (fake + 1) / 2
    with torch.no_grad():
        act_real = fid_inception(real.to(device)).cpu().numpy()
        act_fake = fid_inception(fake.to(device)).cpu().numpy()

    mu_real, sigma_real = np.mean(act_real, axis=0), np.cov(act_real, rowvar=False)
    mu_fake, sigma_fake = np.mean(act_fake, axis=0), np.cov(act_fake, rowvar=False)

    diff = mu_real - mu_fake
    covmean, _ = linalg.sqrtm(sigma_real.dot(sigma_fake), disp=False)
    if np.iscomplexobj(covmean):
        covmean = covmean.real

    fid = diff.dot(diff) + np.trace(sigma_real + sigma_fake - 2 * covmean)
    return float(fid)


# === Funciones de m√©tricas ===
def calculate_metrics(real, fake):
    real_dev = real.detach().to(device)
    fake_dev = fake.detach().to(device)

    # LPIPS (usa RGB)
    if real_dev.shape[1] == 1:
        real_3 = real_dev.repeat(1, 3, 1, 1)
        fake_3 = fake_dev.repeat(1, 3, 1, 1)
    else:
        real_3, fake_3 = real_dev, fake_dev

    with torch.no_grad():
        lpips_score = lpips_fn(real_3, fake_3).mean().item()

    # === FID estilo StyleGAN2-ADA ===
    fid_val = calculate_fid(real_dev, fake_dev)

    # SSIM y PSNR (en numpy, escala 0-255)
    real_np = (real_dev.detach().cpu().numpy().transpose(0, 2, 3, 1) * 127.5 + 127.5).astype(np.uint8)
    fake_np = (fake_dev.detach().cpu().numpy().transpose(0, 2, 3, 1) * 127.5 + 127.5).astype(np.uint8)

    ssim_vals, psnr_vals = [], []
    for i in range(min(len(real_np), len(fake_np))):
        r = real_np[i].squeeze()
        f = fake_np[i].squeeze()
        ssim_vals.append(ssim(r, f, data_range=255))
        psnr_vals.append(psnr(r, f, data_range=255))

    ssim_mean = float(np.mean(ssim_vals)) if ssim_vals else 0.0
    psnr_mean = float(np.mean(psnr_vals)) if psnr_vals else 0.0

    try:
        jsd_val = float(jensenshannon(real_np.flatten(), fake_np.flatten()))
    except Exception:
        jsd_val = None

    return {
        "FID": fid_val,
        "SSIM": ssim_mean,
        "PSNR": psnr_mean,
        "LPIPS": lpips_score,
        "JSD": jsd_val
    }

def compute_kid(feat_real, feat_fake, num_subsets=100, max_subset_size=1000):
    """MMD-based KID with polynomial(3) kernel ‚Äî igual que la implementaci√≥n de NVIDIA."""
    n = feat_real.shape[1]
    m = min(min(feat_real.shape[0], feat_fake.shape[0]), max_subset_size)
    t = 0.0
    for _ in range(num_subsets):
        x = feat_fake[np.random.choice(feat_fake.shape[0], m, replace=False)]
        y = feat_real[np.random.choice(feat_real.shape[0], m, replace=False)]
        a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3
        b = (x @ y.T / n + 1) ** 3
        # (a.sum() - diag(a).sum())/(m-1) - 2 * b.sum() / m
        t += (a.sum() - np.diag(a).sum()) / (m - 1) - 2.0 * b.sum() / m
    return float(t / num_subsets / m)

def _images_to_uint8(images):
    """Convierte un batch tensor [-1,1] o [0,1] a uint8 [0,255], formato NCHW -> N H W C numpy uint8."""
    # aceptamos im√°genes en [-1,1] (t√≠pico) o [0,1]
    imgs = images.detach().cpu()
    if imgs.min() < 0.0:  # asumimos [-1,1]
        imgs = (imgs + 1.0) * 127.5 + 0.0
    else:
        imgs = imgs * 255.0
    imgs = imgs.clamp(0, 255).numpy().transpose(0, 2, 3, 1).astype(np.uint8)
    return imgs

def compute_and_cache_real_feats(dataloader, cache_file, max_reals=None):
    """Extrae features Inception (2048D) para todo el dataset real y los cachea en cache_file (pickle)."""
    if os.path.isfile(cache_file):
        with open(cache_file, "rb") as f:
            feat_real = pickle.load(f)
        return feat_real

    feat_real = []
    print(f"üóÇÔ∏è Calculando features reales y guardando en: {cache_file}")
    for imgs, _ in dataloader:
        # imgs: tensor NCHW, en tu pipeline parecen estar en [-1,1]
        imgs_uint8 = _images_to_uint8(imgs)  # numpy uint8 HWC
        # convert back to tensor normalized to [0,1] for fid_inception usage
        imgs_t = torch.from_numpy(imgs_uint8.astype(np.float32) / 255.0).permute(0,3,1,2)
        # fid_inception espera floats en [0,1] y har√° resize
        with torch.no_grad():
            feats = fid_inception(imgs_t.to(device)).cpu().numpy()  # (N,2048)
        feat_real.append(feats)
        if max_reals is not None and sum([f.shape[0] for f in feat_real]) >= max_reals:
            break
    feat_real = np.vstack(feat_real)
    if max_reals is not None and feat_real.shape[0] > max_reals:
        feat_real = feat_real[:max_reals]
    # cache
    with open(cache_file, "wb") as f:
        pickle.dump(feat_real, f)
    return feat_real

def compute_fake_feats_from_generator(netG, num_fakes, batch_size, nz, device):
    """Genera num_fakes im√°genes con netG y extrae sus features Inception (2048D)."""
    feat_fake = []
    netG.eval()
    with torch.no_grad():
        n_done = 0
        while n_done < num_fakes:
            cur_bs = min(batch_size, num_fakes - n_done)
            z = torch.randn(cur_bs, nz, 1, 1, device=device)
            fake_imgs = netG(z)  # salida en [-1,1] asumida
            imgs_uint8 = _images_to_uint8(fake_imgs)  # numpy HWC uint8
            imgs_t = torch.from_numpy(imgs_uint8.astype(np.float32) / 255.0).permute(0,3,1,2)
            feats = fid_inception(imgs_t.to(device)).cpu().numpy()
            feat_fake.append(feats)
            n_done += cur_bs
    netG.train()
    feat_fake = np.vstack(feat_fake)
    return feat_fake

# Wrapper para integrarlo en calculate_metrics o en el loop de checkpoints
def calculate_kid_stylegan(real_dataloader, netG, opt, cache_name="inception_feats_real.pkl",
                           num_fakes=None, num_subsets=100, max_subset_size=1000):
    """
    Realiza KID fiel a StyleGAN2-ADA:
    - cachea feats reales en opt.checkpoint_dir/cache_name
    - genera num_fakes (por defecto = len(feat_real))
    - aplica compute_kid con num_subsets, max_subset_size
    """
    cache_file = os.path.join(opt.checkpoint_dir, cache_name)
    # Obtener feats reales (cache)
    feat_real = compute_and_cache_real_feats(real_dataloader, cache_file, max_reals=None)
    if num_fakes is None:
        num_fakes = feat_real.shape[0]

    # Generar feats fake
    feat_fake = compute_fake_feats_from_generator(netG, num_fakes=num_fakes,
                                                 batch_size=opt.batchSize, nz=opt.nz, device=device)
    # Calcular KID
    kid_val = compute_kid(feat_real, feat_fake, num_subsets=num_subsets, max_subset_size=max_subset_size)
    return kid_val


# ================= MAIN =================
if __name__ == '__main__':
    opt = argsget()

    # === Directorios ===
    base_dir = opt.experiment
    checkpoint_root = os.path.join(base_dir, "CheckpointsE1")
    generated_root = os.path.join(base_dir, "generatedE1")
    record_dir = os.path.join(base_dir, "record")

    mkdir(checkpoint_root)
    mkdir(generated_root)
    mkdir(record_dir)

    # Exponer rutas en opt para usarlas m√°s abajo
    opt.generated_dir = generated_root

    # === Manejo de sesiones de checkpoints (crea nueva carpeta cada vez) ===
    existing_sessions = [d for d in os.listdir(checkpoint_root) if d.endswith("_checkpoint")]
    existing_sessions.sort()

    resume_dir = None
    start_epoch = 0


    # === Archivos de historial global ===
    loss_file = os.path.join(base_dir, "loss_history.json")
    if os.path.exists(loss_file):
        with open(loss_file, "r") as f:
            loss_history = json.load(f)
    else:
        loss_history = []


    # === Random seed ===
    opt.manualSeed = random.randint(1, 10000)
    print(f"üé≤ Random Seed: {opt.manualSeed}")

    random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)
    cudnn.benchmark = True

    dataloader = dstget(opt)  

    # === Definir modelos
    if opt.noBN:
        netG = DCGAN_G_nobn(opt.img_size, opt.nz, opt.nc, opt.ngf, opt.ngpu, opt.n_extra_layers)
    elif opt.mlp_G:
        netG = MLP_G(opt.img_size, opt.nz, opt.nc, opt.ngf, opt.ngpu)
    else:
        netG = DCGAN_G(opt.img_size, opt.nz, opt.nc, opt.ngf, opt.ngpu, opt.n_extra_layers)
    netG.apply(weights_init)
    netG.to(device)

    if opt.mlp_D:
        netD = MLP_D(opt.img_size, opt.nz, opt.nc, opt.ndf, opt.ngpu)
    else:
        netD = DCGAN_D(opt.img_size, opt.nz, opt.nc, opt.ndf, opt.ngpu, opt.n_extra_layers)
    netD.apply(weights_init)
    netD.to(device)

    # === Reanudar desde el √∫ltimo checkpoint consistente (G y D del mismo epoch) ===
    if existing_sessions:
        # Carpeta m√°s reciente (para reanudar)
        last_session = existing_sessions[-1]
        resume_dir = os.path.join(checkpoint_root, last_session)
        print(f"‚úÖ Carpeta de checkpoints encontrada: {resume_dir}")

        # Buscar checkpoints de esa carpeta
        netG_files = glob.glob(os.path.join(resume_dir, "netG_epoch*.pth"))
        netD_files = glob.glob(os.path.join(resume_dir, "netD_epoch*.pth"))

        def extract_epochs(files, prefix):
            epochs = []
            for f in files:
                m = re.search(rf"{prefix}_epoch(\d+)\.pth", os.path.basename(f))
                if m:
                    epochs.append(int(m.group(1)))
            return sorted(epochs)

        G_epochs = extract_epochs(netG_files, "netG")
        D_epochs = extract_epochs(netD_files, "netD")
        common_epochs = sorted(set(G_epochs).intersection(D_epochs))

        if common_epochs:
            start_epoch = common_epochs[-1]
            netG.load_state_dict(torch.load(os.path.join(resume_dir, f"netG_epoch{start_epoch}.pth"), map_location=device))
            netD.load_state_dict(torch.load(os.path.join(resume_dir, f"netD_epoch{start_epoch}.pth"), map_location=device))
            print(f"üîÑ Reanudando desde checkpoint epoch {start_epoch} en {resume_dir}")
        else:
            print("‚ö†Ô∏è No se encontraron checkpoints coincidentes. Entrenamiento iniciar√° desde epoch 0.")

        # Crear NUEVA carpeta numerada
        last_idx = int(last_session.split("_")[0])
        new_idx = last_idx + 1
    else:
        print("üÜï No hab√≠a checkpoints previos. Entrenamiento desde cero.")
        new_idx = 0

    session_name = f"{new_idx:05d}_checkpoint"
    opt.checkpoint_dir = os.path.join(checkpoint_root, session_name)
    mkdir(opt.checkpoint_dir)
    print(f"üìÇ Nueva carpeta de checkpoints creada: {opt.checkpoint_dir}")

    # === Inputs fijos
    noise = torch.FloatTensor(opt.batchSize, opt.nz, 1, 1).to(device)
    fixed_noise = torch.FloatTensor(1, opt.nz, 1, 1).normal_(0, 1).to(device)  # solo 1 imagen

    # === Optimizadores
    if opt.adam:
        optimizerD = optim.Adam(netD.parameters(), lr=opt.lrD, betas=(opt.beta1, 0.999))
        optimizerG = optim.Adam(netG.parameters(), lr=opt.lrG, betas=(opt.beta1, 0.999))
    else:
        optimizerD = optim.RMSprop(netD.parameters(), lr=opt.lrD)
        optimizerG = optim.RMSprop(netG.parameters(), lr=opt.lrG)

    gen_iterations = 0
    total_D, total_G = 0, 0
    T1, T2 = opt.T1, opt.T2

    # === Guardar imagen inicial SOLO si empezamos desde epoch 0 ===
    if start_epoch == 0:
        with torch.no_grad():
            init_fake = netG(fixed_noise).detach()
        save_image((init_fake.cpu() * 0.5 + 0.5),
                  os.path.join(opt.generated_dir, "generated_0.png"),
                  nrow=1, normalize=False)
        print("üñºÔ∏è Imagen inicial generada guardada como generated_0.png")


    # === Entrenamiento
    for epoch in range(start_epoch, opt.niter):
        data_iter = iter(dataloader)
        i = 0
        while i < len(dataloader):
            if gen_iterations < 2:  # WGAN warmup
                Diters = 100
                j = 0
                while j < Diters and i < len(dataloader):
                    j += 1
                    total_D += 1
                    imgs, _ = next(data_iter)
                    imgs = imgs.to(device)
                    i += 1
                    errD_real, errD_fake, errD = discriminator_train(netD, netG, imgs, noise, optimizerD, opt)

                total_G += 1
                for p in netD.parameters():
                    p.requires_grad = False
                errG = generator_train(netG, netD, noise, optimizerG, opt)
                gen_iterations += 1
            else:  # AdaGAN mode
                imgs, _ = next(data_iter)
                imgs = imgs.to(device)
                i += 1

                c_errD_real, c_errD_fake, c_errD = discriminator_infer(netD, netG, imgs, noise, opt)
                c_errG = c_errD_fake.item()

                if not (c_errD_real.item() < c_errD_fake.item() - T2 and c_errG > T1):
                    total_D += 1
                    errD_real, errD_fake, errD = discriminator_train(netD, netG, imgs, noise, optimizerD, opt)
                else:
                    total_G += 1
                    for p in netD.parameters():
                        p.requires_grad = False
                    errG = generator_train(netG, netD, noise, optimizerG, opt)
                    gen_iterations += 1

        # === Guardar p√©rdidas cada √©poca + imprimir progreso ===
        epoch_losses = {
            "epoch": epoch + 1,
            "loss_D": float(errD.item()) if 'errD' in locals() else None,
            "loss_G": float(errG.item()) if 'errG' in locals() else None
        }
        loss_history.append(epoch_losses)
        with open(loss_file, "w") as f:
            json.dump(loss_history, f, indent=4)

        d_str = f"{epoch_losses['loss_D']:.4f}" if epoch_losses['loss_D'] is not None else "nan"
        g_str = f"{epoch_losses['loss_G']:.4f}" if epoch_losses['loss_G'] is not None else "nan"
        print(f"[√âpoca {epoch+1}/{opt.niter}] Loss_D: {d_str} | Loss_G: {g_str}")

        # === Guardar checkpoints, im√°genes y m√©tricas cada 5 √©pocas ===
        if (epoch + 1) % 5 == 0:
            # Guardar modelos
            torch.save(netG.state_dict(), os.path.join(opt.checkpoint_dir, f"netG_epoch{epoch+1}.pth"))
            torch.save(netD.state_dict(), os.path.join(opt.checkpoint_dir, f"netD_epoch{epoch+1}.pth"))

            # Generar UNA imagen de ejemplo (fixed_noise)
            with torch.no_grad():
                fake_example = netG(fixed_noise).detach()
            save_image((fake_example.cpu() * 0.5 + 0.5),
                       os.path.join(opt.generated_dir, f"generated_{epoch+1}.png"),
                       nrow=1, normalize=False)

            # Ahora (ejemplo): calcula KID "global" usando todo el dataset y generador
            kid_val = calculate_kid_stylegan(dataloader, netG, opt, cache_name="inception_feats_real.pkl", num_fakes=min(10000, len(dataloader)*opt.batchSize), num_subsets=100, max_subset_size=1000)

            # Calcular m√©tricas sobre un batch real y batch fake del mismo tama√±o
            real_batch, _ = next(iter(dataloader))
            real_batch = real_batch.to(device)
            z = torch.randn(real_batch.size(0), opt.nz, 1, 1, device=device)
            fake_batch = netG(z)
            other_metrics = calculate_metrics(real_batch, fake_batch)  # esta funci√≥n ya devuelve FID etc.

            metrics = calculate_metrics(real_batch, fake_batch)
            metrics = {
                "epoch": epoch + 1,
                "KID": kid_val,
                "FID": other_metrics["FID"],
                "LPIPS": other_metrics["LPIPS"],
                "SSIM": other_metrics["SSIM"],
                "PSNR": other_metrics["PSNR"],
                "JSD": other_metrics["JSD"]
            }

            # Guardar m√©tricas acumuladas
            metrics_file = os.path.join(opt.checkpoint_dir, "metrics_history.json")
            if os.path.exists(metrics_file):
                with open(metrics_file, "r") as f:
                    metrics_history = json.load(f)
            else:
                metrics_history = []
            metrics_history.append(metrics)
            with open(metrics_file, "w") as f:
                json.dump(metrics_history, f, indent=4)

            print(f"‚úÖ Checkpoint, imagen y m√©tricas guardadas en √©poca {epoch+1} ‚Äî Carpeta: {opt.checkpoint_dir}")

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.12/dist-packages/lpips/weights/v0.1/alex.pth
üé≤ Random Seed: 7582
‚úÖ Carpeta de checkpoints encontrada: /content/drive/MyDrive/Proyecto_Grado/MedGAN/ADA/CheckpointsE1/00000_checkpoint
üîÑ Reanudando desde checkpoint epoch 70 en /content/drive/MyDrive/Proyecto_Grado/MedGAN/ADA/CheckpointsE1/00000_checkpoint
üìÇ Nueva carpeta de checkpoints creada: /content/drive/MyDrive/Proyecto_Grado/MedGAN/ADA/CheckpointsE1/00001_checkpoint
[√âpoca 71/2000] Loss_D: -1.0943 | Loss_G: 1.1747
[√âpoca 72/2000] Loss_D: -0.8682 | Loss_G: 1.2210
[√âpoca 73/2000] Loss_D: -0.8029 | Loss_G: 0.8111
[√âpoca 74/2000] Loss_D: -0.3731 | Loss_G: 0.7124
[√âpoca 75/2000] Loss_D: -0.6022 | Loss_G: 1.1260
üóÇÔ∏è Calculando features reales y guardando en: /content/drive/MyDrive/Proyecto_Grado/MedGAN/ADA/CheckpointsE1/00001_checkpoint/inception_feats_real.pkl


  covmean, _ = linalg.sqrtm(sigma_real.dot(sigma_fake), disp=False)


‚úÖ Checkpoint, imagen y m√©tricas guardadas en √©poca 75 ‚Äî Carpeta: /content/drive/MyDrive/Proyecto_Grado/MedGAN/ADA/CheckpointsE1/00001_checkpoint
[√âpoca 76/2000] Loss_D: -1.1357 | Loss_G: 1.0560
[√âpoca 77/2000] Loss_D: -0.8697 | Loss_G: 0.8433
[√âpoca 78/2000] Loss_D: -0.4001 | Loss_G: 0.3830
[√âpoca 79/2000] Loss_D: -0.9265 | Loss_G: 0.6249
[√âpoca 80/2000] Loss_D: -0.9373 | Loss_G: 1.1962
‚úÖ Checkpoint, imagen y m√©tricas guardadas en √©poca 80 ‚Äî Carpeta: /content/drive/MyDrive/Proyecto_Grado/MedGAN/ADA/CheckpointsE1/00001_checkpoint
[√âpoca 81/2000] Loss_D: -0.7548 | Loss_G: 0.6167
[√âpoca 82/2000] Loss_D: -0.4359 | Loss_G: 0.3954
[√âpoca 83/2000] Loss_D: -0.7421 | Loss_G: 1.1746
[√âpoca 84/2000] Loss_D: -0.8215 | Loss_G: 1.0196
[√âpoca 85/2000] Loss_D: -0.4871 | Loss_G: 0.5437
‚úÖ Checkpoint, imagen y m√©tricas guardadas en √©poca 85 ‚Äî Carpeta: /content/drive/MyDrive/Proyecto_Grado/MedGAN/ADA/CheckpointsE1/00001_checkpoint
[√âpoca 86/2000] Loss_D: -0.5356 | Loss_G: 0.74

### Metricas

In [None]:
# Ruta base 
base_dir = EXPERIMENTS_DIR

# Archivos globales de p√©rdidas
loss_file = os.path.join(base_dir, "loss_history.json")

# Carpeta ra√≠z de checkpoints
checkpoint_root = os.path.join(base_dir, "CheckpointsE1")

# --- Cargar p√©rdidas ---
loss_history = []
if os.path.exists(loss_file):
    with open(loss_file, "r") as f:
        loss_history = json.load(f)
df_loss = pd.DataFrame(loss_history)

# --- Cargar TODAS las m√©tricas de todos los checkpoints ---
metrics_history = []
if os.path.exists(checkpoint_root):
    sessions = [d for d in os.listdir(checkpoint_root) if d.endswith("_checkpoint")]
    sessions.sort()
    for session in sessions:
        metrics_file = os.path.join(checkpoint_root, session, "metrics_history.json")
        if os.path.exists(metrics_file):
            with open(metrics_file, "r") as f:
                metrics = json.load(f)
                metrics_history.extend(metrics)

df_metrics = pd.DataFrame(metrics_history)

# --- Merge p√©rdidas + m√©tricas ---
df = pd.merge(df_loss, df_metrics, on="epoch", how="outer").sort_values("epoch")

# --- Reordenar columnas ---
ordered_cols = ["epoch", "loss_G", "loss_D", "FID", "KID", "JSD", "SSIM", "PSNR", "LPIPS"]
for col in ordered_cols:
    if col not in df.columns:
        df[col] = None

df = df[ordered_cols]

# --- Reemplazar NaN con vac√≠o ---
df = df.fillna("")

# --- Guardar CSV ---
csv_path = os.path.join(base_dir, "metrics_medganE4.csv")
df.to_csv(csv_path, index=False)

print(f"‚úÖ CSV exportado en: {csv_path}")
df.tail()