In [29]:
from pathlib import Path
import pandas as pd
import re

# =========================================================
# Paths
# =========================================================
base = Path("/home/manager/Documents/WCS_Project_Cancer")
flat = base / "data" / "flatfile"
gz   = base / "data" / "gzip"

counts_path = gz / "GSE211781_raw_counts_GRCh38.p13_NCBI.tsv.gz"
series_path = gz / "GSE211781_series_matrix.txt.gz"
annot_path  = gz / "Human.GRCh38.p13.annot.tsv.gz"

if not counts_path.exists():
    counts_path = flat / "GSE211781_raw_counts_GRCh38.p13_NCBI.tsv"
if not series_path.exists():
    series_path = flat / "GSE211781_series_matrix.txt"
if not annot_path.exists():
    annot_path = flat / "Human.GRCh38.p13.annot.tsv"

# =========================================================
# 1) Leer counts (genes x muestras)
# =========================================================
counts_raw = pd.read_csv(
    counts_path,
    sep="\t",
    header=0,
    index_col=0
)

# =========================================================
# 2) Leer anotaciones y agregar Gene_Name
#     -> dejamos índice como MultiIndex (Gene_ID, Gene_Name)
# =========================================================
gene_annotations = pd.read_csv(
    annot_path,
    sep="\t",
    header=0,
    comment="#",
    low_memory=False
)

def pick_col(df, candidates):
    for c in candidates:
        if c in df.columns:
            return c
    return None

gene_id_col = pick_col(gene_annotations, [
    "Gene_ID", "gene_id", "GeneID", "geneID", "Ensembl_ID", "ensembl_gene_id", "gene"
])
gene_name_col = pick_col(gene_annotations, [
    "Gene_Name", "gene_name", "GeneName", "gene_symbol", "Symbol", "HGNC_symbol", "gene_name_hgnc"
])

# Mapa mínimo robusto
if gene_id_col is not None and gene_name_col is not None:
    ann_map = (
        gene_annotations[[gene_id_col, gene_name_col]]
        .dropna(subset=[gene_id_col])
        .drop_duplicates()
        .rename(columns={gene_id_col: "Gene_ID", gene_name_col: "Gene_Name"})
        .set_index("Gene_ID")
    )

    counts_with_names = (
        counts_raw
        .reset_index()
        .rename(columns={counts_raw.index.name or "index": "Gene_ID"})
    )

    # Si el índice de counts no tenía nombre, pandas pone "index"
    if "Gene_ID" not in counts_with_names.columns:
        counts_with_names = counts_raw.reset_index().rename(columns={"index": "Gene_ID"})

    counts_with_names["Gene_Name"] = counts_with_names["Gene_ID"].map(ann_map["Gene_Name"])

    counts_with_names = counts_with_names.set_index(["Gene_ID", "Gene_Name"])
else:
    # Fallback: si no encontramos columnas esperadas en anotación
    counts_with_names = counts_raw.copy()
    counts_with_names.index.name = "Gene_ID"

# =========================================================
# 3) Parser ligero del series_matrix local
# =========================================================
def parse_geo_series_matrix(path: Path) -> pd.DataFrame:
    if path.suffix == ".gz":
        import gzip
        with gzip.open(path, "rt", encoding="utf-8", errors="replace") as f:
            lines = f.readlines()
    else:
        with open(path, "r", encoding="utf-8", errors="replace") as f:
            lines = f.readlines()

    sample_lines = [ln.rstrip("\n") for ln in lines if ln.startswith("!Sample_")]

    keys, values_matrix = [], []
    for ln in sample_lines:
        parts = ln.split("\t")
        key = parts[0].replace("!Sample_", "")
        vals = [v.strip().strip('"') for v in parts[1:]]
        keys.append(key)
        values_matrix.append(vals)

    if not values_matrix:
        return pd.DataFrame()

    n_samples = max(len(v) for v in values_matrix)

    norm = []
    for vals in values_matrix:
        if len(vals) < n_samples:
            vals = vals + [None] * (n_samples - len(vals))
        norm.append(vals)

    tmp = pd.DataFrame(norm, index=keys)

    if "geo_accession" in tmp.index:
        sample_ids = tmp.loc["geo_accession"].tolist()
    else:
        sample_ids = [f"Sample_{i+1}" for i in range(n_samples)]

    tmp.columns = sample_ids
    pData = tmp.T
    pData.index.name = "Sample_ID"
    return pData

