In [8]:
import os
import re
from pathlib import Path

import plotnine as p9
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
from matplotlib import colors, ticker
from cvfgaou import plot

from IPython.display import Markdown

BUCKET = os.environ["WORKSPACE_BUCKET"]

## Predictor figures

In [3]:
# Load DF

# SpliceAI filtered version
predictor_or_df = pd.read_parquet(f'{BUCKET}/or-estimates/or-estimates-2025-12-11_predictors.parquet')

In [5]:
predictor_or_df.reset_index(inplace=True)

In [30]:
predictor_or_df

### MSH2 figure

In [10]:
msh2_revel_predictor_or = predictor_or_df[
    (predictor_or_df['Gene'] == 'MSH2') &
    (predictor_or_df['Dataset'] == 'REVEL')
]

In [23]:
msh2_revel_plot = p9.ggplot(
    msh2_revel_predictor_or,
    p9.aes(x="Classification", y="LogOR", color="Classifier", ymin="LogOR_LI", ymax="LogOR_UI")
) + p9.geom_pointrange(position=p9.position_dodge(width=0.3))

In [24]:
msh2_revel_plot

In [37]:
msh2_revel_predictor_or

In [40]:
np.exp(0.668224	)

### Predictor summary

In [32]:
predictor_summary = (
    p9.ggplot(
        predictor_or_df[~predictor_or_df['Few samples']],
        p9.aes(x="Classification", y="LogOR", color="Classifier", ymin="LogOR_LI", ymax="LogOR_UI")
    )
    + p9.geom_pointrange(position=p9.position_dodge(width=0.5))
    + p9.facet_grid(('Gene', 'Dataset'), scales='free_y')
    + p9.theme(figure_size=(6,20))
)
predictor_summary

## Common result dataframe

In [2]:
!gsutil ls $WORKSPACE_BUCKET/or-estimates

In [3]:
or_estimates_df = pd.read_parquet(f'{BUCKET}/or-estimates/or-estimates-2025-11-12.parquet')
or_estimates_df

## Combined points

In [4]:
combined_or_estimates_df = pd.read_parquet(f'{BUCKET}/or-estimates/combined-or-estimates-2025-10-17.parquet').reset_index()
combined_or_estimates_df = combined_or_estimates_df[combined_or_estimates_df['Dataset'].str.startswith('Combined')]

In [5]:
combined_or_estimates_df = combined_or_estimates_df[
    ~combined_or_estimates_df['Classification'].str.endswith('16')
]

In [6]:
list(combined_or_estimates_df['Classification'].drop_duplicates().sort_values())

In [7]:
combined_or_estimates_df['Classification'] = pd.Categorical(
    combined_or_estimates_df['Classification'],
    categories=[
        #'≤ -16',
        '≤ -12',
        '≤ -11',
        '≤ -10',
        '≤ -9',
        '≤ -8',
        '≤ -7',
        '≤ -6',
        '≤ -5',
        '≤ -4',
        '≤ -3',
        '≤ -2',
        '≤ -1',
        '≥ +1',
        '≥ +2',
        '≥ +3',
        '≥ +4',
        '≥ +5',
        '≥ +6',
        '≥ +7',
        '≥ +8',
        '≥ +9',
        '≥ +10',
        '≥ +11',
        '≥ +12',
        #'≥ +16'
    ],
    ordered=True
)

## Combined points

### Compare naive log-odds

In [None]:
# Sanity check
import seaborn as sns
sns.relplot(combined_or_estimates_df, x='LogOR', y='naive_logOR')

### Detailed plots showing broken up odds

In [None]:
gene_dfs = list(combined_or_estimates_df.groupby('Gene'))

combined_fig = plt.figure(layout='constrained', figsize=(10, len(gene_dfs)*3.5 + 3))

axs = combined_fig.subplots(len(gene_dfs), 4)

# Formatting setup
for ax in axs.flatten():
    ax.grid(True, which='major', axis='x', linestyle='--')
    ax.grid(True, which='minor', axis='x', linestyle=':')
    ax.xaxis.set_minor_locator(ticker.AutoMinorLocator())

# Establish Y positions
y_labels = combined_or_estimates_df.Classification.cat.remove_unused_categories().cat.categories
y_map = {
    label: index
    for index, label in enumerate(y_labels)
}
    
# Draw subplots
for row, (gene, plot_df) in enumerate(gene_dfs):
    
    # Common vars
    y_vals = plot_df.Classification.map(y_map)
    
    ##################
    # Log OR subplot #
    ##################
    ax = axs[row,0]
    
    logOR_df = plot_df[plot_df['LogOR_UI'] - plot_df['LogOR_LI'] < 50]

    ax.errorbar(
        x=logOR_df.LogOR,
        y=y_vals[logOR_df.index],
        xerr=np.absolute(logOR_df[['LogOR_LI', 'LogOR_UI']].to_numpy() - logOR_df[['LogOR']].to_numpy()).transpose(),
        fmt='.',
        capsize=4,
        capthick=1.5,
        elinewidth=1,
        color='black'
    )

    ax.axvline(x=0, color='red', linestyle=':')

    ax.set_xlabel("ln(OR)")

    ax.set_yticks(range(len(y_labels)), y_labels)
    
    ax.set_ylabel(gene)
    
    ################
    # Naive Log OR #
    ################
    ax = axs[row,1]
    
    ax.plot(
        plot_df.naive_logOR,
        y_vals,
        '+'
    )
    ax.set_yticks(range(len(y_labels)), y_labels)
    ax.axvline(x=0, color='red', linestyle=':')
    ax.set_xlabel("naive ln(OR)")
    
    ################
    # Carrier odds #
    ################
    ax = axs[row,2]
    
    ax.plot(
        np.log(plot_df.variant_odds),
        y_vals,
        '+'
    )
    ax.set_yticks(range(len(y_labels)), y_labels)
    ax.axvline(x=0, color='red', linestyle=':')
    ax.set_xlabel("ln(naive carrier odds)")

    ####################
    # Non-carrier odds #
    ####################
    ax = axs[row,3]
    
    ax.plot(
        np.log(plot_df.non_variant_odds),
        y_vals,
        '+'
    )
    ax.set_yticks(range(len(y_labels)), y_labels)
    ax.axvline(x=0, color='red', linestyle=':')
    ax.set_xlabel("ln(naive non-carrier odds)")


In [None]:
# Color codes
clinvar_colors = {
    'Benign': '#1D7AAB',
    'Likely benign': '#64A1C4',
    'Uncertain significance': '#A0A0A0',
    'Conflicting': 'grey',
    'Likely pathogenic': '#E6B1B8',
    'Pathogenic': '#CA7682',
}

cohort_colors = {
    'cases': '#CA7682',
    'overlap': 'grey',
    'controls': '#1D7AAB'
}

