In [None]:
import scanpy as sc
import pandas as pd

import os
import pandas as pd
import anndata as ad
import numpy as np
from scipy.io import mmread
import scipy.sparse as sp
import matplotlib.pyplot as plt
from IPython.display import Image
import scanpy as sc
from cnmf import cNMF, Preprocess

from multiprocessing import Process

import seaborn as sns

from cnmf import cNMF

find the best component number

In [None]:
adata = sc.read_h5ad("./all_integrated_harmony.h5ad")

In [None]:
p = Preprocess(random_seed=14)
(adata_c, adata_tp10k, hvgs) = p.preprocess_for_cnmf(adata, harmony_vars='dataset', n_top_rna_genes = 2000,
                                                     max_scaled_thresh = None, quantile_thresh = .9999, makeplots=True,
                                                    save_output_base='./all_integrated/')


In [None]:
cnmf_obj = cNMF(output_dir="./all_integrated", name="all_integrated_cNMF")
cnmf_obj.prepare(counts_fn="./all_integrated/all_integrated.Corrected.HVG.Varnorm.h5ad",
                           tpm_fn='./all_integrated/all_integrated.TP10K.h5ad',
                           genes_file='./all_integrated/all_integrated.Corrected.HVGs.txt',
                           components=np.arange(5,40), n_iter=20, seed=14, num_highvar_genes=2000)

In [None]:
cnmf_obj.factorize(worker_i=0, total_workers=1)
cnmf_obj.combine()
cnmf_obj.k_selection_plot() 

use the best component number

In [None]:
cnmf_obj = cNMF(output_dir="./all_integrated", name="all_integrated_cNMF")
cnmf_obj.prepare(counts_fn="./all_integrated/all_integrated.Corrected.HVG.Varnorm.h5ad",
                           tpm_fn='./all_integrated/all_integrated.TP10K.h5ad',
                           genes_file='./all_integrated/all_integrated.Corrected.HVGs.txt',
                           components=15, n_iter=20, seed=14, num_highvar_genes=2000)

In [None]:
def run_worker(worker_i, total_workers):
    cnmf_obj = cNMF(output_dir="./all_integrated", name="all_integrated_cNMF")
    cnmf_obj.factorize(worker_i=worker_i, total_workers=total_workers)
total_workers = 8
processes = []
for i in range(total_workers):
    p = Process(target=run_worker, args=(i, total_workers))
    p.start()
    processes.append(p)

for p in processes:
    p.join()

cnmf_obj = cNMF(output_dir="./all_integrated", name="all_integrated_cNMF")
cnmf_obj.combine()

In [None]:
cnmf_obj.consensus(k=15, density_threshold=0.25)
usage, spectra_scores, spectra_tpm, top_genes = cnmf_obj.load_results(K=15, density_threshold=0.25)

In [None]:
top_genes.to_csv('./cnmf_top_genes.csv', index=False)

visulization

In [None]:
def classify_components(
    adata,
    usage,
    top_genes,
    gene_list_state_markers=set(),
    ftest_thresh=1e-10
):
    from scipy.stats import f_oneway
    import pandas as pd

    n_components = usage.shape[1]
    results = []

    component_ids = top_genes.columns

    for idx, comp_id in enumerate(component_ids):
        adata.obs[f'cNMF_{comp_id}'] = usage.iloc[:, idx]
        usage_i = adata.obs[f'cNMF_{comp_id}']

        # 1. F-test
        groups = [
            usage_i[adata.obs["celltype"] == ct].values
            for ct in adata.obs["celltype"].unique()
        ]
        f_p = f_oneway(*groups).pvalue if len(groups) > 1 else 1.0

        top_genes_i = set(top_genes[comp_id].dropna().values)
        state_ratio = len(top_genes_i & gene_list_state_markers) / len(top_genes_i)

        if f_p < ftest_thresh:
            label = "Celltype-specific"
        elif state_ratio > 0.1:
            label = "State"
        else:
            label = "Other"

        results.append({
            "Component": comp_id,
            "Label": label,
            "F-test_p": f_p,
            "StateGeneRatio": state_ratio
        })

    df_result = pd.DataFrame(results)
    return df_result

In [None]:
df_result = classify_components(adata_all, usage, top_genes)

In [None]:
sc._settings.settings._vector_friendly=True
n_rows = 3
n_cols = 3


for page in range(0, 15, 9): 
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 15))
    axes = axes.flatten()  

    for i, comp_id in enumerate(range(page, min(page + 9, 15))): 
        ax = axes[i]
        # Plot the UMAP for each component on the corresponding subplot
        sc.pl.umap(
            adata, 
            color=f'cNMF_{comp_id+1}', 
            ax=ax,         
            show=False,   
            title=f'Component {comp_id+1}'
        )

    # Adjust layout and save to the current page as a PDF
    plt.tight_layout()
    plt.savefig(f"cNMF_{page//9 + 1}.svg", bbox_inches="tight", dpi=600)
    plt.close()