In [19]:
# Installations

# !pip install torch_sparse
# !pip install torch_scatter

In [20]:
# Imports

import os, json 

import networkx as nx

import pandas as pd
import numpy as np
import random
import pickle

import torch
import torch.nn as nn

from torch_geometric.data import DataLoader
import torch.nn.functional as F

from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.utils import dropout_adj
from torch_geometric.nn import global_mean_pool as gep
from torch_geometric import data as DATA
from torch_geometric.data import InMemoryDataset, Batch

from collections import OrderedDict

from rdkit import Chem
from rdkit.Chem import MolFromSmiles

In [21]:
# Hardcoded feature information 
# Reference: https://www.bioinfor.com/amino-acid/ and https://www.sigmaaldrich.com/life-science/metabolomics/learning-center/amino-acid-reference-chart.html

residue_table = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y', 'X']

aliphatic_residues_table = ['A', 'I', 'L', 'M', 'V']
aromatic_residues_table = ['F', 'W', 'Y']
polar_neutral_residues_table = ['C', 'N', 'Q', 'S', 'T']
acidic_charged_residues_table = ['D', 'E']
basic_charged_residues_table = ['H', 'K', 'R']

weight_table = {'A': 71.08, 'C': 103.15, 'D': 115.09, 'E': 129.12, 'F': 147.18, 'G': 57.05, 'H': 137.14,
                    'I': 113.16, 'K': 128.18, 'L': 113.16, 'M': 131.20, 'N': 114.11, 'P': 97.12, 'Q': 128.13,
                    'R': 156.19, 'S': 87.08, 'T': 101.11, 'V': 99.13, 'W': 186.22, 'Y': 163.18}

hydrophobic_ph2_table = {'A': 47, 'C': 52, 'D': -18, 'E': 8, 'F': 92, 'G': 0, 'H': -42, 'I': 100,
                             'K': -37, 'L': 100, 'M': 74, 'N': -41, 'P': -46, 'Q': -18, 'R': -26, 'S': -7,
                             'T': 13, 'V': 79, 'W': 84, 'Y': 49}

hydrophobic_ph7_table = {'A': 41, 'C': 49, 'D': -55, 'E': -31, 'F': 100, 'G': 0, 'H': 8, 'I': 99,
                             'K': -23, 'L': 97, 'M': 74, 'N': -28, 'P': -46, 'Q': -10, 'R': -14, 'S': -5,
                             'T': 13, 'V': 76, 'W': 97, 'Y': 63}

pl_table = {'A': 6.00, 'C': 5.07, 'D': 2.77, 'E': 3.22, 'F': 5.48, 'G': 5.97, 'H': 7.59,
                'I': 6.02, 'K': 9.74, 'L': 5.98, 'M': 5.74, 'N': 5.41, 'P': 6.30, 'Q': 5.65,
                'R': 10.76, 'S': 5.68, 'T': 5.60, 'V': 5.96, 'W': 5.89, 'Y': 5.96}

pka_table = {'A': 2.34, 'C': 1.96, 'D': 1.88, 'E': 2.19, 'F': 1.83, 'G': 2.34, 'H': 1.82, 'I': 2.36,
                 'K': 2.18, 'L': 2.36, 'M': 2.28, 'N': 2.02, 'P': 1.99, 'Q': 2.17, 'R': 2.17, 'S': 2.21,
                 'T': 2.09, 'V': 2.32, 'W': 2.83, 'Y': 2.32}

pkb_table = {'A': 9.69, 'C': 10.28, 'D': 9.60, 'E': 9.67, 'F': 9.13, 'G': 9.60, 'H': 9.17,
                 'I': 9.60, 'K': 8.95, 'L': 9.60, 'M': 9.21, 'N': 8.80, 'P': 10.60, 'Q': 9.13,
                 'R': 9.04, 'S': 9.15, 'T': 9.10, 'V': 9.62, 'W': 9.39, 'Y': 9.62}

