In [None]:
# ============================================================
# One-Class SVM con embeddings CNN (ResNet18) para células
# Directorios esperados:
#   ./sanos      -> imágenes de células sanas (para entrenar y validar)
#   ./no_sanos   -> imágenes de células infectadas (solo validar)
#
# Salidas:
#   ./out_oneclass/
#     - model_ocsvm.joblib         (modelo)
#     - embeddings_train_sanos.npy (opcional cache)
#     - val_results.csv            (scores y predicciones)
#     - metrics.json               (métricas globales)
#     - figuras: ROC/PR/Hist/CM
# ============================================================

import os
from pathlib import Path
import json
import numpy as np
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
from sklearn.svm import OneClassSVM
from sklearn.metrics import (
    classification_report, roc_auc_score, roc_curve, precision_recall_curve,
    average_precision_score, confusion_matrix, ConfusionMatrixDisplay,
    f1_score, accuracy_score, precision_score, recall_score
)
import matplotlib.pyplot as plt
import joblib
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

# ------------------- Rutas -------------------
DIR_TRAIN_SANOS = Path("./sampled/sanos")
DIR_VAL_SANOS   = Path("./sampled/sanos")     # si tienes carpeta distinta, cámbiala
DIR_VAL_INFECT  = Path("./sampled/nosanos")

OUT_DIR = Path("./out_oneclass")
OUT_DIR.mkdir(parents=True, exist_ok=True)

MODEL_PATH = OUT_DIR / "model_ocsvm.joblib"
EMB_CACHE  = OUT_DIR / "embeddings_train_sanos.npy"
VAL_CSV    = OUT_DIR / "val_results.csv"
METRICS_JSON = OUT_DIR / "metrics.json"

# ------------------- Configuración -------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IMG_SIZE = 224
BATCH = 32

# One-Class SVM hiperparámetros:
OCSV_NU = 0.05      # fracción esperada de outliers en sanos
OCSV_GAMMA = "scale"

# Umbral: objetivo de FPR en sanos (percentil de scores sanos, sobre decision_function)
FPR_TARGET = 0.05   # 5% de sanos como “anómalos” (umbral conservador)

# ------------------- Utilidades -------------------
def list_images(root: Path):
    exts = {".png",".jpg",".jpeg",".tif",".tiff",".bmp"}
    return sorted([p for p in root.rglob("*") if p.suffix.lower() in exts])

