# 3.0 — Baselines clasificación binaria (Malignant vs nonMalignant)

- `y=1`: `Class_group == "Malignant"`
- `y=0`: `Class_group == "nonMalignant"`

Este notebook:
1) Carga `train/test`.
2) Construye `gene_cols` (columnas que empiezan por `ENSG`).
3) Lanza experimentos baseline con MLflow.
4) Selecciona un candidato (minimizando FN) y lo guarda en `models/`.


In [1]:
from __future__ import annotations
from pathlib import Path

from dataclasses import replace
import pandas as pd
import numpy as np


from genomics_dl.models.train_binary import BinaryTrainConfig, run_training


In [2]:
# Paths
DATA_PROCESSED = Path("../data/processed")
TRAIN_PATH = DATA_PROCESSED / "gse183635_tep_tpm_train.parquet"
TEST_PATH  = DATA_PROCESSED / "gse183635_tep_tpm_test.parquet"

df_train = pd.read_parquet(TRAIN_PATH)
df_test  = pd.read_parquet(TEST_PATH)

df_train.shape, df_test.shape

((1880, 5452), (471, 5452))

In [3]:
metadata_cols = [
    "Sample ID",
    "Patient_group",
    "Stage",
    "Sex",
    "Age",
    "Sample-supplying institution",
    "Training series",
    "Evaluation series",
    "Validation series",
    "lib.size",
    "classificationScoreCancer",
    "Class_group",
]

# Genes: columnas ENSG...
gene_cols = [c for c in df_train.columns if str(c).startswith("ENSG")]

# Checks
assert "Class_group" in df_train.columns
assert len(gene_cols) > 0
assert set(gene_cols).isdisjoint(set(metadata_cols))

len(gene_cols), gene_cols[:5]

(5440,
 ['ENSG00000000419',
  'ENSG00000000460',
  'ENSG00000000938',
  'ENSG00000001036',
  'ENSG00000001461'])

In [4]:
# Sanity: etiquetas
df_train["Class_group"].value_counts(dropna=False)

Class_group
Malignant       1302
nonMalignant     578
Name: count, dtype: int64

## Experimentos baseline

Estrategia:
- Barrido pequeño de `pos_weight` (penaliza más errores en Malignant).
- Comparación sin PCA / con PCA.
- Selección de umbral con constraint de recall (sensibilidad).

Para no llenar `models/` durante el barrido: `save_local_bundle=False`.
Luego re-entrenamos el candidato final con `save_local_bundle=True`.


In [5]:
base_cfg = BinaryTrainConfig(
    train_path=TRAIN_PATH,
    test_path=TEST_PATH,
    label_col="Class_group",
    positive_label="Malignant",
    experiment_name="gse183635_binary",
    model_name="logreg_binary",
    model_version="v0.1.0",

    # FN focus
    min_recall_for_threshold=0.95,
    
    save_local_bundle=False,
    save_plots=False,
    mlflow_log_artifacts=False,
    mlflow_log_model=False,
)


sweep = []
classifiers = ["logreg", "sgd_logloss", "linear_svc_calibrated", "rf", "extratrees"]
for use_pca in [False, True]:
    for pos_weight in [1.0, 2.0, 5.0]:
        sweep.append((use_pca, pos_weight))

results = []
for clf_name in classifiers:
    for use_pca, pos_weight in sweep:
        cfg = replace(
            base_cfg,
            use_pca=use_pca,
            pos_weight=pos_weight,
            # nombre distinto en MLflow para que sea más legible
            model_name=f"{clf_name}_pca{int(use_pca)}_pw{int(pos_weight)}",
        )
        out = run_training(cfg, feature_cols=gene_cols)
        results.append({
            "model_name": cfg.model_name,
            "use_pca": use_pca,
            "pos_weight": pos_weight,
            "threshold": out["chosen_threshold"],
            "test_fn": out["test_metrics"]["fn"],
            "test_fnr": out["test_metrics"]["fnr"],
            "test_recall": out["test_metrics"]["recall_sensitivity"],
            "test_specificity": out["test_metrics"]["specificity"],
            "test_pr_auc": out["test_metrics"]["pr_auc"],
            "test_roc_auc": out["test_metrics"]["roc_auc"],
            "mlflow_run_id": out["mlflow_run_id"],
        })

res_df = (
    pd.DataFrame(results)
      .sort_values(["test_fn", "test_fnr", "test_recall"], ascending=[True, True, False])
      .reset_index(drop=True)
)

2025/12/21 14:45:59 INFO mlflow.store.db.utils: Creating initial MLflow database tables...
2025/12/21 14:46:00 INFO mlflow.store.db.utils: Updating database tables
2025/12/21 14:46:00 INFO alembic.runtime.migration: Context impl SQLiteImpl.
2025/12/21 14:46:00 INFO alembic.runtime.migration: Will assume non-transactional DDL.
2025/12/21 14:46:01 INFO alembic.runtime.migration: Context impl SQLiteImpl.
2025/12/21 14:46:01 INFO alembic.runtime.migration: Will assume non-transactional DDL.


