In [None]:
import numpy as np
import os
import json
import pickle
import pandas as pd
from pathlib import Path
from scipy.linalg import solve as solve_dense, pinv
import matplotlib.pyplot as plt
import networkx as nx
from datetime import datetime
from matplotlib.lines import Line2D
from matplotlib.patches import Patch 
from collections import  Counter
import warnings
import functools

In [None]:
## 告警装饰器 ## 
def limit_warnings(max_count=10): 
    def decorator(func):
        @functools.wraps(func)   
        def wrapper(*args, **kwargs):
            if not hasattr(wrapper, '_warning_count'):
                wrapper._warning_count = 0
                wrapper._max_warnings = max_count
            return func(*args, **kwargs)
        
        def controlled_warn(message, category=None):   # 提供一个安全的 warn 方法
            if wrapper._warning_count < wrapper._max_warnings:    
                warnings.warn(message, category or UserWarning)
                wrapper._warning_count += 1
            else:
                pass  # 超过次数，静默忽略
        
        wrapper.warn = controlled_warn
        return wrapper
    return decorator

In [None]:
def find_maximal_clique_blocks(R: np.ndarray, snp_ids: np.ndarray, r_min: float = 0.8) :
    """
    对每个SNP，完全基于 R ≥ r_min 构建 maximal cliques 作为候选 block
    Args:
        R: LD 矩阵 (p x p), 已排序，对称
        snp_ids: SNP ID 数组 (p,)，内容是详细的文本
        r_min: 最小 R 阈值（正相位），目前设置为 0.8
    Returns:
        blocks: List of dict, dict 包含'snps'(索引数组), 'snp_ids'(ID列表), 'size'
    """
    p = R.shape[0]
    G = nx.Graph()
    G.add_nodes_from(range(p))
    
    # 依次添加正相位强 LD，双重循环保证全部连锁
    for i in range(p):
        for j in range(i + 1, p):
            if R[i, j] >= r_min:
                G.add_edge(i, j)
    
    # 找所有 maximal cliques，nx包保证功能实现
    cliques = list(nx.find_cliques(G)) ## 关键是这里的输出是什么

    blocks = []
    for clique in cliques:
        if len(clique) < 2:
            continue
        clique = sorted(clique)
        
        blocks.append({
            'snps': np.array(clique),             ## 这里是纯index
            'snp_ids': snp_ids[clique].tolist(),  ## 这里是根据index提取的详细文本
            'size': len(clique)
        })
    
    return blocks

In [None]:
def resolve_block_overlap(blocks, R):
    """
    解决 block 间 SNP 重叠，优先保留“内部 LD 内聚性高”者。
    严格保证输出 block 之间无 SNP 重叠
    注意，这里传入的R矩阵可能是numpy，注意['snps']是index不是名字（后续检查）
    Args:
        blocks: 所有候选 block 列表（来自 find_maximal_clique_blocks）
        R: LD 矩阵 (p, p)
        排序依据：'mean_r' ，不使用PC1，因为可能后面还有剪枝，导致PC1不可靠
    Returns:
        final_blocks: 无重叠的 block 列表（dict 格式同输入）
    """
    if not blocks:
        return []
    blocks_with_metric = []
    for blk in blocks:
        snps = blk['snps']
        if len(snps) < 2:
            continue
        R_sub = R[np.ix_(snps, snps)]
        upper_tri = R_sub[np.triu_indices_from(R_sub, k=1)]
        mean_r = upper_tri.mean() if len(upper_tri) > 0 else 0.0
        blk_with_metric = blk.copy()
        blk_with_metric['mean_r'] = mean_r
        blocks_with_metric.append(blk_with_metric)
    if not blocks_with_metric:
        return []
    # 按 mean_r 降序排序（优先保留内聚性强的）
    blocks_sorted = sorted(blocks_with_metric, key=lambda x: x['mean_r'], reverse=True)
    # 贪心选择：只要有任何 SNP 重叠，就跳过
    final_blocks = []
    used_snps = set()
    for block in blocks_sorted:
        member_set = set(block['snps'])
        if used_snps & member_set:  # 存在重叠 → 跳过
            continue
        final_blocks.append(block)
        used_snps |= member_set  # 添加当前 block 的所有 SNP

    return final_blocks

In [None]:
def evaluate_and_prune_block(
    block: dict,
    R: np.ndarray,
    pve_min: float = 0.7,
    min_size: int = 2
):
    """
    迭代修剪单个 block，直到其 PC1 解释方差R矩阵的比例 ≥ pve_min
    Z矩阵 的方向是结果，不是定义，因此不作为修剪依据，不参与修剪
    并且已经保证了 block 代表信号与 GWAS 总体方向一致，已经够了
    成功时返回包含 snps、loadings、pve 等字段的增强 block 字典；
    失败时返回 None。
    关键保证：
        - 每次更新后优先检查 PVE
        - 主动修复确保在病态下仍能推进
        - 返回结果可直接用于后续 enrich 函数
    """
    curr = block['snps'].copy().tolist()
    idx_to_id = dict(zip(block['snps'], block['snp_ids']))  # 提前构建 ID 映射

    while len(curr) >= min_size:
        n = len(curr)
        R_sub = R[np.ix_(curr, curr)]
        # 步骤1：尝试分解，检查是否满足 PVE
        try:
            eigvals = np.linalg.eigvalsh(R_sub)
            lambda_max = eigvals[-1]
            pve = lambda_max / n

            if pve >= pve_min:
                _, eigvecs = np.linalg.eigh(R_sub)
                pc1 = eigvecs[:, -1]  

                w = pc1 / np.linalg.norm(pc1)             ##  L2 归一化
                snp_ids_kept = [idx_to_id[i] for i in curr]
                return {
                    'snps': np.array(curr),
                    'snp_ids': snp_ids_kept,
                    'loadings': w,           # 归一化载荷，也就是PC1解释方差比例
                    'pve': float(pve),
                    'size': len(curr),
                    'mean_r': R_sub[np.triu_indices(n, k=1)].mean() if n > 1 else 0.0
                }

        except np.linalg.LinAlgError:
            pass  # 进入主动修复
        
        try:
            _, eigvecs = np.linalg.eigh(R_sub)
            pc1 = eigvecs[:, -1]
            outlier_idx = np.argmin(np.abs(pc1))
            curr.pop(outlier_idx)
            continue
        except np.linalg.LinAlgError:
            pass

        best_score = -1
        best_idx = None
        for i in range(len(curr)):
            trial = curr[:i] + curr[i+1:]
            if len(trial) < min_size:
                continue
            R_trial = R[np.ix_(trial, trial)]
            try:
                eigvals = np.linalg.eigvalsh(R_trial)
                pve_candidate = eigvals[-1] / len(trial)
                if pve_candidate > best_score:
                    best_score = pve_candidate
                    best_idx = i
            except:
                continue
        if best_idx is not None:
            curr.pop(best_idx)
        else:
            return None  # 无法修复
    return None

