## Magnet-DTI Training Script

In [2]:
# Check if the output directory exists, if not create it
import os
if not os.path.exists("/content/output"):
    os.makedirs("/content/output")

In [3]:
# GPU memory management
import os
import torch
torch.cuda.empty_cache()
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,garbage_collection_threshold:0.8"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT = 100000
torch.cuda.memory._record_memory_history(max_entries=MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT)

## Load Data

In [4]:
import pandas as pd

# Read the XLSX file
df1=pd.read_csv("/content/MAGNET-DTI/data-celegans/02protein_ESM(1).csv")
df2 = pd.read_excel("/content/MAGNET-DTI/data-celegans/02protein_ESM(2).xlsx",header=None)
df2.columns = df1.columns

df=pd.concat([df1,df2])
# Convert to CSV
df.to_csv("/content/MAGNET-DTI/data-celegans/02protein_ESM.csv", index=False)

In [5]:
df1.shape, df2.shape, df.shape

((937, 1281), (939, 1281), (1876, 1281))

In [6]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

#GAT
class GraphAttentionLayer(nn.Module):
    """
    Simple GAT layer, similar to https://arxiv.org/abs/1710.10903. This part of code refers to the implementation of https://github.com/Diego999/pyGAT.git

    """

    def __init__(self, in_features, out_features, alpha, concat=True):
        super(GraphAttentionLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.alpha = alpha
        self.concat = concat

        self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)
        self.a = nn.Parameter(torch.empty(size=(2 * out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)
        self.leakyrelu = nn.LeakyReLU(self.alpha)

    def forward(self, h, adj):
        #print("Here GAL")
        Wh = torch.mm(h, self.W)  # h.shape: (N, in_features), Wh.shape: (N, out_features)
        a_input = self._prepare_attentional_mechanism_input(Wh)
        e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))  # （N，N）

        zero_vec = -9e15 * torch.ones_like(e)
        attention = torch.where(adj > 0, e, zero_vec)
        attention = F.softmax(attention, dim=1)

        h_prime = torch.matmul(attention, Wh)
        return F.relu(h_prime)

    def _prepare_attentional_mechanism_input(self, Wh):
        N = Wh.size()[0]  # number of nodes

        # Below, two matrices are created that contain embeddings in their rows in different orders.
        # (e stands for embedding)
        # These are the rows of the first matrix (Wh_repeated_in_chunks):
        # e1, e1, ..., e1,            e2, e2, ..., e2,            ..., eN, eN, ..., eN
        # '-------------' -> N times  '-------------' -> N times       '-------------' -> N times
        #
        # These are the rows of the second matrix (Wh_repeated_alternating):
        # e1, e2, ..., eN, e1, e2, ..., eN, ..., e1, e2, ..., eN
        # '----------------------------------------------------' -> N times
        #

        Wh_repeated_in_chunks = Wh.repeat_interleave(N, dim=0)
        Wh_repeated_alternating = Wh.repeat(N, 1)
        # Wh_repeated_in_chunks.shape == Wh_repeated_alternating.shape == (N * N, out_features)

        # The all_combination_matrix, created below, will look like this (|| denotes concatenation):
        # e1 || e1
        # e1 || e2
        # e1 || e3
        # ...
        # e1 || eN
        # e2 || e1
        # e2 || e2
        # e2 || e3
        # ...
        # e2 || eN
        # ...
        # eN || e1
        # eN || e2
        # eN || e3
        # ...
        # eN || eN

        all_combinations_matrix = torch.cat([Wh_repeated_in_chunks, Wh_repeated_alternating], dim=1)
        # all_combinations_matrix.shape == (N * N, 2 * out_features)

        return all_combinations_matrix.view(N, N, 2 * self.out_features)

    def __repr__(self):
        return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'


class selfattention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        batch_size, seq_len, _ = x.size()

        q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attn = F.softmax(scores, dim=-1)

        output = torch.matmul(attn, v)
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)

        return output

#GCN
class GraphConvolutionLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super(GraphConvolutionLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features

        self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)

    def forward(self, h, adj):
        Wh = torch.mm(h, self.W)  # h.shape: (N, in_features), Wh.shape: (N, out_features)
        output = torch.matmul(adj, Wh)
        return F.relu(output)

