In [1]:
# NBVAL_SKIP
import sys
sys.path.append('../')

import logging
logging.getLogger('matplotlib').setLevel(logging.CRITICAL)
logging.getLogger('graphein').setLevel(logging.INFO)

# PSCDB - Baselines

In [2]:
# NBVAL_SKIP
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import pytorch_lightning as pl
from tqdm.notebook import tqdm
import networkx as nx
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.utils import from_networkx
from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import f1_score

import warnings
warnings.filterwarnings("ignore")

[23:36:27] Enabling RDKit 2019.09.3 jupyter extensions


## Load dataset


In [3]:
# NBVAL_SKIP
df = pd.read_csv("../datasets/pscdb/structural_rearrangement_data.csv")
pdbs = df["Free PDB"]
y = [torch.argmax(torch.Tensor(lab)).type(torch.LongTensor) for lab in LabelBinarizer().fit_transform(df.motion_type)]

## Transformation from Raw Structure to ML-ready Datasets Construction with Graphein

In [4]:
# NBVAL_SKIP
from graphein.protein.config import ProteinGraphConfig
from graphein.protein.edges.distance import add_hydrogen_bond_interactions, add_peptide_bonds, add_k_nn_edges
from graphein.protein.graphs import construct_graph

from functools import partial

# Override config with constructors
constructors = {
    "edge_construction_functions": [partial(add_k_nn_edges, k=3, long_interaction_threshold=0)],
    #"edge_construction_functions": [add_hydrogen_bond_interactions, add_peptide_bonds],
    #"node_metadata_functions": [add_dssp_feature]
}

config = ProteinGraphConfig(**constructors)
print(config.dict())

# Make graphs
graph_list = []
y_list = []
for idx, pdb in enumerate(tqdm(pdbs)):
    try:
        graph_list.append(
            construct_graph(pdb_code=pdb,
                        config=config
                       )
            )
        y_list.append(y[idx])
    except:
        print(str(idx) + ' processing error...')
        pass

{'granularity': 'CA', 'keep_hets': False, 'insertions': False, 'pdb_dir': PosixPath('../examples/pdbs'), 'verbose': False, 'exclude_waters': True, 'deprotonate': False, 'protein_df_processing_functions': None, 'edge_construction_functions': [functools.partial(<function add_k_nn_edges at 0x7fb4b22bbf70>, k=3, long_interaction_threshold=0)], 'node_metadata_functions': [<function meiler_embedding at 0x7fb4b22c7430>], 'edge_metadata_functions': None, 'graph_metadata_functions': None, 'get_contacts_config': None, 'dssp_config': None}


  0%|          | 0/891 [00:00<?, ?it/s]

URL Error [Errno 110] Connection timed out
274 processing error...
666 processing error...
677 processing error...


In [44]:
# NBVAL_SKIP
pdbs[274]
#pdbs[266]
#pdbs[677]

'3e59'

## Convert Nx graphs to PyTorch Geometric

In [8]:
# NBVAL_SKIP
from graphein.ml.conversion import GraphFormatConvertor

format_convertor = GraphFormatConvertor('nx', 'pyg', 
                                        verbose = 'gnn', 
                                        columns = None)

Using backend: pytorch


In [9]:
# NBVAL_SKIP
pyg_list = [format_convertor(graph) for graph in tqdm(graph_list)]

  0%|          | 0/888 [00:00<?, ?it/s]

In [10]:
# NBVAL_SKIP
for idx, g in enumerate(pyg_list):
    g.y = y_list[idx] 
    g.coords = torch.FloatTensor(g.coords[0])

In [11]:
# NBVAL_SKIP
for i in pyg_list:
    if i.coords.shape[0] == len(i.node_id):
        pass
    else:
        print(i)
        pyg_list.remove(i)

Data(coords=[10112, 3], dist_mat=[1], edge_index=[2, 120], name=[1], node_id=[1264], y=1)
Data(coords=[820, 3], dist_mat=[1], edge_index=[2, 1431], name=[1], node_id=[808], y=2)
Data(coords=[668, 3], dist_mat=[1], edge_index=[2, 1166], name=[1], node_id=[666], y=4)
Data(coords=[2720, 3], dist_mat=[1], edge_index=[2, 3], name=[1], node_id=[340], y=5)


## Model Configuration

In [32]:
# NBVAL_SKIP
config_default = dict(
    n_hid = 8,
    n_out = 8,
    batch_size = 4,
    dropout = 0.5,
    lr = 0.001,
    num_heads = 32,
    num_att_dim = 64,
    model_name = 'GCN'
)

class Struct:
    def __init__(self, **entries):
        self.__dict__.update(entries)
        
config = Struct(**config_default)

global model_name
model_name = config.model_name

## Construct DataLoaders

In [33]:
# NBVAL_SKIP
import numpy as np
np.random.seed(42)
idx_all = np.arange(len(pyg_list))
np.random.shuffle(idx_all)

