In [1]:
import sqlite3
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import matplotlib as mpl

# set font to 15
mpl.rcParams.update({'font.size': 8})

# Set Helvetica font globally
mpl.rcParams['font.family'] = 'sans-serif'
mpl.rcParams['font.sans-serif'] = ['Helvetica']

# Optional: make sure fonts are updated
mpl.rcParams['pdf.fonttype'] = 42  # Avoids Type 3 fonts in PDF output
mpl.rcParams['ps.fonttype'] = 42

In [2]:
nerd_sqlite = '../../Core_nerd_analysis/nerd.sqlite'
rg_ids = [123, 124, 129, 130]  # update if needed

conn = sqlite3.connect(nerd_sqlite)
conn.row_factory = sqlite3.Row

query = """
SELECT *
FROM probe_tc_fits_view
WHERE fit_kind = 'round3_constrained'
  AND rg_id IN ({})
  AND rt_protocol = 'MRT'
ORDER BY rg_id
""".format(",".join(["?"] * len(rg_ids)))

df = pd.read_sql_query(query, conn, params=rg_ids)
conn.close()
df

OperationalError: unable to open database file

In [None]:
import numpy as np

def calculate_dG(logkobs, logkadd):
    R = 1.9872036e-3  # kcal/(mol*K)
    T = 298.15  # K
    
    logKKp1 = logkobs - logkadd
    KKp1 = np.exp(logKKp1) # K / (K+1)
    K = KKp1 / (1 - KKp1)
    dG = -R * T * np.log(K)

    return dG

logkadd = -2.676278119425211 # logkadd
logkobs_r2_99 = -3.4806204679354953
logkobs_r2_55 = -7.240355294404981
logkobs_r2_39 = -7.503658411920693

dG_r2_99 = calculate_dG(logkobs_r2_99, logkadd)
dG_r2_55 = calculate_dG(logkobs_r2_55, logkadd)
dG_r2_39 = calculate_dG(logkobs_r2_39, logkadd)

print(f"dG for r2=0.99: {dG_r2_99:.2f} kcal/mol")
print(f"dG for r2=0.55: {dG_r2_55:.2f} kcal/mol")
print(f"dG for r2=0.39: {dG_r2_39:.2f} kcal/mol")

In [None]:
# filter to nt_base = C
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

df_c = df[df['nt_base'] == 'C']

fig, ax = plt.subplots(figsize=(4, 4))

# Main scatter plot
sns.scatterplot(data=df_c, x='log_kobs', y='r2', ax=ax)

# Create inset: width, height as %, and location
inset_ax = inset_axes(
    ax,
    width="40%", height="60%",
    loc='upper left',
    bbox_to_anchor=(0.2, -0.05, 1, 1),   # shift right by increasing first value
    bbox_transform=ax.transAxes,
    borderpad=0
)
# Plot same data in inset
sns.scatterplot(data=df_c, x='log_kobs', y='r2', ax=inset_ax)

# Zoom limits for inset
inset_ax.set_xlim(-10, 0)

# Optional: tighten inset ticks
inset_ax.tick_params(axis='both', labelsize=10)

# remove inset axes labels
inset_ax.set_xlabel('')
inset_ax.set_ylabel('')

# set main axes labels
ax.set_xlabel(r'$\log(k_{obs})$')
ax.set_ylabel(r'$R^2$')

plt.tight_layout()
#plt.savefig("exports/r2_v_log_kobs_p4p6_C.pdf")
plt.show()

In [None]:
# filter to nt_base = C
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

df_c = df[df['nt_base'] == 'U']

fig, ax = plt.subplots(figsize=(4, 4))

# Main scatter plot
sns.scatterplot(data=df_c, x='log_kobs', y='r2', ax=ax)

# set main axes labels
ax.set_xlabel(r'$\log(k_{obs})$')
ax.set_ylabel(r'$R^2$')

# set y lim of both 0 to 1
ax.set_ylim(0, 1)

plt.tight_layout()
#plt.savefig("exports/r2_v_log_kobs_p4p6_A.pdf")
plt.show()

In [None]:
# get max log_kobs for each nt_base and save in dict
max_logkobs = {}
for base in df['nt_base'].unique():
    df_base = df[df['nt_base'] == base]
    max_logkobs[base] = df_base['log_kobs'].max()
max_logkobs

