<a href="https://colab.research.google.com/github/JosselinPerret/ENSxQRT-Data-Challenge/blob/main/ENSxQRT_Data_Challengeipynb.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<p align="center">
  <img src="https://upload.wikimedia.org/wikipedia/fr/8/86/Logo_CentraleSup%C3%A9lec.svg" alt="Logo 1" width="250"/>
  <img src="https://www.qube-rt.com/img/qrt.svg" alt="Logo 2" width="400" style="margin: 20px;"/>
</p>

# Data Challenge : Leukemia Risk Prediction


*GOAL OF THE CHALLENGE and WHY IT IS IMPORTANT:*

The goal of the challenge is to **predict disease risk for patients with blood cancer**, in the context of specific subtypes of adult myeloid leukemias.

The risk is measured through the **overall survival** of patients, i.e. the duration of survival from the diagnosis of the blood cancer to the time of death or last follow-up.

Estimating the prognosis of patients is critical for an optimal clinical management.
For exemple, patients with low risk-disease will be offered supportive care to improve blood counts and quality of life, while patients with high-risk disease will be considered for hematopoietic stem cell transplantion.

The performance metric used in the challenge is the **IPCW-C-Index**.

*THE DATASETS*

The **training set is made of 3,323 patients**.

The **test set is made of 1,193 patients**.

For each patient, you have acces to CLINICAL data and MOLECULAR data.

The details of the data are as follows:

- OUTCOME:
  * OS_YEARS = Overall survival time in years
  * OS_STATUS = 1 (death) , 0 (alive at the last follow-up)

- CLINICAL DATA, with one line per patient:
  
  * ID = unique identifier per patient
  * CENTER = clinical center
  * BM_BLAST = Bone marrow blasts in % (blasts are abnormal blood cells)
  * WBC = White Blood Cell count in Giga/L
  * ANC = Absolute Neutrophil count in Giga/L
  * MONOCYTES = Monocyte count in Giga/L
  * HB = Hemoglobin in g/dL
  * PLT = Platelets coutn in Giga/L
  * CYTOGENETICS = A description of the karyotype observed in the blood cells of the patients, measured by a cytogeneticist. Cytogenetics is the science of chromosomes. A karyotype is performed from the blood tumoral cells. The convention for notation is ISCN (https://en.wikipedia.org/wiki/International_System_for_Human_Cytogenomic_Nomenclature). Cytogenetic notation are: https://en.wikipedia.org/wiki/Cytogenetic_notation. Note that a karyotype can be normal or abnornal. The notation 46,XX denotes a normal karyotype in females (23 pairs of chromosomes including 2 chromosomes X) and 46,XY in males (23 pairs of chromosomes inclusing 1 chromosme X and 1 chromsome Y). A common abnormality in the blood cancerous cells might be for exemple a loss of chromosome 7 (monosomy 7, or -7), which is typically asssociated with higher risk disease

- GENE MOLECULAR DATA, with one line per patient per somatic mutation. Mutations are detected from the sequencing of the blood tumoral cells.
We call somatic (= acquired) mutations the mutations that are found in the tumoral cells but not in other cells of the body.

  * ID = unique identifier per patient
  * CHR START END = position of the mutation on the human genome
  * REF ALT = reference and alternate (=mutant) nucleotide
  * GENE = the affected gene
  * PROTEIN_CHANGE = the consequence of the mutation on the protei that is expressed by a given gene
  * EFFECT = a broad categorization of the mutation consequences on a given gene.
  * VAF = Variant Allele Fraction = it represents the **proportion** of cells with the deleterious mutations.

# Dependencies

In [None]:
pip install scikit-survival

Collecting scikit-survival
  Downloading scikit_survival-0.25.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (7.2 kB)
Collecting ecos (from scikit-survival)
  Downloading ecos-2.0.14-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.0 kB)
