# SigBERT *(P. Minchella et al., 2025)*

This notebook is an application companion to the paper:

**SigBERT: Combining Narrative Medical Reports and Rough Path Signature Theory for Survival Prediction in Oncology**

---

## Summary of the Method

The pipeline performs time-to-event prediction using longitudinal narrative data, following these steps:

1. **Sentence Embeddings**: Each medical report is transformed into a high-dimensional vector (typically using a language model such as OncoBERT (RoBERTa-based architecture).
2. **Dimensionality Reduction**: A linear compression (e.g., Johnson-Lindenstrauss mapping or PCA) reduces the embedding size for computational efficiency.
3. **Signature Extraction**: Path signature theory (up to order 2 or 3) is applied to capture the time dynamics of these compressed embeddings.
4. **Survival Modeling**: A LASSO-regularized Cox model is trained on these signature features to estimate risk scores and survival times.
5. **Evaluation**: C-index, time-dependent AUC, Brier Score, and Integrated Brier Score (IBS) are used for evaluation, with extensive validation over varying observation depths.

---

## Requirements for Using This Notebook with Your Data

Ensure your column names follow the naming convention described in the README.

**Note**: If your dataset only provides the event duration (in days) without an explicit `date_start`, you can create `date_start` by subtracting the duration from a fixed reference date such as `"1970-01-01"` or `"2000-01-01"`. Then compute `date_end` by adding the duration.

Once your data is properly preprocessed, you're ready to run the notebook.

In [1]:
import types
import sys
from numbers import Real, Integral

# Create a fake module to emulate 'sklearn.utils._param_validation'
# (used by skglm in newer versions of scikit-learn, >=1.3)
param_validation = types.ModuleType("sklearn.utils._param_validation")

# Define a minimal replacement for Interval used in _parameter_constraints
class Interval:
    def __init__(self, dtype, left, right, closed="neither"):
        self.dtype = dtype
        self.left = left
        self.right = right
        self.closed = closed

# Define a minimal replacement for StrOptions used in _parameter_constraints
class StrOptions:
    def __init__(self, options):
        self.options = set(options)

# Add the custom classes to the fake module
param_validation.Interval = Interval
param_validation.StrOptions = StrOptions

# Inject the fake module into sys.modules before skglm is imported
# This prevents skglm from raising an ImportError if sklearn < 1.3
sys.modules["sklearn.utils._param_validation"] = param_validation

In [2]:
import pandas as pd
import torch
import numpy as np
import time
import warnings
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

In [3]:
import os

# Add the src directory to the Python path
notebook_dir = os.path.dirname(os.path.abspath("__file__"))
src_path = os.path.abspath(os.path.join(notebook_dir, '..', 'src/sigbert'))
if src_path not in sys.path:
    sys.path.insert(0, src_path)

# Now import our custom modules
from _utils import *
from descriptive_stats_pkg import *
from compression_pkg import *
from survival_analysis_pkg import *
from metrics_plot_results_pkg import *

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
start_notebook = time.time()

# I) Data Importation

In [5]:
df_OG = global_data_import(path_import = "../data/data_real.csv", nrows=None)

FileNotFoundError: [Errno 2] No such file or directory: '../data/data_real.csv'

In [None]:
df_OG = convert_date_columns(df_OG)
df_OG['date_death'] = df_OG['date_death'] - pd.to_timedelta(100, unit='D')

In [None]:
df_OG.head(3)

In [None]:
print_dataset_statistics(df_OG)

In [None]:
plot_report_distribution_per_patient(df_OG, export_path='../results/reports_per_patients.png')

# II) Training

In [None]:
Ndays = int((df_OG['date_end'] - df_OG['date_start']).min().days)
print(f"Ndays = {Ndays}")

# Imposes a maximum number of known reports
max_reports = 221

### Train-Test Split

In [None]:
df_train_new_OG, test_groups = make_train_test(df_OG)

In [None]:
k_comp = 25
_, R_comp = pca_compression(df_train_new_OG, k_comp, verbose = True)

In [None]:
# Instance of results lists
c_index_test_results = []
df_survival_test_list = []
lambda_l1_CV = 0.7

In [None]:
df_all = df_OG.copy()

In [None]:
(
    df_results,               # Summary DataFrame with metrics for the current max_reports setup
    cph,                      # Trained Cox proportional hazards model
    df_survival,              # Survival data (event, time, risk_score) for the training set
    w_sk,                     # Risk scores for training patients
    scores,                   # Signature feature importance scores
    X,                        # Design matrix (features) used to train the model
    y_train,                  # Target array for survival analysis (event, time) before preprocess
    y_cox,                    # Target array for survival analysis (event, time)
    c_index_train,            # C-index on the training set
    c_index_test_list,        # List of C-index values on each test group
    c_index_test_mean,        # Mean C-index across test groups
    c_index_test_std,         # Standard deviation of C-index across test groups
    df_survival_test_list     # List of survival DataFrames for each test group
) = global_sigbert_process(
    max_reports,
    df_all,
    df_train_new_OG,
    test_groups,
    R_comp,
    lambda_l1_CV
)

In [None]:
df_results

## III) Results : Plots and metrics

In [None]:
print(f"Training c-index: {c_index_train:.3f}.")
print(f"Validation c-index: {c_index_test_mean:.3f} (sd {c_index_test_std:.4f}).")

In [None]:
lower_bound, upper_bound = jackknife_confidence_interval(c_index_test_list)

# Affichage des résultats
print(f"Jackknife Confidence Interval (95%): [{lower_bound:.4f}, {upper_bound:.4f}]")

In [None]:
df_survival_test_overall = pd.concat(
[df_survival_test_list[i] for i in range(len(test_groups))],
    axis=0
)

In [None]:
df_survival_all = pd.concat([df_survival, df_survival_test_overall])

In [None]:
df_label, _ = plot_risk_score_distribution_by_event(
    df_survival_all,
    export_plot="../results/risk_score_distribution.png",
    use_ttest=False
)

In [None]:
results_corr, summary_corr = evaluate_correlation(df_survival_test_list, verbose=True)

In [None]:
results_KM_pairwise, results_KM_global, quartile_groups = plot_km_by_risk_quartiles(
    df_survival_test_overall,
    export_fig=True,
    path_export_fig="../results/KM_by_quartiles.png",
    time_max_days=3650
)

In [None]:
# Assign quartiles to the DataFrame
df_survival_test_overall["Quartile"] = pd.qcut(
    df_survival_test_overall["risk_score"],
    q=4,
    labels=["Q1 (Low)", "Q2", "Q3", "Q4"]
)

# Now call the plotting function
anova_pval, kruskal_pval = plot_boxplot_log_time_by_quartile(
    df_survival=df_survival_test_overall, 
    quartile_groups=quartile_groups, 
    export_fig=True, 
    path_export_fig="../results/boxplot_log_time.png", 
    print_median_time=True
)

### td-AUC

In [None]:
mean_auc_list, mean_auc_per_time, std_auc_per_time, times = plot_dynamic_auc(
    y_train,
    df_survival_test_list,
    test_groups,
    export_fig = '../results/time-dep-AUC-curves.png'
)

### Brier Score

In [None]:
# Définir les temps d’évaluation
evaluation_times = np.linspace(100, 3650, 100)

# Ajouter des jalons cliniques
evaluation_times = np.sort(np.unique(np.concatenate((
    evaluation_times,
    [365, 730, 1095, 1825, 3650]
))))

# Appeler la méthode
brier_scores_array, bs_mean, bs_std, bs_upper, bs_lower = evaluate_brier_score_multiple_tests(
    df_survival_test_list=df_survival_test_list,
    cph=cph,
    evaluation_times=evaluation_times,
    brier_score_function=brier_score_ipcw_with_cph,
    export_fig=True,
    path_export_fig="../results/brier_score_tests.png",
    verbose=True
)

In [None]:
times_of_interest = [365, 701, 1095, 1825, 3650]

bs_results = summarize_brier_scores_at_times(
    evaluation_times=evaluation_times,
    bs_mean=bs_mean,
    bs_lower=bs_lower,
    bs_upper=bs_upper,
    times_of_interest=times_of_interest,
    verbose=True
)

In [None]:
ibs_results = compute_and_plot_ibs_with_ci(
    evaluation_times=evaluation_times,
    bs_mean=bs_mean,
    bs_lower=bs_lower,
    bs_upper=bs_upper,
    times_of_interest=[365, 701, 1095, 1825, 3650],
    plot_baseline=0.25,
    verbose=True
)

## Validation C-index by number of known reports on our true cohort

In [None]:
df_combined = pd.read_csv('../results/c-index_by_nbr_reports_on_true_data.csv')

In [None]:
plot_smoothed_cindex_by_report_count(df_combined, export_path ='../results/Mean_Cindex_by_reports.png')

In [None]:
duration_notebook = time.time() - start_notebook
print(f" NoteBook total duration: {duration_notebook:.2f}s i.e. {duration_notebook / 60:.2f}min.")