In [1]:
# load modules
import os
import json
import numpy as np
import esm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# import model
model, alphabet = esm.pretrained.esm_if1_gvp4_t16_142M_UR50()
model = model.eval()



In [3]:
# import labels
f = open("DeepTMHMM.partitions.json")
labels = json.load(f)

In [4]:
# get representation structure for all proteins
for cv in labels.keys():
    for i, protein in enumerate(labels[cv]):
        
        path = f"proteins/{cv}/{protein['id']}.pdb"  
        # if protein does not exists in alphafold-db, skip over 
        if not os.path.exists(path):
            print("this one didn't exist: ", protein['id'])
            continue
              
        data = {}
        
        # receive 3-D structure (Atom array (of amino acids) including 3-D coordinates and other info)        
        structure = esm.inverse_folding.util.load_structure(path)
        
        # get coordinates for each amino acid's N-terminal, alpha-carbon and C-terminal (first three in pdb)
        coords, native_seq = esm.inverse_folding.util.extract_coords_from_structure(structure)

        # get encoder output as structure representation shape: (amino acid, encoder dimension)
        rep = esm.inverse_folding.util.get_encoder_output(model, alphabet, coords)
        
        data["data"] = rep.detach().numpy()
        data["labels"] = protein["labels"]
        
        # create directory and save data as .npy file
        encoder_path = f"encoder_proteins/{cv}/{protein['id']}"
        os.makedirs(os.path.dirname(encoder_path), exist_ok=True)
        np.save(encoder_path, data)    

AssertionError: 

In [5]:
# example of loaded dataset
read_dictionary = np.load(encoder_path + ".npy", allow_pickle='TRUE').item()
print(read_dictionary) # displays "world"

{'data': array([[ 0.01676764,  0.10006791, -0.22470556, ...,  0.45647484,
         0.14616954, -0.46214253],
       [-1.4381772 , -0.7790984 , -0.802924  , ...,  0.30524823,
         0.01956351,  0.10650016],
       [-1.3221653 , -0.62276614, -0.8991547 , ...,  0.25747558,
        -0.02869022,  0.02802757],
       ...,
       [ 0.16744782,  0.8542608 , -0.05963039, ..., -0.27786946,
        -0.11583645, -0.28408888],
       [ 0.6397855 ,  0.04167555,  0.37690443, ...,  0.21222025,
        -0.20938833, -0.33905762],
       [ 1.0578748 , -0.11029857,  0.7357471 , ...,  0.21586058,
         0.34215224,  0.15475065]], dtype=float32), 'labels': 'SSSSSSSSSSSSSSSSSSSSSSSSSPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPBBBBBBBOOOOOOOOOOOOOOOOOOOOOOOOOOOOBBBBBBBPPPPPPPPBBBBBBBOOOOOOOOOOOOOOOOOOOOOOOOOBBBBBBBBPPPPPPPPBBBBBBBBOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOBBBBBBBBBPPPPPPBBBBBBBBOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOBBBBBBBBBBPPPPBBBBBBBBBOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOB