### Importing libraries

In [1]:
###Author: Andrea Mastropietro © All rights reserved

import os

import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import Linear, GraphConv, global_add_pool
import torch.nn.functional as F

import random
import numpy as np
from sklearn.preprocessing import RobustScaler

import json
import networkx as nx
import pandas as pd
from tqdm.auto import tqdm

from src.utils import create_edge_index, ChemicalDataset 

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Working on device: ", device)

Working on device:  cuda


### Set random seeds

In [3]:
SEED = 42

torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

### Load data and affinity information

In [4]:
DATA_PATH = 'data/interaction_affinity_data/'
interaction_affinities = None

with open(DATA_PATH + 'interaction_affinities.json', 'r') as fp:
    interaction_affinities = json.load(fp)

affinities_df = pd.DataFrame.from_dict(interaction_affinities, orient='index', columns=['affinity'])

display(affinities_df.head())

affinities_df = affinities_df.sort_values(by = "affinity", ascending=True)
interaction_affinities = affinities_df.to_dict(orient='index')

Unnamed: 0,affinity
10gs,6.4
11gs,5.82
13gs,4.62
16pk,5.22
184l,4.72


### Define node features

In [5]:
descriptors_interaction_dict = None
num_node_features = 0

descriptors_interaction_dict = {}
descriptors_interaction_dict["CA"] = [1, 0, 0, 0, 0, 0, 0, 0]
descriptors_interaction_dict["NZ"] = [0, 1, 0, 0, 0, 0, 0, 0]
descriptors_interaction_dict["N"] = [0, 0, 1, 0, 0, 0, 0, 0]
descriptors_interaction_dict["OG"] = [0, 0, 0, 1, 0, 0, 0, 0]
descriptors_interaction_dict["O"] = [0, 0, 0, 0, 1, 0, 0, 0]
descriptors_interaction_dict["CZ"] = [0, 0, 0, 0, 0, 1, 0, 0]
descriptors_interaction_dict["OD1"] = [0, 0, 0, 0, 0, 0, 1, 0]
descriptors_interaction_dict["ZN"] = [0, 0, 0, 0, 0, 0, 0, 1]

num_node_features = len(descriptors_interaction_dict["CA"])

### Dataset generation

In [6]:
def generate_pli_dataset_dict(data_path):

    directory = os.fsencode(data_path)

    dataset_dict = {}
    dirs = os.listdir(directory)
    for file in tqdm(dirs):
        interaction_name = os.fsdecode(file)

        if interaction_name in interaction_affinities:
            if os.path.isdir(data_path + interaction_name):
                dataset_dict[interaction_name] = {}
                G = None
                with open(data_path + interaction_name + "/" + interaction_name + "_interaction_graph.json", 'r') as f:
                    data = json.load(f)
                    G = nx.Graph()

                    for node in data['nodes']:
                        G.add_node(node["id"], atom_type=node["attype"], origin=node["pl"]) 

                    for edge in data['edges']:
                        if edge["id1"] != None and edge["id2"] != None:
                            G.add_edge(edge["id1"], edge["id2"], weight= float(edge["length"]))
                            

                    for node in data['nodes']:
                        nx.set_node_attributes(G, {node["id"]: node["attype"]}, "atom_type")
                        nx.set_node_attributes(G, {node["id"]: node["pl"]}, "origin")

                    
                    
                dataset_dict[interaction_name]["networkx_graph"] = G
                edge_index, edge_weight = create_edge_index(G, weighted=True)

                dataset_dict[interaction_name]["edge_index"] = edge_index
                dataset_dict[interaction_name]["edge_weight"] = edge_weight
                

                num_nodes = G.number_of_nodes()
                
                
                
                dataset_dict[interaction_name]["x"] = torch.zeros((num_nodes, num_node_features), dtype=torch.float)
                for node in G.nodes:
                    dataset_dict[interaction_name]["x"][node] = torch.tensor(descriptors_interaction_dict[G.nodes[node]["atom_type"]], dtype=torch.float)
                    
                ## gather label
                dataset_dict[interaction_name]["y"] = torch.FloatTensor([interaction_affinities[interaction_name]["affinity"]])

    
    return dataset_dict

In [7]:
pli_dataset_dict = generate_pli_dataset_dict(DATA_PATH + "/dataset/")

  0%|          | 0/14215 [00:00<?, ?it/s]

### Scaling edge weights (distance in Angstrom - Å)

In [8]:
first_level = [pli_dataset_dict[key]["edge_weight"] for key in pli_dataset_dict]
second_level = [item for sublist in first_level for item in sublist]

transformer = RobustScaler().fit(np.array(second_level).reshape(-1, 1))

