In [1]:
from collections import OrderedDict
import json,pickle

dataset = 'Stitch'
# load dataset
dataset_path = 'data/' + dataset + '/'

# train_fold_origin = json.load(open(dataset_path + 'folds/train_fold_setting1.txt'))
# train_fold_origin = [e for e in train_fold_origin]  # for 5 folds

ligands = json.load(open(dataset_path + 'ligands_can.txt'), object_pairs_hook=OrderedDict)
proteins = json.load(open(dataset_path + 'proteins.txt'), object_pairs_hook=OrderedDict)

In [2]:
len(ligands)

165366

In [3]:
# drugs = []
# drug_smiles = []
from rdkit import RDLogger
from rdkit import Chem
from rdkit.Chem import MolFromSmiles
# RDLogger.DisableLog('rdApp.*')
# # smiles
# for d in ligands.keys():
#     lg = Chem.MolToSmiles(Chem.MolFromSmiles(ligands[d]), isomericSmiles=True)
#     drugs.append(lg)
#     drug_smiles.append(ligands[d])

In [4]:
# one ont encoding
def one_of_k_encoding(x, allowable_set):
    if x not in allowable_set:
        # print(x)
        raise Exception('input {0} not in allowable set{1}:'.format(x, allowable_set))
    return list(map(lambda s: x == s, allowable_set))


def one_of_k_encoding_unk(x, allowable_set):
    '''Maps inputs not in the allowable set to the last element.'''
    if x not in allowable_set:
        x = allowable_set[-1]
    return list(map(lambda s: x == s, allowable_set))
    
# mol atom feature for mol graph
def atom_features(atom):
    # 44 +11 +11 +11 +1
    return np.array(one_of_k_encoding_unk(atom.GetSymbol(),
                                          ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'As',
                                           'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se',
                                           'Ti', 'Zn', 'H', 'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr', 'Cr',
                                           'Pt', 'Hg', 'Pb', 'X']) +
                    one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) +
                    one_of_k_encoding_unk(atom.GetTotalNumHs(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) +
                    one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) +
                    [atom.GetIsAromatic()])

In [5]:
# mol smile to mol graph edge index
import networkx as nx
import numpy as np
def smile_to_graph(smile):
    mol = Chem.MolFromSmiles(smile)

    c_size = mol.GetNumAtoms()

    features = []
    for atom in mol.GetAtoms():
        feature = atom_features(atom)
        features.append(feature / sum(feature))

    edges = []
    for bond in mol.GetBonds():
        edges.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()])
    g = nx.Graph(edges).to_directed()
    edge_index = []
    mol_edge_weight=[]
    mol_adj = np.zeros((c_size, c_size))
    for e1, e2 in g.edges:
        mol_adj[e1, e2] = 1
        # edge_index.append([e1, e2])
    mol_adj += np.matrix(np.eye(mol_adj.shape[0]))
    index_row, index_col = np.where(mol_adj >= 0.5)
    for i, j in zip(index_row, index_col):
        edge_index.append([i, j])
        mol_edge_weight.append([1])
    # print('smile_to_graph')
    # print(np.array(features).shape)
    return c_size, features, edge_index,mol_edge_weight

In [6]:
# 14min
# compound_iso_smiles = drugs

# create smile graph
# smile_graph = {}
# i = 0
# for smile in compound_iso_smiles:
#     g = smile_to_graph(smile)
#     smile_graph[smile] = g
#     if i%1000==0:
#         print(i)
#     i+=1
with open('data/Stitch/temp/smile_graph.pickle','rb') as f:
    smile_graph = pickle.load(f)

In [7]:
import torch
import pandas as pd

# esm_emb_path = 'data/' + dataset + '/protein_emb/UniRef50_'
# protein_dict = json.load(open('data/Stitch/proteins.txt'))
# target_reps_dict = {}
# i=0
# for key in proteins.keys():
#     target_reps_dict[protein_dict[key]] = torch. load(esm_emb_path+ key + '.pt')['mean_representations'][36]
#     if i%1000==0:
#         print(i)
#     i+=1

# with open('data/Stitch/temp/protein_rep.pickle','wb') as handle:
#     pickle.dump(target_reps_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

with open('data/Stitch/temp/protein_rep.pickle', 'rb') as handle:
    target_reps_dict = pickle.load(handle)

In [8]:
# import pickle




In [9]:
# target_reps_dict

