In [4]:
import torch as t
from torch import nn, optim
import numpy as np
from model import *
from multiprocess import Process, Queue
import argparse
from tqdm import tqdm
from torch.nn import BCEWithLogitsLoss
from dgl import NID, EID
from dgl.dataloading import GraphDataLoader
from torch.utils.data import Dataset, DataLoader
from utils import parse_arguments
from utils import *#load_ogb_dataset,load_hmdd_dataset, evaluate_hits,standardize_dir
from sampler import SEALData,SEALData_case_study
from model import AE, GCN, DGCNN
from logger import LightLogging
import pdb
import os.path as osp
from eval_irgat import *

np.random.seed(1337)
device = t.device('cuda:0')#('cuda' if t.cuda.is_available() else 'cpu')

def save2File(input, out_path):
    if str(type(input)).find('ndarray') < 0:
        input = np.array(input)
    df = pd.DataFrame(input)
    df.to_csv(out_path, header=False, index=False)
    
def adjMatrix2list(csv_path, dis_id):
    vals = pd.read_csv(csv_path,header = None).values
    n_miRNA = vals.shape[0]
    n_disease = vals.shape[1]

    # get possivite known association, put into pos_assoc
    tmp_res1 = np.where(vals == 1)
    pos_assoc = list(zip(tmp_res1[0], tmp_res1[1]))

    
    test_lbl = vals[:,dis_id].tolist()
    test_assoc = list(zip(range(0,n_miRNA), [dis_id]*n_miRNA))
    
    return pos_assoc, test_assoc,test_lbl, n_miRNA, n_disease

def gen_train_test(dis_id, args):
    '''train(all known associations)
    '''
    adj_path = args.data_dir + 'm-d.csv'
    train_pos_assoc, test_assoc,test_lbl, n_miRNA, n_disease = adjMatrix2list(adj_path, dis_id)
    print('len(train_pos_assoc:', len(train_pos_assoc), 'len(test_assoc):', len(test_assoc))
         
    train_file_path    =  args.fold_dir + 'train.csv'
    test_file_path     = args.fold_dir + 'test.csv'
    test_lbl_file_path = args.fold_dir + 'test_lbl.csv'
    
    #if not osp.exists(train_file_path): 
    save2File(train_pos_assoc, train_file_path)
    
    #if not osp.exists(test_file_path):        
    save2File(test_assoc, test_file_path)
              
    #if not osp.exists(test_lbl_file_path):        
    save2File(test_lbl, test_lbl_file_path)
    
    return
