### Train and validate GCNs

Train and validate GCNs on DUD-E data subsets.

In [1]:
%load_ext nb_black

<IPython.core.display.Javascript object>

In [2]:
import os
import time
import torch
import torch.nn.functional as F

import numpy as np
import pandas as pd

from progressbar import progressbar
from torch.nn import Linear, Softmax, Dropout
from LigandDataset import LigandDataset
from torch_geometric.nn.pool import topk_pool
from torch_geometric.data import (
    Data,
)
from torch_geometric.nn import GCNConv, global_mean_pool
from sklearn.metrics import roc_auc_score, precision_score, recall_score

<IPython.core.display.Javascript object>

#### Hyperparameters

In [3]:
# Gradient descent batch size.
BATCH_SIZE = 32
# Number of training epochs.
NUM_EPOCHS = 10
# Number of convolutional blocks.
NUM_BLOCKS_LIGANDS = 1
NUM_BLOCKS_PROTEINS = 1
# Number of convolutional channels.
NUM_CHANNELS_LIGANDS = 16
NUM_CHANNELS_PROTEINS = 16
# Graph power to create bonds with after pooling.
GRAPH_POWER_AFTER_POOLING = 2

<IPython.core.display.Javascript object>

#### Data

In [4]:
# Load the list of targets, and generate the sublists of targets
# being trained and validated on. Since we set a seed, these lists
# will be the same as those which were used to generate the LigandDataset objects.
all_targets = pd.read_csv("../data/dud-e_targets.csv").target_name.tolist()
all_targets = [target.lower() for target in all_targets]

np.random.seed(1)
np.random.shuffle(all_targets)
training_targets = all_targets[:60]
validation_targets = all_targets[60:80]

# Store these sublists in a dictionary.
targets = {"training": training_targets, "validation": validation_targets}

<IPython.core.display.Javascript object>

In [5]:
# Load the target Data object for each target.
target_dict = dict(
    zip(
        all_targets,
        [
            pd.read_pickle(f"../data/raw/{10}_bond_cutoff_{5}_{target}.pkl")
            for target in all_targets
        ],
    )
)

<IPython.core.display.Javascript object>

In [6]:
# Load the ligand Dataset objects.
training_set = pd.read_pickle("../data/training_set.pkl")
validation_set = pd.read_pickle("../data/validation_set.pkl")

<IPython.core.display.Javascript object>

#### Network

In [7]:
def convert_edge_index(edge_index):
    """Convert edge list to bonds list."""
    edge_index_list = edge_index.numpy().tolist()
    return [
        [edge_index_list[0][i], edge_index_list[1][i]]
        for i in range(len(edge_index_list[0]))
    ]


def convert_bonds(bonds):
    """Convert bonds list to edge list."""
    return torch.tensor(np.hsplit(np.array(bonds).transpose(), 1)[0])

<IPython.core.display.Javascript object>

In [8]:
def add_bonds(num_atoms, edge_index, graph_power=GRAPH_POWER_AFTER_POOLING):
    """Get bond list using the graph power method."""
    original_bonds = convert_edge_index(edge_index)
    # We will append the bonds induced by graph power to
    # the processed_bonds list, instead of appending them to
    # original_bonds directly - that would result in mistakes.
    processed_bonds = original_bonds.copy()
    # Create a dictionary which maps each atom index to a list
    # with the indices of all of its neighbors.
    original_bonds_dict = dict()
    for i in range(len(original_bonds)):
        try:
            original_bonds_dict[original_bonds[i][0]].append(original_bonds[i][1])
        except:
            original_bonds_dict[original_bonds[i][0]] = [original_bonds[i][1]]

    # For each atom, find its neighbors to the `graph_power`-th graph power.
    for atom_index in range(num_atoms):
        all_neighbors = []
        # First, we explore the neighbors of the atom's immediate
        # neighbors as the second graph-power neighbors. Next, we consider
        # their neighbors, and so on.
        try:
            neighbors_to_explore = original_bonds_dict[atom_index]
        except:
            continue
        for power in range(1, graph_power + 1):
            new_neighbors = set()
            for neighbor_atom_index in neighbors_to_explore:
                new_neighbors.update(original_bonds_dict[neighbor_atom_index])
            # Store the `power`-th graph power atom neighbors in the list.
            all_neighbors += neighbors_to_explore
            # Their neighbors are now the new neighbors to explore.
            neighbors_to_explore = list(new_neighbors)
            new_neighbors = set()
        all_neighbors = list(set(all_neighbors))
        try:
            all_neighbors.remove(atom_index)
        except:
            pass
        # Add the associated bonds to the bonds list.
        processed_bonds += [
            [atom_index, neighbor_index] for neighbor_index in all_neighbors
        ]
    return convert_bonds(processed_bonds)