In [10]:
import os
from torch_geometric.data import InMemoryDataset, DataLoader, Batch
from torch_geometric import data as DATA
import torch
import numpy as np
import torchvision.transforms as T

# initialize the dataset
class DTADataset(InMemoryDataset):
    def __init__(self, root='/tmp', dataset='davis',
                 xd=None, y=None, transform= None,
                 pre_transform=None, smile_graph=None, target_key=None, target_rep=None):
        super(DTADataset, self).__init__(root, transform, pre_transform)
        self.dataset = dataset
        self.process(xd, target_key, y, smile_graph, target_rep)

    @property
    def raw_file_names(self):
        pass
        # return ['some_file_1', 'some_file_2', ...]

    @property
    def processed_file_names(self):
        return [self.dataset + '_data_mol.pt', self.dataset + '_data_pro.pt']

    def _process(self):
        if not os.path.exists(self.processed_dir):
            os.makedirs(self.processed_dir)

    def process(self, xd, target_key, y, smile_graph, target_rep):
        assert (len(xd) == len(target_key) and len(xd) == len(y)), 'The three lists must be the same length!'
        data_list_mol = []
        data_list_pro = []
        data_len = len(xd)
        for i in range(data_len):
            entity1 = xd[i]
            # print(torch.Tensor(target_rep[target_key[i]][36]))
            # print(torch.FloatTensor(y[i]).shape)
            # torch.from_numpy
            labels = y[i]
            # print(labels,torch.FloatTensor([labels]),torch.FloatTensor([labels]).shape)
            
            # labels = torch.concat((torch.FloatTensor([labels]),torch.Tensor(target_rep[target_key[i]][36])))
            # print(labels)
            # print(labels.shape)
            # print('DTI')
            # convert SMILES to molecular representation using rdkit
            if entity1 in smile_graph.keys():
                c_size, features, edge_index,edge_weight = smile_graph[entity1]
            else:
                # print('graph not found')
                c_size, features, edge_index,edge_weight = smile_to_graph(entity1)
                # print('complete')
            # print(target_features.shape, target_edge_index.shape)
            # make the graph ready for PyTorch Geometrics GCN algorithms:
            GCNData_mol = DATA.Data(x=torch.Tensor(np.array(features)),
                                    edge_index=torch.LongTensor(edge_index).transpose(1, 0),
                                    y=torch.FloatTensor([labels])
                                    )
            GCNData_mol.__setitem__('c_size', torch.LongTensor([c_size]))
            data_list_mol.append(GCNData_mol)

            data_list_pro.append(torch.Tensor(target_rep[target_key[i]]))
            if i%10000==0:
                print(i)
            # print(data_list_mol,data_list_pro)
   
            
        if self.pre_filter is not None:
            data_list_mol = [data for data in data_list_mol if self.pre_filter(data)]
        if self.pre_transform is not None:
            data_list_mol = [self.pre_transform(data) for data in data_list_mol]
        self.data_mol = data_list_mol
        self.data_pro = data_list_pro

    def __len__(self):
        return len(self.data_mol)

    def __getitem__(self, idx):
        return self.data_mol[idx], self.data_pro[idx]
        
def collate(batch):
    graphs = Batch.from_data_list([item[0] for item in batch])
    tensors = [item[1] for item in batch]
    tensors = torch.stack(tensors)

    return graphs,tensors

In [11]:

import warnings
warnings.filterwarnings("ignore")

In [12]:
df_train_fold = pd.read_csv('data/' + dataset + '/'+ dataset+'_' + 'train' + '.csv')
train_drugs, train_prot_keys, train_Y = list(df_train_fold['compound_iso_smiles']), list(df_train_fold['target_sequence']), list(df_train_fold['affinity'])
train_drugs, train_prot_keys, train_Y = np.asarray(train_drugs), np.asarray(train_prot_keys), np.asarray(train_Y)

train_dataset = DTADataset(root='data', dataset=dataset + '_' + 'train', xd=train_drugs, target_key=train_prot_keys,
                            y=train_Y, smile_graph=smile_graph, target_rep=target_reps_dict)

0
10000
20000
30000
40000
50000
60000
70000
80000
90000
100000
110000
120000
130000
140000
150000
160000
170000
180000
190000
200000
210000
220000
230000
240000
250000
260000
270000
280000
290000
300000
310000
320000
330000
340000
350000
360000
370000
380000
390000
400000
410000
420000
430000
440000
450000
460000
470000
480000
490000
500000
510000
520000
530000
540000
550000
560000
570000
580000
590000
600000
610000
620000
630000
640000
650000
660000
670000
680000
690000
700000
710000
720000
730000
740000
750000
760000
770000
780000
790000
800000
810000


