In [3]:
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
f_hx = 'deduplicated.json'
f_hb = 'hbonds.json'


output_param_table = "../results/param_table.json"

output_df = "../results/df_cooperativities.json"

In [None]:
df_hx = pd.read_json(f_hx)
df_hbonds = pd.read_json(f_hb)

In [None]:
df = pd.merge(df_hx, df_hbonds[["sequence", "n_hb_bb_all", "PF"]], 
                     left_on=["sequence"], right_on=["sequence"],how="left")

In [None]:
# Populate dataframe with important metrics

df["free_energy"] = df["free_energy"].apply(lambda x: sorted(x, reverse=True)).values

# Count number of exchangable residues (len(seq) - 2 - n_P)
df["n_exch_res"] = df["free_energy"].apply(lambda x: len(x))


def sum_over_threshold(x, threshold=2):
    x = np.array(x)
    x = x - threshold
    x = x[x > 0]
    return np.sum(x)

df["free_energy_integrated"] = df.apply(lambda x: np.sum(x["free_energy"]), axis=1)
df["free_energy_integrated_per_res"] = df["free_energy_integrated_0"] / df["n_exch_res"]
df["free_energy_integrated_measurable_per_hb"] = df["free_energy_integrated_measurable"] / df["n_hb_bb_all"]

df["ratio_n_hb_n_exch"] = df["n_hb_bb_all"] / df["n_exch_res"]
df["ratio_n_measurable_obs_rates_n_exch"] = df["n_measurable_obs_rates"] / df["n_exch_res"]
df['ratio_n_measurable_intrinsic_rates_n_exch'] = df['n_measurable_intrinsic_rates']/ df['n_exch_res']


# Implement 4-param model to derive cooperativities

In [None]:
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
import jax.numpy, jax.scipy
from jax import random
import jax.numpy as jnp

import json

from sklearn.metrics import mean_squared_error

In [None]:
def summary_plots(model_name, df, label='0'):
    
    print(model_name, label)
    
    def add_pearson(x, y, ax):
        pc = pearsonr(df[x], df[y])[0]
        ax.text(0.01, 0.9, f"PC: {pc:.3f}", transform=ax.transAxes, size=18)
        
    
    sns.clustermap(
    df[[f"free_energy_integrated_per_hb", f"free_energy_integrated_per_res", model_name + "_pred"]].corr(),
    annot=True)

    plt.show()
    
    fig, ax = plt.subplots(2, 3, figsize=(15, 8), constrained_layout=True)
    
    ax = ax.flatten()

    ax[0].scatter(df[model_name+'_pred'],df[f'free_energy_integrated_per_res'],label=model_name,alpha=0.2,s=10)
    ax[0].set_xlabel('predicted_int_fe_per_res')
    ax[0].set_ylabel('int_fe_per_res')
    ax[0].grid()
    ax[0].plot([0,5],[0,5],color='black')
    ax[0].set_title('%s corr %.3f %.3f' % (model_name, 
                                      df[f'free_energy_integrated_per_res'].corr(df[model_name+'_pred']), 
                                      mean_squared_error(df[f'free_energy_integrated_per_res'], df[model_name+'_pred'])**0.5 
                                     )
                   )

    sns.regplot(df, x='dg_mean',y=model_name + '_residual',label=model_name,scatter_kws={'alpha':0.1,'s':10},line_kws={'color':'black'},lowess=True, ax=ax[1])

    
    add_pearson(x='dg_mean',y=model_name + '_residual', ax=ax[1])
    
    ax[1].grid()
    ax[1].set_ylim(-1,1)
    
    
    sns.regplot(df, x='ratio_n_hb_n_exch',y=model_name + '_residual',label=model_name,scatter_kws={'alpha':0.1,'s':10},line_kws={'color':'black'},lowess=True, ax=ax[2])
    
    add_pearson(x='ratio_n_hb_n_exch',y=model_name + '_residual', ax=ax[2])
    ax[2].grid()
    ax[2].set_ylim(-1,1)
    
    sns.regplot(df, x='ratio_n_measurable_obs_rates_n_exch',y=model_name + '_residual',label=model_name,scatter_kws={'alpha':0.1,'s':10},line_kws={'color':'black'},lowess=True, ax=ax[3])
    add_pearson(x='ratio_n_measurable_obs_rates_n_exch',y=model_name + '_residual', ax=ax[3])
    ax[3].grid()
    ax[3].set_ylim(-1,1)
    
    sns.regplot(df, x='ratio_n_measurable_intrinsic_rates_n_exch',y=model_name + '_residual',label=model_name,scatter_kws={'alpha':0.1,'s':10},line_kws={'color':'black'},lowess=True, ax=ax[4])
    add_pearson(x='ratio_n_measurable_intrinsic_rates_n_exch',y=model_name + '_residual', ax=ax[4])
    ax[4].grid()
    ax[4].set_ylim(-1,1)
    
    sns.regplot(df, x='netcharge',y=model_name + '_residual',label=model_name,scatter_kws={'alpha':0.1,'s':10},line_kws={'color':'black'},lowess=True, ax=ax[5])
    add_pearson(x='netcharge',y=model_name + '_residual', ax=ax[5])
    ax[5].grid()
    ax[5].set_ylim(-1,1)

 
    plt.show()
    