pData_raw = parse_geo_series_matrix(series_path)

# =========================================================
# 4) Construir etiquetas robustas (buscando en TODAS las characteristics)
# =========================================================
def safe_series(df, col, default=""):
    return df[col] if col in df.columns else pd.Series([default]*len(df), index=df.index)

title = safe_series(pData_raw, "title")
description = safe_series(pData_raw, "description")
source_name = safe_series(pData_raw, "source_name_ch1")

# Combina todas las columnas de characteristics disponibles
char_cols = [c for c in pData_raw.columns if c.startswith("characteristics_ch1")]
if char_cols:
    chars_joined = (
        pData_raw[char_cols]
        .fillna("")
        .astype(str)
        .agg(" | ".join, axis=1)
    )
else:
    chars_joined = pd.Series([""]*len(pData_raw), index=pData_raw.index)

# Intento de extraer la porción de tratamiento si aparece
# (si no aparece explícitamente, usamos el texto completo)
treat_clean = chars_joined.str.replace(r".*treatment:\s*", "", regex=True).fillna(chars_joined)

def parse_drug(x: str) -> str:
    if re.search(r"Apalutamide|ARN", x, flags=re.I): return "Apalutamide"
    if re.search(r"Bicalutamide|BIC", x, flags=re.I): return "Bicalutamide"
    if re.search(r"Enzalutamide|ENZ", x, flags=re.I): return "Enzalutamide"
    if re.search(r"Vehicle|VEH|Control", x, flags=re.I): return "Control"
    return "Unknown"

def parse_dht(x: str) -> str:
    if re.search(r"DHT\+", x): return "Plus_DHT"
    if re.search(r"DHT-|No DHT|without DHT|VEH\+", x, flags=re.I): return "No_DHT"
    return "No_DHT"

def parse_rep(t: str, d: str) -> str:
    s = f"{t} {d}"
    if re.search(r"rep\s*1|rep1|biol\s*rep\s*1", s, flags=re.I): return "Rep1"
    if re.search(r"rep\s*2|rep2|biol\s*rep\s*2", s, flags=re.I): return "Rep2"
    if re.search(r"rep\s*3|rep3|biol\s*rep\s*3", s, flags=re.I): return "Rep3"
    return "Rep1"

sample_info = pd.DataFrame({
    "Sample_ID": pData_raw.index,
    "title": title.values,
    "description": description.values,
    "source_name_ch1": source_name.values,
    "Treatment_Clean": treat_clean.values,
})

sample_info["Drug"] = sample_info["Treatment_Clean"].map(parse_drug)
sample_info["DHT"] = sample_info["Treatment_Clean"].map(parse_dht)
sample_info["Replicate"] = [
    parse_rep(t, d) for t, d in zip(sample_info["title"], sample_info["description"])
]

sample_info["GSM"] = sample_info["Sample_ID"].str.extract(r"(GSM\d+)", expand=False)

# =========================================================
# 5) Alinear metadata con columnas de counts por GSM
# =========================================================
count_colnames = list(counts_raw.columns)

def extract_gsm(name):
    m = re.search(r"(GSM\d+)", str(name))
    return m.group(1) if m else None

counts_gsm = [extract_gsm(c) for c in count_colnames]

if all(g is not None for g in counts_gsm) and sample_info["GSM"].notna().all():
    sample_info = (
        sample_info.set_index("GSM")
        .loc[counts_gsm]
        .reset_index()
    )
else:
    # Si no hay GSM en columnas, intentamos alineación por tamaño
    if counts_raw.shape[1] == sample_info.shape[0]:
        counts_raw.columns = sample_info["Sample_ID"].tolist()

# =========================================================
# 6) Acortar nombres de columnas al contexto del problema
#     (CTL, BIC, ENZ, ARN) + (DHT+ / DHT-) + (R1/R2/...)
# =========================================================
drug_short_map = {
    "Control": "CTL",
    "Bicalutamide": "BIC",
    "Enzalutamide": "ENZ",
    "Apalutamide": "ARN",
    "Unknown": "UNK"
}
dht_short_map = {
    "Plus_DHT": "DHT+",
    "No_DHT": "DHT-"
}

