In [5]:
# ============================================================
# PIPELINE: L2 → StandardScaler → PCA(whiten) → One-Class SVM
# - Embeddings: ResNet18 (ImageNet) 512-D por imagen
# - Train: solo 'sanos' (One-Class)
# - Validación: 'sanos' vs 'no_sanos'
# - Selección de hiperparámetros por AUC en validación (anomaly score)
# - Umbral: mejor F1 (opción alternativa: FPR objetivo en sanos)
# - Salidas: ./out_oneclass_pca/*
# Requisitos: torch, torchvision, pillow, numpy, scikit-learn, matplotlib, joblib
# ============================================================

import os, json, warnings
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import torch
import torch.nn as nn
from torchvision import models, transforms

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler, Normalizer
from sklearn.decomposition import PCA
from sklearn.svm import OneClassSVM
from sklearn.metrics import (
    roc_auc_score, precision_recall_curve, average_precision_score,
    roc_curve, confusion_matrix, ConfusionMatrixDisplay,
    f1_score, accuracy_score, precision_score, recall_score
)
import joblib
warnings.filterwarnings("ignore", category=UserWarning)

# ------------------- Rutas -------------------
DIR_TRAIN_SANOS = Path("./sampled/sanos")
DIR_VAL_SANOS   = Path("./val/sanos")     # si tienes otra carpeta para val-sanos, cámbiala aquí
DIR_VAL_INFECT  = Path("./sampled/nosanos")

OUT_DIR = Path("./out_oneclass_pca")
OUT_DIR.mkdir(parents=True, exist_ok=True)

CACHE_TRAIN_EMB = OUT_DIR / "emb_train_sanos.npy"
CACHE_VALS_EMB  = OUT_DIR / "emb_val_sanos.npy"
CACHE_VALI_EMB  = OUT_DIR / "emb_val_nosanos.npy"
MODEL_PATH      = OUT_DIR / "pipeline_ocsvm_pca.joblib"
VAL_CSV         = OUT_DIR / "val_results.csv"
METRICS_JSON    = OUT_DIR / "metrics.json"

# ------------------- Config -------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IMG_SIZE = 224
BATCH = 32

# Grid de hiperparámetros
GRID_PCA = [64, 128]
GRID_NU  = [0.01, 0.02, 0.05, 0.1]
GRID_GAM = ["scale", 1e-3, 1e-4, 1e-5]

# Umbral por FPR objetivo (conservador) en sanos
FPR_TARGET = 0.05

# ------------------- Utilidades -------------------
def list_images(root: Path):
    exts = {".png",".jpg",".jpeg",".tif",".tiff",".bmp"}
    return sorted([p for p in root.rglob("*") if p.suffix.lower() in exts])

def ensure_non_empty(paths, msg):
    if len(paths) == 0:
        raise RuntimeError(msg)

