In [None]:

import warnings
warnings.filterwarnings('ignore')

import os
import os.path as osp
from time import gmtime, strftime
import argparse
import random
from datetime import datetime
import math
import csv
from csv import DictWriter
from tqdm import tqdm
import pandas as pd
import ast
import scipy.sparse as ssp
import numpy as np
import networkx as nx
import urllib
import io
import zipfile
import matplotlib.pyplot as plt

import tensorflow as tf

from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import normalize
from sklearn.model_selection import KFold
from sklearn.datasets import make_classification
from sklearn import metrics
from sklearn.metrics import (precision_recall_curve, average_precision_score,
                             PrecisionRecallDisplay, precision_score, recall_score, accuracy_score, f1_score)


import torch
from torch.nn import (ModuleList, Linear, Conv1d, MaxPool1d, Embedding, ReLU, BCEWithLogitsLoss,
                      Sequential, BatchNorm1d as BN)
import torch.nn.functional as F
from torch.utils.data import DataLoader

from torch_geometric.nn import Node2Vec

import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, SAGEConv, GAE
from torch_geometric.utils import (negative_sampling, add_self_loops,
                                   train_test_split_edges)
from torch_geometric.datasets import KarateClub, RelLinkPredDataset, LastFM
from torch_geometric.data import Data

from ogb.linkproppred import PygLinkPropPredDataset, Evaluator

from python_utils.logger import Logged

In [None]:
def dataset():
    path = "data/%s" %( args.dataset_name)
    
    isExist = os.path.exists(path + "/graph_info.csv")
    
    if not isExist:
        os.makedirs(path)

            
        A, A_train, A_test, data_x, n_v, n_e, train_message_edges, test_message_edges,\
        edge_pos_train, edge_neg_train, edge_pos_test, edge_neg_test =\
            data_loader(args.dataset_name, args.network_type, args.feature_type )


        data_train = Data(x = data_x, edge_index = train_message_edges, num_nodes = n_v)
        data_test = Data(x = data_x, edge_index = test_message_edges, num_nodes = n_v)
        
        
        
        num_subgraph_nodes = math.ceil((2*n_e/n_v)*(1+((2*n_e)/((n_v)*(n_v-1)))))#PLACN
        num_nodes = num_subgraph_nodes + math.ceil(0.5*num_subgraph_nodes) + 2
        
        arrayOfA = A.toarray()
        listOfA = arrayOfA.tolist()
        
        arrayOfA_train = A_train.toarray()
        listOfA_train = arrayOfA_train.tolist()
        
        arrayOfA_test = A_test.toarray()
        listOfA_test = arrayOfA_test.tolist()
        
        field_names = ['adj_matrix', 'adj_matrix_train', 'adj_matrix_test', 'data_x', 'num_vertices', 'num_edges',
                        'edge_index', 'train_message_edges', 'test_message_edges', 'edge_pos_train', 'edge_neg_train',
                        'edge_pos_test', 'edge_neg_test','num_subgraph_nodes']
        
        dict = {'adj_matrix':listOfA, 'adj_matrix_train':listOfA_train, 'adj_matrix_test':listOfA_test,
                'data_x':data_x.tolist(), 'num_vertices':n_v,  'num_edges':n_e,
                'train_message_edges':train_message_edges.tolist(), 'test_message_edges':test_message_edges.tolist(),
                'edge_pos_train':edge_pos_train.tolist(),
                'edge_neg_train':edge_neg_train.tolist(), 'edge_pos_test':edge_pos_test.tolist(),
                'edge_neg_test':edge_neg_test.tolist(), 'num_subgraph_nodes':num_subgraph_nodes}
        
        with open(path + "/graph_info.csv", mode='w') as f_object:
            dictwriter_object = DictWriter(f_object, fieldnames=field_names)
            dictwriter_object.writeheader()
            dictwriter_object.writerow(dict)
            f_object.close()
        train_edges_x = np.concatenate([edge_pos_train.T[0],edge_neg_train.T[0]])
        train_edges_y = np.concatenate([edge_pos_train.T[1],edge_neg_train.T[1]])
        subgraph_primitive(A_train,train_edges_x , train_edges_y, num_nodes=num_nodes, type_data='train' ) 
        
        test_edges_x = np.concatenate([edge_pos_test.T[0],edge_neg_test.T[0]])
        test_edges_y = np.concatenate([edge_pos_test.T[1],edge_neg_test.T[1]])
        subgraph_primitive(A_test, test_edges_x, test_edges_y, num_nodes=num_nodes, type_data='test' ) 

    else:
        
        df = pd.read_csv(path+"/graph_info.csv")

        A = ssp.csr_matrix(ast.literal_eval(df['adj_matrix'].dropna().values[0]))
        A_train = ssp.csr_matrix(ast.literal_eval(df['adj_matrix_train'].dropna().values[0]))
        A_test = ssp.csr_matrix(ast.literal_eval(df['adj_matrix_test'].dropna().values[0]))
        data_x = torch.tensor(ast.literal_eval(df['data_x'].dropna().values[0]))
        train_message_edges = torch.tensor(ast.literal_eval(df['train_message_edges'].dropna().values[0]))
        test_message_edges = torch.tensor(ast.literal_eval(df['test_message_edges'].dropna().values[0]))
        edge_pos_train = torch.tensor(ast.literal_eval(df['edge_pos_train'].dropna().values[0]))
        edge_neg_train = torch.tensor(ast.literal_eval(df['edge_neg_train'].dropna().values[0]))
        edge_pos_test = torch.tensor(ast.literal_eval(df['edge_pos_test'].dropna().values[0]))
        edge_neg_test = torch.tensor(ast.literal_eval(df['edge_neg_test'].dropna().values[0]))
        n_v = df['num_vertices'].dropna().values[0]
        n_e = df['num_edges'].dropna().values[0]
        num_subgraph_nodes = df['num_subgraph_nodes'].dropna().values[0]
        num_nodes = num_subgraph_nodes + math.ceil(0.5*num_subgraph_nodes) + 2
        
        data_train = Data(x = data_x, edge_index = train_message_edges, num_nodes = n_v)
        data_test = Data(x = data_x, edge_index = test_message_edges, num_nodes = n_v)
        
            
    
    print("number of nodes:", n_v)  
    print("number of edges:", n_e) 
    print("number of message edges in train:", len(train_message_edges))
    print("number of message edges in test:", len(test_message_edges))
    print("number of supervision edges in train:", len(edge_pos_train))
    print("number of supervision edges in test:", len(edge_pos_test))
    
    print("number of subgraph nodes :", num_subgraph_nodes)
    
    return(data_train, data_test, A, A_train, A_test, edge_pos_train,
           edge_neg_train, edge_pos_test, edge_neg_test, num_subgraph_nodes, num_nodes)
            
            
            
            

In [None]:
def node2vec_model(graph):
    edge_index = torch.tensor(list(graph.edges), dtype=torch.long).t().contiguous() 
    node2vec = Node2Vec(edge_index, embedding_dim=256, walk_length=40, context_size=10, walks_per_node=100)
    loader = DataLoader(range(graph.number_of_nodes()), batch_size=64, shuffle=True)
    optimizer = torch.optim.Adam(node2vec.parameters(), lr=0.01) 
    node2vec.train()
    for epoch in range(100):
        total_loss = 0
        for batch in loader:
            optimizer.zero_grad() 
            
            pos_rw = node2vec.pos_sample(batch) 
            neg_rw = node2vec.neg_sample(batch)  
            
            loss = node2vec.loss(pos_rw, neg_rw)
            loss.backward()  
            optimizer.step()  
            total_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")

    node2vec.eval()  
    embed = node2vec.embedding.weight.data.cpu().numpy() 

    data_x = torch.tensor(embed)
    print(data_x)
    return(data_x)

