In [None]:
import numpy as np
import matplotlib.pyplot as plt
from umap import UMAP
from scipy.stats import spearmanr, pearsonr
from sklearn import metrics
import glob, pymol, lmdb, os
import pickle as pkl
from skgstat.models import spherical
from scipy.linalg import solve
from math import dist
from scipy.spatial.distance import squareform
from sklearn import metrics
from venn import venn
import scipy.cluster.hierarchy as sch
import radialtree as rt
from adjustText import adjust_text

import plotly.offline as go_offline
import plotly.graph_objects as go

from scipy.interpolate import LinearNDInterpolator

def get_auroc(preds, obs):
        fpr, tpr, thresholds  = metrics.roc_curve(obs, preds, drop_intermediate=False)
        auroc = metrics.auc(fpr, tpr)
        return auroc

def get_auprc(preds, obs):
        precision, recall, thresholds  = metrics.precision_recall_curve(obs, preds)
        auprc = metrics.auc(recall, precision)
        return auprc


In [None]:
pmi_score = np.load('pmi_score.npy', allow_pickle = True).item()
pmi_score['106_107']

In [None]:
#antibody_list_batch1 = ['1_2', '7_8', '11_12', '13_14', '15_16', '17_18', '19_20', '25_26', '31_32', '43_44', '51_52', '51_53', '86_87', '88_89', '92_93', '94_95', '96_97', '104_105', '106_107', '108_109', '110_111', '116_117', '118_119']
#4 antibodies, 7/8, 51/53, 116/117, 118/119 are removed.
antibody_list_batch1 = ['1_2', '11_12', '13_14', '15_16', '17_18', '19_20', '25_26', '31_32', '43_44', '51_52', '86_87', '88_89', '92_93', '94_95', '96_97', '104_105', '106_107', '108_109', '110_111']

clinical_antibody_list = ['ADI', 'AMU', 'BAM', 'BEB', 'C135', 'C144', 'CAS', 'CIL', 'ETE', 'IMD', 'REG', 'ROM', 'SOT', 'TIX']
ace2_binding_residue = np.array([417, 446, 447, 449, 453, 455, 456, 473, 475, 476, 477, 484, 486, 487, 489, 490, 493, 494, 495, 496, 498, 500, 501, 502, 503, 505])
seq_dict_all = np.load('seq_dict_all_filtered.npy', allow_pickle = True).item()
antibody_list_all_with_clinical = np.load('antibody_list_all_with_clinical.npy')
print(len(antibody_list_all_with_clinical))

In [None]:
#selected ACE2 binding residues
#select resi 417+446+447+449+453+455+456+473+475+476+477+484+486+487+489+490+493+494+495+496+498+500+501+502+503+505

In [None]:
experimental_tested_boolean = np.zeros(len(antibody_list_all_with_clinical))
clinical_tested_boolean = np.zeros(len(experimental_tested_boolean))
round2_boolean = np.zeros(len(experimental_tested_boolean))
for i in range(len(antibody_list_all_with_clinical)):
    if(antibody_list_all_with_clinical[i] in antibody_list_batch1):
        experimental_tested_boolean[i] = 1
    elif(antibody_list_all_with_clinical[i] in clinical_antibody_list):
        clinical_tested_boolean[i] = 1
    else:
        round2_boolean[i] = 1

experimental_tested_boolean = np.array(experimental_tested_boolean, dtype = bool)
clinical_tested_boolean = np.array(clinical_tested_boolean, dtype = bool)
round2_boolean = np.array(round2_boolean, dtype = bool)


#calculate the WT neutralization based on the ACE2 binding residues coverage
#running this section takes time
def neutralization_calculation(antibody_binding, unmutated_ace2_binding, ace2_binding_residue):
    return len(np.intersect1d(antibody_binding, unmutated_ace2_binding)) / len(ace2_binding_residue)

path = 'wt_haddock_results_balrefined_processed/'
bal_binding_residue_list = {}
pmi_score = {}
conformation_list = os.listdir(path)
for antibody_i in antibody_list_all_with_clinical[np.invert(clinical_tested_boolean)]:
    bal_binding_residue_list[antibody_i] = {}
    pmi_score[antibody_i] = {}
    for conformation_i in conformation_list:
        if(antibody_i == conformation_i[0:len(antibody_i)]):
            cluster_i = conformation_i.split(antibody_i + '_')[1].split('.pdb')[0]
            pymol.cmd.load(path + '/' + conformation_i)
            pymol.cmd.select('interface', '(chain B and not H*) within 5 of (chain H+L and not name H*)')
            pymol.cmd.save(path + '/ligand_interface.pdb', 'interface')
            pymol.cmd.delete('all')
            interface_input = open(path + '/ligand_interface.pdb').readlines()
            interface_resi_temp = []
            for row_i in interface_input:
                if(len(row_i.split()) == 12):
                    interface_resi_temp.append(row_i.split()[5])
            bal_binding_residue_list[antibody_i]['cluster' + cluster_i] = np.array(np.unique(np.array(interface_resi_temp)), dtype = int)
            os.remove(path + 'ligand_interface.pdb')
            #pmi_score[antibody_i]['cluster' + conformation_i.split('_')[2].split('.pdb')[0]] = float(open(path + '../haddock_results_balrefined/' + conformation_i.split('.pdb')[0] + '/' + 'PMI_log').readlines()[0].split('\n')[0])

#np.save('wt_bal_binding_residue_list.npy', bal_binding_residue_list)

bal_binding_residue_list = np.load('wt_bal_binding_residue_list.npy', allow_pickle = True).item()
pmi_score = np.load('pmi_score.npy', allow_pickle = True).item()

wt_neutralization_score = []
for antibody_i in antibody_list_all_with_clinical[np.invert(clinical_tested_boolean)]:
    cluster_list = list(pmi_score[antibody_i].keys())
    pmi_score_temp = []
    neutralization_score_temp = []
    for cluster_i in cluster_list:
        pmi_score_temp.append(pmi_score[antibody_i][cluster_i])
        binding_residue = bal_binding_residue_list[antibody_i][cluster_i]
        neutralization_score_temp.append(neutralization_calculation(binding_residue, ace2_binding_residue, ace2_binding_residue))
    pmi_score_temp = np.array(pmi_score_temp, dtype = float)
    neutralization_score_temp = np.array(neutralization_score_temp, dtype = float)
    wt_neutralization_score.append(np.sum(pmi_score_temp * neutralization_score_temp))

wt_neutralization_score = np.array(wt_neutralization_score, dtype = float)
#np.save('wt_neutralization_score.npy', wt_neutralization_score)

In [None]:
#experiment results
clinical_antibody_fc = {'ADI': {'Delta': 1.5, 'Omicron BA1': 108, 'Omicron BA5': 935},
                        'BAM': {'Delta': 1000, 'Omicron BA1': 1000, 'Omicron BA5': 686}, 
                        'BEB': {'Delta': 1, 'Omicron BA1': 1, 'Omicron BA5': 1}, 
                        'CAS': {'Delta': 0.7, 'Omicron BA1': 1000, 'Omicron BA5': 1000}, 
                        'CIL': {'Delta': 2.1, 'Omicron BA1': 1000, 'Omicron BA5': 9.4}, 
                        'ETE': {'Delta': 0.5, 'Omicron BA1': 414, 'Omicron BA5': 444}, 
                        'IMD': {'Delta': 2.1, 'Omicron BA1': 1000, 'Omicron BA5': 633}, 
                        'SOT': {'Delta': 1.3, 'Omicron BA1': 3.8, 'Omicron BA5': 16}, 
                        'TIX': {'Delta': 1, 'Omicron BA1': 306, 'Omicron BA5': 1000}, 
                        'REG': {'Delta': 28, 'Omicron BA1': 1000, 'Omicron BA5': 1000}, 
                        'AMU': {'Delta': 0.6, 'Omicron BA1': 136, 'Omicron BA5': 116}, 
                        'ROM': {'Delta': 'NA', 'Omicron BA1': 0.8, 'Omicron BA5': 64}, 
                        'C135': {'Delta': 0.4, 'Omicron BA1': 1000, 'Omicron BA5': 1000}, 
                        'C144': {'Delta': 2.5, 'Omicron BA1': 1000, 'Omicron BA5': 1000}
                       }

patient_antibody_ec50 = {
                        '1_2': {'WT': 5.698, 'Delta': 0.5436, 'Omicron BA1': 250, 'Omicron BA5': 250},
                        #'7_8': {'WT': 250, 'Delta': 20.33, 'Omicron BA1': 250, 'Omicron BA5': 250},
                        '13_14': {'WT': 250, 'Delta': 250, 'Omicron BA1': 250, 'Omicron BA5': 250},
                        '15_16': {'WT': 250, 'Delta': 250, 'Omicron BA1': 250, 'Omicron BA5': 250},
                        '17_18': {'WT': 250, 'Delta': 250, 'Omicron BA1': 250, 'Omicron BA5': 250},
                        '19_20': {'WT': 250, 'Delta': 136.7, 'Omicron BA1': 250, 'Omicron BA5': 250},
                        '25_26': {'WT': 69.64, 'Delta': 133.8, 'Omicron BA1': 250, 'Omicron BA5': 250},
                        '31_32': {'WT': 0.4067, 'Delta': 0.05577, 'Omicron BA1': 250, 'Omicron BA5': 250},
                        '43_44': {'WT': 250, 'Delta': 250, 'Omicron BA1': 250, 'Omicron BA5': 250},
                        #'118_119': {'WT': 250, 'Delta': 250, 'Omicron BA1': 250, 'Omicron BA5': 10.17},
                        '51_52': {'WT': 250, 'Delta': 250, 'Omicron BA1': 250, 'Omicron BA5': 250},
                        #'51_53': {'WT': 250, 'Delta': 202.8, 'Omicron BA1': 250, 'Omicron BA5': 250},
                        '86_87': {'WT': 250, 'Delta': 250, 'Omicron BA1': 250, 'Omicron BA5': 250},
                        '88_89': {'WT': 17.89, 'Delta': 7.167, 'Omicron BA1': 250, 'Omicron BA5': 5.296},
                        '92_93': {'WT': 250, 'Delta': 250, 'Omicron BA1': 250, 'Omicron BA5': 250},
                        '94_95': {'WT': 250, 'Delta': 250, 'Omicron BA1': 250, 'Omicron BA5': 250},
                        '96_97': {'WT': 250, 'Delta': 250, 'Omicron BA1': 250, 'Omicron BA5': 250},
                        '104_105': {'WT': 250, 'Delta': 250, 'Omicron BA1': 250, 'Omicron BA5': 250},
                        '108_109': {'WT': 3.858, 'Delta': 0.4388, 'Omicron BA1': 31.27, 'Omicron BA5': 250},
                        '110_111': {'WT': 250, 'Delta': 250, 'Omicron BA1': 250, 'Omicron BA5': 250},
                        #'116_117': {'WT': 250, 'Delta': 250, 'Omicron BA1': 250, 'Omicron BA5': 250},
                        '106_107': {'WT': 5.542, 'Delta': 1.341, 'Omicron BA1': 10.23, 'Omicron BA5': 146.6},
                        '11_12': {'WT': 1.587, 'Delta': 250, 'Omicron BA1': 2.276, 'Omicron BA5': 250}
}

