In [None]:
import numpy as np
import pandas as pd
import os
import scipy.linalg
from cmdstanpy import CmdStanModel

# ------------------------------------------------------------------------------
# Helper function: prior network preprocessing
# ------------------------------------------------------------------------------

def prior_pp(prior, expr):
    """
    Filter low-confidence edges in the prior network using partial correlation.
    Similar to the R implementation using GeneNet.
    
    Parameters:
      prior: pd.DataFrame
          Prior regulatory network (adjacency matrix), rows are TFs, columns are target genes.
      expr: pd.DataFrame
          Normalized and log-transformed gene expression matrix, rows are genes.
    
    Returns:
      pd.DataFrame: Filtered prior network.
    """
    # Filter TFs and target genes that exist in the expression matrix
    tf = prior.index.intersection(expr.index)
    tg = prior.columns.intersection(expr.index)
    all_genes = np.unique(np.concatenate([tf, tg]))
    
    # Extract expression data for these genes (transposed like in R GeneNet)
    expr_sub = expr.loc[all_genes].T  # Samples × genes
    
    try:
        # Calculate correlation matrix
        corr_matrix = expr_sub.corr().values
        
        # Apply shrinkage to correlation matrix - similar to what GeneNet does
        n_genes = corr_matrix.shape[0]
        n_samples = expr_sub.shape[0]
        
        # Apply regularization to ensure positive definiteness
        shrinkage = min(0.2, 1/np.sqrt(n_samples))
        
        # Calculate shrunk correlation using Ledoit-Wolf like approach
        shrunk_corr = (1 - shrinkage) * corr_matrix + shrinkage * np.eye(n_genes)
        
        # Calculate approximate partial correlation using matrix inversion
        # This is similar to how GeneNet calculates it without graphical lasso
        try:
            # More numerically stable approach for matrix inversion
            precision_mat = scipy.linalg.pinvh(shrunk_corr)  # Pseudo-inverse for better stability
            
            # Convert precision to partial correlation
            diag_precision = np.sqrt(np.diag(precision_mat))
            partial_corr = -precision_mat / np.outer(diag_precision, diag_precision)
            np.fill_diagonal(partial_corr, 0)
        except:
            # If matrix inversion fails, fall back to just using the correlation
            print("Warning: Matrix inversion failed. Using correlation instead of partial correlation.")
            partial_corr = shrunk_corr
            np.fill_diagonal(partial_corr, 0)
        
    except Exception as e:
        print(f"Warning: Correlation calculation failed: {str(e)}. Using simpler approach.")
        # Fall back to a very simple correlation approach if all else fails
        corr_values = np.zeros((len(all_genes), len(all_genes)))
        
        # Calculate pairwise correlations manually if necessary
        for i, gene_i in enumerate(all_genes):
            for j, gene_j in enumerate(all_genes):
                if i != j:
                    try:
                        # Simple Pearson correlation
                        correlation = np.corrcoef(expr.loc[gene_i], expr.loc[gene_j])[0, 1]
                        corr_values[i, j] = correlation
                    except:
                        corr_values[i, j] = 0
        
        partial_corr = corr_values
    
    # Convert partial correlation coefficients to DataFrame
    coexp = pd.DataFrame(partial_corr, index=all_genes, columns=all_genes)
    
    # Take the part of the prior matrix that corresponds to the partial correlation submatrix
    P_ij = prior.loc[tf, tg].copy().astype(np.float64)  # Convert to float64 explicitly
    C_ij = coexp.loc[tf, tg].abs() * P_ij.abs()
    
    # Compare the sign of prior edges and partial correlation
    sign_P = np.sign(P_ij)
    sign_C = np.sign(C_ij)
    
    # For edges with inconsistent signs, adjust the weight to a very small value (fuzzy)
    inconsistent = (sign_P * sign_C) < 0
    P_ij[inconsistent] = 1e-6
    
    # Remove all-zero TFs and genes
    P_ij = P_ij.loc[(P_ij != 0).any(axis=1), (P_ij != 0).any(axis=0)]
    return P_ij

