In [21]:
# Common imports and constants
import os
import pandas as pd
from IPython.display import display, Markdown
from tqdm.notebook import tqdm
#from cvfgaou import aou, survival, gctools

DATASET = os.environ["WORKSPACE_CDR"]
BUCKET = os.environ["WORKSPACE_BUCKET"]
DATA_DIR = f'{BUCKET}/data_v1'

In [2]:
source_df = pd.read_parquet(f'{BUCKET}/survival-estimates/survival-estimates_2025-10-19.parquet')

In [3]:
curves_df = source_df.reset_index().copy()

In [4]:
curves_df = curves_df[
    #(curves_df['Gene'] == 'GCK') &
    (curves_df['Using surveys'] == False)
]

In [5]:
curves_df

In [12]:
curves_df.Classification.value_counts()

In [23]:
import matplotlib.pyplot as plt
from cvfgaou import plot
from pathlib import Path
import re

In [31]:
pdfs_dir = Path(f'survival-plot-pdfs_2025-10-19')
pdfs_dir.mkdir(exist_ok=True, parents=True)
svgs_dir = Path(f'survival-plot-svgs_2025-10-19')
svgs_dir.mkdir(exist_ok=True, parents=True)

for (gene, dataset, classifier), plot_df in tqdm(curves_df.groupby(['Gene', 'Dataset', 'Classifier'])):

    # Remove curves with 2 or fewer points, and remove all curves except +-1 from calibrated sets
    display_curves = [
        curve_df
        for (classification, surveys), curve_df, in plot_df.groupby(['Classification', 'Using surveys'])
        if curve_df.shape[0] > 2 and (
            (not (classifier.startswith('Calibrated') or classifier.endswith('points'))) or
            (classification in {'≤ -1', '≥ +1'})
        )
    ]
    if not display_curves or dataset=='N/A': continue

    #plot_df = pd.concat(display_curves)

    plot_df = pd.concat(
        display_curves + [
            curves_df[
                (curves_df['Gene'] == gene) &
                (curves_df['Classifier'] == 'Baseline')
            ].assign(Dataset=dataset, Classifier=classifier)
            for classifier in plot_df.Classifier.drop_duplicates()
        ]
    )        

    fig = plt.figure(layout='constrained')
    fig.suptitle(f'{gene}: {dataset}, {classifier}')
    ax = fig.subplots(1, 1)

    for i, (classification, ax_df) in enumerate(plot_df.groupby('Classification')):
        selected_color = plt.rcParams['axes.prop_cycle'].by_key()['color'][i]
        # So hacky
        if classification == 'Full cohort':
            selected_color = 'gray'
        if classification == '≤ -1':
            selected_color = '#63a1c4'
        if classification == '≥ +1':
            selected_color = '#e6b1b8'

        ax.fill_between(
            ax_df['Age'],
            ax_df['Survival_LI'],
            ax_df['Survival_UI'],
            facecolor=selected_color,
            alpha=0.5,
            label=classification
        )
        ax.plot(
            ax_df['Age'],
            ax_df['Survival to onset'],
            color=selected_color,
            label=classification
        )

    ax.set_xlabel("Age")
    ax.set_ylabel("Proportion unaffected")
    handles, labels = ax.get_legend_handles_labels()
    combined_handles, combined_labels = zip(*(
        (tuple(h for h, l in zip(handles, labels) if l == label), label)
        for label in set(labels)
    ))
    ax.legend(combined_handles, combined_labels, loc='lower left')#handles, labels)

    filename = re.sub('[^0-9a-zA-Z()\-]+', '_', f'{gene}_{dataset}_{classifier}')
    fig.savefig(
        pdfs_dir/f'{filename}.pdf',
        format='pdf',
        dpi=600
    )
    fig.savefig(
        svgs_dir/f'{filename}.svg',
        format='svg',
        dpi=600
    )
    
    plt.close(fig)

