In [5]:
from sklearn.metrics import roc_curve
from sklearn.metrics import auc
import matplotlib.pyplot as plt
import json
import ast
import numpy as np
from embedding import load_graph_network
from train import fit_transform
from embedData import embedData
import torch
from scipy.sparse import identity, csr_matrix
import pandas as pd
import torch_geometric
from torch_scatter import scatter_add
from tqdm import tqdm
import os
from RWR import RWRPyG, get_top_nodes
from utils import classify_element
from matplotlib.ticker import MaxNLocator
from adjustText import adjust_text 
import seaborn as sns
import matplotlib
matplotlib.use('Agg')  # 强制使用Agg渲染器（支持PNG和PDF）

## random walk with restart

In [6]:
adj_path = '../Data/merged_df_long_convert.txt'

In [7]:
relationship = pd.read_csv(adj_path, sep='\t',
    dtype={'node1': str, 'node2': str, 'relationship': float} )
relationship.head(2)

Unnamed: 0,node1,node2,relationship
0,1002133,I802,1.0
1,1002133,M1997,1.0


In [8]:
feature_matrix = np.load('../Data/UKB_feature_all_GAT.npy')

In [9]:
# 读取graph node names
with open('../Data/keys.json', 'r') as f:
    loaded_keys = json.load(f)

In [10]:
# 初始化RWRPyG
print("初始化RWRPyG...")
rwr = RWRPyG(relationship=relationship, feature_matrix=feature_matrix, node_names=loaded_keys, device='cuda')

# 构建图
rwr.build_graph(k_neighbors=10, metric='minkowski', p=2, symmetrize=True)

初始化RWRPyG...
构建相似性图 (k=10, metric=minkowski)...
图构建完成: 49373 个节点, 3836410 条边


<RWR.RWRPyG at 0x7fb99866f680>

In [11]:
try:
    select_index = rwr.name_to_idx["C920"]
    print(f"找到节点，索引为 {select_index}")
except KeyError:
    print(f"错误：未找到节点")
    exit()

找到节点，索引为 2273


In [12]:
rwr_scores = rwr.compute_rwr(seeds=[select_index], alpha=0.5)

计算RWR (alpha=0.5, num of seeds=1)...
RWR在第 15 次迭代后收敛 (diff=3.829719e-07)


In [13]:
# 获取统计信息
top_nodes, scores, stats = get_top_nodes(
    rwr, rwr_scores, select_index, top_k=18903, node_type='icd10', return_stats=True
)

# 打印统计结果
print(f"\n从节点 {rwr.node_names[select_index]} 出发的可达节点类型统计：")
print(f"icd10: {stats['icd10']}")
print(f"eid: {stats['eid']}")
print(f"protein: {stats['protein']}")
print(f"metabolite: {stats['metabolite']}")
print(f"total: {stats['total']}")


从节点 C920 出发的可达节点类型统计：
icd10: 18903
eid: 27295
protein: 2923
metabolite: 251
total: 49372


In [14]:
print("\n最可及的top n nodes：")
print("排名\t节点名称\tRWR得分\t\t类型")
for rank, (name, score) in enumerate(zip(top_nodes, scores), 1):
    print(f"{rank}\t{name}\t{score:.6f}\t\t{classify_element(name)}")


最可及的top n nodes：
排名	节点名称	RWR得分		类型
1	C942	0.015074		icd10
2	C925	0.011169		icd10
3	T0220	0.010845		icd10
4	C926	0.010808		icd10
5	C944	0.010718		icd10
6	D720	0.010713		icd10
7	C932	0.010546		icd10
8	C940	0.010462		icd10
9	D46	0.010303		icd10
10	C928	0.010211		icd10
11	C910	0.010138		icd10
12	C930	0.009717		icd10
13	C933	0.009653		icd10
14	C945	0.009221		icd10
15	C92	0.009110		icd10
16	C927	0.008943		icd10
17	C924	0.008708		icd10
18	C901	0.008415		icd10
19	C922	0.008355		icd10
20	C950	0.008354		icd10
21	C93	0.007772		icd10
22	L908	0.007411		icd10
23	C943	0.007223		icd10
24	C921	0.006321		icd10
25	C929	0.006001		icd10
26	Z856	0.003218		icd10
27	M706	0.002185		icd10
28	C937	0.001667		icd10
29	C912	0.001612		icd10
30	C946	0.001541		icd10
31	C94	0.001486		icd10
32	D466	0.001135		icd10
33	C91	0.001061		icd10
34	C90	0.001015		icd10
35	C424	0.000962		icd10
36	C913	0.000945		icd10
37	C931	0.000876		icd10
38	D463	0.000852		icd10
39	C939	0.000836		icd10
40	C914	0.000750		icd10
41	E002	0.000744		