In [None]:
def data_loader(data_name, network_type, feature_type, negative_injection = True):
    
    print("load data...")
    file_path = "dataset/" + args.dataset_name + ".txt"
    
    #sample positive
    positive_all = np.loadtxt(file_path, dtype=int, usecols=(0, 1))
    if np.min(positive_all) == 1:
        positive_all -= 1
    np.random.shuffle(positive_all)
    n = int(len(positive_all)*0.54)
    positive = np.asarray(positive_all[:n])
    supervision_edges_pos = positive
    message_edges = np.asarray(positive_all[n:])
    
    G = nx.Graph() if args.network_type == 0 else nx.DiGraph()
    G.add_edges_from(positive_all)
    nodes_size = len(G.nodes()) #nodes size in the network
    edge_size = len(G.edges())
    
    # sample negative
    negative_all = list(nx.non_edges(G))
    if np.min(negative_all) == 1:
        negative_all -= 1
    np.random.shuffle(negative_all)
    negative = np.asarray(negative_all[:len(positive)])
    supervision_edges_neg = negative
    
    test_size = int(len(positive) * args.test_ratio)
    train_pos, test_pos = supervision_edges_pos[:-test_size], supervision_edges_pos[-test_size:]
    train_neg, test_neg = supervision_edges_neg[:-test_size], supervision_edges_neg[-test_size:]
    
    train_message_edges = message_edges
    test_message_edges = np.concatenate([message_edges, train_pos])
    
    #adj matrix
    A = np.zeros([nodes_size, nodes_size], dtype=np.uint8)
    A[positive_all[:, 0], positive_all[:, 1]] = 1
    
    if network_type == 0:
        A[positive_all[:, 1], positive_all[:, 0]] = 1
        
    
    A_train = np.zeros([nodes_size, nodes_size], dtype=np.uint8)
    A_train[train_message_edges[:, 0], train_message_edges[:, 1]] = 1
    
        
    if network_type == 0:
        A_train[train_message_edges[:, 1], train_message_edges[:, 0]] = 1
    
    
    A_test = np.zeros([nodes_size, nodes_size], dtype=np.uint8)
    A_test[test_message_edges[:, 0], test_message_edges[:, 1]] = 1
    if network_type == 0:
        A_test[test_message_edges[:, 1], test_message_edges[:, 0]] = 1
        
    for i in range(nodes_size):
        A_test[i,i] = 1
        A_train[i,i] = 1
        A[i,i] = 1
        
    A = ssp.csr_matrix(A)    
    A_test = ssp.csr_matrix(A_test) 
    A_train = ssp.csr_matrix(A_train)
    
    #nodes feature
    if feature_type == "node2vec":
        data_x = node2vec_model(G)

    elif feature_type == "onehot":
        data_x = torch.diag(torch.ones(nodes_size))#ONE_HOT
    
            
    return(A ,A_train, A_test, data_x, nodes_size, edge_size,
               torch.from_numpy(train_message_edges), torch.from_numpy(test_message_edges), 
               torch.from_numpy(train_pos), torch.from_numpy(train_neg),
              torch.from_numpy(test_pos), torch.from_numpy(test_neg))



In [None]:

class Logger(object):
    def __init__(self, runs, info=None):
        self.info = info
        self.results = [[] for _ in range(runs)]

    def add_result(self, run, result):
        assert len(result) == 3
        assert run >= 0 and run < len(self.results)
        self.results[run].append(result)

    def print_statistics(self, run=None):
        if run is not None:
            result = 100 * torch.tensor(self.results[run])
            argmax = result[:, 1].argmax().item()
            print(f'Run {run + 1:02d}:')
            print(f'Highest Train: {result[:, 0].max():.2f}')
            print(f'Highest Valid: {result[:, 1].max():.2f}')
            print(f'  Final Train: {result[argmax, 0]:.2f}')
            print(f'   Final Test: {result[argmax, 2]:.2f}')
        else:
            result = 100 * torch.tensor(self.results)

            best_results = []
            for r in result:
                train1 = r[:, 0].max().item()
                valid = r[:, 1].max().item()
                train2 = r[r[:, 1].argmax(), 0].item()
                test = r[r[:, 1].argmax(), 2].item()
                best_results.append((train1, valid, train2, test))

            best_result = torch.tensor(best_results)

            print(f'All runs:')
            r = best_result[:, 0]
            print(f'Highest Train: {r.mean():.2f} ± {r.std():.2f}')
            r = best_result[:, 1]
            print(f'Highest Valid: {r.mean():.2f} ± {r.std():.2f}')
            r = best_result[:, 2]
            print(f'  Final Train: {r.mean():.2f} ± {r.std():.2f}')
            r = best_result[:, 3]
            print(f'   Final Test: {r.mean():.2f} ± {r.std():.2f}')

In [None]:
def accuracy(pred, label):

    accu = 0.0

    pred_label = pred.ge(0.5) 
    accu = torch.sum(pred_label == label) / label.shape[0]
    accu = accu.item()

    return round(accu,4)

In [None]:
def recall(pred, label):

    pred_label = pred.ge(0.5) 
    recall = recall_score(label, pred_label)

    return round(recall,4)

In [None]:
def precision(pred, label):

    pred_label = pred.ge(0.5) 
    precision = precision_score(label, pred_label)

    return round(precision,4)

In [None]:
def f1(pred, label):

    pred_label = pred.ge(0.5) 
    f1 = f1_score(label, pred_label)

    return round(f1,4)

In [None]:
def precision_recall(pred, label):
    
    precision, recall, _ = precision_recall_curve(label, pred)
    
    return (precision, recall)


In [None]:
def precision_recall_AUC(pred, label):
    
    precision, recall, _ = precision_recall_curve(label, pred)
    
    pr_auc = metrics.auc(recall, precision)
    
    return round(pr_auc,4)

In [None]:
def average_precision(pred, label):
    
    avg_precision = average_precision_score(label, pred)
    
    return round(avg_precision,4)

In [None]:

def neighbors(fringe, A, outgoing=True):

    if outgoing:
        res = (set(A[list(fringe)].indices))
    else:
        res = set(A[:, list(fringe)].indices)

    return res

In [None]:
def k_hop_subgraph(num_hops, A, src, dst=None, num_nodes=0 ,sample_ratio=1.0, 
                   max_nodes_per_hop=None, node_features=None, 
                   y=1, directed=False, A_csc=None):
    # Extract the k-hop enclosing subgraph around link (src, dst) from A. 
    if dst == None:
        src = src.item()
        nodes = [src]
        visited = set([src])
        fringe = set([src])
        for dist in range(1, num_hops+1):
            if not directed:
                fringe = neighbors(fringe, A)
            else:
                out_neighbors = neighbors(fringe, A)
                in_neighbors = neighbors(fringe, A_csc, False)
                fringe = out_neighbors.union(in_neighbors)
            fringe = fringe - visited
            visited = visited.union(fringe)
            if sample_ratio < 1.0:
                fringe = random.sample(fringe, int(sample_ratio*len(fringe)))
            if max_nodes_per_hop is not None:
                if max_nodes_per_hop < len(fringe):
                    fringe = random.sample(fringe, max_nodes_per_hop)
            if len(fringe) == 0:
                break
            if num_nodes != 0:
                if len(nodes + list(fringe)) > num_nodes:
                    x = len(nodes + list(fringe)) - num_nodes
                    nodes = nodes + list(fringe)[:-x]
                    break
                    
                if len(nodes) == num_nodes:
                    break
                    
            nodes = nodes + list(fringe)
            
            
        subgraph = A[nodes, :][:, nodes]

        # Remove target link between the subgraph.
        subgraph[0, 1] = 0
        subgraph[1, 0] = 0

        if node_features is not None:
            node_features = node_features[nodes]
        
    
    else:
        src = src.item()
        dst = dst.item()
        nodes = [src, dst]
        visited = set([src, dst])
        fringe = set([src, dst])
        for dist in range(1, num_hops+1):
            if not directed:
                fringe = neighbors(fringe, A)
            else:
                out_neighbors = neighbors(fringe, A)
                in_neighbors = neighbors(fringe, A_csc, False)
                fringe = out_neighbors.union(in_neighbors)
            fringe = fringe - visited
            visited = visited.union(fringe)
            if sample_ratio < 1.0:
                fringe = random.sample(fringe, int(sample_ratio*len(fringe)))
            if max_nodes_per_hop is not None:
                if max_nodes_per_hop < len(fringe):
                    fringe = random.sample(fringe, max_nodes_per_hop)
            if len(fringe) == 0:
                break
                
            if num_nodes != 0:
                if len(nodes + list(fringe)) > num_nodes:
                    x = len(nodes + list(fringe)) - num_nodes
                    nodes = nodes + list(fringe)[:-x]
                    break
                    
                if len(nodes) == num_nodes:
                    break
                    
            nodes = nodes + list(fringe)
            
        subgraph = A[nodes, :][:, nodes]

        # Remove target link between the subgraph.
        subgraph[0, 1] = 0
        subgraph[1, 0] = 0

        if node_features is not None:
            node_features = node_features[nodes]

    return nodes, subgraph


