In [8]:
from pathlib import Path
import pandas as pd
import re

base = Path("./")
flat = base / "data" / "flatfile"
gz   = base / "data" / "gzip"

counts_path = gz / "GSE211781_raw_counts_GRCh38.p13_NCBI.tsv.gz"
series_path = gz / "GSE211781_series_matrix.txt.gz"
annot_path  = gz / "Human.GRCh38.p13.annot.tsv.gz"

if not counts_path.exists():
    counts_path = flat / "GSE211781_raw_counts_GRCh38.p13_NCBI.tsv"
if not series_path.exists():
    series_path = flat / "GSE211781_series_matrix.txt"
if not annot_path.exists():
    annot_path = flat / "Human.GRCh38.p13.annot.tsv"

# 1) Leer counts (genes x muestras)
counts_raw = pd.read_csv(
    counts_path,
    sep="\t",
    header=0,
    index_col=0
)

# 2) Leer anotaciones y agregar Gene_Name
#     -> dejamos índice como MultiIndex (Gene_ID, Gene_Name)
gene_annotations = pd.read_csv(
    annot_path,
    sep="\t",
    header=0,
    comment="#",
    low_memory=False
)

def pick_col(df, candidates):
    for c in candidates:
        if c in df.columns:
            return c
    return None

gene_id_col = pick_col(gene_annotations, [
    "Gene_ID", "gene_id", "GeneID", "geneID", "Ensembl_ID", "ensembl_gene_id", "gene"
])
gene_name_col = pick_col(gene_annotations, [
    "Gene_Name", "gene_name", "GeneName", "gene_symbol", "Symbol", "HGNC_symbol", "gene_name_hgnc"
])

# Mapa mínimo robusto
if gene_id_col is not None and gene_name_col is not None:
    ann_map = (
        gene_annotations[[gene_id_col, gene_name_col]]
        .dropna(subset=[gene_id_col])
        .drop_duplicates()
        .rename(columns={gene_id_col: "Gene_ID", gene_name_col: "Gene_Name"})
        .set_index("Gene_ID")
    )

    counts_with_names = (
        counts_raw
        .reset_index()
        .rename(columns={counts_raw.index.name or "index": "Gene_ID"})
    )

    # Si el índice de counts no tenía nombre, pandas pone "index"
    if "Gene_ID" not in counts_with_names.columns:
        counts_with_names = counts_raw.reset_index().rename(columns={"index": "Gene_ID"})

    counts_with_names["Gene_Name"] = counts_with_names["Gene_ID"].map(ann_map["Gene_Name"])

    counts_with_names = counts_with_names.set_index(["Gene_ID", "Gene_Name"])
else:
    # Fallback: si no encontramos columnas esperadas en anotación
    counts_with_names = counts_raw.copy()
    counts_with_names.index.name = "Gene_ID"

# 3) Parser ligero del series_matrix local
def parse_geo_series_matrix(path: Path) -> pd.DataFrame:
    if path.suffix == ".gz":
        import gzip
        with gzip.open(path, "rt", encoding="utf-8", errors="replace") as f:
            lines = f.readlines()
    else:
        with open(path, "r", encoding="utf-8", errors="replace") as f:
            lines = f.readlines()

    sample_lines = [ln.rstrip("\n") for ln in lines if ln.startswith("!Sample_")]

    keys, values_matrix = [], []
    for ln in sample_lines:
        parts = ln.split("\t")
        key = parts[0].replace("!Sample_", "")
        vals = [v.strip().strip('"') for v in parts[1:]]
        keys.append(key)
        values_matrix.append(vals)

    if not values_matrix:
        return pd.DataFrame()

    n_samples = max(len(v) for v in values_matrix)

    norm = []
    for vals in values_matrix:
        if len(vals) < n_samples:
            vals = vals + [None] * (n_samples - len(vals))
        norm.append(vals)

    tmp = pd.DataFrame(norm, index=keys)

    if "geo_accession" in tmp.index:
        sample_ids = tmp.loc["geo_accession"].tolist()
    else:
        sample_ids = [f"Sample_{i+1}" for i in range(n_samples)]

    tmp.columns = sample_ids
    pData = tmp.T
    pData.index.name = "Sample_ID"
    return pData