single_antibody_binding_inhibition = {
        '1_2':     32.14, 
        #'7_8':     'undetectable', 
        '11_12':   'undetectable', 
        '13_14':   'undetectable', 
        '15_16':   'undetectable', 
        '17_18':   'undetectable', 
        '19_20':   32.84, 
        '25_26':   34.76, 
        '31_32':   5.667,  
        '43_44':   'undetectable', 
        '51_52':   'undetectable',
        #'51_53':   'undetectable',
        '86_87':   'undetectable', 
        '88_89':   30.68, 
        '92_93':   'undetectable', 
        '94_95':   'undetectable', 
        '96_97':   'undetectable',
        '104_105': 'undetectable', 
        '106_107': 2.997, 
        '108_109': 104.5, 
        '110_111': 'undetectable',
        #'116_117': 'undetectable',
        #'118_119': 'undetectable'
        }

In [None]:
language_model_name = ['antibody_mlm_seqConcate_transformer_21-09-04-23-40-31_779165', 
                       'antibody_mlm_seqConcate_transformer_21-09-04-23-41-19_756331', 
                       'antibody_mlm_seqConcate_transformer_21-09-04-23-41-50_667473', 
                       'antibody_mlm_seqConcate_transformer_21-09-04-23-44-28_263658', 
                       'antibody_mlm_seqIndiv_transformer_21-09-04-23-44-49_442174', 
                       'antibody_mlm_seqIndiv_transformer_21-09-04-23-45-49_733608', 
                       'antibody_mlm_seqIndiv_transformer_21-09-04-23-53-52_523451', 
                       'antibody_mlm_seqIndiv_transformer_21-09-05-00-19-31_090981'][4]
umap_2d = UMAP(n_components=2, init='random', random_state=0, metric='euclidean')
model = lambda h: spherical(h, len(clinical_antibody_label) + 2, 1, 0.0)
unsquareform = lambda a: a[np.nonzero(np.triu(a))]
heavy_chain_embedding = np.load('embedding_h_' + language_model_name + '.npy', allow_pickle = True).item()
light_chain_embedding = np.load('embedding_l_' + language_model_name + '.npy', allow_pickle = True).item()
language_model_embedding = np.zeros((len(antibody_list_all_with_clinical), 768 * 2))
for i in range(len(antibody_list_all_with_clinical)):
    antibody_i = antibody_list_all_with_clinical[i]
    language_model_embedding[i, 0:768] = heavy_chain_embedding[antibody_i]
    language_model_embedding[i, 768:] = light_chain_embedding[antibody_i]
proj_2d = umap_2d.fit_transform(language_model_embedding)
#chain_similarity = np.zeros((len(antibody_list_all_with_clinical), len(antibody_list_all_with_clinical)))
#for i in range(len(antibody_liwtst_all_with_clinical) - 1):
#    for j in range(i + 1, len(antibody_list_all_with_clinical)):
#        chain_similarity[i, j] = dist(language_model_embedding[i], language_model_embedding[j])
#np.save('language model chain similarity.npy', chain_similarity)
#np.save('umap_projection_2d.npy', proj_2d)

In [None]:
#Figure S2A
plt.figure(figsize=(16, 12), dpi=600)

clinical_shift_x = np.zeros(np.sum(experimental_tested_boolean))
clinical_shift_y = np.zeros(len(clinical_shift_x))
s = 40

plt.figure(figsize=(8, 8), dpi=300)
boolean_temp = np.logical_or(round2_boolean, experimental_tested_boolean)
plt.scatter(x = proj_2d[np.invert(boolean_temp), 0], y = proj_2d[np.invert(boolean_temp), 1], s = s * 4, label = 'Clinical antibodies (14)', c = 'black', marker = '*')
plt.scatter(x = proj_2d[boolean_temp, 0], y = proj_2d[boolean_temp, 1], s = s / 2, label = 'Patient antibodies (1366)', c = 'gray')
plt.xlabel('UMAP 1', fontsize = 25)
plt.ylabel('UMAP 2', fontsize = 25)
texts = []

for i, txt in enumerate(antibody_list_all_with_clinical[clinical_tested_boolean]):
    texts.append(plt.annotate(txt.split('_')[0], xy=(proj_2d[clinical_tested_boolean, 0][i], proj_2d[clinical_tested_boolean, 1][i]), xytext=(proj_2d[clinical_tested_boolean, 0][i]+0.5, proj_2d[clinical_tested_boolean, 1][i] + 0.5),  fontsize = 10, c = 'black'))

adjust_text(texts)
plt.legend()
plt.title('All antibodies', fontsize = 30)
#plt.savefig('FigureS2A_version1.eps', dpi = 1000, format = 'eps')
plt.savefig('FigureS2A_version1.png', dpi = 600)


In [None]:
#Figure S2A version 2
plt.figure(figsize=(8, 6), dpi=600)

clinical_shift_x = np.zeros(np.sum(experimental_tested_boolean))
clinical_shift_y = np.zeros(len(clinical_shift_x))

s = 40

plt.figure(figsize=(8, 8), dpi=300)

boolean_temp = np.logical_or(round2_boolean, experimental_tested_boolean)
plt.scatter(x = proj_2d[np.invert(boolean_temp), 0], y = proj_2d[np.invert(boolean_temp), 1], s = s * 4, label = 'Clinical antibodies (14)', c = 'black', marker = '*')
plt.scatter(x = proj_2d[boolean_temp, 0], y = proj_2d[boolean_temp, 1], s = s / 2, label = 'Patient antibodies (1366)', c = 'gray')
plt.scatter(x = proj_2d[experimental_tested_boolean, 0], y = proj_2d[experimental_tested_boolean, 1], s = s / 4, label = 'Patient antibodies (19 tested)', c = 'r', marker = 'o')

plt.xlabel('UMAP 1', fontsize = 25)
plt.ylabel('UMAP 2', fontsize = 25)

texts = []

for i, txt in enumerate(antibody_list_all_with_clinical[clinical_tested_boolean]):
    texts.append(plt.annotate(txt.split('_')[0], xy=(proj_2d[clinical_tested_boolean, 0][i], proj_2d[clinical_tested_boolean, 1][i]), xytext=(proj_2d[clinical_tested_boolean, 0][i]+0.5, proj_2d[clinical_tested_boolean, 1][i] + 0.5),  fontsize = 10, c = 'black'))
for i, txt in enumerate(antibody_list_all_with_clinical[experimental_tested_boolean]):
    texts.append(plt.annotate(txt.split('_')[0], xy=(proj_2d[experimental_tested_boolean, 0][i], proj_2d[experimental_tested_boolean, 1][i]), xytext=(proj_2d[experimental_tested_boolean, 0][i]+0.5, proj_2d[experimental_tested_boolean, 1][i] + 0.5),  fontsize = 10, c = 'red'))

adjust_text(texts)
plt.legend()
plt.title('All antibodies', fontsize = 30)
#plt.savefig('FigureS2A_version2.eps', dpi = 600, format = 'eps')
plt.savefig('FigureS2A_version2.png', dpi = 600)

output = open('covid antibody ml prediction.csv', 'w+')
output.write('Antibody ID,X,Y,WT neutralization,Fold improvement(Delta),Fold improvement(Omicron BA1),Fold improvement(Omicron BA5)\n')
for i in range(len(x)):
    output.write(antibody_list_all_with_clinical[np.invert(clinical_tested_boolean)][i] + ',' + 
                 str(x[i]) + ',' + 
                 str(y[i]) + ',' + 
                 str(wt_neutralization_score[i]) + ',' + 
                 str(kriging_prediction_results_delta[i]) + ',' + 
                 str(kriging_prediction_results_omicron_ba1[i]) + ',' + 
                 str(kriging_prediction_results_omicron_ba5[i]) + '\n')

output.close()

In [None]:
wt_neutralization_score = np.load('wt_neutralization_score.npy')
antibody_list_wt_neutralization_visualization = antibody_list_all_with_clinical[experimental_tested_boolean]
x_wt_neutralization_experimental_tested = proj_2d[experimental_tested_boolean, 0]
y_wt_neutralization_experimental_tested = proj_2d[experimental_tested_boolean, 1]
z_wt_neutralization_experimental_tested = wt_neutralization_score[experimental_tested_boolean[np.invert(clinical_tested_boolean)]]
    

In [None]:
X = np.linspace(min(x_wt_neutralization_experimental_tested), max(x_wt_neutralization_experimental_tested), num=50)
Y = np.linspace(min(y_wt_neutralization_experimental_tested), max(y_wt_neutralization_experimental_tested), num=50)
X, Y = np.meshgrid(X, Y)  # 2D grid for interpolation

interp = LinearNDInterpolator(list(zip(x_wt_neutralization_experimental_tested, y_wt_neutralization_experimental_tested)), z_wt_neutralization_experimental_tested)
Z = interp(X, Y)

In [None]:
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(x_wt_neutralization_experimental_tested, y_wt_neutralization_experimental_tested, z_wt_neutralization_experimental_tested, c=z_wt_neutralization_experimental_tested, cmap = 'rainbow', marker='o', edgecolors='k')

