## About the dataset

The dataset contains **40,000 molecules** and labelled ``0`` or ``1`` based on their ability to inhibit **HIV**

We split our project into 5 modules:

1) Buiding the **dataset** - converting raw data into useful node and edge features
2) Building the **GNN**
3) **Generative GNN** - to generate arbitrary molecules that are potential HIV inhibitors
4) **Explainable AI** on graphs

### Step 1: Buiding the dataset - converting raw data into useful node and edge features

In [None]:
# imports
import pandas as pd

path = "data/raw/HIV.csv"
dataset = pd.read_csv(path)
dataset.head()

So, the first column is the actual set of molecules. We want to convert these to graph so that we can pass as data to GNNs

In [None]:
# General info about dataset
print(dataset.shape)
print(dataset["HIV_active"].value_counts())

There are **41127 molecules** and out of that only **1443 molecules** are actually HIV inhibitors. We keep this in mind when selecting training and testing data. We might have to balance the dataset by under-sampling the negative samples.

In [None]:
# Convert the smiles column to molecule structure
import rdkit
from rdkit import Chem
from rdkit.Chem import Draw

# See for a small subset of molecules
sample_smiles = dataset["smiles"][4:30].values
sample_mols = [Chem.MolFromSmiles(smiles) for smiles in sample_smiles]
grid = Draw.MolsToGridImage(sample_mols,molsPerRow=8,subImgSize=(200,200))

grid

Now that we have seen a few of the molecules as graph, let us move to creating the custom dataset using PyTorch Geometric. We use the method mentioned in Documentation that corresponds to creating a dataset that stores in your local machine instead of RAM

In [None]:
import torch
import torch_geometric
from torch_geometric.data import Dataset, Data
import numpy as np 
import os
from tqdm import tqdm