In [None]:
def subgraph_primitive(A, nodes1, nodes2, num_nodes, type_data ):
    
    max_hop= args.max_hop
    
    
        
    path = "data/%s" %( args.dataset_name)
    isExist = os.path.exists(path)
    if not isExist:
        os.makedirs(path)


    field_names = ['nodes1', 'nodes2', 'subgraph_nodes', 'subgraph_A']

    for index in range(len(nodes1)): 
        subgraph_nodes, subgraph_A = k_hop_subgraph(max_hop, A, nodes1[index], nodes2[index], num_nodes=num_nodes)

        dict = {'nodes1':subgraph_nodes[0], 'nodes2':subgraph_nodes[1], 'subgraph_nodes':subgraph_nodes, 'subgraph_A':subgraph_A}

        isExist = os.path.exists(path+'/subgraphs_%s_info.csv'%(type_data))
        if not isExist:
            with open(path+'/subgraphs_%s_info.csv'%(type_data), mode='a') as f_object:
                dictwriter_object = DictWriter(f_object, fieldnames=field_names)
                dictwriter_object.writeheader()
                dictwriter_object.writerow(dict)
                f_object.close()
        else:
            with open(path+'/subgraphs_%s_info.csv'%(type_data), mode='a') as f_object:
                dictwriter_object = DictWriter(f_object, fieldnames=field_names)
                dictwriter_object.writerow(dict)
                f_object.close()


            
            
           

In [None]:
def subgraph2vec(A, dff, embed, nodes1, nodes2, num_nodes  ):
    
    z_embed = torch.tensor(()).to(embed.device)
    
    
    
    if args.dist_type == 'cos':
        cos = torch.nn.CosineSimilarity(dim=0)
        
    for index in range(len(nodes1)):
        dis = []
        
        node1 = nodes1[index].item()
        node2 = nodes2[index].item()
        
           
        subgraph_nodes = (dff[(dff['nodes1'] == node1)&(dff['nodes2'] == node2)]['subgraph_nodes'])
        if args.subgraph_type == 'DIS':
            subgraph_nodes = (ast.literal_eval(subgraph_nodes.dropna().values[0]))
        elif args.subgraph_type == 'hhop':
            subgraph_nodes = (ast.literal_eval(subgraph_nodes.dropna().values[0]))[:num_nodes+2]
        
        embed_node_index = ((embed[node1]+embed[node2])/2)
        embed_node1 = (embed[node1])
        embed_node2 = (embed[node2])


        d_list = []
        for i in range (len(subgraph_nodes)):

            if subgraph_nodes[i] != nodes1[index] and subgraph_nodes[i] != nodes2[index]:

                if args.dist_type == 'norm':
                    d_list.append((torch.norm((embed[subgraph_nodes[i]])-(embed_node_index))).item())

                elif args.dist_type == 'cos':
                    d_list.append(1-(cos(embed[subgraph_nodes[i]], embed_node_index).item()))
            else:
                d_list.append(0)

        d_list = torch.tensor((normalize([d_list])[0]).tolist())


        for i in range (len(subgraph_nodes)):
            dis.append([subgraph_nodes[i], d_list[i]])



        df = pd.DataFrame.from_dict(dict(dis), orient='index', columns=['distance'])
        df = df.sort_values('distance', ascending=True)


        subgraph = df.dropna().index.values[:(num_nodes+2)]
        distances = df.dropna().values[:(num_nodes+2)]

        subgraph_embed = embed[subgraph].to(embed.device)

        if args.subgraph_feature_type == 'NDP':

            w_list = torch.zeros(len(subgraph))

            for i in range(len(subgraph)):
                d = distances[i,0]
                if d==0:
                    w_list[i] = math.log((1-0.0001)/0.0001)
                elif(d==1):
                    w_list[i] = math.log((1-0.9999)/0.9999)
                else:
                    w_list[i] = math.log((1-d)/d) #adaboost
            w_list = w_list / (w_list.sum())

            for i in range(len(subgraph)):
                subgraph_embed[i] = subgraph_embed[i]*w_list[i]

            s = (subgraph_embed.sum(dim=0))
            s = s.reshape([1,len(s)])
            z_embed = torch.cat((z_embed,s), 0)



        elif args.subgraph_feature_type == 'cnn':
            subgraph_embed= subgraph_embed.unsqueeze(0)

            if len(subgraph_embed)<(num_nodes+2):
                s = subgraph_embed.tolist()

                for x in range(0,((num_nodes+2)-len(subgraph))):
                    z = [0]*args.GNN_out_channels
                    s[0].append(z)

                subgraph_embed = torch.tensor(s).to(embed.device)

            z_embed = torch.cat((z_embed,subgraph_embed),0)
            
    
    return(z_embed)
        

In [None]:
class subg2vec_model(torch.nn.Module):
    def __init__(self, input_num, feature_dim, hidden_channels, out_channels, dropout):
        
        super(subg2vec_model, self).__init__()        
        
        self.conv1 = torch.nn.Conv2d(1, 8, (3,feature_dim), padding=1)
        self.pool = torch.nn.MaxPool2d(2, 2)
        self.conv2 = torch.nn.Conv2d(8, 16, (3,2), padding=1)
        
        input_num = input_num+2
        
        x = math.ceil((math.ceil((input_num-3+2)/2)-3+2)/2)
        y = math.ceil((math.ceil((feature_dim-2+2)/2)-2+2)/2)

        self.lin1 = torch.nn.Linear(int(0.5*x*y), hidden_channels)
        self.lin2 = torch.nn.Linear(hidden_channels, out_channels)
        

        
        self.dropout = dropout
        
    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()
            
    def forward(self, x):
        x = x.unsqueeze(1)
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x,1) # flatten all dimensions except batch
        x = F.relu(self.lin1(x))
        x = self.lin2(x)
        return (x.squeeze(1))


In [None]:
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout):
        
        super(GCN, self).__init__()
        
        self.convs = torch.nn.ModuleList()
        self.convs.append(GCNConv(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))
        self.convs.append(GCNConv(hidden_channels, out_channels))
        
        self.dropout = dropout
        
    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
            
    def forward(self, x, adj):
        for conv in self.convs[:-1]:
            x = conv(x, adj)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, adj)

        return x
        

In [None]:

class SAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout):
        super(SAGE, self).__init__()

        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        self.convs.append(SAGEConv(hidden_channels, out_channels))

        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()

    def forward(self, x, adj_t):
        for conv in self.convs[:-1]:
            x = conv(x, adj_t)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, adj_t)
        return x



In [None]:

