# ESM for Protein Residue Embedding

#### GearNet Protein Construction

In [13]:
import pickle
from tqdm import tqdm
from torchdrug import utils, transforms, data, models, layers

In [14]:
protein_pkl = '../../../data/dta-datasets/Davis/pdb_Protein.pkl'
with utils.smart_open(protein_pkl, "rb") as fin:
    protein_list = pickle.load(fin)

In [15]:
protein_list

[Protein(num_atom=7306, num_bond=14932, num_residue=961),
 Protein(num_atom=8961, num_bond=18354, num_residue=1167),
 Protein(num_atom=8961, num_bond=18354, num_residue=1167),
 Protein(num_atom=8961, num_bond=18354, num_residue=1167),
 Protein(num_atom=8961, num_bond=18354, num_residue=1167),
 Protein(num_atom=8961, num_bond=18354, num_residue=1167),
 Protein(num_atom=8961, num_bond=18354, num_residue=1167),
 Protein(num_atom=8961, num_bond=18354, num_residue=1167),
 Protein(num_atom=8961, num_bond=18354, num_residue=1167),
 Protein(num_atom=8961, num_bond=18354, num_residue=1167),
 Protein(num_atom=8961, num_bond=18354, num_residue=1167),
 Protein(num_atom=8961, num_bond=18354, num_residue=1167),
 Protein(num_atom=8961, num_bond=18354, num_residue=1167),
 Protein(num_atom=8961, num_bond=18354, num_residue=1167),
 Protein(num_atom=8961, num_bond=18354, num_residue=1167),
 Protein(num_atom=8961, num_bond=18354, num_residue=1167),
 Protein(num_atom=9023, num_bond=18462, num_residue=1182)

As the ESM model limit the input length with max 1022. The excess part will be truncated, but we can not tell which part is more import. So if the length is between [1023, 2044] two embedding will be added, [2044, 3066] will get three and so on.

In [16]:
graph_construction_model = layers.GraphConstruction(node_layers=[layers.geometry.AlphaCarbonNode()], 
                                                    edge_layers=[layers.geometry.SpatialEdge(radius=10.0, min_distance=5),
                                                                 layers.geometry.KNNEdge(k=10, min_distance=5),
                                                                 layers.geometry.SequentialEdge(max_distance=2)],
                                                    edge_feature="gearnet")

For loop for all 442 protein with different length.

In [6]:
output_list = []
for protein in protein_list:
    curLength = protein.num_residue.tolist()
    if curLength <= 1022:
        graph = data.Protein.pack([protein])
        transGraph = graph_construction_model(graph)
        output_list.append(transGraph)
    elif curLength <= 2044:
        mask1 = list(range(1022))
        mask2 = list(range(1022, curLength))
        part1 = protein.subresidue(mask1)
        part2 = protein.subresidue(mask2)
        graph = data.Protein.pack([part1, part2])
        transGraph = graph_construction_model(graph)
        output_list.append(transGraph)
    elif curLength <= 3066:
        mask1 = list(range(1022))
        mask2 = list(range(1022, 2044))
        mask3 = list(range(2044, curLength))
        part1 = protein.subresidue(mask1)
        part2 = protein.subresidue(mask2)
        part3 = protein.subresidue(mask3)
        graph = data.Protein.pack([part1, part2, part3])
        transGraph = graph_construction_model(graph)
        output_list.append(transGraph)
    else :
        raise ValueError("Error with to long residue")

In [7]:
output_list

[PackedProtein(batch_size=1, num_atoms=[961], num_bonds=[11102], num_residues=[961]),
 PackedProtein(batch_size=2, num_atoms=[1022, 145], num_bonds=[13112, 1909], num_residues=[1022, 145]),
 PackedProtein(batch_size=2, num_atoms=[1022, 145], num_bonds=[13112, 1909], num_residues=[1022, 145]),
 PackedProtein(batch_size=2, num_atoms=[1022, 145], num_bonds=[13112, 1909], num_residues=[1022, 145]),
 PackedProtein(batch_size=2, num_atoms=[1022, 145], num_bonds=[13112, 1909], num_residues=[1022, 145]),
 PackedProtein(batch_size=2, num_atoms=[1022, 145], num_bonds=[13112, 1909], num_residues=[1022, 145]),
 PackedProtein(batch_size=2, num_atoms=[1022, 145], num_bonds=[13112, 1909], num_residues=[1022, 145]),
 PackedProtein(batch_size=2, num_atoms=[1022, 145], num_bonds=[13112, 1909], num_residues=[1022, 145]),
 PackedProtein(batch_size=2, num_atoms=[1022, 145], num_bonds=[13112, 1909], num_residues=[1022, 145]),
 PackedProtein(batch_size=2, num_atoms=[1022, 145], num_bonds=[13112, 1909], num_r

In [8]:
len(output_list)

442

In [9]:
protein_view = transforms.ProteinView(view="residue", keys="graph1")

In [10]:
transform_list = []
for protein in output_list:
    item = {"graph1": protein}
    transform_list.append(protein_view(item)['graph1'])

In [11]:
transform_list

[PackedProtein(batch_size=1, num_atoms=[961], num_bonds=[11102], num_residues=[961]),
 PackedProtein(batch_size=2, num_atoms=[1022, 145], num_bonds=[13112, 1909], num_residues=[1022, 145]),
 PackedProtein(batch_size=2, num_atoms=[1022, 145], num_bonds=[13112, 1909], num_residues=[1022, 145]),
 PackedProtein(batch_size=2, num_atoms=[1022, 145], num_bonds=[13112, 1909], num_residues=[1022, 145]),
 PackedProtein(batch_size=2, num_atoms=[1022, 145], num_bonds=[13112, 1909], num_residues=[1022, 145]),
 PackedProtein(batch_size=2, num_atoms=[1022, 145], num_bonds=[13112, 1909], num_residues=[1022, 145]),
 PackedProtein(batch_size=2, num_atoms=[1022, 145], num_bonds=[13112, 1909], num_residues=[1022, 145]),
 PackedProtein(batch_size=2, num_atoms=[1022, 145], num_bonds=[13112, 1909], num_residues=[1022, 145]),
 PackedProtein(batch_size=2, num_atoms=[1022, 145], num_bonds=[13112, 1909], num_residues=[1022, 145]),
 PackedProtein(batch_size=2, num_atoms=[1022, 145], num_bonds=[13112, 1909], num_r

In [12]:
sequence_model = models.ESM(path="../../../result/model_pth/ESM/", model="ESM-2_650M")

In [17]:
def esmProcess(esmModel, curProtein, pklName):
    """
    curProtein: list of packedProtein 
    pklName: the pkl name to store the protein
    """
    output = esmModel(curProtein, curProtein.residue_feature)
    curProtein.residue_feature = output['residue_feature']
    curProtein.mol_feature = output['graph_feature']
    # print(curProtein.residue_feature.shape)
    protein_pkl = '../../../data/dta-datasets/Davis/esm/gearnetesm_' + pklName + '.pkl'
    with utils.smart_open(protein_pkl, "wb") as fout:
        pickle.dump(curProtein, fout)

In [18]:
indexes = range(5)   # len(transform_list)
indexes = tqdm(indexes, "ESM Computing ......")
for index in indexes:
    curProtein = transform_list[index]
    if curProtein.batch_size == 1:
        esmProcess(sequence_model, curProtein, str(index))
    else:
        for i in range(curProtein.batch_size):
            protein = data.Protein.pack([curProtein[i]])
            esmProcess(sequence_model, protein, str(index) + '-' + str(i))

ESM Computing ......: 100%|██████████| 5/5 [00:45<00:00,  9.19s/it]


Read all pkl and aggreate into the new list

In [20]:
import pickle
import torch
from tqdm import tqdm
from torchdrug import utils, transforms, data, models, layers

In [21]:
protein_pkl = '../../../data/dta-datasets/Davis/pdb_Protein.pkl'
with utils.smart_open(protein_pkl, "rb") as fin:
    protein_list = pickle.load(fin)

In [22]:
protein_list

[Protein(num_atom=7306, num_bond=14932, num_residue=961),
 Protein(num_atom=8961, num_bond=18354, num_residue=1167),
 Protein(num_atom=8961, num_bond=18354, num_residue=1167),
 Protein(num_atom=8961, num_bond=18354, num_residue=1167),
 Protein(num_atom=8961, num_bond=18354, num_residue=1167),
 Protein(num_atom=8961, num_bond=18354, num_residue=1167),
 Protein(num_atom=8961, num_bond=18354, num_residue=1167),
 Protein(num_atom=8961, num_bond=18354, num_residue=1167),
 Protein(num_atom=8961, num_bond=18354, num_residue=1167),
 Protein(num_atom=8961, num_bond=18354, num_residue=1167),
 Protein(num_atom=8961, num_bond=18354, num_residue=1167),
 Protein(num_atom=8961, num_bond=18354, num_residue=1167),
 Protein(num_atom=8961, num_bond=18354, num_residue=1167),
 Protein(num_atom=8961, num_bond=18354, num_residue=1167),
 Protein(num_atom=8961, num_bond=18354, num_residue=1167),
 Protein(num_atom=8961, num_bond=18354, num_residue=1167),
 Protein(num_atom=9023, num_bond=18462, num_residue=1182)

In [23]:
graph_construction_model = layers.GraphConstruction(node_layers=[layers.geometry.AlphaCarbonNode()], 
                                                    edge_layers=[layers.geometry.SpatialEdge(radius=10.0, min_distance=5),
                                                                 layers.geometry.KNNEdge(k=10, min_distance=5),
                                                                 layers.geometry.SequentialEdge(max_distance=2)],
                                                    edge_feature="gearnet")

In [24]:
output_list = []
for protein in protein_list:
    graph = data.Protein.pack([protein])
    transGraph = graph_construction_model(graph)
    output_list.append(transGraph)

In [26]:
len(output_list)

442

In [25]:
output_list

[PackedProtein(batch_size=1, num_atoms=[961], num_bonds=[11102], num_residues=[961]),
 PackedProtein(batch_size=1, num_atoms=[1167], num_bonds=[15017], num_residues=[1167]),
 PackedProtein(batch_size=1, num_atoms=[1167], num_bonds=[15017], num_residues=[1167]),
 PackedProtein(batch_size=1, num_atoms=[1167], num_bonds=[15017], num_residues=[1167]),
 PackedProtein(batch_size=1, num_atoms=[1167], num_bonds=[15017], num_residues=[1167]),
 PackedProtein(batch_size=1, num_atoms=[1167], num_bonds=[15017], num_residues=[1167]),
 PackedProtein(batch_size=1, num_atoms=[1167], num_bonds=[15017], num_residues=[1167]),
 PackedProtein(batch_size=1, num_atoms=[1167], num_bonds=[15017], num_residues=[1167]),
 PackedProtein(batch_size=1, num_atoms=[1167], num_bonds=[15017], num_residues=[1167]),
 PackedProtein(batch_size=1, num_atoms=[1167], num_bonds=[15017], num_residues=[1167]),
 PackedProtein(batch_size=1, num_atoms=[1167], num_bonds=[15017], num_residues=[1167]),
 PackedProtein(batch_size=1, num_a

In [27]:
protein_view = transforms.ProteinView(view="residue", keys="graph1")

In [28]:
transform_list = []
for protein in output_list:
    item = {"graph1": protein}
    transform_list.append(protein_view(item)['graph1'])

In [29]:
transform_list

[PackedProtein(batch_size=1, num_atoms=[961], num_bonds=[11102], num_residues=[961]),
 PackedProtein(batch_size=1, num_atoms=[1167], num_bonds=[15017], num_residues=[1167]),
 PackedProtein(batch_size=1, num_atoms=[1167], num_bonds=[15017], num_residues=[1167]),
 PackedProtein(batch_size=1, num_atoms=[1167], num_bonds=[15017], num_residues=[1167]),
 PackedProtein(batch_size=1, num_atoms=[1167], num_bonds=[15017], num_residues=[1167]),
 PackedProtein(batch_size=1, num_atoms=[1167], num_bonds=[15017], num_residues=[1167]),
 PackedProtein(batch_size=1, num_atoms=[1167], num_bonds=[15017], num_residues=[1167]),
 PackedProtein(batch_size=1, num_atoms=[1167], num_bonds=[15017], num_residues=[1167]),
 PackedProtein(batch_size=1, num_atoms=[1167], num_bonds=[15017], num_residues=[1167]),
 PackedProtein(batch_size=1, num_atoms=[1167], num_bonds=[15017], num_residues=[1167]),
 PackedProtein(batch_size=1, num_atoms=[1167], num_bonds=[15017], num_residues=[1167]),
 PackedProtein(batch_size=1, num_a

Add the 1022 + other residue feature into one protein.

In [30]:
transform_list[1].edge_list[-10:].tolist()

[[1155, 1157, 6],
 [1156, 1158, 6],
 [1157, 1159, 6],
 [1158, 1160, 6],
 [1159, 1161, 6],
 [1160, 1162, 6],
 [1161, 1163, 6],
 [1162, 1164, 6],
 [1163, 1165, 6],
 [1164, 1166, 6]]

In [31]:
import torch

In [32]:
indexes = range(5) 
indexes = tqdm(indexes, "ESM Combining ...")
for index in indexes: 
    # read each esm protein from pkl file
    curProtein = transform_list[index]
    curLength = curProtein.num_residue.tolist()
    if curLength <= 1022:
        protein_pkl = '../../../data/dta-datasets/Davis/esm/gearnetesm_'+str(index)+'.pkl'
        with utils.smart_open(protein_pkl, "rb") as fin:
            output_protein = pickle.load(fin)
        transform_list[index].residue_feature = output_protein.residue_feature
    elif curLength <= 2044:
        for i in range(2):
            protein_pkl = '../../../data/dta-datasets/Davis/esm/gearnetesm_'+str(index)+'-'+str(i)+'.pkl'
            with utils.smart_open(protein_pkl, "rb") as fin:
                output_protein = pickle.load(fin)
            if i == 0: # the 1st part [0, 1022]
                residue_feature = output_protein.residue_feature
            else: # the 2nd part [1022, 2044]
                residue_feature = torch.cat((residue_feature, output_protein.residue_feature), dim=0)
        transform_list[index].residue_feature = residue_feature
    elif curLength <= 3066:
        for i in range(3):
            protein_pkl = '../../../data/dta-datasets/Davis/esm/gearnetesm_'+str(index)+'-'+str(i)+'.pkl'
            with utils.smart_open(protein_pkl, "rb") as fin:
                output_protein = pickle.load(fin)
            if i == 0: # the 1st part [0, 1022]
                residue_feature = output_protein.residue_feature
                # print(residue_feature.shape)
            elif i == 1: # the 2nd part [1022, 2044]
                residue_feature = torch.cat((residue_feature, output_protein.residue_feature), dim=0)
                # print(residue_feature.shape)
            else: # the 3rd part [2044, length]
                residue_feature = torch.cat((residue_feature, output_protein.residue_feature), dim=0)
        print(residue_feature.shape)
        transform_list[index].residue_feature = residue_feature

ESM Combining ...: 100%|██████████| 5/5 [00:00<00:00, 43.86it/s]


In [33]:
transform_list[1].residue_feature.shape

torch.Size([1167, 1280])

In [34]:
transform_list[1].mol_feature.shape

AttributeError: 'PackedProtein' object has no attribute 'mol_feature'

In [44]:
len(transform_list)

442

In [45]:
protein_pkl = '../../../data/dta-datasets/Davis/gearnetesm650m_Protein.pkl'
with utils.smart_open(protein_pkl, "wb") as fout:
    pickle.dump(transform_list, fout)