In [1]:

import os
import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
from scipy.cluster.hierarchy import dendrogram, linkage, fcluster
from scipy.spatial.distance import squareform
from scipy.stats import pearsonr, ttest_ind
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler, quantile_transform
from numba import jit, prange
from collections import defaultdict
import warnings

# RNA-seq specific
from pydeseq2.dds import DeseqDataSet
from pydeseq2.ds import DeseqStats

# Microarray specific  
from patsy import dmatrix
from inmoose.limma import lmFit, eBayes, topTable, makeContrasts, contrasts_fit
from statsmodels.stats.multitest import multipletests

warnings.filterwarnings('ignore')

# Increase recursion limit for large dendrograms
sys.setrecursionlimit(15000)

# ============================================================================
# OPTIMIZED FUNCTIONS WITH NUMBA (unchanged)
# ============================================================================

@jit(nopython=True, parallel=True, fastmath=True)
def fast_correlation_numba(X):
    n_samples, n_genes = X.shape
    X_centered = np.empty_like(X)
    
    for j in prange(n_genes):
        mean_j = np.mean(X[:, j])
        for i in range(n_samples):
            X_centered[i, j] = X[i, j] - mean_j
    
    std_devs = np.empty(n_genes)
    for j in prange(n_genes):
        var_sum = 0.0
        for i in range(n_samples):
            var_sum += X_centered[i, j] ** 2
        std_devs[j] = np.sqrt(var_sum / (n_samples - 1))
    
    corr_matrix = np.empty((n_genes, n_genes))
    for i in prange(n_genes):
        for j in range(i, n_genes):
            if std_devs[i] == 0 or std_devs[j] == 0:
                corr_matrix[i, j] = 0.0
                corr_matrix[j, i] = 0.0
            else:
                cov = 0.0
                for k in range(n_samples):
                    cov += X_centered[k, i] * X_centered[k, j]
                cov /= (n_samples - 1)
                corr = cov / (std_devs[i] * std_devs[j])
                corr_matrix[i, j] = corr
                corr_matrix[j, i] = corr
    
    return corr_matrix


@jit(nopython=True, parallel=True, fastmath=True)
def compute_TOM_numba(A):

    n = A.shape[0]
    A_work = A.copy()
    
    for i in range(n):
        A_work[i, i] = 0.0
    
    k = np.empty(n)
    for i in prange(n):
        k[i] = np.sum(A_work[i, :])
    
    L = np.dot(A_work, A_work)
    
    TOM = np.empty((n, n))
    for i in prange(n):
        for j in range(n):
            if i == j:
                TOM[i, j] = 1.0
            else:
                min_k = min(k[i], k[j])
                denom = min_k + 1.0 - A_work[i, j]
                if denom < 1e-10:
                    denom = 1e-10
                TOM[i, j] = (L[i, j] + A_work[i, j]) / denom
    
    return TOM


@jit(nopython=True, parallel=True)
def compute_coexp_metrics_fast(tom_matrix, threshold):
   
    n = tom_matrix.shape[0]
    n_coexp = np.zeros(n, dtype=np.int32)
    mean_tom = np.zeros(n)
    max_tom = np.zeros(n)
    total_tom = np.zeros(n)
    
    for i in prange(n):
        count = 0
        tom_sum = 0.0
        max_val = 0.0
        total = 0.0
        
        for j in range(n):
            if i != j:
                val = tom_matrix[i, j]
                total += val
                if val > threshold:
                    count += 1
                    tom_sum += val
                if val > max_val:
                    max_val = val
        
        n_coexp[i] = count
        mean_tom[i] = tom_sum / count if count > 0 else 0.0
        max_tom[i] = max_val
        total_tom[i] = total
    
    return n_coexp, mean_tom, max_tom, total_tom


class DynamicTreeCut:
   
    
    def __init__(self, min_cluster_size=20, deep_split=2, 
                 detect_cut_height=0.995, pam_stage=True, 
                 pam_respects_dendro=True, max_pam_dist=None,
                 verbose=True):
        self.min_cluster_size = min_cluster_size
        self.deep_split = deep_split
        self.detect_cut_height = detect_cut_height
        self.pam_stage = pam_stage
        self.pam_respects_dendro = pam_respects_dendro
        self.max_pam_dist = max_pam_dist
        self.verbose = verbose
    
    def _get_cut_height(self, linkage_matrix):
        merge_heights = linkage_matrix[:, 2]
        percentile_5 = np.percentile(merge_heights, 5)
        max_height = np.max(merge_heights)
        cut_height = percentile_5 + self.detect_cut_height * (max_height - percentile_5)
        return cut_height, percentile_5, max_height
    
    def _get_clusters_at_height(self, linkage_matrix, cut_height):
        labels = fcluster(linkage_matrix, t=cut_height, criterion='distance')
        return labels
    
    def _filter_small_clusters(self, labels):
        unique_labels = np.unique(labels)
        label_counts = {label: np.sum(labels == label) for label in unique_labels}
        
        new_labels = labels.copy()
        for label, count in label_counts.items():
            if count < self.min_cluster_size:
                new_labels[labels == label] = 0
        
        valid_labels = [l for l in np.unique(new_labels) if l != 0]
        label_map = {old: new for new, old in enumerate(valid_labels, start=1)}
        label_map[0] = 0
        
        final_labels = np.array([label_map[l] for l in new_labels])
        return final_labels
    
    def _pam_reassignment(self, dissim_matrix, labels, linkage_matrix):
        n_genes = len(labels)
        reassigned_labels = labels.copy()
        
        unassigned_idx = np.where(labels == 0)[0]
        assigned_idx = np.where(labels != 0)[0]
        
        if len(unassigned_idx) == 0 or len(assigned_idx) == 0:
            return reassigned_labels
        
        if self.max_pam_dist is None:
            max_pam_dist = np.percentile(dissim_matrix[dissim_matrix > 0], 95)
        else:
            max_pam_dist = self.max_pam_dist
        
        for gene_idx in unassigned_idx:
            distances_to_assigned = dissim_matrix[gene_idx, assigned_idx]
            closest_assigned_idx = assigned_idx[np.argmin(distances_to_assigned)]
            min_distance = distances_to_assigned[np.argmin(distances_to_assigned)]
            
            if min_distance < max_pam_dist:
                closest_module = labels[closest_assigned_idx]
                
                if self.pam_respects_dendro:
                    module_members = np.where(labels == closest_module)[0]
                    mean_dist_to_module = np.mean(dissim_matrix[gene_idx, module_members])
                    
                    if mean_dist_to_module < max_pam_dist:
                        reassigned_labels[gene_idx] = closest_module
                else:
                    reassigned_labels[gene_idx] = closest_module
        
        if self.verbose:
            n_reassigned = np.sum(reassigned_labels != labels)
            print(f"  PAM stage: Reassigned {n_reassigned} genes")
        
        return reassigned_labels
    
    def cut_tree(self, linkage_matrix, dissim_matrix=None):
        cut_height, ref_height, max_height = self._get_cut_height(linkage_matrix)
        
        if self.verbose:
            
            print(f"  Min module size: {self.min_cluster_size}")
            print(f"  deepSplit: {self.deep_split}")
            print(f"  Cut height: {cut_height:.4f} (detectCutHeight: {self.detect_cut_height})")
            print(f"  PAM stage: {self.pam_stage}")
        
        initial_labels = self._get_clusters_at_height(linkage_matrix, cut_height)
        n_initial = len(np.unique(initial_labels))
        
        if self.verbose:
            print(f"  Initial clusters: {n_initial}")
        
        filtered_labels = self._filter_small_clusters(initial_labels)
        n_after_filter = len(np.unique(filtered_labels[filtered_labels != 0]))
        
        if self.verbose:
            print(f"  Modules after filtering: {n_after_filter}")
            print(f"  Unassigned genes: {np.sum(filtered_labels == 0)}")
        
        if self.pam_stage and dissim_matrix is not None:
            final_labels = self._pam_reassignment(dissim_matrix, filtered_labels, linkage_matrix)
        else:
            final_labels = filtered_labels
        
        return final_labels