class GAutoEncoder(torch.nn.Module):
    def __init__(self, GNN_in_channels, GNN_hidden_channels, GNN_out_channels,
                 GNN_num_layers, dropout):

        super(GAutoEncoder, self).__init__()

        self.encoders = torch.nn.ModuleList()
        self.encoders.append(GCNConv(GNN_in_channels, GNN_hidden_channels))
        for _ in range(GNN_num_layers - 2):
            self.encoders.append(GCNConv(GNN_hidden_channels, GNN_hidden_channels))
        self.encoders.append(GCNConv(GNN_hidden_channels, GNN_out_channels))
        

        self.dropout = dropout

    def reset_parameters(self):
        for encoder in self.encoders:
            encoder.reset_parameters()

    def forward(self, x, adj):
        for encoder in self.encoders[:-1]:
            x = encoder(x, adj)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.encoders[-1](x, adj) 
        
        return x


  
     

In [None]:
def predictor_simple(z, edge_index):
    #print('start decoder')
    z1 = (z[edge_index[0].long()])
    z2 = (z[edge_index[1].long()])
    logits = (z1 * z2).sum(dim=-1)
    #print(logits)
    #print('end decoder')
    return logits

In [None]:
class predictor_model(torch.nn.Module):
    def __init__(self, linear_in_channels, linear_hidden_channels, linear_num_layers, dropout):
        
        super(predictor_model, self).__init__()        
        
        self.lins = torch.nn.ModuleList()
        self.lins.append(torch.nn.Linear(3*linear_in_channels, linear_hidden_channels))
        for _ in range(linear_num_layers-2):
            self.lins.append(torch.nn.Linear(linear_hidden_channels, linear_hidden_channels))
        self.lins.append(torch.nn.Linear(linear_hidden_channels, 1))
        
        self.dropout = dropout
        
    def reset_parameters(self):
        for lin in self.lins:
            lin.reset_parameters()
            
    def forward(self, z):

        for lin in self.lins[:-1]:
            z = lin(z)
            z = F.relu(z)
            z = F.dropout(z, p=self.dropout, training=self.training)
        z = self.lins[-1](z)
        return (((z)).T)



In [None]:
class predictor_model_cnn(torch.nn.Module):
    def __init__(self, input_num, feature_dim, hidden_channels, out_channels, dropout):
        
        super(predictor_model_cnn, self).__init__()        
        
               
        self.conv1 = torch.nn.Conv2d(1, 8, (3,5), padding=1)
        self.pool = torch.nn.MaxPool2d(2, 2)
        self.conv2 = torch.nn.Conv2d(8, 16, (3,5), padding=1)
        
        input_num = input_num+2
        x = math.ceil((math.ceil((input_num-3+2)/2)-3+2)/2)
        y = math.ceil((math.ceil((feature_dim-5+2)/2)-5+2)/2)
        
        self.lin1 = torch.nn.Linear(16*x*y, hidden_channels)
        self.lin2 = torch.nn.Linear(hidden_channels, 64)
        self.lin3 = torch.nn.Linear(64, 16)
        self.lin4 = torch.nn.Linear(16, 1)

        
        self.dropout = dropout
        
    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()
        self.lin3.reset_parameters()
        self.lin4.reset_parameters()
            
    def forward(self, x):
        x = x.unsqueeze(1)
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x,1) # flatten all dimensions except batch
        x = F.relu(self.lin1(x))
        x = F.relu(self.lin2(x))
        x = F.relu(self.lin3(x))
        x = self.lin4(x)
        return (x.squeeze(1))



In [None]:
def get_link_labels(pos_edge_index, neg_edge_index):
    E = pos_edge_index.size(1) + neg_edge_index.size(1)
    link_labels = torch.zeros(E, dtype=torch.float)
    link_labels[:pos_edge_index.size(1)] = 1.
    return link_labels

In [None]:

def train( model, model_predictor, data, A, data_split, optimizer,
          optimizer_predictor, num_subgraph_nodes, model_subgraph=None, optimizer_subgraph=None):
    
    model.train()
    model_predictor.train()
    if args.subgraph_feature_type == 'cnn':
        model_subgraph.train()
    
    pos_train_edge = data_split['train']['edge'].to(data.x.device)
    neg_train_edge = data_split['train']['edge_neg'].to(data.x.device)
    link_labels = ((get_link_labels(pos_train_edge.T, neg_train_edge.T)).to(data.x.device)).tolist()
    
    train_edge = (torch.cat((pos_train_edge, neg_train_edge), dim=0)).tolist()
    
    train_data = tuple(zip(train_edge,link_labels))
    
    
    optimizer.zero_grad()
    optimizer_predictor.zero_grad()
    if args.subgraph_feature_type == 'cnn':
        optimizer_subgraph.zero_grad()
        

    
    total_examples = 0
    total_loss = 0
    i = 0
    
    
    batch_size = args.batch_size
    
    
    train_loader = DataLoader(train_data, batch_size, shuffle=True)
    pbar = tqdm(train_loader, ncols=70)
    for perm in pbar:
        if args.GNN_type == 'autoencoder':
            z = model.encode(data.x, data.edge_index.T)
        else:
            z = model(data.x, data.edge_index.T)
        i += 1

        edge_index_x = perm[0][0].to(data.x.device)
        edge_index_y = perm[0][1].to(data.x.device)
        link_labels = perm[1].to(data.x.device)
        
        
        
        
        z1 = (z[edge_index_x.long()])
        z2 = (z[edge_index_y.long()])
        z_nodes = torch.cat([z1,z2], dim=-1)
        
        df = pd.read_csv("data/%s" %(args.dataset_name) + '/subgraphs_%s_info.csv'%('train'))
        
        if args.subgraph_feature_type == 'cnn':
            sub_embed = subgraph2vec(A, df, z, edge_index_x, edge_index_y, num_subgraph_nodes ).to(data.x.device)
            z_sub = model_subgraph(sub_embed)
            z_embed = torch.cat((z_nodes,z_sub),-1)
            
        elif args.subgraph_feature_type == 'cnn2':
            z_embed = subgraph2vec(A, df, z, edge_index_x, edge_index_y, num_subgraph_nodes ).to(data.x.device)

        else:
            z_sub = subgraph2vec(A, df, z, edge_index_x, edge_index_y, num_subgraph_nodes ).to(data.x.device)
            z_embed = torch.cat((z_nodes,z_sub),-1)

        link_logits = model_predictor(z_embed).to(data.x.device)
        
        
        link_logits = link_logits.squeeze()    

        if args.GNN_type == 'autoencoder':
            loss_autoencoder = model.recon_loss(z, data.edge_index.T)
            loss_predict = BCEWithLogitsLoss()(link_logits, link_labels)
            loss = loss_autoencoder+loss_predict
            loss.backward()

        else:
            loss = BCEWithLogitsLoss()(link_logits, link_labels)
            loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 2.0)
        torch.nn.utils.clip_grad_norm_(model_predictor.parameters(), 2.0)
        if args.subgraph_feature_type == 'cnn':
            torch.nn.utils.clip_grad_norm_(model_subgraph.parameters(), 2.0)

        num_examples = (link_logits.size(0))
        total_loss += loss.item() * num_examples
        total_examples += num_examples

        optimizer.step()
        optimizer_predictor.step()
        if args.subgraph_feature_type == 'cnn':
            optimizer_subgraph.step()

    return (total_loss/ total_examples)


