In [None]:
import numpy as np

In [None]:
#!pip install plotly

In [None]:
import plotly.io as pio
dpi = 600
format = ".tif"
width = 2.0
height = 2.0
scale = 6
text_color = "black"

In [None]:
spikes = ["SPWT","Epsl","Zeta","Beta", "Alfa","Delt","Kapp","Gamm","Iota","Iot2","Eta1","Ihu1","Omi1","Omi5"]

th_min = 4.0
th_max = 8.0

pdbFilesPath = "PDB/"

spike_muts_variants = {'Wild Type': np.array([]),
                         'Epsilon': np.array([ 13, 152, 452, 614]),
                         'Zeta': np.array([ 484,  614, 1176]),
                         'Beta': np.array([ 80, 215, 246, 417, 484, 501, 614, 701]),
                         'Alpha': np.array([ 501,  570,  614,  681,  716,  982, 1118]),
                         'Delta': np.array([158, 452, 478, 614, 681, 950]),
                         'Kappa': np.array([ 154,  484,  614,  681, 1071]),
                         'Gamma': np.array([ 138,  190,  417,  484,  501,  614,  655, 1027]),
                         'Iota1': np.array([  5,  95, 253, 477, 614]),
                         'Iota2': np.array([  5,  95, 253, 484, 614]),
                         'Eta': np.array([ 52,  67, 484, 614, 677, 888]),
                         'Ihu': np.array([  96,  190,  210,  346,  394,  449,  490,  501,  614,  681,  859,
                                 936, 1191]),
                         'Omicron1': np.array([ 67,  95, 142, 212, 339, 371, 373, 375, 417, 440, 446, 477, 478,
                                484, 493, 496, 498, 501, 505, 547, 614, 655, 679, 681, 764, 796,
                                856, 954, 969, 981]),
                         'Omicron5': np.array([142, 213, 339, 371, 373, 375, 376, 405, 408, 417, 440, 452, 477,
                                478, 484, 486, 498, 501, 505, 614, 655, 679, 681, 764, 796, 954,
                                969])
}
                

In [None]:
import pandas as pd 

A_s = dict()

for i, pdb in enumerate(spikes):
    
    #read PCN
    name = list(spike_muts_variants.keys())[i]
    A = np.loadtxt(r"Data\Adj\{}_adj_mat_{}_{}.txt".format(pdb.lower(), str(th_min), str(th_max)))
    A_s[name] = A

In [None]:
A_s.keys()

In [None]:
def compute_contact_similarity(contact_sim_matrix):
    n, m = contact_sim_matrix.shape
    total = n*m
    n_diff_contacts = np.count_nonzero(contact_sim_matrix)
    return ((total-n_diff_contacts)/total)*100.0

In [None]:
def readPDBFile(pdbFilePath):

    atoms = []
    with open(pdbFilePath) as pdbfile:
        for line in pdbfile:
            if line[:4] == 'ATOM':
              # Split the line
                splitted_line = [line[:6], line[6:11], line[12:16], line[17:20], line[21], line[22:26], line[30:38], line[38:46], line[46:54], line[56:61], line[62:66]]
                atoms.append(splitted_line)
                # To format again the pdb file with the fields extracted
                #print("%-6s%5s %4s %3s %s %4s    %8s%8s%8s   %3s%3s\n"%tuple(splitted_line))
    return np.array(atoms)  

def getModelledResidues(pdbFilePath):
    
    modelled_residues = []
    atoms = readPDBFile(pdbFilePath)
    
    last_resi = 0
    for i, atom in enumerate(atoms):
        
        resn = atom[3] #res name
        resi = int(atom[5]) #res index
        chain = atom[4]
        
        if resi != last_resi:
            modelled_residues.append((str(resi)+ " " + chain))
            last_resi = resi

    return modelled_residues

