# 导入库函数

In [10]:
from sklearn.cluster import KMeans
import torch
from torch.nn import Embedding
from torch.utils.data import DataLoader
import torch.nn as nn
from torch.nn import Parameter
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, average_precision_score
from sklearn import metrics
#from torch_geometric.datasets import Planetoid
EPS = 1e-15

# evaluating indicator

In [11]:
def predict_acc(recons_edges, true_edges):
    predict_graph = recons_edges
    predict_edges = np.array(predict_graph)
    
    ap = average_precision_score(true_edges, predict_edges)
    #print("AP： ", ap)
    fpr, tpr, _ = metrics.roc_curve(true_edges, predict_edges, pos_label=1)
    auc = metrics.auc(fpr, tpr)
    #print("AUC SCORE: ",auc)
    
    #predict_edges[predict_edges>gama] = 1
    #predict_edges[predict_edges<=gama] = 0
    
    #print("ACC： ",accuracy_score(true_edges, predict_edges))
    #print("AP： ",precision_score(true_edges, predict_edges, average='macro'))
    #print("RECALL： ",recall_score(true_edges, predict_edges, average='macro'))
    #print("F1 SCORE： ",f1_score(true_edges, predict_edges, average='macro'))
    
    return ap, auc

# model

In [12]:
"""  Given the cluster centroids, update the node representation  """
class reconstruction_graph(nn.Module):
    def __init__(self, NE, alpha=1.0):
        super(reconstruction_graph, self).__init__()
        self.alpha = alpha
        self.nodes_embedding = Parameter(NE)
        
    def forward(self, CC):
        # Calculate the attention score (in nodes) for each node to each cluster
        norm_squared = torch.sum((self.nodes_embedding.unsqueeze(1) - CC)**2, 2)   # Distance from the node to the center of mass
        numerator = 1.0 / (1.0 + (norm_squared / DEC))
        power = float(self.alpha + 1) / 2   
        numerator = numerator**power    
        soft_assignments = (numerator.t() / torch.sum(numerator, 1)).t()       #soft assignment using t-distribution
        
        # Calculate the cluster similarity between nodes (cosine similarity): the larger the value, the more conducive to edge formation
        prod = torch.mm(soft_assignments, soft_assignments.t())       #  numerator
        norm = torch.norm(soft_assignments,p=2,dim=1).unsqueeze(0)      #    denominator 
        clusters_similar = prod.div(torch.mm(norm.t(),norm))
        
        # Calculate the distance between nodes: smaller values favor edge formation
        nodes_distance = torch.norm(self.nodes_embedding[:, None]-self.nodes_embedding, dim=2, p=2)
        nodes_distance = torch.div(nodes_distance, torch.max(nodes_distance))
        
        # The formation probability of the edges is calculated
        distance_similar = torch.div(beta*nodes_distance, clusters_similar)      
        nodes_similar = torch.exp(-distance_similar)
        
        return nodes_similar

class update_nodes_embedding(nn.Module):
    def __init__(self, ne):
        super(update_nodes_embedding, self).__init__()
        NE = ne
        self.reconstruction_module = reconstruction_graph(NE)     
        self.optimizer = torch.optim.SGD(params=self.reconstruction_module.parameters(), lr=LR, momentum=0.9)
        self.loss_function = torch.nn.MSELoss(reduction='sum')
        
    def forward(self, g, CC, edge_train, edge_test):
        self.reconstruction_module.train()
        for epoch in range(5):
            self.optimizer.zero_grad()
            graph_reconstruction = self.reconstruction_module(CC)
            graph_train = torch.take(graph_reconstruction, edge_train)
            loss = self.loss_function(g, graph_train)
            loss.backward()
            self.optimizer.step()
            #print(f'Epoch: {epoch:02d}, Loss: {loss.item():.4f}')
        recons_test_edges = torch.take(graph_reconstruction, edge_test).detach()
        test_ap, test_auc = predict_acc(recons_test_edges, test_edge)
        return self.reconstruction_module.nodes_embedding.detach(), test_ap, test_auc
    
