In [1]:
import os
from tqdm.notebook import tqdm
import json

import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import torch.nn.functional as F

import numpy as np
from sklearn.metrics import confusion_matrix, classification_report

from rdkit import Chem
from rdkit.Chem import Draw

import json
import networkx as nx
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns

from src.utils import create_edge_index, PLIDataset, set_all_seeds, GCN, save_model
from src.edgeshaper import Edgeshaper

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
set_all_seeds(42)

In [3]:
DATA_PATH = 'data/pdbbind/dataset/'
SAVE_FOLDER = 'results/explanations/'
CLASSIFICATION = False
TRAIN = False

NUM_CLASSES = 1

if CLASSIFICATION:
    NUM_CLASSES = 4

## EDA

read mol2 files

In [4]:
# target_mol_name = "1a0t"
# mol = Chem.MolFromMol2File(DATA_PATH + target_mol_name + "/" + target_mol_name + "_ligand.mol2")
# test_mol = Draw.PrepareMolForDrawing(mol)
# test_mol


In [5]:
# num_bonds = len(test_mol.GetBonds())
# num_atoms = len(test_mol.GetAtoms())

# print("Number of bonds: ", num_bonds)
# print("Number of atoms: ", num_atoms)

# rdkit_bonds = {}

# for i in range(num_bonds):
#     init_atom = test_mol.GetBondWithIdx(i).GetBeginAtomIdx()
#     end_atom = test_mol.GetBondWithIdx(i).GetEndAtomIdx()
#     bond_type = test_mol.GetBondWithIdx(i).GetBondType()
#     print("Bond: ", i, " " , init_atom, "-" , end_atom, " ", bond_type)
#     rdkit_bonds[(init_atom, end_atom)] = i
#     #CNC(=O)CN1CN(c2ccccc2)C2(CCN(Cc3cc4c(cc3Cl)OCO4)CC2)C1=O
# # rdkit_bonds

read json interaction graph

In [6]:
# with open(DATA_PATH + target_mol_name + "/" + target_mol_name + "_interaction_graph.json", 'r') as f:
#   data = json.load(f)

# print("Number of atoms: ", len(data['nodes']))
# print("Number of bonds: ", len(data['edges']))

In [7]:
# print(data['nodes'])

Visualize interaction graph - red atoms from Protein and lightblue atoms from Ligand

In [8]:
# G = nx.Graph()

# for edge in data['edges']:
#     if edge["id1"] != None and edge["id2"] != None:
#         G.add_edge(edge["id1"], edge["id2"], weight=edge["length"])

# for node in data['nodes']:
#     nx.set_node_attributes(G, {node["id"]: node["attype"]}, "atom_type")
#     nx.set_node_attributes(G, {node["id"]: node["pl"]}, "from")

# print("Number of nodes: ", G.number_of_nodes())
# print("Number of edges: ", G.number_of_edges())

# colors = ["red" if G.nodes[node]["from"] == "P" else "lightblue" for node in G.nodes]

# # nx.draw_networkx_edge_labels(G, pos=nx.spring_layout(G), edge_labels=nx.get_edge_attributes(G, 'weight'))
# plt.figure(figsize=(10,10))
# pos = nx.spring_layout(G)
# nx.draw(G, pos=pos, with_labels=True, font_weight='bold', labels=nx.get_node_attributes(G, 'atom_type'), node_color=colors)
# # nx.draw_networkx_edge_labels(G, pos=pos, edge_labels=nx.get_edge_attributes(G, 'weight'))

#### Gather affinity values

## Load dataset

#### Load affinities
Already save in JSON, can load them

In [9]:
# directory = os.fsencode(DATA_PATH)

# data_versions = [2020, 2019, 2016]
# interaction_affinities = {}

# for file in tqdm(os.listdir(directory)):
#     interaction_name = os.fsdecode(file)
#     if os.path.isdir(DATA_PATH + interaction_name):
#         for version in data_versions:
#             INDEX_PATH = "data/pdbbind/PDBbind_v" + str(version) + "_plain_text_index/index/INDEX_general_PL_data." + str(version)
#             with open(INDEX_PATH, 'r') as f:
#                 interaction_info = f.readlines()[5:]
#                 for line in interaction_info:
#                     row = line.strip().split(" ")
#                     if row[0] == interaction_name:
#                         row = list(filter(lambda a: a != "", row)) #remove empty strings
                        
