In [1]:
# This notebook is from the github repo of *Repeated Omicron infection alleviates SARS-CoV-2 immune imprinting*

import pandas as pd
import numpy as np
from plotnine import *
import time
neut2codon = {
    "BA5_IC50":"mut1_BA.4_BA.5_EPI_ISL_11207535.csv",
    "XBB1_5_IC50":"mut1_XBB_1_5_EPI_ISL_17054053.csv",
    # "XBB1_5_10_IC50":"mut1_XBB_1_5_10_EPI_ISL_16382307.csv",
}

codon_weights = {
    "BA5_IC50":"mutation_weights_BA.4_BA.5_EPI_ISL_11207535.csv",
    "XBB1_5_IC50":"mutation_weights_XBB_1_5_EPI_ISL_17054053.csv",
    # "XBB1_5_10_IC50":"mutation_weights_XBB_1_5_10_EPI_ISL_16382307.csv",
}

neut2se = {
    "BA2_IC50":"bind_expr_BA2.csv",
    "BA5_IC50":"bind_expr_BA2.csv",
    "XBB1_5_IC50":"bind_expr_BA2.csv",
    # "XBB1_5_10_IC50":"bind_expr_BA2.csv",
}

mut_for_bind_expr = {
    'BA5_IC50': [(452, 'R'), (486, 'V'), (493, 'Q')],
    'XBB1_5_IC50': [(339, 'H'), (346, 'T'), (368, 'I'), (445, 'P'), (446,'S'), (460,'K'), (486, 'P'), (490, 'S'), (493, 'Q')],
    # 'XBB1_5_10_IC50': [(339, 'H'), (346, 'T'), (368, 'I'), (445, 'P'), (446,'S'), (456,'L'), (460,'K'), (486, 'P'), (490, 'S'), (493, 'Q')],
}

for strain in mut_for_bind_expr:
    data = pd.read_csv(neut2se[strain]).assign(bias_e = 0.0,bias_b=0.0)
    for site, mut in mut_for_bind_expr[strain]:
        expr = data.query('site == @site and mutation == @mut')['expr_avg'].item()
        bind = data.query('site == @site and mutation == @mut')['bind_avg'].item()
        data.loc[data['site'] == site, 'bias_e'] += expr
        data.loc[data['site'] == site, 'bias_b'] += bind
        data.loc[data['site'] == site, 'wildtype'] = mut
        
    data['expr_avg'] -= data['bias_e']
    data['bind_avg'] -= data['bias_b']
    data.drop(columns=['bias_e','bias_b']).assign(mutant = lambda x: x['wildtype']+x['site'].astype(str)+x['mutation'], 
                                                  mutation_RBD = lambda x: x['wildtype']+(x['site']-330).astype(str)+x['mutation']).to_csv("mut_approx_"+strain+".csv", index=None)
    neut2se[strain] = "mut_approx_"+strain+".csv"

In [2]:
scores_r = pd.read_csv("antibody_dms_merge_no_filter_clean.csv")
use_abs = np.unique(scores_r['antibody'])

data = pd.read_csv("../antibody_info.csv", index_col=0)[[
    "source","BA5_IC50","XBB1_5_IC50","XBB1_5_10_IC50","paper_reactivity"
]].assign(antibody=lambda x: x.index).query('antibody in @use_abs')

_srcs = data['source'].to_list()

def src_rename(x):
    if "mouse" in x:
        return "mouse"
    elif "WT" in x:
        return "WT"
    elif "SARS" in x:
        return "SARS"
    elif "BA.1" in x and "reinfection" in x:
        return "BA.1_reinfect"
    elif "BA.2" in x and "reinfection" in x:
        return "BA.2_reinfect"
    elif "BA.1" in x:
        return "BA.1"
    elif "BA.2" in x:
        return "BA.2"
    elif "BA.5" in x:
        return "BA.5"
    elif "BF.7" in x:
        return "BF.7"
    
data = data.assign(Usrc = [src_rename(x) for x in _srcs]).query('not (Usrc == "???" or Usrc == "mouse")')

In [3]:
import logomaker
from matplotlib import rcParams
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
rcParams['pdf.fonttype'] = 42