In [None]:
def enrich_block_with_pca1_info(
    block: dict,
    R_full: np.ndarray,
    z_gwas_full: np.ndarray,
    z_qtl_full: np.ndarray
) -> dict:
    """
    为已修剪的 block 添加可用于后续前向选择的统计量。
    关键转换：将 evaluate_and_prune_block 输出的 L2-unitized loadings
              转换为满足 Var(PC1) = 1 的建模尺度载荷。
    Args:
        block: 来自 evaluate_and_prune_block 的输出字典，包含:
               - 'snps': np.array of int, SNP 索引
               - 'snp_ids': list of str, SNP ID 列表
               - 'loadings': np.array, PC1 载荷（L2 归一化，即 ||w||=1）
               - 'pve': float, PC1 解释方差比例
        R_full: (p, p) float, 全局 LD 矩阵（对称、标准化）
        z_gwas_full: (p,) float, GWAS marginal Z 分数
        z_qtl_full: (p,) float, QTL marginal Z 分数（可选用途）
    Returns:
        enhanced_block: dict, 原始 block 的增强版本，新增字段：
            - 'loading_weights': np.array, 满足 w^T R_block w = 1 的载荷
            - 'z_gwas_block': float, block 代表 Z（基于标准化 PC1）
            - 'z_qtl_block': float, QTL 方向上的 block Z
            - 'r_to_others': np.array (p,), block 与所有 SNP 的加权相关性
    """
    snps_in_block = block['snps']
    if len(snps_in_block) == 0:
        raise ValueError("Block has no SNPs.")
    if R_full.shape[0] != R_full.shape[1]:
        raise ValueError("R_full must be square.")
    p = R_full.shape[0]
    if len(z_gwas_full) != p or len(z_qtl_full) != p:
        raise ValueError("z_gwas_full and z_qtl_full must have length p.")
    # --- 提取 block 内部信息 ---
    members_idx = np.array(snps_in_block)
    R_block = R_full[np.ix_(members_idx, members_idx)]  # (m, m)
    w_l2 = np.array(block['loadings'])  # L2-unitized from pruning step

    if len(w_l2) != len(members_idx):
        raise ValueError("Length of 'loadings' does not match number of SNPs in block.")
    
    # --- 关键：归一化使得 Var(PC1) = 1，即：w^T R_block w = 1 ---
    var_pc1 = w_l2 @ R_block @ w_l2
    if var_pc1 < 1e-10:
        raise ValueError(
            f"PC1 variance ({var_pc1:.2e}) too small. "
            "Likely due to near-collinear SNPs or numerical instability."
        )
    scaling_factor = np.sqrt(var_pc1)
    w_model = w_l2 / scaling_factor  # now w_model^T R_block w_model = 1
    
    var_pc1_normalized = w_model @ R_block @ w_model
    assert np.isclose(var_pc1_normalized, 1.0, atol=1e-5), \
        f"Normalization failed: Var(PC1) = {var_pc1_normalized:.3f} ≠ 1.0" ## 

    # --- 校正方向：确保与 GWAS 信号同向 ---
    z_sub = z_gwas_full[members_idx]
    if w_model @ z_sub < 0:
        w_model = -w_model
    # --- 计算 block-level Z 分数 ---
    z_gwas_block = float(w_model @ z_sub)
    z_qtl_block = float(w_model @ z_qtl_full[members_idx])
    # --- 计算 block 与所有 SNP 的加权 LD（相关性尺度）---
    # r_block,j = Σ_k w_k * r_k,j
    r_to_others = w_model @ R_full[members_idx, :]  # (p,)
    # --- 构建增强 block ---
    enhanced_block = block.copy()
    enhanced_block.update({
        'loading_weights': w_model,           # 用于建模的标准化载荷
        'z_gwas_block': z_gwas_block,         # 标准化后的 block Z
        'z_qtl_block': z_qtl_block,
        'r_to_others': r_to_others            # 长度为 p，包含所有 SNP的相互R
    })

    return enhanced_block

In [None]:
def build_enriched_blocks_pipeline(
    R: np.ndarray,
    z_gwas: np.ndarray,
    z_qtl: np.ndarray,
    snp_ids: np.ndarray
) -> dict:
    """
    从无到有构建 block，完成：构建 → 去重 → 修剪 → 信息增强
    输出：
        - enhanced blocks
        - 剩余 SNP 索引
        - block-block 相关性矩阵
        - 扩展的 LD 矩阵 R_extended (p+B, p+B)，支持 SNP + block 统一建模
        - 映射表：block 在扩展矩阵中的位置
    Returns:
        dict: {
            'blocks': list of enhanced_block,
            'remaining_snp_idx': np.array,
            'block_block_r_matrix': (B, B),
            'R_extended': (p+B, p+B),
            'block_positions_in_extended': list of int,  # 长度 B，表示每个 block 在 R_extended 中的列索引
        }
    """
    p = R.shape[0]

    assert len(z_gwas) == p and len(z_qtl) == p and len(snp_ids) == p, "输入维度不匹配"
    if p == 0:
        return {
            'blocks': [],
            'remaining_snp_idx': np.array([]),
            'block_block_r_matrix': np.array([]).reshape(0, 0),
            'R_extended': np.zeros((0, 0)),
            'block_positions_in_extended': []
        }

    # Step 1: 构建 raw blocks
    candidate_blocks = find_maximal_clique_blocks(R, snp_ids, r_min=0.8)
    if not candidate_blocks:
        return {
            'blocks': [],
            'remaining_snp_idx': np.arange(p),
            'block_block_r_matrix': np.array([]).reshape(0, 0),
            'R_extended': R.copy(),  # 原始 R
            'block_positions_in_extended': []
        }

    # Step 2: 去重函数（严格无重叠）
    valid_candidates = resolve_block_overlap(candidate_blocks, R)  

    # Step 3: 修剪（基于 PVE ≥ 0.7）
    pruned_blocks = []
    for blk in valid_candidates:
        pruned = evaluate_and_prune_block(
            block=blk,
            R=R,
            pve_min=0.7,
            min_size=2
        )
        if pruned is not None:
            pruned_blocks.append(pruned)

    if not pruned_blocks:
        return {
            'blocks': [],
            'remaining_snp_idx': np.arange(p),
            'block_block_r_matrix': np.array([]).reshape(0, 0),
            'R_extended': R.copy(),
            'block_positions_in_extended': []
        }

    # Step 4: 增强信息（添加 z_block, r_to_others 等）
    enriched_blocks = [
        enrich_block_with_pca1_info(blk, R, z_gwas, z_qtl)
        for blk in pruned_blocks
    ]

    # Step 5: 计算未被 block 覆盖的 SNP
    used_snps = set()
    for blk in enriched_blocks:
        used_snps.update(blk['snps']) ## 为block内部的snp index内容
        
    remaining_snp_idx = np.array(sorted(set(range(p)) - used_snps)) ## 生成所有未被使用的SNP_index

    # === Step 6: 计算 block-block 相关性矩阵 ===
    n_blocks = len(enriched_blocks)
    block_block_r_matrix = np.eye(n_blocks)
    for i in range(n_blocks):
        blk_i = enriched_blocks[i]
        w_i = blk_i['loading_weights']
        idx_i = blk_i['snps']
        for j in range(i + 1, n_blocks):
            blk_j = enriched_blocks[j]
            w_j = blk_j['loading_weights']
            idx_j = blk_j['snps']
            R_sub = R[np.ix_(idx_i, idx_j)]
            r_ij = w_i @ R_sub @ w_j
            block_block_r_matrix[i, j] = r_ij
            block_block_r_matrix[j, i] = r_ij

    # === Step 7: 全部计算完成，开始构建扩展 R 矩阵 R_extended: (p + B) x (p + B) ===
    B = n_blocks
    R_extended = np.zeros((p + B, p + B))

    # 1. 原始 SNP-SNP 相关性
    R_extended[:p, :p] = R    ## 原始赋值

    # 2. SNP - Block 相关性（利用每个 block 的 r_to_others）
    block_positions = []
    for b_idx, block in enumerate(enriched_blocks):
        pos_in_extended = p + b_idx                # block b_idx 放在第 p + b_idx 列
        block_positions.append(pos_in_extended)
        r_to_others = block['r_to_others']         # (p,) 向量，block PC1 与所有 SNP 的相关性
        
        # 填入 SNP - block 行/列
        R_extended[:p, pos_in_extended] = r_to_others
        R_extended[pos_in_extended, :p] = r_to_others  # 对称

    # 3. Block - Block 相关性
    R_extended[p:, p:] = block_block_r_matrix

    # === 返回结果 ===
    return {
        'blocks': enriched_blocks,
        'remaining_snp_idx': remaining_snp_idx,
        'block_block_r_matrix': block_block_r_matrix,
        'R_extended': R_extended,
        'block_positions_in_extended': block_positions,  # 可溯源：第 i 个 block 在 R_extended 中的位置
    }

