# COMPROBACIONES PREVIAS

In [None]:
import os, sys, torch, subprocess, json
print("PY:", sys.executable)
print("CWD:", os.getcwd())
print("torch:", torch.__version__, " HIP:", getattr(torch.version, "hip", None), " CUDA:", getattr(torch.version, "cuda", None))
print("GPU? ", torch.cuda.is_available())

# ¿Existen los dispositivos GPU?
print("\n/dev/kfd exists?", os.path.exists("/dev/kfd"))
print("/dev/dri/renderD128 exists?", os.path.exists("/dev/dri/renderD128"))

# ¿Estoy en el venv correcto?
print("\nsite-packages:", next(p for p in sys.path if p.endswith("site-packages")))


In [None]:
import torch, os, transformers
print("GPU?", torch.cuda.is_available(), getattr(torch.version, "hip", None))
if torch.cuda.is_available(): print(torch.cuda.get_device_name(0))
print("HF_HOME:", os.environ.get("HF_HOME"))

In [None]:
# Bootstrap
# === Preparación del entorno para la ejecución de todas las celdas (entorno local Docker+ROCm) ===
import os
import sys
from pathlib import Path

# 1) Variables de entorno (en Docker ya vienen saneadas, aquí solo las fijamos por si faltan)
os.environ.setdefault("HF_HOME", "/root/.cache/huggingface")   # cache local persistida por volumen
os.environ["TOKENIZERS_PARALLELISM"] = "false"                 # evita deadlocks
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
Path(os.environ["HF_HOME"]).mkdir(parents=True, exist_ok=True)

# 2) Localización del repo (sin Drive, sin git clone). Buscamos la raíz que contenga src/ y requirements.txt
def _find_repo_root(start: Path) -> Path:
    for p in [start, *start.parents]:
        if (p / "src").exists() and (p / "requirements.txt").exists():
            return p
    return start  # fallback: cwd

REPO_ROOT = _find_repo_root(Path.cwd())
DEST = REPO_ROOT  # mantenemos el nombre usado en el resto del cuaderno

# 3) Añadir src/ al PYTHONPATH
sys.path.append(str(DEST / "src"))

print("Entorno listo.")
print("HF_HOME:", os.environ["HF_HOME"])
print("Repo root:", DEST)
print("SRC en sys.path:", str(DEST / "src") in sys.path)

In [None]:
# (Local) Repo ya presente: no clonamos ni hacemos pull aquí.
# Usamos DEST definido en la celda Bootstrap.
from pathlib import Path

assert 'DEST' in globals(), "Asegúrate de ejecutar primero la celda Bootstrap que define DEST."
assert (DEST / "src").exists(), f"No se encontró src/ en {DEST}. Revisa tu ruta de trabajo o el montaje del volumen."

print("Repo listo en:", DEST)

In [None]:
# --- 4) Añadir tu src al PYTHONPATH y verificar estructura ---
import sys
sys.path.append(str(DEST / "src"))

!ls -la /content/repo

In [None]:
# Preparación de artifacts locales (entorno Docker+ROCm)
import os, sys, zipfile
from pathlib import Path

# Evita cargar TensorFlow por accidente (no lo usamos aquí)
assert "tensorflow" not in sys.modules, "TensorFlow está importado; desactívalo."

# Usamos DEST de la celda Bootstrap como raíz del repo
assert 'DEST' in globals(), "Ejecuta primero la celda Bootstrap (define DEST)."
REPO_ROOT = DEST

# Carpeta donde dejaremos/leeremos los artifacts
ART_DIR = Path("/workspace/BERTolto/artifacts")
ART_DIR.mkdir(parents=True, exist_ok=True)

# Intentamos localizar el ZIP existente (ajusta si lo guardaste en otro sitio)
ZIP_CANDIDATES = [
    REPO_ROOT / "data" / "hf_distilroberta.zip",
    REPO_ROOT / "artifacts" / "hf_distilroberta.zip",
    Path.home() / "BERTolto" / "data" / "hf_distilroberta.zip",
]

zip_path = next((p for p in ZIP_CANDIDATES if p.exists()), None)