def plot_res_logo(res, prefix, shownames={}, rownames=None, site_thres=0.0, force_plot_sites = None, force_ylim = None, width=None):
    flat_res = res.pivot(index=['antibody', 'site'], columns='mutation', values='mut_escape').fillna(0)
    sites_mean_score = flat_res.mean(axis=1)
    sites_total_score = flat_res.sum(axis=1)
    _ = sites_total_score[sites_total_score>site_thres].index
    strong_sites = np.unique(np.array(sorted([i[1] for i in _])))
    print(strong_sites)

    plot_sites = strong_sites
    plot_sites = plot_sites[plot_sites < 520].astype(int)
    print(plot_sites)
    
    if force_plot_sites is not None:
        plot_sites = force_plot_sites
    
    flat_res = flat_res[flat_res.index.isin(plot_sites, level=1)]

    _ = pd.DataFrame(sites_total_score)
    _.columns = ['value']
    _['site'] = [i[1] for i in _.index]
    _['antibody'] = [i[0] for i in _.index]

    if rownames is not None:
        Abs = rownames
    else:
        Abs = np.unique([i[0] for i in flat_res.index])
    print(Abs)
    Npages = len(Abs)//10 + 1
    if width is None:
        width=30
    with PdfPages(prefix+'_aa_logo.pdf') as pdf:
        for p in range(Npages):
            Abs_p = Abs[p*10:min(len(Abs),(p+1)*10)]
            fig = plt.figure(figsize=(width,len(Abs_p)*4.6)).subplots_adjust(wspace=0.2,hspace=0.5)
            site2pos = {}
            for i in range(len(plot_sites)):
                site2pos[plot_sites[i]] = i

            for i in range(len(Abs_p)):
                ab = Abs_p[i]
                _ = flat_res.query('antibody == @ab').droplevel(0)
                add_sites = np.setdiff1d(plot_sites, _.index)
                for _site in add_sites:
                    _.loc[_site,:] = 0.0
                _ = _.sort_index()
                _.index = range(len(_))
                ax = plt.subplot(len(Abs_p), 1, i+1)
                logo = logomaker.Logo(_,
                               ax=ax, 
                               color_scheme='dmslogo_funcgroup', 
                               vpad=.1, 
                               width=.8)
                logo.style_xticks(anchor=1, spacing=1, rotation=90, fontsize=16)
                _max = np.sum(_.to_numpy(), axis=1).max()
                # ax.set_xticklabels(plot_sites[1::2])
                ax.set_xticklabels(plot_sites)
                
                # ax.set_yticks([])
                ax.tick_params(axis='both', which='both', length=0)
                if force_ylim is not None:
                    ax.set_ylim(0.0,force_ylim)
                ax.yaxis.set_tick_params(labelsize=20)
                if ab in shownames:
                    ax.set_title(shownames[ab], fontsize=8, fontweight="bold")
                else:
                    ax.set_title(ab, fontsize=8, fontweight="bold")
            pdf.savefig()
            plt.close()

In [4]:
def mut_agg_func(x, l): # an arbitrary function to simulate an additional mutation. Not used here.
    k = 10
    pw = 4
    
    return max(0, 0.5-np.log((np.sum(np.exp(k*np.array(x)))+l-len(x))/l)/k)**pw

