In [1]:
import torchdrug as td
import pandas as pd
from torchdrug import data, utils
import os

df = pd.read_csv("uniprot_to_pdb_id.csv")
uniprots = list(df["From"])
pdb = list(df["To"])

uniprot_to_pdb_map = {}
for u, p in zip(uniprots, pdb):
    uniprot_to_pdb_map[u] = p
    

uniprot = "Q96BR1"
pdb_id = uniprot_to_pdb_map[uniprot]
url = "https://files.rcsb.org/download/"
url += pdb_id + ".pdb"
pdb_file = utils.download(url, "./")
protein = data.Protein.from_pdb(pdb_file, atom_feature="position", bond_feature="length", residue_feature="symbol")
print(protein)
print(protein.residue_feature.shape)
print(protein.atom_feature.shape)
print(protein.bond_feature.shape)


18:37:31   Downloading https://files.rcsb.org/download/6EDX.pdb to ./6EDX.pdb
Protein(num_atom=979, num_bond=1968, num_residue=136)
torch.Size([136, 21])
torch.Size([979, 3])
torch.Size([1968, 1])




In [2]:
for residue_id, chain_id in zip(protein.residue_type.tolist()[:10], protein.chain_id.tolist()[:10]):
    print(data.Protein.id2residue[residue_id], chain_id)
    
for atom, position in zip(protein.atom_name.tolist()[:10], protein.node_position.tolist()[:10]):
    print(data.Protein.id2atom_name[atom], position)

SER 1
CYS 1
PRO 1
SER 1
VAL 1
SER 1
ILE 1
PRO 1
SER 1
SER 1
N [-9.36400032043457, 24.341999053955078, -28.27400016784668]
CA [-8.71500015258789, 23.16699981689453, -28.834999084472656]
C [-9.642999649047852, 21.950000762939453, -28.81800079345703]
O [-10.770999908447266, 22.020000457763672, -28.31599998474121]
CB [-7.426000118255615, 22.858999252319336, -28.06599998474121]
OG [-6.497000217437744, 23.920000076293945, -28.19700050354004]
N [-9.166000366210938, 20.847000122070312, -29.395000457763672]
CA [-9.835000038146973, 19.55699920654297, -29.277999877929688]
C [-9.65999984741211, 19.024999618530273, -27.85300064086914]
O [-8.53499984741211, 18.756000518798828, -27.42099952697754]


In [3]:
# Construct protein from sequence
import time

aa_seq = "MGKKHKKHKSDKHLYEEYVEKPLKLVLKVGGNEVTELSTGSSGHDSSLFEDKNDHDKHKDRKRKKRKKGEKQIPGEEKGRKRRRVKEDKKKRDRDRVENEAEKDLQCHAPVRLDLPPEKPLTSSLAKQEEVEQTPLQEALNQLMRQLQRKDPSAFFSFPVTDFMSSRAIGAPQEGADMKNKAARLGTYTNWPVQFLEPSRMAASGFYYLGRGDEVRCAFCKVEITNWVRGDDPETDHKRWAPQCPFVRNNAHDTPHDRAPPARSAAAHPQYATEAARLRTFAEWPRGLKQRPEELAEAGFFYTGQGDKTRCFCCDGGLKDWEPDDAPWQQHARWYDRCEYVLLVKGRDFVQRVMTEACVVRDADNEPHIERPAVEAEVADDRLCKICLGAEKTVCFVPCGHVVACGKCAAGVTTCPVCRGQLDKAVRMYQVGYSMIIKHPMDFSTMKEKIKNNDYQSIEELKDNFKLMCTNAMIYNKPETIYYKAAKKLLHSGMKILSQERIQSLKQSIDFMADLQKTRKQKDGTDTSQSGEDGGCWQREREDSGDAEAHAFKSPSKENKKKDKDMLEDKFKSNNLEREQEQLDRIVKESGGKLTRRLVNSQCEFERRKPDGTTTLGLLHPVDPIVGEPGYCPVRLGMTTGRLQSGVNTLQGFKEDKRNKVTPVLYLNYGPYSSYAPHYDSTFANISKDDSDLIYSTYGEDSDLPSDFSIHEFLATCQDYPYVMADSLLDVLTKGGHSRTLQEMEMSLPEDEGHTRTLDTAKEMEITEVEPPGRLDSSTQDRLIALKAVTNFGVPVEVFDSEEAEIFQKKLDETTRLLRELQEAQNERLSTRPPPNMICLLGPSYREMHLAEQVTNNLKELAQQVTPGDIVSTYGVRKAMGISIPSPVMENNFVDLTEDTEEPKKTDVAECGPGGS"
print(aa_seq)

