In [16]:
import torch
import json


class Residue:
    """Residue class"""

    def __init__(self, line: str) -> None:
        self.name = line[17:20].strip()
        self.num = int(line[22:26].strip())
        self.chainID = line[21].strip()
        self.x = float(line[30:38].strip())
        self.y = float(line[38:46].strip())
        self.z = float(line[46:54].strip())


class PDBStructure:
    """Structure class"""

    def __init__(self, filename: str, node_feats: str = "label") -> None:
        self.residues = {}
        self.parse_file(filename)

    def parse_file(self, filename: str) -> None:
        """Parse PDB file"""
        for line in open(filename, "r"):
            if line.startswith("ATOM") and line[12:16].strip() == "CA":
                res = Residue(line)
                self.residues[res.num] = res

    def get_coords(self) -> torch.Tensor:
        """Get coordinates of all atoms"""
        coords = [[res.x, res.y, res.z] for res in self.residues.values()]
        return torch.tensor(coords)

    def get_edges(self, threshold: float = 5) -> torch.Tensor:
        """Get edges of a graph using threshold as a cutoff"""
        coords = self.get_coords()
        dist = torch.cdist(coords, coords)
        edges = torch.where(dist < threshold)
        edges = torch.cat([arr.view(-1, 1) for arr in edges], axis=1)
        edges = edges[edges[:, 0] != edges[:, 1]]
        return edges.t()
    
    def store_graph(self, output_file, threshold=4):
        edges = self.get_edges(threshold)
        positions = self.get_coords()
        names = [f"{i}_{res.name}" for i, res in enumerate(self.residues.values())]
        output = {
            "nodes": [{
                "name": names[i],
                "x": positions[i][0].item(),
                "y": positions[i][1].item(),
                "z": positions[i][2].item(),
                "activity": 1 if i < 10 else 0,
            } for i in range(len(self.residues))],
            "edges": {
                "start": [x.item() for x in edges[0]],
                "end": [x.item() for x in edges[1]],
            },
        }
        json.dump(output, open(output_file, "w"))

In [17]:
prot = PDBStructure("3nir.pdb")
prot.store_graph("tmp.json")

In [20]:
import plotly.graph_objs as go
from plotly.graph_objs import Data, Line


def visualize():
    mol = json.load(open("tmp.json", "r"))
    colors = [255 * n['activity'] for n in mol['nodes']]
    acids = go.Scatter3d(
        x=[n['x'] for n in mol['nodes']],
        y=[n['y'] for n in mol['nodes']],
        z=[n['z'] for n in mol['nodes']],
        mode='markers',
        name='actors',
        marker={
            'symbol': 'circle',
            'size': 20,
            'color': [f'rgb({c},0,0)' for c in colors],
        },
        text=[f"{n['name']}" for n in mol['nodes']],
        hoverinfo='text',
    )
    # print(list((s, t) for s, t in zip(mol['edges']['start'], mol['edges']['end']) if s < t))
    # print(list((mol['nodes'][s]['x'], mol['nodes'][t]['x']) for s, t in zip(mol['edges']['start'], mol['edges']['end']) if s < t))
    X, Y, Z = [], [], []
    for s, t in zip(mol['edges']['start'], mol['edges']['end']):
        if s < t:
            X += [mol['nodes'][s]["x"], mol['nodes'][t]["x"]]
            Y += [mol['nodes'][s]["y"], mol['nodes'][t]["y"]]
            Z += [mol['nodes'][s]["z"], mol['nodes'][t]["z"]]
    edges = go.Scatter3d(
        x=X,
        y=Y,
        z=Z,
        mode='lines',
        line={
            'color': 'rgb(0,0,0)',
            'width': 5,
        },
        hoverinfo='none',
    )
    axis = {
        'showbackground': False,
        'zeroline': False,
        'showgrid': False,
        'showticklabels': False,
        'showspikes': False,
        'title': '',
    }
    layout = go.Layout(
        title="Protein",
        width=1000,
        height=1000,
        showlegend=False,
        scene={
            'xaxis': axis,
            'yaxis': axis,
            'zaxis': axis,
        },
        margin={
            't': 100,
        },
        hovermode='closest',
        annotations=[],
    )
    fig = go.Figure(data=[edges, acids], layout=layout)
    fig.show()
visualize()