pData_raw = parse_geo_series_matrix(series_path)

# 4) Construir etiquetas robustas (buscando en TODAS las characteristics)
def safe_series(df, col, default=""):
    return df[col] if col in df.columns else pd.Series([default]*len(df), index=df.index)

title = safe_series(pData_raw, "title")
description = safe_series(pData_raw, "description")
source_name = safe_series(pData_raw, "source_name_ch1")

# Combina todas las columnas de characteristics disponibles
char_cols = [c for c in pData_raw.columns if c.startswith("characteristics_ch1")]
if char_cols:
    chars_joined = (
        pData_raw[char_cols]
        .fillna("")
        .astype(str)
        .agg(" | ".join, axis=1)
    )
else:
    chars_joined = pd.Series([""]*len(pData_raw), index=pData_raw.index)

# Intento de extraer la porción de tratamiento si aparece
# (si no aparece explícitamente, usamos el texto completo)
treat_clean = chars_joined.str.replace(r".*treatment:\s*", "", regex=True).fillna(chars_joined)

def parse_drug(x: str) -> str:
    if re.search(r"Apalutamide|ARN", x, flags=re.I): return "Apalutamide"
    if re.search(r"Bicalutamide|BIC", x, flags=re.I): return "Bicalutamide"
    if re.search(r"Enzalutamide|ENZ", x, flags=re.I): return "Enzalutamide"
    if re.search(r"Vehicle|VEH|Control", x, flags=re.I): return "Control"
    return "Unknown"

def parse_dht(x: str) -> str:
    if re.search(r"DHT\+", x): return "Plus_DHT"
    if re.search(r"DHT-|No DHT|without DHT|VEH\+", x, flags=re.I): return "No_DHT"
    return "No_DHT"

def parse_rep(t: str, d: str) -> str:
    s = f"{t} {d}"
    if re.search(r"rep\s*1|rep1|biol\s*rep\s*1", s, flags=re.I): return "Rep1"
    if re.search(r"rep\s*2|rep2|biol\s*rep\s*2", s, flags=re.I): return "Rep2"
    if re.search(r"rep\s*3|rep3|biol\s*rep\s*3", s, flags=re.I): return "Rep3"
    return "Rep1"

sample_info = pd.DataFrame({
    "Sample_ID": pData_raw.index,
    "title": title.values,
    "description": description.values,
    "source_name_ch1": source_name.values,
    "Treatment_Clean": treat_clean.values,
})

sample_info["Drug"] = sample_info["Treatment_Clean"].map(parse_drug)
sample_info["DHT"] = sample_info["Treatment_Clean"].map(parse_dht)
sample_info["Replicate"] = [
    parse_rep(t, d) for t, d in zip(sample_info["title"], sample_info["description"])
]

sample_info["GSM"] = sample_info["Sample_ID"].str.extract(r"(GSM\d+)", expand=False)

# 5) Alinear metadata con columnas de counts por GSM
count_colnames = list(counts_raw.columns)

def extract_gsm(name):
    m = re.search(r"(GSM\d+)", str(name))
    return m.group(1) if m else None

counts_gsm = [extract_gsm(c) for c in count_colnames]

if all(g is not None for g in counts_gsm) and sample_info["GSM"].notna().all():
    sample_info = (
        sample_info.set_index("GSM")
        .loc[counts_gsm]
        .reset_index()
    )
else:
    # Si no hay GSM en columnas, intentamos alineación por tamaño
    if counts_raw.shape[1] == sample_info.shape[0]:
        counts_raw.columns = sample_info["Sample_ID"].tolist()

# 6) Acortar nombres de columnas al contexto del problema
#     (CTL, BIC, ENZ, ARN) + (DHT+ / DHT-) + (R1/R2/...)
drug_short_map = {
    "Control": "CTL",
    "Bicalutamide": "BIC",
    "Enzalutamide": "ENZ",
    "Apalutamide": "ARN",
    "Unknown": "UNK"
}
dht_short_map = {
    "Plus_DHT": "DHT+",
    "No_DHT": "DHT-"
}