In [15]:
def normalize_scores(scores):
    """将分数数组归一化到0-1范围"""
    min_val = np.min(scores)
    max_val = np.max(scores)
    # 防止除零错误
    if max_val == min_val:
        return np.zeros_like(scores)
    return (scores - min_val) / (max_val - min_val)

In [21]:
def plot_grouped_heatmap(groups, 
                         group_labels, 
                         top_k=20, 
                         x_label="Diseases", 
                         y_label="Scores", 
                         normalize=True, 
                         row_normalize=True,
                         cluster_rows=True, 
                         cluster_cols=False,
                         df_file = None,
                         figsize=20):
    """
    绘制多组节点的热力图，展示所有分组前top_k个节点的并集
    
    参数:
    groups (list): 包含多个元组的列表，每个元组格式为 (top_nodes, scores)
    group_labels (list): 分组标签列表
    top_k (int): 每个分组取前多少个节点
    x_label (str): x轴标签
    y_label (str): y轴标签
    normalize (bool): 是否对每组分数进行归一化
    """
    # 1. 找出所有分组前top_k个节点的并集
    all_top_nodes = set()
    for nodes, scores in groups:
        top_indices = np.argsort(scores)[-top_k:]
        all_top_nodes.update([nodes[idx] for idx in top_indices])
    all_top_nodes = list(all_top_nodes)  # 转换为有序列表
    
    # 2. 构建热力图数据矩阵
    data_matrix = np.zeros((len(all_top_nodes), len(groups)))
    
    for group_idx, (nodes, scores) in enumerate(groups):
        if normalize:
            group_scores = normalize_scores(scores)
        else:
            group_scores = scores
            
        # 为每个节点在矩阵中找到对应位置
        for node_idx, node in enumerate(nodes):
            if node in all_top_nodes:
                matrix_row_idx = all_top_nodes.index(node)
                data_matrix[matrix_row_idx, group_idx] = group_scores[node_idx]
    
    # 3. 绘制热力图
    # 行归一化处理
    if row_normalize:
        # 对每行进行min-max归一化
        row_min = np.min(data_matrix, axis=1, keepdims=True)
        row_max = np.max(data_matrix, axis=1, keepdims=True)
        
        # 避免除零错误
        row_range = row_max - row_min
        row_range[row_range == 0] = 1  # 处理所有值相同的行
        
        data_matrix = (data_matrix - row_min) / row_range
        
    # 4. 转换为DataFrame，添加行列标签
    df = pd.DataFrame(data_matrix, index=all_top_nodes, columns=group_labels)
    df.to_csv(df_file, sep='\t', na_rep='nan')
    # 5. 绘制聚类热力图
    g = sns.clustermap(
        df,
        method='average',          # 聚类方法
        metric='euclidean',        # 距离度量
        cmap=sns.diverging_palette(240, 10, as_cmap=True),  # 颜色映射
        annot=False,               # 不直接标注分数
        linewidths=0,
        row_cluster=cluster_rows,  # 是否对行聚类
        col_cluster=cluster_cols,  # 是否对列聚类
        figsize=(12, min(figsize, len(all_top_nodes) * 0.6)),  # 增大高度确保显示所有行名
        #cbar_pos=None,             # 不自动生成colorbar
        cbar_kws={'label': y_label + (' (Row-Normalized)' if row_normalize else '')}
    )
    g.ax_row_dendrogram.remove()  # 移除行聚类树
    g.ax_col_dendrogram.remove()  # 移除列聚类树


    
    # 6. 增强可读性
    plt.setp(g.ax_heatmap.get_yticklabels(), rotation=0, fontsize=9)  # y轴标签水平显示
    plt.setp(g.ax_heatmap.get_xticklabels(), rotation=45, ha='right', fontsize=10)
    g.fig.suptitle(f'Top {top_k} Nodes Clustered Heatmap Across Groups', fontsize=14, y=0.98)
    
    # 调整布局
    plt.tight_layout(rect=[0, 0, 0.9, 0.9])  # 为标题和颜色条留出空间
    
    return g

