In [1]:
from utils_new import *

In [2]:
class get_node_emb(nn.Module):
    def __init__(self, NUM_U, NUM_V, 
                 U_CLUSTER_ARC, V_CLUSTER_ARC, IN_DIM = 64):
        super(get_node_emb, self).__init__()
        
        get_ini_emb = get_ini_emds(IN_DIM, IN_DIM, NUM_V, NUM_U)
        V_embedding, U_embedding = get_ini_emb()
        self.V_emb = Parameter(V_embedding)
        self.U_emb = Parameter(U_embedding)
        
        self.U_CLUSTER_NETWORK = MLP(U_CLUSTER_ARC)
        self.V_CLUSTER_NETWORK = MLP(V_CLUSTER_ARC)
        
        self.STRUCTURE_CLUSTERS = torch.nn.Embedding(
            num_embeddings = U_CLUSTER_ARC[0], embedding_dim = IN_DIM)
        nn.init.normal_(self.STRUCTURE_CLUSTERS.weight, std = 1)

        
    def forward(self, idx_U, idx_V, ALPHA):
        
        lhs = self.U_emb[idx_U]     # (batch_size, out_features)
        rhs = self.V_emb[idx_V]     # (batch_size, 1 + neg_edges, out_features)
        
        U_DIS = F.cosine_similarity(lhs.unsqueeze(1), self.STRUCTURE_CLUSTERS.weight.unsqueeze(0), dim=2)
        U_numerator = self.U_CLUSTER_NETWORK(U_DIS)       #soft assignment
        U_soft_assignments = (U_numerator.t() / torch.sum(U_numerator, 1)).t()       #soft assignment
        
#         rhs_flat = rhs.view(-1, rhs.shape[-1])  # shape: (15, 8) - 将5*3个二维张量视为15个向量
#         CLUSTERS_flat = self.STRUCTURE_CLUSTERS.weight.unsqueeze(0).expand(rhs_flat.shape[0], -1, -1).contiguous()  # shape: (15, 4, 8)
#         cos_sim = F.cosine_similarity(rhs_flat.unsqueeze(1), CLUSTERS_flat, dim=-1)  # shape: (15, 4)
#         V_DIS = cos_sim.view(rhs.shape[0], rhs.shape[1], self.STRUCTURE_CLUSTERS.weight.shape[0])  # shape: (5, 3, 4)

        V_DIS = F.cosine_similarity(rhs.unsqueeze(1), self.STRUCTURE_CLUSTERS.weight.unsqueeze(0), dim=2)
        V_assignments = self.V_CLUSTER_NETWORK(V_DIS) # (batch_size, 1 + neg_edges, cluster_numbers)
        
#         PROB_CLUSTER = torch.sum(U_soft_assignments.unsqueeze(1) * V_assignments, dim=2)
        CLUSTER_SIM = torch.mul(U_soft_assignments, V_assignments)
        PROB_CLUSTER = torch.sum(CLUSTER_SIM, dim=1)
        
#         distance_U_V = F.cosine_similarity(lhs.unsqueeze(1), rhs, dim=2) + 1
#         PROB_U_V = distance_U_V / 2.0  # (batch_size, 1 + neg_edges)
        
        distance_U_V = torch.cosine_similarity(lhs, rhs) + 1.0
        PROB_U_V = distance_U_V / 2.0
        
        LINK_PROB = ALPHA*PROB_CLUSTER + (1-ALPHA)*PROB_U_V

        return LINK_PROB

In [3]:
class update_nodes_embedding(nn.Module):
    def __init__(self, NODES_U_NUMBER, NODES_V_NUMBER,
                 U_CLUSTERS_DIM, V_CLUSTERS_DIM):
        super(update_nodes_embedding, self).__init__()
        self.prediction_module = get_node_emb(NODES_U_NUMBER,
                                              NODES_V_NUMBER,
                                              U_CLUSTERS_DIM, 
                                              V_CLUSTERS_DIM)
        self.optimizer = torch.optim.Adam(params=self.prediction_module.parameters(), lr=5e-4)
        self.loss_function = torch.nn.MSELoss(reduction='sum')

    def pairwise_distances(self, clusters):
        norm_squared = torch.sum(clusters**2, dim=1, keepdim=True)  # 每行的范数平方
        distances = torch.sqrt(torch.clamp(norm_squared - 2 * torch.matmul(clusters, clusters.transpose(0, 1)) + norm_squared.transpose(0, 1), min=1e-12))
        return distances

    def forward(self, DATA_TRAIN, DATA_TEST, SET_ALPHA):
        
        best_predict_ap = 0
        best_predict_auc = 0
        
        self.prediction_module.train()
        
        for epoch in range(40):
            total_loss = 0
            