# We extend the functionalites in Dataset class
class MoleculeDataset(Dataset):
    def __init__(self, root, filename, test=False, transform=None, pre_transform=None):
        """
        root = Where the dataset should be stored. This folder is split
        into raw_dir (downloaded dataset) and processed_dir (processed data). 
        """
        self.test = test
        self.filename = filename
        super(MoleculeDataset, self).__init__(root, transform, pre_transform)
        
    @property
    def raw_file_names(self):
        """ If this file exists in raw_dir, the download is not triggered.
            (The download func. is not implemented here)  
        """
        return self.filename

    @property
    def processed_file_names(self):
        """ If these files are found in raw_dir, processing is skipped"""
        self.data = pd.read_csv(self.raw_paths[0]).reset_index()

        if self.test:
            return [f'data_test_{i}.pt' for i in list(self.data.index)]
        else:
            return [f'data_{i}.pt' for i in list(self.data.index)]

    def download(self):
        pass

    def process(self):
        """In this function we construct the graphs for each object in dataset
        """
        # Reading the dataset
        self.data = pd.read_csv(self.raw_paths[0])

        # Iterating over each object to get the required parameters to preprocess
        # tqdm gives process bar which tells us how far our process is done
        for index, mol in tqdm(self.data.iterrows(), total=self.data.shape[0]):
            mol_obj = Chem.MolFromSmiles(mol["smiles"])
            # Get node features
            node_feats = self._get_node_features(mol_obj)
            # Get edge features
            edge_feats = self._get_edge_features(mol_obj)
            # Get adjacency info
            edge_index = self._get_adjacency_info(mol_obj)
            # Get labels info
            label = self._get_labels(mol["HIV_active"])

            # Create data object
            data = Data(x=node_feats, 
                        edge_index=edge_index,
                        edge_attr=edge_feats,
                        y=label,
                        smiles=mol["smiles"]
                        ) 
            if self.test:
                torch.save(data, 
                    os.path.join(self.processed_dir, 
                                 f'data_test_{index}.pt'))
            else:
                torch.save(data, 
                    os.path.join(self.processed_dir, 
                                 f'data_{index}.pt'))

    # Based on domain knowledge we decide to get the node features
    def _get_node_features(self, mol):
        """ 
        This will return a matrix / 2d array of the shape
        [Number of Nodes, Node Feature size]
        """
        all_node_feats = []

        # Iterating over each ATOM of the passed molecule object
        for atom in mol.GetAtoms():
            node_feats = []

            # Feature 1: Atomic number        
            node_feats.append(atom.GetAtomicNum())
            # Feature 2: Atom degree
            node_feats.append(atom.GetDegree())
            # Feature 3: Formal charge
            node_feats.append(atom.GetFormalCharge())
            # Feature 4: Hybridization
            node_feats.append(atom.GetHybridization())
            # Feature 5: Aromaticity
            node_feats.append(atom.GetIsAromatic())
            # Feature 6: Total Num Hs
            node_feats.append(atom.GetTotalNumHs())
            # Feature 7: Radical Electrons
            node_feats.append(atom.GetNumRadicalElectrons())
            # Feature 8: In Ring
            node_feats.append(atom.IsInRing())
            # Feature 9: Chirality
            node_feats.append(atom.GetChiralTag())

            # Append node features to matrix
            all_node_feats.append(node_feats)

        all_node_feats = np.asarray(all_node_feats)
        return torch.tensor(all_node_feats, dtype=torch.float)

    # Based on domain knowledge we decide to get the edge features
    def _get_edge_features(self, mol):
        """ 
        This will return a matrix / 2d array of the shape
        [Number of edges, Edge Feature size]
        """
        all_edge_feats = []

        # Iterating over each BOND of the passed molecule object
        for bond in mol.GetBonds():
            edge_feats = []

            # Feature 1: Bond type (as double)
            edge_feats.append(bond.GetBondTypeAsDouble())
            # Feature 2: Rings
            edge_feats.append(bond.IsInRing())

            # Append node features to matrix (twice, per direction)
            all_edge_feats += [edge_feats, edge_feats]

        all_edge_feats = np.asarray(all_edge_feats)
        return torch.tensor(all_edge_feats, dtype=torch.float)

    def _get_adjacency_info(self, mol):
        """
        We could also use rdmolops.GetAdjacencyMatrix(mol)
        but we want to be sure that the order of the indices
        matches the order of the edge features
        """
        edge_indices = []
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            edge_indices += [[i, j], [j, i]]

        edge_indices = torch.tensor(edge_indices)
        edge_indices = edge_indices.t().to(torch.long).view(2, -1)
        return edge_indices

    def _get_labels(self, label):
        label = np.asarray([label])
        return torch.tensor(label, dtype=torch.int64)

    def len(self):
        return self.data.shape[0]

    def get(self, idx):
        """ - Equivalent to __getitem__ in pytorch
            - Is not needed for PyG's InMemoryDataset
        """
        if self.test:
            data = torch.load(os.path.join(self.processed_dir, 
                                 f'data_test_{idx}.pt'))
        else:
            data = torch.load(os.path.join(self.processed_dir, 
                                 f'data_{idx}.pt'))   
        return data

In [None]:
# Process the dataset
dataset = MoleculeDataset(root="data/",filename='HIV.csv')

print(dataset[0].edge_index.t())
print(dataset[0].x)
print(dataset[0].y)
print(dataset[0].edge_attr)

Now that the pre-processing is done, we can move on to building the GNN model to do graph-level predictions.

### Step 2: Building the GNN model

#### Task 1: Graph classification

Given a molecule, we classify it whether it is a HIV inhibitor (``1``) or not (``0``). For this task we need a feture vector for the whole graph instead of node-features or edge-features. There are a lot of approaches for this. We discuss a couple of them below:

