In [1]:
import pathlib
import pickle
import math
import numpy as np
import scipy.sparse
import scipy.io
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn
from dgl.nn.pytorch import edge_softmax
import time
import argparse
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score

Using backend: pytorch


In [3]:
%run utils.ipynb
%run semantic-attention.ipynb
%run read_dataset.ipynb
%run multi-metapath-fusion.ipynb
%run model.ipynb
%run data2graph.ipynb

In [None]:
num_ntype = 3
dropout_rate = 0.5
lr = 0.005
weight_decay = 0.001
    
    #etypes_lists = [[[0, 1],[0, 2, 3, 1]],
     #               [[1, 0],[2, 3]]
etypes_lists = [[[0, 1],[0, 2, 3, 1],[4,None,5]],
                [[1, 0],[2, 3],[2, None, 3]]]
use_masks = [[True, True],
             [True, False]]
no_masks = [[False] * 3, [False] * 3]
    #use_masks = [[True, True, False],
    #             [True, False, True]]
    #no_masks = [[False] * 3, [False] * 3]
    
num_mir=1296
num_disease=11783
num_gene=10116
    
expected_metapaths = [
    [(0, 1, 0), (0, 1, 2, 1, 0), (0,2,2,0)],
    [(1, 0, 1), (1, 2, 1), (1, 2, 2, 1)]
    ]
    # f='hidden90_sample140.txt'
    
dataset_path = "dataset/"
log_info_path='log_info.txt'
    
