In [1]:
import numpy as np 
import pandas as pd 
import os 
import sys 
import plotly.graph_objects as go 

import pyKVFinder

In [2]:
sys.path.append('/data/jlu/OR_learning/utils/')

import BindingCavity_functions as bc 
import color_function as cf 


In [167]:
import importlib 

importlib.reload(bc)

<module 'BindingCavity_functions' from '/data/jlu/OR_learning/utils/BindingCavity_functions.py'>

In [8]:
# Define paths of AF files 
AF2_PATH = '/data/jlu/AF_files/AF_tmaligned_pdb'
pdb_files = os.listdir(AF2_PATH)


In [None]:
# Simple workflow of using pyKVFinder and visualizing binding cavity
results = pyKVFinder.run_workflow(os.path.join(AF2_PATH, pdb_files[0]))

manual_color = cf.distinct_colors(np.unique(results.cavities))
manual_color[0] = '#D3D3D3' # Define protein space as grey 

bc.plot_3d_array(results.cavities, color_dict=manual_color)

In [None]:
"""
Conceptual pipeline for batch analysis. 

Utilizes pyKVFinder standard workflow 
- extract cavity grid 
- convert cavity grid to coordinates 
- super impose original pdb protein with cavity

"""


TEST_Olfr = ['Olfr224', 'Olfr330', 'Olfr1377']

# sulfur_Olfr = ['Olfr224', 'Olfr328', 'Olfr330', 'Olfr329'] # Sulfur responding Olfr

pdb_to_read = []
for _pdb in pdb_files: 
    if _pdb.split('_')[0] in TEST_Olfr: 
        pdb_to_read.append(_pdb)

# Running pyKVFinder standard workflow for cavity grid
bc_results = []
# bc_atomic = []

for _pdb in pdb_to_read: 
    bc_results.append(pyKVFinder.run_workflow(os.path.join(AF2_PATH, _pdb)))
    # bc_atomic.append(pyKVFinder.read_pdb(os.path.join(AF2_PATH, _pdb)))

# For testing the grid
vertices = []
for _atomic in bc_atomic: 
    vertices.append(pyKVFinder.get_vertices(_atomic))

# Extracting cav coordinates from cavity grid 
bc_cav_coords = []
for _results in bc_results: 
    bc_cav_coords.append(bc.grid2coords(_results))
    
# Read in original pdb to visualize 
pdb_coords = []
for _pdb in pdb_to_read: 
    coords, backbone, _ = bc.load_pdb_coordinates(os.path.join(AF2_PATH, _pdb))
    pdb_coords.append([coords, backbone])


# Figure plotting
fig = go.Figure()

color_map = cf.distinct_colors(list(range(3)))
for i, _cv in enumerate(bc_cav_coords): 
    for _cv_x in _cv[0]: # Iterate and plot individual cavities
        fig.add_trace(go.Scatter3d(
            x=_cv[0][_cv_x].T[0],
            y=_cv[0][_cv_x].T[1],
            z=_cv[0][_cv_x].T[2],
            mode='markers',
            marker=dict(
                size=5,
                color=color_map[i],  # Color for this value
                opacity=0.2
            ),
            name=f"{pdb_to_read[i].split('.')[0]} cav",  # Add the value to the legend
            legendgroup=f"cavity_{i}",  # Group legend entries for this value
            showlegend=_cv_x == list(_cv[0].keys())[0]  # Only show legend for the first trace in this group
        ))
        
for i, _pdb in enumerate(pdb_coords): # Iterate and plot individual cavities
        fig.add_trace(go.Scatter3d(
            y=_pdb[1].T[0],
            x=_pdb[1].T[1],
            z=_pdb[1].T[2],
            mode='markers',
            marker=dict(
                size=5,
                color=color_map[i],  # Color for this value
                opacity=0.3
            ),
            name=f"{pdb_to_read[i].split('.')[0]} prot"  # Add the value to the legend
        ))
    
