Import modules

In [1]:
import os
import pickle
import numpy as np
import pandas as pd
import geometricus as gm
from sklearn.decomposition import PCA
from variables import CAZY_DATA, STRUCTURE_CLUSTERING
from numba.core.errors import NumbaDeprecationWarning, NumbaPendingDeprecationWarning
import warnings
import umap as UMAP
import plotly.graph_objects as go

warnings.simplefilter('ignore', category=NumbaDeprecationWarning)
warnings.simplefilter('ignore', category=NumbaPendingDeprecationWarning)

  @numba.jit()
  @numba.jit()
  @numba.jit()
  @numba.jit()


Handling arguments

In [2]:
# Arguments
structures_folder = '../../Data/AF_core'
resolution = 10 
n_threads = 60 
kmer = 8 
radius = 5

# Handle kmer/radius fragmentation logic
if kmer and radius:
    split = [
        gm.SplitInfo(gm.SplitType.KMER, kmer), 
        gm.SplitInfo(gm.SplitType.RADIUS, radius)
        ]
elif kmer and not radius:
    split = [gm.SplitInfo(gm.SplitType.KMER, kmer)]
elif not kmer and radius:
    split = [gm.SplitInfo(gm.SplitType.RADIUS, radius)]
else:
    raise ValueError('Fragmentation types cannot be both null')

Run Geometricus experiment

In [3]:
# Run Geometricus
invariants, _ = gm.get_invariants_for_structures(
    structures_folder, 
    n_threads = n_threads,
    split_infos = split,
    moment_types = ["O_3", "O_4", "O_5", "F"]
    )

shapemer_class = gm.Geometricus.from_invariants(
    invariants, 
    protein_keys = [structure for structure in os.listdir(structures_folder)], 
    resolution = resolution
    )

shapemer_count_matrix = shapemer_class.get_count_matrix()

# Normalizationfor protein length
shapemer_sum = np.sum(shapemer_count_matrix, axis = 1)
normalized_matrix = shapemer_count_matrix/shapemer_sum[:, None]

Found 3184 protein structures


  0%|          | 0/3184 [00:00<?, ?it/s]@> 954 atoms and 1 coordinate set(s) were parsed in 0.16s.
@> 1164 atoms and 1 coordinate set(s) were parsed in 0.14s.
@> 1088 atoms and 1 coordinate set(s) were parsed in 0.18s.
@> 1068 atoms and 1 coordinate set(s) were parsed in 0.18s.
@> 1075 atoms and 1 coordinate set(s) were parsed in 0.19s.
@> 1078 atoms and 1 coordinate set(s) were parsed in 0.18s.
@> 1382 atoms and 1 coordinate set(s) were parsed in 0.22s.
@> 1004 atoms and 1 coordinate set(s) were parsed in 0.20s.
@> 1121 atoms and 1 coordinate set(s) were parsed in 0.21s.
@> 1405 atoms and 1 coordinate set(s) were parsed in 0.19s.
@> 1466 atoms and 1 coordinate set(s) were parsed in 0.17s.
@> 1266 atoms and 1 coordinate set(s) were parsed in 0.23s.
@> 1404 atoms and 1 coordinate set(s) were parsed in 0.23s.
@> 1543 atoms and 1 coordinate set(s) were parsed in 0.18s.
@> 1178 atoms and 1 coordinate set(s) were parsed in 0.23s.
@> 1093 atoms and 1 coordinate set(s) were parsed in 0.24s.
@

Computed invariants in 56.03 seconds


100%|██████████| 3184/3184 [00:02<00:00, 1332.82it/s]


Store matrix with annotations

In [10]:
proteins = [protein.replace('.pdb', '') for protein in shapemer_class.protein_keys]
shapemers = shapemer_class.shapemer_keys
matrix = pd.DataFrame(normalized_matrix, index = proteins, columns = shapemers)
matrix.to_pickle(f'{STRUCTURE_CLUSTERING}/default_core.pkl')

Dimensionality reduction

In [11]:
# UMAP
umap_space = UMAP.UMAP(
    metric = "cosine", 
    n_components = 2
    ).fit_transform(normalized_matrix)

# PCA
pca = PCA()
pca_space = pca.fit_transform(normalized_matrix)

# PCA scree plot
PC_values = np.arange(pca.n_components_) + 1
PC_cumulative = pca.explained_variance_ratio_.cumsum()

Add annotations to embeddings and prepare them for plotting

In [17]:
# Add family
with open(f'{CAZY_DATA}/uniprot_family.pkl', 'rb') as dic:
    dic = pickle.load(dic)

# Add annotations to embeddings

umap = pd.DataFrame(umap_space[:, 0:2], index = proteins)
pca = pd.DataFrame(pca_space[:, 0:2], index = proteins)
families = [dic[protein] for protein in proteins]
umap['Family'] = families
pca['Family'] = families

# Disregard AA0
umap = umap[umap['Family'] != 'AA0']
pca = pca[pca['Family'] != 'AA0']

# One Hot Encode families into colors
dummies = pd.get_dummies(umap['Family'])

# Translate family to color
family_color = {
    'AA9' : '#1f77b4',
    'AA10' : '#ff7f0e',
    'AA11' : '#2ca02c',
    'AA13' : '#d62728',
    'AA14' : '#9467bd',
    'AA15' : '#8c564b',
    'AA16' : '#e377c2',
    'AA17' : '#bcbd22'
}
for column in dummies:
    dummies[column] = dummies[column].replace(1, family_color[column])
    dummies[column] = dummies[column].replace(0, 'grey')

# Add all colors
dummies['all'] = umap['Family'].apply(lambda x: family_color[x])