pkx_table = {'A': 0.00, 'C': 8.18, 'D': 3.65, 'E': 4.25, 'F': 0.00, 'G': 0, 'H': 6.00,
                 'I': 0.00, 'K': 10.53, 'L': 0.00, 'M': 0.00, 'N': 0.00, 'P': 0.00, 'Q': 0.00,
                 'R': 12.48, 'S': 0.00, 'T': 0.00, 'V': 0.00, 'W': 0.00, 'Y': 0.00}

atom_table = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 
              'K', 'Tl', 'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se','Ti', 'Zn', 'H', 'Li', 'Ge', 'Cu', 'Au', 'Ni', 
              'Cd', 'In', 'Mn', 'Zr', 'Cr', 'Pt', 'Hg', 'Pb', 'X']

count_table = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

# Normalize

def normalize_dict(dic):
    max_ = dic[max(dic, key=dic.get)]
    min_ = dic[min(dic, key=dic.get)]
    interval = float(max_) - float(min_)
    
    for key in dic.keys():
        dic[key] = (dic[key] - min_) / interval
        
    dic['X'] = (max_ + min_) / 2.0 # For unknown 
    return dic

weight_table = normalize_dict(weight_table)
pka_table = normalize_dict(pka_table)
pkb_table = normalize_dict(pkb_table)
pkx_table = normalize_dict(pkx_table)
pl_table = normalize_dict(pl_table)
hydrophobic_ph2_table = normalize_dict(hydrophobic_ph2_table)
hydrophobic_ph7_table = normalize_dict(hydrophobic_ph7_table)


In [22]:
# Parameter settings 

dataset = 'davis' 

TRAIN_BATCH_SIZE = 256
TEST_BATCH_SIZE = 256
LR = 0.001
NUM_EPOCHS = 1000

NUM_PROT_FEATURES = 33
NUM_MOL_FEATURES = 78
NUM_RES_PROPERTIES = 12

DROPOUT = 0.2

print('Learning rate: ', LR)
print('Epochs: ', NUM_EPOCHS)

Learning rate:  0.001
Epochs:  1000


In [23]:
# Directory creation for models/results

models_dir = 'models'
results_dir = 'results'

if not os.path.exists(models_dir):
    os.makedirs(models_dir)

if not os.path.exists(results_dir):
    os.makedirs(results_dir)

In [24]:
# CUDA 

USE_CUDA = torch.cuda.is_available()
device = torch.device('cuda:0' if USE_CUDA else 'cpu')

In [25]:
# Model architecture definition 

# 3-layer, GCN-based model