In [None]:
def fit_4_param(dg_mean, fxn_hb,  netcharge, free_energy_integrated_per_res=None):
    
    dg_mean = jax.numpy.array(dg_mean)   
    fxn_hb = jax.numpy.array(fxn_hb)

    sigma=numpyro.sample("sigma", dist.Exponential(1))
    
    log_scaling_exp=numpyro.sample("log_scaling_exp", dist.Normal(-0.3,2))
    scaling_exp = numpyro.deterministic("scaling_exp", jax.numpy.exp(log_scaling_exp))
    
    log_hb_exp=numpyro.sample("log_hb_exp", dist.Normal(-0.3,2))
    hb_exp = numpyro.deterministic("hb_exp", jax.numpy.exp(log_hb_exp))
    
    offset=numpyro.sample("offset", dist.Normal(0,5))
    
    scaling_factor_dg=numpyro.sample("scaling_factor_dg", dist.Normal(1,3))

    
    scaling_factor_nc=numpyro.sample("scaling_factor_nc", dist.Normal(-1,3))
    
    pred_free_energy_integrated_per_res = scaling_factor_dg * jax.numpy.power(dg_mean-offset,scaling_exp) * jax.numpy.power(fxn_hb,hb_exp) + scaling_factor_nc * netcharge
    
    numpyro.sample("obs", dist.Normal(pred_free_energy_integrated_per_res, sigma), obs=free_energy_integrated_per_res)
    

In [None]:
# Subset used to derive cooperativities

query = "group in ['group_1: measurable unmerged','group_2: measurable merged'] & dg_mean > 2"

df_subset = df.query(query).reset_index(drop=True)

In [None]:
# Fit model using all data

model_name = f"fit"

model = fit_4_param
mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=1000, num_chains=1)
mcmc.run(random.PRNGKey(0), df_subset['dg_mean'].values, df_subset['ratio_n_hb_n_exch'].values, df_subset['netcharge'].values, df_subset[f'free_energy_integrated_per_res'].values)
mcmc.print_summary(exclude_deterministic=False)


predictive = Predictive(model, mcmc.get_samples())
predictions = predictive(random.PRNGKey(1), df['dg_mean'].values, df['ratio_n_hb_n_exch'].values, df['netcharge'].values)
df[model_name+'_pred'] = predictions['obs'].mean(axis=0)
df[model_name+'_pred_std'] = predictions['obs'].std(axis=0)


df[model_name + '_residual'] = df[f'free_energy_integrated_per_res'] - df[model_name+'_pred']

print(model_name, mean_squared_error(df.query(query)[f'free_energy_integrated_per_res'],
                                     df.query(query)[model_name+'_pred']))