In [None]:
@torch.no_grad()
def test(model, model_predictor, data_train, data_test, A_train, A_test, split_edge, num_subgraph_nodes,
         model_subgraph=None, evaluator=False):
    
    

    model.eval()
    model_predictor.eval()
    if args.subgraph_feature_type == 'cnn':
        model_subgraph.eval()



    pos_train_edge = split_edge['train']['edge'].T
    neg_train_edge = split_edge['train']['edge_neg'].T
    pos_test_edge = split_edge['test']['edge'].T
    neg_test_edge = split_edge['test']['edge_neg'].T
    
    
    batch_size = args.batch_size
    

    train_preds = []
    train_labels = (get_link_labels(pos_train_edge, neg_train_edge)).to(data_train.x.device)
    train_edge = torch.cat((pos_train_edge, neg_train_edge), dim=1)
    
    if args.GNN_type == 'autoencoder':
        z = model.encode(data_train.x, data_train.edge_index.T)
    else:
        z = model(data_train.x, data_train.edge_index.T)
            
    for perm in DataLoader(range(train_edge.size(1)), batch_size, shuffle=False):
        
        
        edge = (train_edge.T[perm]).T
        
        edge_index_x = edge[0]
        edge_index_y = edge[1]
        
        z1 = (z[edge_index_x.long()])
        z2 = (z[edge_index_y.long()])
        z_nodes = torch.cat([z1,z2], dim=-1)
        
        df = pd.read_csv("data/%s" %( args.dataset_name) + '/subgraphs_%s_info.csv'%('train'))
        
        if args.subgraph_feature_type == 'cnn':
            sub_embed = subgraph2vec(A_train, df, z, edge_index_x, edge_index_y, num_subgraph_nodes ).to(data_train.x.device)
            z_sub = model_subgraph(sub_embed)
            z_embed = torch.cat((z_nodes,z_sub),-1)
            
        elif args.subgraph_feature_type == 'cnn2':
            z_embed = subgraph2vec(A_train, df, z, edge_index_x, edge_index_y, num_subgraph_nodes ).to(data_train.x.device)

        else:
            z_sub = subgraph2vec(A_train, df, z, edge_index_x, edge_index_y, num_subgraph_nodes ).to(data_train.x.device)
            z_embed = torch.cat((z_nodes,z_sub),-1)        
        
            
        
        train_preds += [model_predictor(z_embed)]
        
    train_pred = torch.cat(train_preds, dim=-1)
    


    test_preds = []
    test_labels = (get_link_labels(pos_test_edge, neg_test_edge)).to(data_test.x.device)
    test_edge = torch.cat((pos_test_edge, neg_test_edge), dim=1)
    
    if args.GNN_type == 'autoencoder':
        z = model.encode(data_test.x, data_test.edge_index.T)
    else:
        z = model(data_test.x, data_test.edge_index.T)
        
    for perm in DataLoader(range(test_edge.size(1)), batch_size, shuffle=False):
            
        edge = (test_edge.T[perm]).T
        
        edge_index_x = edge[0]
        edge_index_y = edge[1]
        
        z1 = (z[edge_index_x.long()])
        z2 = (z[edge_index_y.long()])
        z_nodes = torch.cat([z1,z2], dim=-1)
        
        
        df = pd.read_csv("data/%s" %(args.dataset_name) + '/subgraphs_%s_info.csv'%('test'))
        
        if args.subgraph_feature_type == 'cnn':
            sub_embed = subgraph2vec(A_test,df, z, edge_index_x, edge_index_y, num_subgraph_nodes ).to(data_train.x.device)
            z_sub = model_subgraph(sub_embed)
            z_embed = torch.cat((z_nodes,z_sub),-1)
            
        elif args.subgraph_feature_type == 'cnn2':
            z_embed = subgraph2vec(A_test, df, z, edge_index_x, edge_index_y, num_subgraph_nodes ).to(data_train.x.device)

        else:
            z_sub = subgraph2vec(A_test, df, z, edge_index_x, edge_index_y, num_subgraph_nodes ).to(data_train.x.device)
            z_embed = torch.cat((z_nodes,z_sub),-1) 
            
        
        test_preds += [model_predictor(z_embed)]
        
    test_pred = torch.cat(test_preds, dim=-1)
    


    
    train_logits = train_pred
    test_logits = test_pred
    
    accu_train = accuracy(train_logits, train_labels)
    accu_test = accuracy(test_logits, test_labels)

    recall_train = recall(train_logits[0], train_labels)
    recall_test = recall(test_logits[0], test_labels)
    
    precision_train = precision(train_logits[0], train_labels)
    precision_test = precision(test_logits[0], test_labels)
    
    f1_train = f1(train_logits[0], train_labels)
    f1_test = f1(test_logits[0], test_labels)
    
    avg_precision_train = average_precision(train_logits[0], train_labels)
    avg_precision_test = average_precision(test_logits[0], test_labels)
    
    pr_auc_train = precision_recall_AUC(train_logits[0], train_labels)
    pr_auc_test = precision_recall_AUC(test_logits[0], test_labels)
    
    precision_list_train, recall_list_train = precision_recall(train_logits[0], train_labels)
    precision_list_train = precision_list_train.tolist()
    recall_list_train = recall_list_train.tolist()
    precision_list_test, recall_list_test = precision_recall(test_logits[0], test_labels)
    precision_list_test = precision_list_test.tolist()
    recall_list_test = recall_list_test.tolist()
    
    results = {'train': {'accu_train': accu_train, 'recall_train':recall_train, 'precision_train':precision_train,
                        'f1_train':f1_train, 'avg_precision_train':avg_precision_train, 'pr_auc_train':pr_auc_train,
                        'precision_list_train':precision_list_train, 'recall_list_train':recall_list_train}, 
                'test': {'accu_test': accu_test, 'recall_test':recall_test, 'precision_test':precision_test,
                        'f1_test':f1_test, 'avg_precision_test':avg_precision_test, 'pr_auc_test':pr_auc_test,
                        'precision_list_test':precision_list_test, 'recall_list_test':recall_list_test}}


    return results


In [None]:
def parse_arguments(device=0, dataset_name='karate', subgraph_type='hhop', dist_type='norm', 
                    network_type=0, feature_type='onehot', subgraph_feature_type='CNN',
                    negative_injection=True, log_steps=1, GNN_type='gcn', GNN_num_layers=3, GNN_hidden_channels=128,
                    GNN_out_channels=128, linear_num_layers=5, linear_hidden_channels=32, n2v_dim=256,
                    subg2vec_hidden_channels=128, subg2vec_out_channels=128, max_hop=10, dropout=0.0, no_start_run=1,
                    batch_size=50, lr=0.0001 ,epochs=300, eval_steps=300, test_ratio=0.1, fold=True, kfolds=5, runs=10,
                    coefficient=50, graphlet_size=4, label='dist'):
    
    parser = argparse.ArgumentParser(description='SGAE_subgraph')
    parser.add_argument('--device', type=int, default=device)
    parser.add_argument('--dataset_name', type=str, default=dataset_name) #'ogbl-collab', 'football', 'FB15k-237',
    #'karate', 'USAir', 'PB'
    parser.add_argument('--subgraph_type', type=str, default=subgraph_type) # 'hhop', 'DIS'
    parser.add_argument('--dist_type', type=str, default=dist_type)#'norm', 'cos'
    parser.add_argument('--network_type', type=int, default=network_type)#if directed -> 0, if undirected -> 1
    parser.add_argument('--feature_type', type=str, default=feature_type)#'node2vec', 'onehot'
    parser.add_argument('--subgraph_feature_type', type=str, default=subgraph_feature_type)#'cnn', 'NDP'
    parser.add_argument('--n2v_dim', type=int, default=n2v_dim)
    parser.add_argument('--negative_injection', type=bool, default=negative_injection)
    parser.add_argument('--log_steps', type=int, default=log_steps)
    parser.add_argument('--GNN_type', type=str, default=GNN_type)#'gcn', 'sage', 'autoencoder'
    parser.add_argument('--GNN_num_layers', type=int, default=GNN_num_layers)
    parser.add_argument('--GNN_hidden_channels', type=int, default=GNN_hidden_channels)
    parser.add_argument('--GNN_out_channels', type=int, default=GNN_out_channels)
    parser.add_argument('--linear_num_layers', type=int, default=linear_num_layers)
    parser.add_argument('--linear_hidden_channels', type=int, default=linear_hidden_channels)
    parser.add_argument('--subg2vec_hidden_channels', type=int, default=subg2vec_hidden_channels)
    parser.add_argument('--subg2vec_out_channels', type=int, default=subg2vec_out_channels)
    parser.add_argument('--max_hop', type=int, default=max_hop)
    parser.add_argument('--dropout', type=float, default=dropout)
    parser.add_argument('--batch_size', type=int, default=batch_size)
    parser.add_argument('--lr', type=float, default=lr)
    parser.add_argument('--epochs', type=int, default=epochs)
    parser.add_argument('--eval_steps', type=int, default=eval_steps)
    parser.add_argument('--test_ratio', type=float, default=test_ratio)
    parser.add_argument('--fold', type=bool, default=fold)
    parser.add_argument('--kfolds', type=int, default=kfolds)
    parser.add_argument('--runs', type=int, default=runs)
    parser.add_argument('--coefficient', type=int, default=coefficient)
    parser.add_argument('--no_start_run', type=int, default=no_start_run)
    parser.add_argument('--graphlet_size',type=int,default=4, help='Maximal graphlet size.')
    parser.add_argument('--label', type=str, default=label)
    
    args = parser.parse_args(args=[])
    
    return(args)
    