"""  Type-centroid update aimed at maximizing the prediction effect   """
class ClusteringLayer(nn.Module):
    """  Given the node representation, update the cluster centroids  """
    def __init__(self, cc, alpha=1.0):
        super(ClusteringLayer, self).__init__()
        self.alpha = alpha
        self.cluster_centers = Parameter(cc)
    
    def forward(self, NE):
        norm_squared = torch.sum((NE.unsqueeze(1) - self.cluster_centers)**2, 2) 
        numerator = 1.0 / (1.0 + (norm_squared / DEC))
        power = float(self.alpha + 1) / 2   
        numerator = numerator**power   
        soft_assignments = (numerator.t() / torch.sum(numerator, 1)).t() 
        
        prod = torch.mm(soft_assignments, soft_assignments.t())
        norm = torch.norm(soft_assignments,p=2,dim=1).unsqueeze(0)
        clusters_similar = prod.div(torch.mm(norm.t(),norm))
        
        nodes_distance = torch.norm(NE[:, None]-NE, dim=2, p=2)
        nodes_distance = torch.div(nodes_distance, torch.max(nodes_distance))
        
        distance_similar = torch.div(beta*nodes_distance, clusters_similar)      
        nodes_similar = torch.exp(-distance_similar)
        
        return nodes_similar

def find_cluster_centers(NE, CC, g, edge_train, edge_test, num_epochs=5):        
    clusteringlayer = ClusteringLayer(CC)
    optimizer = torch.optim.SGD(params=clusteringlayer.parameters(), lr=LR, momentum=0.9)
    loss_function = torch.nn.MSELoss(reduction='sum')
    
    for epoch in range(num_epochs):
        optimizer.zero_grad()
        graph_recons = clusteringlayer(NE)
        graph_train = torch.take(graph_recons, edge_train)
        loss = loss_function(g, graph_train)
        loss.backward()
        optimizer.step()
        #print(f'Epoch: {epoch:02d}, Loss: {loss.item():.4f}')     
    recons_test_edges = torch.take(graph_recons, edge_test).detach()
    test_ap, test_auc = predict_acc(recons_test_edges, test_edge)
    return clusteringlayer.cluster_centers.detach(), test_ap, test_auc

In [13]:
def get_cora_graph(filename_adj, nodes_number):
    edge_index = pd.read_csv(filename_adj, header=None, sep=' ')
    edges = torch.from_numpy(edge_index.values).t().contiguous()
    
    #print("edges number ： ", edges.shape[1])
    
    graph_np = np.zeros((nodes_number, nodes_number))
    for i in range(edges.shape[1]):
        graph_np[edges[0,i], edges[1,i]]=1
        graph_np[edges[1,i], edges[0,i]]=1

    posi_edge = np.argwhere(graph_np == 1)
    posi_edge = posi_edge[posi_edge[:,0]<posi_edge[:,1]]         # Only the upper-left corner matrix is taken
    posi_edge_number = posi_edge.shape[0]
    
    nega_edge = np.argwhere(graph_np == 0) 
    nega_edge = nega_edge[nega_edge[:,0]<nega_edge[:,1]]     # Only the upper-left corner matrix is taken
    
    positive_index = np.random.choice(range(posi_edge_number),int(posi_edge_number*0.9),replace=False)
    choose_positive = posi_edge[positive_index]
    posi_not_choose = np.setdiff1d(range(posi_edge_number), positive_index)
    test_positive = posi_edge[posi_not_choose]
    
    negative_index = np.random.choice(range(nega_edge.shape[0]),int(posi_edge_number*0.9)*6,replace=False)
    choose_negative = nega_edge[negative_index]
    nega_not_choose = np.setdiff1d(range(nega_edge.shape[0]), negative_index)
    test_nega_index = np.random.choice(nega_not_choose,len(posi_not_choose),replace=False)
    test_negative = nega_edge[test_nega_index]
    
    train_posi = [list(choose_positive[i]) for i in range(len(choose_positive))]
    train_nega = [list(choose_negative[i]) for i in range(len(choose_negative))]
    train_index = train_posi + train_nega
    train_mask = [train_index[i][0]*nodes_number+train_index[i][1] for i in range(len(train_index))]
    train_mask = torch.tensor(train_mask)
    
    test_posi = [list(test_positive[i]) for i in range(len(test_positive))]
    test_nega = [list(test_negative[i]) for i in range(len(test_negative))]
    test_index = test_posi + test_nega
    test_mask = [test_index[i][0]*nodes_number+test_index[i][1] for i in range(len(test_index))]
    test_mask = torch.tensor(test_mask)
    
    graph_tensor = torch.from_numpy(graph_np).float()
    train_edge = torch.take(graph_tensor, train_mask)
    test_edge = np.array(torch.take(graph_tensor, test_mask))
    
    return train_edge, test_edge, train_mask, test_mask