def map_Graph(G, residue_names):
     
    mapping = dict()
    for i, residue_name in enumerate(residue_names):
        mapping[i] = residue_name
            
    #relabel nodes
    G = nx.relabel_nodes(G, mapping)
    return G
    
def removeUnmodelled(G1, G2):
    
    modelled1 = list(dict(G1.nodes()).keys())
    modelled2 = list(dict(G2.nodes()).keys())
    
    for node1 in modelled1:
        if node1 not in modelled2:
            G1.remove_node(node1)
            
    for node2 in modelled2:
        if node2 not in modelled1:
            G2.remove_node(node2)
            
    return G1, G2

In [None]:
def plot_contact_sim(contact_sim, var1, var2):
        
    n = contact_sim.shape[0]
    fig = go.Figure(data=go.Heatmap(
                      z=contact_sim,
                      x=np.arange(0, n),
                      y=np.arange(0, n),
                      colorscale='OrRd'
                    )
                 )
    #fig.show()    
    return fig

In [None]:
import plotly.graph_objects as go

In [None]:
def compute_contact_similarity_matrix(A1, A2):
    
    diff = A1 != A2
    contact_sim = np.zeros((diff.shape[0], diff.shape[1]))
    for i in range (diff.shape[0]):
        for j in range (diff.shape[1]):
            elem = diff[i][j]
            if elem: #contact difference
                contact_sim[i, j] = 1.0
            else:
                contact_sim[i, j] = 0.0
    return contact_sim

In [None]:
import networkx as nx 

n = len(spikes)
contacts_sim_mat = {}
figures = {}
contacts_sim = np.zeros((n, n), dtype = float)

for i, (var1, A1) in enumerate(A_s.items()):

    n1 = A1.shape[0]
    G_variant1 = nx.from_numpy_matrix(A1)
    pdb1 = spikes[i]
    modelled_nodes_variant1 = getModelledResidues("Data/PDB/{}.pdb".format(pdb1))
    for j, (var2, A2) in enumerate(A_s.items()):
        if i > j:
            if var1 != var2:
                if "{}!={}".format(var1, var2) not in contacts_sim_mat.keys():
                    print(var1, var2)
                    n2 = A2.shape[0]

                    G_variant2 = nx.from_numpy_matrix(A2)
                    pdb2 = spikes[j]
                    modelled_nodes_variant2 = getModelledResidues("Data/PDB/{}.pdb".format(pdb2))

                    G_variant1 = map_Graph(G_variant1, modelled_nodes_variant1)  
                    G_variant2 = map_Graph(G_variant2, modelled_nodes_variant2)  
                    G_variant1_new, G_variant2_new = removeUnmodelled(G_variant1, G_variant2)  
                    A1_copy = np.array(nx.adjacency_matrix(G_variant1_new).todense())
                    A2_copy = np.array(nx.adjacency_matrix(G_variant2_new).todense())
                    
                    contact_sim_matrix = compute_contact_similarity_matrix(A1_copy, A2_copy)
                    contacts_sim_mat[var1+"!="+var2] = contact_sim_matrix
                    contact_sim = compute_contact_similarity(contact_sim_matrix)
                    contacts_sim[i, j] = contact_sim
                    contacts_sim[j, i] = contact_sim
                    print("Contact similarity between {} and {} Spike SARS-CoV-2 variants: {} %".format(var1, var2, contact_sim))
                    #fig = plot_contact_sim(contact_sim_matrix, var1, var2)
                    #figures[var1+"!="+var2] = fig
        elif i == j:
            print("i == j, Contact similarity between {} and {} Spike SARS-CoV-2 variants: {} %".format(var1, var2, 100.0))
            contacts_sim[i, j] = 100.0
            contacts_sim[j, i] = 100.0

In [None]:
import pandas as pd
df_cs = pd.DataFrame(contacts_sim, columns = list(A_s.keys()), index = list(A_s.keys()))
df_cs.to_csv("Data\contact_sim_variants.csv")

In [None]:
#!pip install -U kaleido