sample_info["Drug_short"] = sample_info["Drug"].map(lambda x: drug_short_map.get(x, "UNK"))
sample_info["DHT_short"] = sample_info["DHT"].map(lambda x: dht_short_map.get(x, "DHT-"))
sample_info["Rep_short"] = sample_info["Replicate"].str.replace("Rep", "R", regex=False)

# etiqueta compacta final
sample_info["Short_Label"] = (
    sample_info["Drug_short"] + "_" +
    sample_info["DHT_short"] + "_" +
    sample_info["Rep_short"]
)

# 7) Crear dataframe final:
#     - genes en filas con (Gene_ID, Gene_Name)
#     - muestras en columnas con nombres cortos
#     - metadata guardada aparte
# Reaplicamos las columnas al objeto con nombres de genes
counts_final = counts_with_names.copy()

# Asegurar que tenemos una fila de metadata por columna
if counts_raw.shape[1] == sample_info.shape[0]:
    counts_final.columns = sample_info["Short_Label"].tolist()

print("Counts final shape:", counts_final.shape)
print("Ejemplo de columnas cortas:", list(counts_final.columns[:8]))
print("Ejemplo de índice (genes):", counts_final.index[:3])

Counts final shape: (39376, 16)
Ejemplo de columnas cortas: ['ARN_DHT+_R1', 'ARN_DHT+_R2', 'ARN_DHT-_R1', 'ARN_DHT-_R2', 'BIC_DHT+_R1', 'BIC_DHT+_R2', 'BIC_DHT-_R1', 'BIC_DHT-_R2']
Ejemplo de índice (genes): MultiIndex([(100287102,   'DDX11L1'),
            (   653635,    'WASH7P'),
            (102466751, 'MIR6859-1')],
           names=['Gene_ID', 'Gene_Name'])


In [9]:
counts_final.columns = (
    counts_final.columns
    # 1) Quitar el marcador de No_DHT
    .str.replace(r"_DHT-_", "_", regex=True)
    # 2) Quitar el signo + en los que sí tienen DHT
    .str.replace(r"_DHT\+_", "_DHT_", regex=True)
)

In [10]:
counts_final.columns

Index(['ARN_DHT_R1', 'ARN_DHT_R2', 'ARN_R1', 'ARN_R2', 'BIC_DHT_R1',
       'BIC_DHT_R2', 'BIC_R1', 'BIC_R2', 'CTL_DHT_R1', 'CTL_DHT_R2',
       'ENZ_DHT_R1', 'ENZ_DHT_R2', 'ENZ_R1', 'ENZ_R2', 'CTL_R1', 'CTL_R2'],
      dtype='object')

In [11]:
import numpy as np

# =========================================================
# Normalización estilo DESeq2 (median-of-ratios)
# =========================================================
def deseq2_size_factors(counts: pd.DataFrame) -> pd.Series:
    """
    Calcula size factors siguiendo la idea de DESeq2:
    1) Media geométrica por gen a través de muestras (ignorando ceros)
    2) Para cada muestra, mediana de (count / geo_mean) sobre genes válidos
    """
    X = counts.to_numpy(dtype=float)

    # Máscara de valores > 0
    positive = X > 0

    # Log solo donde es positivo
    logX = np.zeros_like(X, dtype=float)
    logX[positive] = np.log(X[positive])

    # Media geométrica por gen:
    # mean(log(count)) sobre muestras positivas
    n_pos = positive.sum(axis=1)
    with np.errstate(divide="ignore", invalid="ignore"):
        mean_log = logX.sum(axis=1) / n_pos
    geo_means = np.exp(mean_log)

    # Genes válidos: al menos 1 positivo y geo_mean finito y > 0
    valid_genes = (n_pos > 0) & np.isfinite(geo_means) & (geo_means > 0)

    # Ratios por muestra sobre genes válidos
    ratios = X[valid_genes, :] / geo_means[valid_genes][:, None]

    # Mediana por columna (muestra)
    size_factors = np.median(ratios, axis=0)

    return pd.Series(size_factors, index=counts.columns, name="size_factor")


