**Install Neccessary Dependencies**

In [None]:
try:
    import torch_geometric
except ImportError:
    !pip install torch-geometric

In [None]:
!pip install umap-learn

In [None]:
pip install --upgrade "numba==0.59.1" "llvmlite==0.42.0"    # if neccessary

In [None]:
!pip install scanpy

In [None]:
!pip install scikit-misc

In [None]:
!pip install python-igraph leidenalg

**Import neccessary libraries**

In [None]:
import pandas as pd
import numpy as np
import scanpy as sc
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, SAGEConv, GATConv
from sklearn.metrics import silhouette_score, adjusted_rand_score, normalized_mutual_info_score, adjusted_mutual_info_score, confusion_matrix
from sklearn.decomposition import PCA
import umap
import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import ttest_ind, ranksums
from sklearn.cluster import KMeans
from sklearn.neighbors import NearestNeighbors

**Pathway Aware GNN Framework**

In [None]:
# ---- DEVICE SETUP ----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("[INFO] Using device:", device)

# ---- 1. SCTransform-like Preprocessing----
def preprocess_sctransform(adata, regress_vars=None, hvg_n=2000):
    print("[INFO] Preprocessing with SCTransform-like workflow...")
    sc.pp.highly_variable_genes(adata, flavor='seurat_v3', n_top_genes=hvg_n)
    adata = adata[:, adata.var['highly_variable']]
    if regress_vars is not None:
        sc.pp.regress_out(adata, regress_vars)
    sc.pp.scale(adata)
    print("[INFO] Preprocessing completed.")
    return adata

def construct_pathway_bipartite_graphs(
    adata,
    pathway_genes,
    k=5,
    min_expression=0.25,          # absolute floor for expression to count as an edge
    gene_percentile=75,           # per-gene percentile threshold
    min_genes_required=5,         # skip pathways with < this many genes in data
    min_edges_per_gene=2,         # require at least this many edges per gene on avg
    fallback_topk=True            # if too sparse, rescue with top-k per cell
):
    """
    Build bipartite graphs (genes ↔ cells) per pathway using selective edges:
      - add an edge gene↔cell only if expression > max(gene-specific percentile, min_expression)
      - if graph is too sparse, optionally fall back to top-k per cell for that pathway
    """
    print("[INFO] Building bipartite graphs (thresholded + optional top-k fallback)...")
    expr = pd.DataFrame(adata.X, index=adata.obs_names, columns=adata.var_names)
    graphs = {}

    for pathway_name, genes in pathway_genes.items():
        available_genes = [g for g in genes if g in expr.columns]
        if len(available_genes) < min_genes_required:
            continue

        G = nx.Graph()
        G.add_nodes_from(available_genes, bipartite=0, node_type='gene')
        G.add_nodes_from(expr.index, bipartite=1, node_type='cell')

        # ---- selective edges by threshold ----
        edges_added = 0
        for gene in available_genes:
            gene_expr = expr[gene]
            # robust per-gene threshold on non-zero values
            if (gene_expr > 0).sum() > 0:
                per_gene_thr = np.percentile(gene_expr[gene_expr > 0], gene_percentile)
            else:
                per_gene_thr = float('inf')  # no non-zero expression; skip adding edges

            thr = max(per_gene_thr, min_expression)

            # add edges for cells where expression is "meaningfully high"
            high_cells = gene_expr.index[gene_expr > thr]
            for cell in high_cells:
                val = gene_expr.loc[cell]
                # Handle both scalar and non-scalar cases
                val_scalar = val.iloc[0] if hasattr(val, 'iloc') else val
                if val_scalar > 0:
                    # weight can be raw or normalized; keep raw for interpretability
                    G.add_edge(gene, cell, weight=float(val_scalar))
                    edges_added += 1

        # ---- density check + fallback ----
        needed_edges = len(available_genes) * min_edges_per_gene
        if edges_added < needed_edges and fallback_topk:
            # rescue sparsity by adding top-k per cell
            pathway_expr = expr[available_genes].to_numpy()
            # normalize per cell to avoid scale dominance
            row_max = pathway_expr.max(axis=1, keepdims=True)
            row_max[row_max == 0] = 1.0
            norm_expr = pathway_expr / row_max

            for i, cell in enumerate(expr.index):
                cell_vec = norm_expr[i]
                # choose top-k genes among available genes (skip zeros)
                k_eff = min(k, (cell_vec > 0).sum()) if (cell_vec > 0).any() else 0
                if k_eff > 0:
                    top_idx = np.argpartition(cell_vec, -k_eff)[-k_eff:]
                    for gi in top_idx:
                        gene = available_genes[gi]
                        if not G.has_edge(gene, cell):
                            raw_val = pathway_expr[i, gi]  #raw_val = expr.at[cell, gene]
                            if raw_val > 0:
                                G.add_edge(gene, cell, weight=float(raw_val))
                                edges_added += 1

        if edges_added >= needed_edges:
            graphs[pathway_name] = G
            print(f"  [INFO] {pathway_name}: {len(available_genes)} genes, {edges_added} edges")
        else:
            # too sparse even after fallback; drop it
            pass

    print(f"[INFO] Graph building completed. Kept pathways: {len(graphs)}")
    return graphs

# ---- 3. Node Features: [mean, std, degree, type] + PCA profile (cells) + pathway membership (genes) ----
def node_features_bipartite_enhanced(G, expr, pathway_genes=None, pathway=None, n_pca=5):
    genes = [n for n, d in G.nodes(data=True) if d['node_type'] == 'gene']
    cells = [n for n, d in G.nodes(data=True) if d['node_type'] == 'cell']

    # PCA for cells
    if len(genes) > 1 and len(cells) > 1:
        pca = PCA(n_components=min(n_pca, len(genes)))
        cell_pca = pca.fit_transform(expr.loc[cells, genes])
        pca_dim = cell_pca.shape[1]
    else:
        cell_pca = np.zeros((len(cells), n_pca))
        pca_dim = n_pca

    features = []
    # --- Genes ---
    for gene in genes:
        vals = expr[gene] if gene in expr.columns else np.zeros(len(cells))
        pathway_score = len(pathway_genes[pathway]) if pathway_genes and pathway else 0
        vec = [np.mean(vals), np.std(vals), G.degree(gene), 1, 0]  # base features
        vec += [pathway_score]                                     # pathway info
        vec += [0.0] * pca_dim                                     # pad PCA slots
        features.append(vec)

    # --- Cells ---
    for i, cell in enumerate(cells):
        vals = expr.loc[cell, genes] if cell in expr.index else np.zeros(len(genes))
        vec = [np.mean(vals), np.std(vals), G.degree(cell), 0, 1]  # base features
        vec += [0.0]                                               # dummy pathway score slot
        vec += list(cell_pca[i])                                   # real PCA profile
        features.append(vec)

    x = np.array(features, dtype=np.float32)
    return x, genes, cells

# ---- 4. Different Models ----
class UnifiedGNN(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, gnn_type='gcn', heads=2):
        super().__init__()
        self.gnn_type = gnn_type.lower()
        self.heads = heads
        if self.gnn_type == 'gcn':
            self.conv1 = GCNConv(in_dim, hidden_dim)
            self.conv2 = GCNConv(hidden_dim, out_dim)
        elif self.gnn_type == 'sage':
            self.conv1 = SAGEConv(in_dim, hidden_dim)
            self.conv2 = SAGEConv(hidden_dim, out_dim)
        elif self.gnn_type == 'gat':
            self.conv1 = GATConv(in_dim, hidden_dim, heads=heads, concat=False)
            self.conv2 = GATConv(hidden_dim, out_dim, heads=heads, concat=False)
        else:
            raise ValueError("gnn_type must be one of: 'gcn', 'sage', 'gat'")

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return x

