In [None]:
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import matplotlib.pyplot as plt
from umap import UMAP
from scipy.stats import spearmanr, pearsonr
from sklearn import metrics
import glob, pymol, lmdb, os
import pickle as pkl
from skgstat.models import spherical
from scipy.linalg import solve
from math import dist
from scipy.spatial.distance import squareform
from sklearn import metrics
from venn import venn
import scipy.cluster.hierarchy as sch
import radialtree as rt
from adjustText import adjust_text
import plotly.offline as go_offline
import plotly.graph_objects as go
from scipy.interpolate import LinearNDInterpolator


In [None]:
#antibody_list_batch1 = ['1_2', '7_8', '11_12', '13_14', '15_16', '17_18', '19_20', '25_26', '31_32', '43_44', '51_52', '51_53', '86_87', '88_89', '92_93', '94_95', '96_97', '104_105', '106_107', '108_109', '110_111', '116_117', '118_119']
#4 antibodies, 7/8, 51/53, 116/117, 118/119 are removed.
antibody_list_batch1 = ['1_2', '11_12', '13_14', '15_16', '17_18', '19_20', '25_26', '31_32', '43_44', '51_52', '86_87', '88_89', '92_93', '94_95', '96_97', '104_105', '106_107', '108_109', '110_111']

clinical_antibody_list = ['ADI', 'AMU', 'BAM', 'BEB', 'C135', 'C144', 'CAS', 'CIL', 'ETE', 'IMD', 'REG', 'ROM', 'SOT', 'TIX']
ace2_binding_residue = np.array([417, 446, 447, 449, 453, 455, 456, 473, 475, 476, 477, 484, 486, 487, 489, 490, 493, 494, 495, 496, 498, 500, 501, 502, 503, 505])
seq_dict_all = np.load('seq_dict_all_filtered.npy', allow_pickle = True).item()
antibody_list_all_with_clinical = np.load('antibody_list_all_with_clinical.npy')
print(len(antibody_list_all_with_clinical))

In [None]:
#experiment results
clinical_antibody_fc = {'ADI': {'Delta': 1.5, 'Omicron BA1': 108, 'Omicron BA5': 935},
                        'BAM': {'Delta': 1000, 'Omicron BA1': 1000, 'Omicron BA5': 686}, 
                        'BEB': {'Delta': 1, 'Omicron BA1': 1, 'Omicron BA5': 1}, 
                        'CAS': {'Delta': 0.7, 'Omicron BA1': 1000, 'Omicron BA5': 1000}, 
                        'CIL': {'Delta': 2.1, 'Omicron BA1': 1000, 'Omicron BA5': 9.4}, 
                        'ETE': {'Delta': 0.5, 'Omicron BA1': 414, 'Omicron BA5': 444}, 
                        'IMD': {'Delta': 2.1, 'Omicron BA1': 1000, 'Omicron BA5': 633}, 
                        'SOT': {'Delta': 1.3, 'Omicron BA1': 3.8, 'Omicron BA5': 16}, 
                        'TIX': {'Delta': 1, 'Omicron BA1': 306, 'Omicron BA5': 1000}, 
                        'REG': {'Delta': 28, 'Omicron BA1': 1000, 'Omicron BA5': 1000}, 
                        'AMU': {'Delta': 0.6, 'Omicron BA1': 136, 'Omicron BA5': 116}, 
                        'ROM': {'Delta': 'NA', 'Omicron BA1': 0.8, 'Omicron BA5': 64}, 
                        'C135': {'Delta': 0.4, 'Omicron BA1': 1000, 'Omicron BA5': 1000}, 
                        'C144': {'Delta': 2.5, 'Omicron BA1': 1000, 'Omicron BA5': 1000}
                       }