In [1]:
!zip survival-plot-svgs_2025-10-19.zip survival-plot-svgs_2025-10-19/*

# Old stuff

In [3]:
gene_phenotypes = {
    'APP': "Alzheimer's",
    'BAP1': "Cancer: melanoma, mesothelioma, liver, or kidney",
    'BARD1': 'Cancer: ovarian',
    'BRCA1': "Cancer: breast, ovarian, pancreatic, or prostate",
    'BRCA2': "Cancer: breast, ovarian, pancreatic, or prostate",
    'BRIP1': 'Cancer: ovarian',
    'CALM1': 'Long QT',
    'CALM2': 'Long QT',
    'CALM3': 'Long QT',
    'GCK': "MODY",
    'KCNH2': 'Long QT',
    'KCNQ4': "Nonsyndromic genetic hearing loss",
    'MSH2': "Cancer: colorectal, ovarian, stomach, small bowel, urinary tract, brain, breast, or endometrial",
    'OTC': "Ornithine carbamoyltransferase deficiency (X-linked)",
    'PALB2': "Cancer: breast or ovarian",
    'PRKN': "Parkinson's (Autosomal recessive)",
    'PTEN': "Cancer: liver, thyroid, colorectal, breast, endometrial or melanoma",
    'RAD51C': "Cancer: breast, ovarian, or prostate",
    'RAD51D': "Cancer: breast or ovarian",
    'SCN5A': 'Long QT',
    'SNCA': "Parkinson's",
    'TARDBP': "ALS",
    'TP53': "Any cancer",
    'VWF': "von Willebrand disorder"
}

In [8]:
from cvfgaou.notation import DOE_TABLE

cls_seq = [
    f'{inequality} {sign}{points}'
    for points in (8,4,3,2,1)
    for inequality, sign, _, _, _, _ in DOE_TABLE
]

cls_seq

In [9]:
drop_list = []
for (gene, dataset, classifier), df in curves_df.groupby(['Gene', 'Dataset', 'Classifier']):
    
    if not classifier.startswith('Calibrated'):
        continue
    
    for i in range(len(cls_seq)):
        
        if (gene, dataset, classifier, cls_seq[i]) in drop_list:
            continue
            
        curve1 = df.loc[
            (df['Classification'] == cls_seq[i]),
            ['Age', 'Survival to onset', 'Survival_LI', 'Survival_UI']
        ].to_numpy()
        for j in range(i+1, len(cls_seq)):
            curve2 = df.loc[
                (df['Classification'] == cls_seq[j]),
                ['Age', 'Survival to onset', 'Survival_LI', 'Survival_UI']
            ].to_numpy()

            test = (curve1 == curve2)

            if test is not False and test.all(axis=None):
                drop_list.append((gene, dataset, classifier, cls_seq[j]))

drop_list

In [12]:
for gene, gene_df in curves_df.groupby('Gene'):

    display(Markdown(f"## {gene}"))
    display(Markdown(f"**Phenotype: {gene_phenotypes[gene]}**"))
    
    for (dataset, classifier), plot_df in gene_df.groupby(['Dataset', 'Classifier']):

        # Remove curves with 2 or fewer points
        display_curves = [
            curve_df
            for (classification, surveys), curve_df, in plot_df.groupby(['Classification', 'Using surveys'])
            if (gene, dataset, classifier, classification) not in drop_list and curve_df.shape[0] > 2
        ]
        if not display_curves or dataset=='N/A': continue

        plot_df = pd.concat(display_curves)

        plot_df = pd.concat(
            [plot_df] + [
                curves_df[
                    (curves_df['Gene'] == gene) &
                    (curves_df['Classifier'] == 'Baseline')
                ].assign(Dataset=dataset, Classifier=classifier)
                for classifier in plot_df.Classifier.drop_duplicates()
            ]
        )        
        
        fig = plt.figure(layout='constrained')
        fig.suptitle(dataset)
        ax = fig.subplots(1, 1)
        
        for classification, ax_df in plot_df.groupby('Classification'):
            selected_color = plot.points_colors.get(classification)
            if selected_color is None:
                selected_color = 'gray'
                
            ax.fill_between(
                ax_df['Age'],
                ax_df['Survival_LI'],
                ax_df['Survival_UI'],
                facecolor=selected_color,
                alpha=0.5,
                label=classification
            )
            ax.plot(
                ax_df['Age'],
                ax_df['Survival to onset'],
                color=selected_color,
                label=classification
            )
            
        ax.set_xlabel("Age")
        #ax.set_ylabel("what to call??")
        handles, labels = ax.get_legend_handles_labels()
        combined_handles, combined_labels = zip(*(
            (tuple(h for h, l in zip(handles, labels) if l == label), label)
            for label in set(labels)
        ))
        fig.legend(combined_handles, combined_labels, loc='outside center right')#handles, labels)

        display(fig)

In [5]:
# Make classification styling consistent
# Do this by replacing underscores with spaces and capitalizing
# Explicitly manage labels containing acronyms
explicit_map = {
    'possiblyLOF': 'Possibly LOF',
    'possiblyWT': 'Possibly WT',
    'moderate LOF': 'Moderate LOF',
    'severe LOF': 'Severe LOF'
}

curves_df['Classification'] = curves_df['Classification'].where(
    curves_df['Classification'].isin({
        'FUNC', 'INT', 'LOF', 'AF < 0.001', 'GOF', 'WT',
        'possiblyLOF', 'possiblyWT', 'moderate LOF', 'severe LOF'
    }),
    curves_df['Classification'].str.replace('_', ' ').str.capitalize()
).mask(
    curves_df['Classification'].isin(explicit_map),
    curves_df['Classification'].map(explicit_map)
)

In [6]:
# Make rare variant classification labels distinct for our different data sources
rv_rows = curves_df['Dataset'] == 'Rare Variants'
curves_df.loc[
    rv_rows,
    'Classification'
] = curves_df.loc[rv_rows, 'Classifier'].str.cat(curves_df.loc[rv_rows, 'Classification'], sep=' ')

curves_df[rv_rows]

In [7]:
# Order Classifications
curves_df['Classification'] = pd.Categorical(
    curves_df['Classification'],
    categories = [
        # Baselines
        'Full cohort', # Added for curves
        'All of Us AF < 0.001',
        'gnomAD AF < 0.001',
        # Clinvar and AM
        'Benign',
        'Likely benign',
        'Ambiguous',
        'Uncertain significance',
        'Likely pathogenic',
        'Pathogenic',
        # Calibrations
        'Benign very strong',
        'Benign strong',
        'Benign moderate',
        'Benign supporting',
        'Pathogenic supporting',
        'Pathogenic moderate',
        'Pathogenic strong',
        'Pathogenic very strong',
        # Author category groups
        # Depleted - enriched
        'Depleted',
        'Unchanged',
        'Enriched',
        # GOF - LOF
        'GOF',
        'FUNC',
        'INT',
        'WT',
        'Possibly WT',
        'Possibly LOF',
        'Moderate LOF',
        'Severe LOF',
        'LOF',
        # Other one-offs
        'Deleterious',
        'Functional',
        'Uncertain',
        'Normal',
        'Intermediate',
        'Abnormal',
        'Amorphic',
        'Hypomorphic',
        'Unimpaired',
        '0.0',
        '1.0',
        'Fast depleted',
        'Slow depleted',
        'Functionally abnormal',
        'Functionally normal'
    ]
)

In [8]:
curves_df

In [19]:
# Don't use surveys
curves_df = curves_df[~curves_df['Using surveys']]

# Rename Y-axis
curves_df.rename(columns={'Survival to onset': 'Proportion unaffected'}, inplace=True)

In [12]:
#%%capture --no-stderr --no-stdout cap

import plotnine as p9

#curves_df['Surveys'] = curves_df['Using surveys'].apply(lambda s: "With surveys" if s else "Without surveys")

for gene, gene_df in curves_df.groupby('Gene'):

    if gene == 'VWF': continue

    display(Markdown(f"## {gene}"))
    display(Markdown(f"**Phenotype: {gene_phenotypes[gene]}**"))
    
    for (dataset, classifier), plot_df in gene_df.groupby(['Dataset', 'Classifier']):

        plot_df = pd.concat(
            [plot_df] + [
                curves_df[
                    (curves_df['Gene'] == gene) &
                    (curves_df['Classifier'] == 'Baseline')
                ].assign(Dataset=dataset, Classifier=classifier)
                for classifier in plot_df.Classifier.drop_duplicates()
            ]
        )
        
        # Remove curves with 2 or fewer points
        plot_df = pd.concat(
            curve_df for _, curve_df, in plot_df.groupby(['Classification', 'Using surveys']) if curve_df.shape[0] > 2
        )

        #plot_df['Classification'] = plot_df['Classification'].cat.remove_unused_categories()

        #n_facets = plot_df['Classifier'].nunique()

        plot = (
            p9.ggplot(
                data=plot_df,
                mapping=p9.aes(
                    x='Age',
                    y='Proportion unaffected',
                    ymin='Survival_LI',
                    ymax='Survival_UI',
                    color='Classification',
                    fill='Classification',
                    #linetype='Classification'
                )
            )
            + p9.geom_ribbon(alpha=0.3, color='none')
            + p9.geom_line()
            + p9.facet_grid('Classifier ~ Gene + Dataset')
            #+ p9.theme(strip_text_y = p9.element_text(angle=0))
            #+ p9.theme(figure_size=(6, 3*n_facets))
            + p9.ggtitle(f'{gene}\n{dataset}\n{classifier}')
        )
        display(plot)

In [11]:
import pickle
with open('curve-cap.pickle', 'wb') as outstream:
    pickle.dump(cap, outstream)