# ---- 5. Contrastive Loss (with dropout-based augmentation) ----
def contrastive_loss_dropout(cell_emb, drop_prob=0.2, temperature=0.5):
    # Dropout-based augmentation
    mask = (torch.rand_like(cell_emb) > drop_prob).float()
    cell_emb_aug = cell_emb * mask
    N = cell_emb.shape[0]
    z1 = F.normalize(cell_emb, dim=1)
    z2 = F.normalize(cell_emb_aug, dim=1)
    representations = torch.cat([z1, z2], dim=0)
    similarity_matrix = torch.matmul(representations, representations.T)
    mask_eye = torch.eye(2*N, dtype=torch.bool, device=cell_emb.device)
    similarity_matrix = similarity_matrix.masked_fill(mask_eye, -9e15)
    positives = torch.cat([torch.diag(similarity_matrix, N), torch.diag(similarity_matrix, -N)], dim=0)
    negatives = similarity_matrix[~mask_eye].view(2*N, -1)
    logits = torch.cat([positives.unsqueeze(1), negatives], dim=1) / temperature
    labels = torch.zeros(2*N, dtype=torch.long, device=cell_emb.device)
    loss = F.cross_entropy(logits, labels)
    return loss.item(), loss

# ---- 6. Train GNNs only (early stopping) ----
def train_pathway_gnns(graphs, expr, pathway_genes, model_type='sage', hidden_dim=32, out_dim=16,
                       epochs=100, early_stop_patience=15, early_stop_delta=1e-4, device="cuda"):
    print(f"[INFO] Training GNNs ({model_type}) with contrastive loss and early stopping on {device}...")
    pathway_names = list(graphs.keys())
    cells_list = list(expr.index)
    pathway_embs = []

    for pathway in pathway_names:
        G = graphs[pathway]
        x, genes, cells = node_features_bipartite_enhanced(G, expr, pathway_genes, pathway)
        all_nodes = genes + cells
        node_to_idx = {n: i for i, n in enumerate(all_nodes)}
        edges = [(node_to_idx[a], node_to_idx[b]) for a, b in G.edges() if a in node_to_idx and b in node_to_idx]
        if not edges: continue

        edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous().to(device)
        x_tensor = torch.tensor(x, dtype=torch.float32).to(device)

        model = UnifiedGNN(x.shape[1], hidden_dim, out_dim, gnn_type=model_type).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

        best_loss = float("inf")
        patience = 0

        for epoch in range(epochs):
            optimizer.zero_grad()
            emb = model(x_tensor, edge_index)
            cell_emb = emb[len(genes):]
            loss_val, loss = contrastive_loss_dropout(cell_emb, drop_prob=0.2, temperature=0.5)
            loss.backward()
            optimizer.step()

            if loss_val + early_stop_delta < best_loss:
                best_loss = loss_val
                patience = 0
            else:
                patience += 1

            if patience > early_stop_patience:
                print(f"[INFO] Early stopped at epoch {epoch+1} for pathway '{pathway}' (best loss: {best_loss:.4f})")
                break

            if (epoch+1) % 50 == 0 or epoch == 0:
                print(f"[INFO] Pathway '{pathway}' epoch {epoch+1}, contrastive loss: {loss_val:.4f}")

        cell_emb = emb[len(genes):].detach().cpu().numpy()
        cell_emb_dict = {cell: cell_emb[i] for i, cell in enumerate(cells)}
        emb_reordered = np.array([cell_emb_dict.get(cell, np.zeros(out_dim)) for cell in cells_list])
        pathway_embs.append(emb_reordered)

    pathway_embs = np.stack(pathway_embs, axis=1)  # (cells, pathways, dim)
    pathway_embs_tensor = torch.tensor(pathway_embs, dtype=torch.float32).to(device)

    print(f"[INFO] Finished training GNNs. Pathway embeddings shape = {pathway_embs_tensor.shape}")
    return pathway_embs_tensor, pathway_names, cells_list


# ---- 7. Pathway Attention ----
class PathwayAttention(nn.Module):
    def __init__(self, pathway_dim, num_pathways, temperature=1.0, entropy_reg=0.01):
        super().__init__()
        self.attn_layer = nn.Linear(pathway_dim, 1, bias=False)
        self.temperature = temperature
        self.entropy_reg = entropy_reg

    def forward(self, pathway_embs):
        attn_scores = self.attn_layer(pathway_embs)[:,:,0]  # (cells, pathways)
        attn_weights = F.softmax(attn_scores / self.temperature, dim=1)  # scaled softmax
        weighted_emb = torch.sum(pathway_embs * attn_weights.unsqueeze(-1), dim=1)  # (cells, dim)
        entropy = -torch.sum(attn_weights * torch.log(attn_weights + 1e-10), dim=1).mean()
        return weighted_emb, attn_weights, entropy * self.entropy_reg


# ---- 8. Helper: Run Attention with multiple temps ----
def run_attention(pathway_embs_tensor, pathway_names, cells,
                  temperatures=[0.33, 0.7, 0.9, 1.2, 1.5], entropy_reg=0.01, device="cuda"):
    results = {}
    pathway_embs_tensor = pathway_embs_tensor.to(device)

    for T in temperatures:
        print(f"\n[INFO] Running attention with temperature={T} on {device}...")
        attn_model = PathwayAttention(
            pathway_dim=pathway_embs_tensor.shape[-1],
            num_pathways=pathway_embs_tensor.shape[1],
            temperature=T,
            entropy_reg=entropy_reg
        ).to(device)
        attn_model.eval()
        with torch.no_grad():
            final_emb_tensor, attn_weights_tensor, entropy_loss = attn_model(pathway_embs_tensor)
        final_emb = final_emb_tensor.cpu().numpy()
        attn_weights = attn_weights_tensor.cpu().numpy()
        results[T] = {
            "final_emb": final_emb,
            "attn_weights": attn_weights,
            "entropy_loss": entropy_loss.item()
        }
        print(f"[INFO] Attention done (T={T}). Entropy={entropy_loss.item():.4f}")
    return results

# ---- 8. Leiden and KMeans Clustering, Evaluation ----
def cluster_and_eval_scanpy(final_emb, cells, adata, n_clusters=None, leiden_res_list=None):
    print("[INFO] Running clustering (Leiden + KMeans)...")

    # Default Leiden resolutions to test
    if leiden_res_list is None:
        leiden_res_list = [0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.5, 2.0]

    # Add embeddings to AnnData
    adata.obsm['X_gnn'] = final_emb
    sc.pp.neighbors(adata, use_rep='X_gnn', n_neighbors=15)

    leiden_silhouettes = []
    leiden_label_sets = {}

    #  Sweep Leiden resolutions
    for res in leiden_res_list:
        sc.tl.leiden(adata, resolution=res, key_added=f'gnn_leiden_{res}')
        leiden_labels = adata.obs[f'gnn_leiden_{res}'].astype(str).reindex(cells).fillna('NA').tolist()
        leiden_labels_num = pd.factorize(leiden_labels)[0]

        if len(set(leiden_labels_num)) > 1:  # silhouette needs >1 cluster
            sil = silhouette_score(final_emb, leiden_labels_num)
        else:
            sil = -1  # invalid
        leiden_silhouettes.append(sil)
        leiden_label_sets[res] = leiden_labels
        print(f"[INFO] Leiden res={res:.2f}, clusters={len(set(leiden_labels))}, silhouette={sil:.3f}")

    # Pick best Leiden resolution
    best_res = leiden_res_list[np.argmax(leiden_silhouettes)]
    best_leiden_labels = leiden_label_sets[best_res]
    best_leiden_silhouette = max(leiden_silhouettes)
    print(f"[INFO] Best Leiden resolution={best_res}, silhouette={best_leiden_silhouette:.3f}")

    #  KMeans backup
    if n_clusters is None:
        sil_scores = []
        for k in range(3, 15):
            labels = KMeans(n_clusters=k).fit_predict(final_emb)
            sil_scores.append(silhouette_score(final_emb, labels))
        n_clusters = np.argmax(sil_scores) + 3
        print(f"[INFO] Using elbow method: n_clusters={n_clusters}")

    elif n_clusters == "unique":
        n_clusters = adata.obs['assigned_cluster'].nunique()
        print(f"[INFO] Using unique assigned clusters: n_clusters={n_clusters}")

    kmeans_labels = KMeans(n_clusters=n_clusters).fit_predict(final_emb)
    kmeans_silhouette = silhouette_score(final_emb, kmeans_labels)
    print(f"[INFO] KMeans n={n_clusters}, silhouette={kmeans_silhouette:.3f}")

    #  Evaluation (if reference labels exist)
    if 'assigned_cluster' in adata.obs:
        true_labels = adata.obs.loc[cells, 'assigned_cluster']
        metrics = {
            'Leiden ARI': adjusted_rand_score(true_labels, best_leiden_labels),
            'Leiden NMI': normalized_mutual_info_score(true_labels, best_leiden_labels),
            'Leiden AMI': adjusted_mutual_info_score(true_labels, best_leiden_labels),
            'Leiden Silhouette': best_leiden_silhouette,
            'KMeans ARI': adjusted_rand_score(true_labels, kmeans_labels),
            'KMeans NMI': normalized_mutual_info_score(true_labels, kmeans_labels),
            'KMeans AMI': adjusted_mutual_info_score(true_labels, kmeans_labels),
            'KMeans Silhouette': kmeans_silhouette
        }
        print("[INFO] Clustering metrics:")
        for k, v in metrics.items():
            print(f"  {k}: {v:.3f}")

        # Fix label alignment and types
        true_labels = adata.obs.loc[cells, 'assigned_cluster'].astype(str).values
        pred_labels = kmeans_labels.astype(str)

        # Crosstab for rectangular confusion matrix
        cm_df = pd.crosstab(pd.Series(true_labels, name="True"), pd.Series(pred_labels, name="Predicted"))

        plt.figure(figsize=(12,8))
        sns.heatmap(cm_df, annot=True, fmt="d", cmap="Blues")
        plt.xlabel("Predicted Clusters")
        plt.ylabel("True Labels (Assigned Clusters)")
        plt.title("Confusion Matrix : True vs Predicted") #(Crosstab)
        plt.tight_layout()
        plt.show()

    else:
        print("[INFO] No reference cluster labels found.")

    return best_leiden_labels, kmeans_labels, best_res, best_leiden_silhouette, kmeans_silhouette