fig.update_layout(
    scene=dict(
            xaxis=dict(visible=False, showbackground=False),
            yaxis=dict(visible=False, showbackground=False),
            zaxis=dict(visible=False, showbackground=False)
        ),
    margin=dict(r=10, l=10, b=10, t=10),
    legend_title="Values"
)

fig.show()
# fig.write_html('/data/jlu/OR_learning/output/bining_cavity/misc/TEST_cav_alignment.html')

In [82]:
arr = bc_results[0].cavities

In [83]:
x, y, z = np.indices(arr.shape)
values = arr.flatten()

# Filter out the -1 values (empty space)
mask = values == 0
prot_x = x.flatten()[mask]
prot_y = y.flatten()[mask]
prot_z = z.flatten()[mask]

# Add the trace to the figure
fig.add_trace(go.Scatter3d(
    x=prot_x,
    y=prot_y,
    z=prot_z,
    mode='markers',
    marker=dict(
        size=trace_size,
        color=color_map[val],  # Color for this value
        opacity=trace_opacity
    ),
    name=str(val)  # Add the value to the legend
))

# Filter for cavities 

In [85]:
values

array([-1, -1, -1, ..., -1, -1, -1], shape=(3815700,), dtype=int32)

In [None]:
color_map = cf.distinct_colors(list(range(len(aligned_results))))

fig = go.Figure()

fig.add_trace(go.Scatter3d(
    x = coords1.T[0], 
    y = coords1.T[1], 
    z = coords1.T[2], 
    mode='markers', 
    marker=dict(
        size=5, 
        color=color_map[0], 
        opacity=0.5
    )
))

for i, _result in enumerate(aligned_results):
    fig.add_trace(go.Scatter3d(
        x = _result[0].T[0], 
        y = _result[0].T[1], 
        z = _result[0].T[2], 
        mode='markers', 
        marker=dict(
            size=5, 
            color=color_map[i], 
            opacity=0.5
        )
    ))

# Add layout details
fig.update_layout(
    width=600, height=600,
    scene=dict(
        xaxis=dict( title='X', visible=False, showbackground=False),
        yaxis=dict( title='Y', visible=False, showbackground=False),
        zaxis=dict( title='Z', visible=False, showbackground=False)
    ),
    margin=dict(r=10, l=10, b=10, t=10),
    legend_title="Values"
)
fig.show()

In [None]:
color_map = cf.distinct_colors([0,1,2,3])

fig = go.Figure()

fig.add_trace(go.Scatter3d(
    x = coords1.T[0], 
    y = coords1.T[1], 
    z = coords1.T[2], 
    mode='markers', 
    marker=dict(
        size=5, 
        color=color_map[0], 
        opacity=0.8
    )
))
fig.add_trace(go.Scatter3d(
    x = coords2.T[0], 
    y = coords2.T[1], 
    z = coords2.T[2], 
    mode='markers', 
    marker=dict(
        size=5, 
        color=color_map[1],
        opacity=0.8
    )
))

coords3 = np.dot(coords2, aligned_result.u)
fig.add_trace(go.Scatter3d(
    x = coords3.T[0], 
    y = coords3.T[1], 
    z = coords3.T[2], 
    mode='markers', 
    marker=dict(
        size=5, 
        color=color_map[2],
        opacity=0.8
    )
))

fig.add_trace(go.Scatter3d(
    x = backbone4.T[0], 
    y = backbone4.T[1], 
    z = backbone4.T[2], 
    mode='markers', 
    marker=dict(
        size=5, 
        color=color_map[3],
        opacity=0.8
    )
))
# Add layout details
fig.update_layout(
    width=600, height=600,
    scene=dict(
        xaxis=dict( title='X', visible=False, showbackground=False),
        yaxis=dict( title='Y', visible=False, showbackground=False),
        zaxis=dict( title='Z', visible=False, showbackground=False)
    ),
    margin=dict(r=10, l=10, b=10, t=10),
    legend_title="Values"
)
fig.show()