In [13]:
train_dataset

DTADataset(818602)

In [14]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [29]:
datasets = ['davis', 'kiba','Stitch']

cuda_name = 'cuda:0'
print('cuda_name:', cuda_name)
fold = [0, 1, 2, 3, 4][0]
cross_validation_flag = True

TRAIN_BATCH_SIZE = 512
TEST_BATCH_SIZE = 512
LR = 0.001
NUM_EPOCHS = 300
from sklearn.model_selection import train_test_split
train_data,valid_data=train_test_split(train_dataset,shuffle=True,test_size=0.2)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=TRAIN_BATCH_SIZE, shuffle=True,num_workers=8,
                                            collate_fn=collate)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=TEST_BATCH_SIZE, shuffle=False,num_workers=4,
                                            collate_fn=collate)
                                            
# next(iter(train_loader))

cuda_name: cuda:0


In [30]:
# Model

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric
from torch_geometric.nn import TransformerConv,GATConv, GCNConv,global_max_pool as gmp, global_add_pool as gap,global_mean_pool as gep,global_sort_pool
from torch_geometric.utils import dropout_adj



# GCN based model
class GNNNet(torch.nn.Module):
    def __init__(self, n_output=1, num_features_pro=54, num_features_mol=78, output_dim=128, dropout=0.2):
        super(GNNNet, self).__init__()

        print('GNNNet Loaded')
        self.n_output = n_output
        self.mol_conv1 = TransformerConv(num_features_mol, num_features_mol)
        self.mol_conv2 = TransformerConv(num_features_mol, num_features_mol * 2)
        self.mol_conv3 = TransformerConv(num_features_mol * 2, num_features_mol * 4)
        self.mol_fc_g1 = torch.nn.Linear(num_features_mol * 4, 1024)
        self.mol_fc_g2 = torch.nn.Linear(1024, output_dim)

        # self.pro_conv1 = GCNConv(embed_dim, embed_dim)
        # self.pro_conv1 = GCNConv(num_features_pro, num_features_pro)

        # self.pro_conv4 = GCNConv(embed_dim * 4, embed_dim * 8)
        self.pro_fc_g1 = torch.nn.Linear(num_features_pro, 1024)
        self.pro_fc_g2 = torch.nn.Linear(1024, output_dim)

        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

        # combined layers
        self.fc1 = nn.Linear(2688, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.out = nn.Linear(512, self.n_output)

    def forward(self, data_mol, data_pro):
        # get graph input
        mol_x, mol_edge_index, mol_batch = data_mol.x, data_mol.edge_index, data_mol.batch
        # get protein input
        # target_x, target_edge_index, target_batch = data_pro.x, data_pro.edge_index, data_pro.batch

        # target_seq=data_pro.target

        # print('size')
        # print('mol_x', mol_x.size(), 'edge_index', mol_edge_index.size(), 'batch',mol_batch, mol_batch.size())
        # print('target_x', target_x.size(), 'target_edge_index', target_edge_index,target_edge_index.size(), 'batch',target_batch, target_batch.size())

        x = self.mol_conv1(mol_x, mol_edge_index)
        x = self.relu(x)

        # mol_edge_index, _ = dropout_adj(mol_edge_index, training=self.training)
        x = self.mol_conv2(x, mol_edge_index)
        x = self.relu(x)

        # mol_edge_index, _ = dropout_adj(mol_edge_index, training=self.training)
        x = self.mol_conv3(x, mol_edge_index)
        x = self.relu(x)

        x = gep(x, mol_batch)  # global pooling

        # flatten
        x = self.relu(self.mol_fc_g1(x))
        x = self.dropout(x)
        x = self.mol_fc_g2(x)
        x = self.dropout(x)

        # # xt = self.pro_conv1(target_x, target_edge_index)
        # xt = self.relu(xt)
        # xt = gep(xt, target_batch)  # global pooling

        # # flatten
        # xt = self.relu(self.pro_fc_g1(xt))
        # xt = self.dropout(xt)
        # xt = self.pro_fc_g2(xt)
        # xt = self.dropout(xt)

        # print(x.size(), xt.size())
        # concat
        xc = torch.cat((x, data_pro), 1)
        # add some dense layers
        xc = self.fc1(xc)
        xc = self.relu(xc)
        xc = self.dropout(xc)
        xc = self.fc2(xc)
        xc = self.relu(xc)
        xc = self.dropout(xc)
        out = self.out(xc)
        return out



In [31]:


print('Learning rate: ', LR)
print('Epochs: ', NUM_EPOCHS)

models_dir = 'models'
results_dir = 'results'

if not os.path.exists(models_dir):
    os.makedirs(models_dir)

if not os.path.exists(results_dir):
    os.makedirs(results_dir)

# Main program: iterate over different datasets
result_str = ''
USE_CUDA = torch.cuda.is_available()
device = torch.device(cuda_name if USE_CUDA else 'cpu')
model = GNNNet()
model.to(device)

model_st = GNNNet.__name__
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

Learning rate:  0.001
Epochs:  300
GNNNet Loaded


In [32]:
# training function at each epoch
scaler = torch.cuda.amp.GradScaler()
def train(model, device, train_loader, optimizer, epoch):
    print('Training on {} samples...'.format(len(train_loader.dataset)))
    model.train()
    LOG_INTERVAL = 100
    TRAIN_BATCH_SIZE = 512
    loss_fn = torch.nn.MSELoss()
    
    for batch_idx, data in enumerate(train_loader):
        data_mol = data[0].to(device)
        # data_mol = [item.to(device) for item in data[0]]
        # data_mol = data[0].to(device)
        data_pro = data[1].to(device)
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            output = model(data_mol, data_pro)
            # print(data_mol)
            # print(data_mol.y)
            # labels= [sample.y.float().to(device) for sample in data_mol]
            # labels=torch.stack(labels).view(-1, 1)
            labels = data_mol.y.view(-1, 1)
            # print(output.shape,labels.shape)
            loss = loss_fn(output, labels)
        #loss.backward()
        scaler.scale(loss).backward()
        wandb.log({"loss per batch": loss})
        #optimizer.step()
        scaler.step(optimizer)
        scaler.update()
        if batch_idx % LOG_INTERVAL == 0:
            print('Train epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch,
                                                                           batch_idx * TRAIN_BATCH_SIZE,
                                                                           len(train_loader.dataset),
                                                                           100. * batch_idx / len(train_loader),
                                                                           loss.item()))