class GCNNet(torch.nn.Module):
    def __init__(self, n_output=1, num_features_pro=NUM_PROT_FEATURES, num_features_mol=NUM_MOL_FEATURES, output_dim=128, dropout=DROPOUT):
        super(GCNNet, self).__init__()
        print('GCNNet Loaded')
        
        # Output layer
        self.n_output = n_output
        
        # Mol convolutional layers
        self.mol_conv1 = GCNConv(num_features_mol, num_features_mol)
        # self.mol_conv2 = GCNConv(num_features_mol, num_features_mol*2)
        # self.mol_conv3 = GCNConv(num_features_mol*2, num_features_mol*4)
        
        # Mol fully connected layers
        self.mol_fc1 = torch.nn.Linear(num_features_mol, 1024)
        self.mol_fc2 = torch.nn.Linear(1024, output_dim)

        # Protein convolutional layers
        self.pro_conv1 = GCNConv(num_features_pro, num_features_pro)
        #self.pro_conv2 = GCNConv(num_features_pro, num_features_pro*2)
        # self.pro_conv3 = GCNConv(num_features_pro*2, num_features_pro*4)
        
        # Protein fully connected layers
        self.pro_fc1 = torch.nn.Linear(num_features_pro, 1024)
        self.pro_fc2 = torch.nn.Linear(1024, output_dim)

        # Other
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

        # Final layers
        self.fc1 = nn.Linear(2*output_dim, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.out = nn.Linear(512, self.n_output)

    def forward(self, data_mol, data_pro):
        # Molecule info
        mol_data, mol_edge_index, mol_batch = data_mol.x, data_mol.edge_index, data_mol.batch
        
        # Protein info
        target_data, target_edge_index, target_batch = data_pro.x, data_pro.edge_index, data_pro.batch
        
        # Protein
        
        x_p = self.pro_conv1(target_data, target_edge_index)
        x_p = self.relu(x_p)

        x_p = self.pro_conv2(x_p, target_edge_index)
        x_p = self.relu(x_p)

        x_p = self.pro_conv3(x_p, target_edge_index)
        x_p = self.relu(x_p)

        x_p = gep(x_p, target_batch) # Pooling

        x_p = self.relu(self.pro_fc1(x_p))
        x_p = self.dropout(x_p)
        
        x_p = self.pro_fc2(x_p)
        x_p = self.dropout(x_p)
        
        # Molecule
        
        x_m = self.mol_conv1(mol_data, mol_edge_index)
        x_m = self.relu(x_m)
        
        x_m = self.mol_conv2(x_m, mol_edge_index)
        x_m = self.relu(x_m)
        
        x_m = self.mol_conv3(x_m, mol_edge_index)
        x_m = self.relu(x_m)
        
        x_m = gep(x_m, mol_batch) # Pooling

        x_m = self.relu(self.mol_fc1(x_m))
        x_m = self.dropout(x_m)
        
        x_m = self.mol_fc2(x_m)
        x_m = self.dropout(x_m)
        
        # Concatenation

        x_c = torch.cat((x_m, x_p), 1)
        
        x_c = self.fc1(x_c)
        x_c = self.relu(x_c)
        x_c = self.dropout(x_c)
        
        x_c = self.fc2(x_c)
        x_c = self.relu(x_c)
        x_c = self.dropout(x_c)
        
        out = self.out(x_c)
        
        return out

In [26]:
# Helper functions for training and featurization 

# Create one hot encoding with k choices
def one_k_encoding(x, choices):
    return list(map(lambda s: x == s, choices))

# If not in choices, set as last choice (for X residue)
def one_not_k_encoding(x, choices):
    if x not in choices:
        x = choices[-1]
    return list(map(lambda s: x == s, choices))

def sequence_feature(protein_seq):
    protein_property = np.zeros((len(protein_seq), NUM_RES_PROPERTIES))
    protein_one_hot = np.zeros((len(protein_seq), len(residue_table)))
    
    for i in range(len(protein_seq)):
        protein_property[i,] = residue_features(protein_seq[i]) # Get the chemical properties
        protein_one_hot[i,] = one_k_encoding(protein_seq[i], residue_table) # Get the residue identity 
        
    return np.concatenate((protein_property, protein_one_hot), axis=1)

def protein_to_feature(protein_key, protein_seq):
    feature = sequence_feature(protein_seq)
    return feature

def protein_to_graph(protein_key, protein_seq, contact_dir):
    protein_edges = []
    protein_len = len(protein_seq)
    
    contact_file = os.path.join(contact_dir, protein_key + '.npy')
    contact_map = np.load(contact_file)
    contact_map += np.matrix(np.eye(contact_map.shape[0]))
    
    row, col = np.where(contact_map >= 0.5) # Apply the contact map 0.5 threshold
    for i, j in zip(row, col):
        protein_edges.append([i, j]) # Get all the edges
    protein_feature = protein_to_feature(protein_key, protein_seq)
    protein_edges = np.array(protein_edges)
    
    return protein_len, protein_feature, protein_edges

def residue_features(residue):
    numeric_properties = [weight_table[residue],
                          hydrophobic_ph2_table[residue], 
                          hydrophobic_ph7_table[residue],
                          pl_table[residue],
                          pka_table[residue], 
                          pkb_table[residue], 
                          pkx_table[residue]]
    binary_properties = [1 if residue in acidic_charged_residues_table else 0,
                         1 if residue in basic_charged_residues_table else 0, 
                         1 if residue in aromatic_residues_table else 0,
                         1 if residue in aliphatic_residues_table else 0, 
                         1 if residue in polar_neutral_residues_table else 0]
    return np.array(binary_properties + numeric_properties)

def smile_to_graph(smile):
    mol = Chem.MolFromSmiles(smile)
    mol_size = mol.GetNumAtoms()

    node_features = [] # Gather atom features
    for atom in mol.GetAtoms():
        feature = atom_features(atom)
        node_features.append(feature/sum(feature))

    edges = []
    for bond in mol.GetBonds(): # Getting the bonds and adding indices
        edges.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()])
        
    g = nx.Graph(edges).to_directed()
    
    edge_index = []
    mol_adj = np.zeros((mol_size, mol_size)) # Adjacency matrix
    for edge_1, edge_2 in g.edges:
        mol_adj[edge_1, edge_2] = 1 # Fill in edges
        
    mol_adj += np.matrix(np.eye(mol_adj.shape[0]))
    
    row, col = np.where(mol_adj == 1) 
    for i, j in zip(row, col):
        edge_index.append([i, j])
    return mol_size, node_features, edge_index