# 1) Calcular size factors
size_factors = deseq2_size_factors(counts_final)
print(size_factors)

# 1) Dejar solo Gene_Name en el índice
if isinstance(counts_final.index, pd.MultiIndex):
    # Si los niveles tienen nombre explícito
    if "Gene_ID" in counts_final.index.names:
        counts_final = counts_final.droplevel("Gene_ID")
    else:
        # Fallback: asumimos que Gene_ID es el primer nivel
        counts_final = counts_final.droplevel(0)

    counts_final.index.name = "Gene_Name"

elif "Gene_ID" in counts_final.columns and "Gene_Name" in counts_final.columns:
    counts_final = counts_final.set_index("Gene_Name").drop(columns=["Gene_ID"])

else:
    # Si ya está en Gene_Name o no hay Gene_ID explícito, solo renombramos por claridad
    counts_final.index.name = counts_final.index.name or "Gene_Name"


# 2) Si hay Gene_Name duplicados, sumarlos
if counts_final.index.duplicated().any():
    counts_final = counts_final.groupby(counts_final.index).sum()
    counts_final.index.name = "Gene_Name"

# 2) Normalizar counts
counts_norm = counts_final.div(size_factors, axis=1)

print("Raw shape:", counts_final.shape)
print("Norm shape:", counts_norm.shape)

# 3) Log2 de counts normalizados
counts_norm_log2 = np.log2(counts_norm + 1)


ARN_DHT_R1    1.035198
ARN_DHT_R2    0.655074
ARN_R1        0.983638
ARN_R2        1.270939
BIC_DHT_R1    1.002430
BIC_DHT_R2    0.674614
BIC_R1        0.863511
BIC_R2        1.034636
CTL_DHT_R1    1.000000
CTL_DHT_R2    0.551351
ENZ_DHT_R1    1.151725
ENZ_DHT_R2    0.888903
ENZ_R1        0.978747
ENZ_R2        1.117453
CTL_R1        0.948579
CTL_R2        1.069801
Name: size_factor, dtype: float64
Raw shape: (39374, 16)
Norm shape: (39374, 16)


In [14]:
# Eliminamos de counts_norm_log2 las columnas de ARN_R1, ARN_R2, BIC_R1, BIC_R2, ENZ_R1, ENZ_R2
cols_to_remove = [col for col in counts_norm_log2.columns if re.match(r'(ARN|BIC|ENZ|CTL)_R[12]', col)]
counts_norm_log2 = counts_norm_log2.drop(columns=cols_to_remove)
counts_norm_log2

Unnamed: 0_level_0,ARN_DHT_R1,ARN_DHT_R2,BIC_DHT_R1,BIC_DHT_R2,CTL_DHT_R1,CTL_DHT_R2,ENZ_DHT_R1,ENZ_DHT_R2
Gene_Name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
A1BG,8.744624,7.910898,8.120633,7.754471,7.813781,8.044341,8.912943,7.910619
A1BG-AS1,8.866078,8.585575,8.686505,8.472798,8.588715,8.322901,8.886718,8.616526
A1CF,5.273291,4.907101,4.640495,5.133134,3.000000,3.331796,5.159085,4.971521
A2M,4.344827,3.723817,2.582046,3.983977,3.459432,4.114662,3.717614,3.614688
A2M-AS1,6.179589,6.154233,5.777922,6.202302,5.727920,5.838587,5.706739,5.999977
...,...,...,...,...,...,...,...,...
ZYG11A,6.833167,6.602474,6.581498,6.172963,6.303781,6.010789,6.290488,6.560310
ZYG11B,10.219262,11.092731,10.653820,11.024912,10.710806,10.895888,10.053771,10.912960
ZYX,11.478077,10.957510,11.194331,10.741136,11.661333,11.003043,11.418783,10.958529
ZZEF1,11.625184,11.721834,11.876083,11.626989,12.298349,12.027975,11.582944,11.647604


In [15]:
counts_norm_log2.to_csv(base / "data" / "counts_normalized_log2_shortlabels.csv")