for key in tqdm(pli_dataset_dict):
    scaled_weights = transformer.transform(np.array(pli_dataset_dict[key]["edge_weight"]).reshape(-1, 1))
    scaled_weights = [x[0] for x in scaled_weights]
    pli_dataset_dict[key]["edge_weight"] = torch.FloatTensor(scaled_weights)
    

  0%|          | 0/14215 [00:00<?, ?it/s]

### Define data list

In [9]:
data_list = []
EDGE_WEIGHT = True
for interaction_name in tqdm(pli_dataset_dict):
    edge_weight_sample = None
    if EDGE_WEIGHT:
        edge_weight_sample = pli_dataset_dict[interaction_name]["edge_weight"]
    data_list.append(Data(x = pli_dataset_dict[interaction_name]["x"], edge_index = pli_dataset_dict[interaction_name]["edge_index"], edge_weight = edge_weight_sample, y = pli_dataset_dict[interaction_name]["y"], networkx_graph = pli_dataset_dict[interaction_name]["networkx_graph"], interaction_name = interaction_name))

  0%|          | 0/14215 [00:00<?, ?it/s]

### Instantiate dataset

In [10]:
dataset = ChemicalDataset(".", data_list = data_list)

### Gather train/val/test splits

In [11]:
train_interactions = []
val_interactions = []
core_set_interactions = []
hold_out_interactions = []

with open(DATA_PATH + "data_splits/training_set.csv", 'r') as f:
    train_interactions = f.readlines()

train_interactions = [interaction.strip() for interaction in train_interactions]

with open(DATA_PATH + "data_splits/validation_set.csv", 'r') as f:
    val_interactions = f.readlines()

val_interactions = [interaction.strip() for interaction in val_interactions]

with open(DATA_PATH + "data_splits/core_set.csv", 'r') as f:
    core_set_interactions = f.readlines()

core_set_interactions = [interaction.strip() for interaction in core_set_interactions]

with open(DATA_PATH + "data_splits/hold_out_set.csv", 'r') as f:
    hold_out_interactions = f.readlines()

hold_out_interactions = [interaction.strip() for interaction in hold_out_interactions]

train_data = [dataset[i] for i in range(len(dataset)) if dataset[i].interaction_name in train_interactions]
val_data = [dataset[i] for i in range(len(dataset)) if dataset[i].interaction_name in val_interactions]
core_set_data = [dataset[i] for i in range(len(dataset)) if dataset[i].interaction_name in core_set_interactions]
hold_out_data = [dataset[i] for i in range(len(dataset)) if dataset[i].interaction_name in hold_out_interactions]

rng = np.random.default_rng(seed = SEED)
rng.shuffle(train_data)
rng.shuffle(val_data)
rng.shuffle(core_set_data)
rng.shuffle(hold_out_data)

print("Number of training samples: ", len(train_data))
print("Number of validation samples: ", len(val_data))
print("Number of core set samples: ", len(core_set_data))
print("Number of hold-out samples: ", len(hold_out_data))

BATCH_SIZE = 32

train_loader = DataLoader(train_data, batch_size=BATCH_SIZE)
val_loader = DataLoader(val_data, batch_size=BATCH_SIZE)
core_set_loader = DataLoader(core_set_data, batch_size=BATCH_SIZE)
hold_out_loader = DataLoader(hold_out_data, batch_size=BATCH_SIZE)

Number of training samples:  9662
Number of validation samples:  903
Number of core set samples:  257
Number of hold-out samples:  3393


### Define the GNN model - GraphConv

In [12]:
class GC_GNN(torch.nn.Module):
    def __init__(self, node_features_dim, hidden_channels, num_classes):
        super().__init__()
        self.conv1 = GraphConv(node_features_dim, hidden_channels, aggr='max')
        self.conv2 = GraphConv(hidden_channels, hidden_channels, aggr='max')
        self.conv3 = GraphConv(hidden_channels, hidden_channels, aggr='max')
        self.conv4 = GraphConv(hidden_channels, hidden_channels, aggr='max')
        self.conv5 = GraphConv(hidden_channels, hidden_channels, aggr='max')
        self.conv6 = GraphConv(hidden_channels, hidden_channels, aggr='max')
        self.conv7 = GraphConv(hidden_channels, hidden_channels, aggr='max')
        self.lin = Linear(hidden_channels, num_classes)

    def forward(self, x, edge_index, batch, edge_weight = None):

        x = F.relu(self.conv1(x, edge_index, edge_weight = edge_weight))
        x = F.relu(self.conv2(x, edge_index, edge_weight = edge_weight))
        x = F.relu(self.conv3(x, edge_index, edge_weight = edge_weight))
        x = F.relu(self.conv4(x, edge_index, edge_weight = edge_weight))
        x = F.relu(self.conv5(x, edge_index, edge_weight = edge_weight))
        x = F.relu(self.conv6(x, edge_index, edge_weight = edge_weight))
        x = self.conv7(x, edge_index, edge_weight = edge_weight)
        
        x = global_add_pool(x, batch)
        
        x = F.dropout(x, training=self.training)
        x = self.lin(x)

        return x