1) **Naive pooling** - Apply mean/max/sum pooling to all node features to get the graph representation
2) **Hierarchical pooling** - similar to pooling done in images; at each step we share feature with neighbors. So, at each step we can drop the nodes in a specific manner so that finally we are left with a single node whose feature gives us the feature representation of the graph. The choice of nodes to drop is either done by either **Differential pooling** , which clusters nodes and pools its feature vectors and new graph with clusters as nodes is formed, or **Top-K pooling**, which squeezes the feature vectors to a single vector and top-K nodes are considered.
3) **Super/virtual/dummy node** - All nodes will pass their vectors to the super node during message passing but this node won't share its vector to others. The vector obtained by this node is the graph representation

First we do the DATASET changes

In [None]:
# Dataset changes
# Instead of manually assigning features here we us FEATURIZER by DeepChem

# import deepchem as dc

# class MoleculeDataset(Dataset):
#     def __init__(self, root, filename, test=False, transform=None, pre_transform=None):
#         """
#         root = Where the dataset should be stored. This folder is split
#         into raw_dir (downloaded dataset) and processed_dir (processed data). 
#         """
#         self.test = test
#         self.filename = filename
#         super(MoleculeDataset, self).__init__(root, transform, pre_transform)
        
#     @property
#     def raw_file_names(self):
#         """ If this file exists in raw_dir, the download is not triggered.
#             (The download func. is not implemented here)  
#         """
#         return self.filename

#     @property
#     def processed_file_names(self):
#         """ If these files are found in raw_dir, processing is skipped"""
#         self.data = pd.read_csv(self.raw_paths[0]).reset_index()

#         if self.test:
#             return [f'data_test_{i}.pt' for i in list(self.data.index)]
#         else:
#             return [f'data_{i}.pt' for i in list(self.data.index)]
        

#     def download(self):
#         pass

#     # Important changes compared to previous version of the class definition
#     def process(self):
#         # Reading the dataset
#         self.data = pd.read_csv(self.raw_paths[0]).reset_index()

#         # KEY STEP - using Featurizer
#         featurizer = dc.feat.MolGraphConvFeaturizer(use_edges=True)

#         # Iterating over each object to get the required parameters to preprocess
#         # tqdm gives process bar which tells us how far our process is done
#         for index, row in tqdm(self.data.iterrows(), total=self.data.shape[0]):
#             # Featurize molecule
#             mol = Chem.MolFromSmiles(row["smiles"])
#             f = featurizer._featurize(mol)
#             data = f.to_pyg_graph()
#             data.y = self._get_label(row["HIV_active"])
#             data.smiles = row["smiles"]

#             # Naming the processed data files
#             if self.test:
#                 torch.save(data, 
#                     os.path.join(self.processed_dir, 
#                                  f'data_test_{index}.pt'))
#             else:
#                 torch.save(data, 
#                     os.path.join(self.processed_dir, 
#                                  f'data_{index}.pt'))
            

#     def _get_label(self, label):
#         label = np.asarray([label])
#         return torch.tensor(label, dtype=torch.int64)

#     def len(self):
#         return self.data.shape[0]

#     def get(self, idx):
#         """ - Equivalent to __getitem__ in pytorch
#             - Is not needed for PyG's InMemoryDataset
#         """
#         if self.test:
#             data = torch.load(os.path.join(self.processed_dir, 
#                                  f'data_test_{idx}.pt'))
#         else:
#             data = torch.load(os.path.join(self.processed_dir, 
#                                  f'data_{idx}.pt'))        
#         return data

In [None]:
# Oversampling to balance the dataset
# Load raw dataset
import pandas as pd

data = pd.read_csv("data/raw/HIV_train.csv")
data.index = data["index"]
data["HIV_active"].value_counts()
start_index = data.iloc[0]["index"]

# Apply oversampling

# Check how many additional samples we need
neg_class = data["HIV_active"].value_counts()[0]
pos_class = data["HIV_active"].value_counts()[1]
multiplier = int(neg_class/pos_class) - 1

# Replicate the dataset for the positive class
replicated_pos = [data[data["HIV_active"] == 1]]*multiplier