def calc_dG(logkobs, logkadd):
    R = 1.9872036e-3  # kcal/(mol*K)
    T = 298.15  # K
    
    logKKp1 = logkobs - logkadd
    KKp1 = np.exp(logKKp1) # K / (K+1)
    K = KKp1 / (1 - KKp1)
    dG = -R * T * np.log(K)

    return dG


# calculate dG using base-specific max log_kobs

df['dG'] = df.apply(lambda row: calc_dG(row['log_kobs'], max_logkobs[row['nt_base']]), axis=1)
df

In [None]:
## Analysis of correlation

def keep_higher_r2(df):
    for site in df['nt_site'].unique():
        # drop row with lower r2
        df_site = df[df['nt_site'] == site]
        if len(df_site) != 2:
            continue
        if df_site.iloc[0]['r2'] > df_site.iloc[1]['r2']:
            df = df.drop(df_site.index[1])
        else:
            df = df.drop(df_site.index[0])
    return df


df_nomg = df[df['buffer_id'] == 2]
df_nomg = keep_higher_r2(df_nomg)
df_mg = df[df['buffer_id'] == 3]
df_mg = keep_higher_r2(df_mg)
df_mg

# merge df_nomg and df_mg on nt_site and keep the following: nt_site, nt_base, log_kobs, log_kobs_err, dG, r2
# add suffix to all but nt_site and nt_base
df_merged = pd.merge(df_nomg[['nt_site', 'nt_base', 'log_kobs', 'log_kobs_err', 'dG', 'r2']],
                     df_mg[['nt_site', 'log_kobs', 'log_kobs_err', 'dG', 'r2']],
                     on=['nt_site'],
                     suffixes=('_nomg', '_mg'))

# filter to A and C only
df_merged_AC = df_merged[df_merged['nt_base'].isin(['A', 'C'])]

# filter to R2 > 0
df_merged_AC = df_merged_AC[(df_merged_AC['r2_nomg'] > 0.3) & (df_merged_AC['r2_mg'] > 0.3)]

fig, axs = plt.subplots(1, 2, figsize=(10, 5))
sns.scatterplot(data=df_merged_AC, x='log_kobs_nomg', y='log_kobs_mg', ax=axs[0])
sns.scatterplot(data=df_merged_AC, x='dG_nomg', y='dG_mg', ax=axs[1])
plt.show()

In [None]:
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from scipy import stats

# Fit linear regression
slope, intercept, r_value, p_value, std_err = stats.linregress(
    df_merged_AC['log_kobs_nomg'],
    df_merged_AC['log_kobs_mg']
)

# Predicted y-values
x = df_merged_AC['log_kobs_nomg']
y_pred = intercept + slope * x

# Residuals
residuals = df_merged_AC['log_kobs_mg'] - y_pred

# Standard deviation of residuals (σ)
sigma = residuals.std()

# R2 value
r2 = r_value**2
r2

slope, intercept, sigma

In [None]:
fig, ax = plt.subplots(figsize=(5, 5))
sns.scatterplot(
    data=df_merged_AC,
    x='log_kobs_nomg',
    y='log_kobs_mg',
    ax=ax
)

# Regression line
x_line = np.linspace(x.min(), x.max(), 200)
y_line = intercept + slope * x_line
ax.plot(x_line, y_line, color='black', label='Linear fit')

# ±1σ band
ax.fill_between(
    x_line,
    y_line - sigma,
    y_line + sigma,
    color='gray',
    alpha=0.2,
    label='±1σ'
)

# ±2σ band
ax.fill_between(
    x_line,
    y_line - 2*sigma,
    y_line + 2*sigma,
    color='gray',
    alpha=0.1,
    label='±2σ'
)

# ±3σ band
ax.fill_between(
    x_line,
    y_line - 3*sigma,
    y_line + 3*sigma,
    color='gray',
    alpha=0.05,
    label='±3σ'
)

ax.set_xlabel(r'$ln(k_{\text{obs}})$ (No Mg)')
ax.set_ylabel(r'$ln(k_{\text{obs}})$ (5 mM Mg)')

# annotate slope, intercept, r2
ax.text(
    0.05, 0.95,
    f'Slope: {slope:.2f}\nIntercept: {intercept:.2f}\n$R^2$: {r2:.2f}',
    transform=ax.transAxes,
    verticalalignment='top'
    )