# predict
def predicting(model, device, loader):
    model.eval()
    total_preds = torch.Tensor()
    total_labels = torch.Tensor()
    print('Make prediction for {} samples...'.format(len(loader.dataset)))
    with torch.no_grad():
        for data in loader:
            data_mol = data[0].to(device)
            data_pro = data[1].to(device)
            output = model(data_mol, data_pro)
            labels = data_mol.y.view(-1, 1)
            total_preds = torch.cat((total_preds, output.cpu()), 0)

            total_labels = torch.cat((total_labels, labels.cpu()), 0)
    return total_labels.numpy().flatten(), total_preds.numpy().flatten()

In [36]:
import sys, os
import torch
import torch.nn as nn
from torch_geometric.data import DataLoader
import wandb
wandb.init(project="my-dgraphdta-resrt-project", entity="daga06")




0,1
loss per batch,▅█▅▅▄▆▆▇▆▃▄▄▄▂▅▆▃▅▇▁▅▅▆▅▆▃▅▇▃▄█▃▅▄▃▅▇▁▃▄
test_ci,▁█
test_mse,█▁
test_pc,▁█
val_ci,▁▅█
val_mse,█▅▁
val_pc,▁▅█

0,1
loss per batch,1.88305
test_ci,0.62117
test_mse,1.56598
test_pc,0.44533
val_ci,0.62084
val_mse,1.58015
val_pc,0.44559


In [34]:

# All metrics
import os
import sys
import torch
import numpy as np
from random import shuffle
import matplotlib.pyplot as plt
from torch_geometric.data import Batch

from emetrics import get_aupr, get_cindex, get_rm2, get_ci, get_mse, get_rmse, get_pearson, get_spearman
# from utils import *
from scipy import stats
# from gnn import GNNNet
from data_process import create_dataset_for_test

from lifelines.utils import concordance_index




def load_model(model_path):
    model = torch.load(model_path)
    return model