#GAC
class GraphAttentionLayer_GAC(nn.Module):
    def __init__(self, in_features, out_features, alpha=0.2):
        super(GraphAttentionLayer_GAC, self).__init__()
        self.alpha = alpha
        self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)
        self.a = nn.Parameter(torch.zeros(size=(2*out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)

    def forward(self, input, adj):
        h = torch.matmul(input, self.W)
        N = h.size()[0]
        a_input = torch.cat([h.repeat(1, N).view(N*N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2*h.size(1))
        e = F.leaky_relu(torch.matmul(a_input, self.a).squeeze(2))
        attention = F.softmax(e, dim=1)
        h_prime = torch.matmul(attention, h)
        return F.elu(h_prime)

class GAC(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GAC, self).__init__()
        self.gc1 = GraphAttentionLayer_GAC(input_dim, hidden_dim)
        self.gc2 = GraphAttentionLayer_GAC(hidden_dim, output_dim)

    def forward(self, input, adj):
        x = F.relu(self.gc1(input, adj))
        x = self.gc2(x, adj)
        return x

In [None]:
import numpy as np
import torch
from sklearn.decomposition import PCA
from sklearn.preprocessing import scale


def load_data():

    # protein_fea
    DC = np.genfromtxt("/content/MAGNET-DTI/data-celegans/02protein_DC.csv", delimiter=',', skip_header=1, dtype=np.dtype(str))
    protein = DC[:, 0]
    protein_number = len(protein)
    DC = np.array(DC)
    DC = scale(np.array(DC[:,1:],dtype=float))
    pca = PCA(n_components=128)
    DC = pca.fit_transform(DC)
    DC_adj = sim_graph(DC, protein_number)

    ESM = np.genfromtxt("/content/MAGNET-DTI/data-celegans/02protein_ESM.csv", delimiter=',', skip_header=1, dtype=np.dtype(str))
    protein = ESM[:, 0]
    protein_number = len(protein)
    ESM = np.array(ESM)
    ESM = scale(np.array(ESM[:,1:],dtype=float))
    pca = PCA(n_components=128)
    ESM = pca.fit_transform(ESM)
    ESM_adj = sim_graph(ESM, protein_number)

    fusion_protein_fea = np.concatenate((DC, ESM), axis=1)
    fusion_protein_adj = np.logical_or(DC_adj, ESM_adj).astype(int)

    # drug_fea
    PC = np.genfromtxt("/content/MAGNET-DTI/data-celegans/02drug_phy_chem.csv", delimiter=',',skip_header=1,dtype=np.dtype(str))
    drug = PC[:,0]
    drug_number = len(drug)
    PC = np.array(PC)
    PC = scale(np.array(PC[:, 1:], dtype=float))
    pca = PCA(n_components=128)
    PC = pca.fit_transform(PC)
    PC_adj = sim_graph(PC, drug_number)

    MACCS = np.genfromtxt("/content/MAGNET-DTI/data-celegans/02drug_MACCS.csv", delimiter=',',skip_header=1,dtype=np.dtype(str))
    drug = MACCS[:, 0]
    drug_number = len(drug)
    MACCS = np.array(MACCS)
    MACCS = scale(np.array(MACCS[:,1:], dtype=float))
    pca = PCA(n_components=128)
    MACCS = pca.fit_transform(MACCS)
    MACCS_adj = sim_graph(MACCS, drug_number)

    fusion_drug_fea = np.concatenate((PC, MACCS), axis=1)
    fusion_drug_adj = np.logical_or(PC_adj, MACCS_adj).astype(int)

    #label
    labellist = []
    with open('/content/MAGNET-DTI/data-celegans/03label_unique.txt', 'r') as file:
        lines = file.readlines()
    for line in lines:
        line = line.strip()  # Remove whitespace characters from the beginning and end of the line
        elements = line.split(" ")  # Use spaces to separate elements
        processed_elements = [int(elements[1]), int(elements[0]), int(elements[2])]  # swap places
        labellist.append(processed_elements)
    labellist = torch.Tensor(labellist)
    print("drug protein lable:", labellist.shape)

    protein_feat, protein_adj = torch.FloatTensor(fusion_protein_fea), torch.FloatTensor(fusion_protein_adj)
    drug_feat, drug_adj = torch.FloatTensor(fusion_drug_fea), torch.FloatTensor(fusion_drug_adj)
    return protein_feat, protein_adj, drug_feat, drug_adj, labellist

def sim_graph(omics_data, protein_number):
    sim_matrix = np.zeros((protein_number, protein_number), dtype=float)
    adj_matrix = np.zeros((protein_number, protein_number), dtype=float)

    for i in range(protein_number):
        for j in range(i + 1):
            sim_matrix[i, j] = np.dot(omics_data[i], omics_data[j]) / (
                        np.linalg.norm(omics_data[i]) * np.linalg.norm(omics_data[j]))
            sim_matrix[j, i] = sim_matrix[i, j]

    for i in range(protein_number):
        topindex = np.argsort(sim_matrix[i])[-10:]
        for j in topindex:
            adj_matrix[i, j] = 1
    return adj_matrix

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MAGNETDTI(nn.Module):

    def __init__(self, nprotein, ndrug, nproteinfeat, ndrugfeat, nhid, nheads, alpha):
        """Dense version of GAT."""
        super(MAGNETDTI, self).__init__()
        #print("Here MAGNETDTI")
        self.protein_attentions1 = [GraphAttentionLayer(nproteinfeat, nhid, alpha=alpha, concat=True) for _ in range(nheads)]
        for i, attention in enumerate(self.protein_attentions1):
            self.add_module("Attention_Protein1_{}".format(i), attention)
        self.protein_MultiHead1 = [selfattention(64,nheads) for _ in range(nheads)]
        #self.protein_MultiHead1 = [selfattention(nprotein, nhid, nprotein) for _ in range(nheads)]
        for i, attention in enumerate(self.protein_MultiHead1):
            self.add_module("Self_Attention_Protein1_{}".format(i), attention)

        self.protein_prolayer1 = nn.Linear((nhid * nheads), (nhid * nheads), bias=False)
        self.protein_LNlayer1 = nn.LayerNorm(nhid * nheads)
        self.protein_attentions2 = [GraphAttentionLayer((nhid * nheads), nhid, alpha=alpha, concat=True) for _ in range(nheads)]
        for i, attention in enumerate(self.protein_attentions2):
            self.add_module("Attention_Protein2_{}".format(i), attention)

        self.protein_MultiHead2 = [selfattention(64,nheads) for _ in range(nheads)]
        #self.protein_MultiHead2 = [selfattention(nprotein, nhid, nprotein) for _ in range(nheads)]
        for i, attention in enumerate(self.protein_MultiHead2):
            self.add_module("Self_Attention_Protein2_{}".format(i), attention)

        self.protein_prolayer2 = nn.Linear((nhid * nheads), (nhid * nheads), bias=False)
        self.protein_LNlayer2 = nn.LayerNorm(nhid * nheads)
        self.drug_attentions1 = [GraphAttentionLayer(ndrugfeat, nhid, alpha=alpha, concat=True) for _ in range(nheads)]
        for i, attention in enumerate(self.drug_attentions1):
            self.add_module("Attention_Drug1_{}".format(i), attention)

        self.drug_MultiHead1 = [selfattention(64,nheads) for _ in range(nheads)]
        #self.drug_MultiHead1 = [selfattention(ndrug, nhid, ndrug) for _ in range(nheads)]
        for i, attention in enumerate(self.drug_MultiHead1):
            self.add_module("Self_Attention_Drug1_{}".format(i), attention)

        self.drug_prolayer1 = nn.Linear((nhid * nheads), (nhid * nheads), bias=False)
        self.drug_LNlayer1 = nn.LayerNorm(nhid * nheads)
        self.drug_attentions2 = [GraphAttentionLayer((nhid * nheads), nhid, alpha=alpha, concat=True) for _ in range(nheads)]
        for i, attention in enumerate(self.drug_attentions2):
            self.add_module("Attention_Drug2_{}".format(i), attention)

        self.drug_MultiHead2 = [selfattention(64,nheads) for _ in range(nheads)]
        #self.drug_MultiHead2 = [selfattention(ndrug, nhid, ndrug) for _ in range(nheads)]
        for i, attention in enumerate(self.drug_MultiHead2):
            self.add_module("Self_Attention_Drug2_{}".format(i), attention)

        self.drug_prolayer2 = nn.Linear((nhid * nheads), (nhid * nheads), bias=False)
        self.drug_LNlayer2 = nn.LayerNorm(nhid * nheads)
        self.FClayer1 = nn.Linear(nhid * nheads * 2, nhid * nheads * 2)
        self.FClayer2 = nn.Linear(nhid * nheads * 2, nhid * nheads * 2)
        self.FClayer3 = nn.Linear(nhid * nheads * 2, 1)
        self.output = nn.Sigmoid()

    def forward(self, protein_features, protein_adj, drug_features, drug_adj, idx_protein_drug, device):
        proteinx = torch.cat([att(protein_features, protein_adj) for att in self.protein_attentions1], dim=1)
        proteinx = self.protein_prolayer1(proteinx)
        #print("Here")
        proteinayer = proteinx
        temp = torch.zeros_like(proteinx)
        for selfatt in self.protein_MultiHead1:
            #print(f"Here1, proteinx.shape: {proteinx.shape}")
            temp = temp + selfatt(proteinx.unsqueeze(0))
        #print(f"{proteinx.shape}")
        proteinx = temp + proteinayer
        #print("Here2")
        proteinx = self.protein_LNlayer1(proteinx)
        #print(f"{proteinx.shape}, {protein_adj.shape}")
        proteinx = torch.cat([att(proteinx[0], protein_adj) for att in self.protein_attentions2], dim=1)
        proteinx = self.protein_prolayer2(proteinx)
        proteinayer = proteinx
        temp = torch.zeros_like(proteinx)
        for selfatt in self.protein_MultiHead2:
            temp = temp + selfatt(proteinx.unsqueeze(0))
        proteinx = temp + proteinayer
        #print("Here3")
        proteinx = self.protein_LNlayer2(proteinx)
        drugx = torch.cat([att(drug_features, drug_adj) for att in self.drug_attentions1], dim=1)
        drugx = self.drug_prolayer1(drugx)
        druglayer = drugx
        temp = torch.zeros_like(drugx)
        for selfatt in self.drug_MultiHead1:
            temp = temp + selfatt(drugx.unsqueeze(0))

        drugx = temp + druglayer
        drugx = self.drug_LNlayer1(drugx)
        drugx = torch.cat([att(drugx[0], drug_adj) for att in self.drug_attentions2], dim=1)
        drugx = self.drug_prolayer2(drugx.unsqueeze(0))
        druglayer = drugx
        temp = torch.zeros_like(drugx)
        #print(f"Here4: drugx.shape {drugx.shape}")
        for selfatt in self.drug_MultiHead2:
            temp = temp + selfatt(drugx)
        drugx = temp + druglayer
        drugx = self.drug_LNlayer2(drugx)
        #print("drug:", idx_protein_drug[:, 1])
        #print("protein:", idx_protein_drug[:, 0])
        #print("Here5")
        #print(f"proteinx.shape: {proteinx.shape}, drugx.shape: {drugx.shape}")
        proteinx = proteinx.squeeze(0)
        drugx = drugx.squeeze(0)
        protein_drug_x = torch.cat((proteinx[idx_protein_drug[:, 0]], drugx[idx_protein_drug[:, 1]]), dim=1)#------------>Error Here
        protein_drug_x = protein_drug_x.to(device)
        protein_drug_x = self.FClayer1(protein_drug_x)
        protein_drug_x = F.relu(protein_drug_x)
        protein_drug_x = self.FClayer2(protein_drug_x)
        protein_drug_x = F.relu(protein_drug_x)
        protein_drug_x = self.FClayer3(protein_drug_x)
        protein_drug_x = protein_drug_x.squeeze(-1)
        return protein_drug_x

In [9]:
import os
import time
import psutil
import random
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
from torch.autograd import Variable
from sklearn.metrics import r2_score
from scipy.stats import pearsonr

from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error
from sklearn.metrics import roc_auc_score, average_precision_score

#from utils import load_data
#from model import MAGNET-DTI
import torch.utils.data as Dataset

from datetime import datetime
import csv

# Training settings
# parser = argparse.ArgumentParser()
# parser.add_argument('--seed', type=int, default=0, help='Random seed.')
# parser.add_argument('--epochs', type=int, default=30, help='Number of epochs to train.')
# parser.add_argument('--lr', type=float, default=0.0001, help='Initial learning rate.')
# parser.add_argument('--batch', type=int, default=128, help='Number of batch size')
# parser.add_argument('--hidden', type=int, default=16, help='Number of hidden units.')
# parser.add_argument('--nb_heads', type=int, default=4, help='Number of head attentions.')
# parser.add_argument('--alpha', type=float, default=0.2, help='Alpha for the leaky_relu.')
# parser.add_argument('--patience', type=int, default=5, help='Patience')

seed_value=0
epochs=10
lr=0.0001
batch=64
hidden=16
nb_heads=4
alpha=0.2
patience=5

#args = parser.parse_args()

random.seed(seed_value)
np.random.seed(seed_value)
torch.manual_seed(seed_value)

if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.cuda.manual_seed(seed_value)
    used_memory = torch.cuda.memory_allocated()
    cached_memory = torch.cuda.memory_reserved()
    print(f"GPU success，server GPU assigned：{used_memory / 1024 ** 3:.2f} GB，Cached：{cached_memory / 1024 ** 3:.2f} GB")
else:
    device = torch.device("cpu")

GPU success，server GPU assigned：0.00 GB，Cached：0.00 GB


In [10]:
import time
import numpy as np
import csv
from tqdm.notebook import tqdm
import wandb
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score, roc_auc_score, average_precision_score
from scipy.stats import pearsonr


def train(epoch, index_tra, y_tra, index_val, y_val):
    time_begin = time.time()

    output_train = [model_date]
    output_valid = [model_date]

    print("training set size", len(y_tra))

    tra_dataset = Dataset.TensorDataset(index_tra, y_tra)
    train_dataset = Dataset.DataLoader(tra_dataset, batch_size=batch, shuffle=True)
    print("batch size", len(train_dataset))

    model.train()

    # Training loop with tqdm progress bar
    for index_trian, y_train in tqdm(train_dataset, desc=f"Training Epoch {epoch+1}", unit="batch"):
        y_train = y_train.to(device)

        y_tpred = model(protein_features, protein_adj, drug_features, drug_adj, index_trian.numpy().astype(int), device)
        loss_train = loss_func(y_tpred, y_train)

        optimizer.zero_grad()
        loss_train.backward()
        optimizer.step()

        # Log training metrics to W&B
        wandb.log({"train_loss": loss_train.item(), "epoch": epoch + 1})

    model.eval()

    loss_valid, RMSE_valid, PCC_valid, R2_valid = [], [], [], []
    val_dataset = Dataset.TensorDataset(index_val, y_val)
    valid_dataset = Dataset.DataLoader(val_dataset, batch_size=batch, shuffle=True)

    pred_valid, true_valid = [], []

    # Validation loop with tqdm progress bar
    for index_valid, y_valid in tqdm(valid_dataset, desc=f"Validation Epoch {epoch+1}", unit="batch"):
        y_valid = y_valid.to(device)

        y_vpred = model(protein_features, protein_adj, drug_features, drug_adj, index_valid.numpy().astype(int), device)
        loss = loss_func(y_vpred, y_valid)

        pred_valid.extend(y_vpred.cpu().detach().numpy())
        true_valid.extend(y_valid.cpu().detach().numpy())

    time_over = time.time()

    # Compute validation metrics
    loss_valid = mean_squared_error(true_valid, pred_valid)
    RMSE_valid = np.sqrt(loss_valid)
    MAE_valid = mean_absolute_error(true_valid, pred_valid)
    PCC_valid = pearsonr(true_valid, pred_valid)[0]
    R2_valid = r2_score(true_valid, pred_valid)
    AUC_valid = roc_auc_score(true_valid ,pred_valid)
    AUPR_valid = average_precision_score(true_valid,pred_valid)

    # Log validation metrics to W&B
    wandb.log({
        "val_loss": loss_valid,
        "val_RMSE": RMSE_valid,
        "val_MAE": MAE_valid,
        "val_PCC": PCC_valid,
        "val_R2": R2_valid,
        "val_AUC": AUC_valid,
        "val_AUPR": AUPR_valid,
        "epoch": epoch + 1,
        "time_elapsed": time_over - time_begin,
    })

    # Save validation results to CSV
    output_valid.extend([epoch+1, loss_valid, RMSE_valid, MAE_valid, PCC_valid,
                         R2_valid, AUC_valid, AUPR_valid, time_over-time_begin])

    with open('/content/output/records_valid.csv', mode='a', newline='') as file:
         writer = csv.writer(file)
         writer.writerow(output_valid)

    # Compute training metrics for the last batch
    pred_train = y_tpred.cpu().detach().numpy()
    true_train = y_train.cpu().detach().numpy()

    RMSE_train = np.sqrt(loss_train.item())
    MAE_train = mean_absolute_error(true_train, pred_train)
    PCC_train = pearsonr(true_train, pred_train)[0]
    R2_train = r2_score(true_train, pred_train)

    AUC_train = roc_auc_score(true_train,pred_train) if len(np.unique(true_train)) > 1 else float('nan')
    AUPR_train = average_precision_score(true_train,pred_train) if len(np.unique(true_train)) > 1 else float('nan')

    output_train.extend([epoch+1, loss_train.item(), RMSE_train,
                         MAE_train,PCC_train,R2_train,AUC_train,AUPR_train,time_over-time_begin])

    with open('/content/output/records_train.csv', mode='a', newline='') as file:
         writer = csv.writer(file)
         writer.writerow(output_train)

    # Log training metrics to W&B for the last batch of the epoch
    wandb.log({
         "train_RMSE": RMSE_train,
         "train_MAE": MAE_train,
         "train_PCC": PCC_train,
         "train_R2": R2_train,
         "train_AUC": AUC_train,
         "train_AUPR": AUPR_train,
         "epoch": epoch + 1,
     })

    print('Epoch: {:04d}'.format(epoch + 1),
           '\n loss_train: {:.4f}'.format(loss_train.item()),
           'RMSE_train: {:.4f}'.format(RMSE_train),
           'MAE_train: {:.4f}'.format(MAE_train),
           'PCC_train: {:.4f}'.format(PCC_train),
           'R2_train: {:.4f}'.format(R2_train),
           'AUC_train: {:.4f}'.format(AUC_train),
           'AUPR_train: {:.4f}'.format(AUPR_train),
           '\n loss_valid: {:.4f}'.format(loss_valid),
           'RMSE_valid: {:.4f}'.format(RMSE_valid),
           'MAE_valid: {:.4f}'.format(MAE_valid),
           'PCC_valid: {:.4f}'.format(PCC_valid),
           'R2_valid: {:.4f}'.format(R2_valid),
           'AUC_valid: {:.4f}'.format(AUC_valid),
           'AUPR_valid: {:.4f}'.format(AUPR_valid ),
           'time: {:.4f}s'.format(time_over-time_begin))

    if AUC_valid >= best_value[0] and AUPR_valid >= best_value[1]:
         best_value[0] = AUC_valid
         best_value[1] = AUPR_valid
         best_value[2] = epoch + 1

         torch.save(model.state_dict(), f"/content/output/models_{model_date}.pkl")

    return best_value[2], AUC_valid

# End of function train()

In [11]:
def append_results_to_df(df, model_date, RMSE_test, MAE_test, PCC_test, R2_test, AUC_test, AUPR_test):
    new_row = pd.DataFrame({
        'model_date': [model_date],
        'RMSE_test': [RMSE_test],
        'MAE_test': [MAE_test],
        'PCC_test': [PCC_test],
        'R2_test': [R2_test],
        'AUC_test': [AUC_test],
        'AUPR_test': [AUPR_test]
    })
    return pd.concat([df, new_row], ignore_index=True)

def compute_test(index_test, y_test):
    eval = {}
    model.eval()
    loss_test, PCC_test, RMSE_test, R2_test = [], [], [], []
    pred_test, true_test = [], []
    dataset = Dataset.TensorDataset(index_test, y_test)
    test_dataset = Dataset.DataLoader(dataset, batch_size=batch, shuffle=True)
    for index_test, y_test in test_dataset:
        y_test = y_test.to(device)
        #print(f"protein_features: {protein_features.shape}, protein_adj: {protein_adj.shape}, drug_features: {drug_features.shape}, drug_adj:{drug_adj.shape}")
        y_pred = model(protein_features, protein_adj, drug_features, drug_adj, index_test.numpy().astype(int), device)
        loss_test = loss_func(y_pred, y_test)
        pred_test.extend(y_pred.cpu().detach().numpy())
        true_test.extend(y_test.cpu().detach().numpy())

    loss_test = mean_squared_error(true_test, pred_test)
    RMSE_test = np.sqrt(loss_test)
    MAE_test = mean_absolute_error(true_test, pred_test)
    PCC_test = pearsonr(true_test, pred_test)[0]
    R2_test = r2_score(true_test, pred_test)
    AUC_test = roc_auc_score(true_test,pred_test)
    AUPR_test = average_precision_score(true_test,pred_test)

    eval_df = pd.DataFrame(columns=['model_date', 'RMSE_test', 'MAE_test', 'PCC_test', 'R2_test', 'AUC_test', 'AUPR_test'])
    eval_df = append_results_to_df(eval_df, model_date, RMSE_test, MAE_test, PCC_test, R2_test, AUC_test, AUPR_test)

    print("Test set results:",
          "\n loss_test: {:.4f}".format(loss_test),
          "RMSE_test: {:.4f}".format(RMSE_test),
          'MAE_test: {:.4f}'.format(MAE_test),
          "PCC_test: {:.4f}".format(PCC_test),
          "R2_test: {:.4f}".format(R2_test),
          "AUC_test: {:.4f}".format(AUC_test),
          "AUPR_test: {:.4f}".format(AUPR_test)
          )
    return eval_df

In [12]:
# Load data
protein_features, protein_adj, drug_features, drug_adj, sample_set = load_data()
protein_features, protein_adj = Variable(protein_features), Variable(protein_adj)
drug_features, drug_adj = Variable(drug_features), Variable(drug_adj)
protein_features, protein_adj = protein_features.to(device), protein_adj.to(device)
drug_features, drug_adj = drug_features.to(device), drug_adj.to(device)

used_memory = torch.cuda.memory_allocated()  # Amount of GPU memory used
cached_memory = torch.cuda.memory_reserved()   # Amount of GPU memory cached
print(f"Data uploaded successfully，server GPU assigned：{used_memory / 1024**3:.2f} GB，Cached：{cached_memory / 1024**3:.2f} GB")



drug protein lable: torch.Size([6552, 3])
Data uploaded successfully，server GPU assigned：0.03 GB，Cached：0.04 GB


In [13]:
# Model and optimizer
model = MAGNETDTI(nprotein=protein_features.shape[0],
                  ndrug=drug_features.shape[0],
                  nproteinfeat=protein_features.shape[1],
                  ndrugfeat=drug_features.shape[1],
                  nhid= hidden,
                  nheads= nb_heads,
                  alpha= alpha)
optimizer = optim.Adam(model.parameters(), lr=lr)

loss_func = nn.MSELoss()
loss_func.to(device)
used_memory = torch.cuda.memory_allocated()  # Amount of GPU memory used
cached_memory = torch.cuda.memory_reserved()   #Amount of GPU memory cached
print(f"loss function uploaded successfully，server GPU assigned：{used_memory / 1024**3:.2f} GB，Cached：{cached_memory / 1024**3:.2f} GB")
best_value = [0, 0, 1]
model_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

loss function uploaded successfully，server GPU assigned：0.03 GB，Cached：0.04 GB


In [14]:
protein_features.shape, protein_adj.shape, drug_features.shape, drug_adj.shape

(torch.Size([1876, 256]),
 torch.Size([1876, 1876]),
 torch.Size([1767, 256]),
 torch.Size([1767, 1767]))

In [15]:
time_begin = time.time()

train_set, test_set = train_test_split(np.arange(sample_set.shape[0]), test_size=0.1,
                                       random_state=np.random.randint(0, 1000))
train_set, valid_set = train_test_split(train_set, test_size=1 / 9, random_state=np.random.randint(0, 1000))

index_train, y_train = sample_set[train_set[:], :2], sample_set[train_set[:], 2]
index_valid, y_valid = sample_set[valid_set[:], :2], sample_set[valid_set[:], 2]
index_test, y_test = sample_set[test_set[:], :2], sample_set[test_set[:], 2]
y_train, y_test, y_valid = Variable(y_train, requires_grad=True), Variable(y_test, requires_grad=True), Variable(
    y_valid, requires_grad=True)

index_train.shape, y_train.shape, index_valid.shape, y_valid.shape

(torch.Size([5240, 2]),
 torch.Size([5240]),
 torch.Size([656, 2]),
 torch.Size([656]))

In [16]:
model.to(device)
auc_valid = [0]
bad_counter = 0
# Initialize W&B project
wandb.init(project="magnet_dti", config={
    "batch_size": batch,
    "learning_rate": optimizer.param_groups[0]['lr'],
    "epochs": epochs,
})

wandb.watch(model, log="all")  # Log gradients and model parameters

for epoch in range(epochs):
    best_epoch, avg_auc_valid = train(epoch, index_train, y_train, index_valid, y_valid)
    auc_valid.append(avg_auc_valid)

    if abs(auc_valid[-1] - auc_valid[-2]) < 0.0005:
        bad_counter += 1
    else:
        bad_counter = 0

    if bad_counter >= patience:
        break

print("Optimization Finished. Total time: {:.4f}s".format(time.time() - time_begin))
print('Loading {}th epoch'.format(best_epoch))
model.load_state_dict(torch.load('./output/models_{}.pkl'.format(model_date)))
# Testing
eval_test = compute_test(index_test, y_test)
time_total = time.time() - time_begin
print("Total time: {:.4f}s".format(time_total))

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


training set size 5240
batch size 82


Training Epoch 1:   0%|          | 0/82 [00:00<?, ?batch/s]

Validation Epoch 1:   0%|          | 0/11 [00:00<?, ?batch/s]

Epoch: 0001 
 loss_train: 0.1808 RMSE_train: 0.4252 MAE_train: 0.3719 PCC_train: 0.5271 R2_train: 0.2758 AUC_train: 0.8008 AUPR_train: 0.8435 
 loss_valid: 0.1717 RMSE_valid: 0.4144 MAE_valid: 0.3644 PCC_valid: 0.5661 R2_valid: 0.2803 AUC_valid: 0.8222 AUPR_valid: 0.7885 time: 28.8703s
training set size 5240
batch size 82


Training Epoch 2:   0%|          | 0/82 [00:00<?, ?batch/s]

Validation Epoch 2:   0%|          | 0/11 [00:00<?, ?batch/s]

Epoch: 0002 
 loss_train: 0.0717 RMSE_train: 0.2678 MAE_train: 0.1996 PCC_train: 0.8604 R2_train: 0.7072 AUC_train: 0.9622 AUPR_train: 0.9700 
 loss_valid: 0.0984 RMSE_valid: 0.3137 MAE_valid: 0.2255 PCC_valid: 0.7685 R2_valid: 0.5877 AUC_valid: 0.9223 AUPR_valid: 0.9191 time: 27.3022s
training set size 5240
batch size 82


Training Epoch 3:   0%|          | 0/82 [00:00<?, ?batch/s]

Validation Epoch 3:   0%|          | 0/11 [00:00<?, ?batch/s]

Epoch: 0003 
 loss_train: 0.0608 RMSE_train: 0.2465 MAE_train: 0.1792 PCC_train: 0.8742 R2_train: 0.7490 AUC_train: 0.9605 AUPR_train: 0.9685 
 loss_valid: 0.0879 RMSE_valid: 0.2964 MAE_valid: 0.2155 PCC_valid: 0.8034 R2_valid: 0.6317 AUC_valid: 0.9433 AUPR_valid: 0.9363 time: 27.2956s
training set size 5240
batch size 82


Training Epoch 4:   0%|          | 0/82 [00:00<?, ?batch/s]

Validation Epoch 4:   0%|          | 0/11 [00:00<?, ?batch/s]

Epoch: 0004 
 loss_train: 0.1028 RMSE_train: 0.3206 MAE_train: 0.2166 PCC_train: 0.7057 R2_train: 0.4962 AUC_train: 0.9047 AUPR_train: 0.8467 
 loss_valid: 0.0796 RMSE_valid: 0.2821 MAE_valid: 0.1799 PCC_valid: 0.8190 R2_valid: 0.6666 AUC_valid: 0.9535 AUPR_valid: 0.9466 time: 27.2631s
training set size 5240
batch size 82


Training Epoch 5:   0%|          | 0/82 [00:00<?, ?batch/s]

Validation Epoch 5:   0%|          | 0/11 [00:00<?, ?batch/s]

Epoch: 0005 
 loss_train: 0.0764 RMSE_train: 0.2764 MAE_train: 0.1891 PCC_train: 0.8259 R2_train: 0.6796 AUC_train: 0.9626 AUPR_train: 0.9544 
 loss_valid: 0.0720 RMSE_valid: 0.2684 MAE_valid: 0.1716 PCC_valid: 0.8355 R2_valid: 0.6980 AUC_valid: 0.9581 AUPR_valid: 0.9514 time: 27.4618s
training set size 5240
batch size 82


Training Epoch 6:   0%|          | 0/82 [00:00<?, ?batch/s]

Validation Epoch 6:   0%|          | 0/11 [00:00<?, ?batch/s]

Epoch: 0006 
 loss_train: 0.0653 RMSE_train: 0.2555 MAE_train: 0.1772 PCC_train: 0.8595 R2_train: 0.7263 AUC_train: 0.9666 AUPR_train: 0.9742 
 loss_valid: 0.0719 RMSE_valid: 0.2682 MAE_valid: 0.1677 PCC_valid: 0.8384 R2_valid: 0.6985 AUC_valid: 0.9584 AUPR_valid: 0.9532 time: 27.3155s
training set size 5240
batch size 82


Training Epoch 7:   0%|          | 0/82 [00:00<?, ?batch/s]

Validation Epoch 7:   0%|          | 0/11 [00:00<?, ?batch/s]

Epoch: 0007 
 loss_train: 0.0423 RMSE_train: 0.2056 MAE_train: 0.1476 PCC_train: 0.9135 R2_train: 0.8197 AUC_train: 0.9891 AUPR_train: 0.9858 
 loss_valid: 0.0683 RMSE_valid: 0.2613 MAE_valid: 0.1677 PCC_valid: 0.8457 R2_valid: 0.7139 AUC_valid: 0.9608 AUPR_valid: 0.9493 time: 27.3065s
training set size 5240
batch size 82


Training Epoch 8:   0%|          | 0/82 [00:00<?, ?batch/s]

Validation Epoch 8:   0%|          | 0/11 [00:00<?, ?batch/s]

Epoch: 0008 
 loss_train: 0.0653 RMSE_train: 0.2555 MAE_train: 0.1563 PCC_train: 0.8634 R2_train: 0.7386 AUC_train: 0.9719 AUPR_train: 0.9750 
 loss_valid: 0.0638 RMSE_valid: 0.2527 MAE_valid: 0.1505 PCC_valid: 0.8559 R2_valid: 0.7324 AUC_valid: 0.9652 AUPR_valid: 0.9563 time: 27.3626s
training set size 5240
batch size 82


Training Epoch 9:   0%|          | 0/82 [00:00<?, ?batch/s]

Validation Epoch 9:   0%|          | 0/11 [00:00<?, ?batch/s]

Epoch: 0009 
 loss_train: 0.0507 RMSE_train: 0.2251 MAE_train: 0.1302 PCC_train: 0.8971 R2_train: 0.7930 AUC_train: 0.9909 AUPR_train: 0.9889 
 loss_valid: 0.0623 RMSE_valid: 0.2496 MAE_valid: 0.1423 PCC_valid: 0.8607 R2_valid: 0.7390 AUC_valid: 0.9655 AUPR_valid: 0.9568 time: 27.2958s
training set size 5240
batch size 82


Training Epoch 10:   0%|          | 0/82 [00:00<?, ?batch/s]

Validation Epoch 10:   0%|          | 0/11 [00:00<?, ?batch/s]

Epoch: 0010 
 loss_train: 0.0391 RMSE_train: 0.1977 MAE_train: 0.1327 PCC_train: 0.9220 R2_train: 0.8404 AUC_train: 0.9948 AUPR_train: 0.9929 
 loss_valid: 0.0603 RMSE_valid: 0.2456 MAE_valid: 0.1425 PCC_valid: 0.8645 R2_valid: 0.7472 AUC_valid: 0.9688 AUPR_valid: 0.9617 time: 27.2902s
Optimization Finished. Total time: 359.3205s
Loading 10th epoch


  model.load_state_dict(torch.load('./output/models_{}.pkl'.format(model_date)))


Test set results: 
 loss_test: 0.0685 RMSE_test: 0.2617 MAE_test: 0.1600 PCC_test: 0.8473 R2_test: 0.7178 AUC_test: 0.9695 AUPR_test: 0.9642
Total time: 361.0499s


  return pd.concat([df, new_row], ignore_index=True)


In [18]:
eval_test.head()

Unnamed: 0,model_date,RMSE_test,MAE_test,PCC_test,R2_test,AUC_test,AUPR_test
0,2024-12-13 07:02:21,0.261734,0.159984,0.847308,0.717755,0.969526,0.964155