In [None]:
def estimate_sigma_ire(z_cond, tol=1e-2, max_iter=10, return_diagnostics=False):
    """
    使用迭代加权中位数法估计残差 Z 的背景方差 σ²
    假设大多数 SNP 属于噪声（z_i ~ N(0, σ²)），而少数为真信号
    Args:
        z_cond: (p,) conditional Z 向量
        tol: 收敛阈值
        max_iter: 最大迭代次数
        return_diagnostics: 是否返回诊断信息        
    Returns:
        sigma2: 估计的背景方差，限制在合理范围内 [0.8, 5.0]
    """
    z = np.asarray(z_cond).flatten()
    if len(z) == 0:
        return (1.0, {}) if return_diagnostics else 1.0
    ## 初始化
    sigma2_initial =np.median(z**2) / 0.454936448
    sigma2 = sigma2_initial
    
    for iter_idx in range(max_iter):
        # 高斯核权重：|z| 越大，权重越小
        w = np.exp(-z**2 / (2 * sigma2 + 1e-8))
        # 加权中位数计算
        z2 = z**2
        sorted_idx = np.argsort(z2)
        z2_sorted = z2[sorted_idx]
        w_sorted = w[sorted_idx]
        cumw = np.cumsum(w_sorted)
        total_weight = cumw[-1]
        target = 0.5 * total_weight
                
        weighted_median_z2 = z2_sorted[np.searchsorted(cumw, target, side='right') - 1]
        sigma2_new = weighted_median_z2 / 0.454936448
        
        sigma2_new = np.clip(sigma2_new, 0.8, 5.0)
        
        # 检查收敛
        if abs(sigma2_new - sigma2) < tol:
            sigma2 = sigma2_new
            break
        sigma2 = sigma2_new

    return sigma2

In [None]:
@limit_warnings()
def compute_conditional_z(
    z: np.ndarray,
    R: np.ndarray,
    selected_indices: list,
    U_trunc: np.ndarray,      # 全局谱截断特征向量
    Lambda_trunc: np.ndarray  # 全局谱截断特征值
):
    """
    使用全局谱截断稳定计算 conditional Z (COJO 思想)
    """
    if len(selected_indices) == 0:
        return z.copy()
        
    S = selected_indices
    R_full_sub = R[:, S]      # (p, |S|) 从所有 SNP 到已选 SNP 的 LD
    z_selected = z[S]         # (|S|,) 已选 SNP 的 Z 分数
    R_sub = R[np.ix_(S, S)]   # (|S|, |S|) 已选 SNP 之间的 LD

    try:
        U_global_S = U_trunc[S, :]  # (|S|, k) 已选 SNP 在主成分空间中的表示
        R_sub_inv = U_global_S @ np.diag(1.0 / Lambda_trunc) @ U_global_S.T
        beta = R_sub_inv @ z_selected
        
    except Exception as e:
        print(f"compute_conditional_z 谱截断求逆失败: {e}")
        try:
            beta = np.linalg.solve(R_sub, z_selected)
        except:
            beta = np.linalg.pinv(R_sub) @ z_selected

    # COJO 核心：z_cond = z - R[:,S] @ β
    proj_mean = R_full_sub @ beta
    z_cond = z - proj_mean
    return z_cond

In [None]:
def apply_spectral_truncation(R: np.ndarray, threshold: float = None):
    """
    对 LD 矩阵进行谱截断，返回显著特征值对应的特征向量和特征值
    """
    eigenvals, eigenvecs = np.linalg.eigh(R)    
    idx = np.argsort(eigenvals)[::-1]
    eigenvals = eigenvals[idx]
    eigenvecs = eigenvecs[:, idx]
    if threshold is None:
        threshold = max(0.2, 1e-6 * eigenvals[0])  # 修正：使用合理的默认值
    
    keep = eigenvals > threshold
    if not np.any(keep):
        keep[0] = True  # 至少保留最大的一个
    U_trunc = eigenvecs[:, keep]
    Lambda_trunc = eigenvals[keep]
    print(f"   apply_spectral_truncation 诊断:")
    print(f"     - 原始特征值范围: [{eigenvals.min():.6f}, {eigenvals.max():.6f}]")
    print(f"     - 截断阈值: {threshold:.6f}")
    print(f"     - 保留特征值数: {len(Lambda_trunc)}/{len(eigenvals)}")
    print(f"     - Lambda_trunc 范围: [{Lambda_trunc.min():.6f}, {Lambda_trunc.max():.6f}]")
    return U_trunc, Lambda_trunc