def atom_features(atom):
    return np.array(one_not_k_encoding(atom.GetSymbol(), atom_table) +
                    one_k_encoding(atom.GetDegree(), count_table) +
                    one_not_k_encoding(atom.GetTotalNumHs(), count_table) +
                    one_not_k_encoding(atom.GetImplicitValence(), count_table) +
                    [atom.GetIsAromatic()])

# Metrics - MSE
def get_mse(y_true, y_preds):
    mse = ((y_true - y_pred)**2).mean(axis=0)
    return mse

In [93]:
# Dataset generation 

def create_dataset(dataset):
    dataset_path = 'data/' + dataset + '/'
    training = json.load(open(dataset_path + 'split/train_split.txt')) 
    training = [e for e in training]

    proteins_ = json.load(open(dataset_path + 'proteins.txt'), object_pairs_hook=OrderedDict)
    drugs_ = json.load(open(dataset_path + 'drugs.txt'), object_pairs_hook=OrderedDict)
    
    # Load contact map
    contact_path = 'data/' + dataset + '/contact_maps'
    contact_list = []
    for key in proteins_:
        contact_list.append(os.path.join(contact_path, key + '.npy'))

    train = []
    valid = training[0]
    for i in range(len(training)):
        if i != 0:
            train += training[i]

    interaction_scores = pickle.load(open(dataset_path + 'Y', 'rb'), encoding='latin1')
    
    # Lists 
    
    drugs = []
    proteins = []
    protein_keys = []
    drug_smiles = []
    
    # Convert SMILES
    
    for d in drugs_.keys():
        mol = Chem.MolToSmiles(Chem.MolFromSmiles(drugs_[d]), isomericSmiles=True)
        drugs.append(mol)
        drug_smiles.append(drugs_[d])
        
    # Protein sequences 
    
    for p in proteins_.keys():
        proteins.append(proteins_[p])
        protein_keys.append(p)
        
    # Need to transform the Davis scores 
    
    if dataset == 'davis':
        interaction_scores = [-np.log10(label/1e9) for label in interaction_scores]
    interaction_scores = np.asarray(interaction_scores)
    
    train_count = 0
    valid_count = 0
    
    # Training data
    rows, cols = np.where(np.isnan(interaction_scores) == False) # Get rows with data
    rows, cols = rows[train], cols[train]
    train_entries = []
    for pair_ind in range(len(rows)):
        ls = []
        ls += [drugs[rows[pair_ind]]]
        ls += [proteins[cols[pair_ind]]]
        ls += [protein_keys[cols[pair_ind]]]
        ls += [interaction_scores[rows[pair_ind], cols[pair_ind]]]
        train_entries.append(ls)
        train_count += 1

    csv_file = 'data/' + dataset + '_train'  + '.csv'
    data_to_csv(csv_file, train_entries)
            
    # Validation data
    rows, cols = np.where(np.isnan(interaction_scores) == False)
    rows, cols = rows[valid], cols[valid]
    valid_entries = []
    for pair_ind in range(len(rows)):
        ls = []
        ls += [drugs[rows[pair_ind]]]
        ls += [proteins[cols[pair_ind]]]
        ls += [protein_keys[cols[pair_ind]]]
        ls += [interaction_scores[rows[pair_ind], cols[pair_ind]]]
        valid_entries.append(ls)
        valid_count += 1

    csv_file = 'data/' + dataset + '_valid'  + '.csv'
    data_to_csv(csv_file, valid_entries)
            
    print('Dataset: ', dataset)
    print('Working train entries:', train_count)
    print('Working validation entries: ', valid_count)

    compound_smiles = drugs
    protein_key = protein_keys

    # SMILES to graph
    
    smile_graph = {}
    for smile in compound_smiles:
        graph = smile_to_graph(smile)
        smile_graph[smile] = graph

    # Protein to graph 
    
    protein_graph = {}
    for key in protein_key:
        # print(key)
        graph = protein_to_graph(key, proteins_[key], contact_path)
        protein_graph[key] = graph

    train_csv = 'data/' + dataset + '_' + 'train' + '.csv'
    df_train = pd.read_csv(train_csv)
    
    train_drugs, train_prot_keys, train_Y = list(df_train['compound_smiles']), list(df_train['protein_key']), list(df_train['interaction_scores'])
    train_drugs, train_prot_keys, train_Y = np.asarray(train_drugs), np.asarray(train_prot_keys), np.asarray(train_Y)
    train_dataset = ProtDrugDataset(root='data', dataset=dataset + '_' + 'train', xd=train_drugs, protein_key=train_prot_keys, label=train_Y, smile_graph=smile_graph, protein_graph=protein_graph)


    df_valid = pd.read_csv('data/' + dataset + '_' + 'valid' + '.csv')
    valid_drugs, valid_prots_keys, valid_Y = list(df_valid['compound_smiles']), list(df_valid['protein_key']), list(df_valid['interaction_scores'])
    valid_drugs, valid_prots_keys, valid_Y = np.asarray(valid_drugs), np.asarray(valid_prots_keys), np.asarray(valid_Y)
    valid_dataset = ProtDrugDataset(root='data', dataset=dataset + '_' + 'train', xd=valid_drugs,
                               protein_key=valid_prots_keys, label=valid_Y, smile_graph=smile_graph,
                               protein_graph=protein_graph)
    
    return train_dataset, valid_dataset