def grouped_bar_plot(ax, groups, y_values, bar_colors=None, xlabel=None, log_scale=False, **style_args):
    # Grouped bar positioning:
    # If we have n bars in a group, we want the width of each bar to be
    # w=1/(n+1) (to leave a bar's width worth of space between groups.
    # The first bar is then positioned at tick-w*(n-1)/2; this way, the
    # n-th bar will be at tick-w*(n-1)/2 + w*(n-1) = tick+w*(n-1)/2;
    # The resulting expression for the i-th (0-indexed) bar is
    # tick+w*(i - (n - 1)/2)

    nbars=len(groups)
    bar_size=1/(nbars+1)
    
    if bar_colors is None:
        bar_colors = list(
            list(plt.rcParams['axes.prop_cycle'])[n]['color']
            for n in range(nbars)
        )
    else:
        bar_colors = list(bar_colors)

    for n, (label, series) in enumerate(groups):

        if log_scale:
            ax.set_axisbelow(True)
            ax.set_xscale('log')
            
        color = bar_colors[n]

        bc = ax.barh(
            y=y_values - bar_size * (n - (nbars - 1)/2),
            width=series,
            height=bar_size,
            label=label,
            facecolor=colors.to_rgba(color, 0.5),
            edgecolor=color,
            **style_args
        )
        #ax.bar_label(bc, label_type='edge')
        #left += series

    if xlabel is not None:
        ax.set_xlabel(xlabel)

    #ax.legend()

def populate_classifier_fig(fig, plot_df):
    
    axs = fig.subplots(1, 4, sharey=True)
    
    for ax in axs:
        ax.grid(True, which='major', axis='x', linestyle='--')
        ax.grid(True, which='minor', axis='x', linestyle=':')
        ax.xaxis.set_minor_locator(ticker.AutoMinorLocator())

    # Establish Y positions
    y_labels = plot_df.Classification.cat.remove_unused_categories().cat.categories
    y_map = {
        label: index
        for index, label in enumerate(y_labels)
    }

    ##################
    # Log OR subplot #
    ##################

    axs[0].errorbar(
        x=plot_df.LogOR,
        y=plot_df.Classification.map(y_map),
        xerr=np.absolute(plot_df[['LogOR_LI', 'LogOR_UI']].to_numpy() - plot_df[['LogOR']].to_numpy()).transpose(),
        fmt='.',
        capsize=4,
        capthick=1.5,
        elinewidth=1,
        color='black'
    )

    axs[0].axvline(x=0, color='red', linestyle=':')

    axs[0].set_xlabel("ln(OR)")

    axs[0].set_yticks(range(len(y_labels)), y_labels)
    
    #########################
    # Variant count subplot #
    #########################

    grouped_bar_plot(
        axs[1],
        [
            ('cases', plot_df.case_only_variant_count),
            ('overlap', plot_df.overlap_variant_count),
            ('controls', plot_df.control_only_variant_count)
        ],
        plot_df.Classification.map(y_map),
        bar_colors=(cohort_colors[c] for c in ('cases', 'overlap', 'controls')),
        xlabel='Variant count'
    )

    cohort_handles, cohort_labels = axs[1].get_legend_handles_labels()
    subf.legend(
        cohort_handles,
        cohort_labels,
        loc='outside lower center',
        title='Cohort'
    )
    
    #############################
    # Patricipant count subplot #
    #############################

    grouped_bar_plot(
        axs[2],
        [
            ('cases', plot_df.cases_with_variants),
            ('controls', plot_df.controls_with_variants)
        ],
        plot_df.Classification.map(y_map),
        bar_colors=(cohort_colors[c] for c in ('cases', 'controls')),
        xlabel='Participant count',
        log_scale=True
    )
    
    #############################
    # Clinvar breakdown subplot #
    #############################
    
    clinvar_subplot_title = 'ClinVar significance'
    clinvar_categories = (
        'Pathogenic',
        'Likely pathogenic',
        'Conflicting',
        'Uncertain significance',
        'Likely benign',
        'Benign',
        #'Other / not in ClinVar'
    )
    
    class_clinvar_bars, cohort_clinvar_bars = (
        [
            (significance, plot_df[f'{prefix} {significance}'])
            for significance in clinvar_categories
        ]
        for prefix in ('Class ClinVar', 'Cohort ClinVar')
    )
    clinvar_bar_colors = [
        clinvar_colors[significance]
        for significance in clinvar_categories
    ]
    #grouped_bar_plot(
    #    axs[3],
    #    class_clinvar_bars,
    #    plot_df.Classification.map(y_map),
    #    #log_scale=True,
    #    xlabel='ClinVar significance'
    #)
    grouped_bar_plot(
        axs[3],
        cohort_clinvar_bars,
        plot_df.Classification.map(y_map),
        bar_colors=clinvar_bar_colors,
        #hatch='//'
        xlabel='Variant count by ClinVar significance'
    )
    clinvar_handles, clinvar_labels = axs[3].get_legend_handles_labels()
    subfigs[row].legend(
        clinvar_handles,
        clinvar_labels,
        loc='outside lower right',
        title='ClinVar significance'
    )

    #handles, labels = [], []
    #for ax in axs:
    #    ax_handles, ax_labels = ax.get_legend_handles_labels()
    #    handles += ax_handles
    #    labels += ax_labels
    #
    #subfigs[row].legend(handles, labels, loc='outside right')

### Prettier summary plot

In [9]:

combined_fig = plt.figure(
    layout='constrained',
    figsize=(12, 6)
)
combined_plot_df = combined_or_estimates_df[
    (combined_or_estimates_df['Classification'] >= '≥ +1') &
    (combined_or_estimates_df['Gene'] != 'GCK')
].copy()
combined_plot_df['Classification'] = combined_plot_df['Classification'].cat.remove_unused_categories()
plot.summary_fig(combined_plot_df, combined_fig=combined_fig)

In [13]:
combined_fig = plt.figure(
    layout='constrained',
    figsize=(12, 6)
)

combined_plot_df = combined_or_estimates_df[
    (combined_or_estimates_df['Classification'] == '≥ +1') &
    (combined_or_estimates_df['LogOR_LI'] > 0) #&
    #(combined_or_estimates_df['Gene'] != 'GCK')
].copy()
combined_plot_df['Classification'] = combined_plot_df['Classification'].cat.remove_unused_categories()
plot.summary_fig(combined_plot_df, combined_fig=combined_fig)

In [7]:
combined_fig.savefig(
    'combined-summary.svg',
    dpi=600,
    format='svg'
)

In [26]:
# Plot threshold 1 across all genes that reach sig.

selected_slice = combined_or_estimates_df[
    (combined_or_estimates_df['Classification'] == '≥ +1') &
    (combined_or_estimates_df['LogOR_LI'] > 0) &
    (combined_or_estimates_df['LogOR'] < np.log(50))
]


In [27]:
selected_slice

In [33]:
import numpy as np

slice_figure, ax = plt.subplots()

selected_slice.sort_values('Gene', inplace=True, ascending=False)

ax.errorbar(
    x=np.exp(selected_slice.LogOR),
    y=selected_slice.Gene,
    xerr=np.absolute(
        np.exp(selected_slice[['LogOR_LI', 'LogOR_UI']].to_numpy())
        - np.exp(selected_slice[['LogOR']].to_numpy())
    ).transpose(),
    fmt='.',
    capsize=4,
    capthick=1.5,
    elinewidth=1,
    color='black'
)

ax.axvline(x=1, color='red', linestyle=':')

ax.set_xlabel("Odds ratio")

#ax.set_yticks(range(len(y_labels)), y_labels)


In [None]:
# Gene-specific detailed plots

In [48]:
select_genes = ['BRCA1', 'KCNQ4', 'TP53']

