In [95]:
import altair as alt
import numpy as np
import pandas as pd
from sklearn.mixture import GaussianMixture
from sklearn.model_selection import GridSearchCV

from plot_config import init_json_transformer

init_json_transformer()

In [4]:
ne_dms_data = pd.read_table('data/dms_fitness.txt')
ne_dms_data

Unnamed: 0,position,wt_aa,variant_aa,fitness,sigma
0,45,S,N,-0.157877,0.051725
1,45,S,C,-0.088456,0.062313
2,45,S,*,-1.272181,0.108602
3,45,S,Y,-0.010628,0.058444
4,45,S,V,-0.764939,0.093056
...,...,...,...,...,...
5709,333,Y,C,-1.053999,0.096956
5710,333,Y,W,-0.497903,0.086452
5711,333,Y,F,-0.064831,0.053773
5712,333,Y,P,-1.287754,0.117872


In [41]:
control_variants = pd.read_table(
    'data/af6_clinvar_20231024.txt',
    usecols=['wt_aa', 'position', 'variant_aa', 'significance'],
).rename(columns={'significance': 'clinvar_annotation'}).loc[lambda x: x.clinvar_annotation != 'VUS']
control_variants

Unnamed: 0,wt_aa,position,variant_aa,clinvar_annotation
0,V,280,I,B
1,G,28,A,LB
2,A,47,T,LB
3,K,299,T,LB
4,K,331,R,LB
5,M,1,T,LP
6,G,240,V,LP
7,P,281,A,LP
8,M,1,L,P
9,W,10,*,P


In [6]:
def gmm_bic_score(estimator, X):
    """Callable to pass to GridSearchCV that will use the BIC score."""
    return -estimator.bic(X)
    # Negative since GridSearchCV maximizes the score instead of minimize


param_grid = {
    "n_components": range(1, 7),
    "covariance_type": ["full"],
}
grid_search = GridSearchCV(
    GaussianMixture(), param_grid=param_grid, scoring=gmm_bic_score
)
grid_search.fit(ne_dms_data.fitness.to_numpy().reshape(-1, 1))

In [61]:
grid_search_results = (
    pd.DataFrame(grid_search.cv_results_)[
        ["param_n_components", "param_covariance_type", "mean_test_score"]
    ]
    .assign(mean_test_score=lambda x: -x.mean_test_score)
    .rename(
        columns={
            "param_n_components": "Number of components",
            "param_covariance_type": "Type of covariance",
            "mean_test_score": "BIC score",
        }
    )
)
grid_search_results.sort_values(by="BIC score")

Unnamed: 0,Number of components,Type of covariance,BIC score
3,4,full,991.993092
2,3,full,992.372917
4,5,full,1007.26621
5,6,full,1029.270507
1,2,full,1171.739984
0,1,full,1830.051362


In [62]:
alt.Chart(grid_search_results.loc[lambda x: x['Type of covariance'] == 'full']).mark_bar().encode(
    x='Number of components:O',
    y='BIC score',
)

In [115]:
X = ne_dms_data.fitness.to_numpy().reshape(-1, 1)
n_components = 3

gmm = GaussianMixture(
        n_components=n_components, covariance_type='full', random_state=0
    ).fit(X)

x = np.linspace(-2, 0.5, 1000)
pdf = np.exp(gmm.score_samples(x.reshape(-1, 1)))
responsibilities = gmm.predict_proba(x.reshape(-1, 1))
pdf_individual = responsibilities * pdf[:, np.newaxis]
combined_model = pd.DataFrame({'x': x, 'pdf': pdf})
columns = list(range(n_components))
individual_models = (
    pd.DataFrame(pdf_individual, columns=columns, index=x)
    .reset_index()
    .rename(columns={'index': 'x'})
    .melt(id_vars='x', var_name='component', value_name='pdf')
)

In [116]:
axis_values = [-2, -1.5, -1, -0.5, 0, 0.5]

gmm_fit = (
    alt.Chart(ne_dms_data.assign(nonsense=lambda x: x.variant_aa == '*'))
    .mark_bar(size=7)
    .encode(
        x=alt.X(
            'fitness',
            bin=alt.Bin(step=0.05),
            axis=alt.Axis(values=axis_values),
        ),
        y=alt.Y('count()'),
        color=alt.Color(
            'nonsense',
            scale=alt.Scale(domain=[False, True], range=['lightgrey', '#d62728']),
        ),
    )
    + (
        alt.Chart(combined_model)
        .mark_line(color='black')
        .encode(
            x=alt.X('x', title=None),
            y=alt.Y('pdf', title=None),
        )
        + alt.Chart(individual_models)
        .mark_line(strokeDash=(4, 4))
        .encode(
            x=alt.X('x', title=None),
            y=alt.Y('pdf', title=None),
            color=alt.Color('component:N', scale=alt.Scale(scheme='set1')),
        )
    ).resolve_scale(y='shared')
).resolve_scale(y='independent', color='independent')

control_variant_scores = (
    alt.Chart(
        control_variants.merge(
            ne_dms_data, on=['wt_aa', 'position', 'variant_aa']
        ).assign(variant=lambda x: x.wt_aa + x.position.astype(str) + x.variant_aa)
    )
    .mark_circle()
    .encode(
        alt.X(
            'fitness',
            axis=alt.Axis(values=axis_values),
        ),
        color=alt.Color(
            'clinvar_annotation',
            scale=alt.Scale(
                domain=['P', 'LP', 'LB', 'B'],
                range=['#e41a1c', '#ff79896', '#6baed6', '#3182bd'],
            ),
        ),
    )
)

(gmm_fit & control_variant_scores).resolve_scale(
    x='shared', y='independent', color='independent'
)

