<a href="https://colab.research.google.com/github/PeptoneLtd/nerfax/blob/main/jax_foldcomp_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -q git+https://github.com/PeptoneLtd/nerfax.git foldcomp
!wget https://mmseqs.com/foldcomp/foldcomp-linux-x86_64.tar.gz && tar -xvf foldcomp-linux-x86_64.tar.gz && chmod +x foldcomp 

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/PeptoneLtd/nerfax.git
  Cloning https://github.com/PeptoneLtd/nerfax.git to /tmp/pip-req-build-njj489l6
  Running command git clone --filter=blob:none --quiet https://github.com/PeptoneLtd/nerfax.git /tmp/pip-req-build-njj489l6
  Resolved https://github.com/PeptoneLtd/nerfax.git to commit 5fb1a40b75a3d8825d1865078a424c7317fa992a
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting foldcomp
  Downloading foldcomp-0.0.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (266 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m266.5/266.5 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting mp_nerf==1.0.3 (from nerfax==1.0.0)
  Downloading mp_nerf-1.0.3-py3-none-any.whl (23 kB)
Collecting mdtraj (from nerfax==1.0.0)
  Downloading mdtraj-1.9.7.tar.gz (2.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0

In [2]:
import requests
from os.path import basename
import tarfile
def download(url, nmax=3, outfolder='.', file_tag='.pdb.gz'):
    response = requests.get(url, stream=True)
    paths = []
    with tarfile.open(fileobj=response.raw, mode='r|gz') as archive:
      while (len(paths)<nmax):
          tarinfo = archive.next()
          if tarinfo.isreg() & (file_tag in tarinfo.name):
              print(tarinfo.name)
              handle = archive.extractfile(tarinfo)
              data = handle.read()
              handle.close()

              outpath = f'{outfolder}/{basename(tarinfo.name)}'
              open(outpath,'wb').write(data)
              paths.append(outpath)
    return paths

# Download a small part of example data
url = 'https://foldcomp.steineggerlab.workers.dev/afdb_swissprot_foldcompdb.tar.gz'
!mkdir -p /content/example_data/
paths = download(url, nmax=5, outfolder='/content/example_data/', file_tag ='afdb_swissprot.')
# !/content/foldcomp compress /content/example_data/

./afdb_swissprot/afdb_swissprot.95
./afdb_swissprot/afdb_swissprot.191
./afdb_swissprot/afdb_swissprot.121
./afdb_swissprot/afdb_swissprot.179
./afdb_swissprot/afdb_swissprot.148


In [3]:
from functools import partial
from glob import glob
from time import time
import numpy as np
import jax
from jax import jit, numpy as jnp
import foldcomp

from nerfax.foldcomp_utils import decompress, reconstruct, load_data
paths = glob('/content/example_data/afdb_swissprot.*')

compute_rmsd = lambda a,b: ((a-b)**2).sum(-1).mean()**0.5

inputs = []
for path in paths:
    (tag, nResidue, nAtom, idxResidue, idxAtom, nAnchor, chain, firstResidue, lastResidue, strTitle), \
        (anchorIndices, anchorCoords), (hasOXT, oxtCoords), aas, (angles_torsions_discretizers, angles_torsions_body, angles_torsions_end), \
        sideChainAnglesDiscretized, (tempFactorsDisc_min, tempFactorsDisc_cont_f, tempFactorsDisc) = load_data(path)
    inputs.append((angles_torsions_discretizers, angles_torsions_body, angles_torsions_end, anchorCoords, aas, sideChainAnglesDiscretized, hasOXT, oxtCoords))

for i in np.random.choice(np.arange(len(paths)), 4):
    path = paths[i]
    (angles_torsions_discretizers, angles_torsions_body, angles_torsions_end, anchorCoords, aas, sideChainAnglesDiscretized, hasOXT, oxtCoords) = inputs[i]
        
    foldcomp_coords = np.array(foldcomp.get_data(open(path,'rb').read())['coordinates'])
    
    '''
    The aas and hasOXT have to be known for static shapes. 
    So here we fold them in, and compile time eval tags in the codebase does the rest
    '''
    @partial(jit, backend='cpu')
    def fold(angles_torsions_discretizers, angles_torsions_body, angles_torsions_end, anchorCoords, sideChainAnglesDiscretized, oxtCoords):
        return reconstruct(angles_torsions_discretizers, angles_torsions_body, angles_torsions_end, anchorCoords, aas, sideChainAnglesDiscretized, hasOXT, oxtCoords)
    
    start_uncompiled = time()
    coords = jax.block_until_ready(fold(angles_torsions_discretizers, angles_torsions_body, angles_torsions_end, anchorCoords, sideChainAnglesDiscretized, oxtCoords))
    end_uncompiled = time()
    
    rmsd = compute_rmsd(foldcomp_coords, coords)
    
    start_compiled = time()
    coords = jax.block_until_ready(fold(angles_torsions_discretizers, angles_torsions_body, angles_torsions_end, anchorCoords, sideChainAnglesDiscretized, oxtCoords))
    end_compiled = time()
    
    print(f'{aas.shape[0]} residues, {end_uncompiled-start_uncompiled:.2f} seconds with compilation, {rmsd:.3f} Angstrom RMSD, {(end_compiled-start_compiled)*1000:.2f} ms when compiled (with some compile time eval)')



138 residues, 14.45 seconds with compilation, 0.036 Angstrom RMSD, 0.87 ms when compiled (with some compile time eval)
360 residues, 5.98 seconds with compilation, 0.030 Angstrom RMSD, 0.89 ms when compiled (with some compile time eval)
192 residues, 6.00 seconds with compilation, 0.038 Angstrom RMSD, 0.61 ms when compiled (with some compile time eval)
138 residues, 5.32 seconds with compilation, 0.036 Angstrom RMSD, 0.31 ms when compiled (with some compile time eval)


  1 - input shapes differ, it would be much faster if it could be compiled with dynamic shape. Normally this can be dealt with by pre-compiling over a range of shapes and padding up, but the (uncompiled) padding operation then becomes the bottleneck as these are such small computations


In [4]:
display(jax.tree_map(jnp.shape, inputs))

[((2, 6), (14, 24, 6), (24, 6), (16, 3, 3), (360,), (1759,), (), (3,)),
 ((2, 6), (7, 24, 6), (24, 6), (9, 3, 3), (192,), (916,), (), (3,)),
 ((2, 6), (12, 23, 6), (32, 6), (14, 3, 3), (308,), (1525,), (), (3,)),
 ((2, 6), (5, 23, 6), (23, 6), (7, 3, 3), (138,), (643,), (), (3,)),
 ((2, 6), (13, 23, 6), (34, 6), (15, 3, 3), (333,), (1562,), (), (3,))]

 2 - It's slightly worse than just dynamic shape compilation, as some shapes are computed from values in the input arrays, rather than shapes. Currently to deal with this I have used ensure_compile_time_eval.
    # taken out so booleans known
    with jax.ensure_compile_time_eval():
        atom_mask = jnp.array(AA_REF_ATOM_MASK).at[aas].get()
    ...
    for i in range(11):
        with jax.ensure_compile_time_eval():
            # taken out so booleans known, we have no way of inferring shape here without concrete values
            level_mask = atom_mask[:, i]
