### Importing libraries

In [1]:
### Author: Andrea Mastropietro © All rights reserved

import os

import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import Linear, GraphConv, global_add_pool
import torch.nn.functional as F

import random
import numpy as np
from sklearn.preprocessing import RobustScaler

import json
import networkx as nx
import pandas as pd
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

from src.utils import create_edge_index, ChemicalDataset 
from src.edgeshaper import Edgeshaper

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Working on device: ", device)

Working on device:  cuda


### Set random seeds

In [3]:
SEED = 42

torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

### Load data and dataset definition

In [4]:
DATA_PATH = 'data/interaction_affinity_data/'
interaction_affinities = None

with open(DATA_PATH + 'interaction_affinities.json', 'r') as fp:
    interaction_affinities = json.load(fp)

affinities_df = pd.DataFrame.from_dict(interaction_affinities, orient='index', columns=['affinity'])

display(affinities_df.head())

affinities_df = affinities_df.sort_values(by = "affinity", ascending=True)
interaction_affinities = affinities_df.to_dict(orient='index')

descriptors_interaction_dict = None
num_node_features = 0

descriptors_interaction_dict = {}
descriptors_interaction_dict["CA"] = [1, 0, 0, 0, 0, 0, 0, 0]
descriptors_interaction_dict["NZ"] = [0, 1, 0, 0, 0, 0, 0, 0]
descriptors_interaction_dict["N"] = [0, 0, 1, 0, 0, 0, 0, 0]
descriptors_interaction_dict["OG"] = [0, 0, 0, 1, 0, 0, 0, 0]
descriptors_interaction_dict["O"] = [0, 0, 0, 0, 1, 0, 0, 0]
descriptors_interaction_dict["CZ"] = [0, 0, 0, 0, 0, 1, 0, 0]
descriptors_interaction_dict["OD1"] = [0, 0, 0, 0, 0, 0, 1, 0]
descriptors_interaction_dict["ZN"] = [0, 0, 0, 0, 0, 0, 0, 1]

num_node_features = len(descriptors_interaction_dict["CA"])

def generate_pli_dataset_dict(data_path):

    directory = os.fsencode(data_path)

    dataset_dict = {}
    dirs = os.listdir(directory)
    for file in tqdm(dirs):
        interaction_name = os.fsdecode(file)

        if interaction_name in interaction_affinities:
            if os.path.isdir(data_path + interaction_name):
                dataset_dict[interaction_name] = {}
                G = None
                with open(data_path + interaction_name + "/" + interaction_name + "_interaction_graph.json", 'r') as f:
                    data = json.load(f)
                    G = nx.Graph()

                    for node in data['nodes']:
                        G.add_node(node["id"], atom_type=node["attype"], origin=node["pl"]) 

                    for edge in data['edges']:
                        if edge["id1"] != None and edge["id2"] != None:
                            G.add_edge(edge["id1"], edge["id2"], weight= float(edge["length"]))
                            

                    for node in data['nodes']:
                        nx.set_node_attributes(G, {node["id"]: node["attype"]}, "atom_type")
                        nx.set_node_attributes(G, {node["id"]: node["pl"]}, "origin")

                    
                    
                dataset_dict[interaction_name]["networkx_graph"] = G
                edge_index, edge_weight = create_edge_index(G, weighted=True)

                dataset_dict[interaction_name]["edge_index"] = edge_index
                dataset_dict[interaction_name]["edge_weight"] = edge_weight
                

                num_nodes = G.number_of_nodes()
                
                
                
                dataset_dict[interaction_name]["x"] = torch.zeros((num_nodes, num_node_features), dtype=torch.float)
                for node in G.nodes:
                    dataset_dict[interaction_name]["x"][node] = torch.tensor(descriptors_interaction_dict[G.nodes[node]["atom_type"]], dtype=torch.float)
                    
                ## gather label
                dataset_dict[interaction_name]["y"] = torch.FloatTensor([interaction_affinities[interaction_name]["affinity"]])

    
    return dataset_dict

pli_dataset_dict = generate_pli_dataset_dict(DATA_PATH + "/dataset/")

first_level = [pli_dataset_dict[key]["edge_weight"] for key in pli_dataset_dict]
second_level = [item for sublist in first_level for item in sublist]

transformer = RobustScaler().fit(np.array(second_level).reshape(-1, 1))