def run_model_OURS(feats_type, hidden_dim, num_heads, attn_vec_dim,
                     num_epochs, patience, batch_size, neighbor_samples, repeat,save_postfix):
    adjlists_ua, edge_metapath_indices_list_ua, _, type_mask, train_val_test_pos_user_artist, train_val_test_neg_user_artist = load_LastFM_data(dataset_path)
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    features_list = []
    in_dims = []
    if feats_type == 0:
        for i in range(num_ntype):
            dim = (type_mask == i).sum()
            in_dims.append(dim)
            indices = np.vstack((np.arange(dim), np.arange(dim)))
            indices = torch.LongTensor(indices)
            values = torch.FloatTensor(np.ones(dim))
            features_list.append(torch.sparse.FloatTensor(indices, values, torch.Size([dim, dim])).to(device))
    elif feats_type == 1:
        for i in range(num_ntype):
            dim = 10
            num_nodes = (type_mask == i).sum()
            in_dims.append(dim)
            features_list.append(torch.zeros((num_nodes, 10)).to(device))
    train_pos_user_artist = train_val_test_pos_user_artist['train_pos_mir_disease']
    val_pos_user_artist = train_val_test_pos_user_artist['val_pos_mir_disease']
    test_pos_user_artist = train_val_test_pos_user_artist['test_pos_mir_disease']
    train_neg_user_artist = train_val_test_neg_user_artist['train_neg_mir_disease']
    val_neg_user_artist = train_val_test_neg_user_artist['val_neg_mir_disease']
    test_neg_user_artist = train_val_test_neg_user_artist['test_neg_mir_disease']
    y_true_test = np.array([1] * len(test_pos_user_artist) + [0] * len(test_neg_user_artist))

    auc_list = []
    ap_list = []

   # with open(f,"a") as file:
   #     file.write('单层dropout0.6+torch.mm(features,self.weight)'+"\n")
    for _ in range(repeat):
        net = MAGNN_lp(
            [3, 3], 6, etypes_lists, in_dims, hidden_dim, hidden_dim, num_heads, attn_vec_dim, dropout_rate)
        net.to(device)
        optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay)

        # training loop
        net.train()
        early_stopping = EarlyStopping(patience=patience,log_path=log_info_path, verbose=True, save_path='checkpoint01/checkpoint_{}.pt'.format(save_postfix))
        dur1 = []
        dur2 = []
        dur3 = []
        train_pos_idx_generator = index_generator(batch_size=batch_size, num_data=len(train_pos_user_artist))
        val_idx_generator = index_generator(batch_size=batch_size, num_data=len(val_pos_user_artist), shuffle=False)
        for epoch in range(num_epochs):
            t_start = time.time()
            # training
            net.train()
            for iteration in range(train_pos_idx_generator.num_iterations()):
                # forward
                t0 = time.time()

                train_pos_idx_batch = train_pos_idx_generator.next()
                train_pos_idx_batch.sort()
               
                train_pos_user_artist_batch = train_pos_user_artist[train_pos_idx_batch].tolist()
                train_neg_idx_batch = np.random.choice(len(train_neg_user_artist), len(train_pos_idx_batch))
                train_neg_idx_batch.sort()
                train_neg_user_artist_batch = train_neg_user_artist[train_neg_idx_batch].tolist()
                
                #shuffle
                num_pos = train_pos_idx_batch.shape[0]
                train_batch = np.concatenate([train_pos_user_artist_batch, train_neg_user_artist_batch], axis=0)
                y_label = np.zeros((train_batch.shape[0], 1), dtype=int)
                y_label[:num_pos] = 1
                train_data = np.concatenate([train_batch, y_label], axis=1)
                np.random.shuffle(train_data)
                train_batch = train_data[:, :-1]
                y_label = train_data[:, -1]

                train_g_lists, train_indices_lists, train_idx_batch_mapped_lists ,node_lists= parse_minibatch_LastFM(
                   adjlists_ua, edge_metapath_indices_list_ua, train_batch, device, neighbor_samples, no_masks, num_mir)
                t1 = time.time()
                dur1.append(t1 - t0)

                [embedding_user, embedding_artist], _ = net(
                    (train_g_lists, features_list, type_mask, train_indices_lists, train_idx_batch_mapped_lists,node_lists))
                
                embedding_user = embedding_user.view(-1, 1, embedding_user.shape[1])
                embedding_artist = embedding_artist.view(-1, embedding_artist.shape[1], 1)
                
                out = torch.bmm(embedding_user, embedding_artist)
                class_op = torch.LongTensor([1 if l == 1 else -1 for l in y_label]).view(-1, 1, 1).to(device)
                
                train_loss = -torch.mean(F.logsigmoid(out * class_op))#-torch.mean(F.logsigmoid(net.get_loss()*0.5))
                
                t2 = time.time()
                dur2.append(t2 - t1)

                # autograd
                optimizer.zero_grad()
                train_loss.backward()
                optimizer.step()

                t3 = time.time()
                dur3.append(t3 - t2)

                # print training info
                if iteration % 100 == 0:
                    print(
                        'Epoch {:05d} | Iteration {:05d} | Train_Loss {:.4f} | Time1(s) {:.4f} | Time2(s) {:.4f} | Time3(s) {:.4f}'.format(
                            epoch, iteration, train_loss.item(), np.mean(dur1), np.mean(dur2), np.mean(dur3)))
                    with open(log_info_path,"a") as file:
                        file.write('epoch:'+str(epoch)+'iteration:'+str(iteration)+'train_loss:'+str(train_loss.item())+"time1:"+ str(np.mean(dur1))+'time2'
                               +str(np.mean(dur2))+'time3'+str(np.mean(dur3))+"\n")    
            # validation
            net.eval()
            val_loss = []
            with torch.no_grad():
                for iteration in range(val_idx_generator.num_iterations()):
                    # forward
                    val_idx_batch = val_idx_generator.next()
                    val_pos_user_artist_batch = val_pos_user_artist[val_idx_batch].tolist()
                    val_neg_user_artist_batch = val_neg_user_artist[val_idx_batch].tolist()
                    val_pos_g_lists, val_pos_indices_lists, val_pos_idx_batch_mapped_lists,val_pos_node_lists = parse_minibatch_LastFM(
                        adjlists_ua, edge_metapath_indices_list_ua, val_pos_user_artist_batch, device, neighbor_samples, no_masks, num_mir)
                    val_neg_g_lists, val_neg_indices_lists, val_neg_idx_batch_mapped_lists,val_neg_node_lists= parse_minibatch_LastFM(
                        adjlists_ua, edge_metapath_indices_list_ua, val_neg_user_artist_batch, device, neighbor_samples, no_masks, num_mir)

                    [pos_embedding_user, pos_embedding_artist], _ = net(
                        (val_pos_g_lists, features_list, type_mask, val_pos_indices_lists, val_pos_idx_batch_mapped_lists,val_pos_node_lists))
                    [neg_embedding_user, neg_embedding_artist], _ = net(
                        (val_neg_g_lists, features_list, type_mask, val_neg_indices_lists, val_neg_idx_batch_mapped_lists,val_neg_node_lists))
                    pos_embedding_user = pos_embedding_user.view(-1, 1, pos_embedding_user.shape[1])
                    pos_embedding_artist = pos_embedding_artist.view(-1, pos_embedding_artist.shape[1], 1)
                    neg_embedding_user = neg_embedding_user.view(-1, 1, neg_embedding_user.shape[1])
                    neg_embedding_artist = neg_embedding_artist.view(-1, neg_embedding_artist.shape[1], 1)

                    pos_out = torch.bmm(pos_embedding_user, pos_embedding_artist)
                    neg_out = -torch.bmm(neg_embedding_user, neg_embedding_artist)
                    val_loss.append(-torch.mean(F.logsigmoid(pos_out) + F.logsigmoid(neg_out)))#-torch.mean(F.logsigmoid(net.get_loss()*0.5)))
                val_loss = torch.mean(torch.tensor(val_loss))
            t_end = time.time()
            # print validation info
            print('Epoch {:05d} | Val_Loss {:.4f} | Time(s) {:.4f}'.format(
                epoch, val_loss.item(), t_end - t_start))
            
            with open(log_info_path,"a") as file:
                file.write('epoch:'+str(epoch)+'val_loss:'+str(val_loss.item())+"time:"+str(t_end-t_start)+"\n")
            # early stopping
            early_stopping(val_loss, net)
            if early_stopping.early_stop:
                print('Early stopping!')
                with open(log_info_path,"a") as file:
                    file.write('epoch:'+str(epoch)+'early stopping!'+"\n")
                break
            
        test_idx_generator = index_generator(batch_size=batch_size, num_data=len(test_pos_user_artist), shuffle=False)
        net.load_state_dict(torch.load('checkpoint01/checkpoint_{}.pt'.format(save_postfix)))
        net.eval()
        pos_proba_list = []
        neg_proba_list = []
        with torch.no_grad():
            for iteration in range(test_idx_generator.num_iterations()):
                # forward
                test_idx_batch = test_idx_generator.next()
                test_pos_user_artist_batch = test_pos_user_artist[test_idx_batch].tolist()
                test_neg_user_artist_batch = test_neg_user_artist[test_idx_batch].tolist()
                test_pos_g_lists, test_pos_indices_lists, test_pos_idx_batch_mapped_lists,test_pos_node_lists = parse_minibatch_LastFM(
                    adjlists_ua, edge_metapath_indices_list_ua, test_pos_user_artist_batch, device, neighbor_samples, no_masks, num_mir)
                test_neg_g_lists, test_neg_indices_lists, test_neg_idx_batch_mapped_lists,test_neg_node_lists = parse_minibatch_LastFM(
                    adjlists_ua, edge_metapath_indices_list_ua, test_neg_user_artist_batch, device, neighbor_samples, no_masks, num_mir)

                [pos_embedding_user, pos_embedding_artist], _ = net(
                    (test_pos_g_lists, features_list, type_mask, test_pos_indices_lists, test_pos_idx_batch_mapped_lists,test_pos_node_lists))
                
                [neg_embedding_user, neg_embedding_artist], _ = net(
                    (test_neg_g_lists, features_list, type_mask, test_neg_indices_lists, test_neg_idx_batch_mapped_lists,test_neg_node_lists))
                pos_embedding_user = pos_embedding_user.view(-1, 1, pos_embedding_user.shape[1])
                pos_embedding_artist = pos_embedding_artist.view(-1, pos_embedding_artist.shape[1], 1)
                neg_embedding_user = neg_embedding_user.view(-1, 1, neg_embedding_user.shape[1])
                neg_embedding_artist = neg_embedding_artist.view(-1, neg_embedding_artist.shape[1], 1)

                pos_out = torch.bmm(pos_embedding_user, pos_embedding_artist).flatten()
                neg_out = torch.bmm(neg_embedding_user, neg_embedding_artist).flatten()
                pos_proba_list.append(torch.sigmoid(pos_out))
                neg_proba_list.append(torch.sigmoid(neg_out))
            y_proba_test = torch.cat(pos_proba_list + neg_proba_list)
            y_proba_test = y_proba_test.cpu().numpy()
        np.savez('MEAHNE_prediction_result.npz', y_true=y_true_test,
                y_pred=y_proba_test)    
        auc = roc_auc_score(y_true_test, y_proba_test)
        ap = average_precision_score(y_true_test, y_proba_test)
        print('Link Prediction Test')
        print('AUC = {}'.format(auc))
        print('AP = {}'.format(ap))
        auc_list.append(auc)
        ap_list.append(ap)
        with open(log_info_path,"a") as file:
                file.write('Link Prediction Test:'+"\n"+'AUC = {}:'+str(auc)+"|"+"AP = {}:"+str(ap)+"\n")
    print('----------------------------------------------------------------')
    print('Link Prediction Tests Summary')
    print('AUC_mean = {}, AUC_std = {}'.format(np.mean(auc_list), np.std(auc_list)))
    print('AP_mean = {}, AP_std = {}'.format(np.mean(ap_list), np.std(ap_list)))
    with open(log_info_path,"a") as file:
        file.write('Link Prediction Tests Summary:'+"\n"+"AUC_mean = {}"+str(np.mean(auc_list))+'AUC_std = {}'+str(np.std(auc_list))+
                           "\n"+'AP_mean = {} '+str(np.mean(ap_list))+'AP_std = {}'+str(np.std(ap_list))+"\n")

if __name__ == '__main__':
    run_model_OURS(0, 64,  1,128, 100,3,16,60,3, 'MEAHNE')
    