In [1]:
import logging
logging.basicConfig(level=logging.INFO)

import pickle 
import numpy as np
import os
from tqdm import tqdm

from openfold.data import data_pipeline, feature_pipeline
from openfold.utils.tensor_utils import tensor_tree_map, dict_multimap
from openfold.config import model_config
from openfold.utils.loss import compute_contact_ca

INFO:numexpr.utils:NumExpr defaulting to 8 threads.


In [2]:
ss_file = '/Users/chenceshi/Downloads/Chrome Downloads/ss_annotation_31885.pkl'
ss_dict = {}
with open(ss_file, 'rb') as fin:
    second_structure_data = pickle.load(fin)
logging.warning(f"get {len(second_structure_data)} second structure data")
for ss in second_structure_data:
    ss_dict[ss['tag']] = ss['ss3']
data_pipeline = data_pipeline.DataPipeline(ss_dict)



In [3]:
config = model_config(name='initial_training', train=True)
feature_pipeline = feature_pipeline.FeaturePipeline(config.data)

In [4]:
path = '/Users/chenceshi/Downloads/Chrome Downloads/1alo006.pdb'
data = data_pipeline.process_pdb(pdb_path=path)
feats = feature_pipeline.process_features(data, 'train')



In [5]:
feats['all_atom_positions'].shape

torch.Size([256, 37, 3, 4])

In [5]:
cycle_no = 0
fetch_cur_batch = lambda t: t[..., cycle_no]
cur_feats = tensor_tree_map(fetch_cur_batch, feats)

contact = compute_contact_ca(
    cur_feats["all_atom_positions"],
    cur_feats["all_atom_mask"],
    cutoff=12
)
contact_ = contact[:63,:63].numpy()

In [None]:
beautiful_out = []
for line in contact_:
    tmp = []
    for x_ in line:
        tmp.append(str(x_))
    beautiful_out.append(''.join(tmp))
print('\n'.join(beautiful_out))


In [18]:
data_dir = '/Users/chenceshi/Downloads/Chrome Downloads/debug_dataset'
filenames = [x for x in os.listdir(data_dir) if x.endswith('.pdb')]
for filename in tqdm(filenames):
    path = os.path.join(data_dir, filename)
    data = data_pipeline.process_pdb(pdb_path=path)
    print(data["all_atom_positions"].shape)
    feats = feature_pipeline.process_features(data, 'train')
    

 14%|████████████▋                                                                                | 3/22 [00:00<00:00, 21.56it/s]

(100, 37, 3)
(94, 37, 3)
(144, 37, 3)




(319, 37, 3)
(57, 37, 3)
(233, 37, 3)
(159, 37, 3)




(245, 37, 3)
(126, 37, 3)
(63, 37, 3)
(205, 37, 3)




(334, 37, 3)
(313, 37, 3)
(79, 37, 3)
(139, 37, 3)




(323, 37, 3)
(52, 37, 3)




(71, 37, 3)
(418, 37, 3)
(79, 37, 3)
(96, 37, 3)


100%|████████████████████████████████████████████████████████████████████████████████████████████| 22/22 [00:01<00:00, 14.31it/s]

(290, 37, 3)





In [17]:
feats

{'aatype': tensor([[10, 10, 10, 10],
         [ 7,  7,  7,  7],
         [15, 15, 15, 15],
         ...,
         [19, 19, 19, 19],
         [ 1,  1,  1,  1],
         [17, 17, 17, 17]]),
 'sstype': tensor([[0, 0, 0, 0],
         [0, 0, 0, 0],
         [2, 2, 2, 2],
         ...,
         [1, 1, 1, 1],
         [1, 1, 1, 1],
         [1, 1, 1, 1]]),
 'residue_index': tensor([[ 18,  18,  18,  18],
         [ 19,  19,  19,  19],
         [ 20,  20,  20,  20],
         ...,
         [271, 271, 271, 271],
         [272, 272, 272, 272],
         [273, 273, 273, 273]]),
 'chain_index': tensor([[0, 0, 0, 0],
         [0, 0, 0, 0],
         [0, 0, 0, 0],
         ...,
         [0, 0, 0, 0],
         [0, 0, 0, 0],
         [0, 0, 0, 0]]),
 'seq_length': tensor([256, 256, 256, 256]),
 'all_atom_positions': tensor([[[[ 11.4757,  11.4757,  11.4757,  11.4757],
           [-10.7133, -10.7133, -10.7133, -10.7133],
           [ -7.5520,  -7.5520,  -7.5520,  -7.5520]],
 
          [[ 12.8737,  12.8737,