sample_info["Drug_short"] = sample_info["Drug"].map(lambda x: drug_short_map.get(x, "UNK"))
sample_info["DHT_short"] = sample_info["DHT"].map(lambda x: dht_short_map.get(x, "DHT-"))
sample_info["Rep_short"] = sample_info["Replicate"].str.replace("Rep", "R", regex=False)

# etiqueta compacta final
sample_info["Short_Label"] = (
    sample_info["Drug_short"] + "_" +
    sample_info["DHT_short"] + "_" +
    sample_info["Rep_short"]
)

# =========================================================
# 7) Crear dataframe final:
#     - genes en filas con (Gene_ID, Gene_Name)
#     - muestras en columnas con nombres cortos
#     - metadata guardada aparte
# =========================================================
# Reaplicamos las columnas al objeto con nombres de genes
counts_final = counts_with_names.copy()

# Asegurar que tenemos una fila de metadata por columna
if counts_raw.shape[1] == sample_info.shape[0]:
    counts_final.columns = sample_info["Short_Label"].tolist()

# =========================================================
# 8) Guardar objetos útileGene_IDs
# =========================================================
counts_final.to_pickle(base / "data" / "counts_final_shortlabels.pkl")
sample_info.to_csv(base / "data" / "sample_metadata_parsed_shortlabels.csv", index=False)

print("Counts final shape:", counts_final.shape)
print("Ejemplo de columnas cortas:", list(counts_final.columns[:8]))
print("Ejemplo de índice (genes):", counts_final.index[:3])


Counts final shape: (39376, 16)
Ejemplo de columnas cortas: ['ARN_DHT+_R1', 'ARN_DHT+_R2', 'ARN_DHT-_R1', 'ARN_DHT-_R2', 'BIC_DHT+_R1', 'BIC_DHT+_R2', 'BIC_DHT-_R1', 'BIC_DHT-_R2']
Ejemplo de índice (genes): MultiIndex([(100287102,   'DDX11L1'),
            (   653635,    'WASH7P'),
            (102466751, 'MIR6859-1')],
           names=['Gene_ID', 'Gene_Name'])


In [30]:
counts_final.columns = (
    counts_final.columns
    # 1) Quitar el marcador de No_DHT
    .str.replace(r"_DHT-_", "_", regex=True)
    # 2) Quitar el signo + en los que sí tienen DHT
    .str.replace(r"_DHT\+_", "_DHT_", regex=True)
)

In [31]:
counts_final.columns

Index(['ARN_DHT_R1', 'ARN_DHT_R2', 'ARN_R1', 'ARN_R2', 'BIC_DHT_R1',
       'BIC_DHT_R2', 'BIC_R1', 'BIC_R2', 'CTL_DHT_R1', 'CTL_DHT_R2',
       'ENZ_DHT_R1', 'ENZ_DHT_R2', 'ENZ_R1', 'ENZ_R2', 'CTL_R1', 'CTL_R2'],
      dtype='object')

In [37]:
import numpy as np

# =========================================================
# Normalización estilo DESeq2 (median-of-ratios)
# =========================================================
def deseq2_size_factors(counts: pd.DataFrame) -> pd.Series:
    """
    Calcula size factors siguiendo la idea de DESeq2:
    1) Media geométrica por gen a través de muestras (ignorando ceros)
    2) Para cada muestra, mediana de (count / geo_mean) sobre genes válidos
    """
    X = counts.to_numpy(dtype=float)

    # Máscara de valores > 0
    positive = X > 0

    # Log solo donde es positivo
    logX = np.zeros_like(X, dtype=float)
    logX[positive] = np.log(X[positive])

    # Media geométrica por gen:
    # mean(log(count)) sobre muestras positivas
    n_pos = positive.sum(axis=1)
    with np.errstate(divide="ignore", invalid="ignore"):
        mean_log = logX.sum(axis=1) / n_pos
    geo_means = np.exp(mean_log)

    # Genes válidos: al menos 1 positivo y geo_mean finito y > 0
    valid_genes = (n_pos > 0) & np.isfinite(geo_means) & (geo_means > 0)

    # Ratios por muestra sobre genes válidos
    ratios = X[valid_genes, :] / geo_means[valid_genes][:, None]

    # Mediana por columna (muestra)
    size_factors = np.median(ratios, axis=0)

    return pd.Series(size_factors, index=counts.columns, name="size_factor")