patient_antibody_ec50 = {
                        '1_2': {'WT': 5.698, 'Delta': 0.5436, 'Omicron BA1': 250, 'Omicron BA5': 250},
                        #'7_8': {'WT': 250, 'Delta': 20.33, 'Omicron BA1': 250, 'Omicron BA5': 250},
                        '13_14': {'WT': 250, 'Delta': 250, 'Omicron BA1': 250, 'Omicron BA5': 250},
                        '15_16': {'WT': 250, 'Delta': 250, 'Omicron BA1': 250, 'Omicron BA5': 250},
                        '17_18': {'WT': 250, 'Delta': 250, 'Omicron BA1': 250, 'Omicron BA5': 250},
                        '19_20': {'WT': 250, 'Delta': 136.7, 'Omicron BA1': 250, 'Omicron BA5': 250},
                        '25_26': {'WT': 69.64, 'Delta': 133.8, 'Omicron BA1': 250, 'Omicron BA5': 250},
                        '31_32': {'WT': 0.4067, 'Delta': 0.05577, 'Omicron BA1': 250, 'Omicron BA5': 250},
                        '43_44': {'WT': 250, 'Delta': 250, 'Omicron BA1': 250, 'Omicron BA5': 250},
                        #'118_119': {'WT': 250, 'Delta': 250, 'Omicron BA1': 250, 'Omicron BA5': 10.17},
                        '51_52': {'WT': 250, 'Delta': 250, 'Omicron BA1': 250, 'Omicron BA5': 250},
                        #'51_53': {'WT': 250, 'Delta': 202.8, 'Omicron BA1': 250, 'Omicron BA5': 250},
                        '86_87': {'WT': 250, 'Delta': 250, 'Omicron BA1': 250, 'Omicron BA5': 250},
                        '88_89': {'WT': 17.89, 'Delta': 7.167, 'Omicron BA1': 250, 'Omicron BA5': 5.296},
                        '92_93': {'WT': 250, 'Delta': 250, 'Omicron BA1': 250, 'Omicron BA5': 250},
                        '94_95': {'WT': 250, 'Delta': 250, 'Omicron BA1': 250, 'Omicron BA5': 250},
                        '96_97': {'WT': 250, 'Delta': 250, 'Omicron BA1': 250, 'Omicron BA5': 250},
                        '104_105': {'WT': 250, 'Delta': 250, 'Omicron BA1': 250, 'Omicron BA5': 250},
                        '108_109': {'WT': 3.858, 'Delta': 0.4388, 'Omicron BA1': 31.27, 'Omicron BA5': 250},
                        '110_111': {'WT': 250, 'Delta': 250, 'Omicron BA1': 250, 'Omicron BA5': 250},
                        #'116_117': {'WT': 250, 'Delta': 250, 'Omicron BA1': 250, 'Omicron BA5': 250},
                        '106_107': {'WT': 5.542, 'Delta': 1.341, 'Omicron BA1': 10.23, 'Omicron BA5': 146.6},
                        '11_12': {'WT': 1.587, 'Delta': 250, 'Omicron BA1': 2.276, 'Omicron BA5': 250}
}

single_antibody_binding_inhibition = {
        '1_2':     32.14, 
        #'7_8':     'undetectable', 
        '11_12':   'undetectable', 
        '13_14':   'undetectable', 
        '15_16':   'undetectable', 
        '17_18':   'undetectable', 
        '19_20':   32.84, 
        '25_26':   34.76, 
        '31_32':   5.667,  
        '43_44':   'undetectable', 
        '51_52':   'undetectable',
        #'51_53':   'undetectable',
        '86_87':   'undetectable', 
        '88_89':   30.68, 
        '92_93':   'undetectable', 
        '94_95':   'undetectable', 
        '96_97':   'undetectable',
        '104_105': 'undetectable', 
        '106_107': 2.997, 
        '108_109': 104.5, 
        '110_111': 'undetectable',
        #'116_117': 'undetectable',
        #'118_119': 'undetectable'
        }

In [None]:
language_model_name = ['esm1v_t33_650M_UR90S_1', 
                       'esm1v_t33_650M_UR90S_2', 
                       'esm1v_t33_650M_UR90S_3', 
                       'esm1v_t33_650M_UR90S_4', 
                       'esm1v_t33_650M_UR90S_5', 
                       'esm2_650m', 
                       'esm2_3b', 
                       'esm2_15b', 
                       'ProtT5'][8]