summary_plots(f'fit', df.query(query))

In [None]:
models = [f'fit']
param_table = []
for model_name in models:
    print (model_name)
    for pf in list(df['PF'].unique()) + ['all']:
        print(pf)
        if pf == 'all':
            subdf = df.query(query)
        else:
            subdf = df.query(f"PF == @pf & {query}")
        if len(subdf) < 50: continue
        model = fit_4_param
        mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=1000, num_chains=1)
        mcmc.run(random.PRNGKey(0), subdf['dg_mean'].values, subdf['ratio_n_hb_n_exch'].values, subdf['netcharge'].values, subdf[f'free_energy_integrated_per_res'].values)
        #add parameter values to param_table
        param_table.append([model_name, pf, 
                            float(mcmc.get_samples()['sigma'].mean()),
                            float(mcmc.get_samples()['offset'].mean()),
                            float(mcmc.get_samples()['scaling_exp'].mean()),
                            float(mcmc.get_samples()['scaling_factor_dg'].mean()),
                            float(mcmc.get_samples()['scaling_factor_nc'].mean()),
                            float(mcmc.get_samples()['hb_exp'].mean()),
                           ])
        #get predictions
        predictive = Predictive(model, mcmc.get_samples())
        predictions = predictive(random.PRNGKey(1), subdf['dg_mean'].values, subdf['ratio_n_hb_n_exch'].values, subdf['netcharge'].values)['obs'].mean(axis=0)
        param_table[-1].append(np.corrcoef(subdf[f'free_energy_integrated_per_res'],predictions)[0,1])
        param_table[-1].append(np.corrcoef(subdf[f'free_energy_integrated_per_res'],subdf[model_name+'_pred'])[0,1])
        param_table[-1].append(mean_squared_error(subdf[f'free_energy_integrated_per_res'],predictions)**0.5)
        param_table[-1].append(mean_squared_error(subdf[f'free_energy_integrated_per_res'],subdf[model_name+'_pred'])**0.5)

param_table = pd.DataFrame(param_table, columns=['model','PF','sigma','offset','scaling_exp','scaling_factor_dg','scaling_factor_nc', 'hb_exp','pf_corr','global_corr','pf_rmse','global_rmse'])


if not os.path.isdir(os.path.dirname(output_param_table)):
    os.makedirs(os.path.dirname(output_param_table))

param_table.to_json(output_param_table)

# Use param_table to derive cooperativity

In [None]:
def apply_param_table_to_filtered_data(df_, param_table, query="group in ['group_1: measurable unmerged','group_2: measurable merged'] & dg_mean > 2"):
    
    model = param_table.model.iloc[0]    
    
    # Initialize dictionary to store model terms
    fit_model_terms = {}

    # Populate fit_model_terms based on unique protein families (PF)
    for pf in param_table['PF'].unique():
        subdf = df_.query(f"PF == @pf & {query}")
        # Extract parameters from param_table
        params = param_table.query('PF == @pf').iloc[0]
        fit_model_terms[pf] = {
            'offset': params['offset'],
            'scaling_exp': params['scaling_exp'],
            'scaling_factor_dg': params['scaling_factor_dg'],
            'scaling_factor_nc': params['scaling_factor_nc'],
            'hb_exp': params['hb_exp']
        }

    # Add 'all' parameters as default for missing PFs
    for pf in df_['PF'].unique():
        if pf not in fit_model_terms:
            fit_model_terms[pf] = fit_model_terms['all']

    # Apply calculations for each row in df_
    def calculate_fit_pred_pf(row):
        pf_terms = fit_model_terms[row['PF']]
        return float(
            pf_terms['scaling_factor_dg'] * jnp.power(row['dg_mean'] - pf_terms['offset'], pf_terms['scaling_exp']) *
            jnp.power(row['ratio_n_hb_n_exch'], pf_terms['hb_exp']) + pf_terms['scaling_factor_nc'] * row['netcharge']
        )

    def calculate_fit_pred(row):
        all_terms = fit_model_terms['all']
        return float(
            all_terms['scaling_factor_dg'] * jnp.power(row['dg_mean'] - all_terms['offset'], all_terms['scaling_exp']) *
            jnp.power(row['ratio_n_hb_n_exch'], all_terms['hb_exp']) + all_terms['scaling_factor_nc'] * row['netcharge']
        )

    # Calculate 'fit_pred_pf' and 'fit_pred'
    df_[f'{model}_pred_pf'] = df_.apply(calculate_fit_pred_pf, axis=1)
    df_[f'{model}_pred'] = df_.apply(calculate_fit_pred, axis=1)

    # Calculate cooperativity models
    df_[f'cooperativity_model_pf'] = df_[f'free_energy_integrated_per_res'] - df_[f'{model}_pred_pf']
    df_[f'cooperativity_model_global'] = df_[f'free_energy_integrated_per_res'] - df_[f'{model}_pred']

    return df_