# ---- 9. Pathway Attribution per Cluster ----
def pathway_attribution(attn_weights, pathway_names, cluster_labels, cells, top_n=5):
    print("[INFO] Calculating pathway attribution per cluster...")
    attention_df = pd.DataFrame(attn_weights, index=cells, columns=pathway_names)
    results = {}
    clusters = sorted(set(cluster_labels))
    for cluster in clusters:
        cluster_cells = [cell for cell, lbl in zip(cells, cluster_labels) if lbl == cluster]
        other_cells = [cell for cell, lbl in zip(cells, cluster_labels) if lbl != cluster]
        cluster_attention = attention_df.loc[cluster_cells].mean()
        pathway_pvals = {}
        for pathway in pathway_names:
            vals1 = attention_df.loc[cluster_cells, pathway]
            vals2 = attention_df.loc[other_cells, pathway]
            try:
                _, pval = ttest_ind(vals1, vals2, equal_var=False)
            except:
                pval = 1.0
            pathway_pvals[pathway] = pval
        sorted_pathways = cluster_attention.sort_values(ascending=False)
        results[cluster] = []
        for pathway in sorted_pathways.head(top_n).index:
            importance = cluster_attention[pathway]
            pval = pathway_pvals[pathway]
            significance = (
                "***" if pval < 0.001 else
                "**" if pval < 0.01 else
                "*" if pval < 0.05 else ""
            )
            results[cluster].append({
                "pathway": pathway,
                "importance": importance,
                "pval": pval,
                "significance": significance
            })
    print("[INFO] Pathway attribution completed.")
    return results

def print_cluster_pathways(cluster_results, cluster_labels, cells):
    for cluster, pathways in cluster_results.items():
        cluster_cells = [cell for cell, lbl in zip(cells, cluster_labels) if lbl == cluster]
        print(f"\nCluster {cluster} ({len(cluster_cells)} cells):")
        for pw in pathways:
            print(f"  {pw['pathway']}: {pw['importance']:.3f} (p={pw['pval']:.3f}) {pw['significance']}")

# ---- 10. Visualizations ----
def visualize(final_emb, cluster_labels, attn_weights, pathway_names, cells, leiden_labels=None):
    print("[INFO] Generating visualizations (PCA, UMAP, heatmap)...")
    # PCA
    pca = PCA(n_components=2)
    pca_proj = pca.fit_transform(final_emb)
    plt.figure(figsize=(6,5))
    sns.scatterplot(x=pca_proj[:,0], y=pca_proj[:,1], hue=cluster_labels, palette='tab20', s=15)
    plt.title(f'PCA of Cell Embeddings (Var explained: {pca.explained_variance_ratio_.sum():.2f})')
    plt.xlabel("PCA1")
    plt.ylabel("PCA2")
    plt.legend(title="Cluster")
    plt.tight_layout()
    plt.show()
    # UMAP
    reducer = umap.UMAP(n_components=2, random_state=42)
    umap_proj = reducer.fit_transform(final_emb)
    plt.figure(figsize=(6,5))
    if leiden_labels is not None:
        sns.scatterplot(x=umap_proj[:,0], y=umap_proj[:,1], hue=leiden_labels, palette='tab20', s=15)
        plt.title("UMAP of Cell Embeddings (Leiden)")
        plt.legend(title="Leiden")
    else:
        sns.scatterplot(x=umap_proj[:,0], y=umap_proj[:,1], hue=cluster_labels, palette='tab20', s=15)
        plt.title("UMAP of Cell Embeddings")
        plt.legend(title="Cluster")
    plt.xlabel("UMAP1")
    plt.ylabel("UMAP2")
    plt.tight_layout()
    plt.show()
    #kmeans umap
    plt.figure(figsize=(6,5))
    sns.scatterplot(x=umap_proj[:,0], y=umap_proj[:,1], hue=cluster_labels, palette='tab20', s=15)
    plt.title("UMAP of Cell Embeddings")
    plt.legend(title="Cluster")
    plt.xlabel("UMAP1")
    plt.ylabel("UMAP2")
    plt.tight_layout()
    plt.show()
    # Pathway-cluster heatmap
    attention_df = pd.DataFrame(attn_weights, index=cells, columns=pathway_names)
    cluster_attention_matrix = []
    clusters = sorted(set(cluster_labels))
    for cluster in clusters:
        cluster_cells = [cell for cell, lbl in zip(cells, cluster_labels) if lbl == cluster]
        cluster_attention_matrix.append(attention_df.loc[cluster_cells].mean().values)
    cluster_attention_matrix = np.array(cluster_attention_matrix)
    top_k = min(20, len(pathway_names))
    avg_attention = np.mean(cluster_attention_matrix, axis=0)
    top_k_indices = np.argsort(avg_attention)[-top_k:]
    heatmap_matrix = cluster_attention_matrix[:, top_k_indices]
    heatmap_labels = [pathway_names[i][:30] for i in top_k_indices]
    plt.figure(figsize=(10,6))
    sns.heatmap(heatmap_matrix, xticklabels=heatmap_labels, yticklabels=[f"Cluster {i}" for i in clusters], cmap='viridis')
    plt.title("Top Pathway Attention per Cluster")
    plt.xlabel("Pathway")
    plt.ylabel("Cluster")
    plt.tight_layout()
    plt.show()
    print("[INFO] Visualizations completed.")

