In [1]:
import torch

def format_pytorch_version(version):
  return version.split('+')[0]

TORCH_version = torch.__version__
TORCH = format_pytorch_version(TORCH_version)

def format_cuda_version(version):
  return 'cu' + version.replace('.', '')

CUDA_version = torch.version.cuda
CUDA = format_cuda_version(CUDA_version)

!pip install torch-scatter     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-sparse      -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install pip install git+https://github.com/pyg-team/pytorch_geometric.git

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://pytorch-geometric.com/whl/torch-1.12.1+cu113.html
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://pytorch-geometric.com/whl/torch-1.12.1+cu113.html
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/pyg-team/pytorch_geometric.git
  Cloning https://github.com/pyg-team/pytorch_geometric.git to /tmp/pip-req-build-zdpu5pll
  Running command git clone -q https://github.com/pyg-team/pytorch_geometric.git /tmp/pip-req-build-zdpu5pll


In [2]:
import torch
from torch import nn
import numpy as np
import torch_geometric
from torch_geometric.nn import GATConv
from torch_geometric.nn import global_max_pool as gmp
from torch_geometric.data import InMemoryDataset
from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_dense_batch
from math import sqrt
from scipy import stats

In [3]:
class att_GATNet(torch.nn.Module):
    def __init__(self, n_features_drug=78, drug_output_dim=128, n_gat_heads = 10, target_length_max = 1000, n_features_target=25, target_embed_dim=128, n_cnn_filters=32, target_kernel = [4,8,12], dropout=0.2):
        
        super(att_GATNet, self).__init__()
        
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.dropout = nn.Dropout(dropout)
        
        # graph layers
        self.gat1 = GATConv(n_features_drug, n_features_drug, heads=n_gat_heads, dropout=dropout)
        self.gat2 = GATConv(n_features_drug * n_gat_heads, drug_output_dim, dropout=dropout)
        
        # Target embeding
        self.target_embedding = nn.Embedding(num_embeddings=n_features_target + 1, embedding_dim=target_embed_dim)
        
        # 1D convolution on protein sequence embeddings
        self.Target_CNNs = nn.Sequential(
                            nn.Conv1d(in_channels=target_embed_dim, out_channels=n_cnn_filters, kernel_size=target_kernel[0],padding=0),
                            nn.ReLU(),
                            nn.Conv1d(in_channels=n_cnn_filters, out_channels=n_cnn_filters*2, kernel_size=target_kernel[1],padding=0),
                            nn.ReLU(),
                            nn.Conv1d(in_channels=n_cnn_filters*2, out_channels=n_cnn_filters*4, kernel_size=target_kernel[2],padding=0),
                            nn.ReLU(),
                            )
    
        # attention layers
        self.drug_attention_layer = nn.Linear(drug_output_dim, drug_output_dim)
        self.target_attention_layer = nn.Linear(n_cnn_filters*4, n_cnn_filters*4)
        self.attention_layer = nn.Linear(n_cnn_filters*4, n_cnn_filters*4)
    
        # head prediction MLP
        self.target_maxpool = nn.MaxPool1d(kernel_size=target_length_max-target_kernel[0]-target_kernel[1]-target_kernel[2]+3)
        self.Head_MLP = nn.Sequential(
                        self.dropout,
                        nn.Linear(n_cnn_filters*4+drug_output_dim, 1024),
                        nn.LeakyReLU(),
                        self.dropout,
                        nn.Linear(1024, 1024),
                        nn.LeakyReLU(),
                        self.dropout,
                        nn.Linear(1024, 512),
                        nn.LeakyReLU(),
                        self.dropout,
                        nn.Linear(512,1)
                        )

    def forward(self, data):
        
        # drug input processed by GATs
        x_drug, edge_index, batch = data.x, data.edge_index, data.batch

        x_drug = self.dropout(x_drug)
        x_drug = nn.functional.elu(self.gat1(x_drug, edge_index))
        x_drug = self.dropout(x_drug)
        x_drug = self.relu(self.gat2(x_drug, edge_index))
        x_drug, _ = to_dense_batch(x=x_drug, batch=batch)

        #print('x_drug:', x_drug.shape)
        

        # target input processed by CNNs:
        x_target = data.target
        x_target = self.target_embedding(x_target)
        x_target = x_target.permute(0,2,1)
        x_target = self.Target_CNNs(x_target)
        #print('x_target:',x_target.shape)

        # attention layer to concatenate target/drug outputs
        drug_att = self.drug_attention_layer(x_drug)
        #print('drug_att: ', drug_att.shape)
        target_att = self.target_attention_layer(x_target.permute(0, 2, 1))
        #print('target_att: ', target_att.shape)
        drug_att_layers = torch.unsqueeze(drug_att, 2).repeat(1, 1, x_target.shape[-1], 1)  
        target_att_layers = torch.unsqueeze(target_att, 1).repeat(1, x_drug.shape[-2], 1, 1)
        #print('drug_att_layers:', drug_att_layers.shape)
        #print('target_att_layers:', target_att_layers.shape)
        Atten_matrix = self.attention_layer(self.relu(drug_att_layers + target_att_layers))
        drug_weight = self.sigmoid(torch.mean(Atten_matrix, 2))
        target_weight = self.sigmoid(torch.mean(Atten_matrix, 1).permute(0,2,1))
                
        # concatenate drug/target outputs with attentions
        x_drug= x_drug * 0.5 + x_drug * drug_weight
        #print('x_drug: ', x_drug.shape)
        x_target = x_target * 0.5 + x_target * target_weight
        #print('x_target: ', x_target.shape)
        x_drug, _ = torch.max(x_drug, dim=1)
        #print('x_drug: ', x_drug.shape)
        x_target = self.target_maxpool(x_target).squeeze(2)
        #print('x_target: ', x_target.shape)
        x = torch.cat([x_drug, x_target], dim=1)
        #print('x: ', x.shape)

        # head MLP layers to make predictions
        y = self.Head_MLP(x)
        #print('y', y.shape)
        
        return y

