In [None]:
import pickle
import pandas as pd
import numpy as np
from tqdm import tqdm

import dgl
import torch
from torch.nn import NLLLoss
from torch.utils.data import DataLoader
from Bio.PDB import *
from sklearn.metrics import f1_score, precision_score, recall_score, matthews_corrcoef

from graphein.construct_graphs import ProteinGraph

## Data pre-processing
Here, we load the dataset provided in DeepPPISP. We first load the node labels:

In [None]:
# Get DeepPPISP Data
with open('all_dset_list.pkl', 'rb') as f:
    index = pickle.load(f)

with open('dset186_label.pkl', 'rb') as f:
    dset186_labels = pickle.load(f)

with open('dset164_label.pkl', 'rb') as f:
    dset164_labels = pickle.load(f)

with open('dset72_label.pkl', 'rb') as f:
    dset72_labels = pickle.load(f)

labels = dset186_labels + dset164_labels + dset72_labels

In [None]:
# Get PSSMs
with open('dset186_pssm_data.pkl', 'rb') as f:
    dset_186_pssms = pickle.load(f)

with open('dset164_pssm_data.pkl', 'rb') as f:
    dset_164_pssms = pickle.load(f)

with open('dset72_pssm_data.pkl', 'rb') as f:
    dset_72_pssms = pickle.load(f)

pssms = dset_186_pssms + dset_164_pssms + dset_72_pssms

In [None]:
# write labels
#pickle.dump(labels, open('ppisp_node_labels.p', 'wb'))

df = pd.DataFrame(index)
df.columns = ['pos_index', 'example_index', 'res_position', 'dataset', 'pdb', 'length']

df = df.loc[df['res_position'] == 0]

# Get PDB accession and chains
df[['pdb_code', 'chains']] = df.pdb.str.split("_", expand=True)
df['pdb_code'] = df['pdb_code'].str.lower()
 
    
# These columns don't follow the format
df.loc[df['dataset'] == 'dset164', 'pdb_code'] = df.copy().loc[df['dataset'] == 'dset164']['pdb'].str.slice(stop=4)
df.loc[df['dataset'] == 'dset164', 'chains'] = df.copy().loc[df['dataset'] == 'dset164']['pdb'].str.slice(-1)
df['chains'] = df['chains'].fillna('all')
# Remove Obsolete structures
obsolete = ['3NW0', '3VDO']
replacements = ['', '4NQW']
df = df.loc[~df['pdb_code'].isin(obsolete)]

# Assign training/test status
with open('training_list.pkl', 'rb') as f:
    train = pickle.load(f)

with open('testing_list.pkl', 'rb') as f:
    test = pickle.load(f)

df.loc[df['pos_index'].isin(train), 'train'] = 1
df.loc[df['pos_index'].isin(test), 'train'] = 0
df.reset_index(inplace=True)
#Write Dataframe
#df.to_csv('deepppisp_clean.csv')
df['pdb_code'] = df['pdb_code'].str.lower()
df

## Build Graphs

In [None]:
# Initialise Protein Graph Class
pg = ProteinGraph(granularity='CA',
                  insertions=False,
                  keep_hets=False,
                  node_featuriser='meiler',
                  get_contacts_path='/home/arj39/Documents/github/getcontacts',
                  pdb_dir='ppisp_pdbs/',
                  contacts_dir='ppisp_contacts/',
                  exclude_waters=True,
                  covalent_bonds=False,
                  include_ss=True,
                  include_ligand=False,
                  edge_distance_cutoff=None)

In [None]:
graph_list = []
label_list = []
test_indices = []
train_indices = []
idx_counter = 0
for example in tqdm(range(len(labels))):
    # Create Protein Graph
    try:
        # Construct graph using Graphein
        g = pg.dgl_graph_from_pdb_code(pdb_code=df['pdb_code'][example],
                                       chain_selection=list(df['chains'][example]),
                                       edge_construction=['contacts']
                                       )
        # Create PSSM Feats and label
        df_index = df.iloc[example]['example_index']
        label = labels[df_index]
        pssm = pssms[df_index]
                
    except:
        print(f'Failed on example {example}')
        break
    
    # Ensure node labels match number of nodes. There are a few cases (~5) where this doesn't hold. We skip these.
    if g.number_of_nodes() != len(label):
        print('label length does not match ', example)
        print(g.number_of_nodes())
        print(len(label))
        continue
    if g.number_of_nodes() != len(pssm):
        print(g.number_of_nodes())
        print(len(pssm))
        print('pssm length does not match', example)
        continue

    # Track training and test indices
    if df['train'][example] == 0:
        test_indices.append(idx_counter)
    if df['train'][example] == 1:
        train_indices.append(idx_counter)
    idx_counter += 1
    
    # Concatenate graph features and store graph and labels as 'feats'
    g.ndata['feats'] = torch.cat((g.ndata['h'],
                               g.ndata['ss'],
                               g.ndata['asa'],
                               g.ndata['rsa'],
                               g.ndata['coords'],
                               torch.Tensor(pssm)), dim=1)
    graph_list.append(g)
    
    label = torch.Tensor(label).long()
    label_list.append(label)

