In [65]:
import pandas as pd
import numpy as np
import torch
import py3Dmol
from biopandas.pdb import PandasPdb

In [66]:
PATH = '/lustre/isaac/proj/UTK0196/deep-surface-protein-data/protein_analysis/protein_structures/SRR7066492_k141_1124833_3/SRR7066492_k141_1124833_3/relaxed_model_1_pred_0.pdb'

In [67]:
#GRADIENTS = #read from file path
GRADIENTS = np.random.uniform(low=-1.0, high=5.0, size=177) #for now just pretend gradients from uniform distribution

In [68]:
#default structure prediction not going to work for visualizing gradients

# with open(PATH) as ifile:
#     system = "".join([x for x in ifile])

# view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js', width=800, height=400)
# view.addModelsAsFrames(system)
# view.setStyle({'model': -1}, {"cartoon": {'color': 'spectrum'}})

In [69]:
import numpy as np
import matplotlib as mpl

#https://databasecamp.de/en/ml/integrated-gradients-nlp
def colorize(gradients, cmap='RdYlBu'):
    cmap_bound = np.abs(gradients).max()
    norm = mpl.colors.Normalize(vmin=-cmap_bound, vmax=cmap_bound)
    cmap = mpl.cm.get_cmap(cmap)
    
    colors = list(map(lambda x: mpl.colors.rgb2hex(cmap(norm(x))), gradients))
    return colors

In [70]:
def get_colors(path, gradients):
    ppdb = PandasPdb()
    structure = ppdb.read_pdb(path)
    sequence = ppdb.amino3to1()
    
    colors = colorize(gradients)
    atom_colors = []
    for i in range(1, len(sequence)):
        dist = sequence.index[i] - sequence.index[i-1]
        for _ in range(dist):
            atom_colors.append(colors[i-1])
            
    dist = len(structure.df['ATOM']) - sequence.index[-1]
    for _ in range(dist):
            atom_colors.append(colors[-1])
            
    return atom_colors

In [71]:
#https://william-dawson.github.io/using-py3dmol.html
class Atom(dict):
    def __init__(self, line):
        self["type"] = line[0:6].strip()
        self["idx"] = line[6:11].strip()
        self["name"] = line[12:16].strip()
        self["resname"] = line[17:20].strip()
        self["resid"] = int(int(line[22:26]))
        self["x"] = float(line[30:38])
        self["y"] = float(line[38:46])
        self["z"] = float(line[46:54])
        self["sym"] = line[76:78].strip()

    def __str__(self):
        line = list(" " * 80)

        line[0:6] = self["type"].ljust(6)
        line[6:11] = self["idx"].ljust(5)
        line[12:16] = self["name"].ljust(4)
        line[17:20] = self["resname"].ljust(3)
        line[22:26] = str(self["resid"]).ljust(4)
        line[30:38] = str(self["x"]).rjust(8)
        line[38:46] = str(self["y"]).rjust(8)
        line[46:54] = str(self["z"]).rjust(8)
        line[76:78] = self["sym"].rjust(2)
        return "".join(line) + "\n"

class Molecule(list):
    def __init__(self, file):
        for line in file:
            if "ATOM" in line or "HETATM" in line:
                self.append(Atom(line))

    def __str__(self):
        outstr = ""
        for at in self:
            outstr += str(at)

        return outstr


In [72]:
with open(PATH) as ifile:
    mol = Molecule(ifile)

In [73]:
colors = get_colors(PATH, GRADIENTS)

  cmap = mpl.cm.get_cmap(cmap)


In [74]:
view = py3Dmol.view(width=400, height=300)
view.addModelsAsFrames(str(mol))
for i, (at, c) in enumerate(zip(mol, colors)):
    default = {"cartoon": {'color': c}}
    view.setStyle({'model': -1, 'serial': i+1}, at.get("pymol", default))
    
view.zoomTo()
view.show()