# Extrae solo si no está ya la carpeta descomprimida
BASE_ARTI = ART_DIR / "hf_distilroberta"
if zip_path and not BASE_ARTI.exists():
    print(f"Descomprimiendo: {zip_path} -> {ART_DIR}")
    with zipfile.ZipFile(zip_path, "r") as zf:
        zf.extractall(ART_DIR)

# Valida que exista la carpeta final
assert BASE_ARTI.exists(), (
    f"No se encontró {BASE_ARTI}. Coloca 'hf_distilroberta.zip' en "
    f"'{REPO_ROOT}/data' o '{REPO_ROOT}/artifacts', o extrae manualmente la carpeta aquí."
)

print("Artifacts en:", BASE_ARTI)
print("Contenido:", [p.name for p in BASE_ARTI.iterdir()])


In [None]:
# Verificación ligera del stack (sin instalaciones en el notebook)
from importlib.metadata import version as _ver, PackageNotFoundError

def v(pkg):
    try:
        return _ver(pkg)
    except PackageNotFoundError:
        return None

print("Stack de librerías:", {
    "transformers": v("transformers"),
    "datasets": v("datasets"),
    "accelerate": v("accelerate"),
    "huggingface_hub": v("huggingface_hub"),
    "tokenizers": v("tokenizers"),
    "evaluate": v("evaluate"),
    "peft": v("peft"),
    "scikit-learn": v("scikit-learn"),
})
print("OK. Continúa con la celda de entrenamiento.")

In [None]:
# ENTRENAMIENTO DEL MODELO

#*** Carga dataset tokenizado y fine-tuning supervisado de DistilRoBERTa (binario 0/1) ***#
from pathlib import Path
import json, sys
import numpy as np
import torch
import torch.nn as nn
from datasets import load_from_disk

# Asegura que no quede una versión previa de transformers en memoria
sys.modules.pop("transformers", None)

import transformers  # import base primero para fijar versión en sys.modules
print("Transformers version:", transformers.__version__)

from transformers import (
    AutoTokenizer, AutoConfig, AutoModelForSequenceClassification,
    TrainingArguments, Trainer, DataCollatorWithPadding, set_seed
)
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# Directorio de salida (volumen persistente montado por Docker)
OUT_DIR = "/checkpoints/run_distilroberta"

# Localiza los artifacts
# Si la celda de "Preparación de artifacts" ya definió BASE_ARTI, la reutilizamos.
# Si no, hacemos fallback a la ruta estándar dentro del repo.
if "BASE_ARTI" in globals():
    BASE_ARTI = Path(BASE_ARTI)
else:
    BASE_ARTI = Path("/workspace/BERTolto/artifacts/hf_distilroberta")

# Si hubo anidamiento tipo hf_distilroberta/hf_distilroberta, corrige
if not (BASE_ARTI / "dataset").exists() and (BASE_ARTI / "hf_distilroberta").exists():
    BASE_ARTI = BASE_ARTI / "hf_distilroberta"

assert (BASE_ARTI / "dataset").exists() and (BASE_ARTI / "tokenizer").exists(), \
    f"Faltan dataset/ o tokenizer/ en {BASE_ARTI}. Revisa la celda de preparación de artifacts."
print("OK artifacts:", BASE_ARTI)

# Dispositivo y GPU info (ROCm)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("GPU disponible:", torch.cuda.is_available(),
      "| HIP:", getattr(torch.version, "hip", None))
if torch.cuda.is_available():
    print("Device 0:", torch.cuda.get_device_name(0))

set_seed(42)

# --- PARCHE ROBUSTO V2 para carga desde disk con datasets==4.x ---
# Idempotente y sin recursión. Evita fallos al reconstruir `features` de DatasetInfo.
import datasets

DI = datasets.info.DatasetInfo
Feat = datasets.Features

# 1) Parchea from_dict una sola vez: si las features del JSON fallan, las ignora (None)
if not hasattr(DI, "_orig_from_dict"):
    DI._orig_from_dict = DI.from_dict

    @classmethod
    def _safe_from_dict(cls, dataset_info_dict: dict):
        try:
            return cls._orig_from_dict(dataset_info_dict)
        except Exception as e:
            dd = dict(dataset_info_dict or {})
            if "features" in dd:
                dd["features"] = None
            obj = cls._orig_from_dict(dd)
            print("[patch] DatasetInfo.from_dict: ignorando 'features' corruptas -> usando None.",
                  f"({type(e).__name__})")
            return obj

    DI.from_dict = _safe_from_dict