#                         affinity = float(row[3])
                        
#                         interaction_affinities[interaction_name] = affinity
#                         break

# with open(DATA_PATH + '/interaction_affinities.json', 'w+') as fp:
#     json.dump(interaction_affinities, fp, sort_keys=True, indent=4)            

In [10]:
interaction_affinities = None

with open(DATA_PATH + '/interaction_affinities.json', 'r') as fp:
    interaction_affinities = json.load(fp)

In [11]:
interaction_affinities

{'10gs': 6.4,
 '11gs': 5.82,
 '13gs': 4.62,
 '16pk': 5.22,
 '184l': 4.72,
 '185l': 3.54,
 '186l': 4.85,
 '187l': 3.37,
 '188l': 3.33,
 '1a07': 6.4,
 '1a0t': 1.3,
 '1a28': 8.29,
 '1a2c': 6.3,
 '1a30': 4.3,
 '1a42': 9.89,
 '1a4g': 8.4,
 '1a4h': 5.92,
 '1a4k': 8.0,
 '1a4m': 13.0,
 '1a4q': 5.44,
 '1a4r': 6.66,
 '1a4w': 5.92,
 '1a50': 6.7,
 '1a52': 9.86,
 '1a5g': 10.15,
 '1a5h': 6.3,
 '1a5v': 3.1,
 '1a61': 8.62,
 '1a69': 5.3,
 '1a7t': 1.64,
 '1a7x': 9.7,
 '1a85': 4.52,
 '1a86': 4.0,
 '1a8i': 5.52,
 '1a8t': 5.8,
 '1a94': 7.85,
 '1a99': 5.7,
 '1a9m': 6.92,
 '1a9q': 6.17,
 '1a9u': 7.32,
 '1aaq': 8.4,
 '1abf': 5.42,
 '1abt': 5.85,
 '1acj': 8.09,
 '1add': 6.74,
 '1adl': 5.36,
 '1ado': 6.0,
 '1af6': 1.82,
 '1afk': 6.62,
 '1afl': 6.28,
 '1ag9': 9.0,
 '1agm': 12.0,
 '1agw': 4.7,
 '1ai4': 2.5,
 '1ai5': 3.72,
 '1ai6': 3.97,
 '1ai7': 4.09,
 '1aid': 4.82,
 '1aj7': 3.87,
 '1ajn': 2.63,
 '1ajp': 2.23,
 '1ajq': 4.31,
 '1ajv': 7.72,
 '1ajx': 7.91,
 '1al7': 7.8,
 '1al8': 5.32,
 '1alw': 6.52,
 '1amk': 4.3,
 

In [12]:
vals = list(interaction_affinities.values())

In [13]:
# min(vals), max(vals)
plt.hist(vals, bins=100)

(array([  4.,   1.,   3.,   0.,   7.,   3.,  12.,  11.,  18.,  19.,  31.,
         34.,  69.,  48.,  50.,  76.,  82., 147., 102., 162., 136., 183.,
        190., 164., 252., 223., 249., 269., 277., 331., 269., 353., 326.,
        353., 286., 349., 342., 409., 372., 425., 398., 432., 460., 407.,
        337., 428., 422., 434., 425., 450., 334., 464., 326., 357., 267.,
        154., 282., 135., 239., 128., 140.,  86.,  74.,  49.,  71.,  29.,
         30.,  22.,  35.,  21.,  22.,  23.,  10.,  27.,   6.,  10.,   9.,
          6.,   9.,   1.,   5.,   0.,   2.,   1.,   0.,   5.,   0.,   1.,
          0.,   0.,   0.,   1.,   0.,   0.,   1.,   0.,   0.,   0.,   2.,
          1.]),
 array([ 0.4   ,  0.5482,  0.6964,  0.8446,  0.9928,  1.141 ,  1.2892,
         1.4374,  1.5856,  1.7338,  1.882 ,  2.0302,  2.1784,  2.3266,
         2.4748,  2.623 ,  2.7712,  2.9194,  3.0676,  3.2158,  3.364 ,
         3.5122,  3.6604,  3.8086,  3.9568,  4.105 ,  4.2532,  4.4014,
         4.5496,  4.6978,  4.846 ,

In [14]:
affinities_df = pd.DataFrame.from_dict(interaction_affinities, orient='index', columns=['affinity'])
affinities_df = affinities_df.sort_values(by = "affinity", ascending=True)
pseudo_labels = pd.qcut(x = affinities_df["affinity"], q = NUM_CLASSES, labels = [i for i in range(NUM_CLASSES)])
# pseudo_labels = [0 if x < 4 else 2 if x > 8 else 1 for x in affinities_df["affinity"]]
affinities_df["pseudo_label"] = pseudo_labels
interaction_affinities = affinities_df.to_dict(orient='index')
interaction_affinities

{'3zzf': {'affinity': 0.4, 'pseudo_label': 0},
 '3gww': {'affinity': 0.45, 'pseudo_label': 0},
 '3fqa': {'affinity': 0.49, 'pseudo_label': 0},
 '1w8l': {'affinity': 0.49, 'pseudo_label': 0},
 '1zsb': {'affinity': 0.6, 'pseudo_label': 0},
 '4obv': {'affinity': 0.75, 'pseudo_label': 0},
 '3k41': {'affinity': 0.82, 'pseudo_label': 0},
 '1wkm': {'affinity': 0.82, 'pseudo_label': 0},
 '1p0y': {'affinity': 1.0, 'pseudo_label': 0},
 '2d2v': {'affinity': 1.0, 'pseudo_label': 0},
 '2b1r': {'affinity': 1.0, 'pseudo_label': 0},
 '1aw1': {'affinity': 1.05, 'pseudo_label': 0},
 '1maw': {'affinity': 1.1, 'pseudo_label': 0},
 '3fl9': {'affinity': 1.11, 'pseudo_label': 0},
 '5eb2': {'affinity': 1.11, 'pseudo_label': 0},
 '4fci': {'affinity': 1.26, 'pseudo_label': 0},
 '4fck': {'affinity': 1.26, 'pseudo_label': 0},
 '4clp': {'affinity': 1.27, 'pseudo_label': 0},
 '2fah': {'affinity': 1.3, 'pseudo_label': 0},
 '1a0t': {'affinity': 1.3, 'pseudo_label': 0},
 '1b74': {'affinity': 1.3, 'pseudo_label': 0},
 

In [15]:
def generate_pli_dataset_dict(data_path):

    directory = os.fsencode(data_path)

    dataset_dict = {}
    dirs = os.listdir(directory)
    for file in tqdm(dirs):
        interaction_name = os.fsdecode(file)
        if os.path.isdir(data_path + interaction_name):
            dataset_dict[interaction_name] = {}
            with open(data_path + interaction_name + "/" + interaction_name + "_interaction_graph.json", 'r') as f:
                data = json.load(f)
                G = nx.Graph()

                for edge in data['edges']:
                    if edge["id1"] != None and edge["id2"] != None:
                        G.add_edge(edge["id1"], edge["id2"], weight=edge["length"])

                for node in data['nodes']:
                    nx.set_node_attributes(G, {node["id"]: node["attype"]}, "atom_type")
                    nx.set_node_attributes(G, {node["id"]: node["pl"]}, "from")

                # print(nx.to_scipy_sparse_matrix(G).tocoo())
                
            dataset_dict[interaction_name]["networkx_graph"] = G
            edge_index, edge_weight = create_edge_index(G, weighted=True)

            dataset_dict[interaction_name]["edge_index"] = edge_index
            dataset_dict[interaction_name]["edge_weight"] = edge_weight
            # dataset_dict[interaction_name]["edge_weight"] = 

            num_nodes = G.number_of_nodes()
            dataset_dict[interaction_name]["x"] = torch.full((num_nodes, 1), 1.0, dtype=torch.float)#dummy feature

            ## gather label
            dataset_dict[interaction_name]["y"] = torch.FloatTensor([interaction_affinities[interaction_name]["affinity"]])

            ## gather pseudo label
            dataset_dict[interaction_name]["pseudo_label"] = torch.LongTensor([interaction_affinities[interaction_name]["pseudo_label"]])
    
    return dataset_dict
        
pli_dataset_dict = generate_pli_dataset_dict(DATA_PATH)     

  0%|          | 0/14219 [00:00<?, ?it/s]

### create torch dataset

In [16]:
data_list = []
for interaction_name in tqdm(pli_dataset_dict):
    data_list.append(Data(x = pli_dataset_dict[interaction_name]["x"], edge_index = pli_dataset_dict[interaction_name]["edge_index"], edge_weight = pli_dataset_dict[interaction_name]["edge_weight"], y = pli_dataset_dict[interaction_name]["y"], pseudo_label=pli_dataset_dict[interaction_name]["pseudo_label"], networkx_graph = pli_dataset_dict[interaction_name]["networkx_graph"], interaction_name = interaction_name))

  0%|          | 0/14215 [00:00<?, ?it/s]

In [17]:
dataset = PLIDataset(".", data_list = data_list)

In [18]:
train_interactions = []
val_interactions = []
core_set_interactions = []
hold_out_interactions = []

with open("data/pdbbind/pdb_ids/training_set.csv", 'r') as f:
    train_interactions = f.readlines()

train_interactions = [interaction.strip() for interaction in train_interactions]

with open("data/pdbbind/pdb_ids/validation_set.csv", 'r') as f:
    val_interactions = f.readlines()

val_interactions = [interaction.strip() for interaction in val_interactions]

with open("data/pdbbind/pdb_ids/core_set.csv", 'r') as f:
    core_set_interactions = f.readlines()

core_set_interactions = [interaction.strip() for interaction in core_set_interactions]

with open("data/pdbbind/pdb_ids/hold_out_set.csv", 'r') as f:
    hold_out_interactions = f.readlines()

hold_out_interactions = [interaction.strip() for interaction in hold_out_interactions]

In [19]:
train_data = [dataset[i] for i in range(len(dataset)) if dataset[i].interaction_name in train_interactions]
val_data = [dataset[i] for i in range(len(dataset)) if dataset[i].interaction_name in val_interactions]
core_set_data = [dataset[i] for i in range(len(dataset)) if dataset[i].interaction_name in core_set_interactions]
hold_out_data = [dataset[i] for i in range(len(dataset)) if dataset[i].interaction_name in hold_out_interactions]

In [20]:
BATCH_SIZE = 32

train_loader = DataLoader(train_data, batch_size=BATCH_SIZE)
val_loader = DataLoader(val_data, batch_size=BATCH_SIZE)
core_set_loader = DataLoader(core_set_data, batch_size=BATCH_SIZE)
hold_out_loader = DataLoader(hold_out_data, batch_size=BATCH_SIZE)


In [21]:
# num_classes  =  1 even for regression

model = GCN(node_features_dim = dataset[0].x.shape[1], num_classes = NUM_CLASSES, hidden_channels=256).to(device)

In [22]:
# model.load_state_dict(torch.load("models/model_2022_11_08-15_47_33.ckpt"))
# model.to(device)
# model

### Train the network

In [23]:
#training the network
lr = lr=1e-3
EPOCHS = 100
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = torch.nn.MSELoss()

if CLASSIFICATION:
    criterion = torch.nn.CrossEntropyLoss()
    
epochs = EPOCHS

In [24]:
def train():
    model.train()

    for data in train_loader:  # Iterate in batches over the training dataset.
        data = data.to(device)
        out = model(data.x, data.edge_index, data.batch, edge_weight = data.edge_weight)  # Perform a single forward pass.
        loss = torch.sqrt(criterion(torch.squeeze(out), data.y))  # Compute the loss.
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        optimizer.zero_grad()  # Clear gradients.

def test(loader):
    model.eval()

    sum_loss = 0
    for data in loader:  # Iterate in batches over the training/test dataset.
        data = data.to(device)
        
        out = model(data.x, data.edge_index, data.batch, edge_weight = data.edge_weight)  
        loss = torch.sqrt(criterion(torch.squeeze(out), data.y)) 
        sum_loss += loss.item()
    return sum_loss / len(loader.dataset)  # Derive ratio of correct predictions.

In [25]:
def train_classification():
        model.train()

        for data in train_loader:  # Iterate in batches over the training dataset.
            data = data.to(device)
            out = model(data.x, data.edge_index, data.batch, edge_weight = data.edge_weight)  # Perform a single forward pass.
            loss = criterion(out, data.pseudo_label)  # Compute the loss.
            loss.backward()  # Derive gradients.
            optimizer.step()  # Update parameters based on gradients.
            optimizer.zero_grad()  # Clear gradients.

def test_classification(loader):
    model.eval()
    predictions = []
    correct = 0
    for data in loader:  # Iterate in batches over the training/test dataset.
        data = data.to(device)
        
        out = model(data.x, data.edge_index, data.batch, edge_weight = data.edge_weight)  
        pred = out.argmax(dim=1)  # Use the class with highest probability.
        predictions += pred.tolist()
        correct += int((pred == data.pseudo_label).sum())  # Check against ground-truth labels.
    return correct / len(loader.dataset), predictions  # Derive ratio of correct predictions.

To train network, uncomment the cell below

In [26]:
if TRAIN:
    if not CLASSIFICATION:
        for epoch in range(epochs):
            train()
            train_rmse = test(train_loader)
            val_rmse = test(val_loader)
            print(f'Epoch: {epoch:03d}, Train RMSE: {train_rmse:.4f}, Val RMSE: {val_rmse:.4f}')

        core_set_rmse = test(core_set_loader)    
        print(f'Core set 2016 RMSE: {core_set_rmse:.4f}')

        hold_out_set_rmse = test(hold_out_loader)    
        print(f'Hold out set 2019 RMSE: {hold_out_set_rmse:.4f}')

        save_model(model, "models")

    else:
        for epoch in range(epochs):
            train_classification()
            train_acc, _ = test_classification(train_loader)
            val_acc, _ = test_classification(val_loader)
            print(f'Epoch: {epoch:03d}, Train Accuracy: {train_acc:.4f}, Val Accuracy: {val_acc:.4f}')

        core_set_acc, _ = test_classification(core_set_loader)    
        print(f'Core set 2016 Accuracy: {core_set_acc:.4f}')

        hold_out_set_acc, _ = test_classification(hold_out_loader)    
        print(f'Hold out set 2019 Accuracy: {hold_out_set_acc:.4f}')

        save_model(model, "models_classification")
else:
    model.load_state_dict(torch.load("models/model_2022_11_08-15_47_33.ckpt"))
    model.to(device)
    model

In [27]:
%matplotlib inline
if not CLASSIFICATION:
    core_set_rmse = test(core_set_loader)    
    print(f'Core set 2016 RMSE: {core_set_rmse:.4f}')

else:
    core_set_acc, predictions = test_classification(core_set_loader)    
    print(f'Core set 2016 Accuracy: {core_set_acc:.4f}')

    actual_labels = []
    for data in core_set_loader:
        actual_labels += data.pseudo_label.tolist()

    print(classification_report(actual_labels, predictions))
    cm = confusion_matrix(actual_labels, predictions, normalize='true')
    plt.figure(figsize=(9,9))
    # labels = ["strong", "mild", "weak"]
    sns.heatmap(cm, annot=True, fmt=".3f", linewidths=.5, square = True, cmap = 'Blues') #, xticklabels = labels, yticklabels = labels


Core set 2016 RMSE: 0.0839


  return F.mse_loss(input, target, reduction=self.reduction)


In [28]:
if not CLASSIFICATION:
    hold_out_set_rmse = test(hold_out_loader)    
    print(f'Hold out set 2019 RMSE: {hold_out_set_rmse:.4f}')
else:
    hold_out_set_acc, predictions = test_classification(hold_out_loader)    
    print(f'Hold out set 2019 Accuracy: {hold_out_set_acc:.4f}')

    actual_labels = []
    for data in hold_out_loader:
        actual_labels += data.pseudo_label.tolist()
    
    print(classification_report(actual_labels, predictions))
    cm = confusion_matrix(actual_labels, predictions, normalize='true')
    plt.figure(figsize=(9,9))
    # labels = ["strong", "mild", "weak"]
    sns.heatmap(cm, annot=True, fmt=".3f", linewidths=.5, square = True, cmap = 'Blues') #xticklabels = labels, yticklabels = labels

Hold out set 2019 RMSE: 0.0563


## Explainability

In [29]:
num_test_interactions = len(hold_out_interactions)
test_interaction_indices = np.random.choice(num_test_interactions, 10, replace=False)

num_edge_in_protein_list = []
num_edge_in_ligand_list = []
num_edge_in_between_list = []

In [30]:
if not CLASSIFICATION:
    TARGET_CLASS = None

for index in tqdm(test_interaction_indices):
    test_interaction = hold_out_data[index]
    print("Interaction: " + test_interaction.interaction_name)

    batch = torch.zeros(test_interaction.x.shape[0], dtype=int, device=test_interaction.x.device)
    
    if not CLASSIFICATION:
        out = model(test_interaction.x.to(device), test_interaction.edge_index.to(device), batch=batch.to(device), edge_weight=test_interaction.edge_weight.to(device))
        print(f"Predicted: {out.item()} Actual: {test_interaction.y.item()}")
    else:
        out = model(test_interaction.x.to(device), test_interaction.edge_index.to(device), batch=batch.to(device), edge_weight=test_interaction.edge_weight.to(device))
        pred_prob = F.softmax(out, dim=1)
        pred = pred_prob.argmax(dim=1)
        print(f"Predicted: {pred.item()} Actual: {test_interaction.pseudo_label.item()}")
        print(f"Predicted probability: {pred_prob[0][pred.item()].item()}")

        if pred.item() != test_interaction.pseudo_label.item():
            continue
        TARGET_CLASS = pred.item()
    


    #explainability

    edgeshaper_explainer = Edgeshaper(model, test_interaction.x, test_interaction.edge_index, edge_weight = test_interaction.edge_weight, device = device)

    phi_edges = edgeshaper_explainer.explain(M = 100, target_class = TARGET_CLASS, deviation = 1e-3, seed = 42)

    print(f"Sum of phi edges: {sum(phi_edges)}")

    #plotting
    num_bonds = test_interaction.networkx_graph.number_of_edges()
    rdkit_bonds_phi = [0]*num_bonds
    rdkit_bonds = {}

    bonds = dict(test_interaction.networkx_graph.edges())
    bonds = list(bonds.keys())

    for i in range(num_bonds):
        init_atom = bonds[i][0]
        end_atom = bonds[i][1]
        
        rdkit_bonds[(init_atom, end_atom)] = i

    for i in range(len(phi_edges)):
        phi_value = phi_edges[i]
        init_atom = test_interaction.edge_index[0][i].item()
        end_atom = test_interaction.edge_index[1][i].item()
        
        if (init_atom, end_atom) in rdkit_bonds:
            bond_index = rdkit_bonds[(init_atom, end_atom)]
            rdkit_bonds_phi[bond_index] += phi_value
        if (end_atom, init_atom) in rdkit_bonds:
            bond_index = rdkit_bonds[(end_atom, init_atom)]
            rdkit_bonds_phi[bond_index] += phi_value

    # print(rdkit_bonds_phi)

    G = test_interaction.networkx_graph
    print("Number of nodes: ", G.number_of_nodes())
    print("Number of edges: ", G.number_of_edges())

    colors = ["red" if G.nodes[node]["from"] == "P" else "lightblue" for node in G.nodes]

    plt.figure(figsize=(10,10))
    pos = nx.spring_layout(G)
    nx.draw(G, pos=pos, with_labels=True, font_weight='bold', labels=nx.get_node_attributes(G, 'atom_type'), node_color=colors,edge_color=rdkit_bonds_phi, width=3, edge_cmap=plt.cm.bwr)

    plt.show()

    SAVE_PATH = SAVE_FOLDER + "/" + test_interaction.interaction_name + "/"

    if not os.path.exists(SAVE_PATH):
        os.makedirs(SAVE_PATH)

    plt.savefig(SAVE_PATH + test_interaction.interaction_name + "_EdgeSHAPer.png")

    plt.close()

    ### statistics on important edges
    mean_phi = np.mean(rdkit_bonds_phi)
    top_edges = [i for i in range(len(rdkit_bonds_phi)) if rdkit_bonds_phi[i] >= mean_phi]

    num_edge_in_protein = 0
    num_edge_in_ligand = 0
    num_edge_in_between = 0

    atoms_origin = nx.get_node_attributes(G, 'from')

    for bond in bonds:
        init_atom = bond[0]
        end_atom = bond[1]

        bond_index = rdkit_bonds[(init_atom, end_atom)]
        if bond_index in top_edges:
            if atoms_origin[init_atom] == "P" and atoms_origin[end_atom] == "P":
                num_edge_in_protein += 1
            elif atoms_origin[init_atom] == "L" and atoms_origin[end_atom] == "L":
                num_edge_in_ligand += 1
            else:
                num_edge_in_between += 1

    with open(SAVE_PATH + test_interaction.interaction_name + "_statistics.txt", "w") as f:
        f.write("Number of relevant edges connecting protein atoms: " + str(num_edge_in_protein) + "\n")
        f.write("Number of relevant edges connecting ligand atoms: " + str(num_edge_in_ligand) + "\n")
        f.write("Number of relevant edges connecting protein and ligand atoms: " + str(num_edge_in_between) + "\n")

    num_edge_in_protein_list.append(num_edge_in_protein)
    num_edge_in_ligand_list.append(num_edge_in_ligand)
    num_edge_in_between_list.append(num_edge_in_between)

    print("Number of relevant edges connecting protein atoms: ", num_edge_in_protein)
    print("Number of relevant edges connecting ligand atoms: ", num_edge_in_ligand)
    print("Number of relevant edges connecting protein and ligand atoms: ", num_edge_in_between)



  0%|          | 0/10 [00:00<?, ?it/s]

Interaction: 4z22
Predicted: 6.536458969116211 Actual: 6.239999771118164
No target class specified. Regression model assumed.


  0%|          | 0/100 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
phi_edges

[0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 5.960464477539063e-10,
 0.0,
 0.0,
 5.960464477539063e-10,
 2.9802322387695313e-10,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 2.9802322387695313e-10,
 0.0,
 -2.9802322387695313e-10,
 2.9802322387695313e-10,
 2.9802322387695313e-10,
 2.9802322387695313e-10,
 2.9802322387695313e-10,
 5.960464477539063e-10,
 -5.960464477539063e-10,
 5.960464477539063e-10,
 0.0,
 -5.960464477539063e-10,
 0.0,
 0.0,
 0.0,
 0.0,
 2.9802322387695313e-10,
 0.0,
 -2.9802322387695313e-10,
 0.0,
 0.0,
 -2.9802322387695313e-10,
 -5.960464477539063e-10,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 -5.960464477539063e-10,
 0.0,
 0.0,
 0.0,
 -2.9802322387695313e-10,
 -2.9802322387695313e-10,
 0.0,
 0.0,
 -2.9802322387695313e-10,
 0.0,
 0.0,
 -2.9802322387695313e-10,
 2.9802322387695313e-10,
 0.0,
 8.940696716308593e-10,
 0.0,
 -5.960464477539063e-10,
 -1.1920928955078125e-09,
 0.0,
 -8.940696716308593e-10,
 0.0,
 -2.9802322387695313e-10,
 -5.960464477539063e-10,
 0.0,
 2.9802322387695313

In [None]:
print("Avg number of edges in protein: ", round(np.mean(num_edge_in_protein_list), 3))
print("Avg number of edges in ligand: ", round(np.mean(num_edge_in_ligand_list), 3))
print("Avg number of edges in between: ", round(np.mean(num_edge_in_between_list), 3))

Avg number of edges in protein:  2.1
Avg number of edges in ligand:  23.0
Avg number of edges in between:  11.1


In [None]:
with open(SAVE_FOLDER + "/statistics.txt", "w") as f:
    f.write("Avg number of edges in protein: " + str(round(np.mean(num_edge_in_protein_list), 3)) + "\n")
    f.write("Avg number of edges in ligand: " + str(round(np.mean(num_edge_in_ligand_list), 3)) + "\n")
    f.write("Avg number of edges in between: " + str(round(np.mean(num_edge_in_between_list), 3)) + "\n")