In [None]:
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(x_wt_neutralization_experimental_tested, y_wt_neutralization_experimental_tested, z_wt_neutralization_experimental_tested, c=z_wt_neutralization_experimental_tested, cmap = 'rainbow', marker='o', edgecolors='k')
surf = ax.plot_surface(X, Y, Z, cmap='rainbow',
                       linewidth=0, antialiased=True)

In [None]:
print(antibody_list_all_with_clinical[experimental_tested_boolean])
print(antibody_list_all_with_clinical[clinical_tested_boolean])


In [None]:
#Compare WT neutralization prediction to the experimental results
#to have a clearer WT neutralization, 
wt_neutralization_score = np.load('wt_neutralization_score.npy')
our_prediction = []
ec50 = []
antibody_list_tested = []
binding_inhibition = []
for antibody_i in list(patient_antibody_ec50.keys()):
    antibody_list_tested.append(antibody_i)
    ec50.append(patient_antibody_ec50[antibody_i]['WT'])
    our_prediction.append(wt_neutralization_score[np.where(antibody_list_all_with_clinical[np.invert(clinical_tested_boolean)] == antibody_i)[0][0]])
    binding_inhibition.append(single_antibody_binding_inhibition[antibody_i])

antibody_list_tested = np.array(antibody_list_tested)
wt_prediction = np.array(our_prediction)
ec50 = np.array(ec50)
binding_inhibition = np.array(binding_inhibition)
order = (-wt_prediction).argsort()

antibody_list_tested = antibody_list_tested[order]
neutralization_prediction = wt_prediction[order]
ec50 = ec50[order]
binding_inhibition = binding_inhibition[order]

print(np.percentile(wt_neutralization_score, 5))

wt_neutralization_score_truncated = wt_neutralization_score.copy()
wt_neutralization_score_truncated[wt_neutralization_score_truncated <= np.percentile(wt_neutralization_score, 5)] = np.percentile(wt_neutralization_score, 5)

wt_neutralization_score_truncated[wt_neutralization_score_truncated >= np.percentile(wt_neutralization_score, 95)] = np.percentile(wt_neutralization_score, 95)
print(np.max(wt_neutralization_score_truncated))
print(np.min(wt_neutralization_score_truncated))

In [None]:
s = 40
plt.figure(figsize=(8, 6), dpi=600)
   
