In [None]:
import os
import glob
import json
import pandas as pd
import subprocess

import os, glob, json, re
import pandas as pd
import numpy as np
from PIL import Image
import torch
import torchvision.transforms as T
import skimage.metrics
import lpips
from scipy.spatial.distance import jensenshannon
import torchvision.models as models
import torch.nn as nn

from pathlib import Path
import sys
PROJECT_ROOT = Path.cwd().parent
STYLEGAN2_ADA_DIR = PROJECT_ROOT / "external" / "stylegan2-ada"
sys.path.append(str(STYLEGAN2_ADA_DIR))
import dnnlib
import legacy

# Definir el dispositivo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
PROJECT_ROOT = os.path.abspath("..")  # si el notebook está en notebooks/
REPO_DIR = os.path.join(PROJECT_ROOT, "external", "stylegan2-ada")
DATA_DIR = os.path.join(PROJECT_ROOT, "data", "processed")


In [None]:
# DATASET
DATASET_ZIP = os.path.join(DATA_DIR, "frames_extraidos.zip")
assert os.path.isfile(DATASET_ZIP)
RESULTS_DIR = os.path.join(REPO_DIR, "results", "experiment_E2")
PROGRESS_DIR = os.path.join(REPO_DIR, "progress", "experiment_E2")

os.makedirs(PROGRESS_DIR, exist_ok=True)

# Buscar último snapshot
snapshots = sorted(
    glob.glob(os.path.join(RESULTS_DIR, "0*/network-snapshot-*.pkl"))
)

resume_path = snapshots[-1] if snapshots else None
use_resume = True

if use_resume and resume_path:
    print(f"Resuming from: {resume_path}")
else:
    print("Training from scratch")
    resume_path = None

## Train

In [None]:
cmd = [
    "python", "train.py",
    f"--outdir={RESULTS_DIR}",
    f"--data={DATASET_ZIP}",
    "--gpus=1",
    "--batch=32",
    "--cfg=11gb-gpu",
    "--mirror=0",
    "--gamma=32",
    "--aug=ada",
    "--target=0.6",
    "--lrate=0.006",
    "--snap=4"
]

if resume_path:
    cmd.append(f"--resume={resume_path}")

subprocess.run(cmd, check=True)


## Calculation Metrics

In [None]:
# ================== KID HELPER ==================
# InceptionV3 para FID/KID: usamos el penúltimo layer (pool3)
inception = models.inception_v3(weights=models.Inception_V3_Weights.IMAGENET1K_V1, transform_input=False)
inception.fc = nn.Identity()  # quitamos la última capa de clasificación
inception.eval().to(device)

def get_inception_features(x):
    # x debe estar en rango [-1, 1], lo pasamos a [0, 1] si es necesario
    if x.min() < 0:
        x = (x + 1) / 2
    if x.shape[2] != 299 or x.shape[3] != 299:
        x = torch.nn.functional.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False)
    with torch.no_grad():
        feats = inception(x)  # ahora devuelve [B,2048]
    return feats


def polynomial_mmd(x, y, degree=3, gamma=None, coef0=1):
    xx = x @ x.t()
    yy = y @ y.t()
    xy = x @ y.t()
    if gamma is None:
        gamma = 1.0 / x.shape[1]
    K_xx = (gamma * xx + coef0) ** degree
    K_yy = (gamma * yy + coef0) ** degree
    K_xy = (gamma * xy + coef0) ** degree
    return K_xx.mean() + K_yy.mean() - 2 * K_xy.mean()

def compute_kid(real_imgs, fake_imgs, batch_size=16):
    real_feats, fake_feats = [], []
    for i in range(0, len(real_imgs), batch_size):
        r = torch.cat(real_imgs[i:i+batch_size]).to(device)
        f = torch.cat(fake_imgs[i:i+batch_size]).to(device)
        real_feats.append(get_inception_features(r))
        fake_feats.append(get_inception_features(f))
    real_feats = torch.cat(real_feats, dim=0)
    fake_feats = torch.cat(fake_feats, dim=0)
    return polynomial_mmd(real_feats, fake_feats).item()


