Kaplan-meier survival curves

### Action items
This notebooks is a WIP. The following tasks are outstanding:

- [ ] Debug curve workflow
- [ ] Extend to all genes of interest
- [ ] Move code to common functions
- [ ] Move common functions to github

In [1]:
# Common imports and constants
import os
from pathlib import Path
import pandas as pd
from tqdm.notebook import tqdm
from cvfgaou import aou, survival, gctools, data

DATASET = os.environ["WORKSPACE_CDR"]
BUCKET = os.environ["WORKSPACE_BUCKET"]
DATA_DIR = f'{BUCKET}/data_v2'
CURVES_VERSION = '2025-10-19'

In [2]:
# Common functions

In [3]:
# Gene case pools definitions
# Taking them from the gene phenotypes defined for the project

cohorts = aou.CohortLoader(
    gene_cohort_map = {
        gene: (
            {f'{p} case pool.tsv.gz' for p in case_phenos},
            {f'{p} control pool.tsv.gz' for p in control_phenos}
        )
        for gene, (case_phenos, control_phenos, _) in data.gene_phenotypes.items()
    },
    ancestry_df = pd.read_table(f'{DATA_DIR}/ancestry_pca.tsv.gz', index_col='research_id'),
    demo_df = pd.read_table(f'{DATA_DIR}/demo.tsv.gz', index_col='person_id'),
    data_dir = DATA_DIR
)

In [4]:
# Get event count statistics

count_events_sql = f"""
    SELECT *
    FROM (
        SELECT
            person_id,
            COUNT(DISTINCT condition_start_date) AS n_events,
            MIN(condition_start_date) AS first_event,
            MAX(condition_start_date) AS last_event
        FROM `{DATASET}.condition_occurrence`
        WHERE
            person_id IN (
                SELECT person_id FROM `{DATASET}.cb_search_person`
                WHERE has_whole_genome_variant = 1
            )
        GROUP BY person_id
    ) INNER JOIN (
        SELECT
            person_id,
            sex_at_birth,
            dob,
            age_at_cdr
        FROM `{DATASET}.cb_search_person`
    ) USING (person_id)
"""

count_events_df = pd.read_gbq(
    count_events_sql,
    index_col='person_id',
    dialect='standard',
    use_bqstorage_api=('BIGQUERY_STORAGE_API_ENABLED' in os.environ),
    progress_bar_type='tqdm_notebook'
).assign(
    first_event = lambda df: pd.to_datetime(df.first_event),
    last_event = lambda df: pd.to_datetime(df.last_event),
    dob = lambda df: pd.to_datetime(df.dob)
)

count_events_df

In [5]:
# Compute age at last event for all cohorts

age_at_last_event = (count_events_df.last_event - count_events_df.dob).dt.days / 365.25

In [6]:
# Get combined dataframes: Take earliest event time for each person (separately for conditions and surveys)
cohort_dfs = {
    gene:
        pd.concat(
            pd.concat(
                (
                    pd.read_table(f'{DATA_DIR}/{cohort}')
                    for cohort in cohort_group
                ),
                ignore_index=True
            ).groupby('person_id').min(numeric_only=True).assign(case=case_label)
            for case_label, cohort_group in zip((True, False), cohorts)
        ).join(
            age_at_last_event.rename('age_at_last_event'),
            how='left'
        )
    
    for gene, cohorts in tqdm(cohorts.gene_cohort_map.items())
}

In [7]:
# Prepare timepoints for survival curves

def get_cohort_timepoints(cohort_df, use_survey_times):
    
    cohort_df = cohort_df.copy()
    
    if use_survey_times and cohort_df.columns.isin({'survey_high_age', 'survey_low_age'}).any():
        
        cohort_df['survey_upper_limit'] = cohort_df[['survey_high_age', 'age_at_last_event']].min(axis='columns')
        cohort_df['survey_lower_limit'] = cohort_df[['survey_low_age', 'age_at_last_event']].min(axis='columns')
        cohort_df['survey_midpoint'] = cohort_df[['survey_upper_limit', 'survey_lower_limit']].mean(axis='columns')

    else:
        cohort_df['survey_midpoint'] = None

    return cohort_df.assign(
        timepoints=cohort_df.age_at_last_event.mask( # Use last event time for controls
            cohort_df.case, # For cases:
            cohort_df.condition_onset_age.mask( # Use condition onset
                cohort_df.condition_onset_age.isna(), # except where there is no condition onset, then
                cohort_df.survey_midpoint # Use survey midpoint
            )
        )
    ).dropna(axis='index', subset='timepoints').astype({'timepoints': float})