In [None]:
def prepare_network_data(prior_tf_mirna, prior_tf_gene, prior_mirna_gene, 
                        expr_gene, expr_mirna, TFexpressed=True, signed=True):
    """
    准备适合RegInsight模型的多层次网络数据
    """
    # 检查样本是否匹配
    if not np.all(expr_gene.columns == expr_mirna.columns):
        raise ValueError("基因和miRNA表达矩阵样本必须一致")
    
    sample_name = expr_gene.columns.tolist()
    
    # 找出所有潜在的TF、miRNA和基因
    all_TFs = sorted(list(set(prior_tf_mirna.index).union(prior_tf_gene.index)))
    if TFexpressed:
        all_TFs = sorted(list(set(all_TFs).intersection(expr_gene.index)))
        
    all_miRNAs = sorted(list(set(prior_tf_mirna.columns).intersection(expr_mirna.index)))
    all_genes = sorted(list(set(prior_tf_gene.columns).intersection(expr_gene.index)))
    
    # 检查数据是否有效
    if len(all_TFs) == 0 or len(all_miRNAs) == 0 or len(all_genes) == 0:
        raise ValueError("输入数据中基因/miRNA名称不匹配")
    
    # 处理三种先验网络
    network_data = {}
    
    if signed:
        # TF-miRNA网络 - 只使用存在于prior_tf_mirna中的TF
        tf_mz_common = sorted(list(set(all_TFs).intersection(prior_tf_mirna.index)))
        tf_mirna_net = prior_tf_mirna.loc[tf_mz_common, all_miRNAs].copy()
        P_mz = prior_pp(tf_mirna_net, pd.concat([expr_gene, expr_mirna]), edge_type='tf_mirna')
        
        # TF-gene网络 - 只使用存在于prior_tf_gene中的TF
        tf_gz_common = sorted(list(set(all_TFs).intersection(prior_tf_gene.index)))
        tf_gene_net = prior_tf_gene.loc[tf_gz_common, all_genes].copy()
        P_gz = prior_pp(tf_gene_net, pd.concat([expr_gene, expr_mirna]), edge_type='tf_gene')
        
        # miRNA-gene网络 - 使用存在于prior_mirna_gene中的miRNA和基因
        mirna_common = sorted(list(set(all_miRNAs).intersection(prior_mirna_gene.index)))
        gene_common_for_mirna = sorted(list(set(all_genes).intersection(prior_mirna_gene.columns)))
        mirna_gene_net = prior_mirna_gene.loc[mirna_common, gene_common_for_mirna].copy()
        P_gm = prior_pp(mirna_gene_net, pd.concat([expr_gene, expr_mirna]), 
                      edge_type='mirna_gene', enforce_negative=True)
    else:
        # 无符号网络处理 - 同样需要过滤
        tf_mz_common = sorted(list(set(all_TFs).intersection(prior_tf_mirna.index)))
        P_mz = prior_tf_mirna.loc[tf_mz_common, all_miRNAs].copy()
        
        tf_gz_common = sorted(list(set(all_TFs).intersection(prior_tf_gene.index)))
        P_gz = prior_tf_gene.loc[tf_gz_common, all_genes].copy()
        
        mirna_common = sorted(list(set(all_miRNAs).intersection(prior_mirna_gene.index)))
        gene_common_for_mirna = sorted(list(set(all_genes).intersection(prior_mirna_gene.columns)))
        P_gm = prior_mirna_gene.loc[mirna_common, gene_common_for_mirna].copy()
        # 确保miRNA-gene边为负值
        non_zero = P_gm != 0 
        P_gm[non_zero] = -abs(P_gm[non_zero])
    
    # 更新处理后的TF、miRNA和基因名称 (取交集，确保所有维度一致)
    TF_names = sorted(list(set(P_mz.index).union(P_gz.index)))
    miRNA_names = sorted(list(set(P_mz.columns).intersection(P_gm.index)))
    gene_names = sorted(list(set(P_gz.columns).intersection(P_gm.columns)))
    
    network_data['TF_names'] = TF_names
    network_data['miRNA_names'] = miRNA_names
    network_data['gene_names'] = gene_names
    network_data['P_mz'] = P_mz
    network_data['P_gz'] = P_gz
    network_data['P_gm'] = P_gm
    network_data['sample_names'] = sample_name
    
    return network_data