# === PPL helper ===
def compute_ppl(G, device, n_samples=64, eps=1e-4):
    lat_dim = G.z_dim
    z = torch.randn([n_samples, lat_dim], device=device)
    c = torch.zeros([n_samples, G.c_dim], device=device)

    # Interpolación en el espacio latente
    z_eps = z.clone()
    z_eps[:, 0] += eps  # perturbamos la primera dimensión del vector z

    # Generar imágenes originales y perturbadas
    imgs1 = G(z, c, truncation_psi=0.7, noise_mode='const')
    imgs2 = G(z_eps, c, truncation_psi=0.7, noise_mode='const')

    # Normalizamos [0,1] para LPIPS
    imgs1 = (imgs1.clamp(-1, 1) + 1) / 2
    imgs2 = (imgs2.clamp(-1, 1) + 1) / 2

    # Redimensionamos a 256x256 para LPIPS (más rápido)
    imgs1 = torch.nn.functional.interpolate(imgs1, size=(256, 256), mode='bilinear', align_corners=False)
    imgs2 = torch.nn.functional.interpolate(imgs2, size=(256, 256), mode='bilinear', align_corners=False)

    # LPIPS perceptual distance
    lpips_model = lpips.LPIPS(net='alex').to(device)
    d = lpips_model(imgs1, imgs2)

    # Escalamos por la perturbación
    ppl = (d / (eps**2)).mean().item()
    return ppl

### Metrics CSV

In [None]:
# ================== CONFIGURACIÓN ==================
# Dataset
N_images = 40120
mirror = 0          # 1 si usaste --mirror=1, 0 si no
N_eff = N_images * (2 if mirror == 1 else 1)

# Paths
results_dir = "/content/drive/MyDrive/Proyecto_Grado/colab-sg2-ada-pytorch/resultsE2"
progreso_dir = "/content/drive/MyDrive/Proyecto_Grado/colab-sg2-ada-pytorch/progresoE2"
data_dir = "/content/drive/MyDrive/Proyecto_Grado/Data"
os.makedirs(progreso_dir, exist_ok=True)


# Imagen real (ejemplo: primera de class0)
real_image_path = sorted(glob.glob(f'{data_dir}/frames_extraidos/*.png'))[0]
target_size = (512, 512)

# Transformación
real_img = Image.open(real_image_path).convert('RGB').resize(target_size)
transform = T.ToTensor()
x_real = transform(real_img).numpy()

# CSV final
out_csv = os.path.join(progreso_dir, "metrics_snapshots.csv")

# ================== MÉTRICAS ==================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# LPIPS
loss_fn_lpips = lpips.LPIPS(net='alex').to(device)

# JSD helper
def compute_jsd(p, q, bins=256):
    p_hist, _ = np.histogram(p.flatten(), bins=bins, range=(0,1), density=True)
    q_hist, _ = np.histogram(q.flatten(), bins=bins, range=(0,1), density=True)
    return jensenshannon(p_hist, q_hist)

# ================== LOOP PRINCIPAL ==================
all_rows = []
run_dirs = sorted(glob.glob(f'{results_dir}/0000*'))

kimg_offset = 0  # <<< offset acumulativo de kimgs