#             PRED_U_EMDS = self.prediction_module.U_emb.detach()
#             self.prediction_module.STRUCTURE_CLUSTERS = PRED_U_EMDS[self.structure_centers] + torch.sparse.mm(CLUSTER_GRAPH, PRED_U_EMDS)

            ##################### 批训练 #################
            for batch_id, (edge_U, edge_V, edge_label) in enumerate(DATA_TRAIN):

                nodes_similar = self.prediction_module(edge_U, edge_V, SET_ALPHA)
                
                pair_loss = self.loss_function(nodes_similar, edge_label)
                
                cluster_distances = self.pairwise_distances(self.prediction_module.STRUCTURE_CLUSTERS.weight)
                cluster_loss = torch.sum(cluster_distances)
                
                loss = pair_loss - cluster_loss*5e-3
                
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                total_loss += loss.item()
#                 if train_step%100 == 0:
#                     print('iter loss: ', loss.item())
            
            print(f'Epoch: {epoch:02d}, Loss: {total_loss:.4f}')
            
            FINAL_U_EMDS = self.prediction_module.U_emb.detach()
            FINAL_V_EMDS = self.prediction_module.V_emb.detach()

            all_predict_edges = []
            all_true_edges = []
            for _, (test_idx_U, test_idx_V, test_edge_label) in enumerate(DATA_TEST):
        
                test_lhs = FINAL_U_EMDS[test_idx_U]
                test_rhs = FINAL_V_EMDS[test_idx_V]

                predict_test_edges = predict_edges(test_lhs, test_rhs, self.prediction_module)

                predict_label = np.array(predict_test_edges)
                true_label = np.array(test_edge_label)
                all_predict_edges = np.concatenate((all_predict_edges, predict_label), axis=0)
                all_true_edges = np.concatenate((all_true_edges, true_label),axis=0)

            predict_ap, predict_auc = computer_prediction(all_true_edges, all_predict_edges)
            print('epoch: ', epoch, 'predict_auc_roc = ', predict_auc, 'predict_auc_pr = ', predict_ap)
            
            if best_predict_ap < predict_ap:
                best_predict_ap = predict_ap
            if best_predict_auc < predict_auc:
                best_predict_auc = predict_auc
                
        return best_predict_ap, best_predict_auc

In [4]:
set_random_seed(42)
NUMBER_U, NUMBER_V, train_data_loader, test_data_loader = get_LP_train_test_data()
################ 超参数 #####################
cluster_number = 32
U_hidden_dim, U_output_dim, V_hidden_dim, V_output_dim = 12, 32, 12, 32
U_network_arc = [cluster_number, U_hidden_dim, U_output_dim]
V_network_arc = [cluster_number, V_hidden_dim, V_output_dim]

ALL_BEST_AP = []
ALL_BEST_AUC = []

if __name__ == '__main__':
    for set_alpha in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
        print('############################## alpha: ', set_alpha)
        
        UPDATE_NODES_MODULE = update_nodes_embedding(NUMBER_U,
                                                     NUMBER_V,
                                                     U_network_arc,
                                                     V_network_arc)
        best_ap, best_auc = UPDATE_NODES_MODULE(train_data_loader, 
                                                test_data_loader, 
                                                set_alpha)
        ALL_BEST_AP.append(best_ap)
        ALL_BEST_AUC.append(best_auc)