In [10]:
display(res_df)

Unnamed: 0,model_name,use_pca,pos_weight,threshold,test_fn,test_fnr,test_recall,test_specificity,test_pr_auc,test_roc_auc,mlflow_run_id
0,logreg_pca1_pw1,True,1.0,0.016757,20.0,0.06135,0.93865,0.331034,0.911972,0.839454,de84a4cc17664e4baa1b1471962a7697
1,logreg_pca1_pw2,True,2.0,0.015687,20.0,0.06135,0.93865,0.337931,0.91132,0.837148,45308ceae0fb4e55a8ec0541a38c4a5c
2,sgd_logloss_pca1_pw1,True,1.0,0.016757,20.0,0.06135,0.93865,0.331034,0.911972,0.839454,c2fa46e3b65c4568bf8c7435adbccd3a
3,sgd_logloss_pca1_pw2,True,2.0,0.015687,20.0,0.06135,0.93865,0.337931,0.91132,0.837148,c211d85aefe44ba4aaa5eb2c392b03fa
4,linear_svc_calibrated_pca1_pw1,True,1.0,0.016757,20.0,0.06135,0.93865,0.331034,0.911972,0.839454,d7b1e2700d5b4b5998242071f24802f9
5,linear_svc_calibrated_pca1_pw2,True,2.0,0.015687,20.0,0.06135,0.93865,0.337931,0.91132,0.837148,64a8831c919747ecaa6c44ab92f9b588
6,rf_pca1_pw1,True,1.0,0.016757,20.0,0.06135,0.93865,0.331034,0.911972,0.839454,35288f1a8ef348a68b07e7ee8a0ab417
7,rf_pca1_pw2,True,2.0,0.015687,20.0,0.06135,0.93865,0.337931,0.91132,0.837148,974f0a397da04167a22affab7ab97d01
8,extratrees_pca1_pw1,True,1.0,0.016757,20.0,0.06135,0.93865,0.331034,0.911972,0.839454,271289e0a47c4fc4b0169f4694f646ed
9,extratrees_pca1_pw2,True,2.0,0.015687,20.0,0.06135,0.93865,0.337931,0.91132,0.837148,a7caa929a58d4af994da6d328a56a0a7


## Entrenamiento final (guardar en `models/`)

Criterio por defecto aquí: menor FN en test (y luego menor FNR). Ajusta si quieres imponer un mínimo de especificidad.


In [14]:
# Elegimos el mejor del sweep
best = res_df.iloc[0].to_dict()
display(best)


{'model_name': 'logreg_pca1_pw1',
 'use_pca': True,
 'pos_weight': 1.0,
 'threshold': 0.016757279223890992,
 'test_fn': 20.0,
 'test_fnr': 0.06134969325153374,
 'test_recall': 0.9386503067484663,
 'test_specificity': 0.3310344827586207,
 'test_pr_auc': 0.9119715161788899,
 'test_roc_auc': 0.8394541992807277,
 'mlflow_run_id': 'de84a4cc17664e4baa1b1471962a7697'}

In [15]:
final_cfg = replace(
    base_cfg,
    use_pca=bool(best["use_pca"]),
    pos_weight=float(best["pos_weight"]),
    model_name="logreg_binary",
    model_version="v0.2.0",

    save_local_bundle=True,

    save_plots=False,
    mlflow_log_artifacts=False,

    mlflow_log_model=True,
)
final_out = run_training(final_cfg, feature_cols=gene_cols)



In [8]:
# DataFrame de métricas (CV vs Test)
cv = final_out["cv_metrics"].copy()
test = final_out["test_metrics"].copy()

chosen_thr = final_out["chosen_threshold"]
cv["threshold"] = chosen_thr
test["threshold"] = chosen_thr

metrics_df = (
    pd.DataFrame({"cv": cv, "test": test})
      .reset_index()
      .rename(columns={"index": "metric"})
)

priority_order = [
    "threshold",
    "fn", "fnr",
    "recall_sensitivity",
    "specificity",
    "fp", "tn", "tp",
    "precision", "npv",
    "f1",
    "balanced_accuracy",
    "pr_auc",
    "roc_auc",
    "fpr",
]

metrics_df["metric"] = pd.Categorical(metrics_df["metric"], categories=priority_order, ordered=True)
metrics_df = metrics_df.sort_values("metric").reset_index(drop=True)

In [9]:
display(metrics_df)

Unnamed: 0,metric,cv,test
0,threshold,0.016757,0.016757
1,fn,65.0,20.0
2,fnr,0.049923,0.06135
3,recall_sensitivity,0.950077,0.93865
4,specificity,0.299308,0.331034
5,fp,405.0,97.0
6,tn,173.0,48.0
7,tp,1237.0,306.0
8,precision,0.75335,0.759305
9,npv,0.726891,0.705882