# ---- 9. Pipeline ----
def run_pipeline(
    expr_csv, pathway_file, metadata_cols=['barcode', 'assigned_cluster'],
    regress_vars=None, hvg_n=2000, min_genes=10, max_genes=300, k=5,
    model_type='sage', hidden_dim=32, out_dim=16, epochs=100,
    early_stop_patience=15, early_stop_delta=1e-4,
    top_n_pathways=5, top_n_marker_genes=5, n_clusters=None, rare_cell_thresh=50,
    temperatures=[0.7], entropy_reg=0.01, device="cuda"
):
    print("[INFO] Loading expression matrix...")
    df = pd.read_csv(expr_csv, index_col=0)
    metadata = df[metadata_cols]
    expr_df = df.drop(columns=metadata_cols)
    expr_df = expr_df.apply(pd.to_numeric, errors='coerce').fillna(0)
    adata = sc.AnnData(expr_df)
    adata.obs = metadata

    # Fix the duplicate issue
    print(f"[INFO] Before fixing: {adata.obs_names.duplicated().sum()} duplicate cell names")
    print(f"[INFO] Before fixing: {adata.var_names.duplicated().sum()} duplicate gene names")

    # Make observation names (cells) unique
    if adata.obs_names.duplicated().any():
        print("[WARNING] Found duplicate cell names. Making unique...")
        adata.obs_names_make_unique()

    # Make variable names (genes) unique
    if adata.var_names.duplicated().any():
        print("[WARNING] Found duplicate gene names. Making unique...")
        adata.var_names_make_unique()

    print(f"[INFO] After fixing: {adata.obs_names.duplicated().sum()} duplicate cell names")
    print(f"[INFO] After fixing: {adata.var_names.duplicated().sum()} duplicate gene names")

    # Remove rare cell clusters
    if 'assigned_cluster' in adata.obs:
        counts = adata.obs['assigned_cluster'].value_counts()
        keep = counts[counts >= rare_cell_thresh].index
        before = adata.n_obs
        adata = adata[adata.obs['assigned_cluster'].isin(keep)]
        print(f"[INFO] Removed rare cell types (<{rare_cell_thresh} cells). {before} → {adata.n_obs} cells kept.")

    adata = preprocess_sctransform(adata, regress_vars=regress_vars, hvg_n=hvg_n)

    # Pathway parsing
    print("[INFO] Parsing pathway file...")
    pathway_genes = {}
    genes_in_data = set(adata.var_names)
    with open(pathway_file, 'r') as f:
        for line in f:
            parts = line.strip().split('\t')
            pathway, genes = parts[0], parts[2:]
            filtered = [g for g in genes if g in genes_in_data]
            if min_genes <= len(filtered) <= max_genes:
                pathway_genes[pathway] = filtered
    print(f"[INFO] Pathway parsing completed. Filtered pathways: {len(pathway_genes)}")

    # Graph construction
    graphs = construct_pathway_bipartite_graphs(adata, pathway_genes, k=k)
    expr = pd.DataFrame(adata.X, index=adata.obs_names, columns=adata.var_names)

    # ---- Train GNNs once ----
    pathway_embs_tensor, pathway_names, cells = train_pathway_gnns(
        graphs, expr, pathway_genes, model_type=model_type,
        hidden_dim=hidden_dim, out_dim=out_dim, epochs=epochs,
        early_stop_patience=early_stop_patience, early_stop_delta=early_stop_delta
    )

    # ---- Run Attention with multiple temperatures ----
    attn_results = run_attention(pathway_embs_tensor, pathway_names, cells, temperatures, entropy_reg)

    # ---- Cluster + Evaluate for each temp ----
    for T, res in attn_results.items():
        print(f"\n====== Results for Temperature={T} ======")
        leiden_labels, kmeans_labels, best_res, leiden_sil, kmeans_sil = cluster_and_eval_scanpy(
            res["final_emb"], cells, adata, n_clusters=n_clusters
        )
        cluster_results = pathway_attribution(res["attn_weights"], pathway_names, kmeans_labels, cells, top_n=top_n_pathways)
        print_cluster_pathways(cluster_results, kmeans_labels, cells)

        # ---- Marker Gene Enrichment using Scanpy ----
        print("[INFO] Running marker gene analysis with Scanpy (Wilcoxon test)...")
        adata.obs["cluster"] = kmeans_labels.astype(str)
        sc.tl.rank_genes_groups(adata, groupby="cluster", method="wilcoxon", use_raw=False)
        sc.pl.rank_genes_groups(adata, n_genes=top_n_marker_genes, sharey=False)
        markers_df = sc.get.rank_genes_groups_df(adata, group=None)
        for c in sorted(adata.obs["cluster"].unique()):
            top_genes = markers_df[markers_df["group"] == c].head(top_n_marker_genes)
            print(f"\nCluster {c} top marker genes:")
            for _, row in top_genes.iterrows():
                sig = (
                    "***" if row["pvals_adj"] < 0.001 else
                    "**" if row["pvals_adj"] < 0.01 else
                    "*" if row["pvals_adj"] < 0.05 else ""
                )
                print(f"  {row['names']}: logFC={row['logfoldchanges']:.2f}, adj.p={row['pvals_adj']:.3g} {sig}")

        visualize(res["final_emb"], kmeans_labels, res["attn_weights"], pathway_names, cells, leiden_labels)

# Example usage:
run_pipeline(
   'your_scRNA seq_csv_dataset',
   'your_pathway_dataset',
   model_type='gcn', epochs=150, early_stop_patience=15, n_clusters="unique"
 )

**Simple Baseline Bipartite**



In [None]:
# ---- DEVICE SETUP ----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("[INFO] Using device:", device)

# ---- 1. Same preprocessing as main method ----
def preprocess_sctransform(adata, regress_vars=None, hvg_n=2000):
    print("[INFO] Preprocessing with SCTransform-like workflow (highly variable genes + regression)...")
    sc.pp.highly_variable_genes(adata, flavor='seurat_v3', n_top_genes=hvg_n)
    adata = adata[:, adata.var['highly_variable']]
    if regress_vars is not None:
        sc.pp.regress_out(adata, regress_vars)
    sc.pp.scale(adata)
    print("[INFO] Preprocessing completed.")
    return adata

# ---- 2. Simple Bipartite Graph Construction (NO pathway grouping) ----
def construct_simple_bipartite_graph(
    adata,
    k=5,
    min_expression=0.25,
    gene_percentile=75,
    min_edges_per_gene=2,
    fallback_topk=True
):
    """
    Build a SINGLE bipartite graph (genes ↔ cells) using ALL genes together
    WITHOUT pathway-specific grouping. This is the fair baseline.

    Uses the same edge construction logic as pathway-aware method:
    - Selective edges based on expression thresholds
    - Same fallback mechanism
    BUT treats all genes as one big "pathway" to isolate the effect of pathway information
    """
    print("[INFO] Building simple bipartite graph (all genes, no pathway grouping)...")
    expr = pd.DataFrame(adata.X, index=adata.obs_names, columns=adata.var_names)

    genes = list(adata.var_names)
    cells = list(adata.obs_names)

    G = nx.Graph()
    G.add_nodes_from(genes, bipartite=0, node_type='gene')
    G.add_nodes_from(cells, bipartite=1, node_type='cell')

    # ---- selective edges by threshold (same logic as pathway method) ----
    edges_added = 0
    for gene in genes:
        gene_expr = expr[gene]

        # robust per-gene threshold on non-zero values
        if (gene_expr > 0).sum() > 0:
            per_gene_thr = np.percentile(gene_expr[gene_expr > 0], gene_percentile)
        else:
            per_gene_thr = float('inf')

        thr = max(per_gene_thr, min_expression)

        # add edges for cells where expression is "meaningfully high"
        high_cells = gene_expr.index[gene_expr > thr]
        for cell in high_cells:
            val = gene_expr.loc[cell]
            val_scalar = val.iloc[0] if hasattr(val, 'iloc') else val
            if val_scalar > 0:
                G.add_edge(gene, cell, weight=float(val_scalar))
                edges_added += 1

    # ---- density check + fallback ----
    needed_edges = len(genes) * min_edges_per_gene
    if edges_added < needed_edges and fallback_topk:
        print(f"[INFO] Graph sparse ({edges_added} edges), applying top-k fallback...")
        expr_array = expr.to_numpy()

        # normalize per cell
        row_max = expr_array.max(axis=1, keepdims=True)
        row_max[row_max == 0] = 1.0
        norm_expr = expr_array / row_max

        for i, cell in enumerate(cells):
            cell_vec = norm_expr[i]
            k_eff = min(k, (cell_vec > 0).sum()) if (cell_vec > 0).any() else 0
            if k_eff > 0:
                top_idx = np.argpartition(cell_vec, -k_eff)[-k_eff:]
                for gi in top_idx:
                    gene = genes[gi]
                    if not G.has_edge(gene, cell):
                        raw_val = expr_array[i, gi]
                        if raw_val > 0:
                            G.add_edge(gene, cell, weight=float(raw_val))
                            edges_added += 1

    print(f"[INFO] Bipartite graph constructed: {len(genes)} genes, {len(cells)} cells, {edges_added} edges")
    return G, genes, cells