<IPython.core.display.Javascript object>

In [9]:
class GCN(torch.nn.Module):
    def __init__(
        self,
        ligand_num_layers=NUM_BLOCKS_LIGANDS,
        ligand_hidden_channels=NUM_CHANNELS_LIGANDS,
        protein_num_layers=NUM_BLOCKS_PROTEINS,
        protein_hidden_channels=NUM_CHANNELS_PROTEINS,
    ):
        super(GCN, self).__init__()
        torch.manual_seed(1)
        # Create the convolutional layers.
        self.ligand_conv = [GCNConv(1, ligand_hidden_channels, improved=True)] + [
            GCNConv(ligand_hidden_channels, ligand_hidden_channels, improved=True)
            for _ in range(ligand_num_layers - 1)
        ]
        # Create the pooling layers.
        self.ligand_pool = [
            topk_pool.TopKPooling(ligand_hidden_channels)
            for _ in range(ligand_num_layers)
        ]
        # Create the convolutional layers.
        self.protein_conv = [GCNConv(3, protein_hidden_channels, improved=True)] + [
            GCNConv(protein_hidden_channels, protein_hidden_channels, improved=True)
            for _ in range(protein_num_layers - 1)
        ]
        # Create the pooling layers.
        self.protein_pool = [
            topk_pool.TopKPooling(protein_hidden_channels)
            for _ in range(protein_num_layers)
        ]

        # self.dropout = Dropout(0.25)
        # Create the fully-connected layer.
        self.fc = Linear(ligand_hidden_channels + protein_hidden_channels, 2)

        self.ligand_num_layers = ligand_num_layers
        self.protein_num_layers = protein_num_layers

    def forward(self, ligand, protein):
        # Process ligand.
        x_l = ligand.x
        edge_index = ligand.edge_index
        edge_attr = torch.squeeze(ligand.edge_attr)
        # Pass the node and edge information through each convolutional block.
        for block in range(self.ligand_num_layers):
            # Update node representations through convolution.
            x_l = self.ligand_conv[block](x_l.float(), edge_index, edge_attr)
            # Apply non-linear activation.
            x_l = x_l.relu()
            # Pool the graph to 50% of its current size.
            x_l, edge_index, edge_attr, _, _, _ = self.ligand_pool[block](
                x_l, edge_index, edge_attr
            )
            # Add new bonds between the remaining nodes with graph power.
            edge_index = add_bonds(x_l.shape[0], edge_index)
        # Flatten the graph by taking the mean of each channel.
        x_l = global_mean_pool(x_l, batch=torch.LongTensor([0]))

        # Process protein.
        x_p = protein.x
        edge_index = protein.edge_index
        edge_attr = torch.squeeze(protein.edge_attr)
        # Pass the node and edge information through each convolutional block.
        for block in range(self.protein_num_layers):
            # Update node representations through convolution.
            x_p = self.protein_conv[block](x_p.float(), edge_index, edge_attr)
            # Apply non-linear activation.
            x_p = x_p.relu()
            # Pool the graph to 50% of its current size.
            x_p, edge_index, edge_attr, _, _, _ = self.protein_pool[block](
                x_p, edge_index, edge_attr
            )
            # Add new bonds between the remaining nodes with graph power.
            edge_index = add_bonds(x_p.shape[0], edge_index)
        # Flatten the graph by taking the mean of each channel.
        x_p = global_mean_pool(x_p, batch=torch.LongTensor([0]))

        # Concatenate ligand and protein representations.
        x = torch.cat([x_l[0], x_p[0]]).float()
        # x = self.dropout(x)
        # Pass the vector through a fully-connected layer before applying softmax activation.
        x = self.fc(x)
        return F.log_softmax(x, dim=0)

