In [1]:
!pip install -q git+https://github.com/PeptoneLtd/nerfax.git

[?25l[K     |▏                               | 10 kB 20.8 MB/s eta 0:00:01[K     |▎                               | 20 kB 23.3 MB/s eta 0:00:01[K     |▌                               | 30 kB 28.7 MB/s eta 0:00:01[K     |▋                               | 40 kB 16.8 MB/s eta 0:00:01[K     |▉                               | 51 kB 6.3 MB/s eta 0:00:01[K     |█                               | 61 kB 7.4 MB/s eta 0:00:01[K     |█                               | 71 kB 7.2 MB/s eta 0:00:01[K     |█▎                              | 81 kB 7.6 MB/s eta 0:00:01[K     |█▍                              | 92 kB 8.4 MB/s eta 0:00:01[K     |█▋                              | 102 kB 9.2 MB/s eta 0:00:01[K     |█▊                              | 112 kB 9.2 MB/s eta 0:00:01[K     |█▉                              | 122 kB 9.2 MB/s eta 0:00:01[K     |██                              | 133 kB 9.2 MB/s eta 0:00:01[K     |██▏                             | 143 kB 9.2 MB/s eta 0:00:01[K 

In [2]:
import nerfax
import mdtraj as md

path = 'model.pdb'
!wget https://alphafold.ebi.ac.uk/files/AF-Q93WI9-F1-model_v4.pdb -q -O {path}

# Load pdb file with mdtraj into scnet internal coordinate format
internal_coords = nerfax.parser.load_pdb(path)
# Reconstruct scnet cartesian (L,14,3) from internal coordinate format
reconstructed_cartesian_scnet_coords = nerfax.plugin.protein_fold(**internal_coords)

# Copy the resulting scnet coords back into 'normal' mdtraj format
t = nerfax.parser.load_traj(path)
_, _restrict_to_scnet_atoms, _scnet_to_list = nerfax.parser.get_scnet_loader_fns(t)
t = _restrict_to_scnet_atoms(t)
t.xyz = _scnet_to_list(reconstructed_cartesian_scnet_coords)[None]

# Check RMSD to original positions
t_pre = _restrict_to_scnet_atoms(nerfax.parser.load_traj(path))
print(f'RMSD between reconstructed and original : {md.rmsd(t,t_pre)[0]:.1g}')



RMSD between reconstructed and original : 0.0006


If the reconstruction is to be done on the same protein many times over we can compile the reconstruction, specialised to that exact protein

In [3]:
from jax import jit, block_until_ready
fold = jit(nerfax.plugin.get_jax_protein_fold(internal_coords))
inputs = {k:internal_coords[k] for k in ['angles_mask', 'bond_mask']}
# compile
_ = fold(**inputs)
# time
%time output = block_until_ready(fold(**inputs))

CPU times: user 558 µs, sys: 0 ns, total: 558 µs
Wall time: 582 µs


The fold by default uses the associative scan, however this is not optimal for running on CPU. We can switch to using the sequential version instead and should get a faster benchmark

In [4]:
from functools import partial
fold = jit(nerfax.plugin.get_jax_protein_fold(internal_coords, 
                  reconstruct_fn=partial(nerfax.reconstruct.reconstruct_from_internal_coordinates, mode='sequential')
          ))
# compile
_ = fold(**inputs)
# time
%time output = block_until_ready(fold(**inputs))

CPU times: user 413 µs, sys: 0 ns, total: 413 µs
Wall time: 423 µs