import time
def calculate_metrics(Y, P, dataset='davis'):
    # # aupr = get_aupr(Y, P)
    # t = time.time()
    
    
    # cindex = get_cindex(Y, P)
    # print(cindex)
    # print(concordance_index(Y, P))  # DeepDTAget_cindex(Y, P)
    
    
    cindex2 = concordance_index(Y, P)  # GraphDTA
    # rm2 = get_rm2(Y, P)  # DeepDTA
    mse = get_mse(Y, P)
    # t2 = time.time()
    pearson = get_pearson(Y, P)
    # t3 = time.time()
    # spearman = get_spearman(Y, P)
    # rmse = get_rmse(Y, P)

    print('metrics for ', dataset)
    # print('aupr:', aupr)
    # print('cindex:', cindex)
    print('cindex2', cindex2)
    # print('rm2:', rm2)
    print('mse:', mse)
    print('pearson', pearson)
    # print(t - t1,t2-t1,t3-t2)

    


    # result_file_name = 'results/result_' + model_st + '_' + dataset + '.txt'
    result_str = ''
    result_str += dataset + '\r\n'
    result_str += ' ' + ' mse:' + str(mse) + ' ' + ' pearson:' + str(pearson) + ' '+ ' ' + 'ci:' + str(cindex2)
    print(result_str)
    # open(result_file_name, 'w').writelines(result_str)
    return mse,cindex2,pearson

def plot_density(Y, P, fold=0, dataset='davis'):
    plt.figure(figsize=(10, 5))
    plt.grid(linestyle='--')
    ax = plt.gca()
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    plt.scatter(P, Y, color='blue', s=40)
    plt.title('density of ' + dataset, fontsize=30, fontweight='bold')
    plt.xlabel('predicted', fontsize=30, fontweight='bold')
    plt.ylabel('measured', fontsize=30, fontweight='bold')
    # plt.xlim(0, 21)
    # plt.ylim(0, 21)
    if dataset == 'davis':
        plt.plot([5, 11], [5, 11], color='black')
    else:
        plt.plot([6, 16], [6, 16], color='black')
    # plt.legend()
    plt.legend(loc=0, numpoints=1)
    leg = plt.gca().get_legend()
    ltext = leg.get_texts()
    plt.setp(ltext, fontsize=12, fontweight='bold')
    plt.savefig(os.path.join('results', dataset + '_' + str(fold) + '.png'), dpi=500, bbox_inches='tight')


    # plot_density(Y, P, fold, dataset)
import numpy as np
import subprocess
from math import sqrt
from sklearn.metrics import average_precision_score
from scipy import stats


def get_aupr(Y, P, threshold=7.0):
    # print(Y.shape,P.shape)
    Y = np.where(Y >= 7.0, 1, 0)
    P = np.where(P >= 7.0, 1, 0)
    aupr = average_precision_score(Y, P)
    return aupr


def get_cindex(Y, P):
    summ = 0
    pair = 0

    for i in range(1, len(Y)):
        for j in range(0, i):
            if i is not j:
                if (Y[i] > Y[j]):
                    pair += 1
                    summ += 1 * (P[i] > P[j]) + 0.5 * (P[i] == P[j])

    if pair != 0:
        return summ / pair
    else:
        return 0


def r_squared_error(y_obs, y_pred):
    y_obs = np.array(y_obs)
    y_pred = np.array(y_pred)
    y_obs_mean = [np.mean(y_obs) for y in y_obs]
    y_pred_mean = [np.mean(y_pred) for y in y_pred]

    mult = sum((y_pred - y_pred_mean) * (y_obs - y_obs_mean))
    mult = mult * mult

    y_obs_sq = sum((y_obs - y_obs_mean) * (y_obs - y_obs_mean))
    y_pred_sq = sum((y_pred - y_pred_mean) * (y_pred - y_pred_mean))

    return mult / float(y_obs_sq * y_pred_sq)


def get_k(y_obs, y_pred):
    y_obs = np.array(y_obs)
    y_pred = np.array(y_pred)

    return sum(y_obs * y_pred) / float(sum(y_pred * y_pred))


def squared_error_zero(y_obs, y_pred):
    k = get_k(y_obs, y_pred)

    y_obs = np.array(y_obs)
    y_pred = np.array(y_pred)
    y_obs_mean = [np.mean(y_obs) for y in y_obs]
    upp = sum((y_obs - (k * y_pred)) * (y_obs - (k * y_pred)))
    down = sum((y_obs - y_obs_mean) * (y_obs - y_obs_mean))

    return 1 - (upp / float(down))