In [None]:
#!pip install --upgrade nbformat

In [None]:
#plot contact sim 
import plotly.figure_factory as ff
from scipy.spatial.distance import pdist, squareform

ticksuffix = "                             "
# Initialize figure by creating upper dendrogram
labels = np.array(sorted(list(spike_muts_variants.keys())))
labels_no_mutant = np.array(sorted([elem for elem in list(spike_muts_variants.keys()) if "Mutant" not in elem]))
i_no_mutant = [i for i, elem in enumerate(list(spike_muts_variants.keys())) if "Mutant" not in elem]
n_no_mut = labels_no_mutant.shape[0]
n = labels.shape[0]
contacts_sim_nomutant = np.zeros((n_no_mut, n_no_mut))
idx=0
for i in range(n):
    jdx=0
    if i in i_no_mutant:
        for j in range(n):
            if j in i_no_mutant:
                contacts_sim_nomutant[idx, jdx] = contacts_sim[i, j]
                jdx+=1
        idx+=1
print(contacts_sim_nomutant)

fig = ff.create_dendrogram(contacts_sim_nomutant, orientation='bottom', labels=labels_no_mutant)
for i in range(len(fig['data'])):
    fig['data'][i]['yaxis'] = 'y2'

# Create Side Dendrogram
dendro_side = ff.create_dendrogram(contacts_sim_nomutant, orientation='right')
for i in range(len(dendro_side['data'])):
    dendro_side['data'][i]['xaxis'] = 'x2'
    
# Add Side Dendrogram Data to Figure
for data in dendro_side['data']:
    fig.add_trace(data)

# Create Heatmap
dendro_leaves = dendro_side['layout']['yaxis']['ticktext']
dendro_leaves = list(map(int, dendro_leaves))
data_dist = [1-contacts_sim_nomutant[i, j] for i in range(contacts_sim_nomutant.shape[0]) for j in range(contacts_sim_nomutant.shape[1]) if i > j]  #pdist(contacts_sim_nomutant)
heat_data = contacts_sim_nomutant #squareform(data_dist)
heat_data = heat_data[dendro_leaves,:]
heat_data = heat_data[:,dendro_leaves]

# Create Heatmap
dendro_leaves = dendro_side['layout']['yaxis']['ticktext']
dendro_leaves = list(map(int, dendro_leaves))
data_dist = [1-contacts_sim_nomutant[i, j] for i in range(contacts_sim_nomutant.shape[0]) for j in range(contacts_sim_nomutant.shape[1]) if i > j]  #pdist(contacts_sim_nomutant)
heat_data = contacts_sim_nomutant #squareform(data_dist)
heat_data = heat_data[dendro_leaves,:]
heat_data = heat_data[:,dendro_leaves]

heatmap = [
    go.Heatmap(
        x = dendro_leaves,
        y = dendro_leaves,
        z = heat_data,
        colorscale = 'Blues'
    )
]

heatmap[0]['x'] = fig['layout']['xaxis']['tickvals']
heatmap[0]['y'] = dendro_side['layout']['yaxis']['tickvals']

# Add Heatmap Data to Figure
for data in heatmap:
    fig.add_trace(data)

# Edit Layout
fig.update_layout({'width':width*dpi, 'height':height*dpi,
                         'showlegend':False, 'hovermode': 'closest', 'title_font_color': text_color, 'font_color': text_color
                         })
# Edit xaxis
fig.update_layout(xaxis={'domain': [.15, 1],
                                  'mirror': False,
                                  'showgrid': False,
                                  'showline': False,
                                  'zeroline': False,
                                  'tickmode': "array",
                                  'tickvals': np.arange(5, 215, 10),
                                  'ticktext': "<b>"+fig['layout']['xaxis']['ticktext']+"</b>"+ticksuffix})
# Edit xaxis2
fig.update_layout(xaxis2={'domain': [0, .15],
                                   'mirror': False,
                                   'showgrid': False,
                                   'showline': False,
                                   'zeroline': False,
                                   'showticklabels': False,
                                   'ticks':""})