# start_time = time.time()
# seq_protein = data.Protein.from_sequence(aa_seq, atom_feature="symbol", bond_feature="length", residue_feature="symbol")
# end_time = time.time()
# print("Duration of construction: ", end_time - start_time)
# print(seq_protein)

start_time = time.time()
seq_protein = data.Protein.from_sequence(aa_seq, atom_feature=None, bond_feature=None, residue_feature="default")
end_time = time.time()
print("Duration of construction:, ", end_time - start_time)
print(seq_protein)
                                

MGKKHKKHKSDKHLYEEYVEKPLKLVLKVGGNEVTELSTGSSGHDSSLFEDKNDHDKHKDRKRKKRKKGEKQIPGEEKGRKRRRVKEDKKKRDRDRVENEAEKDLQCHAPVRLDLPPEKPLTSSLAKQEEVEQTPLQEALNQLMRQLQRKDPSAFFSFPVTDFMSSRAIGAPQEGADMKNKAARLGTYTNWPVQFLEPSRMAASGFYYLGRGDEVRCAFCKVEITNWVRGDDPETDHKRWAPQCPFVRNNAHDTPHDRAPPARSAAAHPQYATEAARLRTFAEWPRGLKQRPEELAEAGFFYTGQGDKTRCFCCDGGLKDWEPDDAPWQQHARWYDRCEYVLLVKGRDFVQRVMTEACVVRDADNEPHIERPAVEAEVADDRLCKICLGAEKTVCFVPCGHVVACGKCAAGVTTCPVCRGQLDKAVRMYQVGYSMIIKHPMDFSTMKEKIKNNDYQSIEELKDNFKLMCTNAMIYNKPETIYYKAAKKLLHSGMKILSQERIQSLKQSIDFMADLQKTRKQKDGTDTSQSGEDGGCWQREREDSGDAEAHAFKSPSKENKKKDKDMLEDKFKSNNLEREQEQLDRIVKESGGKLTRRLVNSQCEFERRKPDGTTTLGLLHPVDPIVGEPGYCPVRLGMTTGRLQSGVNTLQGFKEDKRNKVTPVLYLNYGPYSSYAPHYDSTFANISKDDSDLIYSTYGEDSDLPSDFSIHEFLATCQDYPYVMADSLLDVLTKGGHSRTLQEMEMSLPEDEGHTRTLDTAKEMEITEVEPPGRLDSSTQDRLIALKAVTNFGVPVEVFDSEEAEIFQKKLDETTRLLRELQEAQNERLSTRPPPNMICLLGPSYREMHLAEQVTNNLKELAQQVTPGDIVSTYGVRKAMGISIPSPVMENNFVDLTEDTEEPKKTDVAECGPGGS
Duration of construction:,  0.007983684539794922
Protein(num_atom=0, num_bond=0, nu

In [4]:
from tqdm import tqdm as tqdm

#batch proteins
protac_df = pd.read_csv("filtered_protacs.csv")
proteins = list(protac_df["Target Protein"])

my_proteins = []
for p in tqdm(proteins):
    seq_protein = data.Protein.from_sequence(aa_seq, atom_feature=None, bond_feature=None, residue_feature="default")
    my_proteins.append(seq_protein)
    
proteins = data.Protein.pack(my_proteins)

print(proteins)
prot_list = [0, 10, 19, 22, 487, 1821, 1453, 267]
sub_proteins = proteins[prot_list]
print(sub_proteins)

100%|█████████████████████████████████████████████████████████████████████████████| 5388/5388 [00:36<00:00, 148.65it/s]


PackedProtein(batch_size=5388, num_atoms=[0, 0, 0, ..., 0, 0, 0], num_bonds=[0, 0, 0, ..., 0, 0, 0], num_residues=[916, 916, 916, ..., 916, 916, 916])
PackedProtein(batch_size=8, num_atoms=[0, 0, 0, 0, 0, 0, 0, 0], num_bonds=[0, 0, 0, 0, 0, 0, 0, 0], num_residues=[916, 916, 916, 916, 916, 916, 916, 916])


In [5]:
# atom to residue and residue to atom
for atom_id, (atom, residue_id) in enumerate(zip(protein.atom_name.tolist()[:20], protein.atom2residue.tolist()[:20])):
    print("[atom ", atom_id, "] ", data.Protein.id2atom_name[atom], ": ", data.Protein.id2residue[residue_id])
    
for residue_id in [0, 1]:
    atom_ids = protein.residue2atom(residue_id).sort()[0]
    for atom, position in zip(protein.atom_name[atom_ids].tolist(), protein.node_position[atom_ids].tolist()):
         print("[residue ", residue_id, "] ", data.Protein.id2atom_name[atom], ": ", position)
    

[atom  0 ]  N :  GLY
[atom  1 ]  CA :  GLY
[atom  2 ]  C :  GLY
[atom  3 ]  O :  GLY
[atom  4 ]  CB :  GLY
[atom  5 ]  OG :  GLY
[atom  6 ]  N :  ALA
[atom  7 ]  CA :  ALA
[atom  8 ]  C :  ALA
[atom  9 ]  O :  ALA
[atom  10 ]  CB :  ALA
[atom  11 ]  SG :  ALA
[atom  12 ]  N :  SER
[atom  13 ]  CA :  SER
[atom  14 ]  C :  SER
[atom  15 ]  O :  SER
[atom  16 ]  CB :  SER
[atom  17 ]  CG :  SER
[atom  18 ]  CD :  SER
[atom  19 ]  N :  PRO
[residue  0 ]  N :  [-9.36400032043457, 24.341999053955078, -28.27400016784668]
[residue  0 ]  CA :  [-8.71500015258789, 23.16699981689453, -28.834999084472656]
[residue  0 ]  C :  [-9.642999649047852, 21.950000762939453, -28.81800079345703]
[residue  0 ]  O :  [-10.770999908447266, 22.020000457763672, -28.31599998474121]
[residue  0 ]  CB :  [-7.426000118255615, 22.858999252319336, -28.06599998474121]
[residue  0 ]  OG :  [-6.497000217437744, 23.920000076293945, -28.19700050354004]
[residue  1 ]  N :  [-9.166000366210938, 20.847000122070312, -29.3950004

In [6]:
#subprotein and masking
first_two = protein[:2]

is_first_two_ = (protein.atom2residue == 0) | (protein.atom2residue == 1)
first_two_ = protein.node_mask(is_first_two_, compact=True)
assert first_two == first_two_

In [7]:
# atom and residue views
protein.view = "atom"
print(protein.node_feature.shape)
protein.view = "residue"
print(protein.node_feature.shape)

torch.Size([979, 3])
torch.Size([136, 21])


In [8]:
import torch

# customised residue and atom attributes
from torch_scatter import scatter_add

next_residue_type = torch.cat([protein.residue_type[1:], torch.full((1,), -1, dtype=protein.residue_type.dtype)])
followed_by_GLY = next_residue_type == data.Protein.residue2id["GLY"]
with protein.residue():
    protein.followed_by_GLY = followed_by_GLY
    

atom_in, atom_out = protein.edge_list.t()[:2]
attached_to_N = scatter_add(protein.atom_type[atom_in] == td.NITROGEN, atom_out, dim_size=protein.num_node)
with protein.atom():
    protein.attached_to_N = attached_to_N
    

In [9]:
# link residue/atom to another residue/atom
from torch_scatter import scatter_max

range_ = torch.arange(protein.num_node)
calpha = torch.where(protein.atom_name == protein.atom_name2id["CA"], range_, -1)
residue2calpha = scatter_max(calpha, protein.atom2residue, dim_size=protein.num_residue)[0]
with protein.residue(), protein.atom_reference():
    protein.residue2calpha = residue2calpha
    
sub_protein = protein[3:10]
for calpha_index in sub_protein.residue2calpha.tolist():
    atom_name = data.Protein.id2atom_name[sub_protein.atom_name[calpha_index].item()]
    print("New index ", calpha_index, ": ", atom_name)

New index  1 :  CA
New index  7 :  CA
New index  14 :  CA
New index  20 :  CA
New index  28 :  CA
New index  35 :  CA
New index  41 :  CA


In [10]:
# two layer 1D CNN for protein sequence representation
from torchdrug import models
from torchdrug import transforms

model = models.ProteinCNN(input_dim=21,
                          hidden_dims=[1024, 1024],
                          kernel_size=5, padding=2, readout="max")

# protein transformations
truncate_transform = transforms.TruncateProtein(max_length=200, random=False)
protein_view_transform = transforms.ProteinView(view="residue")
transform = transforms.Compose([truncate_transform, protein_view_transform])

In [11]:
from torchdrug import datasets

dataset = datasets.BetaLactamase("~/protein-datasets/", atom_feature=None, bond_feature=None, residue_feature="default", transform=transform)
train_set, valid_set, test_set = dataset.split()
print("The label of first sample:", dataset[0][dataset.target_fields[0]])
print("train samples: ", len(train_set), "valid samples: ", len(valid_set), "test samples: ", len(test_set))


18:38:09   Extracting C:\Users\antre/protein-datasets/beta_lactamase.tar.gz to C:\Users\antre/protein-datasets


Constructing proteins from sequences: 100%|███████████████████████████████████████| 5198/5198 [00:16<00:00, 319.46it/s]

The label of first sample: 0.9426838159561157
train samples:  4158 valid samples:  520 test samples:  520





In [12]:
from torchdrug import tasks

# append MLP prediction task on top of CNN
task = tasks.PropertyPrediction(model, task=dataset.tasks,
                                criterion="mse", metric=("mae", "rmse", "spermanr"),
                                normalization=False, num_mlp_layer=2)

In [14]:
# # train model and evaluate on validation set
# from torchdrug import core

# optimizer = torch.optim.Adam(task.parameters(), lr=1e-4)
# solver = core.Engine(task, train_set, valid_set, test_set, optimizer, 
#                      gpus=[0], batch_size=64)
# solver.train(num_epoch=10)
# solver.evaluate("valid")

In [19]:
dataset = datasets.SecondaryStructure("~/protein-datasets/", atom_feature=None, bond_feature=None, residue_feature="default", transform=protein_view_transform)
train_set, valid_set, test_set = dataset.split(["train", "valid", "cb513"])
print("SS3 label: ", dataset[0]["graph"].target[:10])
print("Valid mask: ", dataset[0]["graph"].mask[:10])
print("train samples: ", len(train_set), ", valid samples: ", len(valid_set), ", test samples: ", len(test_set))



18:23:10   Downloading http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/secondary_structure.tar.gz to C:\Users\antre/protein-datasets/secondary_structure.tar.gz
18:24:15   Extracting C:\Users\antre/protein-datasets/secondary_structure.tar.gz to C:\Users\antre/protein-datasets


Constructing proteins from sequences: 100%|█████████████████████████████████████| 11497/11497 [00:26<00:00, 437.39it/s]

SS3 label:  tensor([2, 2, 2, 0, 0, 0, 0, 0, 2, 2])
Valid mask:  tensor([True, True, True, True, True, True, True, True, True, True])
train samples:  8678 , valid samples:  2170 , test samples:  513





In [20]:
# append 1D CNN with task-specific MLP
task = tasks.NodePropertyPrediction(model, criterion="ce",
                                    metric=("micro_acc", "macro_acc"),
                                    num_mlp_layer=2, num_class=3)

In [15]:
# optimizer = torch.optim.Adam(task.parameters(), lr=1e-4)
# solver = core.Engine(task, train_set, valid_set, test_set, optimizer,
#                      gpus=[0], batch_size=128)
# solver.train(num_epoch=5)
# solver.evaluate("valid")

In [18]:
truncate_transform = transforms.TruncateProtein(max_length=350, random=False)
protein_view_transform = transforms.ProteinView(view="residue")
transform = transforms.Compose([truncate_transform, protein_view_transform])

class EnzymeCommissionToy(datasets.EnzymeCommission):
    url = "https://miladeepgraphlearningproteindata.s3.us-east-2.amazonaws.com/data/EnzymeCommission.tar.gz"
    md5 = "728e0625d1eb513fa9b7626e4d3bcf4d"
    processed_file = "enzyme_commission_toy.pkl.gz"
    test_cutoffs = [0.3, 0.4, 0.5, 0.7, 0.95]
    

In [22]:
import warnings
warnings.filterwarnings("ignore")

start_time = time.time()
dataset = EnzymeCommissionToy("~/protein-datasets/", transform=transform, atom_feature=None,
                              bond_feature=None)
end_time = time.time()
print("Duration of first instantiation: ", end_time - start_time)

start_time = time.time()
dataset = EnzymeCommissionToy("~/protein-datasets/", transform=transform, atom_feature=None, 
                            bond_feature=None)
end_time = time.time()
print("Duration of second instantiation: ", end_time - start_time)

train_set, valid_set, test_set = dataset.split()
print("train samples: ", len(train_set), ", valid samples: ", len(valid_set), "test samples: ", len(test_set))



19:13:12   Extracting C:\Users\antre/protein-datasets/EnzymeCommission.tar.gz to C:\Users\antre/protein-datasets


Loading C:\Users\antre/protein-datasets\EnzymeCommission\enzyme_commission_toy.pkl.gz: 100%|█| 1151/1151 [00:04<00:00, 


Duration of first instantiation:  5.375838041305542
19:13:17   Extracting C:\Users\antre/protein-datasets/EnzymeCommission.tar.gz to C:\Users\antre/protein-datasets


Loading C:\Users\antre/protein-datasets\EnzymeCommission\enzyme_commission_toy.pkl.gz: 100%|█| 1151/1151 [00:05<00:00, 


Duration of second instantiation:  5.651062488555908
train samples:  959 , valid samples:  97 test samples:  95


In [24]:
protein = dataset[0]["graph"]

ImportError: DLL load failed while importing torch_ext: The specified module could not be found.