def get_rm2(ys_orig, ys_line):
    r2 = r_squared_error(ys_orig, ys_line)
    r02 = squared_error_zero(ys_orig, ys_line)

    return r2 * (1 - np.sqrt(np.absolute((r2 * r2) - (r02 * r02))))


def get_rmse(y, f):
    rmse = sqrt(((y - f) ** 2).mean(axis=0))
    return rmse


def get_mse(y, f):
    mse = ((y - f) ** 2).mean(axis=0)
    return mse


def get_pearson(y, f):
    rp = np.corrcoef(y, f)[0, 1]
    return rp


def get_spearman(y, f):
    rs = stats.spearmanr(y, f)[0]
    return rs


def get_ci(y, f):
    ind = np.argsort(y)
    y = y[ind]
    f = f[ind]
    i = len(y) - 1
    j = i - 1
    z = 0.0
    S = 0.0
    while i > 0:
        while j >= 0:
            if y[i] > y[j]:
                z = z + 1
                u = f[i] - f[j]
                if u > 0:
                    S = S + 1
                elif u == 0:
                    S = S + 0.5
            j = j - 1
        i = i - 1
        j = i - 1
    ci = S / z
    return ci

In [None]:
dataset = datasets[2]
df_test_fold = pd.read_csv('data/' + dataset + '/'+ dataset+'_' + 'test' + '.csv')
test_drugs, test_prot_keys, test_Y = list(df_test_fold['compound_iso_smiles']), list(df_test_fold['target_sequence']), list(df_test_fold['affinity'])
test_drugs, test_prot_keys, test_Y = np.asarray(test_drugs), np.asarray(test_prot_keys), np.asarray(test_Y)

test_data = DTADataset(root='data', dataset=dataset + '_' + 'test', xd=test_drugs, target_key=test_prot_keys,
                            y=test_Y, smile_graph=smile_graph, target_rep=target_reps_dict)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=TEST_BATCH_SIZE, shuffle=False,num_workers=4,
                                              collate_fn=collate)

In [37]:
#188 minutes
best_mse = 1000
best_test_mse = 1000
best_epoch = -1
# model_file_name = 'models/model_' + model_st + '_' + dataset + '_' + str(fold) + '.model'

for epoch in range(NUM_EPOCHS):
    train(model, device, train_loader, optimizer, epoch + 1)
    print('predicting for valid data')
    G, P = predicting(model, device, valid_loader)
    val_mse,val_ci,val_pc = calculate_metrics(G, P, dataset)
    wandb.log({"val_ci": val_ci})
    wandb.log({"val_mse": val_mse})
    wandb.log({"val_pc": val_pc})

    print('predicting for test data')
    
    if val_mse < best_mse:
        best_mse = val_mse
        best_epoch = epoch + 1
        # torch.save(model.state_dict(), model_file_name)
        print('rmse improved at epoch ', best_epoch, '; best_test_mse', best_mse, model_st, dataset, fold)
    else:
        print('No improvement since epoch ', best_epoch, '; best_test_mse', best_mse, model_st, dataset, fold)
    # #reaching optimzation
    if epoch>=80: 
        G, P = predicting(model, device, test_loader)
        test_mse,test_ci,test_pc = calculate_metrics(G, P, dataset)
        wandb.log({"test_ci": test_ci})
        wandb.log({"test_mse": test_mse})
        wandb.log({"test_pc": test_pc})

Training on 654881 samples...


In [None]:
train_dataset

DTADataset(818602)

GNNNet(
  (mol_conv1): TransformerConv(78, 78, heads=1)
  (mol_conv2): TransformerConv(78, 156, heads=1)
  (mol_conv3): TransformerConv(156, 312, heads=1)
  (mol_fc_g1): Linear(in_features=312, out_features=1024, bias=True)
  (mol_fc_g2): Linear(in_features=1024, out_features=128, bias=True)
  (pro_fc_g1): Linear(in_features=54, out_features=1024, bias=True)
  (pro_fc_g2): Linear(in_features=1024, out_features=128, bias=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.2, inplace=False)
  (fc1): Linear(in_features=2688, out_features=1024, bias=True)
  (fc2): Linear(in_features=1024, out_features=512, bias=True)
  (out): Linear(in_features=512, out_features=1, bias=True)
)

In [24]:
model

