In [1]:
# Standard imports
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
import numpy as np
%matplotlib notebook

# Load mavenn and check path
import mavenn

# Other imports
import logomaker
from scipy.stats import norm
import matplotlib.gridspec as gridspec

# Import helper functions
from helper_functions import my_rsquared, save_fig_with_date_stamp, set_xticks

# Set random seed
np.random.seed(0)

# Set figure name
fig_name = 'fig5'

In [2]:
style_file_name = f'{fig_name}.style'
s = """
axes.linewidth:     0.5     # edge linewidth
font.size:          7.0
axes.labelsize:     7.0  # fontsize of the x any y labels
axes.titlesize:     7.0
xtick.labelsize:    7.0  # fontsize of the tick labels
ytick.labelsize:    7.0  # fontsize of the tick labels
legend.fontsize:      7.0
legend.borderpad:     0.2  # border whitespace
legend.labelspacing:  0.2  # the vertical space between the legend entries
legend.borderaxespad: 0.2  # the border between the axes and legend edge
legend.framealpha:    1.0 
lines.dash_capstyle:   round        # {butt, round, projecting}
lines.solid_capstyle:   round        # {butt, round, projecting}
"""
with open(style_file_name, 'w') as f:
    f.write(s)
    
plt.style.use(style_file_name)

In [3]:
# Load trained models
models = {}
models['additive'] = mavenn.load('../models/mpsa_additive_ge_2021.12.30.21h.07m', verbose=False)
models['neighbor']= mavenn.load('../models/mpsa_neighbor_ge_2021.12.30.21h.07m', verbose=False)
models['pairwise'] = mavenn.load('../models/mpsa_pairwise_ge_2021.12.30.21h.07m', verbose=False)
models['blackbox'] = mavenn.load('../models/mpsa_blackbox_ge_2021.12.30.21h.07m', verbose=False)

Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.


In [4]:
# Load MPSA data
data_df = data_df = pd.read_csv('../datasets/mpsa_data.csv.gz', compression='gzip')

# Split into training+validation and test sets
trainval_df, test_df = mavenn.split_dataset(data_df)

# Get test_df x and y values
x_test = test_df['x'].values
y_test = test_df['y'].values

Training set   :   18,215 observations (  59.75%)
Validation set :    6,019 observations (  19.75%)
Test set       :    6,249 observations (  20.50%)
-------------------------------------------------
Total dataset  :   30,483 observations ( 100.00%)



In [5]:
# Load replicate dataset
rep_data_df = pd.read_csv('../datasets/mpsa_replicate_data.csv.gz', compression='gzip')
      
# merge datasets
cols = ['x','y']
intersection_df = pd.merge(left=data_df[cols], right=rep_data_df[cols], on='x', how='inner')
intersection_df.columns = ['x','y1','y2']
print(f'Num x in both mpsa_data.csv and mpsa_replicate_data.csv: {len(intersection_df):,}')

# Compute mutual information
I_intr, dI_intr = mavenn.src.entropy.mi_continuous(intersection_df['y1'], intersection_df['y2'])
print(f'I_intr >= {I_intr:.3f} +- {dI_intr:.3f} bits')

Num x in both mpsa_data.csv and mpsa_replicate_data.csv: 29,593
I_intr >= 0.462 +- 0.009 bits


In [6]:
# Load results of Sailer-Harms inference. See folder fig5_SH_data for inference script.
SH_df = pd.read_excel('fig5_SH_data/Sailer_Harms_powerlaw_fit_MPSA_data.xlsx',sheet_name='test')
SH_line = pd.read_excel('fig5_SH_data/Sailer_Harms_powerlaw_fit_MPSA_data.xlsx',sheet_name='line')
phi_sh = SH_df['y_add_fixed_test']
y_sh = SH_df['y_obs_test']
phi_grid_sh = SH_line['y_add_line_fixed']
yhat_grid_sh = SH_line['y_obs_line']

In [7]:
# Estimate I_pred for S&H epsistasis model
from mavenn.src.entropy import mi_continuous
knn_fuzz = 0.01  # Default vaue used in Model.I_predictive
phi_sh_fuzz = phi_sh.copy()
phi_sh_fuzz += knn_fuzz * phi_sh_fuzz.std(ddof=1) * np.random.randn(len(phi_sh_fuzz))
I_pred_SH, dI_pred_SH = mi_continuous(phi_sh_fuzz,y_sh)

In [8]:
# Creat dict to hold information values
info_dict = {}

# Create dataframe to hold info
info_df = pd.DataFrame(columns=['name','metric','I','dI'])
row_dicts = []
row_dicts.append({'name':'additive\n(epistasis package)',
                  'metric':'$I_{\\rm var}$',
                  'I':0.000000,
                  'dI':0.000000})

row_dicts.append({'name':'additive\n(epistasis package)',
                  'metric':'$I_{\\rm pre}$',
                  'I':I_pred_SH,
                  'dI':dI_pred_SH})