umap_2d = UMAP(n_components=2, init='random', random_state=0, metric='euclidean')
model = lambda h: spherical(h, len(clinical_antibody_label) + 2, 1, 0.0)
unsquareform = lambda a: a[np.nonzero(np.triu(a))]
chain_embedding = np.load('antibody_language_embedding/covid_' + language_model_name + '_embedding.npy', allow_pickle = True).item()
dimension = 1024
language_model_embedding = np.zeros((len(antibody_list_all_with_clinical), dimension * 2))
for i in range(len(antibody_list_all_with_clinical)):
    antibody_i = antibody_list_all_with_clinical[i]
    language_model_embedding[i, 0:dimension] = chain_embedding[antibody_i + '_H']
    language_model_embedding[i, dimension:] = chain_embedding[antibody_i + '_L']
proj_2d = umap_2d.fit_transform(language_model_embedding)
chain_similarity = np.zeros((len(antibody_list_all_with_clinical), len(antibody_list_all_with_clinical)))
for i in range(len(antibody_list_all_with_clinical) - 1):
    for j in range(i + 1, len(antibody_list_all_with_clinical)):
        chain_similarity[i, j] = dist(language_model_embedding[i], language_model_embedding[j])
np.save('language model ' + language_model_name + ' chain similarity.npy', chain_similarity)
#np.save('umap_projection_2d.npy', proj_2d)

In [None]:
variant_i = 'Delta'
clinical_antibody_labeled_list = []
clinical_antibody_label = []
for antibody_i in clinical_antibody_list:
    if(not clinical_antibody_fc[antibody_i][variant_i] == 'NA'):
        clinical_antibody_labeled_list.append(antibody_i)
        clinical_antibody_label.append(clinical_antibody_fc[antibody_i][variant_i])

clinical_antibody_label = np.array(clinical_antibody_label, dtype = float)
clinical_antibody_label = np.log10(clinical_antibody_label)
clinical_antibody_labeled_index = []
for antibody_i in clinical_antibody_labeled_list:
    clinical_antibody_labeled_index.append(np.where(antibody_list_all_with_clinical == antibody_i)[0][0])

#language embedding based chain similarity
chain_similarity = np.load('language model ' + language_model_name + ' chain similarity.npy')
model = lambda h: spherical(h, len(clinical_antibody_label) + 2, 1, 0.0)
unsquareform = lambda a: a[np.nonzero(np.triu(a))]
kriging_prediction_results = []
for antibody_i in antibody_list_all_with_clinical[np.invert(clinical_tested_boolean)]:
    kriging_antibody_index = []
    kriging_antibody_index.append(np.where(antibody_i == antibody_list_all_with_clinical)[0][0])
    for clinical_antibody_i in clinical_antibody_labeled_list:
        kriging_antibody_index.append(np.where(clinical_antibody_i == antibody_list_all_with_clinical)[0][0])
    kriging_antibody_index = np.array(kriging_antibody_index, dtype = int)
    kriging_similarity_matrix = chain_similarity[kriging_antibody_index][:, kriging_antibody_index]
    variance = kriging_similarity_matrix[0, 1:]
    variance = np.concatenate((variance, [1]))
    kriging_similarity_matrix = kriging_similarity_matrix[1:][:, 1:]
    kriging_similarity_matrix = unsquareform(kriging_similarity_matrix)
    kriging_similarity_matrix = model(kriging_similarity_matrix)
    kriging_similarity_matrix = squareform(kriging_similarity_matrix)
    kriging_similarity_matrix = np.concatenate((kriging_similarity_matrix, np.zeros((len(clinical_antibody_labeled_index), 1)) + 1), axis = 1)
    last_row = np.zeros((1, len(clinical_antibody_labeled_index) + 1)) + 1
    last_row[0, -1] = 0
    kriging_similarity_matrix = np.concatenate((kriging_similarity_matrix, last_row), axis = 0)
    weights = solve(kriging_similarity_matrix, variance)
    kriging_prediction_results.append(clinical_antibody_label.dot(weights[:-1]))
kriging_prediction_results = np.array(kriging_prediction_results, dtype = float)
kriging_prediction_results_delta = kriging_prediction_results
#kriging_prediction_results_delta_standardized = ( kriging_prediction_results - np.min(kriging_prediction_results) ) / (np.max(kriging_prediction_results) - np.min(kriging_prediction_results))

variant_i = 'Omicron BA1'
clinical_antibody_labeled_list = []
clinical_antibody_label = []
for antibody_i in clinical_antibody_list:
    if(not clinical_antibody_fc[antibody_i][variant_i] == 'NA'):
        clinical_antibody_labeled_list.append(antibody_i)
        clinical_antibody_label.append(clinical_antibody_fc[antibody_i][variant_i])