# Clinical Analysis

In [119]:
total_likelihood = np.exp(gmm.score_samples(X))
individual_likelihoods = (
    pd.DataFrame(gmm.predict_proba(X) * total_likelihood[:, np.newaxis], columns=['strong', 'low', 'intermediate'])
)
# component 0 = strong functional impact
# component 1 = low functional impact
# component 2 = intermediate functional impact

ne_dms_data_gmm = ne_dms_data.assign(
    p_abnormal=(individual_likelihoods.strong + individual_likelihoods.intermediate) / total_likelihood,
    p_normal=individual_likelihoods.low / total_likelihood,
    functional_impact=lambda x: pd.cut(x.p_abnormal, bins=[-np.inf, 0.25, 0.75, np.inf], labels=['normal', 'uncertain', 'abnormal'])
)
ne_dms_data_gmm.to_csv('data/dms_fitness_gmm.txt', sep='\t', index=False)
ne_dms_data_gmm

Unnamed: 0,position,wt_aa,variant_aa,fitness,sigma,p_abnormal,p_normal,functional_impact
0,45,S,N,-0.157877,0.051725,0.058232,9.417675e-01,normal
1,45,S,C,-0.088456,0.062313,0.027069,9.729312e-01,normal
2,45,S,*,-1.272181,0.108602,1.000000,4.691821e-19,abnormal
3,45,S,Y,-0.010628,0.058444,0.014139,9.858609e-01,normal
4,45,S,V,-0.764939,0.093056,0.999996,3.565786e-06,abnormal
...,...,...,...,...,...,...,...,...
5709,333,Y,C,-1.053999,0.096956,1.000000,2.020757e-12,abnormal
5710,333,Y,W,-0.497903,0.086452,0.981205,1.879500e-02,abnormal
5711,333,Y,F,-0.064831,0.053773,0.021688,9.783121e-01,normal
5712,333,Y,P,-1.287754,0.117872,1.000000,1.520080e-19,abnormal


In [142]:
clinical_models = (
    pd.DataFrame(pdf_individual, columns=columns, index=x)
    .reset_index()
    .rename(columns={'index': 'x'})
    .assign(a=lambda x: x[0] + x[2], b=lambda x: x[1])[['x', 'a', 'b']]
    .melt(id_vars='x', var_name='component', value_name='pdf')
)

functional_classifications = (
    alt.Chart(ne_dms_data_gmm)
    .mark_bar(size=7.5)
    .encode(
        x=alt.X(
            'fitness',
            bin=alt.Bin(step=0.05),
            axis=alt.Axis(values=axis_values),
        ),
        y='count()',
        color=alt.Color(
            'functional_impact',
            scale=alt.Scale(
                domain=['abnormal', 'normal', 'uncertain'],
                range=['#e41a1c', '#377eb8', '#9467bd'],
            ),
        ),
        opacity=alt.value(0.6)
    )
    + (
        alt.Chart(combined_model)
        .mark_line(color='black')
        .encode(
            x=alt.X('x', title=None),
            y=alt.Y('pdf', title=None),
        )
        + alt.Chart(clinical_models)
        .mark_line(strokeDash=(4, 4))
        .encode(
            x=alt.X('x', title=None),
            y=alt.Y('pdf', title=None),
            color=alt.Color('component:N', scale=alt.Scale(scheme='set1')),
        )
    ).resolve_scale(y='shared', color='independent')
).resolve_scale(y='independent')

nonsense_variant_scores = (
    alt.Chart(ne_dms_data.loc[lambda x: x.variant_aa == '*'], title='nonsense variants')
    .mark_tick()
    .encode(
        x=alt.X(
            'fitness',
            axis=alt.Axis(values=axis_values),
        ),
        color=alt.value('#e41a1c'),
    )
)

(
    functional_classifications & control_variant_scores & nonsense_variant_scores
).resolve_scale(x='shared', y='independent', color='independent')

In [139]:
dms_distribution = (
    alt.Chart(ne_dms_data_gmm)
    .mark_bar()
    .encode(
        x=alt.X('fitness', bin=alt.Bin(step=0.05), axis=alt.Axis(values=axis_values)),
        y=alt.Y('count()', title='SNVs'),
        color=alt.Color(
            'functional_impact',
            scale=alt.Scale(scheme='set1', domain=['abnormal', 'normal', 'uncertain']),
        ),
    )
    .properties(height=50)
)


nonsense_variant_histogram = (
    alt.Chart(ne_dms_data_gmm.loc[lambda x: x.variant_aa == '*'])
    .mark_bar()
    .encode(
        x=alt.X(
            'fitness',
            bin=alt.Bin(step=0.05),
            axis=alt.Axis(values=axis_values),
        ),
        y=alt.Y('count()', title='SNVs'),
        color=alt.value('lightgray'),
    )
)

p_abnormal = (
    alt.Chart(
        ne_dms_data_gmm[
            [
                'position',
                'wt_aa',
                'variant_aa',
                'fitness',
                'p_abnormal',
                'p_normal',
            ]
        ].melt(
            id_vars=['position', 'wt_aa', 'variant_aa', 'fitness'],
            var_name='category',
            value_name='probability',
        )
    )
    .mark_line(strokeDash=(4, 4))
    .encode(
        x=alt.X('fitness', axis=alt.Axis(values=axis_values)),
        y='probability',
        color=alt.Color('category', scale=alt.Scale(scheme='set1')),
    )
)


(
    dms_distribution
    & (nonsense_variant_histogram + p_abnormal)
    .properties(height=150)
    .resolve_scale(y='independent', color='independent')
    & control_variant_scores
).resolve_scale(x='shared', color='independent')