In [32]:
import plotly.graph_objects as go

def plotly_edges(graph, start_l, end_l):
    start_pos = graph[start_l].pos
    end_pos = graph[end_l].pos
    edges = graph[start_l, end_l].edge_index
    xe=[]
    ye=[]
    ze=[]
    for s_pos, e_pos in zip(start_pos[edges[0]],end_pos[edges[1]]):
        xe+=[s_pos[0], e_pos[0], None]
        ye+=[s_pos[1], e_pos[1], None]
        ze+=[s_pos[2], e_pos[2], None]
    return {'x': xe, 'y':ye, 'z':ze, 'mode' : 'lines', 'name': f'{start_l}-{end_l}'}

def plotly_nodes(graph,l):
    poses = graph[l].pos
    xn=[]
    yn=[]
    zn=[]
    for pos in poses:
        xn.append(pos[0])
        yn.append(pos[1])
        zn.append(pos[2])
    return {'x': xn, 'y':yn, 'z':zn, 'mode': 'markers', 'name': l}

def make_fig(graph):
    traces=[
        go.Scatter3d(**plotly_edges(graph, 'ligand', 'ligand'), line=dict(color='red', width=5)),
        go.Scatter3d(**plotly_nodes(graph, 'ligand'), marker=dict(symbol='circle', size=6, color='blue')),
        go.Scatter3d(**plotly_nodes(graph, 'receptor'), marker=dict(symbol='circle', size=3, color='green')),
        go.Scatter3d(**plotly_edges(graph, 'atom', 'receptor'), line=dict(color='yellow', width=3)),
        go.Scatter3d(**plotly_nodes(graph, 'atom'), marker=dict(symbol='circle', size=3, color='orange')),
        go.Scatter3d(**plotly_edges(graph, 'atom', 'atom'), line=dict(color='pink', width=3)),
        go.Scatter3d(**plotly_edges(graph, 'ligand', 'atom'), line=dict(color='light blue', width=3)),
        
    ]
    fig = go.Figure(data=traces, layout=go.Layout(
    width=500,
    height=500,
        ))
    return fig

In [33]:
from datasets.fsmol_dock import FsDock


ds = FsDock('data/smol','data/tasks_smol.csv')


[2024-Dec-23 19:47:36 IST] [fsmol_dock.py:68] INFO - loading data for FS-Dock
  0%|          | 0/5 [00:00<?, ?it/s]

100%|██████████| 5/5 [00:00<00:00,  7.19it/s]


In [45]:
from datasets.process_mols import hide_sidechains


graph = ds.get(3)['graphs'][20]


In [47]:
make_fig(graph).show()

In [48]:
# sub_graph = hide_sidechains(graph)
sub_graph = ds.get(3)['graphs'][30]
make_fig(sub_graph).show()

In [37]:
from rdkit import Chem
Chem.MolToSmiles(graph.mol)


'CN1CCN(CCCn2c3ccc(O)cc3c3c4c(c(-c5ccccc5)cc32)C(=O)NC4=O)CC1'

In [38]:
graph

HeteroData(
  name='CHEMBL1000361_40',
  mol=<rdkit.Chem.rdchem.Mol object at 0x7f6a903ea2c0>,
  core=<rdkit.Chem.rdchem.Mol object at 0x7f6a903ea810>,
  core_smiles='[1*]N1CCN(CCCn2c3ccc([2*])cc3c3c4c(c(-c5ccccc5)cc32)C(=O)NC4=O)CC1',
  side_chains=<rdkit.Chem.rdchem.Mol object at 0x7f6a903ea860>,
  side_chains_smiles='[1*]C.[2*]O',
  core_indices=[33],
  side_chain_indices=[2],
  activity_type='B',
  label=1,
  [1mligand[0m={
    x=[35, 16],
    pos=[35, 3]
  },
  [1mreceptor[0m={
    x=[247, 1281],
    pos=[247, 3]
  },
  [1matom[0m={
    x=[84, 4],
    pos=[84, 3]
  },
  [1m(ligand, lig_bond, ligand)[0m={
    edge_index=[2, 80],
    edge_attr=[80, 4]
  },
  [1m(receptor, to, receptor)[0m={ edge_index=[2, 44514] },
  [1m(atom, to, atom)[0m={ edge_index=[2, 780] },
  [1m(atom, to, receptor)[0m={ edge_index=[2, 84] },
  [1m(ligand, to, receptor)[0m={ edge_index=[2, 5329] },
  [1m(ligand, to, atom)[0m={ edge_index=[2, 350] }
)