In [None]:
@limit_warnings(max_count=30)
def forward_selection(
    z_raw: np.ndarray,
    R: np.ndarray,
    snp_ids: list,
    U_trunc: np.ndarray,
    sigma2: float,
    Lambda_trunc: np.ndarray,
    target_completion: float = 0.9):
    """
    基于 conditional Z 和 LD-adjusted 伪 R² 的前向选择算法
    """
    p = len(z_raw)
    print(f"🔍 forward_selection 诊断:")
    print(f"   - 输入维度: p = {p}")
    print(f"   - U_trunc 形状: {U_trunc.shape}")
    print(f"   - Lambda_trunc 形状: {Lambda_trunc.shape}")
    print(f"   - Lambda_trunc 范围: [{Lambda_trunc.min():.6f}, {Lambda_trunc.max():.6f}]")
    
    if snp_ids is None:
        snp_ids = [f"SNP_{i}" for i in range(p)]

    # === 1. 背景方差校正 ===
    z = z_raw / np.sqrt(sigma2)
    print(f"   - 背景方差估计: {sigma2:.4f}")
    print(f"   - 校正后 Z 范围: [{z.min():.4f}, {z.max():.4f}]")
    print(f"   - 校正后 Z 均值: {z.mean():.6f}")
    
    # === 2. 构造全局能量算子 ===
    P = U_trunc @ np.diag(1.0 / Lambda_trunc) @ U_trunc.T
    print(f"   - 能量算子 P 形状: {P.shape}")
    
    # === 3. 总信号能量 ===
    E_total = z @ P @ z
    print(f"   - 总信号能量 E_total: {E_total:.6f}")

    # === 4. 初始化状态 ===
    selected_indices = []
    remaining_mask = np.ones(p, dtype=bool)
    step = 0
    completion = 0.0  # 初始化 completion
    print(f"🚀 开始前向选择迭代...")
    # === 5. 主循环 ===
    while True:
        step += 1
        if step > p + 10:  # 防止无限循环
            print(f"⚠️  迭代次数过多 ({step})，强制退出")
            break
        # --- 5.1 计算当前 conditional Z ---
        try:
            z_cond = compute_conditional_z(z, R, selected_indices, U_trunc, Lambda_trunc)
            print(f"   - z_cond 计算成功，范围: [{z_cond.min():.4f}, {z_cond.max():.4f}]")
        except Exception as e:
            print(f"❌ z_cond 计算失败: {e}")
            break
        try:
            E_residual = z_cond @ P @ z_cond
            E_explained = E_total - E_residual
            completion = E_explained / E_total if E_total > 1e-8 else 1.0
        except Exception as e:
            print(f"❌ 能量计算失败: {e}")
            break

        if completion >= target_completion:
            print(f"✅ 达到目标完成度 {target_completion}，停止")
            break
        if np.max(np.abs(z_cond[remaining_mask])) < 1.645:
            print(f"⏹️  最大条件 |z| < 1.645，停止")
            break
        if len(selected_indices) == p:
            print(f"⏹️  所有 SNP 已选择，停止")
            break

        print(f"   - 寻找候选 SNP...")
        best_idx = None
        best_completion = completion

        candidate_count = 0
        valid_candidates = 0
        remaining_indices = np.where(remaining_mask)[0]
        sorted_remaining = remaining_indices[np.argsort(np.abs(z_cond[remaining_indices]))[::-1]]
        for idx in np.where(sorted_remaining)[0][:20]:  
            candidate_count += 1
            temp_selected = selected_indices + [idx]
            try:
                z_cond_temp = compute_conditional_z(z, R, temp_selected, U_trunc, Lambda_trunc)
                E_residual_temp = z_cond_temp @ P @ z_cond_temp
                E_explained_temp = E_total - E_residual_temp
                completion_temp = E_explained_temp / E_total if E_total > 1e-8 else 1.0

                if completion_temp > best_completion:
                    best_completion = completion_temp
                    best_idx = idx
                    valid_candidates += 1
                    
            except Exception as e:
                print(f"     候选 SNP {idx} 失败: {e}")
                continue
        print(f"   - 候选 SNP 检查: {candidate_count} 个, 本次迭代有效: {valid_candidates} 个")
        if best_idx is None:
            print("⚠️  没有找到能提升 completion 的 SNP")
            break

        selected_snp_id = snp_ids[best_idx]
        print(f"➡️  选择 SNP: {selected_snp_id} (index {best_idx})")
        selected_indices.append(best_idx)
        remaining_mask[best_idx] = False

    print(f"\n🏁 前向选择完成:")
    print(f"   - 总步数: {step}")
    print(f"   - 最终选中 SNP 数: {len(selected_indices)}")
    print(f"   - 最终 completion: {completion:.6f}")
    
    # === 6. 返回结果 ===
    return {
        'selected_indices': selected_indices,
        'selected_snp_ids': [snp_ids[i] for i in selected_indices],
        'n_selected': len(selected_indices),
        'final_completion': float(completion)  # 新增输出项
    }