In [4]:
TRAIN_BATCH_SIZE = 50
TEST_BATCH_SIZE = 50
LR = 0.0005
LOG_INTERVAL = 50
NUM_EPOCHS = 100
dataset = 'davis'
model_st = 'att_GAT'

In [5]:
data_train = InMemoryDataset()
data_train.data, data_train.slices = torch.load('drive/MyDrive/dti_model/data/davis_train.pt')
train_loader = DataLoader(data_train, batch_size=TEST_BATCH_SIZE, shuffle=True)

data_test = InMemoryDataset()
data_test.data, data_test.slices = torch.load('drive/MyDrive/dti_model/data/davis_test.pt')
test_loader = DataLoader(data_test, batch_size=TEST_BATCH_SIZE, shuffle=True)



In [6]:
# training function at each epoch
def train(model, device, train_loader, optimizer, epoch):
    print('Training on {} samples...'.format(len(train_loader.dataset)))
    model.train()
    for batch_idx, data in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, data.y.view(-1, 1).float().to(device))
        loss.backward()
        optimizer.step()
        if batch_idx % LOG_INTERVAL == 0:
            print('Train epoch: {} [({:.0f}%)]\tLoss: {:.6f}'.format(epoch,
                                                                           100. * batch_idx / len(train_loader),
                                                                           loss.item()))

def predicting(model, device, loader):
    model.eval()
    total_preds = torch.Tensor()
    total_labels = torch.Tensor()
    print('Make prediction for {} samples...'.format(len(loader.dataset)))
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            output = model(data)
            total_preds = torch.cat((total_preds, output.cpu()), 0)
            total_labels = torch.cat((total_labels, data.y.view(-1, 1).cpu()), 0)
    return total_labels.numpy().flatten(),total_preds.numpy().flatten()


def rmse(y,f):
    rmse = sqrt(((y - f)**2).mean(axis=0))
    return rmse
def mse(y,f):
    mse = ((y - f)**2).mean(axis=0)
    return mse
def pearson(y,f):
    rp = np.corrcoef(y, f)[0,1]
    return rp
def spearman(y,f):
    rs = stats.spearmanr(y, f)[0]
    return rs
def ci(y,f):
    ind = np.argsort(y)
    y = y[ind]
    f = f[ind]
    i = len(y)-1
    j = i-1
    z = 0.0
    S = 0.0
    while i > 0:
        while j >= 0:
            if y[i] > y[j]:
                z = z+1
                u = f[i] - f[j]
                if u > 0:
                    S = S + 1
                elif u == 0:
                    S = S + 0.5
            j = j - 1
        i = i - 1
        j = i-1
    ci = S/z
    return ci

In [7]:
cuda_name = "cuda:0"
device = torch.device(cuda_name if torch.cuda.is_available() else "cpu")
model = att_GATNet().to(device)
torch.save(model.state_dict(), 'drive/MyDrive/dti_model/checkpoint/model_checkpoint.pt')
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
best_mse = 1000
best_ci = 0
best_epoch = -1
model_file_name = 'drive/MyDrive/dti_model/results/model_' + model_st + '_' + dataset +  '.pt'
result_file_name = 'drive/MyDrive/dti_model/results/result_' + model_st + '_' + dataset +  '.csv'
for epoch in range(NUM_EPOCHS):
    train(model, device, train_loader, optimizer, epoch+1)
    torch.save(model.state_dict(), 'drive/MyDrive/dti_model/checkpoint/model_checkpoint.pt')
    G,P = predicting(model, device, test_loader)
    ret = [rmse(G,P),mse(G,P),pearson(G,P),spearman(G,P),ci(G,P)]
    if ret[1]<best_mse:
        torch.save(model.state_dict(), model_file_name)
        with open(result_file_name,'w') as f:
            f.write(','.join(map(str,ret)))
        best_epoch = epoch+1
        best_mse = ret[1]
        best_ci = ret[-1]
        print('rmse improved at epoch ', best_epoch, '; best_mse,best_ci:', best_mse,best_ci,model_st,dataset)
    else:
        print(ret[1],'No improvement since epoch ', best_epoch, '; best_mse,best_ci:', best_mse,best_ci,model_st,dataset)

Training on 25046 samples...
Make prediction for 5010 samples...
rmse improved at epoch  1 ; best_mse,best_ci: 0.66669285 0.7260917601880592 att_GAT davis
Training on 25046 samples...
Make prediction for 5010 samples...
0.93626714 No improvement since epoch  1 ; best_mse,best_ci: 0.66669285 0.7260917601880592 att_GAT davis
Training on 25046 samples...
Make prediction for 5010 samples...
1.3259116 No improvement since epoch  1 ; best_mse,best_ci: 0.66669285 0.7260917601880592 att_GAT davis
Training on 25046 samples...
Make prediction for 5010 samples...
rmse improved at epoch  4 ; best_mse,best_ci: 0.65094954 0.7584716008265955 att_GAT davis
Training on 25046 samples...
Make prediction for 5010 samples...
0.7545202 No improvement since epoch  4 ; best_mse,best_ci: 0.65094954 0.7584716008265955 att_GAT davis
Training on 25046 samples...
Make prediction for 5010 samples...
rmse improved at epoch  6 ; best_mse,best_ci: 0.6431258 0.7546829104456696 att_GAT davis
Training on 25046 samples...