In [None]:
def main():
    
    print(args)
    
#*******dataset*******
    
    data_train, data_test, A, A_train, A_test, edge_pos_train,\
        edge_neg_train, edge_pos_test, edge_neg_test, num_subgraph_nodes, num_nodes = dataset()
    

    
    split_edge = {'train': {'edge':edge_pos_train, 
                                'edge_neg':edge_neg_train},
                                'test':{'edge':edge_pos_test, 
                                'edge_neg':edge_neg_test} }
            
    

    device = 'cpu'
    device = torch.device(device)
    
    data_train = data_train.to(device)
    data_test = data_test.to(device)

#*******create models*******   
    if args.GNN_type == 'gcn':
        model = GCN(data_train.x.shape[1], args.GNN_hidden_channels, args.GNN_out_channels, 
                               args.GNN_num_layers, args.dropout).to(device)
    if args.GNN_type == 'sage':
        model = SAGE(data_train.x.shape[1], args.GNN_hidden_channels, args.GNN_out_channels, 
                           args.GNN_num_layers, args.dropout).to(device)
    if args.GNN_type == 'autoencoder':
        model = GAE(GAutoEncoder(data_train.x.shape[1], args.GNN_hidden_channels, args.GNN_out_channels,
                                   args.GNN_num_layers, args.dropout).to(device))
        
    if args.subgraph_feature_type == 'cnn2':
        model_predictor = predictor_model_cnn(num_subgraph_nodes, args.GNN_out_channels, args.linear_hidden_channels, 
                                          args.linear_num_layers, args.dropout).to(device)
    else:
        model_predictor = predictor_model(args.GNN_out_channels, args.linear_hidden_channels, 
                                          args.linear_num_layers, args.dropout).to(device)
    
    if args.subgraph_feature_type == 'cnn':
        model_subgraph = subg2vec_model(num_subgraph_nodes, args.GNN_out_channels, args.subg2vec_hidden_channels, 
                                          args.subg2vec_out_channels, args.dropout).to(device)
    
    coeff = args.coefficient

#*******train and test model*******    
    if args.fold==True:
        for run in range(args.no_start_run , 1 + args.runs):
        
            # Define the K-fold Cross Validator

            kfold = KFold(n_splits=args.kfolds, shuffle=False)

            train_ids=[]
            test_ids=[]

            for fold, (train_id, test_id) in enumerate(kfold.split(edge_pos)):
                train_ids.append(train_id)
                test_ids.append(test_id)

            # K-fold Cross Validation model evaluation
            results_folds=[]
            for fold in range(args.kfolds):
                print('--------------------------------------------------------------')
                print('--------------------------------------------------------------')
                print(f'FOLD {fold}')
                print('--------------------------------')

                split_edge = {'train': {'edge':edge_pos[train_ids[fold]], 
                                'edge_neg':edge_neg[train_ids[fold]]},
                                'test':{'edge':edge_pos[test_ids[fold]], 
                                'edge_neg':edge_neg[test_ids[fold]]} }




                model.reset_parameters()
                model_predictor.reset_parameters()
                if args.subgraph_feature_type == 'cnn':
                    model_subgraph.reset_parameters()

                optimizer = torch.optim.Adam(list(model.parameters()), lr=args.lr)
                optimizer_predictor = torch.optim.Adam(list(model_predictor.parameters()), lr=args.lr)
                if args.subgraph_feature_type == 'cnn':
                    optimizer_subgraph = torch.optim.Adam(list(model_subgraph.parameters()), lr=args.lr)


                max_accu = 0
                min_loss = 1000
                losses = []
                for epoch in range(1, 1 + args.epochs):
                    print(f'Epoch: {epoch:02d} and Fold: {fold:02d} and Run: {run:02d}' )
#*******train*******
                    print(datetime.now())
                    print(data)
                    if args.subgraph_feature_type == 'cnn':
                        loss = train(model, model_predictor, data, A, split_edge, optimizer, optimizer_predictor, 
                                        num_subgraph_nodes, model_subgraph, optimizer_subgraph)
                    else:
                        loss = train(model, model_predictor, data, A, split_edge, optimizer, optimizer_predictor, 
                                        num_subgraph_nodes)

                    print(datetime.now())


                    min_loss = loss

                    if epoch % coeff == 0 :

                        torch.save(model.state_dict(), 'models/%s_%s_%s_model_%s_%s_%s_epochs%d_fold%d_run%d.pth' 
                                    %(args.dataset_name, args.feature_type, args.GNN_type, args.subgraph_type,
                                        args.subgraph_feature_type, args.dist_type, epoch, fold, run))

                        torch.save(model_predictor.state_dict(), 'models/%s_%s_%s_model_predictor_%s_%s_%s_epochs%d_fold%d_run%d.pth' 
                                    %(args.dataset_name, args.feature_type, args.GNN_type, args.subgraph_type,
                                        args.subgraph_feature_type, args.dist_type, epoch, fold, run))

                        if args.subgraph_feature_type == 'cnn':
                            torch.save(model_subgraph.state_dict(), 'models/%s_%s_%s_model_subgraph_%s_%s_%s_epochs%d_fold%d_run%d.pth' 
                                        %(args.dataset_name, args.feature_type, args.GNN_type, args.subgraph_type,
                                            args.subgraph_feature_type, args.dist_type, epoch, fold, run))

                        losses.append(min_loss)

                    if epoch % args.eval_steps == 0:

                        results = []
                        for e in range(1,args.epochs+1):
                            if e%coeff == 0:
#*******test*******                               
                                model.load_state_dict(torch.load('models/%s_%s_%s_model_%s_%s_%s_epochs%d_fold%d_run%d.pth' 
                                    %(args.dataset_name, args.feature_type, args.GNN_type, args.subgraph_type,
                                        args.subgraph_feature_type, args.dist_type, e, fold, run)))

                                model_predictor.load_state_dict(torch.load('models/%s_%s_%s_model_predictor_%s_%s_%s_epochs%d_fold%d_run%d.pth' 
                                    %(args.dataset_name, args.feature_type, args.GNN_type, args.subgraph_type,
                                        args.subgraph_feature_type, args.dist_type, e, fold, run)))

                                if args.subgraph_feature_type == 'cnn':
                                    model_subgraph.load_state_dict(torch.load('models/%s_%s_%s_model_subgraph_%s_%s_%s_epochs%d_fold%d_run%d.pth' 
                                            %(args.dataset_name, args.feature_type, args.GNN_type, args.subgraph_type,
                                                args.subgraph_feature_type, args.dist_type, e, fold, run)))

                                if args.subgraph_feature_type == 'cnn':
                                    result = test(model, model_predictor, data, A, split_edge, num_subgraph_nodes, model_subgraph)
                                else:
                                    result = test(model, model_predictor, data, A, split_edge, num_subgraph_nodes)

                                results.append(result)

                                result_train = result['train']['accu_train']
                                result_valid = 0
                                result_test = result['test']['accu_test']
                                
                                