for run_dir in run_dirs:
    run_id = os.path.basename(run_dir)

    # ---- Cargar stats.jsonl para pérdidas ----
    stats_path = os.path.join(run_dir, 'stats.jsonl')
    stats_lines = []
    if os.path.exists(stats_path):
        with open(stats_path, 'r') as f:
            for line in f:
                try:
                    e = json.loads(line)
                    kimg_val = e['Progress/kimg']['mean'] if isinstance(e['Progress/kimg'], dict) else e['Progress/kimg']
                    g_loss = e['Loss/G/loss']['mean'] if isinstance(e['Loss/G/loss'], dict) else e['Loss/G/loss']
                    d_loss = e['Loss/D/loss']['mean'] if isinstance(e['Loss/D/loss'], dict) else e['Loss/D/loss']
                    stats_lines.append((float(kimg_val), g_loss, d_loss))
                except KeyError:
                    continue

    # ---- Cargar metric-fid50k_full.jsonl ----
    fid_dict = {}
    fid_path = os.path.join(run_dir, "metric-fid50k_full.jsonl")
    if os.path.exists(fid_path):
        with open(fid_path, "r") as f:
            for line in f:
                try:
                    e = json.loads(line)
                    snap = os.path.basename(e["snapshot_pkl"])
                    fid = e["results"]["fid50k_full"]
                    fid_dict[snap] = fid
                except:
                    continue

    # ---- Buscar snapshots ----
    snapshot_paths = sorted(glob.glob(f'{run_dir}/network-snapshot-*.pkl'))

    for snapshot in snapshot_paths:
        snap_name = os.path.basename(snapshot)
        kimg_match = re.search(r'network-snapshot-(\d+).pkl', snapshot)
        if not kimg_match:
            continue
        kimg_snap_local = int(kimg_match.group(1))  # kimg relativo a la carpeta
        kimg_snap = kimg_snap_local + kimg_offset   # kimg absoluto y continuo
        epoch = (kimg_snap * 1000) / N_eff          # conversión a épocas

        # Buscar pérdidas en stats.jsonl para este kimg local
        g_loss, d_loss = None, None
        if stats_lines:
            nearest = min(stats_lines, key=lambda x: abs(x[0] - kimg_snap_local))
            g_loss, d_loss = nearest[1], nearest[2]

        # === Generar imagen con el snapshot ===
        with dnnlib.util.open_url(snapshot) as f_model:
            G = legacy.load_network_pkl(f_model)['G_ema'].to(device)

        z = torch.randn([1, G.z_dim], device=device)
        label = torch.zeros([1, G.c_dim], device=device)
        img = G(z, label, truncation_psi=0.7, noise_mode='const')

        img = (img.clamp(-1, 1) + 1) * 127.5
        img = img.permute(0, 2, 3, 1)[0].cpu().numpy().astype(np.uint8)
        if img.shape[-1] == 1:
            img = np.repeat(img, 3, axis=-1)

        img_pil = Image.fromarray(img).resize(target_size)

        # === Calcular métricas SSIM y PSNR ===
        x_fake = transform(img_pil).numpy()
        ssim_val = skimage.metrics.structural_similarity(
            x_real.transpose(1, 2, 0),
            x_fake.transpose(1, 2, 0),
            channel_axis=-1,
            data_range=1.0
        )
        psnr_val = skimage.metrics.peak_signal_noise_ratio(
            x_real, x_fake, data_range=1.0
        )

        # === LPIPS ===
        x_real_torch = torch.tensor(x_real).unsqueeze(0).to(device)
        x_fake_torch = torch.tensor(x_fake).unsqueeze(0).to(device)
        lpips_val = loss_fn_lpips(x_real_torch, x_fake_torch).item()

        # === JSD ===
        jsd_val = compute_jsd(x_real, x_fake)

        # === FID (si existe en metric-fid50k_full.jsonl) ===
        fid_val = fid_dict.get(snap_name, None)

        # === KID ===
        # Normalizar imágenes a [0,1] y tensor 3xHxW
        x_real_torch_incep = x_real_torch.clone()
        x_fake_torch_incep = x_fake_torch.clone()
        kid_val = compute_kid(
            [x_real_torch.cpu()],
            [x_fake_torch.cpu()]
        )

        # === PPL ===
        ppl_val = compute_ppl(G, device, n_samples=64, eps=1e-4)

        # === Guardar fila ===
        all_rows.append([
          epoch, g_loss, d_loss, fid_val, kid_val, ssim_val, psnr_val,
          lpips_val, jsd_val, ppl_val, kimg_snap, run_id
      ])

    # ---- Actualizar offset con el último snapshot de esta carpeta ----
    if snapshot_paths:
        last_local = max(int(re.search(r'network-snapshot-(\d+).pkl', s).group(1)) for s in snapshot_paths)
        kimg_offset += last_local  # se suma al offset global