def merge_close_modules(module_labels, module_eigengenes, merge_cut_height=0.15, 
                       dissim_matrix=None, verbose=True):
    
    if verbose:
        print(f"\n=== Merging Close Modules (cutHeight={merge_cut_height}) ===")
    
    unique_modules = [col for col in module_eigengenes.columns if col != 'ME_grey']
    n_modules = len(unique_modules)
    
    if n_modules <= 1:
        if verbose:
            print("  No modules to merge")
        return module_labels, {}
    
    me_cor = module_eigengenes[unique_modules].corr()
    me_dissim = 1 - me_cor.abs()
    
    condensed_me_dist = squareform(me_dissim.values, checks=False)
    me_linkage = linkage(condensed_me_dist, method='average')
    
    merge_labels = fcluster(me_linkage, t=merge_cut_height, criterion='distance')
    
    merged_module_labels = module_labels.copy()
    module_name_array = np.array([m if m != 'grey' else 'grey' for m in module_labels])
    
    merge_map_final = {}
    for merge_group in np.unique(merge_labels):
        modules_in_group = [m.replace('ME_', '') for i, m in enumerate(unique_modules) 
                          if merge_labels[i] == merge_group]
        representative = modules_in_group[0]
        for mod in modules_in_group:
            merge_map_final[mod] = representative
    
    for i, mod in enumerate(module_name_array):
        if mod in merge_map_final:
            merged_module_labels[i] = merge_map_final[mod]
    
    unique_merged = np.unique([m for m in merged_module_labels if m != 'grey'])
    
    if verbose:
        print(f"  Modules before merging: {n_modules}")
        print(f"  Modules after merging: {len(unique_merged)}")
    
    return merged_module_labels, merge_map_final