In [None]:
def bootstrap_selection_paths(
    blocks: list,           # list of dict, 每个 block 包含 z_gwas_block, z_qtl_block, snps, 等等
    Z: np.ndarray,          # p, 为原始 Z 向量
    R_extended: np.ndarray, # (p + B, p + B)，已构建好的扩展 R
    snp_list: list,         # (p,)，SNP ID 列表
    analysis_type: str,     # "gwas" or "qtl"
    remaining_snp_idx: np.ndarray,  # 自由池 SNP 索引
    n_bootstraps: int = 100,
    z_perturb_sd: float = 0.01
):
    '''
    基于 bootstrap 的稳定 SNP/block 选择分析
    
    Returns:
    --------
    dict 包含以下字段：
        - 'R_clean': 用于分析的 LD 矩阵
        - 'Z_clean': 用于分析的 Z 分数向量
        - 'Z_extend': 扩展的 Z 向量（包含 block 信息）
        - 'sigma2': 背景方差估计
        - 'all_selected_paths': 每次 bootstrap 的选择路径
        - 'selection_counter': 选择计数器
        - 'selection_frequency': 选择频率
        - 'stable_snp_id': 频率 > 90% 的稳定 SNP/block ID
        - 'snp_list_clean': 清理后的 SNP ID 列表
        - 'clean_indices': 清理后的索引
        - 'n_free_snps': 自由 SNP 数量
        - 'n_blocks': block 数量
        - 'avg_completion': 平均信号完成度  # 新增
    '''
    
    # === Step 1: 基本验证 ===
    p = len(Z)
    B = len(blocks)
    assert len(snp_list) == p
    assert R_extended.shape == (p + B, p + B)
    # === Step 2: 构建扩展 Z 向量 ===
    Z_extended = np.zeros(p + B)
    Z_extended[:p] = Z
    SNPlist_extended = list(snp_list)

    # 添加 block 信息
    for i, block in enumerate(blocks):
        pos = p + i
        if analysis_type.lower() == 'gwas':
            z_block = block['z_gwas_block']
        elif analysis_type.lower() == 'qtl':
            z_block = block['z_qtl_block']
        else:
            raise ValueError(f"Unknown analysis_type: {analysis_type}")
        Z_extended[pos] = z_block
        lead_snp = block['snp_ids'][0]
        block_id = f"block|{lead_snp}"
        SNPlist_extended.append(block_id)
    # === Step 3: 构建 clean 空间 ===
    M = len(remaining_snp_idx)  # 自由 SNP 数量
    N_clean = M + B
    clean_indices = []
    clean_indices.extend(remaining_snp_idx.tolist())
    block_extended_positions = [p + i for i in range(B)]
    clean_indices.extend(block_extended_positions)
    assert len(clean_indices) == N_clean

    # 提取 clean R 和 Z
    R_clean = R_extended[np.ix_(clean_indices, clean_indices)]
    Z_clean = Z_extended[clean_indices]
    snp_list_clean = [snp_list[idx] for idx in remaining_snp_idx] + [f"block|{block['snp_ids'][0]}" for block in blocks]

    # === Step 4: 预处理步骤 ===
    estimate_sigma = estimate_sigma_ire(Z_clean)             # 在 clean 空间中直接估计
    U_trunc, Lambda_trunc = apply_spectral_truncation(R_clean)
    
    # === Step 5: Bootstrap 分析 ===
    print(f"🚀 开始 Bootstrap 分析:")
    print(f"   - Clean 空间大小: {len(clean_indices)} (自由SNP: {M}, Blocks: {B})")
    print(f"   - 背景方差估计: {estimate_sigma:.4f}")
    print(f"   - Bootstrap 次数: {n_bootstraps}")
    all_selected = []
    selection_counter = Counter()
    successful_bootstraps = 0
    completion_rates = []  # 新增：记录每次的完成度

    for boot_idx in range(n_bootstraps):
        if boot_idx % 50 == 0:
            print(f"   - Bootstrap 进度: {boot_idx}/{n_bootstraps}")
        # Z 分数扰动，传入的是纯净的Z-clean
        Z_clean_perturbed = Z_clean + np.random.normal(0, z_perturb_sd, size=Z_clean.shape)
        try:
            result = forward_selection(
                z_raw=Z_clean_perturbed,
                R=R_clean,
                snp_ids=snp_list_clean,
                U_trunc=U_trunc,
                sigma2= estimate_sigma , 
                Lambda_trunc=Lambda_trunc,
                target_completion=0.9
            )
            selected_ids = result['selected_snp_ids']
            if selected_ids:
                successful_bootstraps += 1
            all_selected.append(selected_ids)
            selection_counter.update(selected_ids)
            completion_rates.append(result['final_completion'])  # 新增：记录完成度

        except Exception as e:
            print(f"Bootstrap {boot_idx} failed: {e}")
            completion_rates.append(0.0)  # 失败的情况记录为0
            continue

    print(f"   - 成功的 Bootstrap 次数: {successful_bootstraps}/{n_bootstraps}")
    
    # 计算平均完成度
    avg_completion = np.mean(completion_rates) if completion_rates else 0.0
    print(f"   - 平均信号完成度: {avg_completion:.4f}")
    
    selection_freq = {
        snp_id: count / n_bootstraps
        for snp_id, count in selection_counter.items()
    }

    for snp_id in snp_list_clean:
        if snp_id not in selection_freq:
            selection_freq[snp_id] = 0.0

    stable_snp_id = [snp_id for snp_id, freq in selection_freq.items() if freq > 0.8]
    return {
        'R_clean': R_clean,
        'Z_clean': Z_clean,
        'Z_extend': Z_extended,
        'sigma2': estimate_sigma,
        'U_trunc': U_trunc,              
        'Lambda_trunc': Lambda_trunc,       
        'all_selected_paths': all_selected,
        'selection_counter': selection_counter,
        'selection_frequency': selection_freq,
        'stable_snp_id': stable_snp_id,
        'snp_list_clean': snp_list_clean,
        'clean_indices': clean_indices,
        'n_free_snps': M,
        'n_blocks': B,
        'avg_completion': float(avg_completion)  # 新增返回值
    }

In [None]:
@limit_warnings()
def compute_stable_square_beta(
    result: dict,
    frequency_threshold: float = 0.9,
    min_beta_weight: float = 1e-8
):
    """
    基于 bootstrap 结果，对高频入选变量进行多变量效应估计
    在统一的谱截断空间中进行回归，保证与信号完成度计算一致性
    
    Parameters
    ----------
    result : dict
        bootstrap_selection_paths 的返回结果，必须包含：
        - 'selection_frequency'
        - 'snp_list_clean' 
        - 'R_clean'
        - 'Z_clean'
        - 'U_trunc', 'Lambda_trunc' (全局谱截断结果)
    frequency_threshold : float, default=0.9
        入选频率阈值
    min_beta_weight : float
        防止 beta² 和为 0 的下界

    Returns
    -------
    dict
        增强的结果字典
    """
    # === Step 1: 提取必要字段 ===
    selection_frequency = result['selection_frequency']
    snp_list_clean = result['snp_list_clean']
    R_clean = result['R_clean']
    Z_clean_base = result['Z_clean']
    U_trunc_global = result['U_trunc']            # 全局谱截断特征向量
    Lambda_trunc_global = result['Lambda_trunc']  # 全局谱截断特征值
    
    print(f"📈 compute_stable_square_beta 输入诊断:")
    print(f"   - selection_frequency 中的 SNP 数: {len(selection_frequency)}")
    print(f"   - snp_list_clean 长度: {len(snp_list_clean)}")
    print(f"   - R_clean 形状: {R_clean.shape}")
    print(f"   - 全局主成分数量: {len(Lambda_trunc_global)}")

    # === Step 2: 找出频率 > threshold 的稳定变量 ===
    stable_snp_ids = [
        snp_id for snp_id, freq in selection_frequency.items() 
        if freq >= frequency_threshold
    ]

    if len(stable_snp_ids) == 0:
        print(f"⚠️ No variable selected with frequency ≥ {frequency_threshold}. Skipping pseudobeta.")
        result.update({
            'stable_snps': [],
            'stable_snp_indices': [],
            'beta_square': {},
            'beta_multivar': {},
            'beta_sorted': [],
            'R_sub_condition_number': 0.0,
        })
        return result
    
    # === Step 3: 映射 snp_id → index in clean space ===
    snp_to_idx = {snp: idx for idx, snp in enumerate(snp_list_clean)}
    stable_indices = [snp_to_idx[snp_id] for snp_id in stable_snp_ids]
    # === Step 4: 提取子矩阵 ===
    R_sub = R_clean[np.ix_(stable_indices, stable_indices)]
    z_sub = Z_clean_base[stable_indices]
    k = len(stable_indices)
    # === 在统一的谱截断空间中求解 ===
    # 提取全局特征向量在子空间中的部分
    U_sub = U_trunc_global[stable_indices, :]  # shape: (k, n_components)
    # 构造投影后的广义逆：R⁻¹ ≈ U_sub @ diag(1/Λ) @ U_sub.T
    Lambda_inv = 1.0 / Lambda_trunc_global
    R_sub_inv = U_sub @ np.diag(Lambda_inv) @ U_sub.T
    # 求解 beta = R⁻¹ z
    beta = R_sub_inv @ z_sub
    # 条件数估计（用于诊断）
    cond_num = np.linalg.cond(R_sub) if k > 1 else 1.0
    
    # === Step 6: 计算 beta 平方权重 ===
    beta = np.asarray(beta).flatten()
    weights = beta ** 2
    total_weight = weights.sum()
    # 数值稳定性处理
    if total_weight < min_beta_weight or total_weight == 0:
        beta_squared_weight = np.ones_like(weights) / len(weights) if len(weights) > 0 else np.array([])
        compute_stable_square_beta.warn(
            f"Sum of beta^2 too small ({total_weight:.2e}). Using uniform weights."
        )
    else:
        beta_squared_weight = weights / total_weight

    # === Step 7: 绑定回 snp_id ===
    beta_squared_dict = {
        snp_id: float(beta2weight) 
        for snp_id, beta2weight in zip(stable_snp_ids, beta_squared_weight)
    }
    beta_dict = {
        snp_id: float(b) 
        for snp_id, b in zip(stable_snp_ids, beta)
    }
    beta_sorted = sorted(beta_squared_dict.items(), key=lambda x: x[1], reverse=True)

    # === Step 8: 更新 result 并返回 ===
    result.update({
        'stable_snps': stable_snp_ids,
        'stable_snp_indices': stable_indices,
        'beta_square': beta_squared_dict,
        'beta_multivar': beta_dict,
        'beta_sorted': beta_sorted,
        'R_sub_condition_number': float(cond_num),
    })

    return result