# 2) Sustituye __post_init__ por una versión NO recursiva y segura (una sola vez)
if not hasattr(DI, "_post_init_patched"):
    def _safe_post_init(self):
        try:
            f = getattr(self, "features", None)
            if f is not None and not isinstance(f, Feat):
                try:
                    self.features = Feat.from_dict(f)
                except Exception:
                    self.features = None
        except Exception:
            self.features = None
        # No llamamos al __post_init__ original

    DI.__post_init__ = _safe_post_init
    DI._post_init_patched = True
# --- FIN PARCHE ROBUSTO V2 ---

# dataset/tokenizer
ds  = load_from_disk(str(BASE_ARTI / "dataset"))

## Comprobación de splits + fallback opcional
print("Splits disponibles:", list(ds.keys()))
needed = {"train", "validation", "test"}
avail = set(ds.keys())

## Si por accidente el split se llama "dev", lo renombramos a "validation"
if "validation" not in avail and "dev" in avail:
    ds["validation"] = ds["dev"]
    del ds["dev"]
    avail = set(ds.keys())
    print("Renombrado 'dev' -> 'validation'")

## Si faltan splits, corta con mensaje claro
missing = needed - avail
assert not missing, f"Faltan splits: {missing}. Revisa el paso de tokenización/guardado."

tok = AutoTokenizer.from_pretrained(str(BASE_ARTI / "tokenizer"), use_fast=True)

# métricas (sklearn, evita evaluate.load)
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = (logits[:, 1] > logits[:, 0]).astype(int)

    acc = accuracy_score(labels, preds)
    pr  = precision_score(labels, preds, zero_division=0)
    rec = recall_score(labels, preds, zero_division=0)
    f1  = f1_score(labels, preds, zero_division=0)

    return {"accuracy": acc, "precision": pr, "recall": rec, "f1": f1}

# class weights (si existen)
cw = None
meta_path = BASE_ARTI / "preprocess_meta.json"
if meta_path.exists():
    meta = json.loads(meta_path.read_text())
    if "class_weights" in meta and isinstance(meta["class_weights"], dict):
        # admite claves "0"/"1" o 0/1
        cw_map = {int(k): float(v) for k, v in meta["class_weights"].items()}
        cw = np.array([cw_map.get(0, 1.0), cw_map.get(1, 1.0)], dtype=np.float32)
        print("Class weights:", cw.tolist())

config = AutoConfig.from_pretrained("distilroberta-base", num_labels=2)
model  = AutoModelForSequenceClassification.from_pretrained("distilroberta-base", config=config)

# Collator que ignora claves no usadas por el modelo y añade labels aparte
class SafeCollator(DataCollatorWithPadding):
    def __call__(self, features):
        # recoge labels (acepta 'labels' o 'label')
        labels = []
        cleaned = []
        for f in features:
            if "labels" in f:
                labels.append(f["labels"])
            elif "label" in f:
                labels.append(f["label"])
            # solo claves esperadas por el modelo
            nf = {}
            for k in ("input_ids", "attention_mask", "token_type_ids"):
                if k in f:
                    nf[k] = f[k]
            cleaned.append(nf)

        batch = self.tokenizer.pad(cleaned, padding=True, return_tensors="pt")
        if labels:
            batch["labels"] = torch.tensor(labels, dtype=torch.long)
        return batch

collator = SafeCollator(tokenizer=tok, return_tensors="pt")