# 1) Calcular size factors
size_factors = deseq2_size_factors(counts_final)
print(size_factors)

# 1) Dejar solo Gene_Name en el índice
if isinstance(counts_final.index, pd.MultiIndex):
    # Si los niveles tienen nombre explícito
    if "Gene_ID" in counts_final.index.names:
        counts_final = counts_final.droplevel("Gene_ID")
    else:
        # Fallback: asumimos que Gene_ID es el primer nivel
        counts_final = counts_final.droplevel(0)

    counts_final.index.name = "Gene_Name"

elif "Gene_ID" in counts_final.columns and "Gene_Name" in counts_final.columns:
    counts_final = counts_final.set_index("Gene_Name").drop(columns=["Gene_ID"])

else:
    # Si ya está en Gene_Name o no hay Gene_ID explícito, solo renombramos por claridad
    counts_final.index.name = counts_final.index.name or "Gene_Name"


# 2) Si hay Gene_Name duplicados, sumarlos
if counts_final.index.duplicated().any():
    counts_final = counts_final.groupby(counts_final.index).sum()
    counts_final.index.name = "Gene_Name"

# 2) Normalizar counts
counts_norm = counts_final.div(size_factors, axis=1)

print("Raw shape:", counts_final.shape)
print("Norm shape:", counts_norm.shape)

# 3) Log2 de counts normalizados
counts_norm_log2 = np.log2(counts_norm + 1)


ARN_DHT_R1    1.035209
ARN_DHT_R2    0.655103
ARN_R1        0.983639
ARN_R2        1.270946
BIC_DHT_R1    1.002442
BIC_DHT_R2    0.674639
BIC_R1        0.863534
BIC_R2        1.034649
CTL_DHT_R1    1.000000
CTL_DHT_R2    0.551358
ENZ_DHT_R1    1.151764
ENZ_DHT_R2    0.888918
ENZ_R1        0.978750
ENZ_R2        1.117474
CTL_R1        0.948608
CTL_R2        1.069803
Name: size_factor, dtype: float64
Raw shape: (39374, 16)
Norm shape: (39374, 16)


In [38]:
counts_norm_log2

Unnamed: 0_level_0,ARN_DHT_R1,ARN_DHT_R2,ARN_R1,ARN_R2,BIC_DHT_R1,BIC_DHT_R2,BIC_R1,BIC_R2,CTL_DHT_R1,CTL_DHT_R2,ENZ_DHT_R1,ENZ_DHT_R2,ENZ_R1,ENZ_R2,CTL_R1,CTL_R2
Gene_Name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
A1BG,8.744608,7.910834,8.276387,7.898561,8.120616,7.754419,8.255323,8.208410,7.813781,8.044324,8.912893,7.910596,8.240338,8.348020,8.379663,8.021959
A1BG-AS1,8.866062,8.585510,8.734549,8.528375,8.686488,8.472747,8.734702,8.547178,8.588715,8.322884,8.886669,8.616502,8.982211,8.598374,8.860583,8.651082
A1CF,5.273275,4.907038,5.546848,5.122386,4.640479,5.133084,4.904523,4.707596,3.000000,3.331781,5.159037,4.971498,5.456540,4.370671,5.718489,5.227099
A2M,4.344812,3.723757,3.929115,3.586815,2.582031,3.983929,3.186853,2.957094,3.459432,4.114646,3.717569,3.614667,3.027151,3.850313,3.878166,4.078164
A2M-AS1,6.179573,6.154169,6.308899,6.287397,5.777906,6.202251,5.880235,6.218294,5.727920,5.838570,5.706690,5.999954,6.159833,6.163804,5.957502,6.053820
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
ZYG11A,6.833151,6.602410,6.131982,6.599787,6.581481,6.172912,6.115281,6.218294,6.303781,6.010773,6.290439,6.560287,6.007781,6.163804,5.907705,6.631525
ZYG11B,10.219245,11.092666,10.809238,11.143580,10.653802,11.024860,10.857211,11.050229,10.710806,10.895871,10.053722,10.912936,10.639223,10.904581,10.765196,11.117624
ZYX,11.478060,10.957445,11.422000,11.001869,11.194314,10.741084,11.211579,10.790877,11.661333,11.003026,11.418734,10.958506,11.422757,11.173539,11.519033,10.968118
ZZEF1,11.625168,11.721769,11.828327,11.706078,11.876066,11.626938,11.880057,11.642617,12.298349,12.027958,11.582895,11.647580,12.010048,11.674277,11.944533,11.713255