# Censor timepoints for AoU compliance

def censor_timepoints(cohort_with_timepoints, left=20, right=20):
    """ Censor 'left' earliest timepoints and 'right' latest timepoints
    
    The way this censoring is done is that the 'left' ('right') earliest (latest) timepoints
    are re-assigned the largest (smallest) value in the set.
    """
    
    left_group = cohort_with_timepoints.nsmallest(left, 'timepoints').index
    left_time = cohort_with_timepoints.loc[left_group, 'timepoints'].max()
    
    right_group = cohort_with_timepoints.nlargest(right, 'timepoints').index
    right_time = cohort_with_timepoints.loc[right_group, 'timepoints'].min()
    
    result = cohort_with_timepoints.copy()
    result.loc[left_group, 'timepoints'] = left_time
    result.loc[right_group, 'timepoints'] = right_time
    
    return result

## Curve generation

For each gene we want to have
 - The full cohort curve
 - A curve for every classification
 
When presenting, we will group plots by dataset and include the full cohort curve in every group.

For processing the data, we want to first load all exposure tables into one big dataframe, and then iterate over it.

In [8]:
!ls classes_2025-10-16/

In [9]:
!gsutil ls $WORKSPACE_BUCKET

In [10]:
!mkdir -p classes_2025-10-16
!gsutil -m cp -r $WORKSPACE_BUCKET/classes_2025-09-09/* classes_2025-10-16/
!gsutil -m cp -r $WORKSPACE_BUCKET/classes_2025-10-15/* classes_2025-10-16/
!gsutil -m cp -r $WORKSPACE_BUCKET/combined_classes_2025-10-17/* classes_2025-10-16/

In [11]:
# Load all exposure tables into our dataframe

exposures_df = pd.concat(
    (
        df
        for f in tqdm(Path('classes_2025-10-16/exposures').iterdir())
        for gene, df in pd.read_parquet(f).reset_index().groupby('Gene')
        if gene in cohorts.gene_cohort_map
        if 'person_id' in df.columns # Need to fix missing person_ids in some exposure tables
    ),
    ignore_index=True
)

In [12]:
curve_sets = {}

In [13]:
for gene, gene_df in tqdm(exposures_df.groupby('Gene')):
    # Iterate over genes first to get gene-wide tables
    cohort_df = cohort_dfs[gene]
    
    for survey in (True, False):
        # Deprecate survey timepoints altogether
        if survey:
            continue
        # If surveys exist, also produce curves with surveys
        if survey and not cohort_df.columns.str.startswith('survey_').any():
            continue
        
        # Full population curve
        curve_sets[(gene, survey, 'N/A', 'Baseline', 'Full cohort')] = survival.get_kaplan_meier_estimate(
            censor_timepoints(get_cohort_timepoints(cohort_df, survey))
        )
        
        # !Future TODO!
        # Curve without variants
        #curve_sets[(gene, survey, 'N/A', 'Baseline', 'No variants')] = survival.get_kaplan_meier_estimate(
        #    censor_timepoints(get_cohort_timepoints(cohort_df, survey))
        #)
        
        for (dataset, classifier, classification), class_df in gene_df.groupby(['Dataset', 'Classifier', 'Classification']):
            timepoints_df = censor_timepoints(get_cohort_timepoints(
                cohort_df[cohort_df.index.isin(pd.to_numeric(class_df.person_id, downcast='integer'))],
                survey
            ))

            if not timepoints_df.empty:
                curve_sets[(
                    gene, survey, dataset, classifier, classification
                )] = survival.get_kaplan_meier_estimate(timepoints_df)

In [14]:
# Combine into one big frame
curves_df = pd.concat(
    curve_sets,
    names = ['Gene', 'Using surveys', 'Dataset', 'Classifier', 'Classification']
)

In [15]:
# Save estimated

curves_df.to_parquet(f'{BUCKET}/survival-estimates/survival-estimates_{CURVES_VERSION}.parquet')

In [16]:
raise RuntimeError('Stop here')

In [18]:
!gsutil ls -lh $WORKSPACE_BUCKET/survival-estimates

In [None]:
gene_df.person_id

In [None]:
cohort_df.index

In [None]:
curves_df.reset_index().Classification.value_counts()

In [None]:
curve_sets

## Make plots

In [None]:
import plotnine as p9

for gene, plot_df in curves_df.groupby(['Gene', 'Dataset']):
    
    plot_df.reset_index(inplace=True)
    
    n_facets = plot_df['Classifier'].nunique()
    
    plot = (
        p9.ggplot(
            data=plot_df.reset_index(),
            mapping=p9.aes(
                x='Age',
                y='Survival to onset',
                ymin='Survival_LI',
                ymax='Survival_UI',
                color='Classification',
                fill='Classification'
            )
        )
        + p9.geom_ribbon(alpha=0.3, color='none')
        + p9.geom_line()
        + p9.facet_wrap('~ Gene + Dataset + Classifier', ncol=1)
        + p9.theme(
            figure_size=(6, 4*n_facets)
        )
    )
    display(plot)

In [None]:
import plotnine as p9

plot = (
    p9.ggplot(
        data=plot_df,
        mapping=p9.aes(
            x='Age',
            y='Survival to onset',
            ymin='Survival_LI',
            ymax='Survival_UI'
        )
    )
    + p9.geom_ribbon(alpha=0.3)
    + p9.geom_line()
)
plot

In [None]:
!gsutil ls $WORKSPACE_BUCKET/prelim_exposures

In [None]:
# Load exposures of interest

brca1_clinvar_df = pd.read_csv(f'{BUCKET}/prelim_exposures/clinvar_vat_BRCA1.csv')
brca1_clinvar_df

In [None]:
brca1_clinvar_df.Classification.value_counts()

In [None]:
def get_kaplan_meier_estimate(cohort_with_timepoints_df):
    
    # Censor last 20 timepoints
    to_drop = cohort_with_timepoints_df.nlargest(20, 'timepoints').index
    cohort_with_timepoints_df = cohort_with_timepoints_df.drop(to_drop)

    time, prob_survival, conf_int = kaplan_meier_estimator(
        event=cohort_with_timepoints_df.case,
        time_exit=cohort_with_timepoints_df.timepoints,
        conf_type='log-log'
    )

    return pd.DataFrame(
        {
            'Age': time,
            'Survival to onset': prob_survival,
            'Survival_LI': conf_int[0],
            'Survival_UI': conf_int[1]
        }
    )

combined_df = pd.concat(
    {
        classification: get_kaplan_meier_estimate(cohort_subset)
        for classification in brca1_clinvar_df.Classification.drop_duplicates()
        for cohort_subset in (
            brca_cohort_with_timepoints_df[
                brca_cohort_with_timepoints_df.index.isin(
                    brca1_clinvar_df.person_id[brca1_clinvar_df.Classification == classification]
                )
            ],
        )
        if not cohort_subset.empty
    },
    names = ['Classification']
).reset_index()

In [None]:
for label, df in combined_df[combined_df.Classification.str.endswith('Pathogenic')].groupby('Classification'):
    display(df)

In [None]:
big_plot_df = pd.concat(
    [combined_df, plot_df.assign(Classification='Full cohort')],
    ignore_index=True
)

In [None]:
plot = (
    p9.ggplot(
        data=big_plot_df,
        mapping=p9.aes(
            x='Age',
            y='Survival to onset',
            ymin='Survival_LI',
            ymax='Survival_UI',
            color='Classification',
            fill='Classification'
        )
    )
    + p9.geom_ribbon(alpha=0.3, color='none')
    + p9.geom_line()
)
plot

In [None]:
!gsutil ls $WORKSPACE_BUCKET

In [None]:
!gsutil ls $WORKSPACE_BUCKET/aux_data

In [None]:
brca1_vat = pd.read_table(f'{BUCKET}/aux_data/brca1_vat.tsv')
brca1_vat

In [None]:
brca1_vat.assign(
    pathogenic = brca1_vat.clinvar_classification.str.contains('pathogenic')
).groupby(['contig', 'position']).pathogenic.any().sum()

In [None]:
raise Error('Stop here')

In [None]:
cvfg_df = pd.read_csv(
    df_path,
    index_col='ID',
    usecols=[
        'ID', 'Dataset', 'Gene', 'Chrom', 'STRAND', 'ref_allele', 'alt_allele',
        'auth_reported_score', 'auth_reported_func_class',
        'gnomad_MAF',
        'clinvar_sig', 'clinvar_star'
    ],
    dtype={
        #'HGNC_id': str,
        'Chrom': str,
        'STRAND': str,
        #'auth_reported_score': float, # Contains some strings, will want to resolve in the future
        'auth_reported_func_class': str
        #'hg38_start': int,
        #'hg38_end': int
    }
)

In [None]:
cvfg_df[cvfg_df.Gene.str.contains('BRCA')][['Gene', 'Dataset', 'auth_reported_func_class']].value_counts()

In [None]:
cvfg_df

In [None]:
f'{BUCKET}/prelim_exposures/clinvar_BRCA1.csv'