def build_model_resnet18():
    # 1) Cargar pesos de forma compatible con distintas versiones
    weights = None
    try:
        # torchvision >= 0.13
        weights = models.ResNet18_Weights.IMAGENET1K_V1
        base = models.resnet18(weights=weights)
    except Exception:
        # torchvision viejo: usa 'pretrained=True'
        base = models.resnet18(pretrained=True)

    # 2) Quitar la última capa (clasificador) → extractor de embeddings
    emb_net = nn.Sequential(*list(base.children())[:-1]).to(DEVICE).eval()

    # 3) Transform de entrada (preferir weights.transforms(); si no, fallback a medias/std de ImageNet)
    tfm = None
    if weights is not None:
        try:
            # Esto ya incluye Resize/Crop/ToTensor/Normalize, pero ajustamos tamaño fijo
            base_tfm = weights.transforms()
            tfm = transforms.Compose([
                transforms.Resize((IMG_SIZE, IMG_SIZE)),
                base_tfm
            ])
        except Exception:
            pass

    if tfm is None:
        # Fallback seguro (valores estándar de ImageNet)
        tfm = transforms.Compose([
            transforms.Resize((IMG_SIZE, IMG_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

    return emb_net, tfm

@torch.no_grad()
def embeddings_from_paths(paths, emb_net, tfm, batch=BATCH):
    vecs = []
    for i in range(0, len(paths), batch):
        chunk = paths[i:i+batch]
        imgs = []
        for p in chunk:
            im = Image.open(p).convert("RGB")
            imgs.append(tfm(im))
        x = torch.stack(imgs, dim=0).to(DEVICE)  # (B,3,H,W)
        feat = emb_net(x)                        # (B,512,1,1)
        feat = feat.view(feat.size(0), -1)       # (B,512)
        vecs.append(feat.cpu().numpy())
    if len(vecs) == 0:
        return np.zeros((0,512), dtype=np.float32)
    return np.vstack(vecs).astype(np.float32)

def ensure_non_empty(paths, msg):
    if len(paths) == 0:
        raise RuntimeError(msg)

def save_confusion(y_true, y_pred, out_path, title="Confusion matrix"):
    cm = confusion_matrix(y_true, y_pred, labels=[0,1])
    disp = ConfusionMatrixDisplay(cm, display_labels=["sano(0)","infectado(1)"])
    fig, ax = plt.subplots(figsize=(4.6,4.6))
    disp.plot(ax=ax, cmap="Blues", values_format="d", colorbar=False)
    ax.set_title(title)
    plt.tight_layout(); plt.savefig(out_path, dpi=160); plt.close()

# ------------------- Entrenamiento -------------------
def train_oneclass():
    print(f"[INFO] Dispositivo: {DEVICE}")
    paths_train_sanos = list_images(DIR_TRAIN_SANOS)
    ensure_non_empty(paths_train_sanos, f"No hay imágenes en {DIR_TRAIN_SANOS}")

    emb_net, tfm = build_model_resnet18()

    # Cachear embeddings si ya existen
    if EMB_CACHE.exists():
        X_train = np.load(EMB_CACHE)
        print(f"[OK] Embeddings de train cargados: {X_train.shape}")
    else:
        print(f"[EMB] Extrayendo embeddings de sanos (train): {len(paths_train_sanos)}")
        X_train = embeddings_from_paths(paths_train_sanos, emb_net, tfm)
        np.save(EMB_CACHE, X_train)
        print(f"[OK] Guardado cache embeddings en {EMB_CACHE} -> {X_train.shape}")

    # One-Class SVM (solo con sanos)
    print(f"[OCSVM] Entrenando One-Class SVM (nu={OCSV_NU}, gamma={OCSV_GAMMA})")
    ocsvm = OneClassSVM(kernel="rbf", nu=OCSV_NU, gamma=OCSV_GAMMA)
    ocsvm.fit(X_train)

    # Guardar modelo (y metadatos básicos)
    joblib.dump({
        "ocsvm": ocsvm,
        "cnn": "resnet18-imagenet",
        "img_size": IMG_SIZE,
        "nu": OCSV_NU,
        "gamma": OCSV_GAMMA
    }, MODEL_PATH)
    print(f"[OK] Modelo guardado en {MODEL_PATH}")

# ------------------- Validación -------------------
def validate_oneclass():
    # Carga modelo
    bundle = joblib.load(MODEL_PATH)
    ocsvm = bundle["ocsvm"]
    emb_net, tfm = build_model_resnet18()

    # Cargar sets
    paths_val_sanos = list_images(DIR_VAL_SANOS)
    paths_val_infec = list_images(DIR_VAL_INFECT)
    ensure_non_empty(paths_val_sanos, f"No hay imágenes en {DIR_VAL_SANOS}")
    ensure_non_empty(paths_val_infec, f"No hay imágenes en {DIR_VAL_INFECT}")
    print(f"[VAL] Sanos: {len(paths_val_sanos)} | Infectados: {len(paths_val_infec)}")

    # Embeddings val
    print("[EMB] Extrayendo embeddings de validación…")
    X_sanos = embeddings_from_paths(paths_val_sanos, emb_net, tfm)
    X_infs  = embeddings_from_paths(paths_val_infec, emb_net, tfm)

    # decision_function: >0 inlier (sano), <0 outlier (anómalo)
    df_sanos = ocsvm.decision_function(X_sanos).reshape(-1)  # mayores = más sano
    df_infs  = ocsvm.decision_function(X_infs).reshape(-1)

    # Construir dataset de validación
    scores_df = []   # scores de “normalidad”
    for p, s in zip(paths_val_sanos, df_sanos):
        scores_df.append((p.name, 0, float(s)))  # 0=sano
    for p, s in zip(paths_val_infec, df_infs):
        scores_df.append((p.name, 1, float(s)))  # 1=infectado

    # DataFrame
    import pandas as pd
    df = pd.DataFrame(scores_df, columns=["image","y_true","decision_function"])
    df["set"] = np.where(df["y_true"]==0, "sanos", "no_sanos")

    # --------- Umbral ----------
    # Infectado si decision_function < thr
    # a) Calibración conservadora por sanos: FPR_TARGET (e.g., 5%)
    thr_conserv = float(np.quantile(df[df.y_true==0]["decision_function"].to_numpy(), FPR_TARGET))
    print(f"[THR] Umbral conservador (FPR≈{int(FPR_TARGET*100)}% en sanos): {thr_conserv:.6f}")

    # b) (opcional) Mejor F1 en validación (grid en puntos medios)
    vals = np.sort(np.unique(df["decision_function"].to_numpy()))
    mids = (vals[1:]+vals[:-1])/2.0 if len(vals)>1 else np.array([vals[0]+1e-8])
    best = {"thr": thr_conserv, "f1": -1}
    for thr in np.concatenate([[thr_conserv], mids]):
        y_pred = (df["decision_function"].to_numpy() < thr).astype(int)
        f1 = f1_score(df["y_true"], y_pred, zero_division=0)
        if f1 > best["f1"]:
            best.update({"thr": float(thr), "f1": float(f1)})
    thr_f1 = best["thr"]
    print(f"[THR] Umbral por mejor F1 en validación: {thr_f1:.6f} (F1={best['f1']:.3f})")

    # Elige cuál usar (conservador o por F1). Aquí usamos F1:
    thr_used = thr_f1

    # --------- Predicciones y métricas ----------
    # Definimos “anomalia_score” = -decision_function (mayor = más anómalo)
    df["anomaly_score"] = -df["decision_function"].to_numpy()
    y_true = df["y_true"].to_numpy().astype(int)
    y_pred = (df["decision_function"].to_numpy() < thr_used).astype(int)  # 1=infectado
    df["y_pred"] = y_pred
    df["correct"] = (df["y_pred"]==y_true).astype(int)

    # Métricas globales
    metrics = {
        "threshold_used": float(thr_used),
        "f1": float(f1_score(y_true, y_pred, zero_division=0)),
        "accuracy": float(accuracy_score(y_true, y_pred)),
        "precision": float(precision_score(y_true, y_pred, zero_division=0)),
        "recall": float(recall_score(y_true, y_pred)),
    }
    try:
        auc = roc_auc_score(y_true, df["anomaly_score"].to_numpy())
        ap  = average_precision_score(y_true, df["anomaly_score"].to_numpy())
    except Exception:
        auc, ap = None, None
    metrics["auc"] = None if auc is None else float(auc)
    metrics["average_precision"] = None if ap is None else float(ap)

    df.to_csv(VAL_CSV, index=False)
    with open(METRICS_JSON, "w", encoding="utf-8") as f:
        json.dump(metrics, f, indent=2, ensure_ascii=False)

    print("\n[VAL] Métricas globales:")
    for k,v in metrics.items():
        print(f"- {k}: {v:.4f}" if isinstance(v, float) else f"- {k}: {v}")

    # --------- Figuras ----------
    # ROC / PR con anomaly_score
    try:
        fpr, tpr, _ = roc_curve(y_true, df["anomaly_score"].to_numpy())
        plt.figure(figsize=(5,5))
        plt.plot(fpr, tpr, lw=2, label=f"AUC={metrics['auc']:.3f}")
        plt.plot([0,1],[0,1],"--",lw=1)
        plt.xlabel("FPR"); plt.ylabel("TPR"); plt.title("ROC (One-Class, val)")
        plt.legend(); plt.tight_layout()
        plt.savefig(OUT_DIR / "val_roc_curve.png", dpi=160); plt.close()

        prec, rec, _ = precision_recall_curve(y_true, df["anomaly_score"].to_numpy())
        plt.figure(figsize=(5,5))
        plt.plot(rec, prec, lw=2, label=f"AP={metrics['average_precision']:.3f}")
        plt.xlabel("Recall"); plt.ylabel("Precision"); plt.title("PR (One-Class, val)")
        plt.legend(); plt.tight_layout()
        plt.savefig(OUT_DIR / "val_pr_curve.png", dpi=160); plt.close()
    except Exception:
        pass

    # Histograma de decision_function por set con umbral
    plt.figure(figsize=(6,4))
    for name, sub in df.groupby("set"):
        plt.hist(sub["decision_function"], bins=30, alpha=0.6, label=name)
    plt.axvline(thr_used, color="k", ls="--", lw=1.5, label=f"thr={thr_used:.4f}")
    plt.xlabel("decision_function (mayor = más sano)")
    plt.ylabel("count"); plt.title("Distribución de decision_function (val)")
    plt.legend(); plt.tight_layout()
    plt.savefig(OUT_DIR / "val_hist_decision_function.png", dpi=160); plt.close()

    # Confusion matrix
    save_confusion(y_true, y_pred, OUT_DIR / "val_confusion_matrix.png",
                   title="Matriz de confusión (val)")

    # Reporte por consola
    print("\n[VAL] Classification report:")
    print(classification_report(y_true, y_pred, target_names=["sano","infectado"]))

# ------------------- Main -------------------
if __name__ == "__main__":
    print("=== ENTRENAR One-Class (solo sanos) ===")
    train_oneclass()
    print("=== VALIDAR en sanos + no_sanos ===")
    validate_oneclass()
    print(f"Listo ✅  Artefactos en: {OUT_DIR}")


=== ENTRENAR One-Class (solo sanos) ===
[INFO] Dispositivo: cpu
[OK] Embeddings de train cargados: (250, 512)
[OCSVM] Entrenando One-Class SVM (nu=0.05, gamma=scale)
[OK] Modelo guardado en out_oneclass/model_ocsvm.joblib
=== VALIDAR en sanos + no_sanos ===
[VAL] Sanos: 250 | Infectados: 250
[EMB] Extrayendo embeddings de validación…