# legend bottom right
ax.legend(loc='lower right')
plt.tight_layout()
plt.savefig("exports/log_kobs_nomg_v_log_kobs_mg_p4p6_AC.pdf")
plt.show()

In [None]:
# residuals
df_merged_AC['residual'] = df_merged_AC['log_kobs_mg'] - y_pred

# absolute residual distance
df_merged_AC['abs_residual'] = df_merged_AC['residual'].abs()

# classification
df_merged_AC['sigma_class'] = np.where(
    df_merged_AC['abs_residual'] <= sigma, 'within_1sigma',
    np.where(
        df_merged_AC['abs_residual'] <= 2*sigma, 'within_2sigma',
        'outside_2sigma'
    )
)
df_merged_AC.to_csv("exports/log_kobs_nomg_v_log_kobs_mg_p4p6_AC_with_residuals.csv", index=False)

In [None]:
df_within_1sigma = df_merged_AC[df_merged_AC['sigma_class'] == 'within_1sigma']
df_within_1sigma

In [None]:
df_within_2sigma = df_merged_AC[df_merged_AC['sigma_class'] == 'within_2sigma']
df_within_2sigma

In [None]:
df_outside_2sigma = df_merged_AC[df_merged_AC['sigma_class'] == 'outside_2sigma']
df_outside_2sigma

In [None]:
# dG version

df_merged_AC_dG = df_merged_AC.dropna(subset=['dG_nomg', 'dG_mg'])
bad = ~np.isfinite(df_merged_AC_dG['dG_nomg']) | ~np.isfinite(df_merged_AC_dG['dG_mg'])
df_merged_AC_dG = df_merged_AC_dG[~bad]  # or drop/clean those rows

# Fit linear regression
slope, intercept, r_value, p_value, std_err = stats.linregress(
    df_merged_AC_dG['dG_nomg'],
    df_merged_AC_dG['dG_mg']
)

slope, intercept, sigma

In [None]:
fig, ax = plt.subplots(figsize=(5, 5))
sns.scatterplot(
    data=df_merged_AC_dG,
    x='dG_nomg',
    y='dG_mg',
    ax=ax
)

# Regression line
x_line = np.linspace(df_merged_AC_dG['dG_nomg'].min(), df_merged_AC_dG['dG_nomg'].max(), 200)
y_line = intercept + slope * x_line
ax.plot(x_line, y_line, color='black', label='Linear fit')

# ±1σ band
ax.fill_between(
    x_line,
    y_line - sigma,
    y_line + sigma,
    color='gray',
    alpha=0.2,
    label='±1σ'
)

# ±2σ band
ax.fill_between(
    x_line,
    y_line - 2*sigma,
    y_line + 2*sigma,
    color='gray',
    alpha=0.1,
    label='±2σ'
)

# ±3σ band
ax.fill_between(
    x_line,
    y_line - 3*sigma,
    y_line + 3*sigma,
    color='gray',
    alpha=0.05,
    label='±3σ'
)

ax.set_xlabel(r'$ln(k_{\text{obs}})$ (No Mg)')
ax.set_ylabel(r'$ln(k_{\text{obs}})$ (5 mM Mg)')

# annotate slope, intercept, r2
ax.text(
    0.05, 0.95,
    f'Slope: {slope:.2f}\nIntercept: {intercept:.2f}\n$R^2$: {r2:.2f}',
    transform=ax.transAxes,
    verticalalignment='top'
    )

# legend bottom right
ax.legend(loc='lower right')
plt.tight_layout()
plt.savefig("exports/dG_nomg_v_dG_mg_p4p6_AC.pdf")
plt.show()

In [None]:
y_pred = intercept + slope * df_merged_AC_dG['dG_nomg']

# residuals
df_merged_AC_dG['residual'] = df_merged_AC_dG['dG_mg'] - y_pred

# absolute residual distance
df_merged_AC_dG['abs_residual'] = df_merged_AC_dG['residual'].abs()

# classification
df_merged_AC_dG['sigma_class'] = np.where(
    df_merged_AC_dG['abs_residual'] <= sigma, 'within_1sigma',
    np.where(
        df_merged_AC_dG['abs_residual'] <= 2*sigma, 'within_2sigma',
        'outside_2sigma'
    )
)
df_merged_AC_dG.to_csv("exports/dG_nomg_v_dG_mg_p4p6_AC_with_residuals.csv", index=False)

In [None]:
df_merged_AC_dG['residual'].hist()