In [None]:
def plot_causal_discovery(
    result: dict, 
    title: str = "Causal Discovery by SNP Importance", 
    figsize: tuple = (12, 6),
    show_legend: bool = True,
    bar_alpha: float = 0.8,
    bar_colors: dict = None
):
    """
    可视化稳定SNP的因果信号解释能力（仅显示beta平方权重）
    可视化内容：
    - 横轴：stable SNP 按 beta_square 降序排列
    - 纵轴：beta_square（多变量效应平方权重）
    - 颜色：block 用红色，普通 SNP 用天蓝色    
    
    Parameters
    ----------
    result : dict
        compute_stable_square_beta 的输出，必须包含：
        - 'beta_square': dict, snp_id -> beta_square 权重
        - 'stable_snps': list of str, 稳定 SNP ID 列表
        
    title : str, default="Causal Discovery by SNP Importance"
        图表标题
        
    figsize : tuple, default=(12, 6)
        图像尺寸 (width, height)
        
    show_legend : bool, default=True
        是否显示图例
        
    bar_alpha : float, default=0.8
        柱状图透明度
        
    bar_colors : dict, optional
        自定义颜色映射，格式：{'block': color, 'snp': color}
        
    Returns
    -------
    tuple
        (fig, ax) matplotlib 图形对象
    """
    # -----------------------------
    # 提取数据
    # -----------------------------
    beta_square_dict = result.get('beta_square', {})
    stable_snp_ids = result.get('stable_snps', [])

    # 检查必要数据
    if not stable_snp_ids or not beta_square_dict:
        fig, ax = plt.subplots(figsize=figsize)
        ax.text(0.5, 0.5, 'No stable SNPs/blocks identified', 
                transform=ax.transAxes, fontsize=14, color='gray', alpha=0.7, 
                ha='center', va='center')
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)
        ax.axis('off')
        ax.set_ylabel("")
        plt.title(title, pad=20)
        fig.tight_layout()
        return fig, ax

    # ----------------------------- 
    # 核心排序：按 beta_square 降序
    # -----------------------------
    sorted_by_beta2 = sorted(beta_square_dict.items(), key=lambda x: x[1], reverse=True)
    snp_labels = [item[0] for item in sorted_by_beta2]
    beta2_values = [item[1] for item in sorted_by_beta2]
    x_pos = np.arange(len(snp_labels))
    
    # 颜色设置
    if bar_colors is None:
        bar_colors = {'block': 'red', 'snp': 'skyblue'}
    
    colors = [
        bar_colors['block'] if snp.startswith('block|') else bar_colors['snp'] 
        for snp in snp_labels
    ]

    # -----------------------------
    # 绘图
    # -----------------------------
    fig, ax = plt.subplots(figsize=figsize, dpi=300)
    
    # 绘制 beta_square 柱状图
    bars = ax.bar(x_pos, beta2_values, color=colors, alpha=bar_alpha, width=0.6)
    ax.set_xlabel("Stable SNPs (ordered by $\\beta^2$ weight)")
    ax.set_ylabel("$\\beta^2$ Weight", color='black')
    ax.tick_params(axis='y', labelcolor='black')
    ax.set_xticks(x_pos)
    ax.set_xticklabels(snp_labels, rotation=45, ha='right', fontsize=9)
    ax.set_xlim(-0.6, len(snp_labels) - 0.4)

    # 图例
    if show_legend:
        legend_elements = [
            Patch(facecolor=bar_colors['snp'], label='SNP'),
            Patch(facecolor=bar_colors['block'], label='Block')
        ]
        ax.legend(
            handles=legend_elements,
            loc='upper right',
            frameon=True,
            fontsize=9
        )
    
    plt.title(title, pad=20)
    fig.tight_layout()
    return fig, ax

In [None]:
def run_fine_mapping_for_signal(
        gene_df, ld_df, beta_col_gwas, se_col_gwas,
        beta_col_qtl, se_col_qtl):
    """
    对某一信号运行完整流程，返回 GWAS 和 QTL 的分析结果,以及block构造信息
    Returns:
    --------
    tuple: (result_raw_gwas, result_stable_gwas, 
            result_raw_qtl, result_stable_qtl)
    """
    # === 数据准备 ===
    snps_list = ld_df.index.intersection(ld_df.columns).intersection(gene_df.index)
    snps_list = snps_list.astype(str)
    
    ld_matrix_raw = ld_df.loc[snps_list, snps_list].values
    
    # Z 分数计算
    beta_gwas = gene_df.loc[snps_list, beta_col_gwas].values
    se_gwas = gene_df.loc[snps_list, se_col_gwas].values
    z_gwas = beta_gwas / se_gwas 

    beta_qtl = gene_df.loc[snps_list, beta_col_qtl].values
    se_qtl = gene_df.loc[snps_list, se_col_qtl].values
    z_qtl = beta_qtl / se_qtl 

    # 构建 blocks
    blocks_result = build_enriched_blocks_pipeline(ld_matrix_raw, z_gwas, z_qtl, snps_list)
    r_extended = blocks_result["R_extended"]
    blocks = blocks_result['blocks']
    remaining_snp = blocks_result['remaining_snp_idx']

    print(f"📊 GWAS 分析诊断:")
    print(f"   - 总 SNP 数: {len(snps_list)}")
    print(f"   - Block 数: {len(blocks)}")
    print(f"   - 剩余自由 SNP: {len(remaining_snp)}")
    
    # === GWAS 分析 ===
    result_raw_gwas = bootstrap_selection_paths(
        blocks, z_gwas, r_extended, snps_list, "gwas", remaining_snp, 100, 0.01
    )
    
    # GWAS 诊断
    print(f"   - Bootstrap 选择路径数: {len(result_raw_gwas.get('all_selected_paths', []))}")
    print(f"   - 稳定 SNP 数: {len(result_raw_gwas.get('stable_snp_id', []))}")
    if result_raw_gwas.get('stable_snp_id'):
        print(f"   - 稳定 SNP ID: {result_raw_gwas['stable_snp_id'][:5]}...")
    
    # GWAS 稳定 SNP 分析
    result_stable_gwas = compute_stable_square_beta(result_raw_gwas.copy(), frequency_threshold=0.9)
    
    # === QTL 分析 ===
    print(f"📊 QTL 分析诊断:")
    print(f"   - 剩余自由 SNP: {len(remaining_snp)}")
    result_raw_qtl = bootstrap_selection_paths(
        blocks, z_qtl, r_extended, snps_list, "qtl", remaining_snp, 100, 0.01
    )
    
    # QTL 诊断
    print(f"   - Bootstrap 选择路径数: {len(result_raw_qtl.get('all_selected_paths', []))}")
    print(f"   - 稳定 SNP 数: {len(result_raw_qtl.get('stable_snp_id', []))}")
    if result_raw_qtl.get('stable_snp_id'):
        print(f"   - 稳定 SNP ID: {result_raw_qtl['stable_snp_id'][:5]}...")
    
    # QTL 稳定 SNP 分析
    result_stable_qtl = compute_stable_square_beta(result_raw_qtl.copy(), frequency_threshold=0.9)
    
    return (result_raw_gwas, result_stable_gwas,
            result_raw_qtl, result_stable_qtl, blocks_result)

