## Sequence Generate

We provide a sequence screening demo, and visualize the screening process and results. We use BA.2.1 as parent node, and the generation scale is set as 1 million. For quick start, we provide our generated sequences and predicted classification results for expression. The calculation of quantified antibody barrier score is also included.

In [None]:
!git clone https://github.com/Kevinatil/GenPreMut.git

In [None]:
!pip install dmslogo

In [None]:
!unzip -d GenPreMut/data/predicts/df/ GenPreMut/data/predicts/df/pred_BA.2.1.zip
!unzip -d GenPreMut/ckpt/antibody_barrier_model/model2/ GenPreMut/ckpt/antibody_barrier_model/model2/escape_antibody_BA.2.1.zip

In [None]:
import os
import numpy as np
import pandas as pd

import pickle
from tqdm import tqdm

In [None]:
rbd_name = "BA.2.1"
model_root = "GenPreMut/ckpt"
data_root = "GenPreMut/data"

In [None]:
## calculate quantified antibody score

def get_escape_score(group_mean, seq, group):
    len_ = len(seq)
    ret = 0
    for i in range(len_):
        try:
            s = group_mean[group][i+331][seq[i]]
        except:
            s = 0
        ret += s
    return ret

def get_group_weight(group_dict, counter, key):
    return len(group_dict[key])/counter

def antibody_barrier(df_path, rbd_name, version):
    root = os.path.join(model_root, 'antibody_barrier_model/model{}'.format(version))

    df_ab = pd.read_csv(os.path.join(root, 'escape_antibody_{}.csv'.format(rbd_name)))
    groups = df_ab['group'].unique().tolist()
    antibody_groups = df_ab[['antibody','group']]
    groupby = antibody_groups.groupby('group')

    group_dict = dict()
    for i in range(len(groups)):
        group_dict[groups[i]] = groupby.get_group(groups[i])['antibody'].unique()

    group_mean = pickle.load(open(os.path.join(root, 'group_mean_{}.pkl'.format(rbd_name)),'rb'))


    df=pd.read_csv(df_path)
    seqs=df['sequence'].values
    res=[]
    for seq in tqdm(seqs):
        res_=[]
        for group_ in groups:
            res_.append(get_escape_score(group_mean, seq, group_))
        res.append(res_)

    df = pd.DataFrame()
    df[groups] = res

    counter=0
    for key in group_dict.keys():
        counter+=len(group_dict[key])
    print(counter)

    scores=[]
    for i in tqdm(range(len(df))):
        score_=0
        for group in groups:
            score_ += get_group_weight(group_dict, counter, group) * df[group][i]
        scores.append(score_)

    return scores

In [None]:
df_path = os.path.join(data_root, 'predicts/df/pred_{}.csv'.format(rbd_name))

df['antibody_barrier_model1'] = antibody_barrier(df_path, rbd_name, version=1)
df['antibody_barrier_model2'] = antibody_barrier(df_path, rbd_name, version=2)

df.to_csv(df_path, index=False)

In [None]:
df.head()

In [None]:
## sequence screening
ori='NITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFSTFKCYGVSPTKLNDLCFTNVYADSFVIRGDEVRQIAPGQTGKIADYNYKLPDDFTGCVIAWNSNNLDSKVGGNYNYLYRLFRKSNLKPFERDISTEIYQAGSTPCNGVEGFNCYFPLQSYGFQPTNGVGYQPYRVVVLSFELLHAPATVCGPKKST'

def get_muts(seq):
    len_=len(seq)
    ret=[]
    for i in range(len_):
        if ori[i]!=seq[i]:
            ret.append('{}{}{}'.format(ori[i],331+i,seq[i]))
    return ret