GNNNet(
  (mol_conv1): TransformerConv(78, 78, heads=1)
  (mol_conv2): TransformerConv(78, 156, heads=1)
  (mol_conv3): TransformerConv(156, 312, heads=1)
  (mol_fc_g1): Linear(in_features=312, out_features=1024, bias=True)
  (mol_fc_g2): Linear(in_features=1024, out_features=128, bias=True)
  (pro_fc_g1): Linear(in_features=54, out_features=1024, bias=True)
  (pro_fc_g2): Linear(in_features=1024, out_features=128, bias=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.2, inplace=False)
  (fc1): Linear(in_features=2688, out_features=1024, bias=True)
  (fc2): Linear(in_features=1024, out_features=512, bias=True)
  (out): Linear(in_features=512, out_features=1, bias=True)
)

In [26]:
torch.load('saved_models/esm-2-stitch-0.05_validation_fully_optimized.pt')['model_state_dict']

OrderedDict([('mol_conv1.lin_key.weight',
              tensor([[ 0.1240,  0.1045, -0.0059,  ...,  0.0367,  0.0115, -0.0878],
                      [-0.2090, -0.1334, -0.0054,  ..., -0.0217, -0.0654,  0.0870],
                      [-0.2250, -0.0868,  0.0818,  ...,  0.0653, -0.1036, -0.0764],
                      ...,
                      [ 0.1865,  0.1649, -0.0717,  ...,  0.0680, -0.0225, -0.1060],
                      [-0.7012,  0.3572, -0.0622,  ..., -0.0217,  0.1060,  0.2105],
                      [ 0.2170, -0.0510, -0.3284,  ..., -0.0188,  0.0842,  0.0688]],
                     device='cuda:0')),
             ('mol_conv1.lin_key.bias',
              tensor([ 2.6778e-01, -1.8768e-02, -2.4345e-01, -1.6533e-01, -3.3262e-01,
                      -2.5063e-01, -2.2554e-01, -8.9362e-01,  2.8967e-01,  2.6792e-01,
                       3.0093e-01, -7.1495e-01,  4.9546e-01, -3.9335e-01, -6.7751e-02,
                      -5.5899e-02,  2.9631e-01,  1.9461e-01,  5.3106e-02, -6.1868e-01

In [None]:
# torch.save({
#     'epoch':epoch,
#     'model_state_dict':model.state_dict(),
#     'optimizer_state_dict':optimizer.state_dict()
# },'saved_models/esm-2-stitch.pt')

In [27]:
model.load_state_dict(torch.load('saved_models/esm-2-stitch.pt')['model_state_dict'])


<All keys matched successfully>

In [28]:


Y, P = predicting(model, device, test_loader)
calculate_metrics(Y, P, dataset)

0
10000
20000
30000
40000
50000
60000
70000
80000
90000
100000
110000
120000
130000
140000
150000
160000
170000
180000
190000
200000
210000
220000
230000
240000
250000
260000
270000
280000
290000
300000
310000
320000
330000
340000
350000
360000
370000
380000
390000
400000
410000
420000
430000
440000
Make prediction for 440786 samples...
metrics for  Stitch
cindex2 0.7394952272944774
mse: 1.0202613
pearson 0.6929313903787505
Stitch
  mse:1.0202613  pearson:0.6929313903787505  ci:0.7394952272944774


(1.0202613, 0.7394952272944774, 0.6929313903787505)

In [None]:
model

In [None]:
if __name__ == '__main__':
    dataset = datasets[2]  # dataset selection
    model_st = GNNNet.__name__
    print('dataset:', dataset)

    cuda_name = 'cuda:0'
    print('cuda_name:', cuda_name)

    TEST_BATCH_SIZE = 512
    '''models_dir = 'models'
    results_dir = 'results'

    device = torch.device(cuda_name if torch.cuda.is_available() else 'cpu')
    model_file_name = 'models/model_' + model_st + '_' + dataset + '.model'
    result_file_name = 'results/result_' + model_st + '_' + dataset + '.txt'

    model = GNNNet()
    model.to(device)
    model.load_state_dict(torch.load(model_file_name, map_location=cuda_name))'''
    test_data = create_dataset_for_test(dataset)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=TEST_BATCH_SIZE, shuffle=False,
                                              collate_fn=collate)

    Y, P = predicting(model, device, test_loader)
    calculate_metrics(Y, P, dataset)
    # plot_density(Y, P, fold, dataset)

dataset: davis
cuda_name: cuda:0
dataset: davis
test entries: 5010 effective test entries 5010


KeyboardInterrupt: 