# Compute variational and predictive information on test data
for model_name, model in models.items():
    
    # Compute and record likelihood information
    I_var, dI_var =  model.I_variational(x=x_test, y=y_test)
    row_dicts.append({'name':model_name,
                      'metric':'$I_{\\rm var}$',
                      'I':I_var,
                      'dI':dI_var})
    
    # Compute and record predictive information
    I_pred, dI_pred = model.I_predictive(x=x_test, y=y_test, num_subsamples=1000)
    row_dicts.append({'name':model_name,
                      'metric':'$I_{\\rm pre}$',
                      'I':I_pred,
                      'dI':dI_pred})
    
# Add rows to dataframe 
for row_dict in row_dicts:
    info_df = info_df.append(row_dict, ignore_index=True)
    
# Put space in black box
info_df['name'] = info_df['name'].str.replace('blackbox', 'black box')   
info_df

Unnamed: 0,name,metric,I,dI
0,additive\n(epistasis package),$I_{\rm var}$,0.0,0.0
1,additive\n(epistasis package),$I_{\rm pre}$,0.179863,0.011398
2,additive,$I_{\rm var}$,0.195117,0.02981
3,additive,$I_{\rm pre}$,0.256769,0.012866
4,neighbor,$I_{\rm var}$,0.305341,0.027773
5,neighbor,$I_{\rm pre}$,0.35185,0.016126
6,pairwise,$I_{\rm var}$,0.328253,0.026102
7,pairwise,$I_{\rm pre}$,0.374056,0.01404
8,black box,$I_{\rm var}$,0.408747,0.026136
9,black box,$I_{\rm pre}$,0.457523,0.015222


In [9]:
# Create figure
fig = plt.figure(figsize=[6.5, 7])
gs = gridspec.GridSpec(3, 3, figure=fig)

# Set lims and ticks
ylim = [-2.5, 2.5]
yticks = [-6,-4,-2,0,2,4]
philim = [-3,5]
phiticks = [-2,0,2,4]

#
# Panel A: Bar chart
#
ax = fig.add_subplot(gs[0,:])

# Plot I_int region
K = len(info_df)
xlim = [-.5, 4.5]
ylim = [0, 0.6]
ax.fill_between(xlim, [I_intr, I_intr], [.6, .6], color='gray', alpha=.2, zorder=-100, linewidth=0,
                label='$I_\mathrm{int}$')

# Plot bars
sns.barplot(ax=ax, data=info_df, hue='metric', x='name', y='I')

# Plot 95% confidence intervals
x = np.array([[x-.2,x+.2] for x in range(5)]).ravel()
yerr = info_df['dI'].values
y = info_df['I'].values 
ax.errorbar(x=x[1:], y=y[1:], yerr=yerr[1:], color='k', capsize=3, linestyle='none', 
            elinewidth=1, capthick=1, solid_capstyle='round')
ax.set_ylabel('information (bits)')
ax.set_xlabel('')
ax.set_xlim(xlim)
ax.set_ylim(ylim)

# Reorder legend
handles, labels = ax.get_legend_handles_labels()
handles = [handles[1], handles[2], handles[0]]
labels = [labels[1], labels[2], labels[0]]
ax.legend(handles, labels, loc='upper left')

# Remind user of intrinsic information
print(f'I_intr >= {I_intr:.3f} +- {dI_intr:.3f} bits')

# Compute a p-value for additive epistasis vs. MAVE-NN I_pred
I_epistasis = info_df.loc[1,'I']
dI_epistasis = info_df.loc[1,'dI']
print(f'I_epistasis: {I_epistasis:.3f} +- {dI_epistasis:.3f} bits')
I_mavenn = info_df.loc[3,'I']
dI_mavenn = info_df.loc[3,'dI']
print(f'I_mavenn: {I_mavenn:.3f} +- {dI_mavenn:.3f} bits')
z = (I_mavenn - I_epistasis)/np.sqrt(dI_epistasis**2 + dI_mavenn**2)
print(f'z: {z:.2f}')
dist = norm()
p = norm.cdf(-z)
print(f'p-value: {p:.4e}')

# Plot significance line
ax.plot([0.2, 1.2], [.3, .3], '-k')

# Plot text 
ax.text(x=.7, y=.31, s=f'p = {p:.1e}', ha='center', va='bottom')


#
# Panel B: GE plot for Sailer & Harms model
#

ylim = [-1.5, 2.5]
yticks = [-1, 0, 1, 2]

# Show results from Sailer & Harms
# Set ax
ax = fig.add_subplot(gs[1,0])

ax.scatter(phi_sh, 
           y_sh, 
           s=1,
           alpha=.2,
           color='C0')

ax.plot(phi_grid_sh, 
        yhat_grid_sh, 
        color='C1', 
        alpha=1, 
        linestyle='-', 
        linewidth=2, 
        label='$\hat{y}$')

# Style plot
ax.set_xlim(philim)
ax.set_ylim(ylim)
ax.set_yticks(yticks)
ax.set_xlabel(f'$\phi$, additive (epistasis package)', labelpad=-1)
ax.set_ylabel('$\log_{10}$ PSI ($y$)', labelpad=-2)