# ================== GUARDAR CSV ==================
df_all = pd.DataFrame(all_rows, columns=[
    'epoch', 'G_loss', 'D_loss', 'FID', 'KID', 'SSIM', 'PSNR',
    'LPIPS', 'JSD', 'PPL', 'kimg', 'run_id'
])
df_all = df_all.sort_values(['kimg'])
df_all.to_csv(out_csv, index=False)
print(f"CSV final con snapshots guardado en: {out_csv}")


## Generated Images

In [None]:
# Número de imágenes en tu dataset
N_images = 40120   # cámbialo si tu dataset cambia
mirror = 0         # ponlo en 1 si usaste --mirror=1
N_eff = N_images * (2 if mirror == 1 else 1)

# Ruta donde están los resultados
results_dir = "/content/drive/MyDrive/Proyecto_Grado/colab-sg2-ada-pytorch/resultsE2"
progreso_dir = "/content/drive/MyDrive/Proyecto_Grado/colab-sg2-ada-pytorch/progresoE2"
os.makedirs(progreso_dir, exist_ok=True)

# Inicializar lista de DataFrames
all_runs_data = []

# Buscar todas las carpetas de runs
run_dirs = sorted(glob.glob(f'{results_dir}/0000*'))

# Offset acumulado de kimg
kimg_offset = 0

for run_dir in run_dirs:
    run_id = os.path.basename(run_dir)

    stats_path = os.path.join(run_dir, 'stats.jsonl')
    if not os.path.exists(stats_path):
        print(f"Stats no encontrado para {run_id}, se omite.")
        continue

    # Leer cada línea JSON
    with open(stats_path, 'r') as f:
        lines = [json.loads(line) for line in f]

    rows = []
    for entry in lines:
        try:
            kimg = entry['Progress/kimg']['mean'] if isinstance(entry['Progress/kimg'], dict) else entry['Progress/kimg']
            g_loss = entry['Loss/G/loss']['mean'] if isinstance(entry['Loss/G/loss'], dict) else entry['Loss/G/loss']
            d_loss = entry['Loss/D/loss']['mean'] if isinstance(entry['Loss/D/loss'], dict) else entry['Loss/D/loss']

            # Calcular kimg acumulado
            kimg_total = kimg + kimg_offset

            # Convertir a épocas
            epoch = (kimg_total * 1000) / N_eff

            # Guardar fila
            rows.append({
                'epoch': epoch,         # columna 1: épocas acumuladas
                'G_loss': g_loss,       # columna 2
                'D_loss': d_loss,       # columna 3
                'kimg': kimg_total,     # columna 4: kimg acumulado
                'run_id': run_id        # columna 5: carpeta origen
            })
        except KeyError:
            continue

    # Si se recogieron datos en este run
    if rows:
        df_run = pd.DataFrame(rows)
        all_runs_data.append(df_run)

        # Actualizar offset: sumar el último kimg de esta carpeta
        last_kimg = df_run['kimg'].max() - kimg_offset  # solo lo local
        kimg_offset += last_kimg

# Combinar todos los runs en un solo DataFrame
df_all = pd.concat(all_runs_data, ignore_index=True)

# Guardar como CSV
out_csv = os.path.join(progreso_dir, 'losses_all_runs.csv')
df_all.to_csv(out_csv, index=False)
print(f"CSV combinado guardado en: {out_csv}")