Collecting osqp<1.0.0,>=0.6.3 (from scikit-survival)
  Downloading osqp-0.6.7.post3-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.9 kB)
Collecting qdldl (from osqp<1.0.0,>=0.6.3->scikit-survival)
  Downloading qdldl-0.1.7.post5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.7 kB)
Downloading scikit_survival-0.25.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (3.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.9/3.9 MB[0m [31m59.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading osqp-0.6.7.post3-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64

In [None]:
pip install xgboost



In [None]:
import pandas as pd
import numpy as np
from sksurv.util import Surv
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.impute import SimpleImputer
from sksurv.linear_model import CoxPHSurvivalAnalysis
from sentence_transformers import SentenceTransformer
import matplotlib.pyplot as plt
from sklearn.tree import plot_tree
from sklearn.model_selection import train_test_split
from sksurv.ensemble import RandomSurvivalForest
from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.metrics import concordance_index_censored , concordance_index_ipcw
from sklearn.impute import SimpleImputer
import xgboost as xgb
from sklearn.metrics import accuracy_score
import re

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Data processing & features engineering

In [None]:
# Clinical Data
clin_tr = pd.read_csv("/content/drive/My Drive/Colab Notebooks/qrt data challenge/X_train/clinical_train.csv")
clin_eval = pd.read_csv("/content/drive/My Drive/Colab Notebooks/qrt data challenge/X_test/clinical_test.csv")

# Molecular Data
mol_tr = pd.read_csv("/content/drive/My Drive/Colab Notebooks/qrt data challenge/X_train/molecular_train.csv")
mol_eval = pd.read_csv("/content/drive/My Drive/Colab Notebooks/qrt data challenge/X_test/molecular_test.csv")

y_tr = pd.read_csv("/content/drive/My Drive/Colab Notebooks/qrt data challenge/target_train.csv")
# y_eval = pd.read_csv("/content/drive/My Drive/target_test.csv")

In [None]:
numerical_cols = ['BM_BLAST', 'WBC', 'ANC', 'MONOCYTES', 'HB', 'PLT']

for col in numerical_cols:
    median_val = clin_tr[col].median()
    clin_tr[col] = clin_tr[col].fillna(median_val)
    clin_eval[col] = clin_eval[col].fillna(median_val)

In [None]:
def normalize_iscn(s: str) -> str:
    if pd.isna(s): return ""
    return re.sub(r"\s+", "", s.upper())

# convenience helpers
def has(pattern, s):
    return bool(re.search(pattern, s))

def any_has(patterns, s):
    return any(has(p, s) for p in patterns)

# core extractors
RE_COUNT = re.compile(r"^(\d+),")
RE_SEX   = re.compile(r"^(\d+),(XX|XY)")

ADVERSE = {
    "minus5_or_del5q":  r"(^|[\W])(-5|DEL\(5Q\))([\W]|$)",
    "minus7_or_del7q":  r"(^|[\W])(-7|DEL\(7Q\))([\W]|$)",
    "inv3_or_t3_3":     r"(INV\(3\(Q21\.3?Q26\.2?\)\)|T\(3;3\)\(Q21\.3?;Q26\.2?\))|INV\(3\)|T\(3;3\)",
    "t_6_9":            r"T\(6;9\)",
    "t_v_11q23":        r"T\([0-9XYV];11\)|T\(11;[0-9XYV]\)|11Q23|KMT2A|MLL",  # KMT2A(MLL)
    "t_9_22":           r"T\(9;22\)",
    "abn17p_or_minus17":r"ABN\(17P\)|DEL\(17P\)|-17|TP53|17P",  # abn17p often implies TP53 region loss
    "t_8_16":           r"T\(8;16\)",
    "t_3q26_2_v":       r"T\(3Q26\.2;[0-9XYV]\)|MECOM|EVI1",
}

FAVORABLE = {
    "t_8_21":           r"T\(8;21\)",
    "inv16_or_t16_16":  r"INV\(16\)|T\(16;16\)",
    "t_15_17":          r"T\(15;17\)",  # APL
}

INTERMEDIATE = {
    "t_9_11":           r"T\(9;11\)",
    "plus8":            r"(\+8)([\W]|$)",
}

def count_monosomies(s):
    # count autosomal monosomies: -1..-22
    return len(re.findall(r"-(?:[1-9]|1[0-9]|2[0-2])([\W]|$)", s))

def count_trisomies(s):
    return len(re.findall(r"\+(?:[1-9]|1[0-9]|2[0-2])([\W]|$)", s))

def count_structural(s):
    return len(re.findall(r"(DEL|DUP|DER|ADD|INS|INV|T)\(", s))

def is_monosomal_karyotype(s):
    autosomal_mono = count_monosomies(s)
    has_struct = count_structural(s) > 0
    return autosomal_mono >= 2 or (autosomal_mono >= 1 and has_struct)

def is_complex_karyotype(s):
    # VERY simple proxy: total abnormalities = trisomies + monosomies + structural
    n = count_trisomies(s) + count_monosomies(s) + count_structural(s)
    return n >= 3

def parse_row(cell):
    x = normalize_iscn(cell)
    d = {}

    # core
    m = RE_COUNT.search(x)
    d["chr_count"] = int(m.group(1)) if m else None
    m = RE_SEX.search(x)
    d["sex_karyotype"] = m.group(2) if m else None
    d["is_normal"] = int(x in ("46,XX", "46,XY") or bool(re.fullmatch(r"46,(XX|XY)\[\d+\]", x)))

    # ploidy
    if d["chr_count"] is None:
        d["hypodiploid"]=d["diploid"]=d["hyperdiploid"]=None
    else:
        d["hypodiploid"] = int(d["chr_count"] < 46)
        d["diploid"]     = int(d["chr_count"] == 46)
        d["hyperdiploid"]= int(d["chr_count"] > 46)

    # counts
    d["n_monosomies"] = count_monosomies(x)
    d["n_trisomies"]  = count_trisomies(x)
    d["n_structural"] = count_structural(x)

    # karyotype classes
    d["complex_karyotype"]  = int(is_complex_karyotype(x))
    d["monosomal_karyotype"]= int(is_monosomal_karyotype(x))

    # recurrent lesions
    for name, pat in {**ADVERSE, **FAVORABLE, **INTERMEDIATE}.items():
        d[name] = int(has(pat, x))

    # generic flags
    d["plus8"] = int(has(INTERMEDIATE["plus8"], x))
    d["failed_or_uninformative"] = int(has(r"FAILED|INSUFFICIENT|NOMETA|UNINFORM", x))
    d["clonal_count"] = x.count("/")+1 if "/" in x else 1

    return d

In [None]:
features_tr = clin_tr["CYTOGENETICS"].apply(parse_row).apply(pd.Series)
features_eval = clin_eval["CYTOGENETICS"].apply(parse_row).apply(pd.Series)

X_tr = pd.concat([clin_tr.drop(columns=["CYTOGENETICS"]), features_tr], axis=1)
X_eval = pd.concat([clin_tr.drop(columns=["CYTOGENETICS"]), features_eval], axis=1)

In [None]:
def build_genomic_features(gene_df: pd.DataFrame,
                           min_patients: int = 10,
                           min_pct: float = 0.02) -> pd.DataFrame:
    df = gene_df.copy()

    df["ID"] = df["ID"].astype(str)
    df["GENE"] = df["GENE"].str.upper().str.strip()
    df["EFFECT"] = df["EFFECT"].str.lower().str.replace(" ", "_", regex=False)
    df["VAF"] = pd.to_numeric(df["VAF"], errors="coerce")

    agg_gene = (df.groupby(["ID","GENE"])
                  .agg(
                      has_variant=("GENE","size"),
                      max_vaf=("VAF","max"),
                      mean_vaf=("VAF","mean"), # Add mean VAF calculation
                      any_trunc=("EFFECT", lambda s: int(s.str.contains(r"frameshift|stop_gained|stop_lost|splice").any())),
                      any_missense=("EFFECT", lambda s: int(s.str.contains(r"missense|non_synonymous").any())),
                  )
                  .reset_index())
    agg_gene["has_variant"] = (agg_gene["has_variant"] > 0).astype(int)

    mat_presence = (agg_gene.pivot(index="ID", columns="GENE", values="has_variant")
                           .fillna(0).astype(int))
    mat_presence.columns = [f"mut_{g}" for g in mat_presence.columns]

    mat_vaf_max = agg_gene.pivot(index="ID", columns="GENE", values="max_vaf")
    mat_vaf_max.columns = [f"vaf_{g}_max" for g in mat_vaf_max.columns]

    mat_vaf_mean = agg_gene.pivot(index="ID", columns="GENE", values="mean_vaf") # Pivot for mean VAF
    mat_vaf_mean.columns = [f"vaf_{g}_mean" for g in mat_vaf_mean.columns] # Rename columns for mean VAF


    per_patient = (df.groupby("ID")
                     .agg(
                         n_mut_total=("GENE","size"),
                         n_genes_mutated=("GENE", pd.Series.nunique),
                         median_vaf=("VAF","median"),
                         max_vaf=("VAF","max"),
                         n_truncating=("EFFECT", lambda s: int(s.str.contains(r"frameshift|stop_gained|stop_lost|splice").sum())),
                         n_missense=("EFFECT", lambda s: int(s.str.contains(r"missense|non_synonymous").sum())),
                     ))

    freq = mat_presence.sum().sort_values(ascending=False)
    thresh = max(min_patients, int(np.ceil(min_pct * mat_presence.shape[0])))
    keep_mut_cols = freq[freq >= thresh].index.tolist()

    keep_genes = [c.replace("mut_", "") for c in keep_mut_cols]
    keep_vaf_max_cols = [f"vaf_{g}_max" for g in keep_genes if f"vaf_{g}_max" in mat_vaf_max.columns]
    keep_vaf_mean_cols = [f"vaf_{g}_mean" for g in keep_genes if f"vaf_{g}_mean" in mat_vaf_mean.columns] # Columns to keep for mean VAF

    genomic = per_patient.join([mat_presence[keep_mut_cols], mat_vaf_max[keep_vaf_max_cols], mat_vaf_mean[keep_vaf_mean_cols]], how="left") # Join with mean VAF columns


    for c in keep_mut_cols:
        genomic[c] = genomic[c].fillna(0).astype(int)

    return genomic

In [None]:
mol_tr_processed = build_genomic_features(mol_tr)
mol_eval_processed = build_genomic_features(mol_eval)

In [None]:
X_tr = X_tr.merge(mol_tr_processed,   how="left", left_on="ID", right_index=True, suffixes=("", "_mol"))
X_eval = X_eval.merge(mol_eval_processed, how="left", left_on="ID", right_index=True, suffixes=("", "_mol"))

In [None]:
X_tr["is_male"]   = (X_tr["sex_karyotype"] == "XY").astype(int)
X_eval["is_male"] = (X_eval["sex_karyotype"] == "XY").astype(int)
X_tr = X_tr.drop(columns=["sex_karyotype", "CENTER"])
X_eval = X_eval.drop(columns=["sex_karyotype", "CENTER"])

In [None]:
def align_vaf_with_flags(df):
    vaf_cols = [c for c in df.columns if c.startswith("vaf_") and c.endswith("_max")]
    for v in vaf_cols:
        gene = v.replace("vaf_", "").replace("_max", "")
        m = f"mut_{gene}"
        if m in df.columns:
            df.loc[df[m] == 0, v] = 0.0
            med = df.loc[df[m] == 1, v].median()
            df[v] = df[v].fillna(med if pd.notna(med) else 0.0)
        else:
            df[v] = df[v].fillna(0.0)
    return df

In [None]:
X_tr  = align_vaf_with_flags(X_tr)
X_eval = align_vaf_with_flags(X_eval)

In [None]:
train_cols = X_tr.columns.tolist()
val_cols = X_eval.columns.tolist()

missing_in_val = list(set(train_cols) - set(val_cols))
for col in missing_in_val:
    X_eval[col] = 0

missing_in_train = list(set(val_cols) - set(train_cols))
for col in missing_in_train:
    X_tr[col] = 0

X_eval = X_eval[train_cols]

# Model training

In [None]:
 random_state = 42

In [None]:
train_merged = pd.merge(X_tr, y_tr, on='ID', how='left')
train_merged_cleaned = train_merged.dropna(subset=['OS_STATUS', 'OS_YEARS'])

X = train_merged_cleaned.drop(['ID', 'OS_YEARS', 'OS_STATUS'], axis=1)
y = Surv.from_dataframe("OS_STATUS", "OS_YEARS", train_merged_cleaned)

In [None]:
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=random_state)

In [None]:
rsf = RandomSurvivalForest(
    n_estimators=200, min_samples_split=2, min_samples_leaf=15, n_jobs=-1, random_state=random_state
)

In [None]:
rsf.fit(X_train, y_train)

# Model evaluation

In [None]:
eval_times = np.unique(y_val["OS_YEARS"][y_val["OS_STATUS"] == 1])
c_index, concordant_pairs, discordant_pairs, tied_event_times, tied_risk_scores = concordance_index_ipcw(
    y_train, y_val, rsf.predict(X_val), tau=eval_times[-1] if len(eval_times) > 0 else None
)

print(f"IPCW-C-Index on the validation set: {c_index}")

IPCW-C-Index on the validation set: 0.6914321873252035


# Submission file

In [None]:
X_eval.drop(columns=["ID"], inplace=True)

In [None]:
mol_tr_processed = build_genomic_features(mol_tr)
mol_eval_processed = build_genomic_features(mol_eval)

In [None]:
train_cols = X_tr.columns.tolist()
val_cols = X_eval.columns.tolist()

missing_in_val = list(set(train_cols) - set(val_cols))
for col in missing_in_val:
    X_eval[col] = 0

missing_in_train = list(set(val_cols) - set(train_cols))
for col in missing_in_train:
    X_tr[col] = 0

X_eval = X_eval[train_cols]

In [None]:
train_merged = pd.merge(X_tr, y_tr, on='ID', how='left')
train_merged_cleaned = train_merged.dropna(subset=['OS_STATUS', 'OS_YEARS'])

X = train_merged_cleaned.drop(['ID', 'OS_YEARS', 'OS_STATUS'], axis=1)
y = Surv.from_dataframe("OS_STATUS", "OS_YEARS", train_merged_cleaned)

In [None]:
eval_times = np.unique(y_val["OS_YEARS"][y_val["OS_STATUS"] == 1])
c_index, concordant_pairs, discordant_pairs, tied_event_times, tied_risk_scores = concordance_index_ipcw(
    y_train, y_val, rsf.predict(X_val), tau=eval_times[-1] if len(eval_times) > 0 else None
)

print(f"IPCW-C-Index on the validation set: {c_index}")

IPCW-C-Index on the validation set: 0.6914321873252035


In [None]:
X_eval_aligned = X_eval.loc[clin_eval.index]
X_eval_aligned.drop(columns=["ID"], inplace=True)

predictions_updated = rsf.predict(X_eval_aligned)
submission_df_updated = pd.DataFrame({'ID': clin_eval['ID'], 'risk_score': predictions_updated})
submission_df_updated.to_csv("submission_10.csv", index=False)