# Trainer con pérdida ponderada (si hay weights) y filtrado de inputs
class WeightedTrainer(Trainer):
    def __init__(self, class_weights=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.class_weights = (
            torch.tensor(class_weights, dtype=torch.float32) if class_weights is not None else None
        )

    # En Transformers 4.56.x, Trainer pasa num_items_in_batch; lo aceptamos e ignoramos.
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        # Acepta 'labels' o 'label'
        labels = inputs.pop("labels", None)
        if labels is None and "label" in inputs:
            labels = inputs.pop("label")
        if labels is None:
            raise ValueError("Missing 'labels' in inputs")

        # Filtra claves inesperadas (evita pasar id/context_id al modelo)
        allowed = {"input_ids", "attention_mask", "token_type_ids"}
        model_inputs = {k: v for k, v in inputs.items() if k in allowed}

        outputs = model(**model_inputs)
        logits = outputs.logits  # [B, 2]

        # CrossEntropy con o sin pesos
        if self.class_weights is not None:
            loss_fct = nn.CrossEntropyLoss(weight=self.class_weights.to(logits.device))
        else:
            loss_fct = nn.CrossEntropyLoss()

        loss = loss_fct(logits.view(-1, 2), labels.view(-1).long())
        return (loss, outputs) if return_outputs else loss

# --- TrainingArguments (compatible 4.x) ---
from inspect import signature

TA = TrainingArguments
ta_params = set(signature(TA).parameters.keys())

# Compat: algunas versiones usan 'eval_strategy' en vez de 'evaluation_strategy'
eval_key = "evaluation_strategy" if "evaluation_strategy" in ta_params else (
    "eval_strategy" if "eval_strategy" in ta_params else None
)

# 'report_to' en versiones recientes debe ser lista (no string)
report_to_val = []  # equivalente a "none"

ta_kwargs = dict(
    output_dir=OUT_DIR,
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    greater_is_better=True,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    gradient_accumulation_steps=2,
    num_train_epochs=3,
    learning_rate=2e-5,
    weight_decay=0.01,
    warmup_ratio=0.1,
    fp16=torch.cuda.is_available(),  # ROCm AMP funciona vía torch.cuda en build ROCm
    logging_steps=50,
    report_to=report_to_val,
    dataloader_num_workers=2,
    remove_unused_columns=False,   # conservamos id/context_id en el dataset
    save_total_limit=2,
    seed=42,
)

# Si existe clave de evaluación, la añadimos con "epoch"
if eval_key is not None:
    ta_kwargs[eval_key] = "epoch"
else:
    print("[warn] TrainingArguments sin evaluation_strategy/eval_strategy; usando configuración por defecto.")

args = TA(**ta_kwargs)

trainer = WeightedTrainer(
    class_weights=cw,
    model=model, args=args,
    train_dataset=ds["train"],
    eval_dataset=ds["validation"],
    tokenizer=tok,
    data_collator=collator,
    compute_metrics=compute_metrics,
)

trainer.train()

SAVE_DIR = "/checkpoints/model_distilroberta_ft"
trainer.save_model(SAVE_DIR)
tok.save_pretrained(SAVE_DIR)
print("Modelo guardado en:", SAVE_DIR)


In [None]:
"""
## Calibración de umbral (en validación, con pooling por comentario)
### Obtenemos logits por ventana y agrupamos por comentario (id + context_id).
### Elegimos el umbral que maximiza recall con precisión ≥ objetivo (ajústalo).
import json
import numpy as np
import torch
from torch.utils.data import DataLoader
from collections import defaultdict
from transformers import DataCollatorWithPadding

try:
    from scipy.special import softmax
except Exception:
    def softmax(x, axis=-1):
        x = np.asarray(x)
        x = x - np.max(x, axis=axis, keepdims=True)
        expx = np.exp(x)
        return expx / np.sum(expx, axis=axis, keepdims=True)

TARGET_PREC = 0.90  # ajusta tu objetivo

collator_eval = DataCollatorWithPadding(tokenizer=tok, return_tensors="pt")

# DataLoader conservando metadatos y aplicando padding dinámico
def make_loader(split, bs=64):
    d = ds[split]

    def collate(ex):
        # features para el modelo (pad con collator)
        feats = []
        for e in ex:
            item = {
                "input_ids": e["input_ids"],
                "attention_mask": e["attention_mask"],
            }
            # estandariza a 'labels' si está disponible
            if "labels" in e:
                item["labels"] = e["labels"]
            elif "label" in e:
                item["labels"] = e["label"]
            feats.append(item)
        batch = collator_eval(feats)
        meta = {"id": [e.get("id") for e in ex], "context_id": [e.get("context_id") for e in ex]}
        return batch, meta

    return DataLoader(d, batch_size=bs, shuffle=False, collate_fn=collate)

# logits + metadatos
def logits_with_meta(split):
    loader = make_loader(split)
    mdl = trainer.model.to(DEVICE).eval()
    all_logits, all_labels, all_ids, all_ctx = [], [], [], []

    with torch.no_grad():
        for (batch, meta) in loader:
            labels = batch.pop("labels", None)
            for k in ("input_ids", "attention_mask", "token_type_ids"):
                if k in batch:
                    batch[k] = batch[k].to(DEVICE)
            out = mdl(**batch).logits.detach().cpu().numpy()
            all_logits.append(out)
            if labels is not None:
                all_labels.append(labels.numpy())
            all_ids.extend(meta["id"])
            all_ctx.extend(meta["context_id"])

    labels_arr = np.concatenate(all_labels) if all_labels else None
    return np.vstack(all_logits), labels_arr, np.array(all_ids), np.array(all_ctx)

# pooling por comentario (máx logit)
def pool_comment_metrics(split, threshold):
    logits, labels, ids, ctx = logits_with_meta(split)
    bucket = defaultdict(lambda: {"logits":[], "label":None})

    # Si no hay labels (no debería ocurrir en val/test), aborta con métricas vacías
    if labels is None:
        return dict(precision=0.0, recall=0.0, f1=0.0, accuracy=0.0, counts=dict(tp=0,fp=0,fn=0,tn=0))

    for i in range(len(ids)):
        key = (ids[i], ctx[i])   # o solo id, si prefieres
        bucket[key]["logits"].append(logits[i])
        bucket[key]["label"] = int(labels[i])

    tp=fp=fn=tn=0

    for _, v in bucket.items():
        L = np.stack(v["logits"], axis=0)   # [num_windows, 2]
        pooled = L.max(axis=0)              # máx por clase
        p1 = softmax(pooled)[1]
        pred = int(p1 >= threshold)
        y    = v["label"]
        tp += int(pred==1 and y==1); fp += int(pred==1 and y==0)
        fn += int(pred==0 and y==1); tn += int(pred==0 and y==0)

    prec = tp/(tp+fp+1e-9); rec = tp/(tp+fn+1e-9)
    f1   = 2*prec*rec/(prec+rec+1e-9); acc = (tp+tn)/(tp+tn+fp+fn+1e-9)

    return dict(precision=prec, recall=rec, f1=f1, accuracy=acc, counts=dict(tp=tp,fp=fp,fn=fn,tn=tn))

# Barrido de umbral en VALIDATION
best_t, best_rec = 0.5, -1.0

for t in np.linspace(0.05, 0.95, 181):
    m = pool_comment_metrics("validation", t)
    if m["precision"] >= TARGET_PREC and m["recall"] > best_rec:
        best_t, best_rec = t, m["recall"]

# guarda threshold
from pathlib import Path as _Path
thr_path = _Path(SAVE_DIR) / "threshold.json"
thr_path.write_text(json.dumps({"threshold": float(best_t),
                                "target_precision": TARGET_PREC,
                                "recall_at_threshold": float(best_rec)}, indent=2), encoding="utf-8")

print("threshold:", best_t, "recall@", TARGET_PREC, "=", best_rec)
"""

In [None]:
## Calibración autónoma desde disco (sin reentrenar)
from pathlib import Path
import json, numpy as np, torch
from datasets import load_from_disk
from transformers import (AutoTokenizer, AutoConfig, AutoModelForSequenceClassification,
                          TrainingArguments, Trainer, DataCollatorWithPadding)

# RUTAS — AJUSTA si las cambiaste
BASE_ARTI = Path("/content/artifacts/hf_distilroberta")   # contiene dataset/, tokenizer/
SAVE_DIR  = "/content/drive/MyDrive/BERTolto/model_distilroberta_ft"  # modelo final guardado
OUT_TMP   = "/content/tmp_calib"

# 1) Dataset y tokenizer
ds  = load_from_disk(str(BASE_ARTI / "dataset"))
tok = AutoTokenizer.from_pretrained(str(BASE_ARTI / "tokenizer"), use_fast=True)

# 2) Modelo (final). Si no existiera SAVE_DIR, podrías usar un checkpoint de OUT_DIR/checkpoint-XXXX/
config = AutoConfig.from_pretrained(SAVE_DIR)
model  = AutoModelForSequenceClassification.from_pretrained(SAVE_DIR, config=config)

# 3) Trainer “ligero” solo para predict
args = TrainingArguments(output_dir=OUT_TMP,
                         per_device_eval_batch_size=128,
                         fp16=torch.cuda.is_available(),
                         dataloader_num_workers=2,
                         report_to=[])
trainer = Trainer(model=model, args=args, processing_class=tok,
                  data_collator=DataCollatorWithPadding(tokenizer=tok, return_tensors="pt"))

# --- Calibración rápida (un único predict) ---
def _prepare_pred_dataset(split):
    d = ds[split]
    cols = d.column_names
    label_col = "labels" if "labels" in cols else ("label" if "label" in cols else None)
    assert label_col is not None, f"Split {split} sin columna de label."
    keep = ["input_ids", "attention_mask", label_col] + (["token_type_ids"] if "token_type_ids" in cols else [])
    ids = d["id"] if "id" in cols else list(range(len(d)))
    ctx = d["context_id"] if "context_id" in cols else [0] * len(d)
    pred_ds = d.remove_columns([c for c in cols if c not in keep])
    labels = np.array(d[label_col])
    return pred_ds, labels, np.array(ids), np.array(ctx)

def logits_labels_ids(split):
    pred_ds, labels, ids, ctx = _prepare_pred_dataset(split)
    pred_out = trainer.predict(pred_ds)
    logits = pred_out.predictions  # [N, 2]
    return logits, labels, ids, ctx

def softmax_np(x):
    x = x - x.max(axis=-1, keepdims=True)
    ex = np.exp(x)
    return ex / ex.sum(axis=-1, keepdims=True)

logits_val, labels_val, ids_val, ctx_val = logits_labels_ids("validation")
p1_val = softmax_np(logits_val)[:, 1]

keys = np.core.defchararray.add(ids_val.astype(str), "::" + ctx_val.astype(str))
order = np.argsort(keys)
keys_s, p1_s, y_s = keys[order], p1_val[order], labels_val[order]
grp_st = np.r_[0, 1 + np.flatnonzero(keys_s[1:] != keys_s[:-1])]
pooled_p1 = np.maximum.reduceat(p1_s, grp_st)
pooled_y  = y_s[grp_st].astype(int)

TARGET_PREC = 0.90
ord_desc = np.argsort(-pooled_p1)
scores = pooled_p1[ord_desc]
y_true = pooled_y[ord_desc]
tp = np.cumsum(y_true); fp = np.cumsum(1 - y_true); fn_tot = y_true.sum()
prec = tp / np.maximum(tp + fp, 1e-9)
rec  = tp / np.maximum(fn_tot, 1e-9)
mask = prec >= TARGET_PREC
best_idx = int(np.argmax(rec * mask)) if np.any(mask) else int(np.argmax(prec))
best_thr, best_prec, best_rec = float(scores[best_idx]), float(prec[best_idx]), float(rec[best_idx])

Path(SAVE_DIR, "threshold.json").write_text(json.dumps({
    "threshold": best_thr,
    "target_precision": TARGET_PREC,
    "precision_at_threshold": best_prec,
    "recall_at_threshold": best_rec
}, indent=2), encoding="utf-8")

print(f"[calibración] threshold={best_thr:.4f} | precision={best_prec:.4f} | recall={best_rec:.4f}")


In [None]:
"""
## Calibración de umbral (REEMPLAZADO — usa trainer.predict)
### Obtenemos logits por ventana y agrupamos por comentario (id + context_id).
### Elegimos el umbral que maximiza recall con precisión ≥ objetivo (ajústalo).

import json
import numpy as np
from collections import defaultdict

try:
    from scipy.special import softmax
except Exception:
    def softmax(x, axis=-1):
        x = np.asarray(x)
        x = x - np.max(x, axis=axis, keepdims=True)
        expx = np.exp(x)
        return expx / np.sum(expx, axis=axis, keepdims=True)

TARGET_PREC = 0.90  # ajusta tu objetivo

def _prepare_pred_dataset(split):

    #Devuelve (pred_ds, labels, ids, ctx) para usar con trainer.predict,
    #eliminando columnas no usadas por el modelo para evitar errores.

    d = ds[split]
    cols = d.column_names

    # Identifica columna de label
    if "labels" in cols:
        label_col = "labels"
    elif "label" in cols:
        label_col = "label"
    else:
        raise AssertionError(f"No se encontró columna de label en split='{split}'.")

    # Columnas necesarias para el modelo
    keep = ["input_ids", "attention_mask", label_col]
    if "token_type_ids" in cols:
        keep.append("token_type_ids")

    # Extrae metadatos antes de eliminar columnas
    ids = d["id"] if "id" in cols else list(range(len(d)))
    ctx = d["context_id"] if "context_id" in cols else [0] * len(d)
    labels = np.array(d[label_col])

    # Dataset solo con columnas esperadas por el modelo
    pred_ds = d.remove_columns([c for c in cols if c not in keep])
    return pred_ds, labels, np.array(ids), np.array(ctx)

def logits_with_meta(split):

    #Usa trainer.predict sobre un dataset reducido (sin columnas extra)
    #y devuelve (logits, labels, ids, ctx).

    pred_ds, labels, ids, ctx = _prepare_pred_dataset(split)
    pred_out = trainer.predict(pred_ds)
    logits = pred_out.predictions  # shape [N, 2]
    return logits, labels, ids, ctx

# pooling por comentario (máx logit)
def pool_comment_metrics(split, threshold):
    logits, labels, ids, ctx = logits_with_meta(split)
    bucket = defaultdict(lambda: {"logits": [], "label": None})

    for i in range(len(ids)):
        key = (ids[i], ctx[i])   # o solo id, si prefieres
        bucket[key]["logits"].append(logits[i])
        bucket[key]["label"] = int(labels[i])

    tp = fp = fn = tn = 0

    for _, v in bucket.items():
        L = np.stack(v["logits"], axis=0)   # [num_windows, 2]
        pooled = L.max(axis=0)              # máx por clase
        p1 = softmax(pooled)[1]
        pred = int(p1 >= threshold)
        y = v["label"]
        tp += int(pred == 1 and y == 1); fp += int(pred == 1 and y == 0)
        fn += int(pred == 0 and y == 1); tn += int(pred == 0 and y == 0)

    prec = tp / (tp + fp + 1e-9); rec = tp / (tp + fn + 1e-9)
    f1   = 2 * prec * rec / (prec + rec + 1e-9)
    acc  = (tp + tn) / (tp + tn + fp + fn + 1e-9)

    return dict(precision=prec, recall=rec, f1=f1, accuracy=acc,
                counts=dict(tp=tp, fp=fp, fn=fn, tn=tn))

# Barrido de umbral en VALIDATION
best_t, best_rec = 0.5, -1.0
for t in np.linspace(0.05, 0.95, 181):
    m = pool_comment_metrics("validation", t)
    if m["precision"] >= TARGET_PREC and m["recall"] > best_rec:
        best_t, best_rec = t, m["recall"]

# Guardado del threshold
from pathlib import Path as _Path
thr_path = _Path(SAVE_DIR) / "threshold.json"
thr_path.write_text(json.dumps({
    "threshold": float(best_t),
    "target_precision": TARGET_PREC,
    "recall_at_threshold": float(best_rec)
}, indent=2), encoding="utf-8")

print("threshold:", best_t, "recall@", TARGET_PREC, "=", best_rec)
"""

In [None]:
## Evaluación final en TEST (pooling por comentario)
import json
from pathlib import Path

thr = json.loads((Path(SAVE_DIR) / "threshold.json").read_text())["threshold"]
val_metrics  = pool_comment_metrics("validation", thr)
test_metrics = pool_comment_metrics("test", thr)

(Path(SAVE_DIR) / "eval_comment_level.json").write_text(
    json.dumps({"validation": val_metrics, "test": test_metrics}, indent=2),
    encoding="utf-8"
)

(Path(SAVE_DIR) / "inference_meta.json").write_text(
    json.dumps({"pooling":"max", "threshold": float(thr)}, indent=2),
    encoding="utf-8"
)

print("VAL:", val_metrics)
print("TEST:", test_metrics)
print("Guardado en:", SAVE_DIR)

In [None]:
## "Pack" final listo para reutilizar/re-entrenar
DEPLOY_DIR = "/content/drive/MyDrive/BERTolto/deploy_distilroberta"
!rm -rf "$DEPLOY_DIR"
!mkdir -p "$DEPLOY_DIR"
!cp -r "$SAVE_DIR"/* "$DEPLOY_DIR"/
print("Pack listo en:", DEPLOY_DIR)
!ls -la "$DEPLOY_DIR"