In [None]:
# filter to within 1 sigma
df_within_1sigma = df_merged_AC_dG[df_merged_AC_dG['sigma_class'] == 'within_1sigma']

# linear regression on within 1 sigma
slope, intercept, r_value, p_value, std_err = stats.linregress(
    df_within_1sigma['dG_nomg'],
    df_within_1sigma['dG_mg']
)

# Predicted y-values
x = df_merged_AC_dG['log_kobs_nomg']
y_pred = intercept + slope * x

# Residuals
residuals = df_merged_AC_dG['log_kobs_mg'] - y_pred

# Standard deviation of residuals (σ)
sigma = residuals.std()

# R2 value
r2 = r_value**2

slope, intercept, sigma

In [None]:
df_within_1sigma

In [None]:
df_merged_AC_dG[['nt_site', 'nt_base', 'dG_nomg', 'dG_mg']]

In [None]:
df_within_2sigma = df_merged_AC_dG[df_merged_AC_dG['sigma_class'] == 'within_2sigma']
df_within_2sigma[['nt_site', 'nt_base', 'dG_nomg', 'dG_mg']]

In [None]:
df_outside_2sigma = df_merged_AC_dG[df_merged_AC_dG['sigma_class'] == 'outside_2sigma']
df_outside_2sigma[['nt_site', 'nt_base', 'dG_nomg', 'dG_mg']]

In [None]:
fig, ax = plt.subplots(figsize=(5, 5))
sns.scatterplot(
    data=df_merged_AC_dG,
    x='dG_nomg',
    y='dG_mg',
    ax=ax
)

# Regression line
x_line = np.linspace(df_merged_AC_dG['dG_nomg'].min(), df_merged_AC_dG['dG_nomg'].max(), 200)
y_line = intercept + slope * x_line
ax.plot(x_line, y_line, color='black', label='Linear fit')

# ±1σ band
ax.fill_between(
    x_line,
    y_line - sigma,
    y_line + sigma,
    color='gray',
    alpha=0.2,
    label='±1σ'
)

# ±2σ band
ax.fill_between(
    x_line,
    y_line - 2*sigma,
    y_line + 2*sigma,
    color='gray',
    alpha=0.1,
    label='±2σ'
)

# ±3σ band
ax.fill_between(
    x_line,
    y_line - 3*sigma,
    y_line + 3*sigma,
    color='gray',
    alpha=0.05,
    label='±3σ'
)

ax.set_xlabel(r'$\Delta G_{\text{No Mg}}$ (kcal/mol)')
ax.set_ylabel(r'$\Delta G_{\text{5 mM Mg}}$ (kcal/mol)')

# annotate slope, intercept, r2
ax.text(
    0.05, 0.95,
    f'Slope: {slope:.2f}\nIntercept: {intercept:.2f}\n$R^2$: {r2:.2f}',
    transform=ax.transAxes,
    verticalalignment='top'
    )

# legend bottom right
ax.legend(loc='lower right')
plt.tight_layout()
plt.savefig("exports/dG_nomg_v_dG_mg_p4p6_AC_1sigma.pdf")
plt.show()

In [None]:
df_merged_AC_dG[['nt_site', 'nt_base', 'dG_nomg', 'dG_mg']]

In [None]:
# calculate classifier

df_canonical_annot = pd.read_csv("exports/dG_nomg_v_dG_mg_p4p6_AC_with_residuals_canonical_annot.csv")
df_canonical_annot

In [None]:
from sklearn.tree import DecisionTreeClassifier

clf = DecisionTreeClassifier(max_depth=1)
clf.fit(df_canonical_annot[['abs_residual']], df_canonical_annot['canonical'])

threshold = clf.tree_.threshold[0]
threshold

In [None]:
df_canonical_annot['canonical'].sum() / len(df_canonical_annot)

In [None]:
df_canonical_annot[df_canonical_annot['abs_residual'] <= threshold]['canonical'].sum() / 54

In [None]:
df_canonical = df_canonical_annot[df_canonical_annot['canonical'] == 1]

In [None]:
df_threshold = df_canonical_annot[df_canonical_annot['abs_residual'] <= threshold]
df_threshold

In [None]:
sigma

In [None]:
fig, ax = plt.subplots(figsize=(5, 5))
sns.scatterplot(
    data=df_merged_AC_dG,
    x='dG_nomg',
    y='dG_mg',
    ax=ax
)