subset=wt_neutralization_score>0.6
subset = np.append(subset,[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
print(len(subset))

print(np.invert(clinical_tested_boolean))
intersect = np.logical_and(subset, np.invert(clinical_tested_boolean))
print(intersect)

plt.scatter(x = proj_2d[intersect, 0], y = proj_2d[intersect, 1], s = s / 3,
            label = 'Patients antibody', c = wt_neutralization_score_truncated[intersect[:1366]], marker = 'o', cmap = 'Reds', alpha = 1, vmin=0.6, vmax=0.8)

plt.scatter(x = proj_2d[clinical_tested_boolean, 0], y = proj_2d[clinical_tested_boolean, 1], s = s, label = 'Clinical antibodies', c = 'black', marker = '*')
plt.xlabel('UMAP 1', fontsize = 25, fontname='Arial')
plt.ylabel('UMAP 2', fontsize = 25, fontname='Arial')

texts = []

for i, txt in enumerate(antibody_list_all_with_clinical[clinical_tested_boolean]):
    texts.append(plt.annotate(txt.split('_')[0], xy=(proj_2d[clinical_tested_boolean, 0][i], proj_2d[clinical_tested_boolean, 1][i]), xytext=(proj_2d[clinical_tested_boolean, 0][i]+0.5, proj_2d[clinical_tested_boolean, 1][i] + 0.5),  fontsize = 5, c = 'black'))

texts = []
for i, txt in enumerate(antibody_list_all_with_clinical[intersect]):
    texts.append(plt.annotate(txt.split('_')[0], xy=(proj_2d[intersect, 0][i], proj_2d[intersect, 1][i]), xytext=(proj_2d[intersect, 0][i],proj_2d[intersect, 1][i]+.1), fontsize = 5, c = 'blue'))

adjust_text(texts)

plt.savefig('FigS2-PredH-vs-Clinical.eps', format='eps', dpi = 600)

In [None]:
s = 40
plt.figure(figsize=(8, 6), dpi=600)
   
subset=wt_neutralization_score>0.6
subset = np.append(subset,[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
print(len(subset))

print(np.invert(clinical_tested_boolean))
intersect = np.logical_and(subset, np.invert(clinical_tested_boolean))
print(intersect)

plt.scatter(x = proj_2d[intersect, 0], y = proj_2d[intersect, 1], s = s / 3,
            label = 'Patients antibody', c = wt_neutralization_score_truncated[intersect[:1366]], marker = 'o', cmap = 'Reds', alpha = 1, vmin=0.6, vmax=0.8)

plt.scatter(x = proj_2d[clinical_tested_boolean, 0], y = proj_2d[clinical_tested_boolean, 1], s = s, label = 'Clinical antibodies', c = 'black', marker = '*')
plt.xlabel('UMAP 1', fontsize = 25, fontname='Arial')
plt.ylabel('UMAP 2', fontsize = 25, fontname='Arial')

texts = []

for i, txt in enumerate(antibody_list_all_with_clinical[clinical_tested_boolean]):
    texts.append(plt.annotate(txt.split('_')[0], xy=(proj_2d[clinical_tested_boolean, 0][i], proj_2d[clinical_tested_boolean, 1][i]), xytext=(proj_2d[clinical_tested_boolean, 0][i]+0.5, proj_2d[clinical_tested_boolean, 1][i] + 0.5),  fontsize = 5, c = 'black'))

texts = []
for i, txt in enumerate(antibody_list_all_with_clinical[intersect]):
    texts.append(plt.annotate(txt.split('_')[0], xy=(proj_2d[intersect, 0][i], proj_2d[intersect, 1][i]), xytext=(proj_2d[intersect, 0][i],proj_2d[intersect, 1][i]+.1), fontsize = 5, c = 'blue'))

adjust_text(texts)

plt.savefig('FigS2-PredH-vs-Clinical.png', dpi = 600)

In [None]:
#Figure S2b
s = 40
plt.figure(figsize=(8, 6), dpi=600)
plt.scatter(x = proj_2d[np.invert(clinical_tested_boolean), 0], y = proj_2d[np.invert(clinical_tested_boolean), 1], s = s / 3,
            label = 'Patients antibody', c = wt_neutralization_score_truncated, marker = 'o', cmap = 'rainbow', alpha = 1)

#plt.scatter(x = proj_2d[clinical_tested_boolean, 0], y = proj_2d[clinical_tested_boolean, 1], s = s, label = 'Clinical antibody', c = 'black', marker = '+')
plt.xlabel('UMAP 1', fontsize = 25)
plt.ylabel('UMAP 2', fontsize = 25)
texts = []
for i, txt in enumerate(antibody_list_all_with_clinical[experimental_tested_boolean]):
    if(txt.split('_')[0] == '106'):
        texts.append(plt.annotate(txt.replace('_', '/'), (proj_2d[experimental_tested_boolean, 0][i], proj_2d[experimental_tested_boolean, 1][i]), fontsize = 10, c = 'black'))
    if(txt.split('_')[0] == '88'):
        texts.append(plt.annotate(txt.replace('_', '/'), (proj_2d[experimental_tested_boolean, 0][i], proj_2d[experimental_tested_boolean, 1][i]), fontsize = 10, c = 'black'))

adjust_text(texts)
plt.title('Predicted WT neutralization', fontsize = 25)
plt.colorbar()
plt.tight_layout()
#plt.savefig('FigureS2b-WT_prediction.eps', dpi = 600, format = 'eps')
plt.savefig('FigureS2b-WT_prediction.png', dpi = 600)

In [None]:
def distance(x1,y1,x2,y2):
    d=np.sqrt((x1-x2)**2+(y1-y2)**2)
    return d

def idw_npoint(xz,yz,n_point,p):
    r=10 #block radius iteration distance
    nf=0
    while nf<=n_point: #will stop when np reaching at least n_point
        x_block=[]
        y_block=[]
        z_block=[]
        r +=10 # add 10 unit each iteration
        xr_min=xz-r
        xr_max=xz+r
        yr_min=yz-r
        yr_max=yz+r
        for i in range(len(x)):
            # condition to test if a point is within the block
            if ((x[i]>=xr_min and x[i]<=xr_max) and (y[i]>=yr_min and y[i]<=yr_max)):
                x_block.append(x[i])
                y_block.append(y[i])
                z_block.append(z[i])
        nf=len(x_block) #calculate number of point in the block
    
    #calculate weight based on distance and p value
    w_list=[]
    for j in range(len(x_block)):
        d=distance(xz,yz,x_block[j],y_block[j])
        if d>0:
            w=1/(d**p)
            w_list.append(w)
            z0=0
        else:
            w_list.append(0) #if meet this condition, it means d<=0, weight is set to 0
    
    #check if there is 0 in weight list
    w_check=0 in w_list
    if w_check==True:
        idx=w_list.index(0) # find index for weight=0
        z_idw=z_block[idx] # set the value to the current sample value
    else:
        wt=np.transpose(w_list)
        z_idw=np.dot(z_block,wt)/sum(w_list) # idw calculation using dot product
    return z_idw

In [None]:
x = proj_2d[np.invert(clinical_tested_boolean), 0]
y = proj_2d[np.invert(clinical_tested_boolean), 1]
z = wt_neutralization_score
n=200 #number of interpolation point for x and y axis
x_min=min(x)
x_max=max(x)
y_min=min(y)
y_max=max(y)
w=x_max-x_min #width
h=y_max-y_min #length
wn=w/n #x interval
hn=h/n #y interval

#list to store interpolation point and elevation
y_init=y_min
x_init=x_min
x_idw_list=[]
y_idw_list=[]
z_head=[]
for i in range(n):
    xz=x_init+wn*i
    yz=y_init+hn*i
    y_idw_list.append(yz)
    x_idw_list.append(xz)
    z_idw_list=[]
    for j in range(n):
        xz=x_init+wn*j
        z_idw=idw_npoint(xz,yz,5,1.5) #min. point=5, p=1.5
        z_idw_list.append(z_idw)
    z_head.append(z_idw_list)

In [None]:
#Figure 2D
#write as eps
fig=go.Figure()
fig.add_trace(go.Surface(z=z_head,x=x_idw_list,y=y_idw_list, colorscale = 'Rainbow'))
fig.add_scatter3d(x = x_wt_neutralization_experimental_tested, 
                  y = y_wt_neutralization_experimental_tested, 
                  z = z_wt_neutralization_experimental_tested, 
                  mode = 'markers', marker = {'size': 4, 
                                              'color': z_wt_neutralization_experimental_tested})
fig.update_layout(scene=dict(aspectratio=dict(x=5, y=5, z=2), 
                             xaxis = dict(range=[x_min,x_max],), 
                             yaxis = dict(range=[y_min,y_max]), 
                             zaxis = dict(range=[0.05, 0.7]), 
                             xaxis_title = '', 
                             yaxis_title = '', 
                             zaxis_title = ''))

camera = dict(eye = dict(x = 4, y = -8, z = 3))
fig.update_layout(scene_camera = camera)
fig.write_image('Figure2D.png', height = 1200, width = 1600, scale = 2)
#go_offline.plot()
#go_offline.plot(fig,filename='3d wt neutralization.html',validate=True, auto_open=False)

In [None]:
variant_i = 'Delta'
clinical_antibody_labeled_list = []
clinical_antibody_label = []
for antibody_i in clinical_antibody_list:
    if(not clinical_antibody_fc[antibody_i][variant_i] == 'NA'):
        clinical_antibody_labeled_list.append(antibody_i)
        clinical_antibody_label.append(clinical_antibody_fc[antibody_i][variant_i])

clinical_antibody_label = np.array(clinical_antibody_label, dtype = float)
clinical_antibody_label = np.log10(clinical_antibody_label)
clinical_antibody_labeled_index = []
for antibody_i in clinical_antibody_labeled_list:
    clinical_antibody_labeled_index.append(np.where(antibody_list_all_with_clinical == antibody_i)[0][0])

#language embedding based chain similarity
chain_similarity = np.load('language model chain similarity.npy')
model = lambda h: spherical(h, len(clinical_antibody_label) + 2, 1, 0.0)
unsquareform = lambda a: a[np.nonzero(np.triu(a))]
kriging_prediction_results = []
for antibody_i in antibody_list_all_with_clinical[np.invert(clinical_tested_boolean)]:
    kriging_antibody_index = []
    kriging_antibody_index.append(np.where(antibody_i == antibody_list_all_with_clinical)[0][0])
    for clinical_antibody_i in clinical_antibody_labeled_list:
        kriging_antibody_index.append(np.where(clinical_antibody_i == antibody_list_all_with_clinical)[0][0])
    kriging_antibody_index = np.array(kriging_antibody_index, dtype = int)
    kriging_similarity_matrix = chain_similarity[kriging_antibody_index][:, kriging_antibody_index]
    variance = kriging_similarity_matrix[0, 1:]
    variance = np.concatenate((variance, [1]))
    kriging_similarity_matrix = kriging_similarity_matrix[1:][:, 1:]
    kriging_similarity_matrix = unsquareform(kriging_similarity_matrix)
    kriging_similarity_matrix = model(kriging_similarity_matrix)
    kriging_similarity_matrix = squareform(kriging_similarity_matrix)
    kriging_similarity_matrix = np.concatenate((kriging_similarity_matrix, np.zeros((len(clinical_antibody_labeled_index), 1)) + 1), axis = 1)
    last_row = np.zeros((1, len(clinical_antibody_labeled_index) + 1)) + 1
    last_row[0, -1] = 0
    kriging_similarity_matrix = np.concatenate((kriging_similarity_matrix, last_row), axis = 0)
    weights = solve(kriging_similarity_matrix, variance)
    kriging_prediction_results.append(clinical_antibody_label.dot(weights[:-1]))
kriging_prediction_results = np.array(kriging_prediction_results, dtype = float)
kriging_prediction_results_delta = kriging_prediction_results
#kriging_prediction_results_delta_standardized = ( kriging_prediction_results - np.min(kriging_prediction_results) ) / (np.max(kriging_prediction_results) - np.min(kriging_prediction_results))


In [None]:
variant_i = 'Omicron BA1'
clinical_antibody_labeled_list = []
clinical_antibody_label = []
for antibody_i in clinical_antibody_list:
    if(not clinical_antibody_fc[antibody_i][variant_i] == 'NA'):
        clinical_antibody_labeled_list.append(antibody_i)
        clinical_antibody_label.append(clinical_antibody_fc[antibody_i][variant_i])

clinical_antibody_label = np.array(clinical_antibody_label, dtype = float)
clinical_antibody_label = np.log10(clinical_antibody_label)
clinical_antibody_labeled_index = []
for antibody_i in clinical_antibody_labeled_list:
    clinical_antibody_labeled_index.append(np.where(antibody_list_all_with_clinical == antibody_i)[0][0])

#language embedding based chain similarity
chain_similarity = np.load('language model chain similarity.npy')
model = lambda h: spherical(h, len(clinical_antibody_label) + 2, 1, 0.0)
unsquareform = lambda a: a[np.nonzero(np.triu(a))]
kriging_prediction_results = []
for antibody_i in antibody_list_all_with_clinical[np.invert(clinical_tested_boolean)]:
    kriging_antibody_index = []
    kriging_antibody_index.append(np.where(antibody_i == antibody_list_all_with_clinical)[0][0])
    for clinical_antibody_i in clinical_antibody_labeled_list:
        kriging_antibody_index.append(np.where(clinical_antibody_i == antibody_list_all_with_clinical)[0][0])
    kriging_antibody_index = np.array(kriging_antibody_index, dtype = int)
    kriging_similarity_matrix = chain_similarity[kriging_antibody_index][:, kriging_antibody_index]
    variance = kriging_similarity_matrix[0, 1:]
    variance = np.concatenate((variance, [1]))
    kriging_similarity_matrix = kriging_similarity_matrix[1:][:, 1:]
    kriging_similarity_matrix = unsquareform(kriging_similarity_matrix)
    kriging_similarity_matrix = model(kriging_similarity_matrix)
    kriging_similarity_matrix = squareform(kriging_similarity_matrix)
    kriging_similarity_matrix = np.concatenate((kriging_similarity_matrix, np.zeros((len(clinical_antibody_labeled_index), 1)) + 1), axis = 1)
    last_row = np.zeros((1, len(clinical_antibody_labeled_index) + 1)) + 1
    last_row[0, -1] = 0
    kriging_similarity_matrix = np.concatenate((kriging_similarity_matrix, last_row), axis = 0)
    weights = solve(kriging_similarity_matrix, variance)
    kriging_prediction_results.append(clinical_antibody_label.dot(weights[:-1]))
kriging_prediction_results = np.array(kriging_prediction_results, dtype = float)
kriging_prediction_results_omicron_ba1 = kriging_prediction_results
#kriging_prediction_results_omicron_ba1_standardized = ( kriging_prediction_results - np.min(kriging_prediction_results) ) / (np.max(kriging_prediction_results) - np.min(kriging_prediction_results))


In [None]:
variant_i = 'Omicron BA5'
clinical_antibody_labeled_list = []
clinical_antibody_label = []
for antibody_i in clinical_antibody_list:
    if(not clinical_antibody_fc[antibody_i][variant_i] == 'NA'):
        clinical_antibody_labeled_list.append(antibody_i)
        clinical_antibody_label.append(clinical_antibody_fc[antibody_i][variant_i])

clinical_antibody_label = np.array(clinical_antibody_label, dtype = float)
clinical_antibody_label = np.log10(clinical_antibody_label)
clinical_antibody_labeled_index = []
for antibody_i in clinical_antibody_labeled_list:
    clinical_antibody_labeled_index.append(np.where(antibody_list_all_with_clinical == antibody_i)[0][0])

#language embedding based chain similarity
chain_similarity = np.load('language model chain similarity.npy')
model = lambda h: spherical(h, len(clinical_antibody_label) + 2, 1, 0.0)
unsquareform = lambda a: a[np.nonzero(np.triu(a))]
kriging_prediction_results = []
for antibody_i in antibody_list_all_with_clinical[np.invert(clinical_tested_boolean)]:
    kriging_antibody_index = []
    kriging_antibody_index.append(np.where(antibody_i == antibody_list_all_with_clinical)[0][0])
    for clinical_antibody_i in clinical_antibody_labeled_list:
        kriging_antibody_index.append(np.where(clinical_antibody_i == antibody_list_all_with_clinical)[0][0])
    kriging_antibody_index = np.array(kriging_antibody_index, dtype = int)
    kriging_similarity_matrix = chain_similarity[kriging_antibody_index][:, kriging_antibody_index]
    variance = kriging_similarity_matrix[0, 1:]
    variance = np.concatenate((variance, [1]))
    kriging_similarity_matrix = kriging_similarity_matrix[1:][:, 1:]
    kriging_similarity_matrix = unsquareform(kriging_similarity_matrix)
    kriging_similarity_matrix = model(kriging_similarity_matrix)
    kriging_similarity_matrix = squareform(kriging_similarity_matrix)
    kriging_similarity_matrix = np.concatenate((kriging_similarity_matrix, np.zeros((len(clinical_antibody_labeled_index), 1)) + 1), axis = 1)
    last_row = np.zeros((1, len(clinical_antibody_labeled_index) + 1)) + 1
    last_row[0, -1] = 0
    kriging_similarity_matrix = np.concatenate((kriging_similarity_matrix, last_row), axis = 0)
    weights = solve(kriging_similarity_matrix, variance)
    kriging_prediction_results.append(clinical_antibody_label.dot(weights[:-1]))
kriging_prediction_results = np.array(kriging_prediction_results, dtype = float)
kriging_prediction_results_omicron_ba5 = kriging_prediction_results
#kriging_prediction_results_omicron_ba5_standardized = ( kriging_prediction_results - np.min(kriging_prediction_results) ) / (np.max(kriging_prediction_results) - np.min(kriging_prediction_results))


In [None]:
print(np.mean((kriging_prediction_results_delta >= -1) & (kriging_prediction_results_delta <= 3)))
print(np.mean((kriging_prediction_results_omicron_ba1 >= -1) & (kriging_prediction_results_omicron_ba1 <= 3)))
print(np.mean((kriging_prediction_results_omicron_ba5 >= -1) & (kriging_prediction_results_omicron_ba5 <= 3)))


In [None]:
kriging_prediction_results_delta[kriging_prediction_results_delta <= -1] = -1
kriging_prediction_results_delta[kriging_prediction_results_delta >= 3] = 3
kriging_prediction_results_omicron_ba1[kriging_prediction_results_omicron_ba1 <= -1] = -1
kriging_prediction_results_omicron_ba1[kriging_prediction_results_omicron_ba1 >= 3] = 3
kriging_prediction_results_omicron_ba5[kriging_prediction_results_omicron_ba5 <= -1] = -1
kriging_prediction_results_omicron_ba5[kriging_prediction_results_omicron_ba5 >= 3] = 3

#to have the fold improvement
kriging_prediction_results_delta = - kriging_prediction_results_delta
kriging_prediction_results_omicron_ba1 = - kriging_prediction_results_omicron_ba1
kriging_prediction_results_omicron_ba5 = - kriging_prediction_results_omicron_ba5

#kriging_prediction_results_delta_standardized = 1 - ((kriging_prediction_results_delta + 1) / 4)
#kriging_prediction_results_omicron_ba1_standardized = 1 - ((kriging_prediction_results_omicron_ba1 + 1) / 4)
#kriging_prediction_results_omicron_ba5_standardized = 1 - ((kriging_prediction_results_omicron_ba5 + 1) / 4)


In [None]:
variant_i = 'Delta'
#for all 3 variants, sizes are the same, as 21, 1359
language_model_name = 'antibody_mlm_seqIndiv_transformer_21-09-04-23-44-49_442174'
antibody_list_all_with_clinical = np.load('antibody_list_all_with_clinical.npy')
all_antibodies_tested_embedding = np.zeros((21, 768 * 2))
all_antibodies_tested_fc = []
all_antibodies_tested_list = []
all_antibodies_untested_embedding = np.zeros((1359, 768 * 2))
all_antibodies_untested_fc = []
all_antibodies_untested_list = []
heavy_chain_embedding = np.load('embedding_h_' + language_model_name + '.npy', allow_pickle = True).item()
light_chain_embedding = np.load('embedding_l_' + language_model_name + '.npy', allow_pickle = True).item()
count_i = 0
for antibody_i in list(clinical_antibody_fc):
    if(clinical_antibody_fc[antibody_i][variant_i] != 'NA'):
        all_antibodies_tested_list.append(antibody_i)
        antibody_i_index = np.where(antibody_list_all_with_clinical)
        all_antibodies_tested_embedding[count_i][0:768] = heavy_chain_embedding[antibody_i]
        all_antibodies_tested_embedding[count_i][768:] = light_chain_embedding[antibody_i]
        all_antibodies_tested_fc.append(clinical_antibody_fc[antibody_i][variant_i])
        count_i += 1
for antibody_i in list(patient_antibody_ec50):
    if((patient_antibody_ec50[antibody_i]['WT'] != 250) or (patient_antibody_ec50[antibody_i][variant_i] != 250)):
        all_antibodies_tested_list.append(antibody_i)
        all_antibodies_tested_embedding[count_i][0:768] = heavy_chain_embedding[antibody_i]
        all_antibodies_tested_embedding[count_i][768:] = light_chain_embedding[antibody_i]
        all_antibodies_tested_fc.append(patient_antibody_ec50[antibody_i][variant_i] / patient_antibody_ec50[antibody_i]['WT'])
        count_i += 1
count_i = 0
for antibody_i in antibody_list_all_with_clinical:
    if(not antibody_i in all_antibodies_tested_list):
        all_antibodies_untested_list.append(antibody_i)
        all_antibodies_untested_embedding[count_i][0:768] = heavy_chain_embedding[antibody_i]
        all_antibodies_untested_embedding[count_i][768:] = light_chain_embedding[antibody_i]
        count_i += 1

all_antibodies_tested_fc = np.log10(np.array(all_antibodies_tested_fc))
all_antibodies_untested_list = np.array(all_antibodies_untested_list)
print(len(all_antibodies_tested_fc))
fc_threshold = 1
print(np.mean(all_antibodies_tested_fc <= fc_threshold))


In [None]:
print(len(clinical_tested_boolean))
print(np.sum(clinical_tested_boolean))

kriging_prediction_all = np.zeros((len(all_antibodies_untested_list)-1, 4))
for i in range(len(all_antibodies_untested_list)-1):
    index = np.where(all_antibodies_untested_list[i] == antibody_list_all_with_clinical[np.invert(clinical_tested_boolean)])[0][0]
    kriging_prediction_all[i, 0] = kriging_prediction_results_delta[index]
    kriging_prediction_all[i, 1] = kriging_prediction_results_omicron_ba1[index]
    kriging_prediction_all[i, 2] = kriging_prediction_results_omicron_ba5[index]
    kriging_prediction_all[i, 3] = wt_neutralization_score[index]

np.savetxt('scenario b kriging prediction.csv', kriging_prediction_all, delimiter = ',')

In [None]:
#Figure s2c 1
s = 40

plt.figure(figsize=(6, 6), dpi=600)
#plt.colorbar()
plt.scatter(x = proj_2d[clinical_tested_boolean, 0], y = proj_2d[clinical_tested_boolean, 1], s = s * 4, label = 'Clinical antibodies', c = 'black', marker = '*')
plt.scatter(x = proj_2d[np.invert(clinical_tested_boolean), 0], y = proj_2d[np.invert(clinical_tested_boolean), 1], s = s / 2, label = 'Patient antibodies', c = kriging_prediction_results_delta, marker = 'o', cmap = 'rainbow')

plt.xlabel('UMAP 1', fontsize = 25)
plt.ylabel('UMAP 2', fontsize = 25)
experimental_shift_x = np.zeros(19)
experimental_shift_y = np.zeros(19)


#for i, txt in enumerate(antibody_list_all_with_clinical[experimental_tested_boolean]):
#    plt.annotate(txt.split('_')[0], (proj_2d[experimental_tested_boolean, 0][i] + experimental_shift_x[i], proj_2d[experimental_tested_boolean, 1][i] + experimental_shift_y[i]), fontsize = 8, c = 'black')

#for i, txt in enumerate(antibody_list_all_with_clinical[clinical_tested_boolean]):
#    plt.annotate(txt.split('_')[0], (proj_2d[clinical_tested_boolean, 0][i] + clinical_shift_x[i], proj_2d[clinical_tested_boolean, 1][i] + clinical_shift_y[i]), fontsize = 8, c = 'black')

plt.legend()
plt.title('WT -> Delta', fontsize = 30)
legend_order = [1, 0]
handles, labels = plt.gca().get_legend_handles_labels()
plt.legend([handles[idx] for idx in legend_order],[labels[idx] for idx in legend_order], loc = 4)
plt.tight_layout()
plt.savefig('Figures2c1.png', dpi = 600)

In [None]:
x = proj_2d[np.invert(clinical_tested_boolean), 0]
y = proj_2d[np.invert(clinical_tested_boolean), 1]
z = kriging_prediction_results_delta
n=200 #number of interpolation point for x and y axis
x_min=min(x)
x_max=max(x)
y_min=min(y)
y_max=max(y)
w=x_max-x_min #width
h=y_max-y_min #length
wn=w/n #x interval
hn=h/n #y interval

#list to store interpolation point and elevation
y_init=y_min
x_init=x_min
x_idw_list=[]
y_idw_list=[]
z_head=[]
for i in range(n):
    xz=x_init+wn*i
    yz=y_init+hn*i
    y_idw_list.append(yz)
    x_idw_list.append(xz)
    z_idw_list=[]
    for j in range(n):
        xz=x_init+wn*j
        z_idw=idw_npoint(xz,yz,5,1.5) #min. point=5, p=1.5
        z_idw_list.append(z_idw)
    z_head.append(z_idw_list)

In [None]:
#Figure 2D Delta
fig=go.Figure()
fig.add_trace(go.Surface(z=z_head,x=x_idw_list,y=y_idw_list, colorscale = 'Rainbow'))
fig.add_scatter3d(x = x[experimental_tested_boolean[np.invert(clinical_tested_boolean)]], 
                  y = y[experimental_tested_boolean[np.invert(clinical_tested_boolean)]], 
                  z = z[experimental_tested_boolean[np.invert(clinical_tested_boolean)]], 
                  mode = 'markers', marker = {'size': 4, 
                                              'color': z[experimental_tested_boolean[np.invert(clinical_tested_boolean)]], 
                                              'colorscale': 'Rainbow'})
fig.update_layout(scene=dict(aspectratio=dict(x=5, y=5, z=2), 
                             xaxis = dict(range=[x_min,x_max]), 
                             yaxis = dict(range=[y_min,y_max]), 
                             zaxis = dict(range=[-3, 1]), 
                             xaxis_title = '', 
                             yaxis_title = '', 
                             zaxis_title = ''))
camera = dict(eye = dict(x = 4, y = -8, z = 3))
fig.update_layout(scene_camera = camera)
fig.write_image('Figure2D_delta.png', height = 1200, width = 1600, scale = 2)

#go_offline.plot(fig,filename='3d wt to delta.html',validate=True, auto_open=False)

In [None]:
#write supplementary file
wt_neutralization_score = np.load('wt_neutralization_score.npy')
output = open('Robustness prediction.csv', 'w+')
output.write('Antibody ID,WT neutralization,Robustness (Delta),Robustness (Omicron BA1),Robustness (Omicron BA5),Heavy chain,Light chain\n')
for i in range(np.sum(np.invert(clinical_tested_boolean))):
    antibody_i = antibody_list_all_with_clinical[np.invert(clinical_tested_boolean)][i]
    neutralization_i = str(round(wt_neutralization_score[i], 3)) + ','
    robustness_i_delta = str(round(kriging_prediction_results_delta[i], 3)) + ','
    robustness_i_omicron_ba1 = str(round(kriging_prediction_results_omicron_ba1[i], 3)) + ','
    robustness_i_omicron_ba5 = str(round(kriging_prediction_results_omicron_ba5[i], 3)) + ','
    robustness_i = robustness_i_delta + robustness_i_omicron_ba1 + robustness_i_omicron_ba5
    output.write(antibody_i.split('_')[0] + '/' + antibody_i.split('_')[1] + ',' + neutralization_i + robustness_i + seq_dict_all[antibody_i]['H'] + ',' + seq_dict_all[antibody_i]['L'] + '\n')

output.close()

In [None]:
plt.figure(figsize=(6, 6), dpi=600)
plt.scatter(x = proj_2d[clinical_tested_boolean, 0], y = proj_2d[clinical_tested_boolean, 1], s = s * 4, label = 'Clinical antibodies', c = 'black', marker = '*')
plt.scatter(x = proj_2d[np.invert(clinical_tested_boolean), 0], y = proj_2d[np.invert(clinical_tested_boolean), 1], s = s / 2, label = 'Patient antibodies', c = kriging_prediction_results_omicron_ba1, marker = 'o', cmap = 'rainbow')
plt.colorbar()

#plt.colorbar()
plt.xlabel('UMAP 1', fontsize = 25)
plt.ylabel('UMAP 2', fontsize = 25)
#for i, txt in enumerate(antibody_list_all_with_clinical[experimental_tested_boolean]):
#    plt.annotate(txt.split('_')[0], (proj_2d[experimental_tested_boolean, 0][i] + experimental_shift_x[i], proj_2d[experimental_tested_boolean, 1][i] + experimental_shift_y[i]), fontsize = 10, c = 'black')

#for i, txt in enumerate(antibody_list_all_with_clinical[clinical_tested_boolean]):
#    plt.annotate(txt.split('_')[0], (proj_2d[clinical_tested_boolean, 0][i] + clinical_shift_x[i], proj_2d[clinical_tested_boolean, 1][i] + clinical_shift_y[i]), fontsize = 10, c = 'black')

plt.legend()
leg = plt.gca().get_legend()
legend_dict = {}
leg.legendHandles[0].set_color('black')


plt.title('WT -> Omicron BA1', fontsize = 30)
legend_order = [1, 0]
handles, labels = plt.gca().get_legend_handles_labels()
plt.legend([handles[idx] for idx in legend_order],[labels[idx] for idx in legend_order], loc = 4)
plt.tight_layout()
#plt.savefig('All antibody clustering ' + language_model_name + ' all sequence with chopping kriging omicron ba1 colored.png', pmi = 3000, facecolor = 'white')

In [None]:
x = proj_2d[np.invert(clinical_tested_boolean), 0]
y = proj_2d[np.invert(clinical_tested_boolean), 1]
z = kriging_prediction_results_omicron_ba1
n=200 #number of interpolation point for x and y axis
x_min=min(x)
x_max=max(x)
y_min=min(y)
y_max=max(y)
w=x_max-x_min #width
h=y_max-y_min #length
wn=w/n #x interval
hn=h/n #y interval

#list to store interpolation point and elevation
y_init=y_min
x_init=x_min
x_idw_list=[]
y_idw_list=[]
z_head=[]
for i in range(n):
    xz=x_init+wn*i
    yz=y_init+hn*i
    y_idw_list.append(yz)
    x_idw_list.append(xz)
    z_idw_list=[]
    for j in range(n):
        xz=x_init+wn*j
        z_idw=idw_npoint(xz,yz,5,1.5) #min. point=5, p=1.5
        z_idw_list.append(z_idw)
    z_head.append(z_idw_list)

In [None]:
#Figure2D omicron ba1
fig=go.Figure()
fig.add_trace(go.Surface(z=z_head,x=x_idw_list,y=y_idw_list, colorscale = 'Rainbow'))
fig.add_scatter3d(x = x[experimental_tested_boolean[np.invert(clinical_tested_boolean)]], 
                  y = y[experimental_tested_boolean[np.invert(clinical_tested_boolean)]], 
                  z = z[experimental_tested_boolean[np.invert(clinical_tested_boolean)]], 
                  mode = 'markers', marker = {'size': 4, 
                                              'color': z[experimental_tested_boolean[np.invert(clinical_tested_boolean)]]
                                              , 'colorscale': 'Rainbow'
                                             })
fig.update_layout(scene=dict(aspectratio=dict(x=5, y=5, z=2), 
                             xaxis = dict(range=[x_min,x_max]), 
                             yaxis = dict(range=[y_min,y_max]), 
                             zaxis = dict(range=[-3, 1]), 
                             xaxis_title = '', 
                             yaxis_title = '', 
                             zaxis_title = ''))

camera = dict(eye = dict(x = 4, y = -8, z = 3))
fig.update_layout(scene_camera = camera)
fig.write_image('Figure2D_omicronba1.png', height = 1200, width = 1600, scale = 2)
#go_offline.plot(fig,filename='3d wt to omicron ba1.html',validate=True, auto_open=False)

In [None]:
plt.figure(figsize=(6, 6), dpi=600)
plt.scatter(x = proj_2d[np.invert(clinical_tested_boolean), 0], y = proj_2d[np.invert(clinical_tested_boolean), 1], s = s / 2, label = 'Patient antibodies', c = kriging_prediction_results_omicron_ba5, marker = 'o', cmap = 'rainbow')
#plt.colorbar()
plt.scatter(x = proj_2d[clinical_tested_boolean, 0], y = proj_2d[clinical_tested_boolean, 1], s = s * 4, label = 'Clinical antibodies', c = 'black', marker = '*')

plt.xlabel('UMAP 1', fontsize = 25)
plt.ylabel('UMAP 2', fontsize = 25)

#for i, txt in enumerate(antibody_list_all_with_clinical[experimental_tested_boolean]):
#    plt.annotate(txt.split('_')[0], (proj_2d[experimental_tested_boolean, 0][i] + experimental_shift_x[i], proj_2d[experimental_tested_boolean, 1][i] + experimental_shift_y[i]), fontsize = 10, c = 'black')

#for i, txt in enumerate(antibody_list_all_with_clinical[clinical_tested_boolean]):
#    plt.annotate(txt.split('_')[0], (proj_2d[clinical_tested_boolean, 0][i] + clinical_shift_x[i], proj_2d[clinical_tested_boolean, 1][i] + clinical_shift_y[i]), fontsize = 10, c = 'black')

plt.title('WT -> Omicron BA5', fontsize = 30)
legend_order = [1, 0]
handles, labels = plt.gca().get_legend_handles_labels()
plt.legend([handles[idx] for idx in legend_order],[labels[idx] for idx in legend_order], loc = 4, fontsize = 16)
plt.tight_layout()
#plt.savefig('All antibody clustering ' + language_model_name + ' all sequence with chopping kriging omicron ba5 colored.png', pmi = 3000, facecolor = 'white')

In [None]:
x = proj_2d[np.invert(clinical_tested_boolean), 0]
y = proj_2d[np.invert(clinical_tested_boolean), 1]
z = kriging_prediction_results_omicron_ba5
n=200 #number of interpolation point for x and y axis
x_min=min(x)
x_max=max(x)
y_min=min(y)
y_max=max(y)
w=x_max-x_min #width
h=y_max-y_min #length
wn=w/n #x interval
hn=h/n #y interval

#list to store interpolation point and elevation
y_init=y_min
x_init=x_min
x_idw_list=[]
y_idw_list=[]
z_head=[]
for i in range(n):
    xz=x_init+wn*i
    yz=y_init+hn*i
    y_idw_list.append(yz)
    x_idw_list.append(xz)
    z_idw_list=[]
    for j in range(n):
        xz=x_init+wn*j
        z_idw=idw_npoint(xz,yz,5,1.5) #min. point=5, p=1.5
        z_idw_list.append(z_idw)
    z_head.append(z_idw_list)

In [None]:
#Figure 2D Omicron BA5
fig=go.Figure()
fig.add_trace(go.Surface(z=z_head,x=x_idw_list,y=y_idw_list, colorscale = 'Rainbow'))
fig.add_scatter3d(x = x[experimental_tested_boolean[np.invert(clinical_tested_boolean)]], 
                  y = y[experimental_tested_boolean[np.invert(clinical_tested_boolean)]], 
                  z = z[experimental_tested_boolean[np.invert(clinical_tested_boolean)]], 
                  mode = 'markers', marker = {'size': 4, 
                                              'color': z[experimental_tested_boolean[np.invert(clinical_tested_boolean)]]})
fig.update_layout(scene=dict(aspectratio=dict(x=5, y=5, z=2), 
                             xaxis = dict(range=[x_min,x_max]), 
                             yaxis = dict(range=[y_min,y_max]), 
                             zaxis = dict(range=[-3, 1]), 
                             xaxis_title = '', 
                             yaxis_title = '', 
                             zaxis_title = ''))

camera = dict(eye = dict(x = 4, y = -8, z = 3))
fig.update_layout(scene_camera = camera)

fig.write_image('Figure2D_OmicronBA5.png', height = 1200, width = 1600, scale = 2)
#go_offline.plot(fig,filename='3d wt to omicron ba5.html',validate=True, auto_open=False)

In [None]:
susceptibility_prediction_all_available = []
for antibody_i in ['1_2', '25_26', '31_32', '88_89', '106_107', '108_109']:
    index = np.where(antibody_list_all_with_clinical[np.invert(clinical_tested_boolean)] == antibody_i)[0][0]
    susceptibility_prediction_all_available.append(kriging_prediction_results_delta[index])
for antibody_i in ['11_12', '106_107', '108_109']:
    index = np.where(antibody_list_all_with_clinical[np.invert(clinical_tested_boolean)] == antibody_i)[0][0]
    susceptibility_prediction_all_available.append(kriging_prediction_results_omicron_ba1[index])
for antibody_i in ['88_89', '106_107']:
    index = np.where(antibody_list_all_with_clinical[np.invert(clinical_tested_boolean)] == antibody_i)[0][0]
    susceptibility_prediction_all_available.append(kriging_prediction_results_omicron_ba5[index])
experiment_results_all_available = [0.5436 / 5.698, 133.8 / 69.64, 0.05577 / 0.4067, 7.167 / 17.89, 1.341 / 5.542, 0.4388 / 3.858, 
                                   2.276 / 1.587, 10.23 / 5.542, 31.27 / 3.858, 
                                   5.295 / 17.89, 146.6 / 5.542]

experimental_results_all_available = 1 / np.array(experiment_results_all_available)


susceptibility_prediction_only_one_available = []
for antibody_i in ['11_12', '19_20']:
    index = np.where(antibody_list_all_with_clinical[np.invert(clinical_tested_boolean)] == antibody_i)[0][0]
    susceptibility_prediction_only_one_available.append(kriging_prediction_results_delta[index])
for antibody_i in ['1_2', '25_26', '31_32', '88_89']:
    index = np.where(antibody_list_all_with_clinical[np.invert(clinical_tested_boolean)] == antibody_i)[0][0]
    susceptibility_prediction_only_one_available.append(kriging_prediction_results_omicron_ba1[index])
for antibody_i in ['1_2', '11_12', '25_26', '31_32', '108_109']:
    index = np.where(antibody_list_all_with_clinical[np.invert(clinical_tested_boolean)] == antibody_i)[0][0]
    susceptibility_prediction_only_one_available.append(kriging_prediction_results_omicron_ba5[index])

for antibody_i in ['1_2', '25_26', '31_32', '88_89', '106_107', '108_109']:
    index = np.where(antibody_list_all_with_clinical[np.invert(clinical_tested_boolean)] == antibody_i)[0][0]
    susceptibility_prediction_only_one_available.append(kriging_prediction_results_delta[index])
for antibody_i in ['11_12', '106_107', '108_109']:
    index = np.where(antibody_list_all_with_clinical[np.invert(clinical_tested_boolean)] == antibody_i)[0][0]
    susceptibility_prediction_only_one_available.append(kriging_prediction_results_omicron_ba1[index])
for antibody_i in ['88_89', '106_107']:
    index = np.where(antibody_list_all_with_clinical[np.invert(clinical_tested_boolean)] == antibody_i)[0][0]
    susceptibility_prediction_only_one_available.append(kriging_prediction_results_omicron_ba5[index])
experiment_results_only_one_available = [250 / 1.587, 136.7 / 416, 
                                         250 / 5.698, 250 / 69.64, 250 / 0.4067, 250 / 17.89, 
                                         250 / 5.698, 250 / 1.587, 250 / 69.64, 250 / 0.4067, 250 / 3.858,
                                         0.5436 / 5.698, 133.8 / 69.64, 0.05577 / 0.4067, 7.167 / 17.89, 1.341 / 5.542, 0.4388 / 3.858, 
                                         2.276 / 1.587, 10.23 / 5.542, 31.27 / 3.858, 
                                         5.295 / 17.89, 146.6 / 5.542]

susceptibility_prediction_only_one_available = np.array(susceptibility_prediction_only_one_available)
experiment_results_only_one_available = 1 / np.array(experiment_results_only_one_available)

In [None]:
print(get_auroc(susceptibility_prediction_all_available, np.array(experiment_results_all_available) > 1))
print(get_auprc(susceptibility_prediction_all_available, np.array(experiment_results_all_available) > 1))
print(np.mean(np.array(experiment_results_all_available) > 1))
print(get_auroc(-susceptibility_prediction_only_one_available, np.array(experiment_results_only_one_available) >= 1))
print(get_auprc(-susceptibility_prediction_only_one_available, np.array(experiment_results_only_one_available) >= 1))
print(np.mean(np.array(experiment_results_only_one_available) >= 1))

In [None]:
susceptibility_predicted_getting_worse_boolean = susceptibility_prediction_only_one_available < 0
susceptibility_predicted_getting_better_boolean = susceptibility_prediction_only_one_available >= 0

In [None]:
plt.figure(figsize=(8, 6), dpi=300)
boxplot_list = [np.log10(experiment_results_only_one_available[susceptibility_predicted_getting_better_boolean]), 
                np.log10(experiment_results_only_one_available[susceptibility_predicted_getting_worse_boolean])]
plt.boxplot(boxplot_list)

for i in range(2):
    y = boxplot_list[i]
    x = np.random.normal(2 - 1 + i, 0.04, size=len(y))
    plt.scatter(x, y, c = 'red', alpha=0.6, s = 20)

plt.xticks([1, 2], ['Predicted better (Positive)', 'Predicted worse (Negative)'], fontsize = 15)
plt.ylabel('log10 EC50 fold improvement', fontsize = 20)
plt.axhline(y = 0, linestyle = '--')
plt.tight_layout()
#plt.savefig('Antibody robustness prediction validation.png', facecolor = 'white')

In [None]:
fig, axes = plt.subplots(1, 2, figsize = (10, 4), dpi = 600)
#pred_* is the x coordiate of a bar
pred_positive = 1.1
pred_negative = -0.1
#wt_neutralization_list, each element is for one color (high, medium or low neutralization)
wt_neutralization_prediction_hist = [[pred_positive] * round(100 * 1/3) + [pred_negative] * round(100 * 10/15), 
                                     [pred_positive] * round(100 * 2/3) + [pred_negative] * round(100 * 5/15)]

colors = ['red', 'blue']
axes[0].hist(wt_neutralization_prediction_hist, bins = [-0.5, 0.5, 1.5], rwidth = 0.2, color = colors, stacked = True, label = ['Tested good', 'Tested bad'])
axes[0].set_xticks([0, 1], ['Predicted positive', 'Predicted negative'], fontsize = 10)
axes[0].set_ylabel('Experimental validated\npercentages', fontsize = 15)
axes[0].set_title('Binding inhibition', fontsize = 25)

wt_neutralization_prediction_hist = [[pred_positive] * round(100 * 1/3) + [pred_negative] * round(100 * 10/15), 
                                     [pred_positive] * round(100 * 2/3) + [pred_negative] * round(100 * 5/15)]

axes[1].hist(wt_neutralization_prediction_hist, bins = [-0.5, 0.5, 1.5], rwidth = 0.2, color = colors, stacked = True, label = ['Tested good', 'Tested bad'])
axes[1].set_xticks([0, 1], ['Predicted positive', 'Predicted negative'], fontsize = 10)
#axes[1].set_ylabel('Experimental validated\npercentages', fontsize = 15)
axes[1].set_title('SARS-CoV-2 infection', fontsize = 25)

plt.legend()

plt.tight_layout()
#plt.savefig('WT neutralization prediction validation histogram.png', facecolor = 'white')


In [None]:
undetectable_value = 350
binding_inhibition_boxplot = [[30.8, 2.997] + [undetectable_value] * 1, 
                              [32.84, 32.14, 5.667, 34.7, 104.5] + [undetectable_value] * 10]

ec50_neutralization_boxplot = [[17.89, 5.54] + [undetectable_value] * 1, 
                               [5.7, 0.41, 69.64, 3.86, 1.587] + [undetectable_value] * 10]



In [None]:
plt.figure(figsize=(8, 6), dpi=600)
boxplot_list = [np.log10(binding_inhibition_boxplot[0]), 
                np.log10(binding_inhibition_boxplot[1])]
plt.boxplot(boxplot_list)

for i in range(2):
    y = boxplot_list[i]
    x = np.random.normal(2 - 1 + i, 0.04, size=len(y))
    plt.scatter(x, y, c = 'red', alpha=0.6, s = 20)

plt.xticks([1, 2], ['Predicted better (Positive)', 'Predicted worse (Negative)'], fontsize = 15)
plt.ylabel('log10 binding inhibition', fontsize = 20)
plt.axhline(y = np.log10(250), linestyle = '--')
plt.yticks(np.log10([10, 50, 100, 250, undetectable_value]), ['10', '50', '100', '250', str(undetectable_value)])
plt.tight_layout()
#plt.savefig('Binding inhibition experimental validation.png', facecolor = 'white')

In [None]:
plt.figure(figsize=(8, 6), dpi=600)
boxplot_list = [np.log10(ec50_neutralization_boxplot[0]), 
                np.log10(ec50_neutralization_boxplot[1])]
plt.boxplot(boxplot_list)

for i in range(2):
    y = boxplot_list[i]
    x = np.random.normal(2 - 1 + i, 0.04, size=len(y))
    plt.scatter(x, y, c = 'red', alpha=0.6, s = 20)

plt.xticks([1, 2], ['Predicted better (Positive)', 'Predicted worse (Negative)'], fontsize = 15)
plt.ylabel('log10 neutralization EC50', fontsize = 20)
plt.axhline(y = np.log10(250), linestyle = '--')
plt.yticks(np.log10([10, 50, 100, 250, undetectable_value]), ['10', '50', '100', '250', str(undetectable_value)])
plt.tight_layout()
#plt.savefig('EC50 neutralization experimental validation.png', facecolor = 'white')

#tested antibody assessment table
binding_inhibition_prediction_accuracy = np.array(binding_inhibition != 'undetectable', dtype = float)
binding_inhibition_prediction_precision = np.zeros(len(binding_inhibition_prediction_accuracy))
for i in range(len(binding_inhibition)):
    binding_inhibition_prediction_precision[i] = np.mean(binding_inhibition_prediction_accuracy[0:(i+1)])
binding_inhibition_prediction_fold = binding_inhibition_prediction_precision / np.mean(binding_inhibition_prediction_accuracy)

neutralization_prediction_accuracy = np.array(ec50 != 'undetectable', dtype = float)
neutralization_prediction_precision = np.zeros(len(neutralization_prediction_accuracy))
for i in range(len(ec50)):
    neutralization_prediction_precision[i] = np.mean(neutralization_prediction_accuracy[0:(i+1)])
neutralization_prediction_fold = neutralization_prediction_precision / np.mean(neutralization_prediction_accuracy)

output = open('binding inhibition and ec50 accuracy.csv', 'w+')
for i in range(len(binding_inhibition_prediction_precision)):
    output.write(antibody_list_tested[i].split('_')[0] + '/' + antibody_list_tested[i].split('_')[1] + ',' + str(round(neutralization_prediction[i], 3)) + ',' + str(round(np.mean(neutralization_prediction[i] >= wt_neutralization_score), 3)) + ',' + binding_inhibition[i] + ',' + str(round(binding_inhibition_prediction_precision[i], 3)) + ',' + str(round(binding_inhibition_prediction_fold[i], 3)) + ',' + ec50[i] + ',' + str(round(neutralization_prediction_precision[i], 3)) + ',' + str(round(neutralization_prediction_fold[i], 3)) + '\n')
output.close()

In [None]:
antibody_id = '106_107'
index = np.where(antibody_list_all_with_clinical[np.invert(clinical_tested_boolean)] == antibody_id)[0][0]
print(wt_neutralization_score[index])
print(kriging_prediction_results_delta[index])
print(kriging_prediction_results_omicron_ba1[index])
print(kriging_prediction_results_omicron_ba5[index])


output = open('WT neutralization prediction.csv', 'w+')
output.write('Antibody ID,Predicted neutralization score,Heavy chain,Light chain\n')
order = np.argsort(-wt_neutralization_score)
for i in range(100):
    antibody_i = antibody_list_all_with_clinical[np.invert(clinical_tested_boolean)][order][i]
    neutralization_i = str(round(wt_neutralization_score[order][i], 3))
    output.write(antibody_i.split('_')[0] + '/' + antibody_i.split('_')[1] + ',' + neutralization_i + ',' + seq_dict_all[antibody_i]['H'] + ',' + seq_dict_all[antibody_i]['L'] + '\n')

output.close()

output = open('All antibodies wt neutralization and robustness.csv', 'w+')
output.write('Antibody ID,WT neutralization,Fold improvement (Delta),Fold improvement (Omicron BA1),Fold improvement (Omicron BA5),Heavy chain,Light chain\n')
for antibody_i in antibody_list_all_with_clinical[np.invert(clinical_tested_boolean)]:
    index = np.where(antibody_i == antibody_list_all_with_clinical[np.invert(clinical_tested_boolean)])[0][0]
    output.write(antibody_i.split('_')[0] + '/' + antibody_i.split('_')[1] + ',' + str(round(wt_neutralization_score[index], 4)) + ',' + str(round(kriging_prediction_results_delta[index], 4)) + ',' + str(round(kriging_prediction_results_omicron_ba1[index], 4)) + ',' + str(round(kriging_prediction_results_omicron_ba5[index], 4)) + ',' + seq_dict_all[antibody_i]['H'] + ',' + seq_dict_all[antibody_i]['L'] + '\n')

output.close()

In [None]:
#redesign 106/107 for delta
wt_ddg = -np.array([-0.59, -2.32, -1.65, -9.9, -7.22, -8.35, -7.48, -5.94, -1.56, -1.49])
delta_ddg = -np.array([-3.24, -2.01, -2.71, -11.51, -12.25, -11.82, -15.57, -5.02, -4.18, -4.04])
wt_ec50 = 6.454 / np.array([4.75, 47.37, 27.79, 3.459, 7.153, 3.902, 2.538, 5.465, 9.114, 5.512])
delta_ec50 = 4.958 / np.array([5.462, 9.361, 6.676, 3.671, 2.354, 5.513, 3.466, 4.101, 5.967, 4.81])

fontsize = 20
shift = 0.08
label = ['H F66K', 'H F66N', 'H F66C', 'L Y38S', 'L Y38A', 
         'L Y38Q', 'L Y38G', 'L Y55N', 'L Y55A', 'L D56C']
fig, axes = plt.subplots(1, 2, figsize = (10, 5), dpi = 600)
x = wt_ddg
y = wt_ec50
color_hannah = ['#FF2600', '#4FE9D6', '#BD37E9', '#FF9300', '#000000', '#A67728', '#001993', '#761564', '#B21800', '#0233FF']
axes[0].scatter(x = x, y = y, color = color_hannah)
for i in range(len(label)):
    axes[0].text(x[i] + shift, y[i] + shift / 5, label[i], fontsize = fontsize * 0.5)
#axes[0].set_xlabel('Predicted ddG', fontsize = fontsize)
#axes[0].set_ylabel('log10 transformed\nEC50 fold changes', fontsize = fontsize)
#axes[0].set_title('WT', fontsize = fontsize * 1.5)
axes[0].set_xlim(0, 11.5)
#axes[0].set_ylim(-0.5, 1)
a, b = np.polyfit(x, y, 1)
x_temp = np.arange(np.min(x), np.max(x))
axes[0].plot(x_temp, a*x_temp+b, linestyle = '--', color = 'gray')
x = delta_ddg
y = delta_ec50
axes[1].scatter(x = x, y = y, color = color_hannah)
for i in range(len(label)):
    axes[1].text(x[i] + shift, y[i] + shift / 5, label[i], fontsize = fontsize * 0.5)
#axes[1].set_xlabel('Predicted ddG', fontsize = fontsize)
#axes[1].set_ylabel('log10 transformed\nEC50', fontsize = fontsize)
#axes[1].set_title('Delta', fontsize = fontsize * 1.5)
axes[1].set_xlim(1, 17.5)
#axes[1].set_ylim(-0.5, 1)
a, b = np.polyfit(x, y, 1)
x_temp = np.arange(np.min(x), np.max(x))
axes[1].plot(x_temp, a*x_temp+b, linestyle = '--', color = 'gray')
plt.tight_layout()
#plt.savefig('Redesign 106 for delta.png', pmi = 3000, facecolor = 'white')

In [None]:
print(spearmanr(wt_ddg, wt_ec50))
print(spearmanr(delta_ddg, delta_ec50))


In [None]:
#caluclate the 'neutralization score' for clinical antibody-RBD docking structure
#comment the 'cocktail' antibody structure
clinical_antibody_structure_dict = {
    'ADI': {'pdb': '7u2d', 'RBD': 'A', 'antibody': 'H+L'}, 
    'AMU': {'pdb': '7cdi', 'RBD': 'E', 'antibody': 'H+L'},
    'BAM': {'pdb': '7kmg', 'RBD': 'C', 'antibody': 'A+B'},
    'BEB': {'pdb': '7mmo', 'RBD': 'C', 'antibody': 'A+B'}, 
    'C135': {'pdb': '7k8z', 'RBD': 'A', 'antibody': 'H+L'}, 
    'C144': {'pdb': '7k90', 'RBD': 'B', 'antibody': 'H+L'}, 
    #'CAS': {'pdb': '6xdg', 'RBD': 'E', 'antibody': 'B+D'}, 
    #'CIL': {'pdb': '7l7e', 'RBD': 'b', 'antibody': 'C+D'}, 
    'ETE': {'pdb': '7c01', 'RBD': 'A', 'antibody': 'H+L'}, 
    #'IMD': {'pdb': '6xdg', 'RBD': 'E', 'antibody': 'C+A'}, 
    'REG': {'pdb': '7cm4', 'RBD': 'A', 'antibody': 'H+L'}, 
    #'ROM': {'pdb': '8gx9', 'RBD': 'A', 'antibody': 'M+P'}, 
    #'SOT': {'pdb': '7r6w', 'RBD': 'R', 'antibody': 'A+B'}, 
    'TIX': {'pdb': '7l7d', 'RBD': 'E', 'antibody': 'H+L'}
}

In [None]:
#calculate the WT neutralization based on the ACE2 binding residues coverage
#running this section takes time
def neutralization_calculation(antibody_binding, unmutated_ace2_binding, ace2_binding_residue):
    return len(np.intersect1d(antibody_binding, unmutated_ace2_binding)) / len(ace2_binding_residue)

path = 'clinical_antibody_docking_structures/'
clinical_binding_residue_list = {}
pmi_score = {}
conformation_list = os.listdir(path)
for antibody_i in clinical_antibody_structure_dict:  
    pdb_i = clinical_antibody_structure_dict[antibody_i]
    pymol.cmd.load(path + '/' + clinical_antibody_structure_dict[antibody_i]['pdb'] + '.pdb')
    pymol.cmd.select('interface', '(chain ' + clinical_antibody_structure_dict[antibody_i]['RBD'] + ' and not H*) within 5 of (chain ' + clinical_antibody_structure_dict[antibody_i]['antibody'] + ' and not name H*)')
    pymol.cmd.save(path + '/ligand_interface.pdb', 'interface')
    pymol.cmd.delete('all')
    interface_input = open(path + '/ligand_interface.pdb').readlines()
    interface_resi_temp = []
    for row_i in interface_input:
        if(len(row_i.split()) >5):
            interface_resi_temp.append(row_i[22:].strip().split()[0])
    clinical_binding_residue_list[antibody_i] = np.array(np.unique(np.array(interface_resi_temp)), dtype = int)
    os.remove(path + 'ligand_interface.pdb')
    
clinical_antibody_neutralization_score = []
for antibody_i in clinical_binding_residue_list:
    clinical_antibody_neutralization_score.append(neutralization_calculation(clinical_binding_residue_list[antibody_i], ace2_binding_residue, ace2_binding_residue))


In [None]:
wt_neutralization_score = np.load('wt_neutralization_score.npy')
bins = np.arange(18) / 20
plt.hist(clinical_antibody_neutralization_score, density = True, bins = bins, label = 'Clinical antibodies', alpha = 0.5)
plt.hist(wt_neutralization_score, density = True, bins = bins, label = 'Patient antibodies', alpha = 0.5)
plt.ylim(0, 2.7)
plt.legend()
plt.xlabel('Predicted neutralization', fontsize = 20)
plt.title('WT neutralization', fontsize = 25)
#plt.savefig('WT neutralization histogram.png')

In [None]:
for i in range(len(list(clinical_binding_residue_list.keys()))):
    print(list(clinical_binding_residue_list.keys())[i] + ': ' + str(clinical_antibody_neutralization_score[i]))

In [None]:
wt_neutralization_score = np.load('wt_neutralization_score.npy')
plt.hist(wt_neutralization_score, bins = 20)