#*******print results*******                                
                                i = (int(e/coeff))-1
                                with open('results/%s_%s_%s_model_%s_%s_%s_epochs%d_fold%d_run%d.csv' 
                                    %(args.dataset_name, args.feature_type, args.GNN_type, args.subgraph_type,
                                        args.subgraph_feature_type, args.dist_type, e, fold, run), mode='w') as csv_file:
                                    fieldnames = ['loss', 'accu_train', 'accu_valid', 'accu_test']
                                    writer = csv.DictWriter(csv_file, fieldnames=fieldnames)

                                    writer.writeheader()
                                    writer.writerow({'loss':losses[i], 'accu_train': result_train, 'accu_valid': result_valid, 
                                                        'accu_test': result_test})


                                print(f'Best Result in epochs {e:02d} fold {fold:02d} and Run {run:02d} ')
                                print('---')


                                print(f'Loss_model: {losses[i]:.4f}, '
                                                f'Accu_Train_model: {result_train:.4f}, '
                                                f'Accu_Valid_model: {result_valid:.4f}, '
                                                f'Accu_Test_model: {result_test:.4f}')
                                print('---')

            result_fold = []  

            for e in range(1,args.epochs+1):
                if e%coeff == 0:
                    sum_folds = np.zeros(4)
                    for f in range(args.kfolds):   
                        df_model = pd.read_csv('results/%s_%s_%s_model_%s_%s_%s_epochs%d_fold%d_run%d.csv' 
                                    %(args.dataset_name, args.feature_type, args.GNN_type, args.subgraph_type,
                                        args.subgraph_feature_type, args.dist_type, e, f, run))
                        sum_folds = (sum_folds + df_model.values[0]).tolist()
                    result_fold.append (sum_folds)



            results_folds = (torch.tensor(result_fold))/args.kfolds


            print('--------------------------------------------------------------------')
            print(f'folds Results in Run {run:02d}')
            print('---')

            for e in range(1,args.epochs+1):
                if e%coeff == 0:
                    i = int(e/coeff)-1
                    print(f'folds Results in Epoch: {e:02d}')
                    print('---')
                    print(f'Loss_model: {results_folds[i][0].item():.4f}, '
                                f'Accu_Train_model: {results_folds[i][1].item():.4f}, '
                                f'Accu_Valid_model: {results_folds[i][2].item():.4f}, '
                                f'Accu_Test_model: {results_folds[i][3].item():.4f}')
                    print('---')


            max_results_model.append(results_folds)
            print(max_results_model)
            
        sum_results = torch.zeros(4)
        for run in range(args.runs):
            sum_results = sum_results + torch.tensor(max_results_model[run])


        final_result_model = sum_results/args.runs


        print('--------------------------------------------------------------------')
        print('--------------------------------------------------------------------')
        print(f'Final Results')
        print('---')

        for e in range(1,args.epochs+1):
            if e%coeff == 0:

                i = int(e/coeff)-1

                print(f'Final Results Epoch: {e:02d}')
                print('---')
                print(f'Loss_model: {final_result_model[i][0].item():.4f}, '
                                f'Accu_Train_model: {final_result_model[i][1].item():.4f}, '
                                f'Accu_Valid_model: {final_result_model[i][2].item():.4f}, '
                                f'Accu_Test_model: {final_result_model[i][3].item():.4f}')
                print('---')


                with open('results/%s_%s_%s_model_%s_%s_%s_epochs%d_fold%d_run%d_final.csv' 
                            %(args.dataset_name, args.feature_type, args.GNN_type, args.subgraph_type,
                            args.subgraph_feature_type, args.dist_type, args.epochs, args.kfolds, args.runs) , mode='w') as csv_file:
                    fieldnames = ['loss', 'accu_train', 'accu_valid', 'accu_test']
                    writer = csv.DictWriter(csv_file, fieldnames=fieldnames)

                    writer.writeheader()
                    writer.writerow({'loss':final_result_model[i][0].item(), 'accu_train': final_result_model[i][1].item(), 
                                        'accu_valid': final_result_model[i][2].item(), 'accu_test': final_result_model[i][3].item()})
        print(datetime.now())


        
    
    else:
        max_results_model=[]
        max_precision_recall_results=[]
        for run in range(args.no_start_run , 1 + args.runs):
            model.reset_parameters()
            model_predictor.reset_parameters()
            if args.subgraph_feature_type == 'cnn':
                model_subgraph.reset_parameters()

            optimizer = torch.optim.Adam(list(model.parameters()), lr=args.lr)
            optimizer_predictor = torch.optim.Adam(list(model_predictor.parameters()), lr=args.lr)
            if args.subgraph_feature_type == 'cnn':
                optimizer_subgraph = torch.optim.Adam(list(model_subgraph.parameters()), lr=args.lr)


            max_accu = 0
            min_loss = 1000
            losses = []
            for epoch in range(1, 1 + args.epochs):
                print(f'Epoch: {epoch:02d} and Run: {run:02d}' )
#*******train*******
                print(datetime.now())

                if args.subgraph_feature_type == 'cnn':
                    loss = train(model, model_predictor, data_train, A_train, split_edge, optimizer, optimizer_predictor, 
                                    num_subgraph_nodes, model_subgraph, optimizer_subgraph)
                else:
                    loss = train(model, model_predictor, data_train, A_train, split_edge, optimizer, optimizer_predictor, 
                                    num_subgraph_nodes)

                print(datetime.now())

                min_loss = loss

                if epoch % coeff == 0 :

                    torch.save(model.state_dict(), 'models/%s_%s_%s_model_%s_%s_%s_epochs%d_run%d.pth' 
                                %(args.dataset_name, args.feature_type, args.GNN_type, args.subgraph_type,
                                    args.subgraph_feature_type, args.dist_type, epoch, run))

                    torch.save(model_predictor.state_dict(), 'models/%s_%s_%s_model_predictor_%s_%s_%s_epochs%d_run%d.pth' 
                                %(args.dataset_name, args.feature_type, args.GNN_type, args.subgraph_type,
                                    args.subgraph_feature_type, args.dist_type, epoch, run))

                    if args.subgraph_feature_type == 'cnn':

                        torch.save(model_subgraph.state_dict(), 'models/%s_%s_%s_model_subgraph_%s_%s_%s_epochs%d_run%d.pth' 
                                    %(args.dataset_name, args.feature_type, args.GNN_type, args.subgraph_type,
                                        args.subgraph_feature_type, args.dist_type, epoch, run))

                    losses.append(min_loss)

                if epoch % args.eval_steps == 0:

                    results = []
                    precision_recall_results = []
                    
                    
                    for e in range(1,args.epochs+1):
                        if e%coeff == 0:
#*******test*******                               
                            model.load_state_dict(torch.load('models/%s_%s_%s_model_%s_%s_%s_epochs%d_run%d.pth' 
                                %(args.dataset_name, args.feature_type, args.GNN_type, args.subgraph_type,
                                    args.subgraph_feature_type, args.dist_type, e, run)))

                            model_predictor.load_state_dict(torch.load('models/%s_%s_%s_model_predictor_%s_%s_%s_epochs%d_run%d.pth' 
                                %(args.dataset_name, args.feature_type, args.GNN_type, args.subgraph_type,
                                    args.subgraph_feature_type, args.dist_type, e, run)))

                            if args.subgraph_feature_type == 'cnn':
                                model_subgraph.load_state_dict(torch.load('models/%s_%s_%s_model_subgraph_%s_%s_%s_epochs%d_run%d.pth' 
                                        %(args.dataset_name, args.feature_type, args.GNN_type, args.subgraph_type,
                                            args.subgraph_feature_type, args.dist_type, e, run)))

                            if args.subgraph_feature_type == 'cnn':
                                result = test(model, model_predictor, data_train, data_test,A_train, A_test, split_edge, num_subgraph_nodes, model_subgraph)
                            else:
                                result = test(model, model_predictor, data_train, data_test,A_train, A_test, split_edge, num_subgraph_nodes)

                            

                            
                            
                            accu_train = result['train']['accu_train']
                            accu_test = result['test']['accu_test']
                            
                            recall_train = result['train']['recall_train']
                            recall_test = result['test']['recall_test']

                            precision_train = result['train']['precision_train']
                            precision_test = result['test']['precision_test']

                            f1_train = result['train']['f1_train']
                            f1_test = result['test']['f1_test']

                            avg_precision_train = result['train']['avg_precision_train']
                            avg_precision_test = result['test']['avg_precision_test']

                            pr_auc_train = result['train']['pr_auc_train']
                            pr_auc_test = result['test']['pr_auc_test']
                            
                            
                            precision_list_train = result['train']['precision_list_train']
                            recall_list_train = result['train']['recall_list_train']
                            
                            precision_list_test = result['test']['precision_list_test']
                            recall_list_test = result['test']['recall_list_test']
                            