In [22]:
select_indices = ["C900", "C910", "C920", "C924"]
groups = []
    
for idx in select_indices:
    select_index = rwr.name_to_idx[idx]
    rwr_scores = rwr.compute_rwr(seeds=[select_index], alpha=0.5)
    # 获取当前select_index的top_nodes和scores
    top_nodes, scores = get_top_nodes(
        rwr, rwr_scores, select_index, top_k=2923, 
        node_type='protein', return_stats=False
    )
    # 将数据添加到分组列表
    groups.append((top_nodes, scores))

计算RWR (alpha=0.5, num of seeds=1)...
RWR在第 15 次迭代后收敛 (diff=5.864812e-07)
计算RWR (alpha=0.5, num of seeds=1)...
RWR在第 13 次迭代后收敛 (diff=8.383125e-07)
计算RWR (alpha=0.5, num of seeds=1)...
RWR在第 15 次迭代后收敛 (diff=3.764993e-07)
计算RWR (alpha=0.5, num of seeds=1)...
RWR在第 13 次迭代后收敛 (diff=7.376855e-07)


In [23]:
# 绘制热力图
g = plot_grouped_heatmap(groups=groups, 
                            group_labels=select_indices, 
                            top_k=50, 
                            normalize=True, 
                            row_normalize=True, 
                            cluster_rows=True, 
                            cluster_cols=False,
                            df_file="../Results/RiskFactor_protein_plot.txt",
                            figsize=50)
g.savefig("../Results/RiskFactor_protein_plot.pdf", dpi=300, bbox_inches='tight')
plt.close(g.fig)

In [33]:
RiskFactor_protein_plot = pd.read_csv("../Results/RiskFactor_protein_plot.txt", sep='\t', header=0)

In [35]:
# 设置第一列作为索引
RiskFactor_protein_plot.set_index('Unnamed: 0', inplace=True)

In [36]:
RiskFactor_protein_plot

Unnamed: 0_level_0,C900,C910,C920,C924
Unnamed: 0,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
slamf1,1.000000,0.000000,0.488958,0.009261
bmp6,0.717139,0.000000,1.000000,0.051329
asah1,0.351819,1.000000,0.837645,0.000000
ccl13,0.354414,0.121063,1.000000,0.000000
ubxn1,1.000000,0.101893,0.129224,0.000000
...,...,...,...,...
amot,1.000000,0.000000,0.050204,0.004007
rtbdn,1.000000,0.000000,0.081786,0.033231
ankmy2,0.000000,1.000000,0.805250,0.087074
lbr,0.035435,0.283282,0.000000,1.000000


In [55]:
# 将宽格式转换为长格式
RiskFactor_protein_plot_long = RiskFactor_protein_plot.reset_index().melt(
    id_vars='Unnamed: 0',  # 保留为标识符的列
    var_name='Disease',     # 变量列的新名称
    value_name='expression'    # 值列的新名称
)

# 重命名索引列为更有意义的名称
RiskFactor_protein_plot_long.rename(columns={'Unnamed: 0': 'Molecule'}, inplace=True)

In [56]:
RiskFactor_protein_plot_long[['type']] = 'protein'

In [57]:
RiskFactor_protein_plot_long.to_csv('../Results/RiskFactor_protein_plot_long.txt', sep='\t', na_rep='nan')