IndentationError: unexpected indent (<ipython-input-93-f7669728a7bf>, line 60)

In [89]:
# Create dataset object for PyTorch
class ProtDrugDataset(InMemoryDataset):
    def __init__(self, root, dataset, xd=None, label=None, smile_graph=None, protein_key=None, protein_graph=None):

        super(ProtDrugDataset, self).__init__(root)
        self.dataset = dataset
        self.preprocess(xd, protein_key, label, smile_graph, protein_graph)

    def preprocess(self, xd, protein_key, label, smile_graph, protein_graph):
        data_list_mol = []
        data_list_pro = []
        data_len = len(xd)
        
        for i in range(data_len):
            smiles = xd[i]
            prot_key = protein_key[i]
            labels = label[i]
            mol_size, features, edge_index = smile_graph[smiles]
            protein_size, protein_features, protein_edges = protein_graph[prot_key]

            # Processing for GCN algorithms:
            GCNData_mol = DATA.Data(x=torch.Tensor(features),
                                    edge_index=torch.LongTensor(edge_index).transpose(1, 0),
                                    y=torch.FloatTensor([labels]))
            
            GCNData_mol.__setitem__('mol_size', torch.LongTensor([mol_size]))

            GCNData_pro = DATA.Data(x=torch.Tensor(protein_features),
                                    edge_index=torch.LongTensor(protein_edges).transpose(1, 0),
                                    y=torch.FloatTensor([labels]))
            GCNData_pro.__setitem__('protein_size', torch.LongTensor([protein_size]))
            
            data_list_mol.append(GCNData_mol)
            data_list_pro.append(GCNData_pro)

        self.data_mol = data_list_mol
        self.data_pro = data_list_pro

    def __len__(self):
        return len(self.data_mol)

    def __getitem__(self, idx):
        return self.data_mol[idx], self.data_pro[idx]

