### Ophiuchus Encode

Encode a custom set of pdbs through a pre-trained ophiuchus.

This notebook should be used solely to generate embeddings and slice datums (and potentially tsne values) given a specific list of desired pdb ids. It applies selected transforms accordingly as well.

Checkpoint around 4 o'clock on Saturday the 18th, 2024

In [2]:
%load_ext autoreload
%autoreload 2

In [16]:
import os
import sys

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 
os.environ["CUDA_VISIBLE_DEVICES"] = '3'

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"

# Fix up the path
module_path = os.path.abspath(os.path.join('..'))
module_path1 = os.path.abspath(os.path.join('../../../ophiuchus'))
module_path2 = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
if module_path1 not in sys.path:
    sys.path.append(module_path1)
if module_path2 not in sys.path:
    sys.path.append(module_path2)

from collections import defaultdict
from functools import reduce, partial
from typing import List
from tqdm import tqdm
import numpy as np 
import jax
import haiku as hk
from sklearn.manifold import TSNE


from kheiron.pipeline.registry import Registry
# from absl import app, flags

from moleculib.protein.datum import ProteinDatum
from moleculib.assembly.datum import AssemblyDatum
from moleculib.protein.transform import (
    ProteinCrop,
    TokenizeSequenceBoundaries,
    ProteinPad,
    MaybeMirror,
    BackboneOnly,
    DescribeChemistry
)

from model.base.sequence_convolution import moving_window



In [17]:


# print(f"Saving embeddings to {path}")
# print(f"Do TSNE: {FLAGS.tsne}")
# print(f"Save HTML: {FLAGS.html}")

trained_model = "denim-energy-1008"

registry = Registry('ophiuchus', base_path=os.getenv("ALLAN_REGISTRY"))
platform = registry.get_platform(trained_model, read_only=True)
cfg = platform.cfg

stride = cfg['trainer']['model']['autoencoder']['stride']
kernel_size = cfg['trainer']['model']['autoencoder']['kernel_size']
depth = len(cfg['trainer']['model']['autoencoder']['layers'])


max_chain_len = 253  # crop size for denim energy
chain_len = max_chain_len

# we initialize the slices as [i:i+1] for each i in the chain    
slices = np.stack([np.arange(0, chain_len), np.arange(1, chain_len + 1)],axis=-1)

slices_per_level = {0: slices}
# we then reduce as we would in convolutions, but we keep track of the sizes instead
for i in range(depth - 1):
    windows = moving_window(np.arange(slices.shape[0]), kernel_size, stride)
    # breakpoint()
    slices_ = slices[windows]
    slices = np.stack([slices_[:, :, 0].min(axis=-1), slices_[:, :, 1].max(axis=-1)],axis=-1)
    slices_per_level[i + 1] = slices

protein_transform = [
    ProteinCrop(crop_size=max_chain_len),
    TokenizeSequenceBoundaries(),
    MaybeMirror(hand='left'),
    ProteinPad(pad_size=max_chain_len, random_position=False),
    BackboneOnly(filter=True),
    DescribeChemistry(),
]

def transform(datum):
    return reduce(lambda x, f: f.transform(x), protein_transform, datum)

rng_seq = hk.PRNGSequence(42)
premodel = platform.instantiate_model()
forward_ = hk.transform(lambda *a, **ka: premodel()(*a, **ka))

@jax.jit
def _autoencoder(params, rng, datum):
    return forward_.apply(params, rng, datum)

base_params = platform.get_params(-1)
def autoencoder(batch):
    return _autoencoder(base_params, next(rng_seq), batch)



In [44]:
# Prepare dataset

# This list focuses on pre-defined pairs of interest, as well as beta helices
beta_helix_and_friends = ['2jp7', '1prp', '3nxq', '1gca', '1pcl', '1xiq', '2pqe', '1kzq', '4mzu', '1wpc', '1fnu', '4g6r', '4jj2', '3hno', '1lxa', '6ria', '1hg9', '1dcq', '1cb7', '3a1m', '4zu7', '1acc', '1l5j', '6rib', '2jer', '1air', '2d40', '2fla', '1qte', '2kl8', '1dbv', '2obg', '7jvi', '2z0q', '1yox', '1f6w', '3i48', '3zds', '4puq', '1qre', '6e5c', '1cts', '1hin', '2qnz', '3ub3', '1idj', '3obw', '1dab', '3uxh', '4osd', '4aq6', '4aq2', '4fl6', '2ln3', '1znp']

cd20s = ['6PE9', '6PE8', '1QSC', '6BRB', '3LKJ',
         '6PE7', '1ALY']
ms_related = ['6H24', '1PY9', '5HIU', '6FG1', '6FG2', '4Q6R', '4GMV']

custom_dataset_pdbids = cd20s + ms_related + beta_helix_and_friends

import pandas as pd
print(f"Custom dataset has {len(custom_dataset_pdbids)} samples")
print(pd.Series(custom_dataset_pdbids).unique().shape)