#*******print results*******                                
                            i = (int(e/coeff))-1
                            with open('results/%s_%s_%s_model_%s_%s_%s_epochs%d_d_run%d.csv' 
                                %(args.dataset_name, args.feature_type, args.GNN_type, args.subgraph_type,
                                    args.subgraph_feature_type, args.dist_type, e, run), mode='w') as csv_file:
                                fieldnames = ['loss', 'accu_train', 'accu_test', 'recall_train',
                                                'recall_test', 'precision_train', 'precision_test',
                                                'f1_train', 'f1_test', 'avg_precision_train',
                                                'avg_precision_test', 'pr_auc_train', 'pr_auc_test',
                                                'precision_list_train', 'recall_list_train',
                                                'precision_list_test', 'recall_list_test']
                                writer = csv.DictWriter(csv_file, fieldnames=fieldnames)

                                writer.writeheader()
                                writer.writerow({'loss':losses[i], 'accu_train': accu_train,  
                                                    'accu_test': accu_test, 'recall_train':recall_train, 
                                                    'recall_test':recall_test, 'precision_train':precision_train,
                                                    'precision_test':precision_test,
                                                    'f1_train':f1_train,'f1_test':f1_test,
                                                    'avg_precision_train':avg_precision_train,
                                                    'avg_precision_test':avg_precision_test, 'pr_auc_train':pr_auc_train,
                                                    'pr_auc_test':pr_auc_test,
                                                    'precision_list_train':precision_list_train, 
                                                    'recall_list_train':recall_list_train,
                                                    'precision_list_test':precision_list_test, 
                                                    'recall_list_test':recall_list_test})
                            
                            
                            print(f'Best Result in epochs {e:02d} and Run {run:02d} ')
                            print('---')


                            print(f'Loss_model: {losses[i]:.4f}')
                            
                            print(f'Accu_Train: {accu_train:.4f},     '
                                    f'Accu_Test: {accu_test:.4f}')
                            
                            print(f'Recall_Train: {recall_train:.4f},     '
                                    f'Recall_Test: {recall_test:.4f}')
                            
                            print(f'Precision_Train: {precision_train:.4f},     '
                                    f'Precision_Test: {precision_test:.4f}')
                            
                            print(f'F1_Score_Train: {f1_train:.4f},     ' 
                                    f'F1_Score_Test: {f1_test:.4f}')
                            
                            print(f'Avrage_precision_Train: {avg_precision_train:.4f},     ' 
                                    f'Avrage_precision_Test: {avg_precision_test:.4f}')
                            
                            print(f'Precision-Recall_AUC_Train: {pr_auc_train:.4f},     '
                                    f'Precision-Recall_AUC_Test: {pr_auc_test:.4f}')
                            
                            print('---')
                                
                    
                    
        results=[]

        for r in range(1,args.runs+1):
            result=[]
            for e in range(1,args.epochs+1):

                if (e)%args.coefficient==0:
                    df = pd.read_csv('results/%s_%s_%s_model_%s_%s_%s_epochs%d_d_run%d.csv' 
                            %(args.dataset_name, args.feature_type, args.GNN_type, args.subgraph_type,
                                args.subgraph_feature_type, args.dist_type, e, r))
                    result.append([df['loss'].item(), df['accu_train'].item(), df['accu_test'].item(), df['recall_train'].item(),
                                        df['recall_test'].item(), df['precision_train'].item(), df['precision_test'].item(),
                                        df['f1_train'].item(), df['f1_test'].item(), df['avg_precision_train'].item(),
                                        df['avg_precision_test'].item(), df['pr_auc_train'].item(), df['pr_auc_test'].item()])
            results.append(result)
        sum_results = torch.zeros([13])

        for r in range(args.runs):
            sum_results = sum_results+torch.tensor(results[r])
        
        final_result_model = sum_results/args.runs
        
        
        print('--------------------------------------------------------------------')
        print('--------------------------------------------------------------------')
        print(f'Final Results')
        print('---')

        for e in range(1,args.epochs+1):
            if e%coeff == 0:

                i = int(e/coeff)-1

                print(f'Final Results Epoch: {e:02d}')
                print('---')
                print(f'Loss_model: {final_result_model[i][0].item():.4f} ')
                
                print(f'Accu_Train: {final_result_model[i][1].item():.4f},     '
                        f'Accu_Test: {final_result_model[i][2].item():.4f}')
                
                print(f'Recall_Train: {final_result_model[i][3].item():.4f},     '
                        f'Recall_Test: {final_result_model[i][4].item():.4f}')

                print(f'Precision_Train: {final_result_model[i][5].item():.4f},     '
                        f'Precision_Test: {final_result_model[i][6].item():.4f}')

                print(f'F1_Score_Train: {final_result_model[i][7].item():.4f},     ' 
                        f'F1_Score_Test: {final_result_model[i][8].item():.4f}')

                print(f'Avrage_precision_Train: {final_result_model[i][9].item():.4f},     ' 
                        f'Avrage_precision_Test: {final_result_model[i][10].item():.4f}')

                print(f'Precision-Recall_AUC_Train: {final_result_model[i][11].item():.4f},     '
                        f'Precision-Recall_AUC_Test: {final_result_model[i][12].item():.4f}')
                print('---')


                with open('results/%s_%s_%s_model_%s_%s_%s_epochs%d_run%d_final.csv' 
                            %(args.dataset_name, args.feature_type, args.GNN_type, args.subgraph_type,
                            args.subgraph_feature_type, args.dist_type, e, args.runs) , mode='w') as csv_file:
                    fieldnames = ['loss', 'accu_train', 'accu_test', 'recall_train',
                                    'recall_test', 'precision_train', 'precision_test',
                                    'f1_train', 'f1_test', 'avg_precision_train',
                                    'avg_precision_test', 'pr_auc_train', 'pr_auc_test']
                    writer = csv.DictWriter(csv_file, fieldnames=fieldnames)

                    writer.writeheader()
                    writer.writerow({'loss':final_result_model[i][0].item(), 
                                        'accu_train': final_result_model[i][1].item(), 
                                        'accu_test': final_result_model[i][2].item(),  
                                        'recall_train':final_result_model[i][3].item(), 
                                        'recall_test':final_result_model[i][4].item(), 
                                        'precision_train':final_result_model[i][5].item(),
                                        'precision_test':final_result_model[i][6].item(),
                                        'f1_train':final_result_model[i][7].item(),
                                        'f1_test':final_result_model[i][8].item(),
                                        'avg_precision_train':final_result_model[i][9].item(),
                                        'avg_precision_test':final_result_model[i][10].item(), 
                                        'pr_auc_train':final_result_model[i][11].item(),
                                        'pr_auc_test':final_result_model[i][12].item(),
                                        })
        print(datetime.now())

In [None]:
args = parse_arguments(dataset_name='karate', dist_type='cos', subgraph_type='hhop',
                           feature_type='node2vec', subgraph_feature_type='cnn', GNN_type='gcn', epochs=200,
                           eval_steps=200, no_start_run=1, runs=10, coefficient=50, fold=False, lr=0.01)

if __name__ == "__main__":  
    main()