# Append replicated data
data = pd.concat([data, pd.DataFrame([replicated_pos])], axis=0)
print(data.shape)

# Shuffle dataset
data = data.sample(frac=1).reset_index(drop=True)

# Re-assign index (This is our ID later)
index = range(start_index, start_index + data.shape[0])
data.index = index
data["index"] = data.index
data.head()

# %% Save
data.to_csv("data/raw/HIV_train_oversampled.csv")


We have already done these manually instead of using Deepchem so we comment it out for now. Next, we build the GNN architecture

In [None]:
# Building the model
import torch.nn.functional as F
from torch.nn import Sequential, Linear, BatchNorm1d, ReLU
from torch_geometric.nn import TransformerConv, GATConv, TopKPooling, BatchNorm
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
from torch_geometric.nn.conv.x_conv import XConv
torch.manual_seed(42)


class GNN(torch.nn.Module):
    def __init__(self, feature_size):
        super(GNN,self).__init__()
        num_classes = 2
        embedding_size = 1024


        # GNN layers
        # Since the molecule classes are not too big, we use 3 layers
        # 3 attention heads produce 3 different output vectors so we use linear layer to get a simple embedding
        self.conv1 = GATConv(feature_size, embedding_size, heads=3, dropout=0.3)
        self.head_transform1 = Linear(embedding_size*3,embedding_size)
        self.pool1 = TopKPooling(embedding_size,ratio=0.8)
        self.conv2 = GATConv(embedding_size, embedding_size, heads=3, dropout=0.3)
        self.head_transform2 = Linear(embedding_size*3,embedding_size)
        self.pool2 = TopKPooling(embedding_size,ratio=0.5)
        self.conv3 = GATConv(embedding_size, embedding_size, heads=3, dropout=0.3)
        self.head_transform3 = Linear(embedding_size*3,embedding_size)
        self.pool3 = TopKPooling(embedding_size,ratio=0.2)

        # Linear layers
        self.linear1 = Linear(embedding_size*2,1024)
        self.linear2 = Linear(1024,num_classes)
    
    def forward(self,x,edge_attr,edge_index,batch_index):
        """Does the forward pass of the feature vectors through each of the blocks defined above
        """

        # First block
        x = self.conv1(x,edge_index)
        x = self.head_transform1(x)         # converting back to embedding shape

        # Forming the new graph
        x, edge_index, edge_attr, batch_index, _, _ = self.pool1(x,edge_index,None,batch_index)

        # First output vector (by the first attention head)
        x1 = torch.cat([gmp(x,batch_index),gap(x,batch_index)],dim=1)


        # Second block
        x = self.conv2(x,edge_index)
        x = self.head_transform2(x)         # converting back to embedding shape

        # Forming the new graph
        x, edge_index, edge_attr, batch_index, _, _ = self.pool2(x,edge_index,None,batch_index)

        # Second output vector (by the second attention head)
        x2 = torch.cat([gmp(x,batch_index),gap(x,batch_index)],dim=1)


        # Third block
        x = self.conv3(x,edge_index)
        x = self.head_transform3(x)         # converting back to embedding shape

        # Forming the new graph
        x, edge_index, edge_attr, batch_index, _, _ = self.pool3(x,edge_index,None,batch_index)

        # Third output vector (by the third attention head)
        x3 = torch.cat([gmp(x,batch_index),gap(x,batch_index)],dim=1)


        # Concat pooled vectors
        x = x1 + x2 + x3


        # Output block
        x = self.linear1(x).relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.linear2(x)


        return x

Now that we have built the model, now we move onto to TRAINING and OPTIMIZATION

In [None]:
# Training
# imports
from torch_geometric.data import DataLoader
from tqdm import tqdm
import numpy as np
from sklearn.metrics import confusion_matrix, f1_score, accuracy_score, precision_score, recall_score, roc_auc_score
import mlflow.pytorch

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Loading dataset
train_dataset = MoleculeDataset(root="data/",filename="HIV_train_oversampled.csv")
test_dataset = MoleculeDataset(root="data/",filename="HIV_test.csv")