clinical_antibody_label = np.array(clinical_antibody_label, dtype = float)
clinical_antibody_label = np.log10(clinical_antibody_label)
clinical_antibody_labeled_index = []
for antibody_i in clinical_antibody_labeled_list:
    clinical_antibody_labeled_index.append(np.where(antibody_list_all_with_clinical == antibody_i)[0][0])

#language embedding based chain similarity
chain_similarity = np.load('language model ' + language_model_name + ' chain similarity.npy')
model = lambda h: spherical(h, len(clinical_antibody_label) + 2, 1, 0.0)
unsquareform = lambda a: a[np.nonzero(np.triu(a))]
kriging_prediction_results = []
for antibody_i in antibody_list_all_with_clinical[np.invert(clinical_tested_boolean)]:
    kriging_antibody_index = []
    kriging_antibody_index.append(np.where(antibody_i == antibody_list_all_with_clinical)[0][0])
    for clinical_antibody_i in clinical_antibody_labeled_list:
        kriging_antibody_index.append(np.where(clinical_antibody_i == antibody_list_all_with_clinical)[0][0])
    kriging_antibody_index = np.array(kriging_antibody_index, dtype = int)
    kriging_similarity_matrix = chain_similarity[kriging_antibody_index][:, kriging_antibody_index]
    variance = kriging_similarity_matrix[0, 1:]
    variance = np.concatenate((variance, [1]))
    kriging_similarity_matrix = kriging_similarity_matrix[1:][:, 1:]
    kriging_similarity_matrix = unsquareform(kriging_similarity_matrix)
    kriging_similarity_matrix = model(kriging_similarity_matrix)
    kriging_similarity_matrix = squareform(kriging_similarity_matrix)
    kriging_similarity_matrix = np.concatenate((kriging_similarity_matrix, np.zeros((len(clinical_antibody_labeled_index), 1)) + 1), axis = 1)
    last_row = np.zeros((1, len(clinical_antibody_labeled_index) + 1)) + 1
    last_row[0, -1] = 0
    kriging_similarity_matrix = np.concatenate((kriging_similarity_matrix, last_row), axis = 0)
    weights = solve(kriging_similarity_matrix, variance)
    kriging_prediction_results.append(clinical_antibody_label.dot(weights[:-1]))
kriging_prediction_results = np.array(kriging_prediction_results, dtype = float)
kriging_prediction_results_omicron_ba1 = kriging_prediction_results
#kriging_prediction_results_omicron_ba1_standardized = ( kriging_prediction_results - np.min(kriging_prediction_results) ) / (np.max(kriging_prediction_results) - np.min(kriging_prediction_results))

variant_i = 'Omicron BA5'
clinical_antibody_labeled_list = []
clinical_antibody_label = []
for antibody_i in clinical_antibody_list:
    if(not clinical_antibody_fc[antibody_i][variant_i] == 'NA'):
        clinical_antibody_labeled_list.append(antibody_i)
        clinical_antibody_label.append(clinical_antibody_fc[antibody_i][variant_i])

clinical_antibody_label = np.array(clinical_antibody_label, dtype = float)
clinical_antibody_label = np.log10(clinical_antibody_label)
clinical_antibody_labeled_index = []
for antibody_i in clinical_antibody_labeled_list:
    clinical_antibody_labeled_index.append(np.where(antibody_list_all_with_clinical == antibody_i)[0][0])