Custom dataset has 69 samples
(69,)


In [45]:
class FetchPDBids:
    """Fetch PDB ids as AssemblyDatums."""
    def __init__(self, pdb_ids: List[str]):
        self.pdb_ids = [pdb_id.lower() for pdb_id in pdb_ids]
        self.datums = []
        self.assemblies = []
        self.transformed = []  # list of transformed ProteinDatums

    def __call__(self):
        print(f"Fetching {len(self.pdb_ids)} PDB IDs...", end=" ")
        for pdb_id in self.pdb_ids:
            assembly = AssemblyDatum.fetch_pdb_id(pdb_id,)
            self.assemblies.append(assembly)
            for datum in assembly.protein_data:
                # datum.idcode = pdb_id
                self.datums.append(datum)
            print(f"{pdb_id}, ", end="")
        print("\nDone")

    def togrid(self, k=None, num_columns=3, use_transformed=False):
        if k is None:
            k = len(self.datums)
        if use_transformed:
            if self.transformed == []:
                self.transform()
            datum_grid = self.make_grid(self.transformed[:k], num_columns)
        else:
            datum_grid = self.make_grid(self.datums[:k], num_columns)
        return datum_grid
    
    @staticmethod
    def make_grid(datums: List[ProteinDatum], num_columns=3):
        return [datums[i:i + num_columns] for i in range(0, len(datums), num_columns)]
    
    def transform(self):
        self.transformed = [transform(datum) for datum in self.datums]

fetcher = FetchPDBids(custom_dataset_pdbids)
fetcher()
fetcher.transform()
print(f"Number of fetched datums: {len(fetcher.datums)}, and {len(fetcher.assemblies)} assemblies")



Fetching 69 PDB IDs... 6pe9, 6pe8, 1qsc, 6brb, 3lkj, 6pe7, 1aly, 6h24, 1py9, 5hiu, 6fg1, 6fg2, 4q6r, 4gmv, 2jp7, 1prp, 3nxq, 1gca, 

1pcl, 1xiq, 2pqe, 1kzq, 4mzu, 1wpc, 1fnu, 4g6r, 4jj2, 3hno, 1lxa, 6ria, 1hg9, 1dcq, 1cb7, 3a1m, 4zu7, 1acc, 1l5j, 6rib, 2jer, 1air, 2d40, 2fla, 1qte, 2kl8, 1dbv, 2obg, 7jvi, 2z0q, 1yox, 1f6w, 3i48, 3zds, 4puq, 1qre, 6e5c, 1cts, 1hin, 2qnz, 3ub3, 1idj, 3obw, 1dab, 3uxh, 4osd, 4aq6, 4aq2, 4fl6, 2ln3, 1znp, 
Done
Number of fetched datums: 287, and 69 assemblies


In [69]:
from helpers.utils import residue_map

print(fetcher.datums[5].residue_token.dtype)
residue_map(fetcher.datums[5].residue_token)

int64
[ 6 12 22 15 19  8 18 17  6 18 13  3 22 18 13 10  9  4  3 19 12  5  7 14
 18 18  8 18 13 13  5 13 10  5  8 14  5 21 13 19 20 16  8  8 14 17 10  8
 17 17 14 13 13 12 21 20  3 18 19  4  9 18 10 22 17  6  4 16 18 10 18 10
 18 10 19  6 16 19 13 19 12 18 18 13  8  3  9  6 22  3 22 21 21  7  8  5
  6 21 19 21 17 13 19 16 10  8 10 19 14 13  9 12 14  4 19 22  3  3 17 18
 22 16 12 16 17 17 18  6  9  8 13 14 18 10 19  3 18 22 22  7 13 13  5  5
 16 21 17  4  9  3 14 22  8 20 14 22  6  5  3 13  8 18 10  5 18  8  9 18
 22 19  9  8  6 18 14  6 18 19 21 18 13 18 18 19 13 19 13 18 14  3  6 21
  9 14 11 14 22 21  3  7  9 22 19 11  8 10 13 18 18 17 22 19 14 18 16  5
  4 10  9  7]


