In [1]:
import dataparser as data
import os
import numpy as np
import math
import torch
import pandas as pd
import torch.optim as optim
import util as util
import model as model
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device is: ', DEVICE)

PATH = '/home/jhbang/HDD/jeho/deepDTA_convert/data/'

SmilesConvNet = model.SmilesConv1dNet().to(DEVICE)
ProteinConvNet = model.ProteinConv1dNet().to(DEVICE)
CombineNet = model.CombineFCNet(SmilesConvNet,ProteinConvNet).to(DEVICE)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(CombineNet.parameters(), lr=0.001)

load = torch.load('/home/jhbang/HDD/jeho/deepDTA_convert/Models/davis_DTAmodel.pt')
CombineNet.load_state_dict(load)

def train(model, train_loader, optimizer, epoch):
    model.train()
    for batch_index, (drug, protein, affinity) in enumerate(train_loader):
        drug, protein, affinity = drug.to(DEVICE, dtype=torch.float), protein.to(DEVICE, dtype=torch.float), affinity.to(DEVICE, dtype=torch.float)
        optimizer.zero_grad()
        output = model(drug, protein)
        loss = criterion(torch.squeeze(output,1), affinity)
        loss.backward()
        optimizer.step()
        if batch_index % 20 == 0:
            print('train epoch : {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch,
                                                                            batch_index * len(drug),
                                                                            len(train_loader.dataset),
                                                                            100. * batch_index / len(train_loader),
                                                                            loss.item()))
        
def evaludate(model, test_loader):
    model.eval()
    total_preds = torch.Tensor().to(DEVICE)
    total_labels = torch.Tensor().to(DEVICE)
    
    print('Make prediction for {} sample...'.format(len(test_loader.dataset)))
    
    with torch.no_grad():
        for batch_index, (drug,protein, affinity) in enumerate(test_loader):
            drug, protein, affinity = drug.to(DEVICE, dtype=torch.float), protein.to(DEVICE, dtype=torch.float), affinity.to(DEVICE, dtype=torch.float)
            output = model(drug, protein)
            total_preds = torch.cat((total_preds, output),0)
            total_labels = torch.cat((total_labels, affinity.view(-1, 1)),0)
            # total_labels = total_labels.cpu()
            # total_preds = total_preds.cpu()
    return total_labels.cpu().numpy().flatten(), total_preds.cpu().numpy().flatten()


train_loader, test_loader = data.data_select(PATH, data_opt='davis')

# for epoch in range(100):
#     print('model train starts...')
#     train(CombineNet,train_loader, optimizer, epoch)
#     Y,P = evaludate(CombineNet, test_loader)
#     re = [epoch, util.mse(Y,P), util.ci(Y,P)]
#     print("evaluate test set: epoch: {}, mse: {}, ci {}".format(re[0],re[1],re[2]))
# print('end...')
# torch.save(CombineNet.state_dict(), '/home/jhbang/HDD/jeho/deepDTA_convert/Models/demo_model.pt')

Y,P = evaludate(CombineNet, test_loader)
re = [ util.mse(Y,P), util.ci(Y,P)]
print("evaluate test set: mse: {}, ci {}".format(re[0],re[1]))

device is:  cuda
davis   train  dataset loading...
checking csv file exists...
exists data file, load data...
davis   test  dataset loading...
checking csv file exists...
exists data file, load data...
data loads success!
Make prediction for 5010 sample...
torch.Size([128, 192])
torch.Size([128, 192])
torch.Size([128, 192])
torch.Size([128, 192])
torch.Size([128, 192])
torch.Size([128, 192])
torch.Size([128, 192])
torch.Size([128, 192])
torch.Size([128, 192])
torch.Size([128, 192])
torch.Size([128, 192])
torch.Size([128, 192])
torch.Size([128, 192])
torch.Size([128, 192])
torch.Size([128, 192])
torch.Size([128, 192])
torch.Size([128, 192])
torch.Size([128, 192])
torch.Size([128, 192])
torch.Size([128, 192])
torch.Size([128, 192])
torch.Size([128, 192])
torch.Size([128, 192])
torch.Size([128, 192])
torch.Size([128, 192])
torch.Size([128, 192])
torch.Size([128, 192])
torch.Size([128, 192])
torch.Size([128, 192])
torch.Size([128, 192])
torch.Size([128, 192])
torch.Size([128, 192])
torch.S