In [5]:
from torchdrug import transforms, datasets
import torch

truncate_transform = transforms.TruncateProtein(max_length=350, random=False)
protein_view_transform = transforms.ProteinView(view="residue")
transform = transforms.Compose([truncate_transform, protein_view_transform])

dataset = datasets.AlphaFoldDB(path = '/Users/harper.h/Documents/cs224w/final_project/yeast_alphafold', transform=transform, atom_feature=None, 
                            bond_feature=None)
lengths = [int(0.8 * len(dataset)), int(0.1 * len(dataset))]
lengths += [len(dataset) - sum(lengths)]
train_set, valid_set, test_set = torch.utils.data.random_split(dataset, lengths)


Loading /Users/harper.h/Documents/cs224w/final_project/yeast_alphafold/UP000002311_559292_YEAST_v2_0.pkl.gz: 100%|██████████| 6026/6026 [00:22<00:00, 263.03it/s]


In [8]:
dataset

AlphaFoldDB(
  #sample: 6026
)

In [7]:
len(train_set)

4820

In [3]:
train_set, valid_set, test_set = dataset.split()

AttributeError: 'AlphaFoldDB' object has no attribute 'split'

In [12]:
from torchdrug import transforms

truncate_transform = transforms.TruncateProtein(max_length=350, random=False)
protein_view_transform = transforms.ProteinView(view="residue")
transform = transforms.Compose([truncate_transform, protein_view_transform])

In [13]:
from torchdrug import datasets

class EnzymeCommissionToy(datasets.EnzymeCommission):
    url = "https://miladeepgraphlearningproteindata.s3.us-east-2.amazonaws.com/data/EnzymeCommission.tar.gz"
    md5 = "728e0625d1eb513fa9b7626e4d3bcf4d"
    processed_file = "enzyme_commission_toy.pkl.gz"
    test_cutoffs = [0.3, 0.4, 0.5, 0.7, 0.95]

In [14]:
import time

start_time = time.time()
dataset = EnzymeCommissionToy("~/protein-datasets/", transform=transform, atom_feature=None, 
                            bond_feature=None)
end_time = time.time()
print("Duration of first instantiation: ", end_time - start_time)

start_time = time.time()
dataset = EnzymeCommissionToy("~/protein-datasets/", transform=transform, atom_feature=None, 
                            bond_feature=None)
end_time = time.time()
print("Duration of second instantiation: ", end_time - start_time)

train_set, valid_set, test_set = dataset.split()
print("Shape of function labels for a protein: ", dataset[0]["targets"].shape)
print("train samples: %d, valid samples: %d, test samples: %d" % (len(train_set), len(valid_set), len(test_set)))

00:37:48   Extracting /Users/harper.h/protein-datasets/EnzymeCommission.tar.gz to /Users/harper.h/protein-datasets


Loading /Users/harper.h/protein-datasets/EnzymeCommission/enzyme_commission_toy.pkl.gz: 100%|██████████| 1169/1169 [00:01<00:00, 655.63it/s]


Duration of first instantiation:  2.1469919681549072
00:37:50   Extracting /Users/harper.h/protein-datasets/EnzymeCommission.tar.gz to /Users/harper.h/protein-datasets


Loading /Users/harper.h/protein-datasets/EnzymeCommission/enzyme_commission_toy.pkl.gz: 100%|██████████| 1169/1169 [00:01<00:00, 622.62it/s]


Duration of second instantiation:  2.274559259414673
Shape of function labels for a protein:  torch.Size([538])
train samples: 974, valid samples: 97, test samples: 98


In [15]:
from torchdrug import data

protein = dataset[0]["graph"]
is_first_two = (protein.residue_number == 1) | (protein.residue_number == 2)
first_two = protein.residue_mask(is_first_two, compact=True)
first_two.visualize()

  fig.show()


In [16]:
from torchdrug import layers
from torchdrug.layers import geometry

graph_construction_model = layers.GraphConstruction(node_layers=[geometry.AlphaCarbonNode()])

_protein = data.Protein.pack([protein])
protein_ = graph_construction_model(_protein)
print("Graph before: ", _protein)
print("Graph after: ", protein_)

Graph before:  PackedProtein(batch_size=1, num_atoms=[2639], num_bonds=[5368], num_residues=[350])
Graph after:  PackedProtein(batch_size=1, num_atoms=[350], num_bonds=[0], num_residues=[350])


In [17]:
graph_construction_model = layers.GraphConstruction(node_layers=[geometry.AlphaCarbonNode()], 
                                                    edge_layers=[geometry.SpatialEdge(radius=10.0, min_distance=5),
                                                                 geometry.KNNEdge(k=10, min_distance=5),
                                                                 geometry.SequentialEdge(max_distance=2)])