<IPython.core.display.Javascript object>

#### Training

In [10]:
def get_pred(out):
    """Return the class predicted most likely."""
    return int((out[0] < out[1]).item())


def get_accuracy(all_truths, all_preds):
    """Return accuracy."""
    return np.mean(np.array(all_truths) == np.array(all_preds))


def get_prob(out):
    """Return the predicted probability of the positive class."""
    return np.exp(out[1].item())


def train():
    """Train the model and print epoch performance."""
    model.train()

    all_preds = []
    all_preds_ = []
    all_truths = []
    # Pass through every example in the training set.
    for i in range(training_set.len()):
        curr_ligand = training_set.get(i)
        curr_response = curr_ligand.y
        curr_target = target_dict[curr_ligand.target]

        # Pass the current ligand and target data through the model.
        out = model(curr_ligand, curr_target)
        # Compute the loss.
        loss = criterion(out.unsqueeze(0), curr_ligand.y[0])
        all_preds.append(get_pred(out))
        all_preds_.append(get_prob(out))
        all_truths.append(curr_ligand.y[0].item())
        # Derive gradients.
        loss.backward()
        # If the sufficient number of steps before updating have been
        # reached, update the weights.
        if i > 0 and not i % BATCH_SIZE:
            # Update parameters based on gradients.
            optimizer.step()
            # Clear gradients.
            optimizer.zero_grad()
    print(
        f"Training AUC: {roc_auc_score(all_truths, all_preds_)},"
        f" precision: {precision_score(all_truths, all_preds)},"
        f" recall: {recall_score(all_truths, all_preds)},"
        f" accuracy: {get_accuracy(all_truths, all_preds)}"
    )


def test():
    """Evaluate the model on the validation set."""
    model.eval()

    all_preds = []
    all_preds_ = []
    all_truths = []
    # Pass through every example in the training set.
    for i in range(validation_set.len()):
        curr_ligand = validation_set.get(i)
        curr_response = curr_ligand.y
        curr_target = target_dict[curr_ligand.target]

        # Perform a single forward pass.
        out = model(curr_ligand, curr_target)
        all_preds.append(get_pred(out))
        all_preds_.append(get_prob(out))
        all_truths.append(curr_ligand.y[0].item())
    print(
        f"Validation AUC: {roc_auc_score(all_truths, all_preds_)},"
        f" precision: {precision_score(all_truths, all_preds)},"
        f" recall: {recall_score(all_truths, all_preds)},"
        f" accuracy: {get_accuracy(all_truths, all_preds)}"
    )

<IPython.core.display.Javascript object>

#### Application

In [11]:
model = GCN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = torch.nn.NLLLoss()

<IPython.core.display.Javascript object>

In [None]:
def get_time_left(time_taken, epoch):
    """Get estimate of training and evaluation time left."""
    return round(((NUM_EPOCHS - epoch) * time_taken) / 60)

for epoch in range(1, NUM_EPOCHS + 1):
    print(f"Epoch {epoch}")
    start = time.time()
    train()
    test()
    end = time.time()
    print(get_time_left(end - start, epoch))
    print("")

Epoch 1