# Training function 
def train(model, device, train_loader, optimizer, epoch):
    print('Beginning training.')
    model.train()
    
    VERBOSE_INT = 10
    
    for batch_idx, data in enumerate(train_loader):
        # To device 
        data_mol = data[0].to(device)
        data_pro = data[1].to(device)
        
        optimizer.zero_grad()
        
        output = model(data_mol, data_pro)
        
        loss = loss_fn(output, data_mol.y.view(-1, 1).float().to(device))
        loss.backward()
        optimizer.step()
        
        if batch_idx % VERBOSE_INT == 0:
            print('Train epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch,
                                                                           batch_idx*TRAIN_BATCH_SIZE,
                                                                           len(train_loader.dataset),
                                                                           100.*batch_idx/len(train_loader),
                                                                           loss.item()))      
# Prediction function 
def predict(model, device, loader):
    model.eval()
    total_preds = torch.Tensor()
    total_labels = torch.Tensor()
    
    print('Beginning prediction.')
    with torch.no_grad():
        for data in loader:
            data_mol = data[0].to(device)
            data_pro = data[1].to(device)
            output = model(data_mol, data_pro)
            total_preds = torch.cat((total_preds, output.cpu()), 0)
            total_labels = torch.cat((total_labels, data_mol.y.view(-1, 1).cpu()), 0)
            
    return total_labels.numpy().flatten(), total_preds.numpy().flatten()


def collate(data_list):
    batchA = Batch.from_data_list([data[0] for data in data_list])
    batchB = Batch.from_data_list([data[1] for data in data_list])
    return batchA, batchB

def data_to_csv(csv_file, datalist):
    with open(csv_file, 'w') as f:
        f.write('compound_smiles,protein_sequence,protein_key,interaction_scores\n')
        for data in datalist:
            f.write(','.join(map(str, data)) + '\n')

In [90]:
# Training 

model = GCNNet()
model.to(device)
model_name = 'Davis_GCN_1'

optimizer = torch.optim.Adam(model.parameters(), lr=LR)

loss_fn = nn.MSELoss()

train_data, valid_data = create_dataset(dataset)
# train_data

GCNNet Loaded
Dataset:  davis
Working train entries: 20036
Working validation entries:  5010


In [91]:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=TRAIN_BATCH_SIZE, shuffle=True, collate_fn=collate)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=TEST_BATCH_SIZE, shuffle=False, collate_fn=collate)

In [92]:
UPPER_BOUND = 99999
LOWER_BOUND = -1

best_mse = UPPER_BOUND
best_test_mse =  UPPER_BOUND

best_epoch = LOWER_BOUND

model_file_name = 'models/model_' + model_name + '_' + dataset + '.model'

for epoch in range(NUM_EPOCHS):
        train(model, device, train_loader, optimizer, epoch + 1)
        
        print('Validation.')
        G, P = predict(model, device, valid_loader)
        val = get_mse(G, P)
        
        print('Validation result: ', val, best_mse)
        
        if val < best_mse:
            best_mse = val
            best_epoch = epoch + 1
            torch.save(model.state_dict(), model_file_name)
            print('MSE improved! Epoch: ', best_epoch, '. MSE: ', best_mse)
        else:
            print('No improvement. Last improvement at epoch: ', best_epoch, '. MSE: ', best_mse)

Beginning training.


KeyboardInterrupt: 

In [96]:
# Helper function to load best saved model 

def load_model(path):
    model = torch.load(path)
    return model

# Metrics for evaluating testing data

def evaluate_metrics(y_true, y_pred):
    mse = get_mse(y_true, y_pred)
    ci = get_ci(y_true, y_pred)
    print('MSE: ', mse)
    print('CI: ', ci)