selected_figure, axs = plt.subplots(1,3, figsize=(10, 6), layout='constrained')

for i, gene in enumerate(select_genes):
    ax = axs[i]
    df = combined_or_estimates_df[combined_or_estimates_df.Gene == gene].sort_values('Classification')
    
    ax.errorbar(
        x=np.exp(df.LogOR),
        y=df.Classification,
        xerr=np.absolute(
            np.exp(df[['LogOR_LI', 'LogOR_UI']].to_numpy())
            - np.exp(df[['LogOR']].to_numpy())
        ).transpose(),
        fmt='.',
        capsize=4,
        capthick=1.5,
        elinewidth=1,
        color='black'
    )

    ax.axvline(x=1, color='red', linestyle=':')

    ax.set_xlabel("Odds ratio")
    
    ax.set_title(gene)

## Assays and VEPs

In [2]:
or_estimates_df = pd.read_parquet(f'{BUCKET}/or-estimates/or-estimates-2025-10-16.parquet')
or_estimates_df

In [14]:
or_estimates_df.reset_index(inplace=True)

In [15]:
#########################

In [16]:
# Add dataset category
or_estimates_df['Category'] = 'Assay'
or_estimates_df.loc[
    or_estimates_df.Dataset.isin({'ClinVar', 'Rare Variants'}),
    'Category'
] = 'Baseline'
or_estimates_df.loc[
    or_estimates_df.Dataset.isin({'REVEL', 'AlphaMissense', 'MutPred2'}),
    'Category'
] = 'Predictor'
or_estimates_df['Category'] = pd.Categorical(or_estimates_df['Category'], categories=['Baseline', 'Assay', 'Predictor'])

In [17]:
or_estimates_df['Category'].value_counts()

In [18]:
from datetime import date
or_estimates_df.to_csv(f'cvfg_or_estimates_{date.today().isoformat()}.csv')

In [19]:
list(or_estimates_df['Classification'].drop_duplicates())

In [20]:
# Order Classifications
or_estimates_df['Classification'] = pd.Categorical(
    or_estimates_df['Classification'],
    categories = [
        '≤ -8',
        '≤ -4',
        '≤ -3',
        '≤ -2',
        '≤ -1',
        '≥ +1',
        '≥ +2',
        '≥ +3',
        '≥ +4',
        '≥ +8',
        # AM
        'Likely benign',
        'Ambiguous',
        'Likely pathogenic',
        # assays
        'GOF',
        'possiblyWT',
        'WT',
        'possiblyLOF',
        'moderate LOF',
        'LOF',
        'severe LOF',
        'normal',
        'Amorphic',
        'Hypomorphic',
        'Unimpaired',
        'functionally_normal',
        'indeterminate',
        'functionally_abnormal',
        'FUNC',
        'INT',
        'Abnormal',
        'Intermediate',
        'Normal',
        'Functionally normal',
        'Functionally abnormal',
        'enriched',
        'unchanged',
        'slow depleted',
        'depleted',
        'fast depleted',
        '0.0',
        '1.0',
        '2.0',
        '3.0',
        '4.0'
    ],
    ordered=True
)

In [21]:
or_estimates_df

In [22]:
or_estimates_df.Classification.cat.categories

In [23]:
or_estimates_df['Classification'].drop_duplicates()

### QC: Log odds with vs without controling for covariates

In [None]:
import seaborn as sns
sns.relplot(or_estimates_df, x='LogOR', y='naive_logOR')

Looks like there are 3 groups separated by offsets, as well as a group with LogOR @ 0. Why?

In [None]:
### Assay summaries cross-linked to odds ratios

### Grouped plots

In [15]:
# Color codes
clinvar_colors = {
    'Benign': '#1D7AAB',
    'Likely benign': '#64A1C4',
    'Uncertain significance': '#A0A0A0',
    'Conflicting': 'grey',
    'Likely pathogenic': '#E6B1B8',
    'Pathogenic': '#CA7682',
}

cohort_colors = {
    'cases': '#CA7682',
    'overlap': 'grey',
    'controls': '#1D7AAB'
}

def grouped_bar_plot(ax, groups, y_values, bar_colors=None, xlabel=None, log_scale=False, **style_args):
    # Grouped bar positioning:
    # If we have n bars in a group, we want the width of each bar to be
    # w=1/(n+1) (to leave a bar's width worth of space between groups.
    # The first bar is then positioned at tick-w*(n-1)/2; this way, the
    # n-th bar will be at tick-w*(n-1)/2 + w*(n-1) = tick+w*(n-1)/2;
    # The resulting expression for the i-th (0-indexed) bar is
    # tick+w*(i - (n - 1)/2)

    nbars=len(groups)
    bar_size=1/(nbars+1)
    
    if bar_colors is None:
        bar_colors = list(
            list(plt.rcParams['axes.prop_cycle'])[n]['color']
            for n in range(nbars)
        )
    else:
        bar_colors = list(bar_colors)

    for n, (label, series) in enumerate(groups):

        if log_scale:
            ax.set_axisbelow(True)
            ax.set_xscale('log')
            
        color = bar_colors[n]

        bc = ax.barh(
            y=y_values - bar_size * (n - (nbars - 1)/2),
            width=series,
            height=bar_size,
            label=label,
            facecolor=colors.to_rgba(color, 0.5),
            edgecolor=color,
            **style_args
        )
        #ax.bar_label(bc, label_type='edge')
        #left += series

    if xlabel is not None:
        ax.set_xlabel(xlabel)

    #ax.legend()