train: (15000, 3214) 38457
test: (15000, 3214) 25638
max_sample: 358932
initial neg_g: torch.Size([15000, 3214]) 358932
pos_g: torch.Size([15000, 3214]) 64095
lp_pos_train: (38457, 2)
lp_pos_test: (25638, 2)
lp_neg: torch.Size([357131, 2])
lp_neg_train: torch.Size([153828, 2])
lp_neg_test: torch.Size([25638, 2])
train_edge_index: torch.Size([192285, 2])
test_edge_index: torch.Size([51276, 2])
############################## alpha:  0.1
Epoch: 00, Loss: -4165.2088
epoch:  0 predict_auc_roc =  0.733781374173133 predict_auc_pr =  0.7829248301196684
Epoch: 01, Loss: -26587.2374
epoch:  1 predict_auc_roc =  0.8276850841111631 predict_auc_pr =  0.8719070922021832
Epoch: 02, Loss: -46982.4085
epoch:  2 predict_auc_roc =  0.8608828327115874 predict_auc_pr =  0.8956734921683304
Epoch: 03, Loss: -67414.3612
epoch:  3 predict_auc_roc =  0.8873655254788354 predict_auc_pr =  0.9148645317007713
Epoch: 04, Loss: -88037.9266
epoch:  4 predict_auc_roc =  0.9081490278385028 predict_auc_pr =  0.9290693559

Epoch: 27, Loss: -473295.4778
epoch:  27 predict_auc_roc =  0.9664281324817203 predict_auc_pr =  0.9687339052288779
Epoch: 28, Loss: -489295.6641
epoch:  28 predict_auc_roc =  0.9662413225256719 predict_auc_pr =  0.9685843751151426
Epoch: 29, Loss: -505294.7141
epoch:  29 predict_auc_roc =  0.9662848699062475 predict_auc_pr =  0.9686837190755113
Epoch: 30, Loss: -521290.6477
epoch:  30 predict_auc_roc =  0.9664477374747256 predict_auc_pr =  0.9687750677531317
Epoch: 31, Loss: -537290.6913
epoch:  31 predict_auc_roc =  0.9665151359613303 predict_auc_pr =  0.9688413554294371
Epoch: 32, Loss: -553287.5905
epoch:  32 predict_auc_roc =  0.9665558817592711 predict_auc_pr =  0.9689808313743432
Epoch: 33, Loss: -569283.8569
epoch:  33 predict_auc_roc =  0.9664218127563532 predict_auc_pr =  0.9688450880223893
Epoch: 34, Loss: -585278.0525
epoch:  34 predict_auc_roc =  0.9665180881585075 predict_auc_pr =  0.9689759647585292
Epoch: 35, Loss: -601276.8081
epoch:  35 predict_auc_roc =  0.9665272178

Epoch: 18, Loss: -329949.9969
epoch:  18 predict_auc_roc =  0.9664299215999274 predict_auc_pr =  0.9683942649740164
Epoch: 19, Loss: -345990.1028
epoch:  19 predict_auc_roc =  0.9665565869091767 predict_auc_pr =  0.9685853210486832
Epoch: 20, Loss: -362031.1618
epoch:  20 predict_auc_roc =  0.966511855911284 predict_auc_pr =  0.9685501623705121
Epoch: 21, Loss: -378067.2207
epoch:  21 predict_auc_roc =  0.9666792685398333 predict_auc_pr =  0.9687029257085847
Epoch: 22, Loss: -394093.2980
epoch:  22 predict_auc_roc =  0.9667402301564259 predict_auc_pr =  0.9688029194854423
Epoch: 23, Loss: -410124.8883
epoch:  23 predict_auc_roc =  0.9667620966496138 predict_auc_pr =  0.9688778855536523
Epoch: 24, Loss: -426154.3988
epoch:  24 predict_auc_roc =  0.9667880867864243 predict_auc_pr =  0.9689352484080516
Epoch: 25, Loss: -442188.2594
epoch:  25 predict_auc_roc =  0.966803008275688 predict_auc_pr =  0.9689234867180372
Epoch: 26, Loss: -458215.3488
epoch:  26 predict_auc_roc =  0.966937079560