_protein = data.Protein.pack([protein])
protein_ = graph_construction_model(_protein)
print("Graph before: ", _protein)
print("Graph after: ", protein_)

degree = protein_.degree_in + protein_.degree_out
print("Average degree: ", degree.mean())
print("Maximum degree: ", degree.max())
print("Minimum degree: ", degree.min())
print("Number of zero-degree nodes: ", (degree == 0).sum())

Graph before:  PackedProtein(batch_size=1, num_atoms=[2639], num_bonds=[5368], num_residues=[350])
Graph after:  PackedProtein(batch_size=1, num_atoms=[350], num_bonds=[7276], num_residues=[350])
Average degree:  tensor(41.5771)
Maximum degree:  tensor(76.)
Minimum degree:  tensor(12.)
Number of zero-degree nodes:  tensor(0)


In [18]:
from torchdrug import models

gearnet = models.GearNet(input_dim=21, hidden_dims=[512, 512, 512], num_relation=7,
                         batch_norm=True, concat_hidden=True, short_cut=True, readout="sum")
gearnet_edge = models.GearNet(input_dim=21, hidden_dims=[512, 512, 512], 
                              num_relation=7, edge_input_dim=59, num_angle_bin=8,
                              batch_norm=True, concat_hidden=True, short_cut=True, readout="sum")

In [19]:
from torchdrug import tasks

In [20]:
graph_construction_model = layers.GraphConstruction(node_layers=[geometry.AlphaCarbonNode()], 
                                                    edge_layers=[geometry.SpatialEdge(radius=10.0, min_distance=5),
                                                                 geometry.KNNEdge(k=10, min_distance=5),
                                                                 geometry.SequentialEdge(max_distance=2)],
                                                    edge_feature="gearnet")

task = tasks.MultipleBinaryClassification(gearnet, graph_construction_model=graph_construction_model, num_mlp_layer=3,
                                          task=[_ for _ in range(len(dataset.tasks))], criterion="bce", metric=["auprc@micro", "f1_max"])

In [21]:
from torchdrug import core
import torch

optimizer = torch.optim.Adam(task.parameters(), lr=1e-4)
solver = core.Engine(task, train_set, valid_set, test_set, optimizer,
                      batch_size=4)
solver.train(num_epoch=10)


00:37:52   Preprocess training set
00:37:53   {'batch_size': 4,
 'class': 'core.Engine',
 'gpus': None,
 'gradient_interval': 1,
 'log_interval': 100,
 'logger': 'logging',
 'num_worker': 0,
 'optimizer': {'amsgrad': False,
               'betas': (0.9, 0.999),
               'capturable': False,
               'class': 'optim.Adam',
               'differentiable': False,
               'eps': 1e-08,
               'foreach': None,
               'fused': False,
               'lr': 0.0001,
               'maximize': False,
               'weight_decay': 0},
 'scheduler': None,
 'task': {'class': 'tasks.MultipleBinaryClassification',
          'criterion': 'bce',
          'graph_construction_model': {'class': 'layers.GraphConstruction',
                                       'edge_feature': 'gearnet',
                                       'edge_layers': [{'class': 'layers.geometry.SpatialEdge',
                                                        'max_distance': None,
           

ValueError: not enough values to unpack (expected 3, got 2)

In [14]:
solver.evaluate("valid")

22:04:41   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
22:04:41   Evaluate on valid
torch.Size([1209, 1209, 7]) torch.Size([1209, 21])
torch.Size([1209, 8463])
7 21
torch.Size([8463, 21])
(tensor(1209), 147)
torch.Size([1209, 1209, 7]) torch.Size([1209, 512])
torch.Size([1209, 8463])
7 512
torch.Size([8463, 512])
(tensor(1209), 3584)
torch.Size([1209, 1209, 7]) torch.Size([1209, 512])
torch.Size([1209, 8463])
7 512
torch.Size([8463, 512])
(tensor(1209), 3584)
torch.Size([973, 973, 7]) torch.Size([973, 21])
torch.Size([973, 6811])
7 21
torch.Size([6811, 21])
(tensor(973), 147)
torch.Size([973, 973, 7]) torch.Size([973, 512])
torch.Size([973, 6811])
7 512
torch.Size([6811, 512])
(tensor(973), 3584)
torch.Size([973, 973, 7]) torch.Size([973, 512])
torch.Size([973, 6811])
7 512
torch.Size([6811, 512])
(tensor(973), 3584)
torch.Size([797, 797, 7]) torch.Size([797, 21])
torch.Size([797, 5579])
7 21
torch.Size([5579, 21])
(tensor(797), 147)
torch.Size([797, 797, 7]) torch.Size([797, 512])
torch.Size([797,

{'auprc@micro': tensor(0.1886), 'f1_max': tensor(0.2697)}