Plotly

In [19]:
# Scatter plots
scatter_umap = go.Scatter(
    x = umap[0],
    y = umap[1],
    name = "Default UMAP",
    mode = 'markers',
    visible = True,
    hovertemplate = umap.index.to_list(),
    marker = dict(color = ['grey']*len(dummies))
)

scatter_pca = go.Scatter(
    x = pca[0],
    y = pca[1],
    name = "Default PCA",
    mode = 'markers',
    visible = False,
    hovertemplate = pca.index.to_list(),
    marker=dict(color = ['grey']*len(dummies))
)
x = np.arange(10)
scree_plot = go.Scatter(
    x = PC_values, 
    y = PC_cumulative, 
    visible = False
    )

fig = go.FigureWidget([scatter_umap, scatter_pca, scree_plot])

button_layer_1_height = 1.11
fig.update_layout(
    updatemenus=[
        dict(
            buttons=list([
                
                dict(
                    label = "None",
                    args = [{"marker": {"color":['grey']*len(dummies)}}],
                    method = "restyle"
                ),  

                dict(
                    label = "All",
                    args = [{"marker": {"color":dummies['all']}}],
                    method = "restyle"
                ),  

                dict(
                    label = "AA9",
                    args = [{"marker": {"color": dummies['AA9']}}],
                    method = "restyle"
                ),    

                dict(
                    label = "AA10",
                    args = [{"marker": {"color": dummies['AA10']}}],
                    method = "restyle"
                ),   
                
                dict(
                    label = "AA11",
                    args = [{"marker": {"color": dummies['AA11']}}],
                    method = "restyle"
                ),   

                dict(
                    label = "AA13",
                    args = [{"marker": {"color": dummies['AA13']}}],
                    method = "restyle"
                ), 
                
                dict(
                    label = "AA14",
                    args = [{"marker": {"color": dummies['AA14']}}],
                    method = "restyle"
                ),   

                dict(
                    label = "AA15",
                    args = [{"marker": {"color": dummies['AA15']}}],
                    method = "restyle"
                ),   

                dict(
                    label = "AA16",
                    args = [{"marker": {"color": dummies['AA16']}}],
                    method = "restyle"
                ),                  
                
                dict(
                    label = "AA17",
                    args = [{"marker": {"color": dummies['AA17']}}],
                    method = "restyle"
                ),   

            ]), 
            direction="down",
            pad={"r": 10, "t": 10},
            showactive=True,
            x=0.20,
            xanchor="left",
            y=button_layer_1_height,
            yanchor="top"
        ),
        dict(
            buttons=list([
                dict(label="UMAP",
                     method="update",
                     args=[{"visible": [True, False, False]},
                           {"title": "UMAP"}]),
                dict(label="PCA",
                     method="update",
                     args=[{"visible": [False, True, False]},
                           {"title": "PCA"}]),
                dict(label="Scree",
                     method="update",
                     args=[{"visible": [False, False, True]},
                           {"title": "Scree plot"}]               
                ),
            ]),
            direction="down",
            pad={"r": 10, "t": 10},
            showactive=True,
            x=0,
            xanchor="left",
            y=button_layer_1_height,
            yanchor="top"
        )
    ]
)

# Update size 
fig.update_layout(
    autosize=False,
    width=600,
    height=600)

# Create our callback function
def selection_fn(trace, points, selector):
    # Give shapemer and protein info to normalized_matrix
    df = pd.DataFrame(
        data = normalized_matrix, 
        columns = shapemer_class.shapemer_keys,
        index = [key.split('.')[0] for key in shapemer_class.protein_keys]
        )
    
    # Calculate most frequent shapemer/protein
    point_indeces = points.point_inds
    if point_indeces:
        print(point_indeces)
        points = df.iloc[point_indeces].to_numpy()
        point_names = set(df.iloc[point_indeces].index)
        index = np.unravel_index(np.argmax(points, axis=None), points.shape)
        most_frequent_shapemer = df.columns[index[1]]
        print(most_frequent_shapemer)

        # Map shapemer to residues
        selectedname_residue = dict()
        name_residue = dict(shapemer_class.map_shapemer_to_residues(most_frequent_shapemer))
        for name in point_names:
            if name + '.pdb' in name_residue:
                selectedname_residue[name] = name_residue[name + '.pdb']
        print(selectedname_residue)

fig.data[0].on_selection(selection_fn)
fig.data[1].on_selection(selection_fn)

fig

FigureWidget({
    'data': [{'hovertemplate': [A0A1B3XTW0, A0A6S6VTC2, A0A1H0DH62, ...,
                                A0A084VRN9, A0A6G7P3W3, K1PFV1],
              'marker': {'color': [grey, grey, grey, ..., grey, grey, grey]},
              'mode': 'markers',
              'name': 'Default UMAP',
              'type': 'scatter',
              'uid': 'fd437fe8-727e-4155-a5c6-a0ff3c52e99b',
              'visible': True,
              'x': array([ -1.7209421 ,   5.0150385 ,   4.5605326 , ..., -10.248204  ,
                           -0.94421047, -10.1088915 ], dtype=float32),
              'y': array([ -3.8707855,   4.663015 , -11.406788 , ...,  18.201656 ,  16.358072 ,
                           17.909603 ], dtype=float32)},
             {'hovertemplate': [A0A1B3XTW0, A0A6S6VTC2, A0A1H0DH62, ...,
                                A0A084VRN9, A0A6G7P3W3, K1PFV1],
              'marker': {'color': [grey, grey, grey, ..., grey, grey, grey]},
              'mode': 'markers',
             