# Protein Structural Change Prediction Example
In this notebook, we use Graphein to preprocess the PSCDB database into graphs. We then perform graph classification on the unbound protein ligand graphs to predict the class of structural rearrangement the protein undergoes upon ligand binding.

In [None]:
import dgl
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from dgllife.model.model_zoo import GCNPredictor
from graphein.construct_graphs import ProteinGraph
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import f1_score, precision_score, recall_score
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader

from matplotlib import pyplot as plt
device = 'cpu'

## Data
1.  We load the dataset. The script for parsing the datafrom a webserver to the `structural_rearrangement_data.csv` file is available in `process_data.ipynb` and `make_rearrangement_data.py` 

In [None]:
# Load data sets
df = pd.read_csv('structural_rearrangement_data.csv')
df.head()

2. We create one-hot encodings of the labels, indicating the rearrangement motion class

In [None]:
# Create labels
labels = pd.get_dummies(df.motion_type).values.tolist()
labels = [torch.Tensor(i) for i in labels]

3. We split the data into training and testing data, and construct the Graphs using Graphein. 

Graphein will automatically download the relevant `.PDB` files from the PDB and compute the intramolecular contacts using `GetContacts` if the files are not found in the `pdb_dir` and `contacts_dir` directories.

We select the relevant chains in structure from the PDB from the `Free Chains` column in the dataframe.

In [None]:
# Split datasets
x_train, x_test, y_train, y_test = train_test_split(df, labels, test_size=0.15)

# Initialise Graph Constructor
pg = ProteinGraph(granularity='CA', insertions=False, keep_hets=True,
                  node_featuriser='meiler', get_contacts_path='/Users/arianjamasb/github/getcontacts',
                  pdb_dir='../../examples/pdbs/',
                  contacts_dir='../../examples/contacts/',
                  exclude_waters=True, covalent_bonds=False, include_ss=True)

# Build Graphs
train_graphs = [pg.dgl_graph_from_pdb_code(pdb_code=x_train['Free PDB'].iloc[i],
                                         chain_selection=list(x_train['Free Chains'].iloc[i])) for i in tqdm(range(len(x_train)))]

test_graphs = [pg.dgl_graph_from_pdb_code(pdb_code=x_test['Free PDB'].iloc[i],
                                         chain_selection=list(x_test['Free Chains'].iloc[i])) for i in tqdm(range(len(x_test)))]

In [None]:
def collate(samples):
    # The input `samples` is a list of pairs
    #  (graph, label).
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs, node_attrs='h')
    batched_graph.set_n_initializer(dgl.init.zero_initializer)
    batched_graph.set_e_initializer(dgl.init.zero_initializer)
    return batched_graph, torch.stack(labels)

train_data = list(zip(train_graphs, y_train))
test_data = list(zip(test_graphs, y_test))

#Create dataloaders
train_loader = DataLoader(train_data, batch_size=32, shuffle=True,
                         collate_fn=collate)

test_loader = DataLoader(test_data, batch_size=32, shuffle=True,
                         collate_fn=collate)

# Model
Here, we define a simple GCN for graph classification. We then train it.

In [None]:
n_feats = train_graphs[1].ndata['h'].shape[1]

# Instantiate model
gcn_net = GCNPredictor(in_feats=n_feats,
                       hidden_feats=[32, 32],
                       batchnorm=[True, True],
                       dropout=[0, 0],
                       classifier_hidden_feats=32,
                       n_tasks=7
                       )
gcn_net.to(device)
loss_fn = CrossEntropyLoss()
optimizer = torch.optim.Adam(gcn_net.parameters(), lr=0.005)

In [None]:
epochs = 200

# Training loop
gcn_net.train()
epoch_losses = []

epoch_f1_scores = [] 
epoch_precision_scores = []
epoch_recall_scores = []

for epoch in range(epochs):
    epoch_loss = 0

    preds = []
    labs = []
    # Train on batch
    for i, (bg, labels) in enumerate(train_loader):
        labels = labels.to(device)
        graph_feats = bg.ndata.pop('h').to(device)
        graph_feats, labels = graph_feats.to(device), labels.to(device)
        y_pred = gcn_net(bg, graph_feats)
        
        preds.append(y_pred.detach().numpy())
        labs.append(labels.detach().numpy())

        labels = np.argmax(labels, axis=1)
        
        loss = loss_fn(y_pred, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().item()
        
    epoch_loss /= (i + 1)
    
    preds = np.vstack(preds)
    labs = np.vstack(labs)
    
    # There's some sort of issue going on here with the scoring. All three values are the same
    f1 = f1_score(np.argmax(labs, axis=1), np.argmax(preds, axis=1), average='micro')
    precision = precision_score(np.argmax(labs, axis=1), np.argmax(preds, axis=1), average='micro')
    recall = recall_score(np.argmax(labs, axis=1), np.argmax(preds, axis=1), average='micro')
    
    if epoch % 5 == 0:
        print(f"epoch: {epoch}, LOSS: {epoch_loss:.3f}, F1: {f1:.3f}, Precision: {precision:.3f}, Recall: {recall:.3f}")
        
    epoch_losses.append(epoch_loss)
    epoch_f1_scores.append(f1)
    epoch_precision_scores.append(precision)
    epoch_recall_scores.append(recall)

In [None]:
plt.plot(epoch_losses, label="Loss")
plt.plot(epoch_f1_scores, label='F1')
plt.plot(epoch_precision_scores, label="Precision")
plt.plot(epoch_recall_scores, label="Recall")
plt.legend()

In [None]:
# Evaluate
gcn_net.eval()
test_loss = 0

preds = []
labs = []
for i, (bg, labels) in enumerate(test_loader):
    labels = labels.to(device)
    graph_feats = bg.ndata.pop('h').to(device)
    graph_feats, labels = graph_feats.to(device), labels.to(device)
    y_pred = gcn_net(bg, graph_feats)

    preds.append(y_pred.detach().numpy())
    labs.append(labels.detach().numpy())

labs = np.vstack(labs)
preds = np.vstack(preds)

f1 = f1_score(np.argmax(labs, axis=1), np.argmax(preds, axis=1), average='micro')
precision = precision_score(np.argmax(labs, axis=1), np.argmax(preds, axis=1), average='micro')
recall = recall_score(np.argmax(labs, axis=1), np.argmax(preds, axis=1), average='micro')

print(f"TEST F1: {f1:.3f}, Precision: {precision:.3f}, Recall: {recall:.3f}")