# ---- 3. Node Features for Bipartite Graph (same as pathway method) ----
def node_features_bipartite_simple(G, expr, n_pca=5):
    """
    Create node features for bipartite graph - same logic as pathway method
    but without pathway-specific information
    """
    genes = [n for n, d in G.nodes(data=True) if d['node_type'] == 'gene']
    cells = [n for n, d in G.nodes(data=True) if d['node_type'] == 'cell']

    # PCA for cells
    if len(genes) > 1 and len(cells) > 1:
        pca = PCA(n_components=min(n_pca, len(genes)))
        cell_pca = pca.fit_transform(expr.loc[cells, genes])
        pca_dim = cell_pca.shape[1]
    else:
        cell_pca = np.zeros((len(cells), n_pca))
        pca_dim = n_pca

    features = []
    # ---- Genes ----
    for gene in genes:
        vals = expr[gene] if gene in expr.columns else np.zeros(len(cells))
        # Base features: [mean, std, degree, type_gene, type_cell]
        vec = [np.mean(vals), np.std(vals), G.degree(gene), 1, 0]
        # Pad PCA slots (genes don't get PCA features)
        vec += [0.0] * pca_dim
        features.append(vec)

    # ---- Cells ----
    for i, cell in enumerate(cells):
        vals = expr.loc[cell, genes] if cell in expr.index else np.zeros(len(genes))
        vec = [np.mean(vals), np.std(vals), G.degree(cell), 0, 1]
        # Add PCA profile
        vec += list(cell_pca[i])
        features.append(vec)

    x = np.array(features, dtype=np.float32)
    return x, genes, cells

# ---- 4. Simple GNN Model ----
class SimpleGNN(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, gnn_type='gcn', heads=2):
        super().__init__()
        self.gnn_type = gnn_type.lower()
        if self.gnn_type == 'gcn':
            self.conv1 = GCNConv(in_dim, hidden_dim)
            self.conv2 = GCNConv(hidden_dim, out_dim)
        elif self.gnn_type == 'sage':
            self.conv1 = SAGEConv(in_dim, hidden_dim)
            self.conv2 = SAGEConv(hidden_dim, out_dim)
        elif self.gnn_type == 'gat':
            self.conv1 = GATConv(in_dim, hidden_dim, heads=heads, concat=False)
            self.conv2 = GATConv(hidden_dim, out_dim, heads=heads, concat=False)
        else:
            raise ValueError("gnn_type must be one of: 'gcn', 'sage', 'gat'")

        self.dropout = nn.Dropout(0.2)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        return x

# ---- 5. Contrastive Loss (same as main method) ----
def contrastive_loss_dropout(cell_emb, drop_prob=0.2, temperature=0.5):
    mask = (torch.rand_like(cell_emb) > drop_prob).float()
    cell_emb_aug = cell_emb * mask
    N = cell_emb.shape[0]
    z1 = F.normalize(cell_emb, dim=1)
    z2 = F.normalize(cell_emb_aug, dim=1)
    representations = torch.cat([z1, z2], dim=0)
    similarity_matrix = torch.matmul(representations, representations.T)
    mask_eye = torch.eye(2*N, dtype=torch.bool, device=cell_emb.device)
    similarity_matrix = similarity_matrix.masked_fill(mask_eye, -9e15)
    positives = torch.cat([torch.diag(similarity_matrix, N), torch.diag(similarity_matrix, -N)], dim=0)
    negatives = similarity_matrix[~mask_eye].view(2*N, -1)
    logits = torch.cat([positives.unsqueeze(1), negatives], dim=1) / temperature
    labels = torch.zeros(2*N, dtype=torch.long, device=cell_emb.device)
    loss = F.cross_entropy(logits, labels)
    return loss.item(), loss

# ---- 6. Training (modified for bipartite graph) ----
def train_simple_bipartite_gnn(G, expr, model_type='sage', hidden_dim=32, out_dim=16,
                               epochs=100, early_stop_patience=15, early_stop_delta=1e-4, device="cuda"):
    print(f"[INFO] Training simple bipartite GNN ({model_type}) with contrastive loss...")

    # Get node features
    x, genes, cells = node_features_bipartite_simple(G, expr, n_pca=5)

    all_nodes = genes + cells
    node_to_idx = {n: i for i, n in enumerate(all_nodes)}

    # Convert graph to PyTorch format
    edges = [(node_to_idx[a], node_to_idx[b]) for a, b in G.edges() if a in node_to_idx and b in node_to_idx]
    if not edges:
        print("[ERROR] No edges in graph!")
        return None, cells

    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous().to(device)
    x_tensor = torch.tensor(x, dtype=torch.float32).to(device)

    # Model setup
    model = SimpleGNN(x.shape[1], hidden_dim, out_dim, gnn_type=model_type).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    best_loss = float("inf")
    patience = 0

    for epoch in range(epochs):
        optimizer.zero_grad()
        emb = model(x_tensor, edge_index)

        # Extract cell embeddings (same as pathway method)
        cell_emb = emb[len(genes):]

        loss_val, loss = contrastive_loss_dropout(cell_emb, drop_prob=0.2, temperature=0.5)
        loss.backward()
        optimizer.step()

        if loss_val + early_stop_delta < best_loss:
            best_loss = loss_val
            patience = 0
        else:
            patience += 1

        if patience > early_stop_patience:
            print(f"[INFO] Early stopped at epoch {epoch+1} (best loss: {best_loss:.4f})")
            break

        if (epoch+1) % 50 == 0 or epoch == 0:
            print(f"[INFO] Epoch {epoch+1}, contrastive loss: {loss_val:.4f}")

    # Get final cell embeddings
    model.eval()
    with torch.no_grad():
        final_emb = model(x_tensor, edge_index)
        cell_emb = final_emb[len(genes):].cpu().numpy()

    print(f"[INFO] Training completed. Cell embeddings shape: {cell_emb.shape}")
    return cell_emb, cells