In [16]:
counts_norm_log2_transpose = counts_norm_log2.T
counts_norm_log2_transpose["Label"] = ["ARN" if "ARN" in lab_index else "BIC" if "BIC" in lab_index else "ENZ" if "ENZ" in lab_index else "CTL" for lab_index in counts_norm_log2_transpose.index]
counts_norm_log2_transpose

Gene_Name,A1BG,A1BG-AS1,A1CF,A2M,A2M-AS1,A2ML1,A2MP1,A3GALT2,A4GALT,A4GNT,...,ZWINT,ZXDA,ZXDB,ZXDC,ZYG11A,ZYG11B,ZYX,ZZEF1,ZZZ3,Label
ARN_DHT_R1,8.744624,8.866078,5.273291,4.344827,6.179589,6.869377,2.543495,2.764685,7.207813,0.975263,...,9.645383,8.084701,9.483492,10.744538,6.833167,10.219262,11.478077,11.625184,11.040229,ARN
ARN_DHT_R2,7.910898,8.585575,4.907101,3.723817,6.154233,7.249142,3.881558,2.019023,7.305918,0.0,...,10.46927,8.727421,10.081245,10.934063,6.602474,11.092731,10.95751,11.721834,11.538534,ARN
BIC_DHT_R1,8.120633,8.686505,4.640495,2.582046,5.777922,7.362842,2.319128,1.997375,6.764715,0.0,...,10.325177,8.500334,10.034049,10.936814,6.581498,10.65382,11.194331,11.876083,11.208997,BIC
BIC_DHT_R2,7.754471,8.472798,5.133134,3.983977,6.202302,7.608698,1.311694,1.987196,7.038314,2.445458,...,11.248817,8.577271,10.106393,10.845936,6.172963,11.024912,10.741136,11.626989,11.42637,BIC
CTL_DHT_R1,7.813781,8.588715,3.0,3.459432,5.72792,7.754888,3.169925,1.584963,6.882643,1.584963,...,11.385323,8.491853,10.139551,10.883407,6.303781,10.710806,11.661333,12.298349,11.181773,CTL
CTL_DHT_R2,8.044341,8.322901,3.331796,4.114662,5.838587,7.760444,2.687325,1.492482,6.399708,1.492482,...,12.002054,8.787039,10.093897,10.913631,6.010789,10.895888,11.003043,12.027975,11.402522,CTL
ENZ_DHT_R1,8.912943,8.886718,5.159085,3.717614,5.706739,6.716875,4.505039,2.16126,7.273167,1.452346,...,9.188847,8.086922,9.468898,10.780731,6.290488,10.053771,11.418783,11.582944,10.979313,ENZ
ENZ_DHT_R2,7.910619,8.616526,4.971521,3.614688,5.999977,7.256185,2.954176,2.129265,6.962873,0.0,...,10.486057,8.632972,10.033916,10.911275,6.56031,10.91296,10.958529,11.647604,11.442467,ENZ


In [19]:
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold, cross_val_score

# ---------------------------------------------------------
# 1) Construir matriz de ML (muestras x genes) + Label multiclass
# ---------------------------------------------------------
X_df = counts_norm_log2.T  # ahora filas=muestras, columnas=genes

def infer_label(sample_name: str) -> str:
    s = str(sample_name)
    if "ARN" in s: return "ARN"
    if "BIC" in s: return "BIC"
    if "ENZ" in s: return "ENZ"
    return "CTL"

labels = pd.Series([infer_label(idx) for idx in X_df.index], index=X_df.index, name="Label")