### Normalise graph features

In [None]:
test_graphs = [graph_list[i] for i in test_indices]
train_graphs = [graph_list[i] for i in train_indices]

print(f"Train graphs: {len(train_graphs)}")
print(f"Test graphs: {len(test_graphs)}")

train_labels = [label_list[i] for i in train_indices]
test_labels = [label_list[i] for i in test_indices]

# Compute feature min/maxes for normalisation
train_feats = torch.cat([graph.ndata['feats'] for graph in train_graphs], dim=0)
max_feats = torch.max(train_feats, dim=0)[0]
min_feats = torch.min(train_feats, dim=0)[0]

max_feats[max_feats == 0] = 1
# Normalise train and test graph features
for g in train_graphs:
    g.ndata['feats'] -= min_feats
    g.ndata['feats'] /= max_feats
    
for g in test_graphs:
    g.ndata['feats'] -= min_feats
    g.ndata['feats'] /= max_feats

### Create dataloaders

In [None]:
# Define collate function. This batches the graphs, and concatenates the label tensors, such that each batch is a batched DGL graph and the labels a corresponding binary tensor
def collate(samples):
    # The input `samples` is a list of pairs
    #  (graph, label).
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs, node_attrs='feats')
    batched_graph.set_n_initializer(dgl.init.zero_initializer)
    batched_graph.set_e_initializer(dgl.init.zero_initializer)
    return batched_graph, torch.cat(labels)

train_data = list(zip(train_graphs, train_labels))
test_data = list(zip(test_graphs, test_labels))

#Create dataloaders
train_loader = DataLoader(train_data, batch_size=32, shuffle=True,
                         collate_fn=collate)

test_loader = DataLoader(test_data, batch_size=32, shuffle=True,
                         collate_fn=collate)

## Define Model
Adapted from: https://docs.dgl.ai/en/0.2.x/tutorials/basics/1_first.html

In [None]:
import torch.nn as nn
import torch.nn.functional as F

# Define the message & reduce function
# NOTE: we ignore the GCN's normalization constant c_ij for this tutorial.
def gcn_message(edges):
    # The argument is a batch of edges.
    # This computes a (batch of) message called 'msg' using the source node's feature 'h'.
    return {'msg' : edges.src['h']}

def gcn_reduce(nodes):
    # The argument is a batch of nodes.
    # This computes the new 'h' features by summing received 'msg' in each node's mailbox.
    return {'h' : torch.sum(nodes.mailbox['msg'], dim=1)}

# Define the GCNLayer module
class GCNLayer(nn.Module):
    def __init__(self, in_feats, out_feats):
        super(GCNLayer, self).__init__()
        self.linear = nn.Linear(in_feats, out_feats)

    def forward(self, g, inputs):
        # g is the graph and the inputs is the input node features
        # first set the node features
        g.ndata['h'] = inputs
        # trigger message passing on all edges
        g.send(g.edges(), gcn_message)
        # trigger aggregation at all nodes
        g.recv(g.nodes(), gcn_reduce)
        # get the result node features
        h = g.ndata.pop('h')
        # perform linear transformation
        return self.linear(h)

In [None]:
class GCN(nn.Module):
    def __init__(self, in_feats, hidden_size, num_classes):
        super(GCN, self).__init__()
        self.gcn1 = GCNLayer(in_feats, hidden_size)
        self.gcn2 = GCNLayer(hidden_size, num_classes)

    def forward(self, g, inputs):
        h = self.gcn1(g, inputs)
        h = torch.relu(h)
        h = self.gcn2(g, h)
        return h
# The first layer transforms input features of size of 41 to a hidden size of 5.
# The second layer transforms the hidden layer and produces output features of
# size 2, corresponding to the two classification groups
net = GCN(41, 16, 2)

## Train the model

In [None]:
optimizer = torch.optim.Adam(net.parameters(), lr=0.0001)
train_logits = []

epochs = 50
epoch_losses = []

epoch_f1_scores = [] 
epoch_precision_scores = []
epoch_recall_scores = []

loss_fn = nn.NLLLoss(weight=torch.Tensor([1, 5.84]))