### Train the model

In [13]:
model = GC_GNN(node_features_dim = dataset[0].x.shape[1], num_classes = 1, hidden_channels=256).to(device)

lr = 1e-3
WEIGHT_DECAY = 5e-4

optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=WEIGHT_DECAY)

criterion = torch.nn.MSELoss()
    
epochs = 100

In [14]:
def train():
        model.train()

        for data in train_loader:  # Iterate in batches over the training dataset.
            data = data.to(device)
            out = model(data.x, data.edge_index, data.batch, edge_weight = data.edge_weight)  # Perform a single forward pass.
            
            loss = torch.sqrt(criterion(torch.squeeze(out), data.y))  # Compute the loss.
        
            loss.backward()  # Derive gradients.
            optimizer.step()  # Update parameters based on gradients.
            optimizer.zero_grad()  # Clear gradients.

def test(loader):
    model.eval()

    sum_loss = 0
    for data in loader:  # Iterate in batches over the training/test dataset.
        data = data.to(device)
        
        out = model(data.x, data.edge_index, data.batch, edge_weight = data.edge_weight)  
        
        if  data.y.shape[0] == 1:
            loss = torch.sqrt(criterion(torch.squeeze(out, 1), data.y))
        else:
            loss = torch.sqrt(criterion(torch.squeeze(out), data.y)) * data.y.shape[0]
        sum_loss += loss.item()
        
    return sum_loss / len(loader.dataset) 

### Training loop saving the best model

In [15]:
best_epoch = 0
best_val_loss = 100000

MODEL_SAVE_FOLDER = "models/"
for epoch in tqdm(range(epochs)):
    train()
    train_rmse = test(train_loader)
    val_rmse = test(val_loader)
    if val_rmse < best_val_loss:
        best_val_loss = val_rmse
        best_epoch = epoch
        
        if not os.path.exists(MODEL_SAVE_FOLDER):
            os.makedirs(MODEL_SAVE_FOLDER)

        torch.save(model.state_dict(), MODEL_SAVE_FOLDER + "gc_gnn_model.ckpt")
        
    print(f'Epoch: {epoch:03d}, Train RMSE: {train_rmse:.4f}, Val RMSE: {val_rmse:.4f}')


print(f'Best model at epoch: {best_epoch:03d}')
print("Best val loss: ", best_val_loss)

model = GC_GNN(node_features_dim = dataset[0].x.shape[1], num_classes = 1, hidden_channels=256).to(device)
model.load_state_dict(torch.load(MODEL_SAVE_FOLDER + "gc_gnn_model.ckpt"))
model.to(device)

core_set_rmse = test(core_set_loader)    
print(f'Core set RMSE with best model: {core_set_rmse:.4f}')

hold_out_set_rmse = test(hold_out_loader)    
print(f'Hold-out set RMSE with best model: {hold_out_set_rmse:.4f}')

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch: 000, Train RMSE: 2.0754, Val RMSE: 2.1572
Epoch: 001, Train RMSE: 1.9837, Val RMSE: 2.0637
Epoch: 002, Train RMSE: 1.8993, Val RMSE: 1.9833
Epoch: 003, Train RMSE: 1.8570, Val RMSE: 1.9218
Epoch: 004, Train RMSE: 1.8598, Val RMSE: 1.9150
Epoch: 005, Train RMSE: 1.8208, Val RMSE: 1.8557
Epoch: 006, Train RMSE: 1.7732, Val RMSE: 1.8406
Epoch: 007, Train RMSE: 1.7464, Val RMSE: 1.8101
Epoch: 008, Train RMSE: 1.7407, Val RMSE: 1.8078
Epoch: 009, Train RMSE: 1.7134, Val RMSE: 1.7679
Epoch: 010, Train RMSE: 1.6805, Val RMSE: 1.7406
Epoch: 011, Train RMSE: 1.6835, Val RMSE: 1.7471
Epoch: 012, Train RMSE: 1.6684, Val RMSE: 1.7225
Epoch: 013, Train RMSE: 1.6537, Val RMSE: 1.7174
Epoch: 014, Train RMSE: 1.6857, Val RMSE: 1.7308
Epoch: 015, Train RMSE: 1.6117, Val RMSE: 1.6947
Epoch: 016, Train RMSE: 1.6117, Val RMSE: 1.6928
Epoch: 017, Train RMSE: 1.6335, Val RMSE: 1.7052
Epoch: 018, Train RMSE: 1.6518, Val RMSE: 1.6989
Epoch: 019, Train RMSE: 1.5921, Val RMSE: 1.6918
Epoch: 020, Train RM