# Show predictive information
ax.set_title(f'$I_\mathrm{{pre}}: {I_epistasis:.3f}\pm{dI_epistasis:.3f}$ bits', fontsize=7)

#
# Panel C,D: GE plots for MAVE-NN models
#

# Choose phi grid for plotting curves
for i, name in enumerate(['additive', 'pairwise']):
    
    # Set ax
    ax = fig.add_subplot(gs[1,i+1])
    mpsa_model = models[name]
    phi = mpsa_model.x_to_phi(x_test)
    
    phi_grid = np.linspace(philim[0], philim[1], 1000)
    yhat_grid = mpsa_model.phi_to_yhat(phi_grid)
    yqs_grid = mpsa_model.yhat_to_yq(yhat_grid, q=[.025,.975])
    I_pred = info_df[info_df['name']==name].reset_index(drop=True)['I'][1]
    dI_pred = info_df[info_df['name']==name].reset_index(drop=True)['dI'][1]

    # Draw scatter plot
    ax.scatter(phi, 
               y_test, 
               s=1,
               alpha=.2,
               color='C0')

    # Draw GE curve & confidence intervals
    ax.plot(phi_grid, 
            yhat_grid, 
            color='C1', 
            alpha=1, 
            linestyle='-', 
            linewidth=2, 
            label='$\hat{y}$')
    ax.plot(phi_grid, 
            yqs_grid[:,0], 
            color='C1', 
            alpha=1, 
            linestyle=':', 
            linewidth=1, 
            label='95% CI')
    ax.plot(phi_grid, 
            yqs_grid[:,1], 
            color='C1', 
            alpha=1, 
            linestyle=':', 
            linewidth=1)
    
    # Show predictive information
    ax.set_title(f'$I_\mathrm{{pre}}: {I_pred:.3f}\pm{dI_pred:.3f}$ bits', fontsize=7)
    
    # Style plot
    ax.set_xlim(philim)
    ax.set_ylim(ylim)
    ax.set_yticks(yticks)
    ax.set_yticklabels([])
    ax.set_xlabel(f'$\phi$, {name}', labelpad=0)

#
# Panel E: Logo
#   

# Make axes
ax = fig.add_subplot(gs[2,0])

# Draw logo
logo_df = models['pairwise'].get_theta()['logomaker_df'].fillna(0)
logo = logomaker.Logo(df=logo_df, ax=ax, fade_below=.5, shade_below=.5, width=.9, 
                      font_name='Arial Rounded MT Bold')

# Style logo
ylim = ax.get_ylim()
logo.highlight_position_range(pmin=3, pmax=4, color='w', alpha=1, zorder=10)
logo.highlight_position_range(pmin=3, pmax=4, color='gray', alpha=.1, zorder=11)
logo.style_single_glyph(p=3, c='G', floor=ylim[0], 
                        ceiling=ylim[1], color='gray', zorder=30, alpha=.5, flip=False)
logo.style_single_glyph(p=4, c='U', color='k', zorder=30)
logo.style_single_glyph(p=4, c='C', color='k', zorder=30)
logo.style_spines(visible=False)

# Sytle axes
ax.axvline(2.5, linestyle=':', color='k', zorder=30)
ax.set_ylabel('additive effect ($\Delta \phi$)', labelpad=-1)
ax.set_xticks([0,1,2,3,4,5,6,7,8])
ax.set_xticklabels([f'{x:+d}' for x in range(-3,7) if x!=0])
ax.set_xlabel('nucleotide position', labelpad=5)

#
# Panel E: Pairwise heatmap
#  

# Make axes
ax = fig.add_subplot(gs[2,-2:])

# Draw heatmap
theta = models['pairwise'].get_theta()['theta_lclc']
ax, cb = mavenn.heatmap_pairwise(values=theta,
                          alphabet='rna',
                          ax=ax,
                          gpmap_type='pairwise',
                          cmap_size='3%')

# Style axes
ax.set_xticks([0,1,2,3,4,5,6,7,8])
ax.set_xticklabels([f'{x:+d}' for x in range(-3,7) if x!=0])
ax.set_xlabel('nucleotide position', labelpad=5)
cb.set_label('pairwise effect ($\Delta \Delta \phi$)', labelpad=5, ha='center', va='center', rotation=-90)
cb.outline.set_visible(False)
cb.ax.tick_params(direction='in', size=10, color='white')

# Tighten figure
fig.tight_layout(h_pad=2)

# Save figure
save_fig_with_date_stamp(fig, fig_name)

<IPython.core.display.Javascript object>

I_intr >= 0.462 +- 0.009 bits
I_epistasis: 0.180 +- 0.011 bits
I_mavenn: 0.257 +- 0.013 bits
z: 4.47
p-value: 3.8354e-06
Figure saved figure to png/fig5_ipynb_2022.03.18.08h.00m.57s.png.