In [None]:
# === 新增辅助函数 ===
def is_subset_np(a_arr, b_arr):
    """判断两个字符串数组中每个元素是否满足子集关系"""
    result = []
    for a, b in zip(a_arr, b_arr):
        if pd.isna(a) or pd.isna(b):
            result.append(False)
            continue
        str_a = str(a).upper()
        str_b = str(b).upper()
        set_a = set(str_a.split(','))
        set_b = set(str_b.split(','))
        if any(len(allele) > 1 and allele != '-' for allele in set_a | set_b):
            result.append(False)
            continue
        result.append(set_a <= set_b or set_b <= set_a)
    return np.array(result)

def classify_and_adjust_beta_vectorized(df):
    # 根据实际列名修正
    required_cols = {'REF_GWAS', 'ALF_GWAS', 'REF_QTL', 'ALT_QTL', 'beta_QTL'}
    if not required_cols.issubset(df.columns):
        missing = required_cols - set(df.columns)
        raise ValueError(f"缺少必需列: {missing}")
    # 使用实际的列名
    ref_qtl = df['REF_QTL'].values      # 大写
    alt_qtl = df['ALT_QTL'].values      # 大写
    ref_gwas = df['REF_GWAS'].values    # 大写
    alt_gwas = df['ALF_GWAS'].values    # 根据你的列名是 ALF_GWAS
    beta_qtl = df['beta_QTL'].values    # 小写保持不变

    cond1 = is_subset_np(alt_gwas, alt_qtl) & is_subset_np(ref_gwas, ref_qtl)  # 方向一致
    cond2 = is_subset_np(alt_gwas, ref_qtl) & is_subset_np(ref_gwas, alt_qtl)  # 方向相反
    valid_mask = cond1 | cond2
    invalid_count = (~valid_mask).sum()
    print(f"共 {invalid_count} 行被丢弃（无法归类）")

    adjusted_beta = np.where(cond2, -beta_qtl, beta_qtl)
    df.loc[valid_mask, 'beta_QTL'] = adjusted_beta[valid_mask]

    return df[valid_mask].copy()

def pivot_ld_to_matrix(ld_df):
    """将三列格式的 LD 数据转换为对称矩阵"""
    ld_df = ld_df.drop_duplicates(subset=['ID_A', 'ID_B'])
    all_snps = sorted(set(ld_df['ID_A']) | set(ld_df['ID_B']))
    snp_map = {snp: i for i, snp in enumerate(all_snps)}
    n = len(all_snps)
    matrix = np.eye(n)  # 默认对角线为1

    for _, row in ld_df.iterrows():
        i = snp_map[row['ID_A']]
        j = snp_map[row['ID_B']]
        matrix[i, j] = matrix[j, i] = row['R']

    return pd.DataFrame(matrix, index=all_snps, columns=all_snps)


# === 日志记录函数 ===

def init_log_file(log_path):
    """初始化日志文件"""
    if not log_path.exists():
        pd.DataFrame(columns=[
            'timestamp', 'csv_file', 'parquet_file', 'gene', 'status', 
            'common_snp_count', 'message'
        ]).to_csv(log_path, index=False)

def log_analysis(log_path, csv_file, parquet_file, gene, status, common_snp_count, message=""):
    """记录分析日志"""
    log_entry = pd.DataFrame([{
        'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
        'csv_file': csv_file,
        'parquet_file': parquet_file,
        'gene': gene,
        'status': status,
        'common_snp_count': common_snp_count,
        'message': message
    }])
    log_entry.to_csv(log_path, mode='a', header=False, index=False)


In [None]:
folder_path = Path(r"D:\desk\study5_COPDxLC_SMR\结果文件2")
log_file_path = folder_path / "log_analysis.csv"

# 初始化日志文件
init_log_file(log_file_path)

# 获取所有CSV和Parquet文件（一对一映射）
csv_files = {f.stem: f for f in folder_path.glob("*.csv")}
parquet_files = {f.stem.replace('_LD_matrix', ''): f for f in folder_path.glob("*_LD_matrix.parquet")}

print(f"🔍 找到 {len(csv_files)} 个CSV文件，{len(parquet_files)} 个Parquet文件")