class PyWGCNA:

    
    def __init__(self, data_type='rna_seq', output_dir='output_wgcna'):
        
        self.data_type = data_type.lower()
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)
        
        # Analysis parameters
        self.wgcna_input = None
        self.normalized_data = None
        self.deg_results = None
        self.metadata = None
        self.soft_power = None
        self.module_colors = None
        self.module_eigengenes = None
        

        
        
    
    def get_platform_parameters(self):
        
        if self.data_type == 'microarray':
            return {
                # Soft power selection
                'power_range': (1, 20),           # Broader range for microarray
                'r2_threshold': 0.80,             # Standard threshold
                'min_r2_accept': 0.80,                                # Minimum acceptable
                'max_mean_k': 200,                # Avoid over-connection
                
                # Module detection
                'min_module_size': 30,            # Larger to avoid noise modules
                'deep_split': 3,                  # Higher sensitivity (0-4 scale)
                'detect_cut_height': 0.990,       # Lower = more modules
                'pam_stage': True,
                'pam_respects_dendro': False,     # More flexible reassignment
                
                # Module merging
                'merge_cut_height': 0.25,         # Higher = less aggressive merging
                
                # Data filtering
                'variance_percentile': 50,        # Keep top 50% variable genes
                'min_samples': 10                 # Minimum recommended samples
            }
        else:  # rna_seq
            return {
                # Soft power selection
                'power_range': (1, 20),           # Full range
                'r2_threshold': 0.80,             # Stricter for RNA-seq
                'min_r2_accept': 0.80,
                'max_mean_k': 200,
                
                # Module detection
                'min_module_size': 30,            # Standard size
                'deep_split': 2,                  # Moderate sensitivity
                'detect_cut_height': 0.995,       # Higher = fewer, larger modules
                'pam_stage': True,
                'pam_respects_dendro': True,      # Respect hierarchy
                
                # Module merging
                'merge_cut_height': 0.15,         # Standard merging
                
                # Data filtering
                'variance_percentile': 25,        # Keep top 75% (broader for RNA-seq)
                'min_samples': 10
            }
    
    def load_data(self, count_matrix_path, metadata_path):
        """Load expression data and metadata"""
        print("\n=== LOADING DATA ===")
        count = pd.read_csv(count_matrix_path, index_col=0)
        self.metadata = pd.read_csv(metadata_path, index_col=0)
        
        # Align samples
        common_samples = list(set(count.columns) & set(self.metadata.index))
        self.count_matrix = count[common_samples]
        self.metadata = self.metadata.loc[common_samples]
        
        params = self.get_platform_parameters()
        n_samples = len(common_samples)
        
        print(f"  Initial dimensions: {self.count_matrix.shape}")
        print(f"  Aligned samples: {n_samples}")
        
        if n_samples < params['min_samples']:
            print(f"  ⚠ WARNING: Only {n_samples} samples detected!")
            print(f"  Recommended minimum: {params['min_samples']} for robust WGCNA")
            print(f"  Results may be less reliable with small sample sizes")
        
        return self
    
    def preprocess_rnaseq(self, group_column, case_label, control_label, 
                         min_count=10, padj_threshold=0.05, lfc_threshold= 1):
        
        print("\n=== PREPROCESSING RNA-SEQ DATA ===")
        print("  Method: DESeq2 with variance stabilizing transformation (VST)")
        
        # Filter low-count genes
        self.count_matrix = self.count_matrix.loc[:, (self.count_matrix.sum(axis=0) >= min_count)]
        print(f"  After low-count filtering: {self.count_matrix.shape}")
        
        # Run DESeq2
        print("  Running DESeq2...")
        count_t = self.count_matrix.T.astype(int)
        metadata_aligned = self.metadata.copy()
        metadata_aligned.index = count_t.index
        
        dds = DeseqDataSet(counts=count_t, metadata=metadata_aligned, 
                          design_factors=group_column, n_cpus=4)
        dds.deseq2()
        dds.vst()
        
        # Get normalized data (VST)
        self.normalized_data = pd.DataFrame(dds.layers['vst_counts'].T, 
                                           index=self.count_matrix.index, 
                                           columns=self.count_matrix.columns)
        
        print(f"  ✓ VST normalization complete")
        print(f"  Normalized data range: [{self.normalized_data.min().min():.2f}, {self.normalized_data.max().max():.2f}]")
        
        # Get DE results
        stat_res = DeseqStats(dds, contrast=[group_column, case_label, control_label], 
                            alpha=padj_threshold, n_cpus=4)
        stat_res.summary()
        self.deg_results = stat_res.results_df
        
        sig_degs = self.deg_results[
            (self.deg_results['padj'] < padj_threshold) & 
            (abs(self.deg_results['log2FoldChange']) > lfc_threshold)
        ]
        print(f"  DEGs found: {len(sig_degs)} (padj<{padj_threshold}, |log2FC|>{lfc_threshold})")
        
        return self
    
    def preprocess_microarray(self, group_column, case_label, control_label,
                             covariates=None, normalize=True, 
                             padj_threshold=0.05, lfc_threshold=1):
        
        print("  Method: limma with empirical Bayes moderation")
        
        expr_data = self.count_matrix.copy()
        expr_data = expr_data.fillna(0)
        
        # Log2 transform if needed
        max_val = expr_data.max().max()
        if max_val > 100:
            print(f"  Data appears non-log scale (max={max_val:.1f})")
            print("  Applying log2(x+1) transformation...")
            expr_data = expr_data.mask(expr_data <= 0, 1e-6)
            expr_data = np.log2(expr_data + 1)
        else:
            print(f"  Data appears log-scale (max={max_val:.2f})")
        
        # Quantile normalization
        if normalize:
            print("  Applying quantile normalization...")
            expr_array = quantile_transform(expr_data.values, axis=0, 
                                          copy=True, output_distribution='normal')
            expr_data = pd.DataFrame(expr_array, index=expr_data.index, 
                                    columns=expr_data.columns)
            print(f"  ✓ Normalization complete")
        
        self.normalized_data = expr_data
        print(f"  Normalized data range: [{expr_data.min().min():.2f}, {expr_data.max().max():.2f}]")
        
        # Create design matrix
        print("  Creating design matrix...")
        self.metadata['Label'] = self.metadata[group_column].apply(
            lambda x: 'Disease' if x == case_label else 'Control'
        )
        
        if covariates:
            formula = f"~ 0 + Label + {' + '.join(covariates)}"
            print(f"  Including covariates: {', '.join(covariates)}")
        else:
            formula = "~ 0 + Label"
        
        design_mat = dmatrix(formula, data=self.metadata)
        
        # Fit limma
        print("  Fitting linear model...")
        fit = lmFit(expr_data, design_mat)
        
        print("  Applying empirical Bayes moderation...")
        contrast_matrix = makeContrasts("Label[Disease] - Label[Control]", design_mat)
        fit2 = contrasts_fit(fit, contrast_matrix)
        fit2 = eBayes(fit2)
        
        print("  Extracting differential expression results...")
        results = topTable(fit2, coef='Label[Disease] - Label[Control]', 
                          adjust_method="fdr_bh", number=np.inf)
        self.deg_results = pd.DataFrame(results)
        
        sig_degs = self.deg_results[
            (self.deg_results['adj_pvalue'] < padj_threshold) & 
            (abs(self.deg_results['log2FoldChange']) > lfc_threshold)
        ]
        print(f"  DEGs found: {len(sig_degs)} (FDR<{padj_threshold}, |log2FC|>{lfc_threshold})")
        
        return self
    
    def prepare_wgcna_input(self, custom_variance_percentile=None):
        
        params = self.get_platform_parameters()
        variance_percentile = custom_variance_percentile or params['variance_percentile']
        
        gene_var = self.normalized_data.var(axis=1)
        top_var_threshold = gene_var.quantile(variance_percentile / 100)
        filtered_data = self.normalized_data[gene_var > top_var_threshold]
        
        self.wgcna_input = filtered_data.T  # Samples x Genes
        
        print(f"  Variance filtering: Keep top {100-variance_percentile}% variable genes")
        print(f"  Variance threshold: {top_var_threshold:.3f}")
        print(f"  WGCNA input: {self.wgcna_input.shape[0]} samples × {self.wgcna_input.shape[1]} genes")
        
        if self.wgcna_input.shape[1] < 1000:
            print(f"  ⚠ WARNING: Only {self.wgcna_input.shape[1]} genes in network")
            print(f"  Consider reducing variance_percentile for more genes")
        
        return self
    
    def select_soft_power(self, custom_power_range=None, custom_r2_threshold=None):
      
        params = self.get_platform_parameters()
        power_range = custom_power_range or params['power_range']
        r2_threshold = custom_r2_threshold or params['r2_threshold']
        
        print(f"  Platform: {self.data_type.upper()}")
        print(f"  Target R² threshold: {r2_threshold}")
        print(f"  Power range: {power_range}")
        
        # Compute correlation matrix
        print("  Computing correlation matrix...")
        X = self.wgcna_input.values
        cor_matrix_np = fast_correlation_numba(X)
        
        # Use subset for large datasets (computational efficiency)
        n_genes = cor_matrix_np.shape[0]
        max_genes_for_selection = 5000
        
        if n_genes > max_genes_for_selection:
            print(f"  Using {max_genes_for_selection} random genes for power selection (dataset has {n_genes})")
            np.random.seed(42)
            subset_idx = np.random.choice(n_genes, min(max_genes_for_selection, n_genes), replace=False)
            cor_subset = cor_matrix_np[np.ix_(subset_idx, subset_idx)]
        else:
            cor_subset = cor_matrix_np
        
        cor_abs = np.abs(cor_subset)
        powers = range(*power_range)
        mean_k = []
        median_k = []
        r_squared = []
        slope_values = []
        
        print("  Testing soft powers...")
        for power in powers:
            adj_temp = cor_abs ** power
            k_temp = adj_temp.sum(axis=1)
            mean_k.append(k_temp.mean())
            median_k.append(np.median(k_temp))
            
            # Calculate scale-free topology fit (R²)
            k_hist, k_bins = np.histogram(k_temp[k_temp > 1], bins=30)
            k_centers = (k_bins[:-1] + k_bins[1:]) / 2
            k_centers = k_centers[k_hist > 0]
            k_hist = k_hist[k_hist > 0]
            
            if len(k_centers) > 5:
                log_k = np.log10(k_centers)
                log_p = np.log10(k_hist / k_hist.sum())
                mask = np.isfinite(log_k) & np.isfinite(log_p)
                
                if mask.sum() > 5:
                    slope, intercept = np.polyfit(log_k[mask], log_p[mask], 1)
                    r2 = np.corrcoef(log_k[mask], log_p[mask])[0, 1] ** 2
                    slope_values.append(slope)
                else:
                    r2 = 0
                    slope_values.append(0)
            else:
                r2 = 0
                slope_values.append(0)
            
            r_squared.append(r2)
        
        # Plot soft power selection
        fig, axes = plt.subplots(2, 2, figsize=(16, 12))
        
        # R² plot
        ax1 = axes[0, 0]
        ax1.plot(powers, r_squared, 'o-', linewidth=2.5, markersize=8, color='#2E86AB')
        ax1.axhline(y=r2_threshold, color='red', linestyle='--', linewidth=2, 
                   label=f'Target R² = {r2_threshold}')
        ax1.axhline(y=params['min_r2_accept'], color='orange', linestyle=':', linewidth=1.5,
                   label=f'Min acceptable R² = {params["min_r2_accept"]}')
        ax1.set_xlabel('Soft Threshold Power (β)', fontsize=12, fontweight='bold')
        ax1.set_ylabel('Scale Free Topology Fit (R²)', fontsize=12, fontweight='bold')
        ax1.set_title(f'Scale Independence ({self.data_type.upper()})', fontsize=14, fontweight='bold')
        ax1.legend(fontsize=10)
        ax1.grid(alpha=0.3)
        ax1.set_xlim([powers[0]-0.5, powers[-1]+0.5])
        
        # Mean connectivity plot
        ax2 = axes[0, 1]
        ax2.plot(powers, mean_k, 'o-', linewidth=2.5, markersize=8, color='#06A77D')
        if params['max_mean_k']:
            ax2.axhline(y=params['max_mean_k'], color='red', linestyle='--', linewidth=2,
                       label=f'Max mean k = {params["max_mean_k"]}')
            ax2.legend(fontsize=10)
        ax2.set_xlabel('Soft Threshold Power (β)', fontsize=12, fontweight='bold')
        ax2.set_ylabel('Mean Connectivity', fontsize=12, fontweight='bold')
        ax2.set_title('Mean Connectivity', fontsize=14, fontweight='bold')
        ax2.grid(alpha=0.3)
        ax2.set_xlim([powers[0]-0.5, powers[-1]+0.5])
        
        # Median connectivity plot
        ax3 = axes[1, 0]
        ax3.plot(powers, median_k, 'o-', linewidth=2.5, markersize=8, color='#F18F01')
        ax3.set_xlabel('Soft Threshold Power (β)', fontsize=12, fontweight='bold')
        ax3.set_ylabel('Median Connectivity', fontsize=12, fontweight='bold')
        ax3.set_title('Median Connectivity', fontsize=14, fontweight='bold')
        ax3.grid(alpha=0.3)
        ax3.set_xlim([powers[0]-0.5, powers[-1]+0.5])
        
        # Network density (slope) plot
        ax4 = axes[1, 1]
        ax4.plot(powers, [-s for s in slope_values], 'o-', linewidth=2.5, markersize=8, color='#A23B72')
        ax4.set_xlabel('Soft Threshold Power (β)', fontsize=12, fontweight='bold')
        ax4.set_ylabel('Scale Free Topology Slope (-slope)', fontsize=12, fontweight='bold')
        ax4.set_title('Network Density', fontsize=14, fontweight='bold')
        ax4.grid(alpha=0.3)
        ax4.set_xlim([powers[0]-0.5, powers[-1]+0.5])
        
        plt.tight_layout()
        plt.savefig(f'{self.output_dir}/01_Soft_Power_Selection.pdf', dpi=300, bbox_inches='tight')
        plt.close()
        
        # Select optimal power using multi-criteria decision
        print("\n  Selecting optimal power...")
        
        # Criterion 1: R² >= threshold
        suitable_powers = []
        for i, (p, r2, mk) in enumerate(zip(powers, r_squared, mean_k)):
            if r2 >= r2_threshold and mk <= params['max_mean_k']:
                suitable_powers.append((p, r2, mk, abs(r2 - r2_threshold)))
        
        # If no power meets strict threshold, relax to minimum acceptable
        if not suitable_powers:
            print(f"  No power achieved R² >= {r2_threshold}")
            print(f"  Relaxing to minimum acceptable R² >= {params['min_r2_accept']}")
            for i, (p, r2, mk) in enumerate(zip(powers, r_squared, mean_k)):
                if r2 >= params['min_r2_accept'] and mk <= params['max_mean_k']:
                    suitable_powers.append((p, r2, mk, abs(r2 - r2_threshold)))
        
        # If still none, pick highest R²
        if not suitable_powers:
            print(f"  Using power with highest R² (no threshold met)")
            best_idx = np.argmax(r_squared)
            self.soft_power = list(powers)[best_idx]
            selected_r2 = r_squared[best_idx]
            selected_mk = mean_k[best_idx]
        else:
            # Sort by: closest to target R², then lowest mean k
            suitable_powers_sorted = sorted(suitable_powers, key=lambda x: (x[3], x[2]))
            self.soft_power = suitable_powers_sorted[0][0]
            selected_r2 = suitable_powers_sorted[0][1]
            selected_mk = suitable_powers_sorted[0][2]
        
        print(f"\n  ✓ Selected soft power: β = {self.soft_power}")
        print(f"    Scale-free R²: {selected_r2:.3f}")
        print(f"    Mean connectivity: {selected_mk:.2f}")
        print(f"    Interpretation: {'Excellent' if selected_r2 >= 0.85 else 'Good' if selected_r2 >= 0.80 else 'Acceptable'} scale-free fit")
        
        self.cor_matrix_np = cor_matrix_np
        self.soft_power_metrics = {
            'power': self.soft_power,
            'r_squared': selected_r2,
            'mean_connectivity': selected_mk,
            'all_r2': dict(zip(powers, r_squared)),
            'all_mean_k': dict(zip(powers, mean_k))
        }
        
        return self
    
    def construct_network(self, custom_params=None):
        
        
        params = self.get_platform_parameters()
        if custom_params:
            params.update(custom_params)
        
        print(f"  Platform-optimized parameters ({self.data_type.upper()}):")
        print(f"    min_module_size = {params['min_module_size']}")
        print(f"    deep_split = {params['deep_split']} (0=strict, 4=very sensitive)")
        print(f"    detect_cut_height = {params['detect_cut_height']}")
        print(f"    merge_cut_height = {params['merge_cut_height']}")
        
        # Compute adjacency
        print(f"\n  Applying soft power: β = {self.soft_power}")
        adjacency_np = np.abs(self.cor_matrix_np) ** self.soft_power
        
        # Compute TOM
        print("  Computing Topological Overlap Matrix (TOM)...")
        TOM_np = compute_TOM_numba(adjacency_np)
        self.TOM_df = pd.DataFrame(TOM_np, 
                                   index=self.wgcna_input.columns, 
                                   columns=self.wgcna_input.columns)
        print("  ✓ TOM calculation complete")
        
        # Hierarchical clustering
        print("  Performing hierarchical clustering...")
        dissim_TOM = 1 - TOM_np
        np.fill_diagonal(dissim_TOM, 0)
        condensed_dist = squareform(dissim_TOM, checks=False)
        self.gene_linkage = linkage(condensed_dist, method='average')
        
        # Dynamic tree cut
        print("\n=== DETECTING MODULES (Dynamic Tree Cut) ===")
        dtc = DynamicTreeCut(
            min_cluster_size=params['min_module_size'],
            deep_split=params['deep_split'],
            detect_cut_height=params['detect_cut_height'],
            pam_stage=params['pam_stage'],
            pam_respects_dendro=params['pam_respects_dendro'],
            verbose=True
        )
        
        module_labels_numeric = dtc.cut_tree(self.gene_linkage, dissim_matrix=dissim_TOM)
        
        # Convert to color names
        unique_modules = np.unique(module_labels_numeric[module_labels_numeric != 0])
        n_modules = len(unique_modules)
        
        if n_modules > 0:
            color_names = [f'module_{i}' for i in range(1, n_modules + 1)]
            module_colors = pd.Series(index=self.wgcna_input.columns, dtype='object')
            module_colors[:] = 'grey'
            
            for i, mod_num in enumerate(unique_modules):
                genes_in_mod = self.wgcna_input.columns[module_labels_numeric == mod_num]
                module_colors[genes_in_mod] = color_names[i]
        else:
            module_colors = pd.Series(['grey'] * len(self.wgcna_input.columns), 
                                     index=self.wgcna_input.columns)
        
        self.module_colors = module_colors
        
        # Calculate module eigengenes
        print("\n=== CALCULATING MODULE EIGENGENES ===")
        print("  Method: First principal component (PCA) of module expression")
        ME_dict = {}
        
        for mod in module_colors.unique():
            if mod != 'grey':
                genes_in_mod = module_colors[module_colors == mod].index
                if len(genes_in_mod) > 0:
                    pca_mod = PCA(n_components=1)
                    me = pca_mod.fit_transform(self.wgcna_input[genes_in_mod])
                    ME_dict[f'ME_{mod}'] = me.flatten()
                    
                    var_explained = pca_mod.explained_variance_ratio_[0]
                    print(f"    {mod}: {len(genes_in_mod)} genes, variance explained = {var_explained:.2%}")
        
        if 'grey' in module_colors.values:
            genes_in_grey = module_colors[module_colors == 'grey'].index
            if len(genes_in_grey) > 0:
                pca_grey = PCA(n_components=1)
                me_grey = pca_grey.fit_transform(self.wgcna_input[genes_in_grey])
                ME_dict['ME_grey'] = me_grey.flatten()
        
        self.module_eigengenes = pd.DataFrame(ME_dict, index=self.wgcna_input.index)
        print(f"\n  Calculated eigengenes for {len(ME_dict)} modules")
        
        # Merge close modules
        if n_modules > 1:
            merged_colors, merge_info = merge_close_modules(
                self.module_colors.values,
                self.module_eigengenes,
                merge_cut_height=params['merge_cut_height'],
                dissim_matrix=dissim_TOM,
                verbose=True
            )
            
            self.module_colors = pd.Series(merged_colors, index=self.module_colors.index)
            
            # Recalculate eigengenes after merging
            ME_dict_merged = {}
            for mod in self.module_colors.unique():
                if mod != 'grey':
                    genes_in_mod = self.module_colors[self.module_colors == mod].index
                    if len(genes_in_mod) > 0:
                        pca_mod = PCA(n_components=1)
                        me = pca_mod.fit_transform(self.wgcna_input[genes_in_mod])
                        ME_dict_merged[f'ME_{mod}'] = me.flatten()
            
            if 'grey' in self.module_colors.values:
                genes_in_grey = self.module_colors[self.module_colors == 'grey'].index
                if len(genes_in_grey) > 0:
                    pca_grey = PCA(n_components=1)
                    me_grey = pca_grey.fit_transform(self.wgcna_input[genes_in_grey])
                    ME_dict_merged['ME_grey'] = me_grey.flatten()
            
            self.module_eigengenes = pd.DataFrame(ME_dict_merged, index=self.wgcna_input.index)
        
        # Final summary
        module_counts = self.module_colors.value_counts()
        n_modules_final = len([m for m in self.module_colors.unique() if m != 'grey'])
        
        print(f"\n=== FINAL MODULE SUMMARY ===")
        print(f"  Total modules: {n_modules_final}")
        print(f"  Genes in modules: {np.sum(self.module_colors != 'grey')}")
        print(f"  Unassigned (grey): {np.sum(self.module_colors == 'grey')}")
        print(f"\n  Module size distribution:")
        for mod, count in module_counts.head(10).items():
            print(f"    {mod}: {count} genes")
        if len(module_counts) > 10:
            print(f"    ... and {len(module_counts)-10} more modules")
        
        # Quality metrics
        avg_module_size = np.mean([count for mod, count in module_counts.items() if mod != 'grey'])
        print(f"\n  Quality metrics:")
        print(f"    Average module size: {avg_module_size:.1f} genes")
        print(f"    Assignment rate: {100*np.sum(self.module_colors != 'grey')/len(self.module_colors):.1f}%")
        
        # Save results
        gene_module_df = pd.DataFrame({
            'Gene': self.module_colors.index,
            'Module': self.module_colors.values
        })
        gene_module_df.to_csv(f'{self.output_dir}/Gene_Module_Assignment.csv', index=False)
        
        self.module_eigengenes.to_csv(f'{self.output_dir}/Module_Eigengenes.csv')
        
        return self
    
    def analyze_deg_modules(self, padj_threshold=0.05, lfc_threshold= 1, coexp_threshold=0.1):
        """Analyze DEGs in WGCNA modules with biological interpretation"""
        print("\n=== ANALYZING DEGS IN MODULES ===")
        
        # Get significant DEGs
        padj_col = 'padj' if 'padj' in self.deg_results.columns else 'adj_pvalue'
        sig_genes = self.deg_results[
            (self.deg_results[padj_col] < padj_threshold) & 
            (abs(self.deg_results['log2FoldChange']) > lfc_threshold)
        ].dropna()
        
        deg_in_wgcna = list(set(sig_genes.index) & set(self.module_colors.index))
        print(f"  Total DEGs: {len(sig_genes)}")
        print(f"  DEGs in WGCNA network: {len(deg_in_wgcna)} ({100*len(deg_in_wgcna)/len(sig_genes):.1f}%)")
        
        if len(deg_in_wgcna) == 0:
            print("  No DEGs found in WGCNA network!")
            return self
        
        # DEG distribution across modules
        deg_modules = [self.module_colors[g] for g in deg_in_wgcna]
        deg_module_counts = pd.Series(deg_modules).value_counts()
        
        print(f"\n  DEG distribution across modules:")
        for mod, count in deg_module_counts.head(10).items():
            total_in_mod = np.sum(self.module_colors == mod)
            enrichment = (count / len(deg_in_wgcna)) / (total_in_mod / len(self.module_colors))
            print(f"    {mod}: {count} DEGs / {total_in_mod} genes (enrichment: {enrichment:.2f}x)")
        
        # Create DEG-module table
        deg_module_df = pd.DataFrame({
            'Gene': deg_in_wgcna,
            'Module': [self.module_colors[g] for g in deg_in_wgcna],
            'Log2FC': [sig_genes.loc[g, 'log2FoldChange'] for g in deg_in_wgcna],
            'Adj_PValue': [sig_genes.loc[g, padj_col] for g in deg_in_wgcna],
            'Regulation': ['Up' if sig_genes.loc[g, 'log2FoldChange'] > 0 else 'Down' 
                          for g in deg_in_wgcna]
        })
        deg_module_df.to_csv(f'{self.output_dir}/DEG_Module_Assignment.csv', index=False)
        
        # Co-expression analysis
        print(f"\n  Computing co-expression metrics (TOM threshold = {coexp_threshold})...")
        deg_tom = self.TOM_df.loc[deg_in_wgcna, deg_in_wgcna]
        
        gene_coexp_metrics = deg_module_df.copy()
        
        # Compute metrics
        n_coexp, mean_tom, max_tom, total_tom = compute_coexp_metrics_fast(
            deg_tom.values, coexp_threshold
        )
        
        gene_coexp_metrics['N_CoExpressed_Genes'] = n_coexp
        gene_coexp_metrics['Mean_TOM_Score'] = mean_tom
        gene_coexp_metrics['Max_TOM_Score'] = max_tom
        
        # Module membership (kME)
        print("  Calculating module membership (kME)...")
        mm_values = []
        for g in gene_coexp_metrics['Gene']:
            gene_module = gene_coexp_metrics[gene_coexp_metrics['Gene']==g]['Module'].values[0]
            me_col = f"ME_{gene_module}"
            
            if me_col in self.module_eigengenes.columns:
                corr, _ = pearsonr(self.wgcna_input[g], self.module_eigengenes[me_col])
                mm_values.append(corr)
            else:
                mm_values.append(np.nan)
        
        gene_coexp_metrics['Module_Membership'] = mm_values
        
        gene_coexp_metrics = gene_coexp_metrics.sort_values(
            ['N_CoExpressed_Genes', 'Mean_TOM_Score'], 
            ascending=[False, False]
        ).reset_index(drop=True)
        
        gene_coexp_metrics.to_csv(f'{self.output_dir}/Gene_Coexpression_Metrics.csv', index=False)
        
        print("\n Top 10 hub DEGs (highest co-expression):")
        print(gene_coexp_metrics[['Gene', 'Module', 'N_CoExpressed_Genes', 
                                   'Module_Membership', 'Log2FC']].head(10).to_string(index=False))
        
        return self
    
    def create_visualizations(self):
        
        
        # Module dendrogram (with error handling for large datasets)
        try:
            self._plot_dendrogram()
        except RecursionError:
            print("  ⚠ Dendrogram plotting failed (too many genes)")
            print("  Creating module size plot instead...")
            self._plot_module_sizes()
        
    
        
        # Module eigengene heatmap
        self._plot_eigengene_heatmap()
        
        print("  ✓ All visualizations complete!")
        
        return self
    
    def _plot_dendrogram(self, max_genes=5000):
        
        n_genes = len(self.wgcna_input.columns)
        
        fig = plt.figure(figsize=(22, 12))
        gs = fig.add_gridspec(2, 1, height_ratios=[4, 1], hspace=0.05)
        ax1 = fig.add_subplot(gs[0])
        ax2 = fig.add_subplot(gs[1])
        
        if n_genes > max_genes:
            print(f"  Creating truncated dendrogram ({n_genes} genes)...")
            dend = dendrogram(
                self.gene_linkage, 
                ax=ax1, 
                no_labels=True,
                truncate_mode='lastp',
                p=100,
                above_threshold_color='gray'
            )
            ax1.set_title(f'Gene Dendrogram (Truncated) - {n_genes} genes total', 
                         fontsize=16, fontweight='bold')
        else:
            dend = dendrogram(
                self.gene_linkage, 
                labels=self.wgcna_input.columns, 
                ax=ax1, 
                no_labels=True, 
                above_threshold_color='gray'
            )
            ax1.set_title(f'Gene Dendrogram and Module Colors - {n_genes} genes', 
                         fontsize=16, fontweight='bold')
        
        ax1.set_ylabel('Height (1-TOM)', fontsize=13, fontweight='bold')
        ax1.spines['bottom'].set_visible(False)
        ax1.set_xticks([])
        
        # Color bar
        module_counts = self.module_colors.value_counts()
        color_map = {}
        colors = plt.cm.tab20c(np.linspace(0, 1, 20))
        
        for i, mod in enumerate(self.module_colors.unique()):
            if mod == 'grey':
                color_map[mod] = (0.7, 0.7, 0.7)
            else:
                color_map[mod] = tuple(colors[i % 20][:3])
        
        if n_genes <= max_genes:
            leaf_order = dend['leaves']
            ordered_colors = np.array([color_map[self.module_colors.iloc[i]] for i in leaf_order])
            color_array = ordered_colors.reshape(1, -1, 3)
            
            ax2.imshow(color_array, aspect='auto', interpolation='nearest')
            ax2.set_yticks([])
            ax2.set_xticks([])
            ax2.set_ylabel('Module', fontsize=11, fontweight='bold')
        else:
            ax2.axis('off')
            stats_text = "Module Distribution:\n" + "\n".join(
                [f"{mod}: {count}" for mod, count in module_counts.head(8).items()]
            )
            ax2.text(0.1, 0.5, stats_text, transform=ax2.transAxes, 
                    fontsize=10, family='monospace')
        
        # Legend
        legend_elements = []
        for mod in list(self.module_colors.unique())[:12]:
            if mod in color_map:
                legend_elements.append(mpatches.Patch(
                    color=color_map[mod], 
                    label=f'{mod} (n={module_counts[mod]})'
                ))
        
        ax1.legend(handles=legend_elements, loc='upper right', fontsize=9, ncol=2)
        
        plt.tight_layout()
        plt.savefig(f'{self.output_dir}/02_Module_Dendrogram.pdf', dpi=300, bbox_inches='tight')
        plt.close()
        print("  Dendrogram saved")
    
    def _plot_module_sizes(self):
        
        module_counts = self.module_colors.value_counts()
        module_counts = module_counts[module_counts.index != 'grey'].sort_values(ascending=False)
        
        fig, ax = plt.subplots(figsize=(14, 8))
        
        bars = ax.bar(range(len(module_counts)), module_counts.values,
                     color=plt.cm.tab20(np.linspace(0, 1, len(module_counts))),
                     edgecolor='black', linewidth=1.2)
        
        ax.set_xticks(range(len(module_counts)))
        ax.set_xticklabels(module_counts.index, rotation=45, ha='right', fontsize=9)
        ax.set_xlabel('Module', fontsize=13, fontweight='bold')
        ax.set_ylabel('Number of Genes', fontsize=13, fontweight='bold')
        ax.set_title(f'Module Size Distribution - {len(module_counts)} modules', 
                    fontsize=15, fontweight='bold')
        ax.grid(axis='y', alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(f'{self.output_dir}/02_Module_Sizes.pdf', dpi=300, bbox_inches='tight')
        plt.close()
        print("  ✓ Module size plot saved")
    
    
    
    def _plot_eigengene_heatmap(self):
       
        if len(self.module_eigengenes.columns) < 2:
            print("  Skipping eigengene heatmap (need ≥2 modules)")
            return
        
        # Filter out grey if present
        me_cols = [col for col in self.module_eigengenes.columns if 'grey' not in col.lower()]
        if len(me_cols) < 2:
            return
        
        me_data = self.module_eigengenes[me_cols]
        
        plt.figure(figsize=(14, max(6, len(me_cols)*0.3)))
        
        sns.clustermap(me_data.T, cmap='RdBu_r', center=0, 
                      figsize=(14, max(8, len(me_cols)*0.5)),
                      yticklabels=True, xticklabels=False,
                      cbar_kws={'label': 'Eigengene Expression'})
        
        plt.suptitle('Module Eigengene Expression Heatmap', 
                    fontsize=15, fontweight='bold', y=1.02)
        plt.tight_layout()
        plt.savefig(f'{self.output_dir}/04_Eigengene_Heatmap.pdf', dpi=300, bbox_inches='tight')
        plt.close()
        
        print("  ✓ Eigengene heatmap saved")
    
    def run_complete_analysis(self, count_matrix_path, metadata_path, 
                             group_column, case_label, control_label, 
                             covariates=None, normalize_microarray=True,
                             variance_percentile=None, custom_wgcna_params=None):
        
        # Load data
        self.load_data(count_matrix_path, metadata_path)
        
        # Preprocess based on data type
        if self.data_type == 'rna_seq':
            self.preprocess_rnaseq(group_column, case_label, control_label)
        elif self.data_type == 'microarray':
            self.preprocess_microarray(group_column, case_label, control_label, 
                                      covariates=covariates, 
                                      normalize=normalize_microarray)
        else:
            raise ValueError("data_type must be 'rna_seq' or 'microarray'")
        
        # Prepare WGCNA input
        self.prepare_wgcna_input(custom_variance_percentile=variance_percentile)
        
        # Select soft power
        self.select_soft_power()
        
        # Construct network
        self.construct_network(custom_params=custom_wgcna_params)
        
        # Analyze DEGs
        self.analyze_deg_modules()
        
        # Create visualizations
        self.create_visualizations()
        
        
        
        return self




In [2]:

if __name__ == "__main__":
    
    # Example 1: RNA-seq analysis
    wgcna_rnaseq = PyWGCNA(data_type='rna_seq', output_dir='output_rnaseq_wgcna')
    wgcna_rnaseq.run_complete_analysis(
        count_matrix_path="/backup/as36275d/abhishek/Kamini/backup/GSE169755_counts_sym.csv",
        metadata_path="/backup/as36275d/abhishek/Kamini/backup/METADATA_GSE169755.csv",
        group_column="GSM_Label",
        case_label="Disease",
        control_label="Control"
    )
    
   
    


=== LOADING DATA ===
  Initial dimensions: (19278, 12)
  Aligned samples: 12

=== PREPROCESSING RNA-SEQ DATA ===
  Method: DESeq2 with variance stabilizing transformation (VST)
  After low-count filtering: (19278, 12)
  Running DESeq2...
Using None as control genes, passed at DeseqDataSet initialization


Fitting size factors...
... done in 0.01 seconds.

Fitting dispersions...
... done in 4.01 seconds.

Fitting dispersion trend curve...
... done in 0.49 seconds.

Fitting MAP dispersions...
... done in 4.33 seconds.

Fitting LFCs...
... done in 2.94 seconds.

Calculating cook's distance...
... done in 0.02 seconds.

Replacing 0 outlier genes.

Fitting size factors...
... done in 0.01 seconds.



Fit type used for VST : parametric
Using None as control genes, passed at DeseqDataSet initialization


Fitting dispersions...
... done in 4.60 seconds.



  ✓ VST normalization complete
  Normalized data range: [5.61, 20.68]


Running Wald tests...
... done in 1.14 seconds.



Log2 fold change & Wald test p-value: GSM_Label Disease vs Control
               baseMean  log2FoldChange     lfcSE      stat    pvalue  \
Gene_symbol                                                             
A1BG         149.330525        0.689565  0.474934  1.451916  0.146525   
NAT2           0.835028        0.884032  2.552286  0.346369  0.729066   
ADA            6.026300       -0.562863  1.728799 -0.325580  0.744742   
CDH2         534.411398        0.760583  0.207163  3.671417  0.000241   
AKT3         772.284063       -0.328675  0.167813 -1.958580  0.050162   
...                 ...             ...       ...       ...       ...   
PTBP3        317.668567        0.371334  0.310368  1.196432  0.231528   
KCNE2         10.398638       -0.795574  0.746697 -1.065458  0.286669   
DGCR2        689.784859       -0.359165  0.279178 -1.286509  0.198266   
CASP8AP2     234.654325        0.170190  0.278357  0.611409  0.540929   
SCO2         278.175891        0.108776  0.350396  0.3104

<Figure size 1400x600 with 0 Axes>

In [13]:
## Microarray Data ##  # Example 2: Microarray analysis
if __name__ == "__main__":
    wgcna_microarray = PyWGCNA(data_type='microarray', output_dir='output_microarray_wgcna')
    wgcna_microarray.run_complete_analysis(
        count_matrix_path="/backup/as36275d/abhishek/Kamini/test/Data_extraction/AML/GSE9476/GSE9476_expression.csv",
        metadata_path="/backup/as36275d/abhishek/Kamini/test/Data_extraction/AML/GSE9476/GSE9476_metadata.csv",
        group_column="Gsm_label",
        case_label="Disease",
        control_label="Control",
        covariates=None,  # Add covariates if needed: ['age', 'sex']
        normalize_microarray=True  # Set to True if quantile normalization needed
    )



=== SCIENTIFICALLY OPTIMIZED PyWGCNA PIPELINE (MICROARRAY) ===

=== LOADING DATA ===
  Initial dimensions: (14205, 64)
  Aligned samples: 64

=== PREPROCESSING MICROARRAY DATA ===
  Method: limma with empirical Bayes moderation
  Data appears log-scale (max=15.56)
  Applying quantile normalization...
  ✓ Normalization complete
  Normalized data range: [-5.20, 5.20]
  Creating design matrix...
  Fitting linear model...
  Applying empirical Bayes moderation...
  Extracting differential expression results...
  DEGs found: 9 (FDR<0.05, |log2FC|>1)

=== PREPARING WGCNA INPUT ===
  Variance filtering: Keep top 50% variable genes
  Variance threshold: 0.022
  WGCNA input: 64 samples × 7102 genes

=== SELECTING SOFT-THRESHOLDING POWER ===
  Platform: MICROARRAY
  Target R² threshold: 0.8
  Power range: (1, 20)
  Computing correlation matrix...
  Using 5000 random genes for power selection (dataset has 7102)
  Testing soft powers...

  Selecting optimal power...

  ✓ Selected soft power: β = 7

<Figure size 1400x600 with 0 Axes>