for key in tqdm(pli_dataset_dict):
    scaled_weights = transformer.transform(np.array(pli_dataset_dict[key]["edge_weight"]).reshape(-1, 1))
    scaled_weights = [x[0] for x in scaled_weights]
    pli_dataset_dict[key]["edge_weight"] = torch.FloatTensor(scaled_weights)

data_list = []
EDGE_WEIGHT = True
for interaction_name in tqdm(pli_dataset_dict):
    edge_weight_sample = None
    if EDGE_WEIGHT:
        edge_weight_sample = pli_dataset_dict[interaction_name]["edge_weight"]
    data_list.append(Data(x = pli_dataset_dict[interaction_name]["x"], edge_index = pli_dataset_dict[interaction_name]["edge_index"], edge_weight = edge_weight_sample, y = pli_dataset_dict[interaction_name]["y"], networkx_graph = pli_dataset_dict[interaction_name]["networkx_graph"], interaction_name = interaction_name))

dataset = ChemicalDataset(".", data_list = data_list)

Unnamed: 0,affinity
10gs,6.4
11gs,5.82
13gs,4.62
16pk,5.22
184l,4.72


  0%|          | 0/14215 [00:00<?, ?it/s]

  0%|          | 0/14215 [00:00<?, ?it/s]

  0%|          | 0/14215 [00:00<?, ?it/s]

### Gather test data split

In [5]:
hold_out_interactions = []

with open(DATA_PATH + "data_splits/hold_out_set.csv", 'r') as f:
    hold_out_interactions = f.readlines()

hold_out_interactions = [interaction.strip() for interaction in hold_out_interactions]

hold_out_data = [dataset[i] for i in range(len(dataset)) if dataset[i].interaction_name in hold_out_interactions]
rng = np.random.default_rng(seed = SEED)
rng.shuffle(hold_out_data)

BATCH_SIZE = 32
hold_out_loader = DataLoader(hold_out_data, batch_size=BATCH_SIZE)

### Load model

In [6]:
class GC_GNN(torch.nn.Module):
    def __init__(self, node_features_dim, hidden_channels, num_classes):
        super().__init__()
        self.conv1 = GraphConv(node_features_dim, hidden_channels, aggr='max')
        self.conv2 = GraphConv(hidden_channels, hidden_channels, aggr='max')
        self.conv3 = GraphConv(hidden_channels, hidden_channels, aggr='max')
        self.conv4 = GraphConv(hidden_channels, hidden_channels, aggr='max')
        self.conv5 = GraphConv(hidden_channels, hidden_channels, aggr='max')
        self.conv6 = GraphConv(hidden_channels, hidden_channels, aggr='max')
        self.conv7 = GraphConv(hidden_channels, hidden_channels, aggr='max')
        self.lin = Linear(hidden_channels, num_classes)

    def forward(self, x, edge_index, batch, edge_weight = None):

        x = F.relu(self.conv1(x, edge_index, edge_weight = edge_weight))
        x = F.relu(self.conv2(x, edge_index, edge_weight = edge_weight))
        x = F.relu(self.conv3(x, edge_index, edge_weight = edge_weight))
        x = F.relu(self.conv4(x, edge_index, edge_weight = edge_weight))
        x = F.relu(self.conv5(x, edge_index, edge_weight = edge_weight))
        x = F.relu(self.conv6(x, edge_index, edge_weight = edge_weight))
        x = self.conv7(x, edge_index, edge_weight = edge_weight)
        
        x = global_add_pool(x, batch)
        
        x = F.dropout(x, training=self.training)
        x = self.lin(x)

        return x

In [7]:
MODEL_PATH = "models/gc_gnn_model.ckpt"
HIDDEN_CHANNELS = 256
criterion = torch.nn.MSELoss()

model = GC_GNN(node_features_dim = dataset[0].x.shape[1], num_classes = 1, hidden_channels=HIDDEN_CHANNELS).to(device)
model.load_state_dict(torch.load(MODEL_PATH))
model.to(device)

GC_GNN(
  (conv1): GraphConv(8, 256)
  (conv2): GraphConv(256, 256)
  (conv3): GraphConv(256, 256)
  (conv4): GraphConv(256, 256)
  (conv5): GraphConv(256, 256)
  (conv6): GraphConv(256, 256)
  (conv7): GraphConv(256, 256)
  (lin): Linear(256, 1, bias=True)
)