In [40]:
select_indices = ["C900","C910", "C920", "C924"]
groups = []
    
for idx in select_indices:
    select_index = rwr.name_to_idx[idx]
    rwr_scores = rwr.compute_rwr(seeds=[select_index], alpha=0.5)
    # 获取当前select_index的top_nodes和scores
    top_nodes, scores = get_top_nodes(
        rwr, rwr_scores, select_index, top_k=251, 
        node_type='metabolite', return_stats=False
    )
    # 将数据添加到分组列表
    groups.append((top_nodes, scores))

计算RWR (alpha=0.5, num of seeds=1)...
RWR在第 15 次迭代后收敛 (diff=6.209037e-07)
计算RWR (alpha=0.5, num of seeds=1)...
RWR在第 13 次迭代后收敛 (diff=8.398424e-07)
计算RWR (alpha=0.5, num of seeds=1)...
RWR在第 15 次迭代后收敛 (diff=3.901599e-07)
计算RWR (alpha=0.5, num of seeds=1)...
RWR在第 13 次迭代后收敛 (diff=7.362805e-07)


In [41]:
# 绘制热力图
g = plot_grouped_heatmap(groups=groups, 
                            group_labels=select_indices, 
                            top_k=10, 
                            normalize=True, 
                            row_normalize=True, 
                            cluster_rows=True, 
                            cluster_cols=False,
                            df_file="../Results/RiskFactor_metabolite_plot.txt",
                            figsize=10)
g.savefig("../Results/RiskFactor_metabolite_plot.pdf", dpi=300, bbox_inches='tight')
plt.close(g.fig)

In [42]:
RiskFactor_metabolite_plot = pd.read_csv("../Results/RiskFactor_metabolite_plot.txt", sep='\t', header=0)

In [43]:
# 设置第一列作为索引
RiskFactor_metabolite_plot.set_index('Unnamed: 0', inplace=True)

In [51]:
# 将宽格式转换为长格式
RiskFactor_metabolite_plot_long = RiskFactor_metabolite_plot.reset_index().melt(
    id_vars='Unnamed: 0',  # 保留为标识符的列
    var_name='Disease',     # 变量列的新名称
    value_name='expression'    # 值列的新名称
)

# 重命名索引列为更有意义的名称
RiskFactor_metabolite_plot_long.rename(columns={'Unnamed: 0': 'Molecule'}, inplace=True)

In [52]:
RiskFactor_metabolite_plot_long[['type']] = 'Metabolite'

In [53]:
RiskFactor_metabolite_plot_long.to_csv('../Results/RiskFactor_metabolite_plot_long.txt', sep='\t', na_rep='nan')

In [54]:
RiskFactor_metabolite_plot_long.head(3)

Unnamed: 0,Molecule,Disease,expression,type
0,Triglycerides to Total Lipids in Small VLDL pe...,C900,0.247785,Metabolite
1,Cholesterol to Total Lipids in Large LDL perce...,C900,0.417874,Metabolite
2,Triglycerides to Total Lipids in Chylomicrons ...,C900,0.317222,Metabolite


In [58]:
RiskFactor_protein_plot_long.head(3)

Unnamed: 0,Molecule,Disease,expression,type
0,slamf1,C900,1.0,protein
1,bmp6,C900,0.717139,protein
2,asah1,C900,0.351819,protein


In [59]:
RiskFactor_merge_plot_long = pd.concat([RiskFactor_metabolite_plot_long, RiskFactor_protein_plot_long], ignore_index=True)

In [60]:
RiskFactor_merge_plot_long.head(3)

Unnamed: 0,Molecule,Disease,expression,type
0,Triglycerides to Total Lipids in Small VLDL pe...,C900,0.247785,Metabolite
1,Cholesterol to Total Lipids in Large LDL perce...,C900,0.417874,Metabolite
2,Triglycerides to Total Lipids in Chylomicrons ...,C900,0.317222,Metabolite


In [61]:
RiskFactor_merge_plot_long.to_csv('../Results/RiskFactor_merge_plot_long.txt', sep='\t', na_rep='nan')