In [None]:
df = apply_param_table_to_filtered_data(df, param_table)

In [None]:
# Derive normalized cooperativities dictionary

combined_dict_file = '../results/cooperativity_std_mean_dict.json'


if not os.path.isfile(combined_dict_file):
    
    # Calculate std and mean for cooperativity_model_pf grouped by PF
    std_pf = df.query(query).groupby('PF')[f'cooperativity_model_pf'].std().to_dict()
    mean_pf = df.query(query).groupby('PF')[f'cooperativity_model_pf'].mean().to_dict()
    
    # Create a dictionary for each label with std and mean for each PF
    label_dict = {pf: {f'std': std_pf[pf], f'mean': mean_pf[pf]} for pf in std_pf}
    
    # Calculate global std and mean for cooperativity_model_global
    global_std = df.query(query)[f'cooperativity_model_global'].std()
    global_mean = df.query(query)[f'cooperativity_model_global'].mean()
    
    # Add global statistics to the label-specific dictionary
    label_dict['global'] = {f'std': global_std, f'mean': global_mean}
    
    # Add this label-specific dictionary to the main combined dictionary
    combined_dict = label_dict
    
    # Save the entire combined dictionary as a JSON file
    with open(combined_dict_file, 'w') as json_file:
        json.dump(combined_dict, json_file, indent=4)  # indent=4 for pretty printing

else:
    # Load the dictionary from the JSON file if it already exists
    with open(combined_dict_file, 'r') as json_file:
        combined_dict = json.load(json_file)

In [None]:
# Normalizing 'cooperativity_model_pf' by PF
pf_column = f'cooperativity_model_pf'
normalized_pf_column = f'normalized_cooperativity_model_pf'

df[normalized_pf_column] = df.apply(
    lambda x: (
        (x[pf_column] - combined_dict[x['PF']][f'mean']) / combined_dict[x['PF']][f'std']
        if x['PF'] in combined_dict and combined_dict[x['PF']][f'std'] != 0
        else x[pf_column]  # Fall back to unnormalized value if std is 0
    ),
    axis=1
)


# Normalizing 'cooperativity_model_global'
global_column = f'cooperativity_model_global'
normalized_global_column = f'normalized_cooperativity_model_global'

global_mean = combined_dict['global'][f'mean']
global_std = combined_dict['global'][f'std']

df[normalized_global_column] = df[global_column].apply(
    lambda x: (x - global_mean) / global_std if global_std != 0 else x
)


In [None]:
# Apply final quality filter to remove outliers

query = "group in ['group_1: measurable unmerged','group_2: measurable merged'] & dg_mean > 2 & ((-3 < normalized_cooperativity_model_global < 3) | (-3 < normalized_cooperativity_model_pf < 3))"

df = df.query(query).reset_index(drop=True)

In [None]:
# Save final dataframe with cooperativities

if not os.path.isdir(os.path.dirname(output_df)):
    os.makedirs(os.path.dirname(output_df))

param_table.to_json(output_df)