# ---------------------------------------------------------
# 2) Función: entrenar modelo 1 vs resto y devolver:
#    - dataframe de importancias ordenado
#    - métrica de precisión/accuracy (CV)
# ---------------------------------------------------------
def train_one_vs_rest_rf(
    drug: str,
    X: pd.DataFrame,
    y_multiclass: pd.Series,
    n_estimators: int = 2000,
    random_state: int = 42
):
    # y binaria: 1 si es la droga, 0 si es el resto
    y_bin = (y_multiclass == drug).astype(int)

    # Validación cruzada estratificada (pocas muestras -> folds pequeños)
    skf = StratifiedKFold(n_splits=4, shuffle=True, random_state=random_state)

    rf = RandomForestClassifier(
        n_estimators=n_estimators,
        random_state=random_state,
        n_jobs=-1,
        class_weight="balanced_subsample"
    )

    # Accuracy como "precisión" global de clasificación
    cv_acc = cross_val_score(rf, X, y_bin, cv=skf, scoring="accuracy")
    acc_mean = float(np.mean(cv_acc))
    acc_std  = float(np.std(cv_acc))
    err_mean = 1.0 - acc_mean

    # Entrenar en todo el set para ranking de genes
    rf.fit(X, y_bin)

    imp_df = pd.DataFrame({
        "Gene_Name": X.columns,
        "Importance": rf.feature_importances_
    }).sort_values("Importance", ascending=False).reset_index(drop=True)

    metrics = {
        "drug": drug,
        "cv_accuracy_mean": acc_mean,
        "cv_accuracy_std": acc_std,
        "cv_error_mean": err_mean,
        "n_samples": int(X.shape[0]),
        "n_features": int(X.shape[1]),
        "positive_samples": int(y_bin.sum()),
        "negative_samples": int((1 - y_bin).sum())
    }

    return imp_df, metrics


# ---------------------------------------------------------
# 3) Entrenar 3 modelos: ARN vs resto, BIC vs resto, ENZ vs resto
# ---------------------------------------------------------
drugs_to_model = ["ARN", "BIC", "ENZ"]

importance_tables = {}
metrics_list = []

for d in drugs_to_model:
    imp_df, met = train_one_vs_rest_rf(d, X_df, labels)
    importance_tables[d] = imp_df
    metrics_list.append(met)

# ---------------------------------------------------------
# 4) Dataframes finales (los 3 que quieres)
# ---------------------------------------------------------
imp_ARN = importance_tables["ARN"]
imp_BIC = importance_tables["BIC"]
imp_ENZ = importance_tables["ENZ"]

# ---------------------------------------------------------
# 5) Resumen de precisión/error por modelo
# ---------------------------------------------------------
metrics_df = pd.DataFrame(metrics_list)
print(metrics_df[["drug", "cv_accuracy_mean", "cv_accuracy_std", "cv_error_mean",
                  "positive_samples", "negative_samples"]])

# ---------------------------------------------------------
print("\nTop 20 ARN vs resto:")
print(imp_ARN.head(20))

print("\nTop 20 BIC vs resto:")
print(imp_BIC.head(20))

print("\nTop 20 ENZ vs resto:")
print(imp_ENZ.head(20))



  drug  cv_accuracy_mean  cv_accuracy_std  cv_error_mean  positive_samples  \
0  ARN             0.625         0.216506          0.375                 2   
1  BIC             0.750         0.250000          0.250                 2   
2  ENZ             0.625         0.216506          0.375                 2   

   negative_samples  
0                 6  
1                 6  
2                 6  

Top 20 ARN vs resto:
       Gene_Name  Importance
0        WFDC10B    0.002186
1         HOXD13    0.001639
2        CHEK2P2    0.001639
3           RPRM    0.001639
4   LOC100507403    0.001639
5       MRNIP-DT    0.001639
6         GPR160    0.001639
7         EGFLAM    0.001639
8         PTGER3    0.001639
9         RWDD2B    0.001639
10          CDR2    0.001639
11        ZYG11A    0.001639
12  LOC105379364    0.001639
13          IL27    0.001639
14        PKNOX1    0.001639
15         HOXA9    0.001639
16          CER1    0.001639
17      SNORD119    0.001639
18      PRICKLE1    0.0010

In [20]:
imp_ARN.to_csv("results/RandomForestDE/imp_ARN.csv", index=False, header=True)
imp_BIC.to_csv("results/RandomForestDE/imp_BIC.csv", index=False, header=True)
imp_ENZ.to_csv("results/RandomForestDE/imp_ENZ.csv", index=False, header=True)