def populate_classifier_row(subfigs, row, classifier, plot_df):
    
    subfigs[row].suptitle(classifier)
    
    axs = subfigs[row].subplots(1, 4, sharey=True)
    
    for ax in axs:
        ax.grid(True, which='major', axis='x', linestyle='--')
        ax.grid(True, which='minor', axis='x', linestyle=':')
        ax.xaxis.set_minor_locator(ticker.AutoMinorLocator())

    if row > 0:
        # Share x across OR plots
        axs[0].sharex(subfigs[row-1].axes[0])
        #for col in range(3):
        #    axs[col].sharex(subfigs[row-1].axes[col])

    # Establish Y positions
    y_labels = plot_df.Classification.cat.remove_unused_categories().cat.categories
    y_map = {
        label: index
        for index, label in enumerate(y_labels)
    }

    ##################
    # Log OR subplot #
    ##################

    axs[0].errorbar(
        x=plot_df.LogOR,
        y=plot_df.Classification.map(y_map),
        xerr=np.absolute(plot_df[['LogOR_LI', 'LogOR_UI']].to_numpy() - plot_df[['LogOR']].to_numpy()).transpose(),
        fmt='.',
        capsize=4,
        capthick=1.5,
        elinewidth=1,
        color='black'
    )

    axs[0].axvline(x=0, color='red', linestyle=':')

    axs[0].set_xlabel("ln(OR)")

    axs[0].set_yticks(range(len(y_labels)), y_labels)
    
    #########################
    # Variant count subplot #
    #########################

    grouped_bar_plot(
        axs[1],
        [
            ('cases', plot_df.case_only_variant_count),
            ('overlap', plot_df.overlap_variant_count),
            ('controls', plot_df.control_only_variant_count)
        ],
        plot_df.Classification.map(y_map),
        bar_colors=(cohort_colors[c] for c in ('cases', 'overlap', 'controls')),
        xlabel='Variant count'
    )

    cohort_handles, cohort_labels = axs[1].get_legend_handles_labels()
    subfigs[row].legend(
        cohort_handles,
        cohort_labels,
        loc='outside lower center',
        title='Cohort'
    )
    
    #############################
    # Patricipant count subplot #
    #############################

    grouped_bar_plot(
        axs[2],
        [
            ('cases', plot_df.cases_with_variants),
            ('controls', plot_df.controls_with_variants)
        ],
        plot_df.Classification.map(y_map),
        bar_colors=(cohort_colors[c] for c in ('cases', 'controls')),
        xlabel='Participant count',
        log_scale=True
    )
    
    #############################
    # Clinvar breakdown subplot #
    #############################
    
    clinvar_subplot_title = 'ClinVar significance'
    clinvar_categories = (
        'Pathogenic',
        'Likely pathogenic',
        'Conflicting',
        'Uncertain significance',
        'Likely benign',
        'Benign',
        #'Other / not in ClinVar'
    )
    
    class_clinvar_bars, cohort_clinvar_bars = (
        [
            (significance, plot_df[f'{prefix} {significance}'])
            for significance in clinvar_categories
        ]
        for prefix in ('Class ClinVar', 'Cohort ClinVar')
    )
    clinvar_bar_colors = [
        clinvar_colors[significance]
        for significance in clinvar_categories
    ]
    #grouped_bar_plot(
    #    axs[3],
    #    class_clinvar_bars,
    #    plot_df.Classification.map(y_map),
    #    #log_scale=True,
    #    xlabel='ClinVar significance'
    #)
    grouped_bar_plot(
        axs[3],
        cohort_clinvar_bars,
        plot_df.Classification.map(y_map),
        bar_colors=clinvar_bar_colors,
        #hatch='//'
        xlabel='Variant count by ClinVar significance'
    )
    clinvar_handles, clinvar_labels = axs[3].get_legend_handles_labels()
    subfigs[row].legend(
        clinvar_handles,
        clinvar_labels,
        loc='outside lower right',
        title='ClinVar significance'
    )

    #handles, labels = [], []
    #for ax in axs:
    #    ax_handles, ax_labels = ax.get_legend_handles_labels()
    #    handles += ax_handles
    #    labels += ax_labels
    #
    #subfigs[row].legend(handles, labels, loc='outside right')

### Assay selected example

In [17]:
select_genes = ['RAD51C', 'BAP1', 'MSH2', 'KCNQ4']

# Plot test
for gene, gene_df in or_estimates_df.groupby('Gene'):
    
    if gene not in select_genes:
        continue
        
    gene_fig = plt.figure(layout='constrained', figsize=(10, gene_df.shape[0]))
    
    plot_dfs = list(gene_df.groupby(['Category', 'Dataset', 'Classifier']))
    
    subfigs = gene_fig.subfigures(len(plot_dfs), 1)
    if len(plot_dfs) == 1:
        subfigs = [subfigs]
    
    for row, ((category, dataset, classifier), classifier_df) in enumerate(plot_dfs):

        classifier_df = classifier_df[classifier_df.LogOR_UI - classifier_df.LogOR_LI < 10]
        subplot_title = f'{gene}\n{category}: {dataset}\n{classifier}'
        populate_classifier_row(subfigs, row, subplot_title, classifier_df)
    
    display(gene_fig)

### Individual figure files per classifier

In [16]:
# Generate figure files

def build_fig(title, plot_df):
    
    fig, axs = plt.subplots(1, 4, sharey=True, layout='constrained', figsize=(10, plot_df.shape[0]/2 + 3))
    fig.suptitle(title)
    
    for ax in axs:
        ax.grid(True, which='major', axis='x', linestyle='--')
        ax.grid(True, which='minor', axis='x', linestyle=':')
        ax.xaxis.set_minor_locator(ticker.AutoMinorLocator())

    # Establish Y positions
    y_labels = plot_df.Classification.cat.remove_unused_categories().cat.categories
    y_map = {
        label: index
        for index, label in enumerate(y_labels)
    }

    ##################
    # Log OR subplot #
    ##################

    valid_bars_df = plot_df[plot_df.cases_with_variants > 0]
    axs[0].errorbar(
        x=valid_bars_df.LogOR,
        y=valid_bars_df.Classification.map(y_map),
        xerr=np.absolute(valid_bars_df[['LogOR_LI', 'LogOR_UI']].to_numpy() - valid_bars_df[['LogOR']].to_numpy()).transpose(),
        fmt='.',
        capsize=4,
        capthick=1.5,
        elinewidth=1,
        color='black'
    )

    axs[0].axvline(x=0, color='red', linestyle=':')

    axs[0].set_xlabel("ln(OR)")

    axs[0].set_yticks(range(len(y_labels)), y_labels)
    
    #########################
    # Variant count subplot #
    #########################

    grouped_bar_plot(
        axs[1],
        [
            ('cases', plot_df.case_only_variant_count),
            ('overlap', plot_df.overlap_variant_count),
            ('controls', plot_df.control_only_variant_count)
        ],
        plot_df.Classification.map(y_map),
        bar_colors=(cohort_colors[c] for c in ('cases', 'overlap', 'controls')),
        xlabel='Variant count'
    )

    cohort_handles, cohort_labels = axs[1].get_legend_handles_labels()
    fig.legend(
        cohort_handles,
        cohort_labels,
        loc='outside lower center',
        title='Cohort'
    )
    
    #############################
    # Patricipant count subplot #
    #############################

    grouped_bar_plot(
        axs[2],
        [
            ('cases', plot_df.cases_with_variants),
            ('controls', plot_df.controls_with_variants)
        ],
        plot_df.Classification.map(y_map),
        bar_colors=(cohort_colors[c] for c in ('cases', 'controls')),
        xlabel='Participant count',
        log_scale=True
    )
    
    #############################
    # Clinvar breakdown subplot #
    #############################
    
    clinvar_subplot_title = 'ClinVar significance'
    clinvar_categories = (
        'Pathogenic',
        'Likely pathogenic',
        'Conflicting',
        'Uncertain significance',
        'Likely benign',
        'Benign',
        #'Other / not in ClinVar'
    )
    
    class_clinvar_bars, cohort_clinvar_bars = (
        [
            (significance, plot_df[f'{prefix} {significance}'])
            for significance in clinvar_categories
        ]
        for prefix in ('Class ClinVar', 'Cohort ClinVar')
    )
    clinvar_bar_colors = [
        clinvar_colors[significance]
        for significance in clinvar_categories
    ]
    #grouped_bar_plot(
    #    axs[3],
    #    class_clinvar_bars,
    #    plot_df.Classification.map(y_map),
    #    #log_scale=True,
    #    xlabel='ClinVar significance'
    #)
    grouped_bar_plot(
        axs[3],
        cohort_clinvar_bars,
        plot_df.Classification.map(y_map),
        bar_colors=clinvar_bar_colors,
        #hatch='//'
        xlabel='Variant count by ClinVar sig.'
    )
    clinvar_handles, clinvar_labels = axs[3].get_legend_handles_labels()
    fig.legend(
        clinvar_handles,
        clinvar_labels,
        loc='outside lower right',
        title='ClinVar significance'
    )

    #handles, labels = [], []
    #for ax in axs:
    #    ax_handles, ax_labels = ax.get_legend_handles_labels()
    #    handles += ax_handles
    #    labels += ax_labels
    #
    #subfigs[row].legend(handles, labels, loc='outside right')
    
    return fig