#language embedding based chain similarity
chain_similarity = np.load('language model ' + language_model_name + ' chain similarity.npy')
model = lambda h: spherical(h, len(clinical_antibody_label) + 2, 1, 0.0)
unsquareform = lambda a: a[np.nonzero(np.triu(a))]
kriging_prediction_results = []
for antibody_i in antibody_list_all_with_clinical[np.invert(clinical_tested_boolean)]:
    kriging_antibody_index = []
    kriging_antibody_index.append(np.where(antibody_i == antibody_list_all_with_clinical)[0][0])
    for clinical_antibody_i in clinical_antibody_labeled_list:
        kriging_antibody_index.append(np.where(clinical_antibody_i == antibody_list_all_with_clinical)[0][0])
    kriging_antibody_index = np.array(kriging_antibody_index, dtype = int)
    kriging_similarity_matrix = chain_similarity[kriging_antibody_index][:, kriging_antibody_index]
    variance = kriging_similarity_matrix[0, 1:]
    variance = np.concatenate((variance, [1]))
    kriging_similarity_matrix = kriging_similarity_matrix[1:][:, 1:]
    kriging_similarity_matrix = unsquareform(kriging_similarity_matrix)
    kriging_similarity_matrix = model(kriging_similarity_matrix)
    kriging_similarity_matrix = squareform(kriging_similarity_matrix)
    kriging_similarity_matrix = np.concatenate((kriging_similarity_matrix, np.zeros((len(clinical_antibody_labeled_index), 1)) + 1), axis = 1)
    last_row = np.zeros((1, len(clinical_antibody_labeled_index) + 1)) + 1
    last_row[0, -1] = 0
    kriging_similarity_matrix = np.concatenate((kriging_similarity_matrix, last_row), axis = 0)
    weights = solve(kriging_similarity_matrix, variance)
    kriging_prediction_results.append(clinical_antibody_label.dot(weights[:-1]))
kriging_prediction_results = np.array(kriging_prediction_results, dtype = float)
kriging_prediction_results_omicron_ba5 = kriging_prediction_results
#kriging_prediction_results_omicron_ba5_standardized = ( kriging_prediction_results - np.min(kriging_prediction_results) ) / (np.max(kriging_prediction_results) - np.min(kriging_prediction_results))


In [None]:
print(np.mean((kriging_prediction_results_delta >= -1) & (kriging_prediction_results_delta <= 3)))
print(np.mean((kriging_prediction_results_omicron_ba1 >= -1) & (kriging_prediction_results_omicron_ba1 <= 3)))
print(np.mean((kriging_prediction_results_omicron_ba5 >= -1) & (kriging_prediction_results_omicron_ba5 <= 3)))

kriging_prediction_results_delta[kriging_prediction_results_delta <= -1] = -1
kriging_prediction_results_delta[kriging_prediction_results_delta >= 3] = 3
kriging_prediction_results_omicron_ba1[kriging_prediction_results_omicron_ba1 <= -1] = -1
kriging_prediction_results_omicron_ba1[kriging_prediction_results_omicron_ba1 >= 3] = 3
kriging_prediction_results_omicron_ba5[kriging_prediction_results_omicron_ba5 <= -1] = -1
kriging_prediction_results_omicron_ba5[kriging_prediction_results_omicron_ba5 >= 3] = 3

#to have the fold improvement
kriging_prediction_results_delta = - kriging_prediction_results_delta
kriging_prediction_results_omicron_ba1 = - kriging_prediction_results_omicron_ba1
kriging_prediction_results_omicron_ba5 = - kriging_prediction_results_omicron_ba5

#kriging_prediction_results_delta_standardized = 1 - ((kriging_prediction_results_delta + 1) / 4)
#kriging_prediction_results_omicron_ba1_standardized = 1 - ((kriging_prediction_results_omicron_ba1 + 1) / 4)
#kriging_prediction_results_omicron_ba5_standardized = 1 - ((kriging_prediction_results_omicron_ba5 + 1) / 4)


In [None]:
susceptibility_prediction_all_available = []
for antibody_i in ['1_2', '25_26', '31_32', '88_89', '106_107', '108_109']:
    index = np.where(antibody_list_all_with_clinical[np.invert(clinical_tested_boolean)] == antibody_i)[0][0]
    susceptibility_prediction_all_available.append(kriging_prediction_results_delta[index])
for antibody_i in ['11_12', '106_107', '108_109']:
    index = np.where(antibody_list_all_with_clinical[np.invert(clinical_tested_boolean)] == antibody_i)[0][0]
    susceptibility_prediction_all_available.append(kriging_prediction_results_omicron_ba1[index])
for antibody_i in ['88_89', '106_107']:
    index = np.where(antibody_list_all_with_clinical[np.invert(clinical_tested_boolean)] == antibody_i)[0][0]
    susceptibility_prediction_all_available.append(kriging_prediction_results_omicron_ba5[index])