['ASP',
 'ILE',
 'VAL',
 'MET',
 'THR',
 'GLN',
 'SER',
 'PRO',
 'ASP',
 'SER',
 'LEU',
 'ALA',
 'VAL',
 'SER',
 'LEU',
 'GLY',
 'GLU',
 'ARG',
 'ALA',
 'THR',
 'ILE',
 'ASN',
 'CYS',
 'LYS',
 'SER',
 'SER',
 'GLN',
 'SER',
 'LEU',
 'LEU',
 'ASN',
 'LEU',
 'GLY',
 'ASN',
 'GLN',
 'LYS',
 'ASN',
 'TYR',
 'LEU',
 'THR',
 'TRP',
 'PHE',
 'GLN',
 'GLN',
 'LYS',
 'PRO',
 'GLY',
 'GLN',
 'PRO',
 'PRO',
 'LYS',
 'LEU',
 'LEU',
 'ILE',
 'TYR',
 'TRP',
 'ALA',
 'SER',
 'THR',
 'ARG',
 'GLU',
 'SER',
 'GLY',
 'VAL',
 'PRO',
 'ASP',
 'ARG',
 'PHE',
 'SER',
 'GLY',
 'SER',
 'GLY',
 'SER',
 'GLY',
 'THR',
 'ASP',
 'PHE',
 'THR',
 'LEU',
 'THR',
 'ILE',
 'SER',
 'SER',
 'LEU',
 'GLN',
 'ALA',
 'GLU',
 'ASP',
 'VAL',
 'ALA',
 'VAL',
 'TYR',
 'TYR',
 'CYS',
 'GLN',
 'ASN',
 'ASP',
 'TYR',
 'THR',
 'TYR',
 'PRO',
 'LEU',
 'THR',
 'PHE',
 'GLY',
 'GLN',
 'GLY',
 'THR',
 'LYS',
 'LEU',
 'GLU',
 'ILE',
 'LYS',
 'ARG',
 'THR',
 'VAL',
 'ALA',
 'ALA',
 'PRO',
 'SER',
 'VAL',
 'PHE',
 'ILE',
 'PHE',
 'PRO',


In [47]:

def encode_assemblies(assemblies: List[AssemblyDatum], start_level: int = 0):
    """Encode a list of assemblies through ophiuchus."""

    encoded_dataset = defaultdict(dict)
    sliced_dataset = defaultdict(partial(defaultdict, list))

    for idx, assembly in tqdm(enumerate(assemblies)):
        
        for protein_index in range(len(assembly.protein_data)):

            datum = assembly.protein_data[protein_index]
            datum_input = transform(datum)
            datum_input.idcode = None
            output = autoencoder(datum_input)

            for level in range(start_level, len(output.encoder_internals)):
                tc = output.encoder_internals[level]
                mask = tc.mask_coord[0]
                tc = np.array(tc.irreps_array.filter('0e').array)[0][mask]
                encoded_dataset[datum.idcode][level] = tc

                slices = slices_per_level[level]
                for i, (start, end) in enumerate(slices):
                    if start <= len(datum):
                        sliced_dataset[datum.idcode][level].append(datum[start:end])
        # if idx > 5:
        #     break
    return encoded_dataset, sliced_dataset


In [48]:
%%time

# This takes about 3 mins (on ~55 assemblies)...most of it due to initial compilation
encoded_dataset, sliced_dataset = encode_assemblies(fetcher.assemblies)

8it [00:03,  4.89it/s]

69it [00:19,  3.53it/s]

CPU times: user 20.8 s, sys: 3.08 s, total: 23.9 s
Wall time: 19.8 s





In [51]:
import pickle

path = '../data'

with open(f'{path}/encoded_dataset_custom.pkl', 'wb') as f:
    pickle.dump(encoded_dataset, f)

# Save the sliced dataset
with open(f'{path}/sliced_dataset_custom.pkl', 'wb') as f:
    pickle.dump(sliced_dataset, f)


In [None]:
from moleculib.graphics.py3Dmol import plot_py3dmol_grid
plot_py3dmol_grid(fetcher.togrid(12))

NOTE that it does not really make sense to compute tsne on these embeddings since they are not a part of the full dataset. 

In [28]:
def make_tsne(encoded_dataset, n_layers, start_level=0):
    """Using the encoded dataset, get tsne coords and colors for each level."""
    encoded_dataset_tsne = defaultdict(lambda : defaultdict(lambda : defaultdict(dict)))
    for level in range(start_level, n_layers):
        level_data = []
        for key, value in list(encoded_dataset.items()):
            level_data.append(value[level])
        level_data = np.concatenate(level_data)
        
        print(f'computing position tsne for level {level}: {level_data.shape}')
        position = TSNE(n_components=2, perplexity=3, learning_rate='auto', init='random').fit_transform(level_data)
        print(f'computing color tsne for level {level}: {level_data.shape}')
        colors = TSNE(n_components=3, perplexity=3, learning_rate='auto', init='random').fit_transform(level_data)
        colors = (colors - colors.min())
        colors = (colors * 255 / colors.max()).astype(np.int32)
        colors = [f'rgb({r}, {g}, {b})' for r, g, b in colors]

        cumsum = 0        
        for _, key in enumerate(list(encoded_dataset.keys())):
            len_ = len(encoded_dataset[key][level])
            encoded_dataset_tsne[key][level]['pos'] = position[cumsum:cumsum+len_].tolist()
            encoded_dataset_tsne[key][level]['colors'] = colors[cumsum:cumsum+len_]
            cumsum += len_
    return encoded_dataset_tsne


In [30]:
encoded_dataset_tsne = make_tsne(encoded_dataset, n_layers=5)

computing position tsne for level 0: (40648, 24)


KeyboardInterrupt: 

In [None]:
import json
with open(f'{path}/encoded_dataset_tsne.json', 'w') as f:
    json.dump(encoded_dataset_tsne, f)