# plot threshold scatter points in different color
sns.scatterplot(
    data=df_canonical,
    x='dG_nomg',
    y='dG_mg',
    ax=ax, color='red',
    label='Canonical WCF'
)

# Regression line
x_line = np.linspace(df_merged_AC_dG['dG_nomg'].min(), df_merged_AC_dG['dG_nomg'].max(), 200)
y_line = intercept + slope * x_line
ax.plot(x_line, y_line, color='black', label='Linear fit')

# ±1σ band
ax.fill_between(
    x_line,
    y_line - threshold,
    y_line + threshold,
    color='red',
    alpha=0.1,
    label='calc. threshold'
)


# ±1σ band
ax.fill_between(
    x_line,
    y_line - sigma,
    y_line + sigma,
    color='gray',
    alpha=0.2,
    label='±1σ'
)

# ±2σ band
ax.fill_between(
    x_line,
    y_line - 2*sigma,
    y_line + 2*sigma,
    color='gray',
    alpha=0.1,
    label='±2σ'
)

# ±3σ band
ax.fill_between(
    x_line,
    y_line - 3*sigma,
    y_line + 3*sigma,
    color='gray',
    alpha=0.05,
    label='±3σ'
)


ax.set_xlabel(r'$\Delta G_{\text{No Mg}}$ (kcal/mol)')
ax.set_ylabel(r'$\Delta G_{\text{5 mM Mg}}$ (kcal/mol)')

# annotate slope, intercept, r2
ax.text(
    0.05, 0.95,
    f'Slope: {slope:.2f}\nIntercept: {intercept:.2f}\n$R^2$: {r2:.2f}',
    transform=ax.transAxes,
    verticalalignment='top'
    )

# legend bottom right
ax.legend(loc='lower right')
plt.tight_layout()
plt.savefig("exports/dG_nomg_v_dG_mg_p4p6_AC_1sigma_calc.pdf")
plt.show()

In [None]:
fourU_correlation = pd.read_csv("exports/dG_nomg_v_dG_mg_p4p6_AC_with_residuals_canonical_annot.csv")


In [None]:
hiv_data = pd.read_csv("exports/hiv_calc_dG_manual_mean.csv")
other_var = 'dG_a35g'

# filter to only dG_wt and dG_a35 columns then dropna
hiv_data = hiv_data[['site_nt', 'dG_wt', other_var]].dropna()

# calculate ddG
hiv_data['ddG'] = hiv_data[other_var] - hiv_data['dG_wt']

# Fit linear regression
slope, intercept, r_value, p_value, std_err = stats.linregress(
    hiv_data['dG_wt'],
    hiv_data[other_var]
)

# Predicted y-values
x = hiv_data['dG_wt']
y_pred = intercept + slope * x

# Residuals
residuals = hiv_data[other_var] - y_pred
# Standard deviation of residuals (σ)
sigma = residuals.std()

# R2 value
r2 = r_value**2
r2

slope, intercept, sigma

sigma2 = 0.5

In [None]:
fig, ax = plt.subplots(figsize=(5, 5))
sns.scatterplot(
    data=hiv_data,
    x='dG_wt',
    y=other_var,
    ax=ax
)

# Regression line
x_line = np.linspace(hiv_data['dG_wt'].min(), hiv_data['dG_wt'].max(), 200)
y_line = intercept + slope * x_line
ax.plot(x_line, y_line, color='black', label='True fit')


# ±1σ band
ax.fill_between(
    x_line,
    y_line - sigma,
    y_line + sigma,
    color='gray',
    alpha=0.2,
    label='±1σ'
)
# sigma2
ax.fill_between(
    x_line,
    y_line - sigma2,
    y_line + sigma2,
    color='gray',
    alpha=0.1,
    label='±0.5 kcal/mol'
)


# annotate each scatter point with site_nt
for _, row in hiv_data.iterrows():
    ax.text(row['dG_wt'], row[other_var], str(row['site_nt']), fontsize=10)

ax.set_xlabel(r'$\Delta G_{\text{WT}}$ (kcal/mol)')
ax.set_ylabel(r'$\Delta G_{\text{Mutant}}$ (kcal/mol)')

# set title based on other_var
if other_var == 'dG_a35g':
    title = 'WT vs. A35G'
elif other_var == 'dG_c30u':
    title = 'WT vs. C30U'
elif other_var == 'dG_gs':
    title = 'WT vs. UUCG (GS)'