experiment_results_all_available = [0.5436 / 5.698, 133.8 / 69.64, 0.05577 / 0.4067, 7.167 / 17.89, 1.341 / 5.542, 0.4388 / 3.858, 
                                   2.276 / 1.587, 10.23 / 5.542, 31.27 / 3.858, 
                                   5.295 / 17.89, 146.6 / 5.542]

experimental_results_all_available = 1 / np.array(experiment_results_all_available)


susceptibility_prediction_only_one_available = []
for antibody_i in ['11_12', '19_20']:
    index = np.where(antibody_list_all_with_clinical[np.invert(clinical_tested_boolean)] == antibody_i)[0][0]
    susceptibility_prediction_only_one_available.append(kriging_prediction_results_delta[index])
for antibody_i in ['1_2', '25_26', '31_32', '88_89']:
    index = np.where(antibody_list_all_with_clinical[np.invert(clinical_tested_boolean)] == antibody_i)[0][0]
    susceptibility_prediction_only_one_available.append(kriging_prediction_results_omicron_ba1[index])
for antibody_i in ['1_2', '11_12', '25_26', '31_32', '108_109']:
    index = np.where(antibody_list_all_with_clinical[np.invert(clinical_tested_boolean)] == antibody_i)[0][0]
    susceptibility_prediction_only_one_available.append(kriging_prediction_results_omicron_ba5[index])

for antibody_i in ['1_2', '25_26', '31_32', '88_89', '106_107', '108_109']:
    index = np.where(antibody_list_all_with_clinical[np.invert(clinical_tested_boolean)] == antibody_i)[0][0]
    susceptibility_prediction_only_one_available.append(kriging_prediction_results_delta[index])
for antibody_i in ['11_12', '106_107', '108_109']:
    index = np.where(antibody_list_all_with_clinical[np.invert(clinical_tested_boolean)] == antibody_i)[0][0]
    susceptibility_prediction_only_one_available.append(kriging_prediction_results_omicron_ba1[index])
for antibody_i in ['88_89', '106_107']:
    index = np.where(antibody_list_all_with_clinical[np.invert(clinical_tested_boolean)] == antibody_i)[0][0]
    susceptibility_prediction_only_one_available.append(kriging_prediction_results_omicron_ba5[index])
experiment_results_only_one_available = [250 / 1.587, 136.7 / 416, 
                                         250 / 5.698, 250 / 69.64, 250 / 0.4067, 250 / 17.89, 
                                         250 / 5.698, 250 / 1.587, 250 / 69.64, 250 / 0.4067, 250 / 3.858,
                                         0.5436 / 5.698, 133.8 / 69.64, 0.05577 / 0.4067, 7.167 / 17.89, 1.341 / 5.542, 0.4388 / 3.858, 
                                         2.276 / 1.587, 10.23 / 5.542, 31.27 / 3.858, 
                                         5.295 / 17.89, 146.6 / 5.542]

susceptibility_prediction_only_one_available = np.array(susceptibility_prediction_only_one_available)
experiment_results_only_one_available = 1 / np.array(experiment_results_only_one_available)

susceptibility_predicted_getting_worse_boolean = susceptibility_prediction_only_one_available < 0
susceptibility_predicted_getting_better_boolean = susceptibility_prediction_only_one_available >= 0

In [None]:
plt.figure(figsize=(8, 6), dpi=300)
boxplot_list = [np.log10(experiment_results_only_one_available[susceptibility_predicted_getting_worse_boolean]),
                np.log10(experiment_results_only_one_available[susceptibility_predicted_getting_better_boolean])]
plt.boxplot(boxplot_list)

for i in range(2):
    y = boxplot_list[i]
    x = np.random.normal(2 - 1 + i, 0.04, size=len(y))
    plt.scatter(x, y, c = 'red', alpha=0.6, s = 20)

plt.xticks([2, 1], ['Predicted better (Positive)', 'Predicted worse (Negative)'], fontsize = 15 * 1.2)
plt.ylabel('log10 EC50 fold improvement', fontsize = 20 * 1.2)
plt.axhline(y = 0, linestyle = '--')
plt.tight_layout()
#plt.savefig('Antibody robustness prediction validation ' + language_model_name + '.png', facecolor = 'white')