# ESM for Protein Residue Embedding

#### GearNet Protein Construction without ESM

In [1]:
import pickle
from tqdm import tqdm
from torchdrug import utils, transforms, data, models, layers

No CUDA runtime is found, using CUDA_HOME='C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7'


In [2]:
protein_pkl = '../../../data/dta-datasets/KIBA/pocket_Protein.pkl'
with utils.smart_open(protein_pkl, "rb") as fin:
    protein_list = pickle.load(fin)

In [3]:
protein_list

[Protein(num_atom=586, num_bond=1118, num_residue=73),
 Protein(num_atom=435, num_bond=836, num_residue=55),
 Protein(num_atom=2491, num_bond=4812, num_residue=302),
 Protein(num_atom=1079, num_bond=2072, num_residue=133),
 Protein(num_atom=541, num_bond=1038, num_residue=65),
 Protein(num_atom=719, num_bond=1384, num_residue=87),
 Protein(num_atom=917, num_bond=1754, num_residue=113),
 Protein(num_atom=545, num_bond=1040, num_residue=65),
 Protein(num_atom=598, num_bond=1132, num_residue=76),
 Protein(num_atom=1042, num_bond=1996, num_residue=131),
 Protein(num_atom=596, num_bond=1130, num_residue=74),
 Protein(num_atom=933, num_bond=1792, num_residue=114),
 Protein(num_atom=545, num_bond=1036, num_residue=69),
 Protein(num_atom=67, num_bond=128, num_residue=7),
 Protein(num_atom=580, num_bond=1116, num_residue=67),
 Protein(num_atom=1057, num_bond=2054, num_residue=125),
 Protein(num_atom=2841, num_bond=5518, num_residue=340),
 Protein(num_atom=1007, num_bond=1914, num_residue=122),


In [4]:
graph_construction_model = layers.GraphConstruction(node_layers=[layers.geometry.AlphaCarbonNode()], 
                                                    edge_layers=[layers.geometry.SpatialEdge(radius=10.0, min_distance=5),
                                                                 layers.geometry.KNNEdge(k=10, min_distance=5),
                                                                 layers.geometry.SequentialEdge(max_distance=2)],
                                                    edge_feature="gearnet")

In [5]:
output_list = []
for protein in protein_list:
    graph = data.Protein.pack([protein])
    transGraph = graph_construction_model(graph)
    output_list.append(transGraph)

In [6]:
len(output_list)

229

In [7]:
protein_view = transforms.ProteinView(view="residue", keys="graph1")

In [8]:
transform_list = []
for protein in output_list:
    item = {"graph1": protein}
    transform_list.append(protein_view(item)['graph1'])

In [9]:
transform_list

[PackedProtein(batch_size=1, num_atoms=[73], num_bonds=[1020], num_residues=[73]),
 PackedProtein(batch_size=1, num_atoms=[55], num_bonds=[826], num_residues=[55]),
 PackedProtein(batch_size=1, num_atoms=[302], num_bonds=[4540], num_residues=[302]),
 PackedProtein(batch_size=1, num_atoms=[133], num_bonds=[1951], num_residues=[133]),
 PackedProtein(batch_size=1, num_atoms=[65], num_bonds=[759], num_residues=[65]),
 PackedProtein(batch_size=1, num_atoms=[87], num_bonds=[1294], num_residues=[87]),
 PackedProtein(batch_size=1, num_atoms=[113], num_bonds=[1643], num_residues=[113]),
 PackedProtein(batch_size=1, num_atoms=[65], num_bonds=[956], num_residues=[65]),
 PackedProtein(batch_size=1, num_atoms=[76], num_bonds=[1075], num_residues=[76]),
 PackedProtein(batch_size=1, num_atoms=[131], num_bonds=[1919], num_residues=[131]),
 PackedProtein(batch_size=1, num_atoms=[74], num_bonds=[1115], num_residues=[74]),
 PackedProtein(batch_size=1, num_atoms=[114], num_bonds=[1679], num_residues=[114]

In [10]:
protein_pkl = '../../../data/dta-datasets/KIBA/gearnet_pocket_Protein.pkl'
with utils.smart_open(protein_pkl, "wb") as fout:
    pickle.dump(transform_list, fout)

#### GearNet Protein Construction

In [1]:
import pickle
from tqdm import tqdm
from torchdrug import utils, transforms, data, models, layers

In [2]:
protein_pkl = '../../../data/dta-datasets/KIBA/pdb_Protein.pkl'
with utils.smart_open(protein_pkl, "rb") as fin:
    protein_list = pickle.load(fin)

In [3]:
protein_list

[Protein(num_atom=3454, num_bond=7092, num_residue=431),
 Protein(num_atom=4481, num_bond=9122, num_residue=574),
 Protein(num_atom=8396, num_bond=17184, num_residue=1044),
 Protein(num_atom=5776, num_bond=11822, num_residue=725),
 Protein(num_atom=7660, num_bond=15664, num_residue=970),
 Protein(num_atom=3829, num_bond=7834, num_residue=476),
 Protein(num_atom=6066, num_bond=12346, num_residue=756),
 Protein(num_atom=3228, num_bond=6604, num_residue=403),
 Protein(num_atom=5783, num_bond=11792, num_residue=740),
 Protein(num_atom=5932, num_bond=12108, num_residue=745),
 Protein(num_atom=2964, num_bond=6066, num_residue=365),
 Protein(num_atom=4453, num_bond=9134, num_residue=556),
 Protein(num_atom=3702, num_bond=7522, num_residue=454),
 Protein(num_atom=2136, num_bond=4386, num_residue=272),
 Protein(num_atom=4620, num_bond=9458, num_residue=588),
 Protein(num_atom=5222, num_bond=10662, num_residue=661),
 Protein(num_atom=9192, num_bond=18812, num_residue=1132),
 Protein(num_atom=112

As the ESM model limit the input length with max 1022. The excess part will be truncated, but we can not tell which part is more import. So if the length is between [1023, 2044] two embedding will be added, [2044, 3066] will get three and so on.

In [4]:
graph_construction_model = layers.GraphConstruction(node_layers=[layers.geometry.AlphaCarbonNode()], 
                                                    edge_layers=[layers.geometry.SpatialEdge(radius=10.0, min_distance=5),
                                                                 layers.geometry.KNNEdge(k=10, min_distance=5),
                                                                 layers.geometry.SequentialEdge(max_distance=2)],
                                                    edge_feature="gearnet")

For loop for all 229 protein with different length.

In [5]:
output_list = []
for protein in protein_list:
    print(protein)
    curLength = protein.num_residue.tolist()
    if curLength <= 1022:
        graph = data.Protein.pack([protein])
        transGraph = graph_construction_model(graph)
        output_list.append(transGraph)
    elif curLength <= 2044:
        mask1 = list(range(1022))
        mask2 = list(range(1022, curLength))
        part1 = protein.subresidue(mask1)
        part2 = protein.subresidue(mask2)
        graph = data.Protein.pack([part1, part2])
        transGraph = graph_construction_model(graph)
        output_list.append(transGraph)
    elif curLength <= 3066:
        mask1 = list(range(1022))
        mask2 = list(range(1022, 2044))
        mask3 = list(range(2044, curLength))
        part1 = protein.subresidue(mask1)
        part2 = protein.subresidue(mask2)
        part3 = protein.subresidue(mask3)
        graph = data.Protein.pack([part1, part2, part3])
        transGraph = graph_construction_model(graph)
        output_list.append(transGraph)
    elif curLength <= 5110:
        mask1 = list(range(1022))
        mask2 = list(range(1022, 2044))
        mask3 = list(range(2044, 3066))
        mask4 = list(range(3066, 4088))
        mask5 = list(range(4088, curLength))
        part1 = protein.subresidue(mask1)
        part2 = protein.subresidue(mask2)
        part3 = protein.subresidue(mask3)
        part4 = protein.subresidue(mask4)
        part5 = protein.subresidue(mask5)
        graph = data.Protein.pack([part1, part2, part3, part4, part5])
        transGraph = graph_construction_model(graph)
        output_list.append(transGraph)
    else :
        raise ValueError("Error with to long residue")

Protein(num_atom=3454, num_bond=7092, num_residue=431)
Protein(num_atom=4481, num_bond=9122, num_residue=574)
Protein(num_atom=8396, num_bond=17184, num_residue=1044)
Protein(num_atom=5776, num_bond=11822, num_residue=725)
Protein(num_atom=7660, num_bond=15664, num_residue=970)
Protein(num_atom=3829, num_bond=7834, num_residue=476)
Protein(num_atom=6066, num_bond=12346, num_residue=756)
Protein(num_atom=3228, num_bond=6604, num_residue=403)
Protein(num_atom=5783, num_bond=11792, num_residue=740)
Protein(num_atom=5932, num_bond=12108, num_residue=745)
Protein(num_atom=2964, num_bond=6066, num_residue=365)
Protein(num_atom=4453, num_bond=9134, num_residue=556)
Protein(num_atom=3702, num_bond=7522, num_residue=454)
Protein(num_atom=2136, num_bond=4386, num_residue=272)
Protein(num_atom=4620, num_bond=9458, num_residue=588)
Protein(num_atom=5222, num_bond=10662, num_residue=661)
Protein(num_atom=9192, num_bond=18812, num_residue=1132)
Protein(num_atom=11278, num_bond=22884, num_residue=138

In [6]:
output_list

[PackedProtein(batch_size=1, num_atoms=[431], num_bonds=[7507], num_residues=[431]),
 PackedProtein(batch_size=1, num_atoms=[574], num_bonds=[8301], num_residues=[574]),
 PackedProtein(batch_size=2, num_atoms=[1022, 22], num_bonds=[19103, 220], num_residues=[1022, 22]),
 PackedProtein(batch_size=1, num_atoms=[725], num_bonds=[11806], num_residues=[725]),
 PackedProtein(batch_size=1, num_atoms=[970], num_bonds=[13888], num_residues=[970]),
 PackedProtein(batch_size=1, num_atoms=[476], num_bonds=[8081], num_residues=[476]),
 PackedProtein(batch_size=1, num_atoms=[756], num_bonds=[12719], num_residues=[756]),
 PackedProtein(batch_size=1, num_atoms=[403], num_bonds=[6331], num_residues=[403]),
 PackedProtein(batch_size=1, num_atoms=[740], num_bonds=[11870], num_residues=[740]),
 PackedProtein(batch_size=1, num_atoms=[745], num_bonds=[12311], num_residues=[745]),
 PackedProtein(batch_size=1, num_atoms=[365], num_bonds=[6814], num_residues=[365]),
 PackedProtein(batch_size=1, num_atoms=[556]

In [7]:
len(output_list)

229

In [8]:
protein_view = transforms.ProteinView(view="residue", keys="graph1")

In [9]:
transform_list = []
for protein in output_list:
    item = {"graph1": protein}
    transform_list.append(protein_view(item)['graph1'])

In [10]:
transform_list

[PackedProtein(batch_size=1, num_atoms=[431], num_bonds=[7507], num_residues=[431]),
 PackedProtein(batch_size=1, num_atoms=[574], num_bonds=[8301], num_residues=[574]),
 PackedProtein(batch_size=2, num_atoms=[1022, 22], num_bonds=[19103, 220], num_residues=[1022, 22]),
 PackedProtein(batch_size=1, num_atoms=[725], num_bonds=[11806], num_residues=[725]),
 PackedProtein(batch_size=1, num_atoms=[970], num_bonds=[13888], num_residues=[970]),
 PackedProtein(batch_size=1, num_atoms=[476], num_bonds=[8081], num_residues=[476]),
 PackedProtein(batch_size=1, num_atoms=[756], num_bonds=[12719], num_residues=[756]),
 PackedProtein(batch_size=1, num_atoms=[403], num_bonds=[6331], num_residues=[403]),
 PackedProtein(batch_size=1, num_atoms=[740], num_bonds=[11870], num_residues=[740]),
 PackedProtein(batch_size=1, num_atoms=[745], num_bonds=[12311], num_residues=[745]),
 PackedProtein(batch_size=1, num_atoms=[365], num_bonds=[6814], num_residues=[365]),
 PackedProtein(batch_size=1, num_atoms=[556]

In [11]:
sequence_model = models.ESM(path="../../../result/model_pth/ESM/", model="ESM-2_650M")

In [12]:
def esmProcess(esmModel, curProtein, pklName):
    """
    curProtein: list of packedProtein 
    pklName: the pkl name to store the protein
    """
    output = esmModel(curProtein, curProtein.residue_feature)
    curProtein.residue_feature = output['residue_feature']
    # print(curProtein.residue_feature.shape)
    protein_pkl = '../../../data/dta-datasets/KIBA/esm/gearnetesm_' + pklName + '.pkl'
    with utils.smart_open(protein_pkl, "wb") as fout:
        pickle.dump(curProtein, fout)

- P53779(index:121)
- P78527(index:128)
- Q5S007(index:130)

In [19]:
transform_list[178]

PackedProtein(batch_size=3, num_atoms=[1022, 1022, 483], num_bonds=[16697, 20175, 9679], num_residues=[1022, 1022, 483])

In [20]:
indexes = range(178, 179)  # len(transform_list)
indexes = tqdm(indexes, "ESM Computing ......")
for index in indexes:
    curProtein = transform_list[index]
    if curProtein.batch_size == 1:
        esmProcess(sequence_model, curProtein, str(index))
    else:
        for i in range(curProtein.batch_size):
            protein = data.Protein.pack([curProtein[i]])
            esmProcess(sequence_model, protein, str(index) + '_' + str(i))

ESM Computing ......: 100%|██████████| 1/1 [00:19<00:00, 19.05s/it]


Read all pkl and aggreate into the new list

In [1]:
import pickle
from tqdm import tqdm
from torchdrug import utils, transforms, data, layers

In [2]:
protein_pkl = '../../../data/dta-datasets/KIBA/pdb_Protein.pkl'
with utils.smart_open(protein_pkl, "rb") as fin:
    protein_list = pickle.load(fin)

In [3]:
protein_list

[Protein(num_atom=3454, num_bond=7092, num_residue=431),
 Protein(num_atom=4481, num_bond=9122, num_residue=574),
 Protein(num_atom=8396, num_bond=17184, num_residue=1044),
 Protein(num_atom=5776, num_bond=11822, num_residue=725),
 Protein(num_atom=7660, num_bond=15664, num_residue=970),
 Protein(num_atom=3829, num_bond=7834, num_residue=476),
 Protein(num_atom=6066, num_bond=12346, num_residue=756),
 Protein(num_atom=3228, num_bond=6604, num_residue=403),
 Protein(num_atom=5783, num_bond=11792, num_residue=740),
 Protein(num_atom=5932, num_bond=12108, num_residue=745),
 Protein(num_atom=2964, num_bond=6066, num_residue=365),
 Protein(num_atom=4453, num_bond=9134, num_residue=556),
 Protein(num_atom=3702, num_bond=7522, num_residue=454),
 Protein(num_atom=2136, num_bond=4386, num_residue=272),
 Protein(num_atom=4620, num_bond=9458, num_residue=588),
 Protein(num_atom=5222, num_bond=10662, num_residue=661),
 Protein(num_atom=9192, num_bond=18812, num_residue=1132),
 Protein(num_atom=112

In [4]:
graph_construction_model = layers.GraphConstruction(node_layers=[layers.geometry.AlphaCarbonNode()], 
                                                    edge_layers=[layers.geometry.SpatialEdge(radius=10.0, min_distance=5),
                                                                 layers.geometry.KNNEdge(k=10, min_distance=5),
                                                                 layers.geometry.SequentialEdge(max_distance=2)],
                                                    edge_feature="gearnet")

In [5]:
output_list = []
for protein in protein_list:
    graph = data.Protein.pack([protein])
    transGraph = graph_construction_model(graph)
    output_list.append(transGraph)

In [6]:
len(output_list)

229

In [7]:
output_list

[PackedProtein(batch_size=1, num_atoms=[431], num_bonds=[7507], num_residues=[431]),
 PackedProtein(batch_size=1, num_atoms=[574], num_bonds=[8301], num_residues=[574]),
 PackedProtein(batch_size=1, num_atoms=[1044], num_bonds=[19549], num_residues=[1044]),
 PackedProtein(batch_size=1, num_atoms=[725], num_bonds=[11806], num_residues=[725]),
 PackedProtein(batch_size=1, num_atoms=[970], num_bonds=[13888], num_residues=[970]),
 PackedProtein(batch_size=1, num_atoms=[476], num_bonds=[8081], num_residues=[476]),
 PackedProtein(batch_size=1, num_atoms=[756], num_bonds=[12719], num_residues=[756]),
 PackedProtein(batch_size=1, num_atoms=[403], num_bonds=[6331], num_residues=[403]),
 PackedProtein(batch_size=1, num_atoms=[740], num_bonds=[11870], num_residues=[740]),
 PackedProtein(batch_size=1, num_atoms=[745], num_bonds=[12311], num_residues=[745]),
 PackedProtein(batch_size=1, num_atoms=[365], num_bonds=[6814], num_residues=[365]),
 PackedProtein(batch_size=1, num_atoms=[556], num_bonds=[

In [8]:
protein_view = transforms.ProteinView(view="residue", keys="graph1")

In [9]:
transform_list = []
for protein in output_list:
    item = {"graph1": protein}
    transform_list.append(protein_view(item)['graph1'])

In [10]:
transform_list

[PackedProtein(batch_size=1, num_atoms=[431], num_bonds=[7507], num_residues=[431]),
 PackedProtein(batch_size=1, num_atoms=[574], num_bonds=[8301], num_residues=[574]),
 PackedProtein(batch_size=1, num_atoms=[1044], num_bonds=[19549], num_residues=[1044]),
 PackedProtein(batch_size=1, num_atoms=[725], num_bonds=[11806], num_residues=[725]),
 PackedProtein(batch_size=1, num_atoms=[970], num_bonds=[13888], num_residues=[970]),
 PackedProtein(batch_size=1, num_atoms=[476], num_bonds=[8081], num_residues=[476]),
 PackedProtein(batch_size=1, num_atoms=[756], num_bonds=[12719], num_residues=[756]),
 PackedProtein(batch_size=1, num_atoms=[403], num_bonds=[6331], num_residues=[403]),
 PackedProtein(batch_size=1, num_atoms=[740], num_bonds=[11870], num_residues=[740]),
 PackedProtein(batch_size=1, num_atoms=[745], num_bonds=[12311], num_residues=[745]),
 PackedProtein(batch_size=1, num_atoms=[365], num_bonds=[6814], num_residues=[365]),
 PackedProtein(batch_size=1, num_atoms=[556], num_bonds=[

Add the 1022 + other residue feature into one protein.

In [11]:
transform_list[43].edge_list[-10:].tolist()

[[2335, 2337, 6],
 [2336, 2338, 6],
 [2337, 2339, 6],
 [2338, 2340, 6],
 [2339, 2341, 6],
 [2340, 2342, 6],
 [2341, 2343, 6],
 [2342, 2344, 6],
 [2343, 2345, 6],
 [2344, 2346, 6]]

In [12]:
import torch

In [13]:
indexes = range(len(transform_list)) 
indexes = tqdm(indexes, "ESM Combining ...")
for index in indexes: 
    # read each esm protein from pkl file
    curProtein = transform_list[index]
    curLength = curProtein.num_residue.tolist()
    if curLength <= 1022:
        protein_pkl = '../../../data/dta-datasets/KIBA/esm/gearnetesm_'+str(index)+'.pkl'
        with utils.smart_open(protein_pkl, "rb") as fin:
            output_protein = pickle.load(fin)
        transform_list[index].residue_feature = output_protein.residue_feature
    elif curLength <= 2044:
        for i in range(2):
            protein_pkl = '../../../data/dta-datasets/KIBA/esm/gearnetesm_'+str(index)+'_'+str(i)+'.pkl'
            with utils.smart_open(protein_pkl, "rb") as fin:
                output_protein = pickle.load(fin)
            if i == 0: # the 1st part [0, 1022]
                residue_feature = output_protein.residue_feature
            else: # the 2nd part [1022, 2044]
                residue_feature = torch.cat((residue_feature, output_protein.residue_feature), dim=0)
        transform_list[index].residue_feature = residue_feature
    elif curLength <= 3066:
        for i in range(3):
            protein_pkl = '../../../data/dta-datasets/KIBA/esm/gearnetesm_'+str(index)+'_'+str(i)+'.pkl'
            with utils.smart_open(protein_pkl, "rb") as fin:
                output_protein = pickle.load(fin)
            if i == 0: # the 1st part [0, 1022]
                residue_feature = output_protein.residue_feature
            else: 
                residue_feature = torch.cat((residue_feature, output_protein.residue_feature), dim=0)
        transform_list[index].residue_feature = residue_feature
    elif curLength <= 5110:
        for i in range(5):
            protein_pkl = '../../../data/dta-datasets/KIBA/esm/gearnetesm_'+str(index)+'_'+str(i)+'.pkl'
            with utils.smart_open(protein_pkl, "rb") as fin:
                output_protein = pickle.load(fin)
            if i == 0: # the 1st part [0, 1022]
                residue_feature = output_protein.residue_feature
            else: 
                residue_feature = torch.cat((residue_feature, output_protein.residue_feature), dim=0)
        # print(residue_feature.shape)
        transform_list[index].residue_feature = residue_feature

ESM Combining ...: 100%|██████████| 229/229 [00:07<00:00, 29.57it/s]


In [14]:
transform_list[128].residue_feature.shape

torch.Size([4128, 1280])

In [15]:
len(transform_list)

229

In [16]:
protein_pkl = '../../../data/dta-datasets/KIBA/gearnetesm650m_Protein.pkl'
with utils.smart_open(protein_pkl, "wb") as fout:
    pickle.dump(transform_list, fout)