# ---- 7. Clustering and Evaluation (same as main method) ----
def cluster_and_eval_simple(final_emb, cells, adata, n_clusters=None, leiden_res_list=None):
    print("[INFO] Running clustering (Leiden + KMeans)...")

    if leiden_res_list is None:
        leiden_res_list = [0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.5, 2.0]

    adata.obsm['X_simple_gnn'] = final_emb
    sc.pp.neighbors(adata, use_rep='X_simple_gnn', n_neighbors=15)

    leiden_silhouettes = []
    leiden_label_sets = {}

    for res in leiden_res_list:
        sc.tl.leiden(adata, resolution=res, key_added=f'simple_leiden_{res}')
        leiden_labels = adata.obs[f'simple_leiden_{res}'].astype(str).reindex(cells).fillna('NA').tolist()
        leiden_labels_num = pd.factorize(leiden_labels)[0]

        if len(set(leiden_labels_num)) > 1:
            sil = silhouette_score(final_emb, leiden_labels_num)
        else:
            sil = -1
        leiden_silhouettes.append(sil)
        leiden_label_sets[res] = leiden_labels
        print(f"[INFO] Leiden res={res:.2f}, clusters={len(set(leiden_labels))}, silhouette={sil:.3f}")

    best_res = leiden_res_list[np.argmax(leiden_silhouettes)]
    best_leiden_labels = leiden_label_sets[best_res]
    best_leiden_silhouette = max(leiden_silhouettes)

    # KMeans
    if n_clusters is None:
        sil_scores = []
        for k in range(3, 15):
            labels = KMeans(n_clusters=k).fit_predict(final_emb)
            sil_scores.append(silhouette_score(final_emb, labels))
        n_clusters = np.argmax(sil_scores) + 3
    elif n_clusters == "unique":
        n_clusters = adata.obs['assigned_cluster'].nunique()

    kmeans_labels = KMeans(n_clusters=n_clusters).fit_predict(final_emb)
    kmeans_silhouette = silhouette_score(final_emb, kmeans_labels)

    # Evaluation
    if 'assigned_cluster' in adata.obs:
        true_labels = adata.obs.loc[cells, 'assigned_cluster']
        metrics = {
            'Leiden ARI': adjusted_rand_score(true_labels, best_leiden_labels),
            'Leiden NMI': normalized_mutual_info_score(true_labels, best_leiden_labels),
            'Leiden AMI': adjusted_mutual_info_score(true_labels, best_leiden_labels),
            'Leiden Silhouette': best_leiden_silhouette,
            'KMeans ARI': adjusted_rand_score(true_labels, kmeans_labels),
            'KMeans NMI': normalized_mutual_info_score(true_labels, kmeans_labels),
            'KMeans AMI': adjusted_mutual_info_score(true_labels, kmeans_labels),
            'KMeans Silhouette': kmeans_silhouette
        }
        print("[INFO] Simple GNN Clustering metrics:")
        for k, v in metrics.items():
            print(f"  {k}: {v:.3f}")

        # Confusion matrix
        true_labels = adata.obs.loc[cells, 'assigned_cluster'].astype(str).values
        pred_labels = kmeans_labels.astype(str)
        cm_df = pd.crosstab(pd.Series(true_labels, name="True"), pd.Series(pred_labels, name="Predicted"))
        plt.figure(figsize=(12,8))
        sns.heatmap(cm_df, annot=True, fmt="d", cmap="Blues")
        plt.xlabel("Predicted Clusters")
        plt.ylabel("True Labels")
        plt.title("Simple GNN: Confusion Matrix")
        plt.tight_layout()
        plt.show()

    return best_leiden_labels, kmeans_labels, best_res, best_leiden_silhouette, kmeans_silhouette

# ---- 8. Visualization ----
def visualize_simple(final_emb, cluster_labels, cells, leiden_labels=None):
    print("[INFO] Generating visualizations (PCA, UMAP)...")

    # PCA
    pca = PCA(n_components=2)
    pca_proj = pca.fit_transform(final_emb)
    plt.figure(figsize=(6,5))
    sns.scatterplot(x=pca_proj[:,0], y=pca_proj[:,1], hue=cluster_labels, palette='tab20', s=15)
    plt.title(f'Simple GNN: PCA (Var explained: {pca.explained_variance_ratio_.sum():.2f})')
    plt.xlabel("PCA1")
    plt.ylabel("PCA2")
    plt.legend(title="Cluster")
    plt.tight_layout()
    plt.show()

    # UMAP
    reducer = umap.UMAP(n_components=2, random_state=42)
    umap_proj = reducer.fit_transform(final_emb)

    if leiden_labels is not None:
        plt.figure(figsize=(6,5))
        sns.scatterplot(x=umap_proj[:,0], y=umap_proj[:,1], hue=leiden_labels, palette='tab20', s=15)
        plt.title("Simple GNN: UMAP (Leiden)")
        plt.xlabel("UMAP1")
        plt.ylabel("UMAP2")
        plt.legend(title="Leiden")
        plt.tight_layout()
        plt.show()

    plt.figure(figsize=(6,5))
    sns.scatterplot(x=umap_proj[:,0], y=umap_proj[:,1], hue=cluster_labels, palette='tab20', s=15)
    plt.title("Simple GNN: UMAP (KMeans)")
    plt.xlabel("UMAP1")
    plt.ylabel("UMAP2")
    plt.legend(title="Cluster")
    plt.tight_layout()
    plt.show()

# ---- 9. Main Pipeline (updated for bipartite graph) ----
def run_simple_baseline_pipeline(
    expr_csv, metadata_cols=['barcode', 'assigned_cluster'],
    regress_vars=None, hvg_n=2000,
    k=5, min_expression=0.25, gene_percentile=75,  # same params as pathway method
    model_type='sage', hidden_dim=32, out_dim=16, epochs=100,
    early_stop_patience=15, early_stop_delta=1e-4,
    top_n_marker_genes=5, n_clusters=None, rare_cell_thresh=40,
    device="cuda"
):
    print("[INFO] === SIMPLE BIPARTITE GNN BASELINE (No Pathway Information) ===")
    print("[INFO] Loading expression matrix...")

    df = pd.read_csv(expr_csv, index_col=0)
    metadata = df[metadata_cols]
    expr_df = df.drop(columns=metadata_cols)
    expr_df = expr_df.apply(pd.to_numeric, errors='coerce').fillna(0)

    adata = sc.AnnData(expr_df)
    adata.obs = metadata

    # Fix duplicate names
    if adata.obs_names.duplicated().any():
        print("[WARNING] Found duplicate cell names. Making unique...")
        adata.obs_names_make_unique()
    if adata.var_names.duplicated().any():
        print("[WARNING] Found duplicate gene names. Making unique...")
        adata.var_names_make_unique()

    # Remove rare cell clusters
    if 'assigned_cluster' in adata.obs:
        counts = adata.obs['assigned_cluster'].value_counts()
        keep = counts[counts >= rare_cell_thresh].index
        before = adata.n_obs
        adata = adata[adata.obs['assigned_cluster'].isin(keep)]
        print(f"[INFO] Removed rare cell types (<{rare_cell_thresh} cells). {before} → {adata.n_obs} cells kept.")

    # Preprocess (same as pathway method)
    adata = preprocess_sctransform(adata, regress_vars=regress_vars, hvg_n=hvg_n)

    # Build simple bipartite graph (all genes together, no pathway grouping)
    G, genes, cells = construct_simple_bipartite_graph(
        adata, k=k, min_expression=min_expression,
        gene_percentile=gene_percentile, fallback_topk=True
    )

    expr = pd.DataFrame(adata.X, index=adata.obs_names, columns=adata.var_names)

    # Train GNN
    final_emb, cells_ordered = train_simple_bipartite_gnn(
        G, expr, model_type=model_type, hidden_dim=hidden_dim, out_dim=out_dim,
        epochs=epochs, early_stop_patience=early_stop_patience, early_stop_delta=early_stop_delta,
        device=device
    )

    if final_emb is None:
        print("[ERROR] Training failed!")
        return

    # Clustering and evaluation
    leiden_labels, kmeans_labels, best_res, leiden_sil, kmeans_sil = cluster_and_eval_simple(
        final_emb, cells_ordered, adata, n_clusters=n_clusters
    )

    # Marker gene analysis
    print("[INFO] Running marker gene analysis...")
    adata.obs["cluster"] = kmeans_labels.astype(str)
    sc.tl.rank_genes_groups(adata, groupby="cluster", method="wilcoxon", use_raw=False)
    sc.pl.rank_genes_groups(adata, n_genes=top_n_marker_genes, sharey=False)

    markers_df = sc.get.rank_genes_groups_df(adata, group=None)
    for c in sorted(adata.obs["cluster"].unique()):
        top_genes = markers_df[markers_df["group"] == c].head(top_n_marker_genes)
        print(f"\nCluster {c} top marker genes:")
        for _, row in top_genes.iterrows():
            sig = (
                "***" if row["pvals_adj"] < 0.001 else
                "**" if row["pvals_adj"] < 0.01 else
                "*" if row["pvals_adj"] < 0.05 else ""
            )
            print(f"  {row['names']}: logFC={row['logfoldchanges']:.2f}, adj.p={row['pvals_adj']:.3g} {sig}")

    # Visualizations
    visualize_simple(final_emb, kmeans_labels, cells_ordered, leiden_labels)

    print("[INFO] Simple bipartite GNN baseline completed!")