In [8]:
def test(loader):
    model.eval()

    sum_loss = 0
    for data in loader: 
        data = data.to(device)
        
        out = model(data.x, data.edge_index, data.batch, edge_weight = data.edge_weight)  
        
        if  data.y.shape[0] == 1:
            loss = torch.sqrt(criterion(torch.squeeze(out, 1), data.y))
        else:
            loss = torch.sqrt(criterion(torch.squeeze(out), data.y)) * data.y.shape[0]
        sum_loss += loss.item()
        
    return sum_loss / len(loader.dataset) 

In [9]:
hold_out_set_rmse = test(hold_out_loader)    
print(f'Hold-out set RMSE with loaded model: {hold_out_set_rmse:.4f}')

Hold-out set RMSE with loaded model: 1.6386


### Explaining the predictions using EdgeSHAPer

In [10]:
num_all_test_interactions = len(hold_out_data)
all_test_interaction_indices = np.array(range(num_all_test_interactions))
rng = np.random.default_rng(seed=SEED)
rng.shuffle(all_test_interaction_indices)

display(all_test_interaction_indices)

SAMPLES_TO_EXPLAIN = 5

array([1103,   77, 3258, ..., 1633, 3021, 1672])

In [11]:
EDGE_WEIGHT = True
TARGET_CLASS= None #regression problem
TOP_K_EDGES = 25
EXPLANATION_FOLDER = "results/protein_ligand_affinity_explanations_edgeshaper/"

protein_edges_important = []
ligand_edges_important = []
interaction_edges_important = []

