# ESM for Protein Residue Embedding

## Protein Construction

In [1]:
import pickle
import torch
from tqdm import tqdm
from torchdrug import utils, transforms, data, models, layers

In [2]:
protein_pkl = '../../../data/dta-datasets/PDBbind/gearnet_Protein.pkl'
with utils.smart_open(protein_pkl, "rb") as fin:
    output_list = pickle.load(fin)

In [3]:
output_list

[PackedProtein(batch_size=1, num_atoms=[1161], num_bonds=[23749], num_residues=[1161]),
 PackedProtein(batch_size=1, num_atoms=[501], num_bonds=[9852], num_residues=[501]),
 PackedProtein(batch_size=1, num_atoms=[165], num_bonds=[3709], num_residues=[165]),
 PackedProtein(batch_size=1, num_atoms=[852], num_bonds=[19860], num_residues=[852]),
 PackedProtein(batch_size=1, num_atoms=[257], num_bonds=[5615], num_residues=[257]),
 PackedProtein(batch_size=1, num_atoms=[611], num_bonds=[14390], num_residues=[611]),
 PackedProtein(batch_size=1, num_atoms=[939], num_bonds=[20431], num_residues=[939]),
 PackedProtein(batch_size=1, num_atoms=[295], num_bonds=[6629], num_residues=[295]),
 PackedProtein(batch_size=1, num_atoms=[296], num_bonds=[6538], num_residues=[296]),
 PackedProtein(batch_size=1, num_atoms=[1017], num_bonds=[23912], num_residues=[1017]),
 PackedProtein(batch_size=1, num_atoms=[364], num_bonds=[7158], num_residues=[364]),
 PackedProtein(batch_size=1, num_atoms=[430], num_bonds=

In [4]:
len(output_list)

19404

In [5]:
protein_within_1022_list = []
indexes = range(len(output_list))   # len(transform_list)
indexes = tqdm(indexes, "ESM Combining ...")
for index in indexes: 
    curProtein = output_list[index]
    curLength = curProtein.num_residue.tolist()
    if curLength <= 1022:
        protein_within_1022_list.append(curProtein)
    elif curLength <= 2044:
        mask1 = list(range(1022))
        mask2 = list(range(1022, curLength))
        part1 = curProtein.residue_mask(mask1, compact=True)
        part2 = curProtein.residue_mask(mask2, compact=True)
        graph = data.Protein.pack([part1, part2])
        protein_within_1022_list.append(graph)
    elif curLength <= 3066:
        mask1 = list(range(1022))
        mask2 = list(range(1022, 2044))
        mask3 = list(range(2044, curLength))
        part1 = curProtein.residue_mask(mask1, compact=True)
        part2 = curProtein.residue_mask(mask2, compact=True)
        part3 = curProtein.residue_mask(mask3, compact=True)
        graph = data.Protein.pack([part1, part2, part3])
        protein_within_1022_list.append(graph)
    elif curLength <= 4088:
        mask1 = list(range(1022))
        mask2 = list(range(1022, 2044))
        mask3 = list(range(2044, 3066))
        mask4 = list(range(3066, curLength))
        part1 = curProtein.residue_mask(mask1, compact=True)
        part2 = curProtein.residue_mask(mask2, compact=True)
        part3 = curProtein.residue_mask(mask3, compact=True)
        part4 = curProtein.residue_mask(mask4, compact=True)
        graph = data.Protein.pack([part1, part2, part3, part4])
        protein_within_1022_list.append(graph)
    elif curLength <= 5110:
        mask1 = list(range(1022))
        mask2 = list(range(1022, 2044))
        mask3 = list(range(2044, 3066))
        mask4 = list(range(2044, 4088))
        mask5 = list(range(4088, curLength))
        part1 = curProtein.residue_mask(mask1, compact=True)
        part2 = curProtein.residue_mask(mask2, compact=True)
        part3 = curProtein.residue_mask(mask3, compact=True)
        part4 = curProtein.residue_mask(mask4, compact=True)
        part5 = curProtein.residue_mask(mask5, compact=True)
        graph = data.Protein.pack([part1, part2, part3, part4, part5])
        protein_within_1022_list.append(graph)
    else :
        raise ValueError("Error with too long residue")

ESM Combining ...: 100%|██████████| 19404/19404 [30:52<00:00, 10.47it/s] 


In [6]:
drug_pkl = '../../../data/dta-datasets/PDBbind/protein_for_esm.pkl'
with utils.smart_open(drug_pkl, "wb") as fout:
    pickle.dump(protein_within_1022_list, fout)

#### 1022 Length Limit Protein Construction

In [1]:
import pickle
import torch
from tqdm import tqdm
from torchdrug import utils, transforms, data, models, layers

In [2]:
protein_pkl = '../../../data/dta-datasets/PDBbind/protein_for_esm.pkl'
with utils.smart_open(protein_pkl, "rb") as fin:
    protein_within_1022_list = pickle.load(fin)

KeyboardInterrupt: 

In [3]:
list1 = protein_within_1022_list[5000:10000]
list1

[PackedProtein(batch_size=2, num_atoms=[1022, 139], num_bonds=[20059, 2123], num_residues=[1022, 139]),
 PackedProtein(batch_size=1, num_atoms=[501], num_bonds=[9852], num_residues=[501]),
 PackedProtein(batch_size=1, num_atoms=[165], num_bonds=[3709], num_residues=[165]),
 PackedProtein(batch_size=1, num_atoms=[852], num_bonds=[19860], num_residues=[852]),
 PackedProtein(batch_size=1, num_atoms=[257], num_bonds=[5615], num_residues=[257]),
 PackedProtein(batch_size=1, num_atoms=[611], num_bonds=[14390], num_residues=[611]),
 PackedProtein(batch_size=1, num_atoms=[939], num_bonds=[20431], num_residues=[939]),
 PackedProtein(batch_size=1, num_atoms=[295], num_bonds=[6629], num_residues=[295]),
 PackedProtein(batch_size=1, num_atoms=[296], num_bonds=[6538], num_residues=[296]),
 PackedProtein(batch_size=1, num_atoms=[1017], num_bonds=[23912], num_residues=[1017]),
 PackedProtein(batch_size=1, num_atoms=[364], num_bonds=[7158], num_residues=[364]),
 PackedProtein(batch_size=1, num_atoms=[

In [4]:
len(list1)

5000

In [5]:
drug_pkl = '../../../data/dta-datasets/PDBbind/esm/protein_5000.pkl'
with utils.smart_open(drug_pkl, "wb") as fout:
    pickle.dump(list1, fout)

***

#### ESM for each Protein

In [1]:
import pickle
import torch
from tqdm import tqdm
from torchdrug import utils, transforms, data, models, layers

In [2]:
protein_pkl = '../../../data/dta-datasets/PDBbind/esm/protein_5000.pkl'
with utils.smart_open(protein_pkl, "rb") as fin:
    protein_within_1022_list = pickle.load(fin)

In [3]:
sequence_model = models.ESM(path="../../../result/model_pth/ESM/", model="ESM-2_650M")

In [4]:
def esmProcess(esmModel, curProtein, pklName):
    """
    curProtein: list of packedProtein 
    pklName: the pkl name to store the protein
    """
    output = esmModel(curProtein, curProtein.residue_feature)
    curProtein.residue_feature = output['residue_feature']
    curProtein.mol_feature = output['graph_feature']
    protein_pkl = '../../../data/dta-datasets/PDBbind/esm/gearnetesm_' + pklName + '.pkl'
    with utils.smart_open(protein_pkl, "wb") as fout:
        pickle.dump(curProtein, fout)

In [5]:
protein_within_1022_list

[PackedProtein(batch_size=2, num_atoms=[1022, 139], num_bonds=[20059, 2123], num_residues=[1022, 139]),
 PackedProtein(batch_size=1, num_atoms=[501], num_bonds=[9852], num_residues=[501]),
 PackedProtein(batch_size=1, num_atoms=[165], num_bonds=[3709], num_residues=[165]),
 PackedProtein(batch_size=1, num_atoms=[852], num_bonds=[19860], num_residues=[852]),
 PackedProtein(batch_size=1, num_atoms=[257], num_bonds=[5615], num_residues=[257]),
 PackedProtein(batch_size=1, num_atoms=[611], num_bonds=[14390], num_residues=[611]),
 PackedProtein(batch_size=1, num_atoms=[939], num_bonds=[20431], num_residues=[939]),
 PackedProtein(batch_size=1, num_atoms=[295], num_bonds=[6629], num_residues=[295]),
 PackedProtein(batch_size=1, num_atoms=[296], num_bonds=[6538], num_residues=[296]),
 PackedProtein(batch_size=1, num_atoms=[1017], num_bonds=[23912], num_residues=[1017]),
 PackedProtein(batch_size=1, num_atoms=[364], num_bonds=[7158], num_residues=[364]),
 PackedProtein(batch_size=1, num_atoms=[

In [6]:
len(protein_within_1022_list)

5000

In [8]:
indexes = range(528, 600)   # len(transform_list)
indexes = tqdm(indexes, "ESM Computing ......")
for index in indexes:
    curProtein = protein_within_1022_list[index]
    if curProtein.batch_size == 1:
        esmProcess(sequence_model, curProtein, str(index))
    else:
        for i in range(curProtein.batch_size):
            protein = data.Protein.pack([curProtein[i]])
            esmProcess(sequence_model, protein, str(index) + '-' + str(i))

ESM Computing ......:  29%|██▉       | 29/100 [02:19<05:42,  4.82s/it]


RuntimeError: [enforce fail at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\c10\core\impl\alloc_cpu.cpp:72] data. DefaultCPUAllocator: not enough memory: you tried to allocate 20971520 bytes.

Read all pkl and aggreate into the new list

In [None]:
import pickle
import torch
from tqdm import tqdm
from torchdrug import utils, transforms, data, models, layers

In [None]:
protein_pkl = '../../../data/dta-datasets/PDBbind/gearnet_pocket_Protein.pkl'
with utils.smart_open(protein_pkl, "rb") as fin:
    output_list = pickle.load(fin)

In [None]:
len(output_list)

9018

In [None]:
output_list

[PackedProtein(batch_size=1, num_atoms=[39], num_bonds=[501], num_residues=[39]),
 PackedProtein(batch_size=1, num_atoms=[22], num_bonds=[246], num_residues=[22]),
 PackedProtein(batch_size=1, num_atoms=[29], num_bonds=[347], num_residues=[29]),
 PackedProtein(batch_size=1, num_atoms=[50], num_bonds=[713], num_residues=[50]),
 PackedProtein(batch_size=1, num_atoms=[73], num_bonds=[1106], num_residues=[73]),
 PackedProtein(batch_size=1, num_atoms=[46], num_bonds=[588], num_residues=[46]),
 PackedProtein(batch_size=1, num_atoms=[73], num_bonds=[1131], num_residues=[73]),
 PackedProtein(batch_size=1, num_atoms=[34], num_bonds=[552], num_residues=[34]),
 PackedProtein(batch_size=1, num_atoms=[27], num_bonds=[340], num_residues=[27]),
 PackedProtein(batch_size=1, num_atoms=[53], num_bonds=[834], num_residues=[53]),
 PackedProtein(batch_size=1, num_atoms=[44], num_bonds=[653], num_residues=[44]),
 PackedProtein(batch_size=1, num_atoms=[43], num_bonds=[632], num_residues=[43]),
 PackedProtein

In [None]:
indexes = range(len(output_list)) # len(output_list)
indexes = tqdm(indexes, "ESM Combining ...")
for index in indexes: 
    # read each esm protein from pkl file
    curProtein = output_list[index]
    protein_pkl = '../../../data/dta-datasets/PDBbind/esm/gearnetesm_'+str(index)+'.pkl'
    with utils.smart_open(protein_pkl, "rb") as fin:
        output_protein = pickle.load(fin)
    output_list[index].residue_feature = output_protein.residue_feature

ESM Combining ...: 100%|██████████| 9018/9018 [02:35<00:00, 57.88it/s] 


In [None]:
output_list[8866].residue_feature.shape

torch.Size([29, 1280])

In [None]:
len(output_list)

9018

In [None]:
protein_pkl = '../../../data/dta-datasets/PDBbind/gearnetesm_pocket_Protein.pkl'
with utils.smart_open(protein_pkl, "wb") as fout:
    pickle.dump(output_list, fout)

## Pocket Construction

#### GearNet Construction

In [1]:
import pickle
import torch
from tqdm import tqdm
from torchdrug import utils, transforms, data, models, layers

No CUDA runtime is found, using CUDA_HOME='C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7'


In [2]:
protein_pkl = '../../../data/dta-datasets/PDBbind/gearnet_pocket_Protein.pkl'
with utils.smart_open(protein_pkl, "rb") as fin:
    output_list = pickle.load(fin)

In [3]:
len(output_list)

19404

In [4]:
sequence_model = models.ESM(path="../../../result/model_pth/ESM/", model="ESM-2_650M")

In [5]:
def esmProcess(esmModel, curProtein, pklName):
    """
    curProtein: list of packedProtein 
    pklName: the pkl name to store the protein
    """
    output = esmModel(curProtein, curProtein.residue_feature)
    curProtein.residue_feature = output['residue_feature']
    curProtein.mol_feature = output['graph_feature']
    protein_pkl = '../../../data/dta-datasets/PDBbind/esm/gearnetesm_' + pklName + '.pkl'
    with utils.smart_open(protein_pkl, "wb") as fout:
        pickle.dump(curProtein, fout)

In [6]:
indexes = range(19000, len(output_list))   # len(output_list)
indexes = tqdm(indexes, "ESM Computing ......")
for index in indexes:
    curProtein = output_list[index]
    esmProcess(sequence_model, curProtein, str(index))

ESM Computing ......: 100%|██████████| 404/404 [20:06<00:00,  2.99s/it]


Read all pkl and aggreate into the new list

In [1]:
import pickle
import torch
from tqdm import tqdm
from torchdrug import utils, transforms, data, models, layers

No CUDA runtime is found, using CUDA_HOME='C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7'


In [2]:
protein_pkl = '../../../data/dta-datasets/PDBbind/gearnet_pocket_Protein.pkl'
with utils.smart_open(protein_pkl, "rb") as fin:
    output_list = pickle.load(fin)

In [3]:
len(output_list)

19404

In [4]:
indexes = range(len(output_list)) # len(output_list)
indexes = tqdm(indexes, "ESM Combining ...")
for index in indexes: 
    # read each esm protein from pkl file
    protein_pkl = '../../../data/dta-datasets/PDBbind/esm/gearnetesm_'+str(index)+'.pkl'
    with utils.smart_open(protein_pkl, "rb") as fin:
        output_protein = pickle.load(fin)
    output_list[index].residue_feature = output_protein.residue_feature

ESM Combining ...: 100%|██████████| 19404/19404 [12:44<00:00, 25.39it/s]


In [5]:
output_list[0].residue_feature.shape

torch.Size([49, 1280])

In [6]:
len(output_list)

19404

In [7]:
protein_pkl = '../../../data/dta-datasets/PDBbind/gearnetesm_pocket_Protein.pkl'
with utils.smart_open(protein_pkl, "wb") as fout:
    pickle.dump(output_list, fout)

***

In [None]:
import pickle
import torch
from tqdm import tqdm
from torchdrug import utils, transforms, data, models, layers

In [None]:
protein_pkl = '../../../data/dta-datasets/PDBbind/gearnet_Protein.pkl'
with utils.smart_open(protein_pkl, "rb") as fin:
    output_list = pickle.load(fin)

In [None]:
output_list

[PackedProtein(batch_size=1, num_atoms=[1161], num_bonds=[23749], num_residues=[1161]),
 PackedProtein(batch_size=1, num_atoms=[501], num_bonds=[9852], num_residues=[501]),
 PackedProtein(batch_size=1, num_atoms=[165], num_bonds=[3709], num_residues=[165]),
 PackedProtein(batch_size=1, num_atoms=[852], num_bonds=[19860], num_residues=[852]),
 PackedProtein(batch_size=1, num_atoms=[257], num_bonds=[5615], num_residues=[257]),
 PackedProtein(batch_size=1, num_atoms=[611], num_bonds=[14390], num_residues=[611]),
 PackedProtein(batch_size=1, num_atoms=[939], num_bonds=[20431], num_residues=[939]),
 PackedProtein(batch_size=1, num_atoms=[295], num_bonds=[6629], num_residues=[295]),
 PackedProtein(batch_size=1, num_atoms=[296], num_bonds=[6538], num_residues=[296]),
 PackedProtein(batch_size=1, num_atoms=[1017], num_bonds=[23912], num_residues=[1017]),
 PackedProtein(batch_size=1, num_atoms=[364], num_bonds=[7158], num_residues=[364]),
 PackedProtein(batch_size=1, num_atoms=[430], num_bonds=

In [None]:
len(output_list)

19404

In [None]:
sequence_model = models.ESM(path="../../../result/model_pth/ESM/", model="ESM-2_650M")

In [None]:
def esmProcess(esmModel, curProtein, pklName):
    """
    curProtein: list of packedProtein 
    pklName: the pkl name to store the protein
    """
    output = esmModel(curProtein, curProtein.residue_feature)
    curProtein.residue_feature = output['residue_feature']
    curProtein.mol_feature = output['graph_feature']
    protein_pkl = '../../../data/dta-datasets/PDBbind/esm/gearnetesm_' + pklName + '.pkl'
    with utils.smart_open(protein_pkl, "wb") as fout:
        pickle.dump(curProtein, fout)

In [None]:
indexes = range(0,1000)   # len(transform_list)
indexes = tqdm(indexes, "ESM Computing ......")
for index in indexes:
    curProtein = output_list[index]
    esmProcess(sequence_model, curProtein, str(index))

ESM Computing ......:   0%|          | 2/1000 [00:34<4:45:57, 17.19s/it]


KeyboardInterrupt: 

Read all pkl and aggreate into the new list

In [None]:
import pickle
import torch
from tqdm import tqdm
from torchdrug import utils, transforms, data, models, layers

In [None]:
protein_pkl = '../../../data/dta-datasets/PDBbind/gearnet_pocket_Protein.pkl'
with utils.smart_open(protein_pkl, "rb") as fin:
    output_list = pickle.load(fin)

In [None]:
len(output_list)

9018

In [None]:
output_list

[PackedProtein(batch_size=1, num_atoms=[39], num_bonds=[501], num_residues=[39]),
 PackedProtein(batch_size=1, num_atoms=[22], num_bonds=[246], num_residues=[22]),
 PackedProtein(batch_size=1, num_atoms=[29], num_bonds=[347], num_residues=[29]),
 PackedProtein(batch_size=1, num_atoms=[50], num_bonds=[713], num_residues=[50]),
 PackedProtein(batch_size=1, num_atoms=[73], num_bonds=[1106], num_residues=[73]),
 PackedProtein(batch_size=1, num_atoms=[46], num_bonds=[588], num_residues=[46]),
 PackedProtein(batch_size=1, num_atoms=[73], num_bonds=[1131], num_residues=[73]),
 PackedProtein(batch_size=1, num_atoms=[34], num_bonds=[552], num_residues=[34]),
 PackedProtein(batch_size=1, num_atoms=[27], num_bonds=[340], num_residues=[27]),
 PackedProtein(batch_size=1, num_atoms=[53], num_bonds=[834], num_residues=[53]),
 PackedProtein(batch_size=1, num_atoms=[44], num_bonds=[653], num_residues=[44]),
 PackedProtein(batch_size=1, num_atoms=[43], num_bonds=[632], num_residues=[43]),
 PackedProtein

In [None]:
indexes = range(len(output_list)) # len(output_list)
indexes = tqdm(indexes, "ESM Combining ...")
for index in indexes: 
    # read each esm protein from pkl file
    curProtein = output_list[index]
    protein_pkl = '../../../data/dta-datasets/PDBbind/esm/gearnetesm_'+str(index)+'.pkl'
    with utils.smart_open(protein_pkl, "rb") as fin:
        output_protein = pickle.load(fin)
    output_list[index].residue_feature = output_protein.residue_feature

ESM Combining ...: 100%|██████████| 9018/9018 [02:35<00:00, 57.88it/s] 


In [None]:
output_list[8866].residue_feature.shape

torch.Size([29, 1280])

In [None]:
len(output_list)

9018

In [None]:
protein_pkl = '../../../data/dta-datasets/PDBbind/gearnetesm_pocket_Protein.pkl'
with utils.smart_open(protein_pkl, "wb") as fout:
    pickle.dump(output_list, fout)

In [None]:
indexes = range(5)   # len(transform_list)
indexes = tqdm(indexes, "ESM Computing ......")
for index in indexes:
    curProtein = transform_list[index]
    if curProtein.batch_size == 1:
        esmProcess(sequence_model, curProtein, str(index))
    else:
        for i in range(curProtein.batch_size):
            protein = data.Protein.pack([curProtein[i]])
            esmProcess(sequence_model, protein, str(index) + '-' + str(i))

In [7]:
indexes = range(0,1000)   # len(transform_list)
indexes = tqdm(indexes, "ESM Computing ......")
for index in indexes:
    curProtein = output_list[index]
    esmProcess(sequence_model, curProtein, str(index))

ESM Computing ......:   0%|          | 2/1000 [00:34<4:45:57, 17.19s/it]


KeyboardInterrupt: 

In [None]:
indexes = range(0,1000)   # len(transform_list)
indexes = tqdm(indexes, "ESM Combining ...")
for index in indexes: 
    # read each esm protein from pkl file
    curProtein = output_list[index]
    curLength = curProtein.num_residue.tolist()
    if curLength <= 1022:
        protein_pkl = '../../../data/dta-datasets/Davis/esm/gearnetesm_'+str(index)+'.pkl'
        with utils.smart_open(protein_pkl, "rb") as fin:
            output_protein = pickle.load(fin)
        output_list[index].residue_feature = output_protein.residue_feature
    elif curLength <= 2044:
        for i in range(2):
            protein_pkl = '../../../data/dta-datasets/Davis/esm/gearnetesm_'+str(index)+'-'+str(i)+'.pkl'
            with utils.smart_open(protein_pkl, "rb") as fin:
                output_protein = pickle.load(fin)
            if i == 0: # the 1st part [0, 1022]
                residue_feature = output_protein.residue_feature
            else: # the 2nd part [1022, 2044]
                residue_feature = torch.cat((residue_feature, output_protein.residue_feature), dim=0)
        output_list[index].residue_feature = residue_feature
    elif curLength <= 3066:
        for i in range(3):
            protein_pkl = '../../../data/dta-datasets/Davis/esm/gearnetesm_'+str(index)+'-'+str(i)+'.pkl'
            with utils.smart_open(protein_pkl, "rb") as fin:
                output_protein = pickle.load(fin)
            if i == 0: # the 1st part [0, 1022]
                residue_feature = output_protein.residue_feature
                # print(residue_feature.shape)
            elif i == 1: # the 2nd part [1022, 2044]
                residue_feature = torch.cat((residue_feature, output_protein.residue_feature), dim=0)
                # print(residue_feature.shape)
            else: # the 3rd part [2044, length]
                residue_feature = torch.cat((residue_feature, output_protein.residue_feature), dim=0)
        print(residue_feature.shape)
        output_list[index].residue_feature = residue_feature

Read all pkl and aggreate into the new list

In [1]:
import pickle
import torch
from tqdm import tqdm
from torchdrug import utils, transforms, data, models, layers

In [2]:
protein_pkl = '../../../data/dta-datasets/PDBbind/gearnet_pocket_Protein.pkl'
with utils.smart_open(protein_pkl, "rb") as fin:
    output_list = pickle.load(fin)

In [3]:
len(output_list)

9018

In [4]:
output_list

[PackedProtein(batch_size=1, num_atoms=[39], num_bonds=[501], num_residues=[39]),
 PackedProtein(batch_size=1, num_atoms=[22], num_bonds=[246], num_residues=[22]),
 PackedProtein(batch_size=1, num_atoms=[29], num_bonds=[347], num_residues=[29]),
 PackedProtein(batch_size=1, num_atoms=[50], num_bonds=[713], num_residues=[50]),
 PackedProtein(batch_size=1, num_atoms=[73], num_bonds=[1106], num_residues=[73]),
 PackedProtein(batch_size=1, num_atoms=[46], num_bonds=[588], num_residues=[46]),
 PackedProtein(batch_size=1, num_atoms=[73], num_bonds=[1131], num_residues=[73]),
 PackedProtein(batch_size=1, num_atoms=[34], num_bonds=[552], num_residues=[34]),
 PackedProtein(batch_size=1, num_atoms=[27], num_bonds=[340], num_residues=[27]),
 PackedProtein(batch_size=1, num_atoms=[53], num_bonds=[834], num_residues=[53]),
 PackedProtein(batch_size=1, num_atoms=[44], num_bonds=[653], num_residues=[44]),
 PackedProtein(batch_size=1, num_atoms=[43], num_bonds=[632], num_residues=[43]),
 PackedProtein

In [5]:
indexes = range(len(output_list)) # len(output_list)
indexes = tqdm(indexes, "ESM Combining ...")
for index in indexes: 
    # read each esm protein from pkl file
    curProtein = output_list[index]
    protein_pkl = '../../../data/dta-datasets/PDBbind/esm/gearnetesm_'+str(index)+'.pkl'
    with utils.smart_open(protein_pkl, "rb") as fin:
        output_protein = pickle.load(fin)
    output_list[index].residue_feature = output_protein.residue_feature

ESM Combining ...: 100%|██████████| 9018/9018 [02:35<00:00, 57.88it/s] 


In [8]:
output_list[8866].residue_feature.shape

torch.Size([29, 1280])

In [9]:
len(output_list)

9018

In [10]:
protein_pkl = '../../../data/dta-datasets/PDBbind/gearnetesm_pocket_Protein.pkl'
with utils.smart_open(protein_pkl, "wb") as fout:
    pickle.dump(output_list, fout)

#### GearNet Pocket Construction

In [None]:
import pickle
from tqdm import tqdm
from torchdrug import utils, transforms, data, models, layers

In [None]:
protein_pkl = '../../../data/dta-datasets/PDBbind/gearnet_pocket_Protein.pkl'
with utils.smart_open(protein_pkl, "rb") as fin:
    output_list = pickle.load(fin)

In [None]:
output_list

[PackedProtein(batch_size=1, num_atoms=[39], num_bonds=[501], num_residues=[39]),
 PackedProtein(batch_size=1, num_atoms=[22], num_bonds=[246], num_residues=[22]),
 PackedProtein(batch_size=1, num_atoms=[29], num_bonds=[347], num_residues=[29]),
 PackedProtein(batch_size=1, num_atoms=[50], num_bonds=[713], num_residues=[50]),
 PackedProtein(batch_size=1, num_atoms=[73], num_bonds=[1106], num_residues=[73]),
 PackedProtein(batch_size=1, num_atoms=[46], num_bonds=[588], num_residues=[46]),
 PackedProtein(batch_size=1, num_atoms=[73], num_bonds=[1131], num_residues=[73]),
 PackedProtein(batch_size=1, num_atoms=[34], num_bonds=[552], num_residues=[34]),
 PackedProtein(batch_size=1, num_atoms=[27], num_bonds=[340], num_residues=[27]),
 PackedProtein(batch_size=1, num_atoms=[53], num_bonds=[834], num_residues=[53]),
 PackedProtein(batch_size=1, num_atoms=[44], num_bonds=[653], num_residues=[44]),
 PackedProtein(batch_size=1, num_atoms=[43], num_bonds=[632], num_residues=[43]),
 PackedProtein

In [None]:
len(output_list)

9018

In [None]:
sequence_model = models.ESM(path="../../../result/model_pth/ESM/", model="ESM-2_650M")

In [None]:
def esmProcess(esmModel, curProtein, pklName):
    """
    curProtein: list of packedProtein 
    pklName: the pkl name to store the protein
    """
    output = esmModel(curProtein, curProtein.residue_feature)
    curProtein.residue_feature = output['residue_feature']
    curProtein.mol_feature = output['graph_feature']
    protein_pkl = '../../../data/dta-datasets/PDBbind/esm/gearnetesm_' + pklName + '.pkl'
    with utils.smart_open(protein_pkl, "wb") as fout:
        pickle.dump(curProtein, fout)

In [None]:
indexes = range(8000,9018)   # len(transform_list)
indexes = tqdm(indexes, "ESM Computing ......")
for index in indexes:
    curProtein = output_list[index]
    esmProcess(sequence_model, curProtein, str(index))

ESM Computing ......: 100%|██████████| 1018/1018 [34:12<00:00,  2.02s/it]


Read all pkl and aggreate into the new list

In [None]:
import pickle
import torch
from tqdm import tqdm
from torchdrug import utils, transforms, data, models, layers

In [None]:
protein_pkl = '../../../data/dta-datasets/PDBbind/gearnet_pocket_Protein.pkl'
with utils.smart_open(protein_pkl, "rb") as fin:
    output_list = pickle.load(fin)

In [None]:
len(output_list)

9018

In [None]:
output_list

[PackedProtein(batch_size=1, num_atoms=[39], num_bonds=[501], num_residues=[39]),
 PackedProtein(batch_size=1, num_atoms=[22], num_bonds=[246], num_residues=[22]),
 PackedProtein(batch_size=1, num_atoms=[29], num_bonds=[347], num_residues=[29]),
 PackedProtein(batch_size=1, num_atoms=[50], num_bonds=[713], num_residues=[50]),
 PackedProtein(batch_size=1, num_atoms=[73], num_bonds=[1106], num_residues=[73]),
 PackedProtein(batch_size=1, num_atoms=[46], num_bonds=[588], num_residues=[46]),
 PackedProtein(batch_size=1, num_atoms=[73], num_bonds=[1131], num_residues=[73]),
 PackedProtein(batch_size=1, num_atoms=[34], num_bonds=[552], num_residues=[34]),
 PackedProtein(batch_size=1, num_atoms=[27], num_bonds=[340], num_residues=[27]),
 PackedProtein(batch_size=1, num_atoms=[53], num_bonds=[834], num_residues=[53]),
 PackedProtein(batch_size=1, num_atoms=[44], num_bonds=[653], num_residues=[44]),
 PackedProtein(batch_size=1, num_atoms=[43], num_bonds=[632], num_residues=[43]),
 PackedProtein

In [None]:
indexes = range(len(output_list)) # len(output_list)
indexes = tqdm(indexes, "ESM Combining ...")
for index in indexes: 
    # read each esm protein from pkl file
    curProtein = output_list[index]
    protein_pkl = '../../../data/dta-datasets/PDBbind/esm/gearnetesm_'+str(index)+'.pkl'
    with utils.smart_open(protein_pkl, "rb") as fin:
        output_protein = pickle.load(fin)
    output_list[index].residue_feature = output_protein.residue_feature

ESM Combining ...: 100%|██████████| 9018/9018 [02:35<00:00, 57.88it/s] 


In [None]:
output_list[8866].residue_feature.shape

torch.Size([29, 1280])

In [None]:
len(output_list)

9018

In [None]:
protein_pkl = '../../../data/dta-datasets/PDBbind/gearnetesm_pocket_Protein.pkl'
with utils.smart_open(protein_pkl, "wb") as fout:
    pickle.dump(output_list, fout)