def do_calc(use_ab_src, use_neut, A_adv = True, A_codon = True, A_neut = True,
            E=1.0, B=1.0, use_log=False, use_max=False, use_norm=False,
            logo=False, return_df=False, use_codon_weights=False, specific_weight = None,
            muts = [], title=None
           ):
    neut_data = data[use_neut].to_dict()
    
    mut_expand = set()
    XTag = 0
    for _site, _m in muts:
        if _m == "X":
            XTag = 1
        mut_expand.add(str(_site)+_m)
    print("out mut:",mut_expand)
    
    single_mut_effects = pd.read_csv(neut2se[use_neut]).assign(
        coef=lambda x: [y for y in (np.tanh(x['expr_avg']*(x['expr_avg']<0)*E)+1)*(np.tanh(x['bind_avg']*B)+1)]
    ) # don't give bonus to expr > 0
    
    single_mut_effects.index = single_mut_effects['site'].astype('str') + single_mut_effects['mutation']
    single_mut_effects = single_mut_effects['coef'].to_dict()

    if use_codon_weights is False:
        use_codon = pd.read_csv(neut2codon[use_neut])
        _umuts = {}
        for i in range(len(use_codon)):
            _ms = use_codon['mut1'][i]
            for x in _ms:
                _umuts[str(use_codon['pos'][i])+x] = 1.0
    else:
        _umuts = pd.read_csv(codon_weights[use_neut]).assign(
            mutation = lambda x: x['pos'].astype(str)+x['mut']).set_index('mutation')['weight'].to_dict()
    
    _uabs = set(data.query('Usrc in @use_ab_src').index.to_list())
    
    src_dict = data['Usrc'].to_dict()
    spc_dict = data['paper_reactivity'].to_dict()
    
    scores = scores_r.assign(site_mut = lambda x: x['site'].astype(str)+x['mutation']).query('antibody in @_uabs').assign(
        adv_weight = (lambda x: [single_mut_effects[y] if (y in single_mut_effects and not np.isnan(single_mut_effects[y])) else 0.0 for y in x['site_mut']]) if A_adv else 1.0,
        codon_weight = (lambda x: [(_umuts[y] if y in _umuts else 0.0) for y in x['site_mut'].to_list()]) if A_codon else 1.0
    )
    
    if use_norm:
        scores = scores.assign(escape_max = lambda x: x.groupby('antibody')['mut_escape'].transform('max')).assign(
            mut_escape = lambda x: x['mut_escape']/x['escape_max']).drop(columns=['escape_max'])
    if XTag:
        _scX = scores.groupby(['antibody','site']).max().reset_index().assign(
            mutation = 'X',
            site_mut = lambda x: x['site'].astype(str)+'X'
        )
        scores = pd.concat([scores, _scX])
    
    
    _ab_mut_coef = {} if len(muts) == 0 else (
        scores.query('site_mut in @mut_expand')
              .groupby('antibody')['mut_escape']
              .agg(coef=lambda x: mut_agg_func(x, len(mut_expand)))
    )['coef'].to_dict()
    scores = scores.assign(mut_weight = lambda x: [(_ab_mut_coef[y] if y in _ab_mut_coef else 1.0) for y in x['antibody']]).query('mutation != "X"')
    
    if use_log:
        scores = scores.assign(neut_weight = lambda x: [(0.0 if np.isnan(neut_data[y]) else max(0.0,np.log10(1/min(1,neut_data[y])))) if A_neut else 1.0 for y in x['antibody']])
    else:
        scores = scores.assign(neut_weight = lambda x: [(0.0 if np.isnan(neut_data[y]) else 1.0/neut_data[y]) if A_neut else 1.0 for y in x['antibody']])
    
    scores['other_weight'] = 1.0

    if specific_weight is not None:
        scores = scores.assign(other_weight = lambda x: [specific_weight[src_dict[y]] if (src_dict[y] in specific_weight and "specific" in spc_dict[y]) else 1.0 for y in x['antibody']])
    scores = scores.assign(
        mut_escape_adj = lambda x: x['mut_escape'] * x['neut_weight'] * x['adv_weight'] * x['codon_weight'] * x['other_weight'] * x['mut_weight']
    )
    _title = ("src: "+'+'.join(use_ab_src)+
              ' weight: '+use_neut+' expr_bind:'+str(A_adv)+
              ' codon:'+str(A_codon)+' log:'+str(use_log)+
              ' norm:'+str(use_norm)+' max:'+str(use_max)+
              ' Expr:'+str(E)+' Bind:'+str(B)) if title is None else title
    
    if logo:
        scores = scores.groupby(['site','mutation']).sum()['mut_escape_adj'].reset_index().assign(antibody=_title)
        scores['mut_escape_adj'] = scores['mut_escape_adj']/scores['mut_escape_adj'].max()
        return scores
    
    if use_max:
        site_avg = scores.groupby(['site', 'antibody']).max()['mut_escape_adj'].reset_index().groupby('site').sum().reset_index()
    else:
        site_avg = scores.groupby(['site', 'mutation']).sum()['mut_escape_adj'].reset_index().groupby('site').sum().reset_index()
    site_avg['mut_escape_adj'] = site_avg['mut_escape_adj']/site_avg['mut_escape_adj'].max()
    
    if return_df:
        return site_avg.assign(
            absrc = '+'.join(use_ab_src), weight = use_neut, is_expr_bind = A_adv, is_codon = A_codon, 
            is_neut_log = use_log, is_norm = use_norm, is_max = use_max, expr_coef = E, bind_coef = B
        )
    p = (
        ggplot(site_avg, aes('site', 'mut_escape_adj')) + 
        geom_line() + geom_point()+ theme_classic() + theme(
            axis_text_y=element_blank(),
            axis_ticks_major_y=element_blank(),figure_size=(12,3),
            axis_text_x=element_text(angle=90)
        )+scale_x_continuous(breaks=range(331,531,2))+
        ylab('weighted escape score')+xlab('RBD residues')+ggtitle(_title)+
        geom_text(site_avg.query('mut_escape_adj > 0.2'), aes(label='site'), #nudge_y=0.05, 
                                adjust_text={'expand_points': (2, 2), 'arrowprops': {'arrowstyle': '-'}})
    )
    return p