In [17]:
#select_genes = ['BRCA1']

figs_dir = Path(f'plots_{date.today().isoformat()}')
figs_dir.mkdir(exist_ok=True, parents=True)

# Make and save plots
for (gene, category, dataset, classifier), plot_df in or_estimates_df.groupby(['Gene', 'Category', 'Dataset', 'Classifier']):
    
    #if gene not in select_genes:
    #    continue
    
    plot_title = f'{gene}\n{category}: {dataset}\n{classifier}'
    filename = re.sub('[^0-9a-zA-Z()\-]+', '_', f'{gene}_{category}_{dataset}_{classifier}')+'.svg'
    fig = build_fig(plot_title, plot_df)
    fig.savefig(
        figs_dir/filename,
        format='svg',
        dpi=600
    )
    
    plt.close(fig) # Keeps memory usage down and prevents plots from showing in output here


In [23]:
!zip plots_2025-10-20.zip plots_2025-10-20/*

### Assay summary of calibrated strengths

In [24]:
calibrated_assays = or_estimates_df[
    (or_estimates_df.Category == 'Assay') &
    (or_estimates_df.Classifier.str.startswith('Calibrated'))
]

In [25]:
sig_levels = calibrated_assays[calibrated_assays['LogOR_LI'] > 0].groupby('Dataset')['Classification'].min()

In [26]:
sig_reachable = calibrated_assays.assign(sig=calibrated_assays['LogOR_LI'] > 0).groupby('Dataset').sig.any()

In [27]:
pd.DataFrame.from_dict({'Min. points for significance': sig_levels, 'Reaches significance': sig_reachable}, orient='columns')

#### Prettier plot

In [28]:
def summary_fig(
    or_estimates_df,
    combined_fig = None,
    cols='Gene',
    saturation_threshold=0.7,
    frame_on=False,
    left_ticks=False,
    title_rotation=0
):

    if combined_fig is None:
        combined_fig = plt.figure(
            layout='constrained',
            figsize=(15, 10)
        )

    plot_df = or_estimates_df[
        (or_estimates_df.cases_with_variants > 0)
    ]

    col_dfs = [
        (col_id, col_df)
        for col_id, col_df in plot_df.groupby(cols)
        if not col_df.empty
    ]

    axs = combined_fig.subplots(
        1, #combined_or_estimates_df['Gene'].nunique(),
        len(col_dfs),
        sharex = True,
        sharey = True
    )

    class_tick_list = plot_df.Classification.cat.categories.to_list()
    class_tick_map = {val: ind for ind, val in enumerate(class_tick_list)}

    for col, (gene, col_df) in enumerate(col_dfs):

        # Figure out which bars will get a dark background
        dark_rows=(
            (
                col_df.LogOR_LI > 0
            ) & (
                col_df['Classification']
                .astype(str) # The categorical wrapper breaks things, so convert to str
                .map(plot.points_colors)
                .map(colors.to_rgb)
                .map(colors.rgb_to_hsv)
                .apply(lambda c: c[2] < saturation_threshold)
            )
        )
        
        # Need to draw separately for dark and light background rows
        for dark in True, False:
            subset_df = col_df[dark_rows == dark]
            color = 'white' if dark else 'black'
            axs[col].errorbar(
                x=subset_df.LogOR,
                y=subset_df.Classification.map(class_tick_map),
                xerr=np.absolute(
                    subset_df[['LogOR_LI', 'LogOR_UI']].to_numpy()
                    - subset_df[['LogOR']].to_numpy()
                ).transpose(),
                fmt='.',
                capsize=4,
                capthick=1.5,
                elinewidth=1,
                color=color
            )

        axs[col].axvline(x=0, color='black', linestyle=':')
        axs[col].set_yticks(range(len(class_tick_list)), class_tick_list)
        axs[col].set_title(gene, rotation=title_rotation)

        # Highlight rows with significant ORs
        for classification in col_df.loc[col_df.LogOR_LI > 0, 'Classification']:
            y_val = class_tick_map[classification]
            axs[col].axhspan(y_val-0.45, y_val+0.45, color=plot.points_colors[classification])

        axs[col].set_frame_on(frame_on)
        axs[col].tick_params(left=left_ticks)

    return combined_fig

In [29]:
assay_sum_fig = plt.figure(
    layout='constrained',
    figsize=(16, 6)
)
assay_sum_plot_df = calibrated_assays[
    (calibrated_assays['Classification'] >= '≥ +1')
    & (calibrated_assays['Gene'] != 'PTEN')
].copy()
#combined_plot_df['Classification'] = combined_plot_df['Classification'].cat.remove_unused_categories()
assay_summary_fig = summary_fig(assay_sum_plot_df, cols='Dataset', title_rotation=90, combined_fig=assay_sum_fig)

In [28]:
assay_summary_fig.savefig(f'assay_summary_{date.today().isoformat()}.svg', format='svg', dpi=600)

### VEP summary of calibrated strengths

In [24]:
calibrated_veps = or_estimates_df[
    (or_estimates_df.Category == 'Predictor') &
    (or_estimates_df.Classifier.str.startswith('Calibrated'))
]

In [31]:
for (dataset, classifier), plot_df in calibrated_veps.groupby(['Dataset', 'Classifier']):
    c_sum_fig = plt.figure(
        layout='constrained',
        figsize=(15, 5)
    )
    c_sum_plot_df = plot_df[(calibrated_veps['Classification'] >= '≥ +1')].copy()
    #combined_plot_df['Classification'] = combined_plot_df['Classification'].cat.remove_unused_categories()
    plot.summary_fig(c_sum_plot_df, combined_fig=c_sum_fig)
    c_sum_fig.suptitle(f'{dataset} {classifier}')

In [37]:
for (gene, dataset), plot_df in calibrated_veps.groupby(['Gene', 'Dataset']):
    g_sum_fig = plt.figure(
        layout='constrained',
        figsize=(15, 5)
    )
    g_sum_plot_df = plot_df[
        (calibrated_veps['Classification'] >= '≥ +1')
        & (calibrated_veps['cases_with_variants'] > 1)
    ].copy()
    g_sum_plot_df['Classification'] = g_sum_plot_df['Classification'].cat.remove_unused_categories()
    #combined_plot_df['Classification'] = combined_plot_df['Classification'].cat.remove_unused_categories()
    
    if g_sum_plot_df.empty:
         continue
         
    try:
        plot.summary_fig(g_sum_plot_df, combined_fig=g_sum_fig, cols='Classifier')
        g_sum_fig.suptitle(f'{gene} {dataset}')
    except:
            print(g_sum_plot_df)
            raise

In [31]:
revel_gl_sum_fig = plt.figure(
    layout='constrained',
    figsize=(15, 2)
)
revel_gl_sum_plot_df = calibrated_veps[
    (calibrated_veps['Classification'] >= '≥ +1') &
    (calibrated_veps['Dataset'] == 'REVEL') &
    (calibrated_veps['Classifier'] != 'Calibrated (gene-specific)')
].copy()
#combined_plot_df['Classification'] = combined_plot_df['Classification'].cat.remove_unused_categories()
plot.summary_fig(revel_gl_sum_plot_df, combined_fig=revel_gl_sum_fig)
revel_gl_sum_fig.suptitle('REVEL (Bergquist calibration)')

In [32]:
mp2_gs_sum_fig = plt.figure(
    layout='constrained',
    figsize=(5, 2)
)
mp2_gs_sum_plot_df = calibrated_veps[
    (calibrated_veps['Classification'] >= '≥ +1') &
    (calibrated_veps['Dataset'] == 'MutPred2') &
    (calibrated_veps['Classifier'] == 'Calibrated (gene-specific)')
].copy()
#combined_plot_df['Classification'] = combined_plot_df['Classification'].cat.remove_unused_categories()
plot.summary_fig(mp2_gs_sum_plot_df, combined_fig=mp2_gs_sum_fig)
mp2_gs_sum_fig.suptitle('MutPred2 (gene-specific calibration)')

In [33]:
mp2_gl_sum_fig = plt.figure(
    layout='constrained',
    figsize=(12, 2)
)
mp2_gl_sum_plot_df = calibrated_veps[
    (calibrated_veps['Classification'] >= '≥ +1') &
    (calibrated_veps['Dataset'] == 'MutPred2') &
    (calibrated_veps['Classifier'] != 'Calibrated (gene-specific)') &
    (~calibrated_veps['Gene'].isin(['PTEN', 'TP53']))
].copy()
#combined_plot_df['Classification'] = combined_plot_df['Classification'].cat.remove_unused_categories()
plot.summary_fig(mp2_gl_sum_plot_df, combined_fig=mp2_gl_sum_fig)
mp2_gl_sum_fig.suptitle('MutPred2 (Bergquist calibration)')

In [34]:
am_gs_sum_fig = plt.figure(
    layout='constrained',
    figsize=(5, 2)
)
am_gs_sum_plot_df = calibrated_veps[
    (calibrated_veps['Classification'] >= '≥ +1') &
    (calibrated_veps['Dataset'] == 'AlphaMissense') &
    (calibrated_veps['Classifier'] == 'Calibrated (gene-specific)')
].copy()
#combined_plot_df['Classification'] = combined_plot_df['Classification'].cat.remove_unused_categories()
plot.summary_fig(am_gs_sum_plot_df, combined_fig=am_gs_sum_fig)
am_gs_sum_fig.suptitle('AlphaMissense (gene-specific calibration)')

In [35]:
am_gl_sum_fig = plt.figure(
    layout='constrained',
    figsize=(12, 2)
)
am_gl_sum_plot_df = calibrated_veps[
    (calibrated_veps['Classification'] >= '≥ +1') &
    (calibrated_veps['Dataset'] == 'AlphaMissense') &
    (calibrated_veps['Classifier'] != 'Calibrated (gene-specific)') &
    (calibrated_veps['Gene'] != 'PTEN')
].copy()
#combined_plot_df['Classification'] = combined_plot_df['Classification'].cat.remove_unused_categories()
plot.summary_fig(am_gl_sum_plot_df, combined_fig=am_gl_sum_fig)
am_gl_sum_fig.suptitle('AlphaMissense (Bergquist calibration)')

In [36]:
# Save
summaries_dir = Path(f'summary-plots_{date.today().isoformat()}')
summaries_dir.mkdir(exist_ok=True, parents=True)

assay_summary_fig.savefig(summaries_dir/'assay_summary.svg', dpi=600)
revel_gs_sum_fig.savefig(summaries_dir/'revel_gene-specific_summary.svg', dpi=600)
revel_gl_sum_fig.savefig(summaries_dir/'revel_bergquist_summary.svg', dpi=600)
mp2_gs_sum_fig.savefig(summaries_dir/'mp2_gene-specific_summary.svg', dpi=600)
mp2_gl_sum_fig.savefig(summaries_dir/'mp2_bergquist_summary.svg', dpi=600)
am_gs_sum_fig.savefig(summaries_dir/'am_gene-specific_summary.svg', dpi=600)
am_gl_sum_fig.savefig(summaries_dir/'am_bergquist_summary.svg', dpi=600)


In [37]:
!zip summary-plots_2025-10-20.zip summary-plots_2025-10-20/*

In [None]:
cveps_sig_levels = calibrated_veps[calibrated_veps['LogOR_LI'] > 0].groupby(['Gene', 'Dataset', 'Classifier'])['Classification'].min()

In [None]:
cveps_sig_reachable = calibrated_veps.assign(sig=calibrated_veps['LogOR_LI'] > 0).groupby(['Gene', 'Dataset', 'Classifier']).sig.any()

In [None]:
cveps_summary = pd.DataFrame.from_dict(
    {
        'Min. points for significance': cveps_sig_levels,
        'Reaches significance': cveps_sig_reachable
    },
    orient='columns'
)

In [None]:
cveps_summary_table = cveps_summary.unstack(level = ['Dataset', 'Classifier']).sort_index(axis='columns')
cveps_summary_table

In [None]:
cveps_summary_table['Min. points for significance'].astype(str).apply(lambda s: s.str.replace('nan',''))

## VEPs

In [None]:
vep_or_2025_09_02_df = pd.read_parquet('or-estimates-2025-09-02.parquet').reset_index().assign(version='2025-09-02')

In [None]:
!gsutil ls $WORKSPACE_BUCKET/or-estimates

In [None]:
!ls

In [None]:
def annotate_category(df):
    or_estimates_df = df.copy()
    # Add dataset category
    or_estimates_df['Category'] = 'Assay'
    or_estimates_df.loc[
        or_estimates_df.Dataset.isin({'ClinVar', 'Rare Variants'}),
        'Category'
    ] = 'Baseline'
    or_estimates_df.loc[
        or_estimates_df.Dataset.isin({'REVEL', 'AlphaMissense', 'MutPred2'}),
        'Category'
    ] = 'Predictor'
    or_estimates_df['Category'] = pd.Categorical(or_estimates_df['Category'], categories=['Baseline', 'Assay', 'Predictor'])
    return or_estimates_df

In [None]:
or_estimates_dfs = {
    version: pd.read_csv(f'cvfg_or_estimates_{version}.csv').assign(version=version)
    for version in ['2025-08-20', '2025-08-21', '2025-08-28', '2025-09-04']
}

or_estimates_dfs = {
    version: df if 'Category' in df.columns else annotate_category(df)
    for version, df in or_estimates_dfs.items()
}

for version, df in or_estimates_dfs.items():
    print(version)
    print(df['Category'].value_counts())

In [None]:
# From notes:
# 2025-08-20 is unfiltered, still uses old scale
# 2025-08-28 is filtered by SpliceAI
# 2025-09-04 is filtered by SpliceAI and AF

In [None]:
# Will use BRCA1 as an example for predictors
veps_sa_af_filtered = or_estimates_dfs['2025-09-04']
veps_sa_af_filtered = veps_sa_af_filtered[veps_sa_af_filtered['Category'] == 'Predictor'].copy()

In [None]:
# Order Classifications
veps_sa_af_filtered['Classification'] = pd.Categorical(
    veps_sa_af_filtered['Classification'],
    categories = [
        '≤ -8',
        '≤ -4',
        '≤ -3',
        '≤ -2',
        '≤ -1',
        '≥ +1',
        '≥ +2',
        '≥ +3',
        '≥ +4',
        '≥ +8',
        # AM
        'Likely benign',
        'Ambiguous',
        'Likely pathogenic'
    ],
    ordered=True
)

In [None]:
## Selected example

select_genes = ['BRCA1', 'MSH2']
select_dataset = 'MutPred2'

# Plot test
for gene, gene_df in veps_sa_af_filtered[veps_sa_af_filtered['Dataset'] == select_dataset].groupby('Gene'):
    
    if gene not in select_genes:
        continue
        
    gene_fig = plt.figure(layout='constrained', figsize=(10, gene_df.shape[0]))
    
    plot_dfs = list(gene_df.groupby(['Category', 'Dataset', 'Classifier']))
    
    subfigs = gene_fig.subfigures(len(plot_dfs), 1)
    if len(plot_dfs) == 1:
        subfigs = [subfigs]
    
    for row, ((category, dataset, classifier), classifier_df) in enumerate(plot_dfs):

        classifier_df = classifier_df[classifier_df.LogOR_UI - classifier_df.LogOR_LI < 10]
        subplot_title = f'{gene}\n{category}: {dataset}\n{classifier}'
        populate_classifier_row(subfigs, row, subplot_title, classifier_df)
    
    display(gene_fig)

### Another selected example

In [None]:
## Selected example

select_genes = ['BRCA2']
select_dataset = 'MutPred2'

# Plot test
for gene, gene_df in veps_sa_af_filtered[veps_sa_af_filtered['Dataset'] == select_dataset].groupby('Gene'):
    
    if gene not in select_genes:
        continue
        
    gene_fig = plt.figure(layout='constrained', figsize=(10, gene_df.shape[0]))
    
    plot_dfs = list(gene_df.groupby(['Category', 'Dataset', 'Classifier']))
    
    subfigs = gene_fig.subfigures(len(plot_dfs), 1)
    if len(plot_dfs) == 1:
        subfigs = [subfigs]
    
    for row, ((category, dataset, classifier), classifier_df) in enumerate(plot_dfs):

        classifier_df = classifier_df[classifier_df.LogOR_UI - classifier_df.LogOR_LI < 10]
        subplot_title = f'{gene}\n{category}: {dataset}\n{classifier}'
        populate_classifier_row(subfigs, row, subplot_title, classifier_df)
    
    display(gene_fig)

### Effect of filtering

In [None]:
veps_unfiltered_df = or_estimates_dfs['2025-08-20']
veps_unfiltered_df = veps_unfiltered_df[veps_unfiltered_df['Category'] == 'Predictor'].copy()

veps_unfiltered_df['Classification'].value_counts()

In [None]:
from cvfgaou import notation
class_map = {
    f'{label} {strength}': f'{inequality} {sign}{points}'
    for points, strength in notation.SOE_TABLE
    for inequality, sign, _, label, _, _ in notation.DOE_TABLE
}

In [None]:
veps_unfiltered_df['Classification'] = veps_unfiltered_df['Classification'].map(class_map, na_action='ignore')
veps_unfiltered_df['Classification'].value_counts()

In [None]:
# Order Classifications
veps_unfiltered_df['Classification'] = pd.Categorical(
    veps_unfiltered_df['Classification'],
    categories = [
        '≤ -8',
        '≤ -4',
        '≤ -3',
        '≤ -2',
        '≤ -1',
        '≥ +1',
        '≥ +2',
        '≥ +3',
        '≥ +4',
        '≥ +8',
        # AM
        'Likely benign',
        'Ambiguous',
        'Likely pathogenic'
    ],
    ordered=True
)

In [None]:
## Selected example

select_genes = ['BRCA1', 'BRCA2','MSH2']
select_dataset = 'MutPred2'

# Plot test
for gene, gene_df in veps_unfiltered_df[veps_unfiltered_df['Dataset'] == select_dataset].groupby('Gene'):
    
    if gene not in select_genes:
        continue
        
    gene_fig = plt.figure(layout='constrained', figsize=(10, gene_df.shape[0]))
    
    plot_dfs = list(gene_df.groupby(['Category', 'Dataset', 'Classifier']))
    
    subfigs = gene_fig.subfigures(len(plot_dfs), 1)
    if len(plot_dfs) == 1:
        subfigs = [subfigs]
    
    for row, ((category, dataset, classifier), classifier_df) in enumerate(plot_dfs):

        classifier_df = classifier_df[classifier_df.LogOR_UI - classifier_df.LogOR_LI < 10]
        subplot_title = f'{gene}\n{category}: {dataset}\n{classifier}'
        populate_classifier_row(subfigs, row, subplot_title, classifier_df)
    
    display(gene_fig)

### Unfiltered VEPs sumamry

In [28]:
calibrated_uf_veps = veps_unfiltered_df[
    (veps_unfiltered_df.Classifier.str.startswith('Calibrated')) &
    ~veps_unfiltered_df.Gene.isin({'APP', 'SNCA', 'VWF'})
]

In [None]:
cufveps_sig_levels = calibrated_uf_veps[calibrated_uf_veps['LogOR_LI'] > 0].groupby(['Gene', 'Dataset', 'Classifier'])['Classification'].min()

In [None]:
cufveps_sig_reachable = calibrated_uf_veps.assign(sig=calibrated_uf_veps['LogOR_LI'] > 0).groupby(['Gene', 'Dataset', 'Classifier']).sig.any()

In [None]:
cufveps_summary = pd.DataFrame.from_dict(
    {
        'Min. points for significance': cufveps_sig_levels,
        'Reaches significance': cufveps_sig_reachable
    },
    orient='columns'
)

In [None]:
cufveps_summary_table = cufveps_summary.unstack(level = ['Dataset', 'Classifier']).sort_index(axis='columns')
cufveps_summary_table

In [None]:
cufveps_summary_table['Min. points for significance'].astype(str).apply(lambda s: s.str.replace('nan',''))

### Detailed plots

In [None]:
combined_gene_dfs = {
    gene: gene_df
    for gene, gene_df in combined_or_estimates_df.groupby('Gene')
}

combined_fig = plt.figure(layout='constrained', figsize=(10, combined_or_estimates_df.shape[0]))

subfigs = combined_fig.subfigures(len(combined_gene_dfs), 1)

for row, (gene, gene_df) in enumerate(combined_gene_dfs.items()):
    
    gene_df = gene_df[gene_df.cases_with_variants > 0]
    #classifier_df = classifier_df[classifier_df.LogOR_UI - classifier_df.LogOR_LI < 10]
    subplot_title = gene
    populate_classifier_row(subfigs, row, subplot_title, gene_df)

display(combined_fig)

### Summary plot

In [5]:
combined_or_estimates_df[
    combined_or_estimates_df['Classification'] == '≥ +1'
]

In [None]:
# Genes are cols

import matplotlib.pyplot as plt
from matplotlib import colors, ticker
import numpy as np

combined_fig = plt.figure(
    layout='constrained',
    figsize=(15, 10)
)

plot_df = combined_or_estimates_df[
#    (combined_or_estimates_df['Classification'] >= '≥ +1') &
    (combined_or_estimates_df.cases_with_variants > 0)
]

col_dfs = [
    (gene, col_df)
    for gene, col_df in plot_df.groupby('Gene')
    if not col_df.empty
]

axs = combined_fig.subplots(
    1, #combined_or_estimates_df['Gene'].nunique(),
    len(col_dfs),
    sharex = True,
    sharey = True
)

class_tick_list = plot_df.Classification.cat.categories.to_list()

for col, (gene, col_df) in enumerate(col_dfs):
    axs[col].errorbar(
        x = col_df.LogOR,
        y = col_df.Classification.map({gene: ind for ind, gene in enumerate(class_tick_list)}),
        xerr=np.absolute(col_df[['LogOR_LI', 'LogOR_UI']].to_numpy() - col_df[['LogOR']].to_numpy()).transpose(),
        fmt='.',
        capsize=4,
        capthick=1.5,
        elinewidth=1,
        color='black'
    )
    axs[col].axvline(x=0, color='black', linestyle=':')
    axs[col].set_yticks(range(len(class_tick_list)), class_tick_list)
    axs[col].set_title(gene)

In [None]:
# Genes are rows

import matplotlib.pyplot as plt
from matplotlib import colors, ticker
import numpy as np

combined_fig = plt.figure(
    layout='constrained',
    figsize=(10, 10)
)

plot_df = combined_or_estimates_df[
    (combined_or_estimates_df['Classification'] >= '≥ +1') &
    (combined_or_estimates_df.cases_with_variants > 0)
]

col_dfs = [
    (classification, col_df)
    for classification, col_df in plot_df.groupby('Classification')
    if not col_df.empty
]

axs = combined_fig.subplots(
    1, #combined_or_estimates_df['Gene'].nunique(),
    len(col_dfs),
    sharex = True,
    sharey = True
)

gene_tick_list = plot_df.Gene.drop_duplicates().sort_values().to_list()

for col, (classification, col_df) in enumerate(col_dfs):
    axs[col].errorbar(
        x = col_df.LogOR,
        y = col_df.Gene.map({gene: ind for ind, gene in enumerate(gene_tick_list)}),
        xerr=np.absolute(col_df[['LogOR_LI', 'LogOR_UI']].to_numpy() - col_df[['LogOR']].to_numpy()).transpose(),
        fmt='.',
        capsize=4,
        capthick=1.5,
        elinewidth=1,
        color='black'
    )
    axs[col].axvline(x=0, color='black', linestyle=':')
    axs[col].set_yticks(range(len(gene_tick_list)), gene_tick_list)
    axs[col].set_title(classification)


In [17]:
col_dfs

## Main paper figure elements

**Subfig. B** Combined score classifications across all genes

In [None]:
main_fig

In [None]:
import plotnine as p9

summary_plot_df = or_estimates_df[
    (or_estimates_df['Classification'] == 'Pathogenic moderate') &
    (or_estimates_df['LogOR_UI'] - or_estimates_df['LogOR_LI'] < 100) &
    (or_estimates_df['Dataset'].isin({'MutPred2', 'AlphaMissense'}))
]

gp = (
    p9.ggplot(
        data=summary_plot_df,
        mapping=p9.aes(
            x='LogOR',
            y='Gene',
            xmin='LogOR_LI',
            xmax='LogOR_UI'
        )
    )
    + p9.geom_point()
    + p9.geom_errorbarh(height=0.1)
    + p9.geom_vline(
        xintercept = 0,
        colour='red',
        linetype='dotted'
    )
    + p9.facet_wrap(
        '~ Dataset + Classification',
        #scales='free_y',
        ncol=2
    )
    + p9.theme(
        figure_size=(9, summary_plot_df.shape[0]/6),
        axis_title_y=p9.element_blank(),
        legend_title=p9.element_blank()
    )
)

In [None]:
display(gp)

In [None]:
summary_plot_df = or_estimates_df[
    (or_estimates_df['Classification'] == 'Pathogenic moderate') &
    (or_estimates_df['LogOR_UI'] - or_estimates_df['LogOR_LI'] < 100) &
    (or_estimates_df['Classifier'] == 'Calibrated (2025-03-04)')
]

gp = (
    p9.ggplot(
        data=summary_plot_df,
        mapping=p9.aes(
            x='LogOR',
            y='Dataset',
            xmin='LogOR_LI',
            xmax='LogOR_UI'
        )
    )
    + p9.geom_point()
    + p9.geom_errorbarh(height=0.1)
    + p9.geom_vline(
        xintercept = 0,
        colour='red',
        linetype='dotted'
    )
    + p9.facet_wrap(
        '~ Classification',
        #scales='free_y',
        ncol=1
    )
    + p9.theme(
        figure_size=(6, summary_plot_df.shape[0]/4),
        axis_title_y=p9.element_blank(),
        legend_title=p9.element_blank()
    )
)

In [None]:
display(gp)

In [None]:
%%capture --no-stderr --no-stdout cap

for gene, plot_df in or_estimates_df.groupby('Gene'):
    
    #if gene == 'VWF': continue # Some weirdness with VWF
    
    display(Markdown(f'## {gene}'))
    display(Markdown(f"**Phenotype: {gene_phenotypes[gene]}**"))
    
    with pd.option_context('display.max_rows', None):
        display(
            plot_df
            .drop(columns=['LogOR', 'LogOR_LI', 'LogOR_UI'])
            .set_index(['Gene', 'Category', 'Dataset', 'Classifier', 'Classification'])
        )
    
    plot_df = plot_df[plot_df['LogOR_UI'] - plot_df['LogOR_LI'] < 100]
    plot_df['Classification'] = pd.Categorical(
        plot_df['Classification'],
        categories=plot_df['Classification'].cat.categories[::-1]
    )
    
    gp = (
        p9.ggplot(
            data=plot_df,
            mapping=p9.aes(
                x='LogOR',
                y='Classification',
                color='Classifier',
                xmin='LogOR_LI',
                xmax='LogOR_UI'
            )
        )
        + p9.geom_point()
        + p9.geom_errorbarh(height=0.1)
        + p9.geom_vline(
            xintercept = 0,
            colour='red',
            linetype='dotted'
        )
        + p9.facet_wrap(
            '~ Category + Dataset + Classifier',
            scales='free_y',
            ncol=1
        )
        + p9.theme(
            figure_size=(6, plot_df.shape[0]/2),
            axis_title_y=p9.element_blank(),
            legend_title=p9.element_blank()
        )
    )
    
    display(gp)

In [None]:
cap.show()

In [None]:
import pickle
with open('cap.pickle', 'wb') as outstream:
    pickle.dump(cap, outstream)

## Data exploration cells

In [None]:
or_estimates_df[
    (or_estimates_df.Gene == 'BRIP1') &
    (or_estimates_df.Dataset == 'MutPred2')
][[
    'Dataset', 'Gene', 'Classifier', 'Classification', 'LogOR', 'LogOR_LI',
    'LogOR_UI', 'cases_with_variants', 'controls_with_variants',
    'cases_without_variants', 'controls_without_variants',
    'variants_per_case', 'variants_per_control', #'case_variant_min_af',
    #'control_variant_min_af', 'case_variant_max_af',
    #'control_variant_max_af',
    'case_only_variant_count',
    'control_only_variant_count', 'overlap_variant_count',
    'Variants in class', 'Variants in cohort'
]]