# Example usage:
run_simple_baseline_pipeline(
    'your_scRNA seq_csv_dataset',
    model_type='gcn', epochs=150, n_clusters="unique"
)

**Simple Baseline MLP**

In [None]:
# ---- DEVICE SETUP ----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("[INFO] Using device:", device)

# ---- 1. Same preprocessing as main method ----
def preprocess_sctransform(adata, regress_vars=None, hvg_n=2000):
    print("[INFO] Preprocessing with SCTransform-like workflow (highly variable genes + regression)...")
    sc.pp.highly_variable_genes(adata, flavor='seurat_v3', n_top_genes=hvg_n)
    adata = adata[:, adata.var['highly_variable']]
    if regress_vars is not None:
        sc.pp.regress_out(adata, regress_vars)
    sc.pp.scale(adata)
    print("[INFO] Preprocessing completed.")
    return adata

# ---- 2. MLP Autoencoder Model ----
class MLPAutoencoder(nn.Module):
    def __init__(self, input_dim, hidden_dims=[512, 256, 128], latent_dim=32, dropout=0.2):
        super().__init__()
        self.input_dim = input_dim
        self.latent_dim = latent_dim

        # Encoder
        encoder_layers = []
        prev_dim = input_dim
        for hidden_dim in hidden_dims:
            encoder_layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.BatchNorm1d(hidden_dim)
            ])
            prev_dim = hidden_dim

        # Latent layer
        encoder_layers.append(nn.Linear(prev_dim, latent_dim))
        self.encoder = nn.Sequential(*encoder_layers)

        # Decoder
        decoder_layers = []
        prev_dim = latent_dim
        for hidden_dim in reversed(hidden_dims):
            decoder_layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.BatchNorm1d(hidden_dim)
            ])
            prev_dim = hidden_dim

        # Output layer
        decoder_layers.append(nn.Linear(prev_dim, input_dim))
        self.decoder = nn.Sequential(*decoder_layers)

    def forward(self, x):
        latent = self.encoder(x)
        reconstructed = self.decoder(latent)
        return latent, reconstructed

    def encode(self, x):
        return self.encoder(x)

# ---- 3. Alternative: Simple MLP for dimensionality reduction ----
class SimpleMLP(nn.Module):
    def __init__(self, input_dim, hidden_dims=[512, 256], output_dim=32, dropout=0.2):
        super().__init__()
        layers = []
        prev_dim = input_dim

        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.BatchNorm1d(hidden_dim)
            ])
            prev_dim = hidden_dim

        layers.append(nn.Linear(prev_dim, output_dim))
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)

# ---- 4. Contrastive Loss for MLP ----
def contrastive_loss_mlp(embeddings, drop_prob=0.2, temperature=0.5):
    """
    Contrastive loss for MLP embeddings using dropout augmentation
    """
    # Dropout-based augmentation
    mask = (torch.rand_like(embeddings) > drop_prob).float()
    embeddings_aug = embeddings * mask
    N = embeddings.shape[0]
    z1 = F.normalize(embeddings, dim=1)
    z2 = F.normalize(embeddings_aug, dim=1)
    representations = torch.cat([z1, z2], dim=0)
    similarity_matrix = torch.matmul(representations, representations.T)
    mask_eye = torch.eye(2*N, dtype=torch.bool, device=embeddings.device)
    similarity_matrix = similarity_matrix.masked_fill(mask_eye, -9e15)
    positives = torch.cat([torch.diag(similarity_matrix, N), torch.diag(similarity_matrix, -N)], dim=0)
    negatives = similarity_matrix[~mask_eye].view(2*N, -1)
    logits = torch.cat([positives.unsqueeze(1), negatives], dim=1) / temperature
    labels = torch.zeros(2*N, dtype=torch.long, device=embeddings.device)
    loss = F.cross_entropy(logits, labels)
    return loss.item(), loss

# ---- 5. Training Functions ----
def train_mlp_autoencoder(data, model_type='autoencoder', hidden_dims=[512, 256, 128],
                         latent_dim=32, epochs=100, batch_size=256, learning_rate=0.001,
                         early_stop_patience=15, early_stop_delta=1e-4, device="cuda"):
    """
    Train MLP autoencoder or simple MLP with contrastive loss
    """
    print(f"[INFO] Training MLP {model_type} with contrastive loss...")

    n_cells, n_genes = data.shape
    data_tensor = torch.tensor(data, dtype=torch.float32)

    # Create DataLoader
    dataset = torch.utils.data.TensorDataset(data_tensor)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Initialize model
    if model_type == 'autoencoder':
        model = MLPAutoencoder(n_genes, hidden_dims, latent_dim).to(device)
    else:  # simple MLP
        model = SimpleMLP(n_genes, hidden_dims, latent_dim).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    best_loss = float("inf")
    patience = 0

    for epoch in range(epochs):
        epoch_loss = 0
        n_batches = 0

        for batch_data, in dataloader:
            batch_data = batch_data.to(device)
            optimizer.zero_grad()

            if model_type == 'autoencoder':
                latent, reconstructed = model(batch_data)
                # Reconstruction loss
                recon_loss = F.mse_loss(reconstructed, batch_data)
                # Contrastive loss on latent space
                cont_loss_val, cont_loss = contrastive_loss_mlp(latent, drop_prob=0.2, temperature=0.5)
                # Combined loss
                total_loss = recon_loss + 0.1 * cont_loss  # weight contrastive loss lower
            else:
                embeddings = model(batch_data)
                cont_loss_val, total_loss = contrastive_loss_mlp(embeddings, drop_prob=0.2, temperature=0.5)

            total_loss.backward()
            optimizer.step()

            epoch_loss += total_loss.item()
            n_batches += 1

        avg_loss = epoch_loss / n_batches

        # Early stopping
        if avg_loss + early_stop_delta < best_loss:
            best_loss = avg_loss
            patience = 0
        else:
            patience += 1

        if patience > early_stop_patience:
            print(f"[INFO] Early stopped at epoch {epoch+1} (best loss: {best_loss:.4f})")
            break

        if (epoch+1) % 20 == 0 or epoch == 0:
            print(f"[INFO] Epoch {epoch+1}, average loss: {avg_loss:.4f}")

    # Get final embeddings
    model.eval()
    with torch.no_grad():
        data_tensor = data_tensor.to(device)
        if model_type == 'autoencoder':
            final_emb, _ = model(data_tensor)
        else:
            final_emb = model(data_tensor)
        final_emb = final_emb.cpu().numpy()

    print(f"[INFO] Training completed. Final embeddings shape: {final_emb.shape}")
    return final_emb

