In [1]:
# utils
import pandas as pd
import plotly.express as px
import numpy as np
import mdtraj as md

# smp_nerf (and helpers)
from functools import partial 
import jax
from jax import numpy as jnp, vmap, jit, pmap
from nerfax.plugin import convert_natural_to_scnet, convert_scnet_to_natural, get_jax_protein_fold
from nerfax.reconstruct import reconstruct_from_internal_coordinates
from nerfax.utils import get_align_rigid_bodies_fn
from nerfax.parser import load_pdb, load_to_sc_coord_format, get_scnet_loader_fns, load_traj

In [8]:
!tar -xvf ../data/biomolecular_condensate.tar.gz -C ../data

biomolecular_condenstate.pdb


In [10]:
path = '../data/biomolecular_condensate.pdb'

t = load_traj(path)
scaffolds = load_pdb(path, first_frame_only=False)

reconstruct_fn = partial(reconstruct_from_internal_coordinates, mode='associative')
_fold = get_jax_protein_fold(jax.tree_map(lambda x: x[0], scaffolds), reconstruct_fn=reconstruct_fn)
def fold(scaffolds, ref):
    # ref: (3,3) array with positions of (N,CA,C) for first residue
    coords = _fold(*[scaffolds[k] for k in ['angles_mask', 'bond_mask']])
    align_fn = get_align_rigid_bodies_fn(coords[0,:3], ref)
    return align_fn(coords)

f = jit(vmap(fold))
scaffolds = load_pdb(path, first_frame_only=False)

inputs = (scaffolds, t.xyz[:,:3]*10)
inputs = jax.tree_map(jnp.array, inputs)

coords = jax.block_until_ready(f(*inputs)) # Compile once
timings = %timeit -n 10 -r 10 -q -o _ = jax.block_until_ready(f(*inputs))
logs = {'best':timings.best, 'average': timings.average, 'stdev': timings.stdev}
print(f"Timing: {logs['best']*1e3:.2f}ms for {t.n_frames} chains of {t.n_residues} residues reconstructed in global reference frame from internal coordinates")

3.57 ms ± 253 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)
Timing: 3.18ms for 1000 chains of 163 residues reconstructed in global reference frame from internal coordinates


In [None]:
# tracing
# with jax.profiler.trace('/tmp/tensorboard'):
#     _ = jax.block_until_ready(f(*inputs))

In [None]:
def save(t, outpath, coords):
    _parse_coords, _restrict_to_scnet_atoms, _scnet_to_list = get_scnet_loader_fns(t)
    t_ = _restrict_to_scnet_atoms(t)
    t_.xyz = vmap(_scnet_to_list)(coords)
    t_.save_pdb(outpath)
    
save(t, '/tmp/reconstructed.pdb', coords)
t.save_pdb('/tmp/ground_truth.pdb')

In [None]:
# Optional visualisation
import nglview as nv
t_reconstructed = md.load('/tmp/reconstructed.pdb')
t_ground_truth = md.load('/tmp/ground_truth.pdb')


vw = nv.NGLWidget(height='700px')
vw.add_trajectory(t_reconstructed, default=False, name='reconstructed')
vw.add_trajectory(t_ground_truth, default=False, name='ground truth')
selection='all'

vw.add_representation('licorice', selection, component=0, color='blue')
vw.add_representation('licorice', selection, component=1, color='red', opacity=0.5)
vw.center()
vw.display(gui=True)