In [14]:
filename = 'datasets/Celegans.txt'    # The node index starts at 0, sep=None
nodes_number = 297
train_edge, test_edge, train_mask, test_mask = get_cora_graph(filename, nodes_number)

In [15]:
embedding_dim = 12
n_clusters = 24
alpha = 1
LR = 0.4

all_best_ap = []
all_best_auc = []

for beta in [3.0, 3.2, 3.4, 3.6, 3.8, 4.0, 4.2, 4.4, 4.6, 4.8, 5.0, 5.2, 5.4, 5.6, 5.8, 6.0, 6.2, 6.4, 6.6, 6.8, 7.0]:
    print("########################################################## beta :   ", beta)
    for DEC in [3.0, 3.2, 3.4, 3.6, 3.8, 4.0, 4.2, 4.4, 4.6, 4.8, 5.0, 5.2, 5.4, 5.6, 5.8, 6.0, 6.2, 6.4, 6.6, 6.8, 7.0]:
        print("############################ DEC :   ", DEC)
        ini_embedding = Embedding(nodes_number, embedding_dim, sparse=True)       
        raw_nodes_embedding = ini_embedding.weight.detach()        # Initial node embedding

        kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(raw_nodes_embedding)
        cluster_centers = kmeans.cluster_centers_
        raw_cluster_centers = torch.tensor(cluster_centers, dtype=torch.float)     # Initial cluster centroids

        best_ap = 0
        best_auc = 0
        for module_epoch in range(300):
            #print("######################### Module cycle ： %d ##########################"%module_epoch)
            update_nodes_module = update_nodes_embedding(raw_nodes_embedding)
            raw_nodes_embedding, nodes_ap, nodes_auc = update_nodes_module(train_edge, raw_cluster_centers, train_mask, test_mask)
            if nodes_ap > best_ap:
                best_ap = nodes_ap
                best_auc = nodes_auc
            raw_cluster_centers, cluster_ap, cluster_auc = find_cluster_centers(raw_nodes_embedding, raw_cluster_centers, train_edge, train_mask, test_mask)
            if cluster_ap > best_ap:
                best_ap = cluster_ap
                best_auc = cluster_auc
            
        print("best ap: ", best_ap)
        print("best auc: ", best_auc)
        
        all_best_ap.append(best_ap)
        all_best_auc.append(best_auc)

########################################################## beta :    3.0
############################ DEC :    3.0
best ap:  0.8895011067925632
best auc:  0.8873769605191996
############################ DEC :    3.2
best ap:  0.8537891220877826
best auc:  0.8740292049756625
############################ DEC :    3.4
best ap:  0.8728835910104435
best auc:  0.8880475932936721
############################ DEC :    3.6
best ap:  0.8660516558600111
best auc:  0.8636668469442943
############################ DEC :    3.8
best ap:  0.872481163020585
best auc:  0.8972633856138453
############################ DEC :    4.0
best ap:  0.8906257441131259
best auc:  0.8911627906976743
############################ DEC :    4.2
best ap:  0.871860618707133
best auc:  0.8871389940508383
############################ DEC :    4.4
best ap:  0.8875988472373688
best auc:  0.8906003244997296
############################ DEC :    4.6
best ap:  0.8891430849011421
best auc:  0.8860573282855597
####################

best ap:  0.8498086530331778
best auc:  0.8700054083288263
############################ DEC :    6.4
best ap:  0.8476796612301647
best auc:  0.8655922120064901
############################ DEC :    6.6
best ap:  0.8639900784321125
best auc:  0.8522660897782586
############################ DEC :    6.8
best ap:  0.8942521628282676
best auc:  0.8905570578691184
############################ DEC :    7.0
best ap:  0.8494524847796986
best auc:  0.848523526230395
########################################################## beta :    3.8
############################ DEC :    3.0
best ap:  0.8673101648499886
best auc:  0.8826608977825852
############################ DEC :    3.2
best ap:  0.8693208900188995
best auc:  0.8702001081665764
############################ DEC :    3.4
best ap:  0.8379825373926151
best auc:  0.8652677122769065
############################ DEC :    3.6
best ap:  0.8776205877037668
best auc:  0.8851487290427258
############################ DEC :    3.8
best ap:  0.8760745