In [39]:
counts_norm_log2_transpose = counts_norm_log2.T
counts_norm_log2_transpose["Label"] = ["ARN" if "ARN" in lab_index else "BIC" if "BIC" in lab_index else "ENZ" if "ENZ" in lab_index else "CTL" for lab_index in counts_norm_log2_transpose.index]
counts_norm_log2_transpose

Gene_Name,A1BG,A1BG-AS1,A1CF,A2M,A2M-AS1,A2ML1,A2MP1,A3GALT2,A4GALT,A4GNT,...,ZWINT,ZXDA,ZXDB,ZXDC,ZYG11A,ZYG11B,ZYX,ZZEF1,ZZZ3,Label
ARN_DHT_R1,8.744608,8.866062,5.273275,4.344812,6.179573,6.86936,2.543481,2.764671,7.207797,0.975255,...,9.645367,8.084684,9.483476,10.744522,6.833151,10.219245,11.47806,11.625168,11.040213,ARN
ARN_DHT_R2,7.910834,8.58551,4.907038,3.723757,6.154169,7.249078,3.881498,2.018974,7.305853,0.0,...,10.469205,8.727356,10.08118,10.933998,6.60241,11.092666,10.957445,11.721769,11.538469,ARN
ARN_R1,8.276387,8.734549,5.546848,3.929115,6.308899,6.398554,3.020846,3.020846,7.111089,1.600872,...,8.811648,8.771937,10.046144,10.929174,6.131982,10.809238,11.422,11.828327,11.360297,ARN
ARN_R2,7.898561,8.528375,5.122386,3.586815,6.287397,6.8367,2.516241,2.052159,7.082633,2.516241,...,10.308976,8.822195,10.129092,10.936764,6.599787,11.14358,11.001869,11.706078,11.544085,ARN
BIC_DHT_R1,8.120616,8.686488,4.640479,2.582031,5.777906,7.362825,2.319114,1.997362,6.764698,0.0,...,10.325159,8.500317,10.034032,10.936797,6.581481,10.653802,11.194314,11.876066,11.20898,BIC
BIC_DHT_R2,7.754419,8.472747,5.133084,3.983929,6.202251,7.608647,1.311663,1.987157,7.038263,2.445416,...,11.248766,8.57722,10.106341,10.845885,6.172912,11.02486,10.741084,11.626938,11.426319,BIC
BIC_R1,8.255323,8.734702,4.904523,3.186853,5.880235,7.254577,2.161596,1.729472,7.164283,0.0,...,9.39869,8.641557,10.107801,11.012465,6.115281,10.857211,11.211579,11.880057,11.282375,BIC
BIC_R2,8.20841,8.547178,4.707596,2.957094,6.218294,7.474692,2.544129,1.963302,7.208573,0.975639,...,10.728143,8.845782,10.055503,10.977406,6.218294,11.050229,10.790877,11.642617,11.418482,BIC
CTL_DHT_R1,7.813781,8.588715,3.0,3.459432,5.72792,7.754888,3.169925,1.584963,6.882643,1.584963,...,11.385323,8.491853,10.139551,10.883407,6.303781,10.710806,11.661333,12.298349,11.181773,CTL
CTL_DHT_R2,8.044324,8.322884,3.331781,4.114646,5.83857,7.760427,2.687311,1.492471,6.399691,1.492471,...,12.002037,8.787022,10.09388,10.913614,6.010773,10.895871,11.003026,12.027958,11.402505,CTL


In [None]:
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold, cross_val_score

# ---------------------------------------------------------
# 1) Construir matriz de ML (muestras x genes) + Label multiclass
# ---------------------------------------------------------
X_df = counts_norm_log2.T  # ahora filas=muestras, columnas=genes

def infer_label(sample_name: str) -> str:
    s = str(sample_name)
    if "ARN" in s: return "ARN"
    if "BIC" in s: return "BIC"
    if "ENZ" in s: return "ENZ"
    return "CTL"

labels = pd.Series([infer_label(idx) for idx in X_df.index], index=X_df.index, name="Label")