def case_study(dis_id,args):
    gen_train_test(dis_id,args)
    prefix=args.fold_dir
    graph, split_edge = load_hmdd_dataset_case_study(args.data_dir,prefix,0.)
    if args.sim_type == 'none':
        features = None
        m_d_path = args.data_dir + 'm-d.csv'
        m_d = pd.read_csv(m_d_path, header=None).values
        n_miRNA,n_disease = m_d.shape[0],m_d.shape[1]
    else:
        miRNA_sim, disease_sim = cal_faulty_sim(args)
    
        features = dict()
        disease_sim_tensor = t.FloatTensor(disease_sim)
        disease_sim_tensor = disease_sim_tensor.to(device)        
        features['d'] = disease_sim_tensor#, 'edge_index': dd_edge_index}

        miRNA_sim_tensor = t.FloatTensor(miRNA_sim)
        miRNA_sim_tensor = miRNA_sim_tensor.to(device)
        features['m'] = miRNA_sim_tensor#, 'edge_index': mm_edge_index}

        features = train_features(features,args)
        n_miRNA,n_disease = miRNA_sim.shape[0],disease_sim.shape[0]

    seal_data = SEALData_case_study(g=graph, split_edge=split_edge, n_miRNA=n_miRNA, n_disease=n_disease, hop=args.hop,\
                         neg_samples=args.neg_samples, subsample_ratio=args.subsample_ratio, use_coalesce=False,\
                         random_seed=args.randseed, prefix=args.dataset+str(args.randseed), \
                         save_dir=args.save_dir, num_workers=args.num_workers, print_fn=print)
    num_nodes = graph.num_nodes()
    node_attribute = None #jicm
    edge_weight = None #jicm

    train_data = seal_data('train')
    test_data = seal_data('test')#only use the #miRNA for disease-id
    print("len of train_data: ", len(train_data), ",len of test data: ", len(test_data))

    train_loader = GraphDataLoader(train_data, batch_size=args.batch_size)
    test_loader = GraphDataLoader(test_data, batch_size=args.batch_size)
    if  'irgat-gcn' in args.method:
        model = GCN(num_layers=args.num_layers,
                    hidden_units=args.hidden_units,
                    gcn_type=args.gcn_type,
                    pooling_type=args.pooling,
                    node_attributes=features,#node_attribute,
                    edge_weights=edge_weight,
                    node_embedding=None,
                    use_embedding=False,
                    num_nodes=num_nodes,
                    dropout=args.dropout)
    elif  'irgat-dgcnn' in args.method:
        model = DGCNN(num_layers=args.num_layers, 
                    hidden_units=args.hidden_units,
                    k=args.sort_k,
                    gcn_type=args.gcn_type,
                    node_attributes=features,
                    edge_weights=edge_weight,
                    node_embedding=None,
                    use_embedding=False,
                    num_nodes=num_nodes,
                    dropout=args.dropout)
    else:
        print('Invalid method name, please input the right name!')
        return
    
    model = model.to(device)
    parameters = model.parameters()
    optimizer  = t.optim.Adam(parameters, lr=args.lr)#, weight_decay=args.decay)
    scheduler  = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min',factor=0.1, patience=4, verbose=True) 
    
    regression_crit = Myloss()
    model = model.to(device)
    regression_crit = regression_crit.to(device)
    
    model.train()
    loss_fn = BCEWithLogitsLoss()
    for epoch in range(0, args.epochs):    
        total_loss = 0
        pbar = tqdm(train_loader,ncols=100)
        for i, (g, labels) in enumerate(pbar):
            g = g.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            
            logits = model(g, g.ndata['z'], g.ndata[NID], g.edata[EID])
            loss = loss_fn(logits, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * args.batch_size#
            pbar.set_postfix({'loss:':loss.item()})
        
        scheduler.step(total_loss)
   
    test_pred_lbl, test_pair_lbl = evaluate(model=model,dataloader=test_loader,device=device)#jicm 
        
    #save scores
    if args.save_score:
        score_save_dir = args.result_dir + str(args.randseed)
        score_save_dir = standardize_dir(score_save_dir)
        score_path = score_save_dir  + 'case_'+str(dis_id) + '_' + args.method + '_' + args.sim_type  +'.csv'
        save_scores(test_pred_lbl.tolist(), test_pair_lbl.tolist(), score_path)   

def query_topK(args,dis_id=236, k=50):
    score_save_dir = args.result_dir + str(args.randseed)
    score_save_dir = standardize_dir(score_save_dir)
    score_path = score_save_dir  + 'case_'+str(dis_id) + '_' + args.method + '_' + args.sim_type  +'.csv'
    score = pd.read_csv(score_path, header=None)

    mirna_name =  pd.read_csv(osp.join(args.data_dir, 'mirna_name.csv'),header = None)[1].tolist()
    disease    = pd.read_csv(osp.join(args.data_dir, 'disease_name.csv'),header = None)[1].tolist()
    n_miRNA, n_disease = len(mirna_name), len(disease)
    y_pred, y_true = score[0].tolist(), score[1].tolist()

    loss_all = {}
    loss = [] #for one disease, (loss, #mirnas)
    for i in range(n_miRNA):#all mirna 
        loss_all[i] = y_pred[i]
        if y_true[i] > 0.5:
            loss.append(y_pred[i])

    cc=0
    loss_all = sorted(loss_all.items(), key=lambda d:d[1], reverse = True)
    pred_K = [mirna_name[ls[0]]for ls in loss_all[:k]]
    pred_K_path=args.result_dir+'case_dis_'+str(dis_id)+'_predK.csv'
    save2File(pred_K, pred_K_path)

    print("disease no: ", dis_id)
    for i in range(k):
        if loss_all[i][1] in loss:
            cc = cc + 1
            print(mirna_name[loss_all[i][0]])
        else:
            print(mirna_name[loss_all[i][0]], 'unknown')
    print("test disease no:{},matched:{}/{}".format(dis_id,cc,k))

def verify_case(pred_K, bench_in_one_dis):
    '''pred_K: for one disease, the top K predicted miRNAs
       bench_in_one_dis: for one disease, the dbdmec/miR2disease verificated assoc
    '''
    pred_rna  = pd.read_csv(pred_K,header = None)[0].tolist()
    bench_rna  = pd.read_csv(bench_in_one_dis,header = None)[0].tolist()
    bench_rna = [str(rna).lower() for rna in bench_rna]

    cc=0
    for i in range(len(pred_rna)):
        rna = pred_rna[i]
        for brna in bench_rna:
            if rna in brna:
                print("confirmed")
                cc = cc + 1
                break
        else:
            print("unknown")   
    print("{}/{} miRNA confirmed".format(cc,len(pred_rna)))

In [5]:
hmdd = 'HMDD3.2/'

parser = argparse.ArgumentParser(description='Subgraph neural networks learning for miRNA-disease association prediction')
parser.add_argument('--epochs', type=int, default=130, metavar='N', help='number of epochs to train')
parser.add_argument('--data_dir', default='../data/'+hmdd, help='dataset directory')
parser.add_argument('--dataset', default=hmdd, help='dataset name')
parser.add_argument('--fold_dir', default='../data/'+hmdd+'folds/', help='dataset directory')
parser.add_argument('--result_dir', default='../data/'+hmdd+'results_subsample_1.0/', help='saved result directory')
parser.add_argument('--method', default='irgat-dgcnn', help='method')
parser.add_argument('--save_score', default=True, help='whether to save the predicted score or not')
parser.add_argument('--sim_type', default='functional1', help='the miRNA and disease sim, pass in "functional2" for miRNA functional + disease semantic(with phenotype info added),'
                                                              '"none" for none miRNA and disease additional info used,'
                                                              '"functional1" for miRNA functional and disease semantic only,'
                                                              '"gip" for miRNA and disease GIP kernel similarity,'
                                                              '"seq" for miRNA sequence and disease semantic')
parser.add_argument('--randseed', default=112, help='the random seed')

parser.add_argument('--num_layers', type=int, default=3)
parser.add_argument('--hidden_units', type=int, default=32)
parser.add_argument('--gcn_type', type=str, default='gcn')
parser.add_argument('--pooling', type=str, default='sum')
parser.add_argument('--dropout', type=str, default=0.5)
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--decay', type=float, default=0.0001)
parser.add_argument('--hop', type=int, default=1)
parser.add_argument('--sort_k', type=int, default=20)
parser.add_argument('--neg_samples', type=int, default=1,help='negative samples per positive sample')
parser.add_argument('--subsample_ratio', type=float, default=1.)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--num_workers', type=int, default=8)
parser.add_argument('--random_seed', type=int, default=123)
parser.add_argument('--save_dir', type=str, default='./processed')
args = parser.parse_args(args=[])

   


In [None]:
dis_id =1 #1:colon neoplasms
args.sim_type = 'functional2'
case_study(dis_id,args)

query_topK(args,dis_id=dis_id, k=50)
pred_K, bench_in_one_dis = args.result_dir+'case_dis_'+str(dis_id)+'_predK.csv', args.result_dir+'case_dis_'+str(dis_id)+'_dbdemc.csv'
verify_case(pred_K, bench_in_one_dis)
bench_in_one_dis = args.result_dir+'case_dis_'+str(dis_id)+'_dbdemc.csv'
verify_case(pred_K, bench_in_one_dis)