In [None]:
# Loading GNN model
model = GNN(feature_size=train_dataset[0].x.shape[1])
model = model.to(device)
print(f"Number of parameter: {count_parameters(model)}")
model

In [None]:
# Loss and Optimizer
weights = torch.tensor([1,10], dtype=torch.float32).to(device)
loss_fn = torch.nn.CrossEntropyLoss(weight=weights)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)       # Exponentially decay lr at each epoch

In [None]:
# Prepare training
NUM_GRAPHS_PER_BATCH = 256
train_loader = DataLoader(train_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=True)

In [None]:
# train function
def train(epoch):
    # Enumerate over data
    all_preds = []
    all_labels = []
    for _, batch in enumerate(tqdm(train_loader)):
        batch.to(device)
        # Reset gradients
        optimizer.zero_grad()
        # Passing node features and connection info
        pred = model(batch.x.float(), batch.edge_attr.float(), batch.edge_index, batch.batch)
        # Calculating loss and gradients
        loss = torch.sqrt(loss_fn(pred,batch.y))
        loss.backward()
        # Update using gradients
        optimizer.step()

        all_preds.append(np.argmax(pred.cpu().detach().numpy(), axis=1))
        all_labels.append(batch.y.cpu().detach().numpy())
    
    all_preds = np.concatenate(all_preds).ravel()
    all_labels = np.concatenate(all_labels).ravel()
    calculate_metrics(all_preds, all_labels, epoch, "train")
    
    return loss

def test(epoch):
    # Enumerate over data
    all_preds = []
    all_labels = []
    for batch in test_loader:
        batch.to(device)
        # Passing node features and connection info
        pred = model(batch.x.float(), batch.edge_attr.float(), batch.edge_index, batch.batch)
        # Calculating loss and gradients
        loss = torch.sqrt(loss_fn(pred,batch.y))
        all_preds.append(np.argmax(pred.cpu().detach().numpy(), axis=1))
        all_labels.append(batch.y.cpu().detach().numpy())
    
    all_preds = np.concatenate(all_preds).ravel()
    all_labels = np.concatenate(all_labels).ravel()
    calculate_metrics(all_preds, all_labels, epoch, "train")
    
    return loss

def calculate_metrics(y_pred, y_true, epoch, type):
    print(f"\n Confusion matrix: \n {confusion_matrix(y_true, y_pred)}")
    print(f"F1 score: {f1_score(y_true, y_pred)}")
    print(f"Accuracy: {accuracy_score(y_true, y_pred)}")
    print(f"Precision: {precision_score(y_true, y_pred)}")
    print(f"Recall: {recall_score(y_true, y_pred)}")
    try:
        roc = roc_auc_score(y_true, y_pred)
        print(f"ROC AUC: {roc}")
        mlflow.log_metric(key=f"ROC-AUC-{type}", value=float(roc), step=epoch)
    except:
        mlflow.log_metric(key=f"ROC-AUC-{type}", value=float(0), step=epoch)
        print(f"ROC AUC: notdefined")

In [None]:
# Run the training
with mlflow.start_run() as run:
    for epoch in range(50):
        # set as training
        model.train()
        # call train fn
        loss = train(epoch=epoch)
        loss = loss.cpu().detach().numpy()
        print(f"Epoch {epoch}   | Train loss {loss}")
        mlflow.log_metric(key="Train loss", value=float(loss), step=epoch)

        # set as testing
        model.eval()
        if epoch % 5 == 0:
            # call test fn
            loss = test(epoch=epoch)
            print(f"Epoch {epoch}   | Test loss {loss}")
            mlflow.log_metric(key="Test loss", value=float(loss), step=epoch)
        
        scheduler.step()
    
    print('Done')

In [None]:
# Save the model
mlflow.pytorch.log_model(model, "model")