train_idx, valid_idx, test_idx = np.split(idx_all, [int(.8*len(pyg_list)), int(.9*len(pyg_list))])
train, valid, test = [pyg_list[i] for i in train_idx], [pyg_list[i] for i in valid_idx], [pyg_list[i] for i in test_idx]

from torch_geometric.data import DataLoader
train_loader = DataLoader(train, batch_size=config.batch_size, shuffle = True, drop_last = True)
valid_loader = DataLoader(valid, batch_size=32)
test_loader = DataLoader(test, batch_size=32)

In [34]:
# NBVAL_SKIP
pyg_list[0]

Data(coords=[635, 3], dist_mat=[1], edge_index=[2, 1118], name=[1], node_id=[635], y=1)

## Define Model

In [35]:
# NBVAL_SKIP
from torch_geometric.nn import GCNConv, GATConv, SAGEConv, global_add_pool
from torch.nn.functional import mse_loss, nll_loss, relu, softmax, cross_entropy
from torch.nn import functional as F
from pytorch_lightning.metrics.functional import accuracy

In [36]:
# NBVAL_SKIP
class GraphNets(pl.LightningModule):
    def __init__(self):
        super().__init__()
        
        if model_name == 'GCN':
            self.layer1 = GCNConv(in_channels=3, out_channels=config.n_hid)
            self.layer2 = GCNConv(in_channels=config.n_hid, out_channels=config.n_out)

        elif model_name == 'GAT':
            self.layer1 = GATConv(3, config.num_att_dim, heads=config.num_heads, dropout=config.dropout)
            self.layer2 = GATConv(config.num_att_dim * config.num_heads, out_channels = config.n_out, heads=1, concat=False,
                                 dropout=config.dropout)
            
        elif model_name == 'GraphSAGE':
            self.layer1 = SAGEConv(3, config.n_hid)
            self.layer2 = SAGEConv(config.n_hid, config.n_out)  
            
        self.decoder = nn.Linear(config.n_out, 7)
        
    def forward(self, g):
        x = g.coords
        x = F.dropout(x, p=config.dropout, training=self.training)
        x = F.elu(self.layer1(x, g.edge_index))
        x = F.dropout(x, p=config.dropout, training=self.training)
        x = self.layer2(x, g.edge_index)
        x = global_add_pool(x, batch=g.batch)
        x = self.decoder(x)
        return softmax(x)

    def training_step(self, batch, batch_idx):
        x = batch   
        y = x.y
        y_hat = self(x)
        loss = cross_entropy(y_hat, y)
        acc = accuracy(y_hat, y)

        self.log("train_loss", loss)
        self.log("train_acc", acc)
        return loss

    def validation_step(self, batch, batch_idx):
        x = batch   
        y = x.y
        y_hat = self(x)
        loss = cross_entropy(y_hat, y)
        acc = accuracy(y_hat, y)
        self.log("valid_loss", loss)
        self.log("valid_acc", acc)

    def test_step(self, batch, batch_idx):
        x = batch   
        y = x.y
        y_hat = self(x)
        loss = cross_entropy(y_hat, y)
        acc = accuracy(y_hat, y)

        y_pred_softmax = torch.log_softmax(y_hat, dim = 1)
        y_pred_tags = torch.argmax(y_pred_softmax, dim = 1) 
        f1 = f1_score(y.detach().cpu().numpy(), y_pred_tags.detach().cpu().numpy(), average = 'weighted')

        self.log("test_loss", loss)
        self.log("test_acc", acc)
        self.log("test_f1", f1)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=config.lr)
        return optimizer

In [37]:
# NBVAL_SKIP
GraphNets()

GraphNets(
  (layer1): GCNConv(3, 8)
  (layer2): GCNConv(8, 8)
  (decoder): Linear(in_features=8, out_features=7, bias=True)
)

In [40]:
# NBVAL_SKIP
from pytorch_lightning.callbacks import ModelCheckpoint
import os

file_path = './graphein_model'
if not os.path.exists(file_path):
    os.mkdir(file_path)

checkpoint_callback = ModelCheckpoint(
    monitor="valid_loss",
    dirpath=file_path,
    filename="model-{epoch:02d}-{val_loss:.2f}",
    save_top_k=1,
    mode="min",
)

## Train!

In [41]:
# NBVAL_SKIP
# Train Model
model = GraphNets()
trainer = pl.Trainer(max_epochs=200, gpus=-1, callbacks=[checkpoint_callback])
trainer.fit(model, train_loader, valid_loader)

# evaluate on the model with the best validation set
best_model = GraphNets.load_from_checkpoint(checkpoint_callback.best_model_path)
out_best_test = trainer.test(best_model, test_loader)[0]

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/home/atj39/anaconda3/envs/graphein-dev/lib/python3.8/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/home/atj39/anaconda3/envs/graphein-dev/lib/python3.8/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'GraphNets' on <module '__main__' (built-in)>


KeyboardInterrupt: 