for index in tqdm(all_test_interaction_indices[:SAMPLES_TO_EXPLAIN]):
    model.eval()
    test_interaction = hold_out_data[index]
    print("\nInteraction: " + test_interaction.interaction_name)

    edge_weight_to_pass = None
    if EDGE_WEIGHT:
        edge_weight_to_pass = test_interaction.edge_weight.to(device)

    batch = torch.zeros(test_interaction.x.shape[0], dtype=int, device=test_interaction.x.device)
    
    
    out = model(test_interaction.x.to(device), test_interaction.edge_index.to(device), batch=batch.to(device), edge_weight=edge_weight_to_pass)

    edgeshaper_explainer = Edgeshaper(model, test_interaction.x, test_interaction.edge_index, edge_weight = test_interaction.edge_weight, device = device)

    phi_edges = edgeshaper_explainer.explain(M = 100, target_class = TARGET_CLASS, seed = SEED) 

    SAVE_PATH = EXPLANATION_FOLDER + test_interaction.interaction_name + "/"
    
    
    if not os.path.exists(SAVE_PATH):
        os.makedirs(SAVE_PATH)


    with open(SAVE_PATH + test_interaction.interaction_name + "_statistics.txt", "w+") as f:
        f.write("Interaction name: " + test_interaction.interaction_name + "\n\n")
        f.write("Affinity: " + str(test_interaction.y.item()) + "\n")
        f.write("Predicted value: " + str(out.item()) + "\n\n")


        f.write("Shapley values for edges: \n\n")
        for i in range(len(phi_edges)):
            f.write("(" + str(test_interaction.edge_index[0][i].item()) + "," + str(test_interaction.edge_index[1][i].item()) + "): " + str(phi_edges[i]) + "\n\n")

    #plotting
    print("Saving explanations visualization...")

    num_bonds = test_interaction.networkx_graph.number_of_edges()

    rdkit_bonds_phi = [0]*num_bonds
    rdkit_bonds = {}

    bonds = dict(test_interaction.networkx_graph.edges())
    bonds = list(bonds.keys())

    for i in range(num_bonds):
        init_atom = bonds[i][0]
        end_atom = bonds[i][1]
        
        rdkit_bonds[(init_atom, end_atom)] = i

    
    
    for i in range(len(phi_edges)):
        phi_value = phi_edges[i]
        init_atom = test_interaction.edge_index[0][i].item()
        end_atom = test_interaction.edge_index[1][i].item()
        
        if (init_atom, end_atom) in rdkit_bonds:
            bond_index = rdkit_bonds[(init_atom, end_atom)]
            rdkit_bonds_phi[bond_index] += phi_value
        if (end_atom, init_atom) in rdkit_bonds:
            bond_index = rdkit_bonds[(end_atom, init_atom)]
            rdkit_bonds_phi[bond_index] += phi_value

    G = test_interaction.networkx_graph
    colors = ["red" if G.nodes[node]["origin"] == "L" else "lightblue" for node in G.nodes]

    absolute_phi = np.abs(rdkit_bonds_phi)
    #sort indices according to decreasing phi values
    indices_sorted = np.argsort(-absolute_phi)

    top_edges = indices_sorted[:TOP_K_EDGES]

    atoms_origin = nx.get_node_attributes(G, 'origin')

    edges_to_draw = []
    edges_colors = []
    edges_widths = []

    num_protein_edges_important = 0
    num_ligand_edges_important = 0
    num_interaction_edges_important = 0
    for bond in bonds:
        init_atom = bond[0]
        end_atom = bond[1]

        bond_index = rdkit_bonds[(init_atom, end_atom)]
        if bond_index in top_edges:
            if atoms_origin[init_atom] == "P" and atoms_origin[end_atom] == "P":
                num_protein_edges_important += 1
                edges_colors.append("darkblue")
            elif atoms_origin[init_atom] == "L" and atoms_origin[end_atom] == "L":
                num_ligand_edges_important += 1
                edges_colors.append("darkred")
            else:
                num_interaction_edges_important += 1
                edges_colors.append("darkgreen")
            edges_widths.append(3)
            
        else:
            edges_colors.append("lightgrey") 
            edges_widths.append(1.5)

    protein_edges_important.append(num_protein_edges_important)
    ligand_edges_important.append(num_ligand_edges_important)
    interaction_edges_important.append(num_interaction_edges_important)

    with open(SAVE_PATH + test_interaction.interaction_name + "_statistics.txt", "a+") as f:
        f.write("Number of important protein edges: " + str(num_protein_edges_important) + "\n")
        f.write("Number of important ligand edges: " + str(num_ligand_edges_important) + "\n")
        f.write("Number of important interaction edges: " + str(num_interaction_edges_important) + "\n")
        
    #draw graph with important edges
    plt.figure(figsize=(10,10))
    pos = nx.spring_layout(G)

    nx.draw(G, pos=pos, node_size = 400, with_labels=True, font_weight='bold', labels=nx.get_node_attributes(G, 'atom_type'), node_color=colors,edge_color=edges_colors, width=edges_widths)  

    plt.savefig(SAVE_PATH + test_interaction.interaction_name + "_EdgeSHAPer_top_" + str(TOP_K_EDGES) + "_edges.png" , dpi=300)
    
    plt.close()

    #save original interaction graph
    plt.figure(figsize=(10,10))
                    
    nx.draw(G, pos=pos, node_size = 400, with_labels=True, font_weight='bold', labels=nx.get_node_attributes(G, 'atom_type'), node_color=colors)
    
    plt.savefig(SAVE_PATH + test_interaction.interaction_name + "_interaction_graph.png" , dpi=300)
    
    plt.close()

print("\nAverage number of important protein edges: ", np.mean(protein_edges_important))
print("Average number of important ligand edges: ", np.mean(ligand_edges_important))
print("Average number of important interaction edges: ", np.mean(interaction_edges_important))

  0%|          | 0/5 [00:00<?, ?it/s]


Interaction: 6gfz
No target class specified. Regression model assumed.


100%|██████████| 248/248 [05:03<00:00,  1.22s/it]


Saving explanations visualization...

Interaction: 1o4q
No target class specified. Regression model assumed.


100%|██████████| 248/248 [04:32<00:00,  1.10s/it]


Saving explanations visualization...

Interaction: 5a3r
No target class specified. Regression model assumed.


100%|██████████| 372/372 [08:44<00:00,  1.41s/it]


Saving explanations visualization...

Interaction: 5nf5
No target class specified. Regression model assumed.


100%|██████████| 198/198 [03:34<00:00,  1.09s/it]


Saving explanations visualization...

Interaction: 5oui
No target class specified. Regression model assumed.


100%|██████████| 108/108 [01:37<00:00,  1.11it/s]


Saving explanations visualization...
Average number of important protein edges:  2.6
Average number of important ligand edges:  9.2
Average number of important interaction edges:  13.2