elif other_var == 'dG_es2':
    title = 'WT vs. UUCG (ES2)'
ax.set_title(title)

ax.legend()
# save
plt.tight_layout()
#plt.savefig(f"exports/dG_wt_v_{other_var}_hiv.pdf")

In [None]:
# Define the regions
regions = {
    'lower_stem': [19, 20, 41, 44, 45],
    'upper_stem': [27, 29, 37, 39],
    'tribulge': [22, 24],
    'apical_loop': [30, 33, 35]
}

In [None]:
# swarmplot of ddG values (v1)
hiv_data = pd.read_csv("exports/hiv_calc_dG_manual_mean.csv")

hiv_data['ddG_a35g'] = hiv_data['dG_a35g'] - hiv_data['dG_wt']
hiv_data['ddG_c30u'] = hiv_data['dG_c30u'] - hiv_data['dG_wt']
hiv_data['ddG_gs'] = hiv_data['dG_gs'] - hiv_data['dG_wt']
hiv_data['ddG_es2'] = hiv_data['dG_es2'] - hiv_data['dG_wt']

# melt the dataframe to long format (ddG)
hiv_data_melted = hiv_data.melt(
    id_vars=['site', 'site_nt', 'dG_wt'],
    value_vars=['ddG_a35g', 'ddG_c30u', 'ddG_gs', 'ddG_es2'],
    var_name='Mutation',
    value_name='ddG'
).dropna()

# replace ddG_a35g with A35G, ddG_c30u with C30U, ddG_gs with GS, ddG_es2 with ES2
hiv_data_melted['Mutation'] = hiv_data_melted['Mutation'].replace({
    'ddG_a35g': 'A35G',
    'ddG_c30u': 'C30U',
    'ddG_gs': 'UUCG',
    'ddG_es2': 'UUCG(ES2)'
})

# annotate regions based on dict
hiv_data_melted['Region'] = hiv_data_melted['site'].apply(lambda x: next(
    (region for region, sites in regions.items() if x in sites), 'Other'))

# calculate 1 sigma
sigma_1 = np.std(hiv_data_melted['ddG'], ddof=1)   # 1σ
# sigma_2
sigma_2 = 2 * sigma_1
print(f"1σ: {sigma_1:.2f} kcal/mol")
print(f"2σ: {sigma_2:.2f} kcal/mol")
# swarmplot with region-specific markers
region_markers = {
    'lower_stem': 'o',
    'upper_stem': 's',
    'tribulge': 'D',
    'apical_loop': 'X',
    'Other': '^'
}



# colors gs: de8452, a35g: 55a868, c30u: c54f53, uucg es2: 685ca8, wt: 4d73b1

palette = {
    'A35G': '#55a868',
    'C30U': '#c54f53',
    'UUCG': '#de8452',
    'UUCG(ES2)': '#685ca8'
}

fig, axs = plt.subplots(1, 2, figsize=(4.5, 3.5), sharey = True, gridspec_kw={'width_ratios': [3, 1]})
for mut in hiv_data_melted["Mutation"].unique():
    subset = hiv_data_melted[hiv_data_melted["Mutation"] == mut]
    sns.swarmplot(
        data=subset,
        x='Mutation',
        y='ddG',
        dodge=True,
        color = palette[mut],
        ax=axs[0]
    )

# annotate sites outside of 1 sigma on swarmplot
for _, row in hiv_data_melted.iterrows():
    if abs(row['ddG']) > sigma_1:
        axs[0].text(
            x=row['Mutation'],
            y=row['ddG'],
            s=str(row['site']),
            fontsize=11,
            ha='left',
            va='bottom'
        )

# shade ±1 sigma
axs[0].axhspan(-sigma_1, sigma_1, color='gray', alpha=0.2, label='±1σ')

# drop xaxis label for axs[0]
axs[0].set_xlabel('')

# rotate xtick labels for axs[0]
axs[0].set_xticklabels(axs[0].get_xticklabels(), rotation=30, ha='right')

# on axs[1], histogram of all ddG values with 1sigma shaded, vertical
sns.histplot(
    data=hiv_data_melted,
    y='ddG',
    bins=10,
    kde=True,
    color='gray',
    ax=axs[1]
)


# drop xaxis labels for axs[1]
axs[1].set_xlabel('')
# drop ticks for xaxis
axs[1].set_xticks([])