epoch:  8 predict_auc_roc =  0.9637314361414329 predict_auc_pr =  0.9655233375537059
Epoch: 09, Loss: -181319.2560
epoch:  9 predict_auc_roc =  0.9653364638216171 predict_auc_pr =  0.9671229707141394
Epoch: 10, Loss: -197987.2308
epoch:  10 predict_auc_roc =  0.965980918348412 predict_auc_pr =  0.9678971968762153
Epoch: 11, Loss: -214438.0113
epoch:  11 predict_auc_roc =  0.9658615821862393 predict_auc_pr =  0.9679214208480437
Epoch: 12, Loss: -230675.8977
epoch:  12 predict_auc_roc =  0.9657936284035928 predict_auc_pr =  0.9676743306641387
Epoch: 13, Loss: -246836.5736
epoch:  13 predict_auc_roc =  0.965744525019878 predict_auc_pr =  0.9676864707602726
Epoch: 14, Loss: -262944.7976
epoch:  14 predict_auc_roc =  0.965663995683576 predict_auc_pr =  0.9676417450895053
Epoch: 15, Loss: -279027.8570
epoch:  15 predict_auc_roc =  0.9657821733308551 predict_auc_pr =  0.9677373286912649
Epoch: 16, Loss: -295094.2177
epoch:  16 predict_auc_roc =  0.9658388629713208 predict_auc_pr =  0.96792923

Epoch: 39, Loss: -663329.5212
epoch:  39 predict_auc_roc =  0.9653191674909238 predict_auc_pr =  0.967830852154096
############################## alpha:  0.8
Epoch: 00, Loss: -18057.6958
epoch:  0 predict_auc_roc =  0.9080119549122008 predict_auc_pr =  0.9166254201871282
Epoch: 01, Loss: -46069.6400
epoch:  1 predict_auc_roc =  0.9249841722371683 predict_auc_pr =  0.9298477252265971
Epoch: 02, Loss: -63102.2440
epoch:  2 predict_auc_roc =  0.935262835095983 predict_auc_pr =  0.9426529558053685
Epoch: 03, Loss: -80029.8888
epoch:  3 predict_auc_roc =  0.9427358236252219 predict_auc_pr =  0.9503102646260702
Epoch: 04, Loss: -96862.0980
epoch:  4 predict_auc_roc =  0.94961810191707 predict_auc_pr =  0.9555686745979526
Epoch: 05, Loss: -113738.2172
epoch:  5 predict_auc_roc =  0.9542627737304454 predict_auc_pr =  0.9590672717477718
Epoch: 06, Loss: -130234.5550
epoch:  6 predict_auc_roc =  0.95724335566378 predict_auc_pr =  0.9612254891274753
Epoch: 07, Loss: -146500.8219
epoch:  7 predict

Epoch: 30, Loss: -514662.5784
epoch:  30 predict_auc_roc =  0.957890941908117 predict_auc_pr =  0.9600235104689958
Epoch: 31, Loss: -530692.5010
epoch:  31 predict_auc_roc =  0.9574287621661332 predict_auc_pr =  0.9596514085108108
Epoch: 32, Loss: -546733.0613
epoch:  32 predict_auc_roc =  0.9572342100140342 predict_auc_pr =  0.95939604800648
Epoch: 33, Loss: -562774.1406
epoch:  33 predict_auc_roc =  0.9572776760019022 predict_auc_pr =  0.9594212568971464
Epoch: 34, Loss: -578816.8741
epoch:  34 predict_auc_roc =  0.9568589789827354 predict_auc_pr =  0.9591293968958204
Epoch: 35, Loss: -594853.2318
epoch:  35 predict_auc_roc =  0.9569516236615898 predict_auc_pr =  0.9592596171640361
Epoch: 36, Loss: -610895.9033
epoch:  36 predict_auc_roc =  0.9567611358201116 predict_auc_pr =  0.9589213180367173
Epoch: 37, Loss: -626927.3403
epoch:  37 predict_auc_roc =  0.956574093096133 predict_auc_pr =  0.9587511165667543
Epoch: 38, Loss: -642956.1584
epoch:  38 predict_auc_roc =  0.95636374010925

In [5]:
ALL_BEST_AP

[0.968349718289502,
 0.9694800805251294,
 0.9697757100207427,
 0.9692023477331011,
 0.9693022124274668,
 0.9691141052074046,
 0.9696725454386171,
 0.9688470789056831,
 0.963047535502504]

In [6]:
ALL_BEST_AUC

[0.9658180629508055,
 0.9670065044974628,
 0.9672115608729123,
 0.9669406037888133,
 0.9667984845146435,
 0.9667135508439797,
 0.9668339572822227,
 0.966015243250611,
 0.961806858409386]