In [9]:
df = []
for use_ab_src, use_neut in [
                             (['BA.5','BF.7'],'BA5_IC50'), 
                             (['BA.5','BF.7'],'XBB1_5_IC50'), 
                             (['BA.5','BF.7','BA.1_reinfect','BA.2_reinfect'],'XBB1_5_IC50'),
                             # (['BA.5','BF.7','BA.1_reinfect','BA.2_reinfect'],'XBB1_5_10_IC50'),
                            ]:
    ts = do_calc(use_ab_src, use_neut, A_adv = True, A_codon = True, A_neut=True, use_log=True,
                          E=1.5, B=1.0, specific_weight = {"BA.5":1/3.23, "BF.7":1/3.02, "BA.1_reinfect":1/1.98, "BA.2_reinfect":1/1.33}, # weights for correcting the proportion of cross-reactive antibodies, as our DMS dataset enriches Omicron-specific mAbs
                 use_norm=True, use_max=False, use_codon_weights=True, logo=True).rename(columns={'mut_escape_adj':'mut_escape'})
    df.append(ts)

df = pd.concat(df)

plot_res_logo(df, "logo", site_thres=0.3, width=10)

out mut: set()
out mut: set()
out mut: set()
[346 403 405 406 417 420 439 440 444 445 446 449 450 453 455 456 460 475
 478 479 484 490 493 504 505]
[346 403 405 406 417 420 439 440 444 445 446 449 450 453 455 456 460 475
 478 479 484 490 493 504 505]
['src: BA.5+BF.7 weight: BA5_IC50 expr_bind:True codon:True log:True norm:True max:False Expr:1.5 Bind:1.0'
 'src: BA.5+BF.7 weight: XBB1_5_IC50 expr_bind:True codon:True log:True norm:True max:False Expr:1.5 Bind:1.0'
 'src: BA.5+BF.7+BA.1_reinfect+BA.2_reinfect weight: XBB1_5_IC50 expr_bind:True codon:True log:True norm:True max:False Expr:1.5 Bind:1.0']


In [5]:
# subject to R for plot

xx = "sum"

df = []
for use_ab_src in [['BA.5','BF.7'],['BA.5','BF.7','BA.1_reinfect','BA.2_reinfect']]:
    for use_neut in ["BA5_IC50", "XBB1_5_IC50"]:
        df.append(do_calc(use_ab_src, use_neut, A_adv = True, A_codon = False, A_neut=False, use_log=True,
                          E=1.5, B=1.0, specific_weight = {"BA.5":1/3.23, "BF.7":1/3.02, "BA.1_reinfect":1/1.98, "BA.2_reinfect":1/1.33},
                 use_norm=True, use_max=False, use_codon_weights=True, logo=False,return_df=True))
df = pd.concat(df)
df.to_csv("tmp_data-"+xx+".csv", index=None)

df = df[(df['site']>=331)&(df['site']<=520)]
df.to_csv("tmp_data-"+xx+"331_520.csv", index=None)

out mut: set()
out mut: set()
out mut: set()
out mut: set()