handles = [
    plt.Line2D(
        [0], [0],
        marker=region_markers.get(region, 'o'),
        color='w',
        label=region,
        markerfacecolor=palette.get(region, 'gray'),
        markersize=6,
        linestyle=''
    )
    for region in hiv_data_melted['Region'].unique()
]
#plt.legend(title='Region', handles=handles)
# set ylabel
axs[0].set_ylabel(r'$\Delta\Delta G_{DMS}$ (kcal/mol)')

plt.tight_layout()
plt.savefig("exports/hiv_ddG_swarmplot.pdf")

In [None]:
# swarmplot of ddG values (v2)
hiv_data = pd.read_csv("exports/hiv_calc_dG_manual_mean.csv")

hiv_data['ddG_a35g'] = hiv_data['dG_a35g'] - hiv_data['dG_wt']
hiv_data['ddG_c30u'] = hiv_data['dG_c30u'] - hiv_data['dG_wt']
hiv_data['ddG_gs'] = hiv_data['dG_gs'] - hiv_data['dG_wt']
hiv_data['ddG_es2'] = hiv_data['dG_es2'] - hiv_data['dG_wt']

# melt the dataframe to long format (ddG)
hiv_data_melted = hiv_data.melt(
    id_vars=['site', 'site_nt', 'dG_wt'],
    value_vars=['ddG_a35g', 'ddG_c30u', 'ddG_gs', 'ddG_es2'],
    var_name='Mutation',
    value_name='ddG'
).dropna()

# replace ddG_a35g with A35G, ddG_c30u with C30U, ddG_gs with GS, ddG_es2 with ES2
hiv_data_melted['Mutation'] = hiv_data_melted['Mutation'].replace({
    'ddG_a35g': 'A35G',
    'ddG_c30u': 'C30U',
    'ddG_gs': 'UUCG(GS)',
    'ddG_es2': 'UUCG(ES2)'
})

# annotate regions based on dict
hiv_data_melted['Region'] = hiv_data_melted['site'].apply(lambda x: next(
    (region for region, sites in regions.items() if x in sites), 'Other'))

# calculate 1 sigma
sigma_1 = np.std(hiv_data_melted['ddG'], ddof=1)   # 1σ
# sigma_2
sigma_2 = 2 * sigma_1
print(f"1σ: {sigma_1:.2f} kcal/mol")
print(f"2σ: {sigma_2:.2f} kcal/mol")
# swarmplot with region-specific markers
region_markers = {
    'lower_stem': 'o',
    'upper_stem': 's',
    'tribulge': 'D',
    'apical_loop': 'X',
    'Other': '^'
}

palette = dict(zip(
    sorted(hiv_data_melted['Region'].unique()),
    sns.color_palette('deep', n_colors=hiv_data_melted['Region'].nunique())
))

fig, axs = plt.subplots(1, 2, figsize=(4, 3), sharey = True, gridspec_kw={'width_ratios': [3, 1]})
for region in hiv_data_melted["Region"].unique():
    subset = hiv_data_melted[hiv_data_melted["Region"] == region]
    sns.swarmplot(
        data=subset,
        x='Mutation',
        y='ddG',
        dodge=True,
        marker=region_markers.get(region, 'o'),
        color=palette.get(region)
        , ax=axs[0]
    )

# shade ±1 sigma
axs[0].axhspan(-sigma_1, sigma_1, color='gray', alpha=0.2, label='±1σ')

# drop xaxis label for axs[0]
axs[0].set_xlabel('')

# rotate xtick labels for axs[0]
axs[0].set_xticklabels(axs[0].get_xticklabels(), rotation=45, ha='right')

# on axs[1], histogram of all ddG values with 1sigma shaded, vertical
sns.histplot(
    data=hiv_data_melted,
    y='ddG',
    bins=10,
    kde=True,
    color='gray',
    ax=axs[1]
)
# drop xaxis labels for axs[1]
axs[1].set_xlabel('')
# drop ticks for xaxis
axs[1].set_xticks([])

handles = [
    plt.Line2D(
        [0], [0],
        marker=region_markers.get(region, 'o'),
        color='w',
        label=region,
        markerfacecolor=palette.get(region, 'gray'),
        markersize=6,
        linestyle=''
    )
    for region in hiv_data_melted['Region'].unique()
]
#plt.legend(title='Region', handles=handles)

In [None]:
hiv_data_melted