# Based on implementation/definition in: https://lifelines.readthedocs.io/en/latest/lifelines.utils.html#lifelines.utils.concordance_index
def get_ci(y_true, y_pred):
    admissible = 0
    paired = 0
    for i in range(1, len(y_true)):
        for j in range(0, i):
            if i != j:
                if (y_true[i] > y_true[j]):
                    paired += 1
                    admissible += 1*(y_pred[i] > y_pred[j]) + 0.5*(y_pred[i] == y_pred[j])
    if pair != 0:
        return total/admissible
    return 0

In [97]:
def create_test_data(dataset):
    
    dataset_path = 'data/' + dataset + '/'
    test = json.load(open(dataset_path + 'split/train_split.txt'))
    
    drugs_ = json.load(open(dataset_path + 'drugs.txt'), object_pairs_hook=OrderedDict)
    proteins_ = json.load(open(dataset_path + 'proteins.txt'), object_pairs_hook=OrderedDict)
    
    interaction_scores = pickle.load(open(dataset_path + 'Y', 'rb'), encoding='latin1')
    
    contact_path = 'data/' + dataset + '/contact_maps'
    contact_list = []
    for key in proteins_:
        contact_list.append(os.path.join(contact_path, key + '.npy'))

    # Lists 
    
    drugs = []
    proteins = []
    protein_keys = []
    drug_smiles = []
    
    # Convert SMILES
    
    for d in drugs_.keys():
        mol = Chem.MolToSmiles(Chem.MolFromSmiles(drugs_[d]), isomericSmiles=True)
        drugs.append(mol)
        drug_smiles.append(drugs_[d])
        
    # Protein sequences 
    
    for p in proteins_.keys():
        proteins.append(proteins_[p])
        protein_keys.append(p)
           
    # Need to transform the Davis scores 
    
    if dataset == 'davis':
        interaction_scores = [-np.log10(label/1e9) for label in interaction_scores]
    interaction_scores = np.asarray(interaction_scores)
    
    # Testing data
    rows, cols = np.where(np.isnan(interaction_scores) == False) # Get rows with data
    rows, cols = rows[test], cols[test]
    test_entries = []
    for pair_ind in range(len(rows)):
        ls = []
        ls += [drugs[rows[pair_ind]]]
        ls += [proteins[cols[pair_ind]]]
        ls += [protein_keys[cols[pair_ind]]]
        ls += [interaction_scores[rows[pair_ind], cols[pair_ind]]]
        test.append(ls)
        test_count += 1
        
    csv_file = 'data/' + dataset + '_test' + '.csv'
    data_to_csv(csv_file, test_entries)

    compound_smiles = drugs
    protein_key = protein_keys

    # SMILES to graph
    
    smile_graph = {}
    for smile in compound_smiles:
        graph = smile_to_graph(smile)
        smile_graph[smile] = graph

    # Protein to graph 
    
    protein_graph = {}
    for key in protein_key:
        # print(key)
        graph = protein_to_graph(key, proteins_[key], contact_path)
        protein_graph[key] = graph
        
    test_csv = 'data/' + dataset + '_' + 'test' + '.csv'
    df_test = pd.read_csv(test_csv)
    
    test_drugs, test_prot_keys, test_Y = list(df_test['compound_smiles']), list(df_test['protein_key']), list(df_test['interaction_scores'])
    test_drugs, test_prot_keys, test_Y = np.asarray(test_drugs), np.asarray(test_prot_keys), np.asarray(v_Y)
    test_dataset = ProtDrugDataset(root='data', dataset=dataset + '_' + 'train', xd=test_drugs, protein_key=test_prot_keys, label=test_Y, smile_graph=smile_graph, protein_graph=protein_graph)

    return test_dataset

In [100]:
TEST_BATCH_SIZE = 256

model = GCNNet()
model.to(device)

model.load_state_dict(torch.load(model_file_name, map_location='cuda:0'))

test_data = create_test_data(dataset)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=TEST_BATCH_SIZE, shuffle=False, collate_fn=collate)

y_true, y_preds = predicting(model, device, test_loader)
calculate_metrics(y_true, y_preds, dataset)

GCNNet Loaded


FileNotFoundError: [Errno 2] No such file or directory: 'models/model_Davis_GCN_3_davis.model'