# Edit yaxis
fig.update_layout(yaxis={'domain': [0, .85],
                                  'mirror': False,
                                  'showgrid': False,
                                  'showline': False,
                                  'zeroline': False,
                                  'showticklabels': True, 
                                  'tickmode': "array",
                                  'tickvals': np.arange(5, 215, 10),
                                  'ticktext': "<b>"+fig['layout']['xaxis']['ticktext']+"</b>"+ticksuffix
                        })
# Edit yaxis2
fig.update_layout(yaxis2={'domain':[.825, .975],
                                   'mirror': False,
                                   'showgrid': False,
                                   'showline': False,
                                   'zeroline': False,
                                   'showticklabels': False,
                                   'ticks': ""})

# Plot!
fig.show()

#Save
fig.write_html("Figures/ContactSimilarity/contactDistNormNoMut.html")
pio.write_image(fig, "Figures/ContactSimilarity/contactDistNormNoMut.png", engine = "orca", width=width*dpi, height=height*dpi, scale=scale)

In [None]:
#plot contact sim 
import plotly.figure_factory as ff
from scipy.spatial.distance import pdist, squareform

ticksuffix = "                             "
# Initialize figure by creating upper dendrogram
labels = np.array(sorted(list(spike_muts_variants.keys())))
n = labels.shape[0]

fig = ff.create_dendrogram(contacts_sim, orientation='bottom', labels=labels)
for i in range(len(fig['data'])):
    fig['data'][i]['yaxis'] = 'y2'

# Create Side Dendrogram
dendro_side = ff.create_dendrogram(contacts_sim, orientation='right')
for i in range(len(dendro_side['data'])):
    dendro_side['data'][i]['xaxis'] = 'x2'
    
# Add Side Dendrogram Data to Figure
for data in dendro_side['data']:
    fig.add_trace(data)

# Create Heatmap
dendro_leaves = dendro_side['layout']['yaxis']['ticktext']
dendro_leaves = list(map(int, dendro_leaves))
data_dist = [1-contacts_sim[i, j] for i in range(contacts_sim.shape[0]) for j in range(contacts_sim.shape[1]) if i > j]  #pdist(contacts_sim)
heat_data = contacts_sim #squareform(data_dist)
heat_data = heat_data[dendro_leaves,:]
heat_data = heat_data[:,dendro_leaves]

# Create Heatmap
dendro_leaves = dendro_side['layout']['yaxis']['ticktext']
dendro_leaves = list(map(int, dendro_leaves))
data_dist = [1-contacts_sim[i, j] for i in range(contacts_sim.shape[0]) for j in range(contacts_sim.shape[1]) if i > j]  #pdist(contacts_sim)
heat_data = contacts_sim #squareform(data_dist)
heat_data = heat_data[dendro_leaves,:]
heat_data = heat_data[:,dendro_leaves]

heatmap = [
    go.Heatmap(
        x = dendro_leaves,
        y = dendro_leaves,
        z = heat_data,
        colorscale = 'Blues'
    )
]

heatmap[0]['x'] = fig['layout']['xaxis']['tickvals']
heatmap[0]['y'] = dendro_side['layout']['yaxis']['tickvals']

# Add Heatmap Data to Figure
for data in heatmap:
    fig.add_trace(data)

# Edit Layout
fig.update_layout({'width':width*dpi, 'height':height*dpi,
                         'showlegend':False, 'hovermode': 'closest', 'title_font_color': text_color, 'font_color': text_color
                         })
# Edit xaxis
fig.update_layout(xaxis={'domain': [.15, 1],
                                  'mirror': False,
                                  'showgrid': False,
                                  'showline': False,
                                  'zeroline': False,
                                  'tickmode': "array",
                                  'tickvals': np.arange(5, 215, 10),
                                  'ticktext': "<b>"+fig['layout']['xaxis']['ticktext']+"</b>"+ticksuffix})