# Training loop
net.train()
for epoch in range(epochs):
    epoch_loss = 0
    
    epoch_logits = []
    labs = []
    
    # Iterate over batches
    for i, (bg, labels) in enumerate(train_loader):
        #labels = labels.to_device()
        logits = net(bg, bg.ndata['feats'])
        # we save the logits for visualization later
        train_logits.append(logits.detach().numpy())
        epoch_logits.append(logits.detach().numpy())
        labs.append(labels.unsqueeze(1).detach().numpy())

        logp = F.log_softmax(logits, 1)
        loss = loss_fn(logp, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().item()
        
    # Calculate accuracy metrics
    epoch_logits = np.vstack(epoch_logits)
    labs = np.vstack(labs)
    
    #print(np.argmax(epoch_logits, axis=1).sum())
    
    f1 = f1_score(labs, np.argmax(np.vstack(epoch_logits), axis=1), average='weighted')
    precision = precision_score(labs, np.argmax(np.vstack(epoch_logits), axis=1), average='weighted')
    recall = recall_score(labs, np.argmax(np.vstack(epoch_logits), axis=1), average='weighted')
    
    
    epoch_loss /= (i+1)
    if epoch % 5 == 0:
        print('Epoch %d | Loss: %.4f | F1: %.4f | Precision: %.4f | Recall: %.4f' % (epoch, epoch_loss, f1, precision, recall))
        
    epoch_losses.append(epoch_loss)
    epoch_f1_scores.append(f1)
    epoch_precision_scores.append(precision)
    epoch_recall_scores.append(recall)

In [None]:
# Here, we derive the class weights used above
print(sum(labs))
print(len(labs))

74072-10833
63239/10833

In [None]:
from matplotlib import pyplot as plt
%matplotlib inline

plt.plot(epoch_losses, label="Loss")
plt.plot(epoch_f1_scores, label='F1')
plt.plot(epoch_precision_scores, label="Precision")
plt.plot(epoch_recall_scores, label="Recall")
plt.legend()

In [None]:
epochs = 200

# Training loop
gcn_net.train()
epoch_losses = []

epoch_f1_scores = [] 
epoch_precision_scores = []
epoch_recall_scores = []

for epoch in range(epochs):
    epoch_loss = 0

    preds = []
    labs = []
    # Train on batch
    for i, (bg, labels) in enumerate(train_loader):
        labels = labels.to(device)
        graph_feats = bg.ndata.pop('h').to(device)
        graph_feats, labels = graph_feats.to(device), labels.to(device)
        y_pred = gcn_net(bg, graph_feats)
        
        preds.append(y_pred.detach().numpy())
        labs.append(labels.detach().numpy())

        labels = np.argmax(labels, axis=1)
        
        loss = loss_fn(y_pred, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().item()
        
    epoch_loss /= (i + 1)
    
    preds = np.vstack(preds)
    labs = np.vstack(labs)
    
    # There's some sort of issue going on here with the scoring. All three values are the same
    f1 = f1_score(np.argmax(labs, axis=1), np.argmax(preds, axis=1), average='micro')
    precision = precision_score(np.argmax(labs, axis=1), np.argmax(preds, axis=1), average='micro')
    recall = recall_score(np.argmax(labs, axis=1), np.argmax(preds, axis=1), average='micro')
    
    if epoch % 5 == 0:
        print(f"epoch: {epoch}, LOSS: {epoch_loss:.3f}, F1: {f1:.3f}, Precision: {precision:.3f}, Recall: {recall:.3f}")
        
    epoch_losses.append(epoch_loss)
    epoch_f1_scores.append(f1)
    epoch_precision_scores.append(precision)
    epoch_recall_scores.append(recall)

## Evaluate the model

In [None]:
test_loss =0
test_logits = []
preds = []
labs = []

net.eval()
for i, (bg, labels) in enumerate(test_loader):
    #labels = labels.to(device)
    logits = net(bg, bg.ndata['feats'])
    
    test_logits.append(logits.detach().numpy())
    labs.append(labels.unsqueeze(1).detach().numpy())

    #logp = F.log_softmax(logits, 1)
    
test_logits = np.vstack(test_logits)
labs = np.vstack(labs)

f1 = f1_score(labs, np.argmax(np.vstack(test_logits), axis=1), average='weighted')
precision = precision_score(labs, np.argmax(np.vstack(test_logits), axis=1), average='weighted')
recall = recall_score(labs, np.argmax(np.vstack(test_logits), axis=1), average='weighted')

print('Test: F1: %.4f | Precision: %.4f | Recall: %.4f' % (f1, precision, recall))

In [None]:
# Evaluate
gcn_net.eval()
test_loss = 0

preds = []
labs = []
for i, (bg, labels) in enumerate(test_loader):
    labels = labels.to(device)
    graph_feats = bg.ndata.pop('h').to(device)
    graph_feats, labels = graph_feats.to(device), labels.to(device)
    y_pred = gcn_net(bg, graph_feats)

    preds.append(y_pred.detach().numpy())
    labs.append(labels.detach().numpy())

labs = np.vstack(labs)
preds = np.vstack(preds)

f1 = f1_score(np.argmax(labs, axis=1), np.argmax(preds, axis=1), average='micro')
precision = precision_score(np.argmax(labs, axis=1), np.argmax(preds, axis=1), average='micro')
recall = recall_score(np.argmax(labs, axis=1), np.argmax(preds, axis=1), average='micro')

print(f"TEST F1: {f1:.3f}, Precision: {precision:.3f}, Recall: {recall:.3f}")