In [8]:
import os
print(os.getcwd())
os.chdir('/Users/boyuren/Documents/multi_head_graph_rag/MH-GRAG-V1')
print(os.getcwd())

/Users/boyuren/Documents/multi_head_graph_rag/MH-GRAG-V1/tests
/Users/boyuren/Documents/multi_head_graph_rag/MH-GRAG-V1


In [13]:
import torch
import pandas as pd
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_mutual_info_score
from src.gnn_clustering.train import train_model_multi_head
from src.gnn_clustering.evaluate import (
    get_embeddings,
    get_embeddings_list, 
    kmeans_clustering, 
    leiden_clustering, 
    compute_modularity, 
    format_communities, 
    random_clustering
)
from src.gnn_clustering.data_loader import load_random_data
from src.gnn_clustering.model import get_multi_head_model, get_model
from src.gnn_clustering.utils import get_device, get_dense_adj

# 设置设备
device = get_device()

def test_model_performance( num_nodes, num_edges):
    results = []

    # 加载随机数据
    data = load_random_data(num_nodes, num_edges)
    data = data.to(device)

    # 使用Leiden算法进行聚类，获取簇数和模块度
    communities_leiden, modularity_leiden = leiden_clustering(data)
    num_clusters_leiden = len(communities_leiden)

    # 初始模型和优化器
    num_heads = 3
    model = get_multi_head_model(data=data, device=device, num_heads=num_heads)
    single_head_model = get_model(data, device=device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    # 获取密集邻接矩阵
    adj = get_dense_adj(data.edge_index, device=device)

    # 初始嵌入和KMeans聚类模块度
    initial_embeddings = get_embeddings(single_head_model, data, device=device)
    clusters_kmeans_initial = kmeans_clustering(initial_embeddings, n_clusters=7)
    communities_kmeans_initial = format_communities(clusters_kmeans_initial, n_clusters=7)
    modularity_kmeans_initial = compute_modularity(data, communities_kmeans_initial)

    # 随机聚类模块度
    communities_random = random_clustering(data.num_nodes, n_clusters=7)
    modularity_random = compute_modularity(data, communities_random)

    # 模型训练
    model = train_model_multi_head(model, data, adj, optimizer, num_heads)
    embeddings_list = get_embeddings_list(model, data, device)

    embeddings=embeddings_list[0]
    from sklearn.metrics.pairwise import cosine_similarity
    # 计算相似度矩阵
    similarity_matrix = cosine_similarity(embeddings)
    from sklearn.neighbors import NearestNeighbors
    K = 10  # 设置 K 值，可根据数据集大小和实验需求调整
    nbrs = NearestNeighbors(n_neighbors=K, metric='cosine').fit(embeddings)
    distances, indices = nbrs.kneighbors(embeddings)
    edges = []
    weights = []
    num_nodes = embeddings.shape[0]
    for i in range(num_nodes):
        for j_idx in range(1, K):  # 从 1 开始，避免自环
            j = indices[i][j_idx]
            if i != j:
                edges.append((i, j))
                weights.append(similarity_matrix[i][j])
    import networkx as nx
    G = nx.Graph()
    for edge, weight in zip(edges, weights):
        G.add_edge(edge[0], edge[1], weight=weight)
    import igraph as ig
    # 将 NetworkX 图转换为 igraph 图
    G_ig = ig.Graph.TupleList(G.edges(data=True), weights=True, directed=False)
    import leidenalg
    # 使用模块度作为优化目标
    partition = leidenalg.find_partition(G_ig, leidenalg.ModularityVertexPartition)
    # 获取社区划分结果
    communities_leiden_embeddings = partition
    # 模块度已经包含在 partition 对象中
    modularity_leiden_embeddings = partition.modularity
    print(f'基于嵌入的 Leiden 算法的模块度: {modularity_leiden_embeddings:.4f}')
    print(f'基于嵌入的 Leiden 算法得到的簇数: {len(partition)}')


# 执行测试
df_report = test_model_performance(num_nodes=2048, num_edges=2048)

Epoch 1, Loss: 5.1786
Epoch 20, Loss: 3.4455
Epoch 40, Loss: 3.3657
Epoch 60, Loss: 3.3348
Epoch 80, Loss: 3.3161
Epoch 100, Loss: 3.3031
Epoch 120, Loss: 3.2933
Epoch 140, Loss: 3.2861
Epoch 160, Loss: 3.2806
Epoch 180, Loss: 3.2763
Epoch 200, Loss: 3.2729
基于嵌入的 Leiden 算法的模块度: 0.8440
基于嵌入的 Leiden 算法得到的簇数: 21


In [10]:
df_report

Unnamed: 0,num_nodes,num_edges,modularity_leiden,num_clusters_leiden,modularity_kmeans_initial,modularity_random,head_0_modularity,head_1_modularity,head_2_modularity,head_0_vs_1_mutual_info,head_0_vs_2_mutual_info,head_1_vs_2_mutual_info,avg_modularity,avg_mutual_info,modularity_percent
0,1024,1024,0.809295,188,0.370805,0.002599,0.706762,0.675788,0.700118,0.080159,0.078988,0.122128,0.694223,0.093758,85.781199