for kind in ['model1','model2']:
    thres=0.5
    df=pd.read_csv(os.path.join(data_root, 'predicts/df/pred_{}.csv'.format(rbd_name)))

    len_=len(df)

    seq1=set(df[df['expr_cls']>thres]['sequence'].values)
    seq2=set(df.sort_values(['antibody_barrier_{}'.format(kind)],ascending=False)['sequence'].values[:len_//2])
    seqs=seq1.intersection(seq2)
    seqs=list(seqs)

    f=open(os.path.join(data_root, 'predicts/df/seqs_final_{}_{}.txt'.format(rbd_name, kind)), 'w')
    for seq in seqs:
        f.write('{}\n'.format(seq))
    f.close()
    print(len(seqs))

    all_muts=[]
    for seq in seqs:
        all_muts+=get_muts(seq)

    mut_types=set(all_muts)
    nums={}
    for i in mut_types:
        nums[i]=0
    for i in all_muts:
        nums[i]+=1

    freqs=[]
    for i in nums.keys():
        freqs.append([i,nums[i]])
    df=pd.DataFrame(freqs,columns=['mutation','frequency'])
    df.sort_values(['frequency'],ascending=False).to_csv(os.path.join(data_root, 'predicts/df/mut_freq_{}_{}.csv'.format(rbd_name, kind)), index=False)

In [None]:
## give prediction
df1 = pd.read_csv(os.path.join(data_root, 'predicts/df/mut_freq_{}_{}.csv'.format(rbd_name, 'model1')))
df2 = pd.read_csv(os.path.join(data_root, 'predicts/df/mut_freq_{}_{}.csv'.format(rbd_name, 'model2')))

muts1=set(df1['mutation'].values)
muts2=set(df2['mutation'].values)

df1 = df1['mutation'].values
df2 = df2['mutation'].values

muts = muts1.intersection(muts2)

rank1 = []
rank2 = []

for i in range(len(df1)):
    if df1[i] in muts:
        rank1.append(df1[i])

for i in range(len(df2)):
    if df2[i] in muts:
        rank2.append(df2[i])

rank1 = np.array(rank1)
rank2 = np.array(rank2)

rank_all = []

for mut in muts:
    rank_all.append([mut, (np.where(rank1 == mut)[0][0] + np.where(rank2 == mut)[0][0]) / 2])

rank_all.sort(key = lambda k: k[1])

pd.DataFrame(rank_all, columns = ['mutation', 'rank']).to_csv(os.path.join(data_root, 'predicts/df/mut_rank_{}.csv'.format(rbd_name)), index=False)

In [None]:
## visualize predicted mutation types
import dmslogo
import matplotlib.pyplot as plt
%matplotlib inline

targets_all_dict = {
    'BA.2.1': ['452R','452Q','346T','486S','486P','460K','486V','446S','490S','339H','444T','445P','368I','478R'],
    'BA.5.1': ['346T','346S','444T','446R','460K','490S']
}

# rank top 100
targets_todraw_dict = {
    'BA.2.1': ['452R','452Q','346T','486S','486P','460K','486V','446S','490S','339H','444T'],
    'BA.5.1': ['346T','346S','444T','446R','460K','490S']
}

targets_all = targets_all_dict[rbd_name]
targets_todraw = targets_todraw_dict[rbd_name]

df = pd.read_csv(os.path.join(data_root, 'predicts/df/mut_freq_{}_model1.csv'.format(rbd_name)))
df = df[16:]

df['pos']=df['mutation'].apply(lambda x: x[1:4]).apply(int)
df['variable']=df['mutation'].apply(lambda x: x[4])
df['wt_label']=df['mutation'].apply(lambda x: x[:4])
df['color']='#808080'

freq = df.groupby('pos').sum()['frequency']
df['freq_sum'] = df['pos'].apply(lambda x: freq[x])


sites_todraw = []
for site in targets_all:
    sites_todraw.append(int(site[:3]))

for site in targets_todraw:
    pos,mut = site[:3], site[3]
    idx = df[(df['pos'] == int(pos))&(df['variable'] == mut)].index
    assert len(idx)
    df.loc[idx,'color'] = '#610345'



fig, ax = dmslogo.draw_logo(data=df[df["pos"].isin(sites_todraw)],
                        x_col='pos',
                        letter_col='variable',
                        letter_height_col='frequency',
                        color_col="color",
                        xtick_col="wt_label",
                        xlabel="Site",
                        ylabel="frequency",
                        axisfontscale= 2,
                        letterheightscale=1,
                )

#fig.savefig('{}_freq.svg'.format(name), bbox_inches='tight')
fig.show()