best ap:  0.8807647260852622
best auc:  0.8945159545700379
############################ DEC :    5.4
best ap:  0.8833526183569662
best auc:  0.8902541914548405
############################ DEC :    5.6
best ap:  0.8835045643474123
best auc:  0.8942779881016765
############################ DEC :    5.8
best ap:  0.8716469229490029
best auc:  0.8803244997295836
############################ DEC :    6.0
best ap:  0.8885665183616102
best auc:  0.8931963223363981
############################ DEC :    6.2
best ap:  0.8792383985424599
best auc:  0.8859491617090319
############################ DEC :    6.4
best ap:  0.8584047304349729
best auc:  0.8698107084910761
############################ DEC :    6.6
best ap:  0.8832806461451808
best auc:  0.8846727961060032
############################ DEC :    6.8
best ap:  0.8810896850255633
best auc:  0.8828339643050297
############################ DEC :    7.0
best ap:  0.871477018826319
best auc:  0.8874202271498107
#################################

best ap:  0.8875432653386921
best auc:  0.889821525148729
############################ DEC :    4.4
best ap:  0.8735293230200071
best auc:  0.8837425635478636
############################ DEC :    4.6
best ap:  0.8813276822370955
best auc:  0.8871822606814495
############################ DEC :    4.8
best ap:  0.8855666341665114
best auc:  0.8903407247160627
############################ DEC :    5.0
best ap:  0.8913285337165076
best auc:  0.8856679286100594
############################ DEC :    5.2
best ap:  0.902139421420353
best auc:  0.8959437533802055
############################ DEC :    5.4
best ap:  0.8909278448884822
best auc:  0.8859275283937262
############################ DEC :    5.6
best ap:  0.8680221332898594
best auc:  0.8799567333693888
############################ DEC :    5.8
best ap:  0.8995671215380018
best auc:  0.9009843158464035
############################ DEC :    6.0
best ap:  0.8904049597363697
best auc:  0.8845213628988642
############################ DEC :

best ap:  0.8838781810985871
best auc:  0.8798485667928609
############################ DEC :    3.4
best ap:  0.8796774419553459
best auc:  0.8729259058950783
############################ DEC :    3.6
best ap:  0.8826896076876691
best auc:  0.870827474310438
############################ DEC :    3.8
best ap:  0.8783642788740513
best auc:  0.8866414277988102
############################ DEC :    4.0
best ap:  0.8709389996326664
best auc:  0.8934559221200649
############################ DEC :    4.2
best ap:  0.8406333996617178
best auc:  0.841968631692807
############################ DEC :    4.4
best ap:  0.8578571450970383
best auc:  0.8459924283396432
############################ DEC :    4.6
best ap:  0.8654549095177878
best auc:  0.8795240670632775
############################ DEC :    4.8
best ap:  0.8927359430523905
best auc:  0.8872038939967549
############################ DEC :    5.0
best ap:  0.8872761257491022
best auc:  0.8813845321795565
############################ DEC :

best ap:  0.8884972830756201
best auc:  0.8938453217955652
############################ DEC :    6.8
best ap:  0.887551660399521
best auc:  0.8857111952406707
############################ DEC :    7.0
best ap:  0.8900671679656146
best auc:  0.8841752298539751
########################################################## beta :    6.8
############################ DEC :    3.0
best ap:  0.8974992322741346
best auc:  0.8897566252028124
############################ DEC :    3.2
best ap:  0.8842905134365366
best auc:  0.8848674959437534
############################ DEC :    3.4
best ap:  0.8805535395799555
best auc:  0.885494862087615
############################ DEC :    3.6
best ap:  0.8617400440284497
best auc:  0.8742671714440237
############################ DEC :    3.8
best ap:  0.865287037347161
best auc:  0.8688588426176311
############################ DEC :    4.0
best ap:  0.8839279103443913
best auc:  0.8820551649540292
############################ DEC :    4.2
best ap:  0.878074573

In [18]:
auc_data = np.array(all_best_auc).reshape(21, 21)
ap_data = np.array(all_best_ap).reshape(21, 21)

In [20]:
ap_data.max(), auc_data.max()

(0.9084502410404804, 0.9077988101676582)

In [35]:
ap_data.min()

0.8357985846927513

In [36]:
auc_data.min()

0.8343969713358571