# ---------------------------------------------------------
# 2) Función: entrenar modelo 1 vs resto y devolver:
#    - dataframe de importancias ordenado
#    - métrica de precisión/accuracy (CV)
# ---------------------------------------------------------
def train_one_vs_rest_rf(
    drug: str,
    X: pd.DataFrame,
    y_multiclass: pd.Series,
    n_estimators: int = 2000,
    random_state: int = 42
):
    # y binaria: 1 si es la droga, 0 si es el resto
    y_bin = (y_multiclass == drug).astype(int)

    # Validación cruzada estratificada (pocas muestras -> folds pequeños)
    skf = StratifiedKFold(n_splits=4, shuffle=True, random_state=random_state)

    rf = RandomForestClassifier(
        n_estimators=n_estimators,
        random_state=random_state,
        n_jobs=-1,
        class_weight="balanced_subsample"
    )

    # Accuracy como "precisión" global de clasificación
    cv_acc = cross_val_score(rf, X, y_bin, cv=skf, scoring="accuracy")
    acc_mean = float(np.mean(cv_acc))
    acc_std  = float(np.std(cv_acc))
    err_mean = 1.0 - acc_mean

    # Entrenar en todo el set para ranking de genes
    rf.fit(X, y_bin)

    imp_df = pd.DataFrame({
        "Gene_Name": X.columns,
        "Importance": rf.feimp_ARNature_importances_
    }).sort_values("Importance", ascending=False).reset_index(drop=True)

    metrics = {
        "drug": drug,
        "cv_accuracy_mean": acc_mean,
        "cv_accuracy_std": acc_std,
        "cv_error_mean": err_mean,
        "n_samples": int(X.shape[0]),
        "n_features": int(X.shape[1]),
        "positive_samples": int(y_bin.sum()),
        "negative_samples": int((1 - y_bin).sum())
    }

    return imp_df, metrics


# ---------------------------------------------------------
# 3) Entrenar 3 modelos: ARN vs resto, BIC vs resto, ENZ vs resto
# ---------------------------------------------------------
drugs_to_model = ["ARN", "BIC", "ENZ"]

importance_tables = {}
metrics_list = []

for d in drugs_to_model:
    imp_df, met = train_one_vs_rest_rf(d, X_df, labels)
    importance_tables[d] = imp_df
    metrics_list.append(met)

# ---------------------------------------------------------
# 4) Dataframes finales (los 3 que quieres)
# ---------------------------------------------------------
imp_ARN = importance_tables["ARN"]
imp_BIC = importance_tables["BIC"]
imp_ENZ = importance_tables["ENZ"]

# ---------------------------------------------------------
# 5) Resumen de precisión/error por modelo
# ---------------------------------------------------------
metrics_df = pd.DataFrame(metrics_list)
print(metrics_df[["drug", "cv_accuracy_mean", "cv_accuracy_std", "cv_error_mean",
                  "positive_samples", "negative_samples"]])

# ---------------------------------------------------------
print("\nTop 20 ARN vs resto:")
print(imp_ARN.head(20))

print("\nTop 20 BIC vs resto:")
print(imp_BIC.head(20))

print("\nTop 20 ENZ vs resto:")
print(imp_ENZ.head(20))

  drug  cv_accuracy_mean  cv_accuracy_std  cv_error_mean  positive_samples  \
0  ARN            0.7500         0.000000         0.2500                 4   
1  BIC            0.7500         0.000000         0.2500                 4   
2  ENZ            0.6875         0.108253         0.3125                 4   

   negative_samples  
0                12  
1                12  
2                12  

Top 20 ARN vs resto:
       Gene_Name  Importance
0      LINC00578    0.004536
1          HOXA7    0.003528
2        PPP2R2A    0.003528
3           FDX1    0.003024
4         UBE2D3    0.003024
5             C3    0.003024
6          HSDL2    0.002868
7        MIR9983    0.002520
8            GCA    0.002520
9           OAZ2    0.002520
10     LINC02327    0.002520
11      SNORD149    0.002520
12       ARL6IP5    0.002520
13    ZNF516-AS1    0.002520
14     LINC02021    0.002443
15      C14orf28    0.002443
16  LOC105376154    0.002443
17        SLC51A    0.002436
18    PIK3CD-AS2    0.0023