# 遍历每个CSV文件
for csv_prefix, csv_path in csv_files.items():
    print(f"\n👉 正在处理: {csv_prefix}")
    
    # 查找对应的Parquet文件
    pq_path = parquet_files.get(csv_prefix)
    pq_prefix = pq_path.stem if pq_path else ""

    # 读取CSV文件
    # 在数据读取后添加去重逻辑，基于P_GWAS值
    df_full = pd.read_csv(csv_path)
    original_rows = len(df_full)
    
    # 如果有P_GWAS列，基于P_GWAS去重，保留P值最小的
    if 'P_GWAS' in df_full.columns:
        # 按SNP分组，选择P_GWAS最小的行
        df_full = df_full.loc[df_full.groupby('SNP')['P_GWAS'].idxmin()]
        if len(df_full) < original_rows:
            print(f"🧹 基于P_GWAS去重: {original_rows} → {len(df_full)} 行 (保留P值最小)")
    
    csv_snp_count = len(df_full['SNP'].unique())

    try:
        ld_long = pd.read_parquet(pq_path, engine='fastparquet')
        pq_snp_set = set(ld_long['ID_A']) | set(ld_long['ID_B'])
        pq_snp_count = len(pq_snp_set)
        
        coverage_ratio = (pq_snp_count / csv_snp_count) * 100 if csv_snp_count > 0 else 0
    except Exception as e:
        print(f"❌ 读取Parquet失败: {e}")
        log_analysis(log_file_path, csv_prefix, pq_prefix, "", "ERROR", 0, f"读取Parquet失败: {e}")
        continue
    
    output_pkl = folder_path / f"{pq_prefix}.pkl"
    output_gwas_png = folder_path / f"{pq_prefix}_gwas.png"
    output_qtl_png = folder_path / f"{pq_prefix}_qtl.png"
    output_blank = folder_path / f"{pq_prefix}_blank"
    
    # 检查_blank文件
    if output_blank.exists():
        print(f"⏭️  已知问题（SNP不足），跳过")
        log_analysis(log_file_path, csv_prefix, pq_prefix, "", "SKIPPED", 0, 
                    f"已知问题：SNP不足 | CSV SNP数: {csv_snp_count}, Parquet SNP数: {pq_snp_count}, 覆盖率: {coverage_ratio:.2f}%")
        continue
    
    # 检查是否已完成
    if output_pkl.exists() and output_gwas_png.exists() and output_qtl_png.exists():
        print(f"✅ 已完成，跳过")
        log_analysis(log_file_path, csv_prefix, pq_prefix, "", "SKIPPED", 0, 
                    f"已完成，跳过 | CSV SNP数: {csv_snp_count}, Parquet SNP数: {pq_snp_count}, 覆盖率: {coverage_ratio:.2f}%")
        continue
    
    # 调整 beta_QTL 方向（对完整数据进行处理）
    try:
        df_adjusted = classify_and_adjust_beta_vectorized(df_full)
        print(f"📊 数据过滤：原始 {len(df_full)} 行 → 过滤后 {len(df_adjusted)} 行")
    except Exception as e:
        print(f"❌ 调整 beta 失败: {e}")
        log_analysis(log_file_path, csv_prefix, pq_prefix, "", "ERROR", 0, 
                    f"调整beta失败: {e} | CSV SNP数: {csv_snp_count}, Parquet SNP数: {pq_snp_count}, 覆盖率: {coverage_ratio:.2f}%")
        continue

    # 设置索引
    df_adjusted = df_adjusted.set_index('SNP')
    ## 这个是调整后的，然后去LD里面找相关内容
    try:
        ld_df = pivot_ld_to_matrix(ld_long)
    except Exception as e:
        error_msg = f"转换Parquet失败: {e}"
        print(f"❌ {error_msg}")
        log_analysis(log_file_path, csv_prefix, pq_prefix, "", "ERROR", 0, 
                    f"{error_msg} | CSV SNP数: {csv_snp_count}, Parquet SNP数: {pq_snp_count}, 覆盖率: {coverage_ratio:.2f}%")
        continue

    # 筛选共同SNP（使用调整后的数据索引），反正是LD内容是跟着df_adjusted走的，df_adjusted的意思是调整方向后的，所以我们
    snps_common = ld_df.index.intersection(ld_df.columns).intersection(df_adjusted.index)
    common_snp_count = len(snps_common)
    ## 这里的意思是，如果LD里面没找到。。。那么也删去
    if common_snp_count < 3:
        print(f"⚠️ 共同SNP数量不足（{common_snp_count} < 3），跳过分析")
        log_analysis(log_file_path, csv_prefix, pq_prefix, "", "SKIPPED", common_snp_count, 
                    f"共同SNP数量不足 | CSV SNP数: {csv_snp_count}, Parquet SNP数: {pq_snp_count}, 覆盖率: {coverage_ratio:.2f}%")
        output_blank.touch()
        continue
    
    # 提取共同SNP的子集
    ld_df = ld_df.loc[snps_common, snps_common]
    df_sub = df_adjusted.loc[snps_common]        
    # 到这里完全调整完成
    
    # 开始分析，输入的是清理后的LD以及 df_sub 
    try:
        (result_raw_gwas, result_stable_gwas,
         result_raw_qtl, result_stable_qtl, blocks_result) = run_fine_mapping_for_signal(
            df_sub, ld_df, 'BETA_GWAS', 'SE_GWAS', 'beta_QTL', 'SE_QTL'
        )
        print(f"✅ 分析完成，共同SNP数: {common_snp_count}")
        log_analysis(log_file_path, csv_prefix, pq_prefix, "", "SUCCESS", common_snp_count, 
                    f"分析完成 | CSV SNP数: {csv_snp_count}, Parquet SNP数: {pq_snp_count}, 覆盖率: {coverage_ratio:.2f}%")
    except Exception as e:
        error_msg = f"分析失败: {e}"
        print(f"❌ {error_msg}")
        log_analysis(log_file_path, csv_prefix, pq_prefix, "", "ERROR", common_snp_count, 
                    f"{error_msg} | CSV SNP数: {csv_snp_count}, Parquet SNP数: {pq_snp_count}, 覆盖率: {coverage_ratio:.2f}%")
        continue
    
    # 绘图
    try:
        fig_gwas, _ = plot_causal_discovery(result_stable_gwas, f"{csv_prefix}_gwas")
        fig_qtl, _ = plot_causal_discovery(result_stable_qtl, f"{csv_prefix}_qtl")
        fig_gwas.savefig(output_gwas_png, dpi=300, bbox_inches='tight')
        fig_qtl.savefig(output_qtl_png, dpi=300, bbox_inches='tight')
        plt.close(fig_gwas)
        plt.close(fig_qtl)
        print(f"✅ 绘图完成")
    except Exception as e:
        error_msg = f"绘图失败: {e}"
        print(f"❌ {error_msg}")
        log_analysis(log_file_path, csv_prefix, pq_prefix, "", "ERROR", common_snp_count, 
                    f"{error_msg} | CSV SNP数: {csv_snp_count}, Parquet SNP数: {pq_snp_count}, 覆盖率: {coverage_ratio:.2f}%")
        continue
    
    # 保存结果
    try:
        combined = {
            "gwas_bootstrap": result_raw_gwas,
            "gwas_stable": result_stable_gwas,
            "qtl_bootstrap": result_raw_qtl,
            "qtl_stable": result_stable_qtl,
            "block": blocks_result
        }
        with open(output_pkl, 'wb') as f:
            pickle.dump(combined, f)
        print(f"✅ 结果已保存至 {output_pkl}")
    except Exception as e:
        error_msg = f"保存失败: {e}"
        print(f"❌ {error_msg}")
        log_analysis(log_file_path, csv_prefix, pq_prefix, "", "ERROR", common_snp_count, 
                    f"{error_msg} | CSV SNP数: {csv_snp_count}, Parquet SNP数: {pq_snp_count}, 覆盖率: {coverage_ratio:.2f}%")
        continue

print("✅ 全部任务完成")