# ---- 6. Clustering and Evaluation ----
def cluster_and_eval_mlp(final_emb, cells, adata, n_clusters=None, leiden_res_list=None):
    print("[INFO] Running clustering (Leiden + KMeans)...")

    if leiden_res_list is None:
        leiden_res_list = [0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.5, 2.0]

    adata.obsm['X_mlp'] = final_emb
    sc.pp.neighbors(adata, use_rep='X_mlp', n_neighbors=15)

    leiden_silhouettes = []
    leiden_label_sets = {}

    # Leiden clustering
    for res in leiden_res_list:
        sc.tl.leiden(adata, resolution=res, key_added=f'mlp_leiden_{res}')
        leiden_labels = adata.obs[f'mlp_leiden_{res}'].astype(str).reindex(cells).fillna('NA').tolist()
        leiden_labels_num = pd.factorize(leiden_labels)[0]

        if len(set(leiden_labels_num)) > 1:
            sil = silhouette_score(final_emb, leiden_labels_num)
        else:
            sil = -1
        leiden_silhouettes.append(sil)
        leiden_label_sets[res] = leiden_labels
        print(f"[INFO] Leiden res={res:.2f}, clusters={len(set(leiden_labels))}, silhouette={sil:.3f}")

    best_res = leiden_res_list[np.argmax(leiden_silhouettes)]
    best_leiden_labels = leiden_label_sets[best_res]
    best_leiden_silhouette = max(leiden_silhouettes)

    # KMeans clustering
    if n_clusters is None:
        sil_scores = []
        for k in range(3, 15):
            labels = KMeans(n_clusters=k, random_state=42).fit_predict(final_emb)
            sil_scores.append(silhouette_score(final_emb, labels))
        n_clusters = np.argmax(sil_scores) + 3
        print(f"[INFO] Using elbow method: n_clusters={n_clusters}")
    elif n_clusters == "unique":
        n_clusters = adata.obs['assigned_cluster'].nunique()
        print(f"[INFO] Using unique assigned clusters: n_clusters={n_clusters}")

    kmeans_labels = KMeans(n_clusters=n_clusters, random_state=42).fit_predict(final_emb)
    kmeans_silhouette = silhouette_score(final_emb, kmeans_labels)

    # Evaluation metrics
    if 'assigned_cluster' in adata.obs:
        true_labels = adata.obs.loc[cells, 'assigned_cluster']
        metrics = {
            'Leiden ARI': adjusted_rand_score(true_labels, best_leiden_labels),
            'Leiden NMI': normalized_mutual_info_score(true_labels, best_leiden_labels),
            'Leiden AMI': adjusted_mutual_info_score(true_labels, best_leiden_labels),
            'Leiden Silhouette': best_leiden_silhouette,
            'KMeans ARI': adjusted_rand_score(true_labels, kmeans_labels),
            'KMeans NMI': normalized_mutual_info_score(true_labels, kmeans_labels),
            'KMeans AMI': adjusted_mutual_info_score(true_labels, kmeans_labels),
            'KMeans Silhouette': kmeans_silhouette
        }
        print("[INFO] MLP Clustering metrics:")
        for k, v in metrics.items():
            print(f"  {k}: {v:.3f}")

        # Confusion matrix
        true_labels = adata.obs.loc[cells, 'assigned_cluster'].astype(str).values
        pred_labels = kmeans_labels.astype(str)
        cm_df = pd.crosstab(pd.Series(true_labels, name="True"), pd.Series(pred_labels, name="Predicted"))
        plt.figure(figsize=(12,8))
        sns.heatmap(cm_df, annot=True, fmt="d", cmap="Blues")
        plt.xlabel("Predicted Clusters")
        plt.ylabel("True Labels")
        plt.title("MLP: Confusion Matrix")
        plt.tight_layout()
        plt.show()

    return best_leiden_labels, kmeans_labels, best_res, best_leiden_silhouette, kmeans_silhouette

# ---- 7. Visualization ----
def visualize_mlp(final_emb, cluster_labels, cells, leiden_labels=None):
    print("[INFO] Generating visualizations (PCA, UMAP)...")

    # PCA
    pca = PCA(n_components=2)
    pca_proj = pca.fit_transform(final_emb)
    plt.figure(figsize=(6,5))
    sns.scatterplot(x=pca_proj[:,0], y=pca_proj[:,1], hue=cluster_labels, palette='tab20', s=15)
    plt.title(f'MLP: PCA (Var explained: {pca.explained_variance_ratio_.sum():.2f})')
    plt.xlabel("PCA1")
    plt.ylabel("PCA2")
    plt.legend(title="Cluster")
    plt.tight_layout()
    plt.show()

    # UMAP
    reducer = umap.UMAP(n_components=2, random_state=42)
    umap_proj = reducer.fit_transform(final_emb)

    if leiden_labels is not None:
        plt.figure(figsize=(6,5))
        sns.scatterplot(x=umap_proj[:,0], y=umap_proj[:,1], hue=leiden_labels, palette='tab20', s=15)
        plt.title("MLP: UMAP (Leiden)")
        plt.xlabel("UMAP1")
        plt.ylabel("UMAP2")
        plt.legend(title="Leiden")
        plt.tight_layout()
        plt.show()

    plt.figure(figsize=(6,5))
    sns.scatterplot(x=umap_proj[:,0], y=umap_proj[:,1], hue=cluster_labels, palette='tab20', s=15)
    plt.title("MLP: UMAP (KMeans)")
    plt.xlabel("UMAP1")
    plt.ylabel("UMAP2")
    plt.legend(title="Cluster")
    plt.tight_layout()
    plt.show()


# ---- 9. Main Pipeline ----
def run_mlp_baseline_pipeline(
    expr_csv, metadata_cols=['barcode', 'assigned_cluster'],
    regress_vars=None, hvg_n=2000, model_type='autoencoder',
    hidden_dims=[512, 256, 128], latent_dim=32, epochs=100, batch_size=256,
    learning_rate=0.001, early_stop_patience=15, early_stop_delta=1e-4,
    top_n_marker_genes=5, n_clusters=None, rare_cell_thresh=40,
    device="cuda"
):
    print("[INFO] === MLP BASELINE (No Graph, No Pathway Information) ===")
    print("[INFO] Loading expression matrix...")

    df = pd.read_csv(expr_csv, index_col=0)
    metadata = df[metadata_cols]
    expr_df = df.drop(columns=metadata_cols)
    expr_df = expr_df.apply(pd.to_numeric, errors='coerce').fillna(0)

    adata = sc.AnnData(expr_df)
    adata.obs = metadata

    # Fix duplicate names
    if adata.obs_names.duplicated().any():
        print("[WARNING] Found duplicate cell names. Making unique...")
        adata.obs_names_make_unique()
    if adata.var_names.duplicated().any():
        print("[WARNING] Found duplicate gene names. Making unique...")
        adata.var_names_make_unique()

    # Remove rare cell clusters
    if 'assigned_cluster' in adata.obs:
        counts = adata.obs['assigned_cluster'].value_counts()
        keep = counts[counts >= rare_cell_thresh].index
        before = adata.n_obs
        adata = adata[adata.obs['assigned_cluster'].isin(keep)]
        print(f"[INFO] Removed rare cell types (<{rare_cell_thresh} cells). {before} → {adata.n_obs} cells kept.")

    # Preprocess
    adata = preprocess_sctransform(adata, regress_vars=regress_vars, hvg_n=hvg_n)

    cells = list(adata.obs_names)
    data = adata.X

    # Train MLP
    final_emb = train_mlp_autoencoder(
        data, model_type=model_type, hidden_dims=hidden_dims, latent_dim=latent_dim,
        epochs=epochs, batch_size=batch_size, learning_rate=learning_rate,
        early_stop_patience=early_stop_patience, early_stop_delta=early_stop_delta,
        device=device
    )

    # Clustering and evaluation
    leiden_labels, kmeans_labels, best_res, leiden_sil, kmeans_sil = cluster_and_eval_mlp(
        final_emb, cells, adata, n_clusters=n_clusters
    )

    # Marker gene analysis
    print("[INFO] Running marker gene analysis...")
    adata.obs["cluster"] = kmeans_labels.astype(str)
    sc.tl.rank_genes_groups(adata, groupby="cluster", method="wilcoxon", use_raw=False)
    sc.pl.rank_genes_groups(adata, n_genes=top_n_marker_genes, sharey=False)

    markers_df = sc.get.rank_genes_groups_df(adata, group=None)
    for c in sorted(adata.obs["cluster"].unique()):
        top_genes = markers_df[markers_df["group"] == c].head(top_n_marker_genes)
        print(f"\nMLP Cluster {c} top marker genes:")
        for _, row in top_genes.iterrows():
            sig = (
                "***" if row["pvals_adj"] < 0.001 else
                "**" if row["pvals_adj"] < 0.01 else
                "*" if row["pvals_adj"] < 0.05 else ""
            )
            print(f"  {row['names']}: logFC={row['logfoldchanges']:.2f}, adj.p={row['pvals_adj']:.3g} {sig}")

    # Visualizations
    visualize_mlp(final_emb, kmeans_labels, cells, leiden_labels)


# Example usage:
run_mlp_baseline_pipeline(
    'your_scRNA seq_csv_dataset',
    model_type=None, epochs=150, n_clusters="unique"
)