def build_model_resnet18():
    # Compat con distintas versiones de torchvision
    weights = None
    try:
        weights = models.ResNet18_Weights.IMAGENET1K_V1
        base = models.resnet18(weights=weights)
    except Exception:
        base = models.resnet18(pretrained=True)
    emb_net = nn.Sequential(*list(base.children())[:-1]).to(DEVICE).eval()
    # Transform preferido por weights; fallback a ImageNet mean/std
    tfm = None
    if weights is not None:
        try:
            base_tfm = weights.transforms()
            tfm = transforms.Compose([
                transforms.Resize((IMG_SIZE, IMG_SIZE)),
                base_tfm
            ])
        except Exception:
            pass
    if tfm is None:
        tfm = transforms.Compose([
            transforms.Resize((IMG_SIZE, IMG_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
    return emb_net, tfm

@torch.no_grad()
def embeddings_from_paths(paths, emb_net, tfm, batch=BATCH):
    vecs = []
    for i in range(0, len(paths), batch):
        chunk = paths[i:i+batch]
        imgs = []
        for p in chunk:
            im = Image.open(p).convert("RGB")
            imgs.append(tfm(im))
        x = torch.stack(imgs, dim=0).to(DEVICE) if imgs else torch.empty(0)
        if x.numel() == 0:
            continue
        feat = emb_net(x)              # (B,512,1,1)
        feat = feat.view(feat.size(0), -1).cpu().numpy()
        vecs.append(feat)
    if len(vecs) == 0:
        return np.zeros((0,512), dtype=np.float32)
    return np.vstack(vecs).astype(np.float32)

def save_confusion(y_true, y_pred, out_path, title="Confusion matrix"):
    cm = confusion_matrix(y_true, y_pred, labels=[0,1])
    disp = ConfusionMatrixDisplay(cm, display_labels=["sano(0)","infectado(1)"])
    fig, ax = plt.subplots(figsize=(4.8,4.8))
    disp.plot(ax=ax, cmap="Blues", values_format="d", colorbar=False)
    ax.set_title(title)
    plt.tight_layout(); plt.savefig(out_path, dpi=160); plt.close()

def rnorm(x):
    p1, p99 = np.percentile(x, 1), np.percentile(x, 99)
    return np.clip((x - p1) / (p99 - p1 + 1e-8), 0, 1)

# ------------------- Embeddings -------------------
def build_or_load_embeddings():
    emb_net, tfm = build_model_resnet18()

    train_paths = list_images(DIR_TRAIN_SANOS)
    val_sanos_paths = list_images(DIR_VAL_SANOS)
    val_infs_paths  = list_images(DIR_VAL_INFECT)

    ensure_non_empty(train_paths, f"No hay imágenes en {DIR_TRAIN_SANOS}")
    ensure_non_empty(val_sanos_paths, f"No hay imágenes en {DIR_VAL_SANOS}")
    ensure_non_empty(val_infs_paths,  f"No hay imágenes en {DIR_VAL_INFECT}")

    if CACHE_TRAIN_EMB.exists():
        X_train = np.load(CACHE_TRAIN_EMB)
        print(f"[CACHE] X_train sanos -> {X_train.shape}")
    else:
        print(f"[EMB] Extrayendo embeddings TRAIN sanos: {len(train_paths)}")
        X_train = embeddings_from_paths(train_paths, emb_net, tfm)
        np.save(CACHE_TRAIN_EMB, X_train)

    if CACHE_VALS_EMB.exists() and CACHE_VALI_EMB.exists():
        X_vs = np.load(CACHE_VALS_EMB)
        X_vi = np.load(CACHE_VALI_EMB)
        print(f"[CACHE] X_val sanos -> {X_vs.shape} | X_val infectados -> {X_vi.shape}")
    else:
        print("[EMB] Extrayendo embeddings VALIDACIÓN…")
        X_vs = embeddings_from_paths(val_sanos_paths, emb_net, tfm)
        X_vi = embeddings_from_paths(val_infs_paths, emb_net, tfm)
        np.save(CACHE_VALS_EMB, X_vs)
        np.save(CACHE_VALI_EMB, X_vi)

    return X_train, (X_vs, X_vi), (val_sanos_paths, val_infs_paths)


def make_pca_safe(ncomp, X_train):
    # máximo = min(n_samples, n_features)
    max_comp = min(X_train.shape[0], X_train.shape[1])
    return min(ncomp, max_comp)

# ------------------- Grid Search (AUC) -------------------
def grid_search_auc(X_train, X_vs, X_vi):
    best = {"auc": -1, "pipe": None, "params": None}
    for ncomp in GRID_PCA:
        for nu in GRID_NU:
            for gamma in GRID_GAM:
                ncomp_safe = make_pca_safe(ncomp, X_train)
                pipe = Pipeline([
                    ("l2", Normalizer(norm="l2")),
                    ("sc", StandardScaler()),            # centra y escala embeddings L2
                    ("pca", PCA(n_components=ncomp_safe, whiten=True, random_state=42)),
                    #("pca", PCA(n_components=128, whiten=True, random_state=42)),
                    ("oc", OneClassSVM(kernel="rbf", nu=nu, gamma=gamma)),
                ])
                pipe.fit(X_train)  # SOLO sanos

                # decision_function: mayor = más “normal”
                s_sanos = pipe.score_samples(X_vs)  # igual que decision_function para OCSVM
                s_infs  = pipe.score_samples(X_vi)
                y_true  = np.r_[np.zeros_like(s_sanos), np.ones_like(s_infs)]
                # anomaly_score = -decision_function (mayor = más anómalo)
                scores  = -np.r_[s_sanos, s_infs]
                try:
                    auc = roc_auc_score(y_true, scores)
                except Exception:
                    auc = 0.5
                print(f"[GRID] PCA={ncomp:>3} nu={nu:<4} gamma={str(gamma):<7} | AUC={auc:.3f}")
                if auc > best["auc"]:
                    best.update({"auc": float(auc), "pipe": pipe,
                                 "params": {"pca": ncomp, "nu": nu, "gamma": gamma}})
    print(f"[BEST] AUC={best['auc']:.3f} con params={best['params']}")
    return best["pipe"], best["params"], best["auc"]

# ------------------- Umbrales y métricas -------------------
def pick_threshold(scores_sanos, scores_infs):
    """
    scores_* son decision_function (mayor = más normal).
    Retorna:
      - thr_fpr: umbral por FPR objetivo en sanos (conservador)
      - thr_f1:  umbral que maximiza F1 en validación
    """
    thr_fpr = float(np.quantile(scores_sanos, FPR_TARGET))
    vals = np.sort(np.unique(np.r_[scores_sanos, scores_infs]))
    mids = (vals[1:]+vals[:-1])/2.0 if len(vals)>1 else np.array([vals[0]+1e-8])
    best_f1, thr_f1 = -1, thr_fpr
    y_true = np.r_[np.zeros_like(scores_sanos), np.ones_like(scores_infs)]
    for thr in np.r_[thr_fpr, mids]:
        y_pred = (np.r_[scores_sanos, scores_infs] < thr).astype(int)
        f1 = f1_score(y_true, y_pred, zero_division=0)
        if f1 > best_f1:
            best_f1, thr_f1 = f1, float(thr)
    return thr_fpr, thr_f1, best_f1

def evaluate_and_plot(pipe, X_vs, X_vi, val_paths, thr_used, out_dir: Path):
    (paths_s, paths_i) = val_paths
    # decision_function: mayor = más normal
    s_sanos = pipe.score_samples(X_vs)
    s_infs  = pipe.score_samples(X_vi)
    y_true  = np.r_[np.zeros_like(s_sanos), np.ones_like(s_infs)].astype(int)
    decfun  = np.r_[s_sanos, s_infs]
    anomaly = -decfun

    # pred con umbral
    y_pred = (decfun < thr_used).astype(int)

    # métricas
    metrics = {
        "threshold_used": float(thr_used),
        "f1": float(f1_score(y_true, y_pred, zero_division=0)),
        "accuracy": float(accuracy_score(y_true, y_pred)),
        "precision": float(precision_score(y_true, y_pred, zero_division=0)),
        "recall": float(recall_score(y_true, y_pred)),
        "auc": float(roc_auc_score(y_true, anomaly)),
        "average_precision": float(average_precision_score(y_true, anomaly))
    }

    # CSV de resultados
    import pandas as pd
    rows = []
    for p, sc in zip(paths_s, s_sanos):
        rows.append((p.name, "sano", 0, float(sc)))
    for p, sc in zip(paths_i, s_infs):
        rows.append((p.name, "no_sano", 1, float(sc)))
    df = pd.DataFrame(rows, columns=["image","set","y_true","decision_function"])
    df["anomaly_score"] = -df["decision_function"]
    df["y_pred"] = (df["decision_function"] < thr_used).astype(int)
    df["correct"] = (df["y_pred"] == df["y_true"]).astype(int)
    df.to_csv(VAL_CSV, index=False)

    # Gráficas
    # ROC
    fpr, tpr, _ = roc_curve(y_true, anomaly)
    plt.figure(figsize=(5,5))
    plt.plot(fpr, tpr, lw=2, label=f"AUC={metrics['auc']:.3f}")
    plt.plot([0,1],[0,1],"--",lw=1)
    plt.xlabel("FPR"); plt.ylabel("TPR"); plt.title("ROC (validation)")
    plt.legend(); plt.tight_layout(); plt.savefig(out_dir/"val_roc.png", dpi=160); plt.close()
    # PR
    prec, rec, _ = precision_recall_curve(y_true, anomaly)
    plt.figure(figsize=(5,5))
    plt.plot(rec, prec, lw=2, label=f"AP={metrics['average_precision']:.3f}")
    plt.xlabel("Recall"); plt.ylabel("Precision"); plt.title("PR (validation)")
    plt.legend(); plt.tight_layout(); plt.savefig(out_dir/"val_pr.png", dpi=160); plt.close()
    # Hist con línea de umbral (en decision_function)
    plt.figure(figsize=(6,4))
    plt.hist(s_sanos, bins=30, alpha=0.6, label="sanos")
    plt.hist(s_infs,  bins=30, alpha=0.6, label="no_sanos")
    plt.axvline(thr_used, color="k", ls="--", lw=1.5, label=f"thr={thr_used:.4f}")
    plt.xlabel("decision_function (mayor = más sano)")
    plt.ylabel("count"); plt.title("Distribución decision_function (val)")
    plt.legend(); plt.tight_layout(); plt.savefig(out_dir/"val_hist_decfun.png", dpi=160); plt.close()
    # CM
    save_confusion(y_true, y_pred, out_dir/"val_cm.png", "Matriz de confusión (val)")

    # Guardar métricas
    with open(METRICS_JSON, "w", encoding="utf-8") as f:
        json.dump(metrics, f, indent=2, ensure_ascii=False)

    # Resumen
    print("\n[MÉTRICAS VALIDACIÓN]")
    for k,v in metrics.items():
        print(f"- {k}: {v:.4f}" if isinstance(v,float) else f"- {k}: {v}")
    return metrics

# ------------------- Main -------------------
if __name__ == "__main__":
    print(f"[INFO] Device: {DEVICE}")
    # 1) Embeddings
    X_train, (X_vs, X_vi), val_paths = build_or_load_embeddings()
    print(f"[SHAPE] X_train={X_train.shape} | X_val_sanos={X_vs.shape} | X_val_infectados={X_vi.shape}")

    # 2) Grid Search por AUC
    pipe, params, best_auc = grid_search_auc(X_train, X_vs, X_vi)

    # 3) Umbrales
    s_sanos = pipe.score_samples(X_vs)
    s_infs  = pipe.score_samples(X_vi)
    thr_fpr, thr_f1, best_f1 = pick_threshold(s_sanos, s_infs)
    print(f"[THR] FPR target {int(FPR_TARGET*100)}% -> {thr_fpr:.6f}")
    print(f"[THR] Mejor F1 en validación -> {thr_f1:.6f} (F1={best_f1:.3f})")

    # 4) Evaluar y graficar (usar F1 por defecto; cambia a thr_fpr si quieres conservador)
    thr_used = thr_f1
    metrics = evaluate_and_plot(pipe, X_vs, X_vi, val_paths, thr_used, OUT_DIR)

    # 5) Guardar pipeline y umbral
    joblib.dump({"pipeline": pipe, "params": params, "thr_used": float(thr_used),
                 "cnn": "resnet18", "img_size": IMG_SIZE},
                MODEL_PATH)
    print(f"[OK] Guardado: {MODEL_PATH}\nArtefactos en {OUT_DIR}")


[INFO] Device: cpu
[CACHE] X_train sanos -> (250, 512)
[CACHE] X_val sanos -> (250, 512) | X_val infectados -> (250, 512)
[SHAPE] X_train=(250, 512) | X_val_sanos=(250, 512) | X_val_infectados=(250, 512)
[GRID] PCA= 64 nu=0.01 gamma=scale   | AUC=0.807
[GRID] PCA= 64 nu=0.01 gamma=0.001   | AUC=0.789
[GRID] PCA= 64 nu=0.01 gamma=0.0001  | AUC=0.788
[GRID] PCA= 64 nu=0.01 gamma=1e-05   | AUC=0.786
[GRID] PCA= 64 nu=0.02 gamma=scale   | AUC=0.807
[GRID] PCA= 64 nu=0.02 gamma=0.001   | AUC=0.789
[GRID] PCA= 64 nu=0.02 gamma=0.0001  | AUC=0.788
[GRID] PCA= 64 nu=0.02 gamma=1e-05   | AUC=0.787
[GRID] PCA= 64 nu=0.05 gamma=scale   | AUC=0.807
[GRID] PCA= 64 nu=0.05 gamma=0.001   | AUC=0.790
[GRID] PCA= 64 nu=0.05 gamma=0.0001  | AUC=0.790
[GRID] PCA= 64 nu=0.05 gamma=1e-05   | AUC=0.791
[GRID] PCA= 64 nu=0.1  gamma=scale   | AUC=0.807
[GRID] PCA= 64 nu=0.1  gamma=0.001   | AUC=0.789
[GRID] PCA= 64 nu=0.1  gamma=0.0001  | AUC=0.789
[GRID] PCA= 64 nu=0.1  gamma=1e-05   | AUC=0.789
[GRID] PCA=1