# Edit xaxis2
fig.update_layout(xaxis2={'domain': [0, .15],
                                   'mirror': False,
                                   'showgrid': False,
                                   'showline': False,
                                   'zeroline': False,
                                   'showticklabels': False,
                                   'ticks':""})

# Edit yaxis
fig.update_layout(yaxis={'domain': [0, .85],
                                  'mirror': False,
                                  'showgrid': False,
                                  'showline': False,
                                  'zeroline': False,
                                  'showticklabels': True, 
                                  'tickmode': "array",
                                  'tickvals': np.arange(5, 215, 10),
                                  'ticktext': "<b>"+fig['layout']['xaxis']['ticktext']+"</b>"+ticksuffix,
                        })
# Edit yaxis2
fig.update_layout(yaxis2={'domain':[.825, .975],
                                   'mirror': False,
                                   'showgrid': False,
                                   'showline': False,
                                   'zeroline': False,
                                   'showticklabels': False,
                                   'ticks': ""})

# Plot!
fig.show()

#Save
fig.write_html("Figures/ContactSimilarity/contactDistNorm.html")
pio.write_image(fig, "Figures/ContactSimilarity/contactDistNorm.png", engine = "orca", width=width*dpi, height=height*dpi, scale=scale)

In [None]:
N = contacts_sim.shape[0]
labels = list(spike_muts_variants.keys())
labels = ["<b>{}</b>".format(label) for label in labels]
fig = go.Figure(data=go.Heatmap(
                      z=contacts_sim,
                      x=np.arange(0, N),
                      y=np.arange(0, N),
                      colorscale='OrRd'
                    )
                 )
# Edit yaxis
fig.update_layout(yaxis={         'mirror': False,
                                  'showgrid': False,
                                  'showline': False,
                                  'zeroline': False, 
                                  'showticklabels': True,
                                  'ticklabelposition': "outside bottom",
                                  'tickmode': "array",
                                  'tickvals': np.arange(0, N, 1),
                                  'ticktext': labels,
                                  'tickfont_size': 20
                        })
# Edit xaxis
fig.update_layout(xaxis={

                                  'mirror': False,
                                  'showgrid': False,
                                  'showline': False,
                                  'zeroline': False,
                                  'showticklabels': True,
                                  'ticklabelposition': "outside bottom",
                                  'tickmode': "array",
                                  'tickvals': np.arange(0, N, 1),
                                  'ticktext': labels,
                                  'tickfont_size': 20
                                
    
                        })
fig.update_layout(width=width*dpi, height=height*dpi, title_font_color = text_color, font_color = text_color)
fig.show()

#Save
fig.write_html("Figures/ContactSimilarity/contactSim.html")
pio.write_image(fig, "Figures/ContactSimilarity/contactSim.png", engine = "orca", width=width*dpi, height=height*dpi, scale=scale)

In [None]:
var1 = "Omicron1"
var2 = "Omicron5"
if var1+"!="+var2 in contacts_sim_mat.keys():
    key = var1+"!="+var2 
elif var2+"!="+var1 in contacts_sim_mat.keys():
    key = var2+"!="+var1 
fig = plot_contact_sim(contacts_sim_mat[key], var1, var2)
fig.show()

In [None]:
contact_sim_matrix

In [None]:
var1, var2

In [None]:
curr_min = 100.0
curr_i_min = 0
curr_j_min = 0
for i in range(n):
    for j in range(n):
        curr_value = contacts_sim[i][j]
        if curr_value > 0.0:
            if curr_value < curr_min:
                curr_min = curr_value
                curr_i_min = i
                curr_j_min = j
                
print("Min {} found for i = {} and j = {}".format(curr_min, curr_i_min, curr_j_min))    
print(curr_i_min, spikes[curr_i_min])
print(curr_j_min, spikes[curr_j_min])