In [34]:
import scanpy as sc
import pandas as pd
import numpy as np
import seaborn as sns
import infercnvpy as cnv
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import math
import scipy.stats
from scipy.sparse import issparse
from scipy.cluster.hierarchy import linkage, fcluster, to_tree
from scipy.spatial.distance import pdist, squareform
import os
import warnings
warnings.filterwarnings("ignore")

sc.settings.set_figure_params(dpi=150, frameon=False, vector_friendly=True)
sns.set_style("white")

# matplotlib inline

# Loading Data

In [None]:
data_dir = "./raw_data"
matrix_path = os.path.join(data_dir, "GSM3828672_Smartseq2_GBM_IDHwt_processed_TPM.tsv")
meta_path = os.path.join(data_dir, "GSE131928_single_cells_tumor_name_and_adult_or_peidatric.xlsx")
df = pd.read_csv(matrix_path, sep='\t', index_col=0)
adata = sc.AnnData(df.T)
print(f"data loaded. Matrix shape: {adata.shape} (cells x genes)")

meta_df = pd.read_excel(meta_path, index_col=0)
common_cells = adata.obs_names.intersection(meta_df.index)
print(f"{len(common_cells)} cells found in both expression matrix and metadata.")

adata = adata[common_cells].copy()
adata.obs = meta_df.loc[common_cells]

max_val = np.max(adata.X)
sc.pp.calculate_qc_metrics(adata, percent_top=None, log1p=False, inplace=True)

print(f"Processed data saved to {save_path}.")

In [None]:
rename_dict = {
    'Unnamed: 7': 'tumor_name',
    'Unnamed: 8': 'age_group'
}
adata.obs.rename(columns=rename_dict, inplace=True)
adata.obs = adata.obs[['tumor_name', 'age_group']]
adata.obs.head()

# Filtering Generally Highly Expressed Genes

In [None]:
tpm_approx = 10 * (2**adata.X - 1)
avg_tpm = tpm_approx.mean(axis=0)
Ea = np.log2(avg_tpm + 1)
gene_mask = Ea > 4
print(f"original gene count: {adata.n_vars}")
print(f"filtered gene count (Ea > 4): {np.sum(gene_mask)}")

adata = adata[:, gene_mask].copy()
print(f"{adata.shape} (cells x genes)")

# Define Non-malignant Cell Markers

In [None]:
print(f"Data max value: {adata.X.max()}")

adata_scaled = adata.copy()
sc.pp.scale(adata_scaled, max_value=10, zero_center=True)

marker_genes = {
    'Macrophage': ['CD14', 'AIF1', 'FCER1G', 'FCGR3A', 'TYROBP', 'CSF1R'],
    'Oligodendrocyte': ['MBP', 'TF', 'PLP1', 'MAG', 'MOG', 'CLDN11'],
    'T cell' : ['CD2', 'CD3D', 'CD3E', 'CD3G']
}
score_df = pd.DataFrame(index=adata.obs_names)

for ct, genes in marker_genes.items():
    valid = [g for g in genes if g in adata.var_names]
    sc.tl.score_genes(adata_scaled, gene_list=valid, score_name=f'score_{ct}', ctrl_size=100, n_bins=30)
    adata.obs[f'score_{ct}'] = adata_scaled.obs[f'score_{ct}']
    score_df[f'score_{ct}'] = adata.obs[f'score_{ct}']

plt.figure(figsize=(12, 4))
thresholds = {'Macrophage': 1.5, 'Oligodendrocyte': 3.0, 'T cell': 4.0}
for i, col in enumerate(score_df.columns):
    plt.subplot(1, 3, i+1)
    sns.histplot(score_df[col], kde=True)
    plt.axvline(x=4.0, color='r', linestyle='--', label='Paper Threshold (4.0)')
    plt.axvline(x=thresholds[col.split('_')[1]], color='g', linestyle='--', label='Custom Threshold')
    plt.title(f'{col} Score Distribution')
    plt.legend()
plt.tight_layout()
plt.show()

adata.obs['is_normal_marker'] = (
    (adata.obs['score_Macrophage'] > 1.5) |
    (adata.obs['score_Oligodendrocyte'] > 3.0) |
    (adata.obs['score_T cell'] > 4.0)
)

print(f"Non-malignant cell counts based on marker genes: {sum(adata.obs['is_normal_marker'])}")

# Detect CNVs

In [None]:
adata.obs['cnv_reference'] = 'Query'
adata.obs.loc[adata.obs['is_normal_marker'], 'cnv_reference'] = 'Reference'

cnv.io.genomic_position_from_biomart(adata, species="hsapiens", biomart_gene_id="external_gene_name")
adata_cnv = adata[:, ~adata.var['chromosome'].isna()].copy()

cnv.tl.infercnv(
    adata_cnv,
    reference_key="cnv_reference",
    reference_cat="Reference",
    window_size=100,
    step=1,
    n_jobs=8
)

cnv_mtx = adata_cnv.obsm['X_cnv']
if issparse(cnv_mtx):
    cnv_mtx = cnv_mtx.toarray()

adata.obsm['X_cnv'] = cnv_mtx

cna_signal = np.mean(cnv_mtx**2, axis=1)
adata.obs['cna_signal'] = cna_signal

def fast_correlation(X, ref_vec):
    X_centered = X - X.mean(axis=1, keepdims=True)
    ref_centered = ref_vec - ref_vec.mean()
    numerator = np.dot(X_centered, ref_centered)
    denominator = np.linalg.norm(X_centered, axis=1) * np.linalg.norm(ref_centered)
    with np.errstate(divide='ignore', invalid='ignore'):
        corr = numerator / denominator
        corr[np.isnan(corr)] = 0
    return corr

adata.obs['cna_corr'] = 0.0
tumor_ids = adata.obs['tumor_name'].unique()

for tumor in tumor_ids:
    tumor_cells_mask = (adata.obs['tumor_name'] == tumor)
    suspected_mask = tumor_cells_mask & (~adata.obs['is_normal_marker'])
    if np.sum(suspected_mask) < 5:
        print(f"Warning: Tumor {tumor} has too few suspected malignant cells (<5). Skipping correlation.")
        continue

    tumor_cnv_subset = cnv_mtx[tumor_cells_mask, :]
    
    local_suspected_mask = (~adata.obs.loc[tumor_cells_mask, 'is_normal_marker']).values
    current_avg_profile = np.mean(tumor_cnv_subset[local_suspected_mask, :], axis=0)
    current_corrs = fast_correlation(tumor_cnv_subset, current_avg_profile)

    adata.obs.loc[tumor_cells_mask, 'cna_corr'] = current_corrs

# Set up threshold for CNV-based malignant cell

In [None]:
THRESHOLD_SIGNAL = 0.01
THRESHOLD_CORR = 0.2

adata.obs['is_malignant_cnv'] = (
    (adata.obs['cna_signal'] > THRESHOLD_SIGNAL) & 
    (adata.obs['cna_corr'] > THRESHOLD_CORR)
)

print(f"Malignant cells identified by CNA: {sum(adata.obs['is_malignant_cnv'])}")

plt.figure(figsize=(8, 6))
sns.scatterplot(
    data=adata.obs, 
    x='cna_signal', 
    y='cna_corr', 
    hue='is_normal_marker',
    palette={True: 'black', False: 'blue'},
    style='is_normal_marker',
    markers={True: 'X', False: 'o'},
    s=15, 
    alpha=0.6,
    linewidth=0
)

plt.axvline(THRESHOLD_SIGNAL, color='red', linestyle='--', label=f'Signal={THRESHOLD_SIGNAL}')
plt.axhline(THRESHOLD_CORR, color='green', linestyle='--', label=f'Corr={THRESHOLD_CORR}')

plt.title("Diagnostic Plot: CNA Signal vs Correlation (Per Tumor)")
plt.xlabel("CNV Signal")
plt.ylabel("CNV Correlation")
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
plt.grid(True, linestyle=':', alpha=0.3)
plt.xlim(0, 0.1)
plt.ylim(-0.2, 1.0)
plt.show() 

In [None]:
adata_cnv.obs['groups'] = adata.obs['tumor_name'].astype(str)
adata_cnv.obs.loc[adata.obs['is_normal_marker'], 'groups'] = 'Non-malignant'

adult_tumors = sorted(adata.obs[adata.obs['age_group'] == 'adult']['tumor_name'].unique())
pediatric_tumors = sorted(adata.obs[adata.obs['age_group'] == 'pediatric']['tumor_name'].unique())

category_order = ['Non-malignant'] + adult_tumors + pediatric_tumors

adata_cnv.obs['groups'] = pd.Categorical(
    adata_cnv.obs['groups'], 
    categories=category_order, 
    ordered=True
)

adata_cnv_sorted = adata_cnv[adata_cnv.obs.sort_values('groups').index].copy()
cnv.pl.chromosome_heatmap(
    adata_cnv_sorted,
    groupby='groups',
    dendrogram=False,
    cmap='RdBu_r',
    vmin=-0.6, vmax=0.6,
    figsize=(12, 12),
    save=None
)
plt.show()

In [None]:
sc.tl.pca(adata_scaled, n_comps=50, svd_solver='arpack')
sc.tl.tsne(adata_scaled, 
           random_state=42,
           n_pcs=50,
           perplexity=30,
           early_exaggeration=4.0
           )
adata.obsm['X_tsne'] = adata_scaled.obsm['X_tsne']

adata.obs['fig1b_class'] = 'Other' 
adata.obs.loc[adata.obs['is_malignant_cnv'], 'fig1b_class'] = 'Malignant (CNAs)'

adata.obs.loc[adata.obs['score_Macrophage'] > 1.5, 'fig1b_class'] = 'Macrophages'
adata.obs.loc[adata.obs['score_Oligodendrocyte'] > 3.0, 'fig1b_class'] = 'Oligodendrocytes'
adata.obs.loc[adata.obs['score_T cell'] > 4.0, 'fig1b_class'] = 'T cells'

palette_fig1b = {
    'Malignant (CNAs)': '#0000CD',
    'Macrophages': '#00FFFF',
    'Oligodendrocytes': '#FF00FF',
    'T cells': '#32CD32',
    'Other': '#D3D3D3'
}

plt.figure(figsize=(8, 8))
sc.pl.tsne(
    adata, 
    color='fig1b_class', 
    palette=palette_fig1b,
    title='Figure 1B: Cell Classification',
    frameon=False,
    s=20,
    legend_fontsize=10
)

In [None]:
sc.tl.leiden(adata_scaled, resolution=0.8, key_added='global_clusters')
adata.obs['global_clusters'] = adata_scaled.obs['global_clusters']

cluster_scores = adata.obs.groupby('global_clusters')[
    ['score_Macrophage', 'score_Oligodendrocyte', 'score_T cell']
].mean()

normal_clusters = cluster_scores[(cluster_scores > 1.0).any(axis=1)].index.tolist()
print(f"Identified normal clusters (Marker High): {normal_clusters}")

adata.obs['is_in_malignant_cluster'] = ~adata.obs['global_clusters'].isin(normal_clusters)

is_strict_malignant = (
    (adata.obs['is_malignant_cnv']) & 
    (adata.obs['is_in_malignant_cluster']) & 
    (~adata.obs['is_normal_marker'])
)

adata.obs['final_cell_class'] = 'Non-malignant'
adata.obs.loc[is_strict_malignant, 'final_cell_class'] = 'Malignant'

print(f"Final strict malignant cell count: {sum(is_strict_malignant)}")

adata_mal = adata[adata.obs['final_cell_class'] == 'Malignant'].copy()
adata_mal_scaled = adata_scaled[adata_mal.obs_names].copy()
sc.tl.pca(adata_mal_scaled, n_comps=50, svd_solver='arpack')

sc.tl.tsne(
    adata_mal_scaled, 
    random_state=42, 
    n_pcs=50, 
    perplexity=30
)

adata_mal.obsm['X_tsne'] = adata_mal_scaled.obsm['X_tsne']

plt.figure(figsize=(8, 8))
sc.pl.tsne(
    adata_mal, 
    color='tumor_name', 
    title='Figure 1C: Malignant Cells by Tumor (Strict Intersection)',
    frameon=False,
    palette='tab20',
    s=20,
    legend_loc='on data',
    legend_fontsize=8
)

In [None]:
common_cells = adata_mal_scaled.obs_names.intersection(adata_mal.obs_names)
adata_mal_scaled = adata_mal_scaled[common_cells].copy()
adata_mal = adata_mal[common_cells].copy()

tumor_counts = adata_mal_scaled.obs['tumor_name'].value_counts()
valid_tumors = tumor_counts[tumor_counts > 50].index

qualified_signatures_all = {}
qualified_signatures_meta = []

def recover_all_clusters_from_linkage(Z, n_leaf_nodes, min_size, max_ratio):
    clusters = []
    T = to_tree(Z)

    def traverse(node):
        if node.is_leaf():
            return [node.id]

        left_leaves = traverse(node.left)
        right_leaves = traverse(node.right)

        current_cluster_indices = left_leaves + right_leaves
        current_size = len(current_cluster_indices)

        if current_size >= min_size and current_size <= max_ratio * n_leaf_nodes:
            clusters.append(current_cluster_indices)

        return current_cluster_indices
    
    traverse(T)

    return clusters

LOGFC_THRESHOLD = math.log2(3)
P_ADJ_N_SIG1 = 0.05
P_ADJ_N_SIG2 = 0.005
N_SIG1_REQUIRED = 50
N_SIG2_REQUIRED = 10
CLUSTER_SIZE_MIN = 5              
CLUSTER_MAX_RATIO = 0.8          

for i, tumor in enumerate(valid_tumors):
    adata_clus_er = adata_mal_scaled[adata_mal_scaled.obs['tumor_name'] == tumor].copy()
    adata_de_eij = adata_mal[adata_mal.obs['tumor_name'] == tumor].copy()

    genes_to_keep = adata_clus_er.var_names.intersection(adata_de_eij.var_names)
    adata_clus_er = adata_clus_er[:, genes_to_keep].copy()
    adata_de_eij = adata_de_eij[:, genes_to_keep].copy()
    sc.pp.filter_genes(adata_clus_er, min_cells=3)
    genes_to_keep = adata_clus_er.var_names
    adata_de_eij = adata_de_eij[:, genes_to_keep].copy()

    n_cells_orig = adata_clus_er.n_obs
    if n_cells_orig < 50:
        continue

    print(f"[{i+1}/{len(valid_tumors)}] Processing {tumor} (Cells: {n_cells_orig})...", end=" ")   

    try:
        cell_variances = np.var(adata_clus_er.X, axis=1)
        finite_variance_cells = adata_clus_er.obs_names[cell_variances > 1e-6]
        adata_clus_filtered = adata_clus_er[finite_variance_cells].copy()
        n_cells_filtered = adata_clus_filtered.n_obs

        if n_cells_filtered < CLUSTER_SIZE_MIN:
            print("-> HCL Skipped (too few cells after variance filter).")
            continue

        D = pdist(adata_clus_filtered.X, metric='correlation')
        Z = linkage(D, method='average')

        potential_clusters_indices = recover_all_clusters_from_linkage(
            Z, n_cells_filtered, min_size=CLUSTER_SIZE_MIN, max_ratio=CLUSTER_MAX_RATIO
        )

        print(f"Recovered {len(potential_clusters_indices)} potential clusters. -> ", end="")

    except Exception as e:
        print(f"-> HCL/Clustering Error: {e}. Skipping tumor.")
        continue

    current_tumor_signatures = []
    adata_de_eij = adata_de_eij[adata_clus_filtered.obs_names].copy()

    for cluster_id, indices in enumerate(potential_clusters_indices):
        cluster_cells = adata_de_eij.obs_names[indices]

        temp_adata_de = adata_de_eij.copy()
        temp_adata_de.obs['temp_cluster'] = 'Other'
        temp_adata_de.obs.loc[cluster_cells, 'temp_cluster'] = 'Target'

        try:
             sc.tl.rank_genes_groups(
                 temp_adata_de, 'temp_cluster', groups=['Target'], reference='Other',
                 method='wilcoxon', use_raw=False
             )
        except Exception:
             continue

        df_stats = sc.get.rank_genes_groups_df(temp_adata_de, group='Target')

        df_stats_up = df_stats[df_stats['logfoldchanges'] > 0]
        df_Nsig1 = df_stats_up[
            (df_stats_up['logfoldchanges'] > LOGFC_THRESHOLD) &
            (df_stats_up['pvals_adj'] < P_ADJ_N_SIG1)
        ]
        N_sig1 = len(df_Nsig1)

        df_Nsig2 = df_stats_up[
            (df_stats_up['logfoldchanges'] > LOGFC_THRESHOLD) &
            (df_stats_up['pvals_adj'] < P_ADJ_N_SIG2)
        ]
        N_sig2 = len(df_Nsig2)

        if N_sig1 > N_SIG1_REQUIRED and N_sig2 > N_SIG2_REQUIRED:
            cluster_genes = df_Nsig1['names'].tolist()
            name = f"{tumor}_C{cluster_id}"

            current_tumor_signatures.append({
                'name': name,
                'genes': cluster_genes,
                'N_sig1': N_sig1
            })

    count_found = len(current_tumor_signatures)
    for sig in current_tumor_signatures:
        qualified_signatures_all[sig['name']] = sig['genes']
        qualified_signatures_meta.append({'Signature': sig['name'], 'N_sig1': sig['N_sig1']})

    print(f"Extracted {count_found} strictly filtered signatures.")

total_raw_signatures = len(qualified_signatures_all)
print(f"\original signature number: {total_raw_signatures}")

final_signatures = {}
sig_names = list(qualified_signatures_all.keys())
sig_sets = {name: set(genes) for name, genes in qualified_signatures_all.items()}
Nsig1_map = {item['Signature']: item['N_sig1'] for item in qualified_signatures_meta}

for name1 in sig_names:
    genes1 = sig_sets[name1]
    Nsig1_1 = Nsig1_map.get(name1)

    if Nsig1_1 is None: continue

    is_redundant = False
    for name2 in sig_names:
        if name1 == name2: continue

        genes2 = sig_sets[name2]
        Nsig1_2 = Nsig1_map.get(name2)

        if Nsig1_2 is None: continue

        intersection = len(genes1.intersection(genes2))
        union = len(genes1.union(genes2))
        jaccard_index = intersection / union if union > 0 else 0

        if jaccard_index > 0.75:
            if Nsig1_1 < Nsig1_2:
                is_redundant = True
                break

    if not is_redundant:
        final_signatures[name1] = qualified_signatures_all[name1]

total_final_signatures = len(final_signatures)
print(f"filtered signature number: {total_final_signatures}")

In [None]:
if total_final_signatures > 0:
    all_genes = sorted(list(set(g for sig in final_signatures.values() for g in sig)))
    sig_names = list(final_signatures.keys())
    sig_matrix = np.zeros((len(sig_names), len(all_genes)))
    sig_df = pd.DataFrame(sig_matrix, index=sig_names, columns=all_genes)
    for sig_name, genes in final_signatures.items():
        sig_df.loc[sig_name, genes] = 1 
    
    corr_matrix = sig_df.T.corr(method='pearson')

    corr_distance = 1 - corr_matrix.values
    distance_vector = squareform(corr_distance)
    Z_meta = linkage(distance_vector, method='average')
    
    plt.figure(figsize=(14, 14))
    sns.clustermap(corr_matrix,
                       row_linkage=Z_meta, 
                       col_linkage=Z_meta,
                       method='average', 
                       metric='correlation', 
                       cmap='vlag',        
                       vmin=-0.6, vmax=0.6, 
                       figsize=(14, 14),
                       xticklabels=False,
                       yticklabels=False)
    
    plt.suptitle("Signature Pearson Correlation", y=1.02)
    plt.show() 

    NUM_GROUPS = 8
    clusters = fcluster(Z_meta, t=NUM_GROUPS, criterion='maxclust')
    print(f"\n--- Meta-Module Identification (Cut into {NUM_GROUPS} Groups) ---")

    for cid in np.unique(clusters):
        indices = np.where(clusters == cid)[0]
        current_sigs_names = corr_matrix.index[indices]
        
        gene_counter = {}
        for sig_name in current_sigs_names:
            genes = final_signatures[sig_name]
            for gene in genes:
                gene_counter[gene] = gene_counter.get(gene, 0) + 1
        
        top_genes = sorted(gene_counter.items(), key=lambda x: x[1], reverse=True)[:10]
        genes_only = [g[0] for g in top_genes]
        print(f"Group {cid} (Size: {len(indices)} signatures): Core Genes: {genes_only}")
            
else:
    print("No signatures remained after filtering; skipping Meta-Clustering.")

In [None]:
REMOVE_GROUP_ID = 1

indices_to_remove = np.where(clusters == REMOVE_GROUP_ID)[0]
sigs_to_remove = corr_matrix.index[indices_to_remove]
sigs_to_keep = [sig for sig in sig_names if sig not in sigs_to_remove]

print(f"Original signatures: {len(sig_names)}")
print(f"Remaining signatures: {len(sigs_to_keep)}")

if len(sigs_to_keep) > 0:
    sig_df_filtered = sig_df.loc[sigs_to_keep, :]
    sig_df_filtered = sig_df_filtered.loc[:, (sig_df_filtered != 0).any(axis=0)]
    corr_matrix_new = sig_df_filtered.T.corr(method='pearson')
    corr_distance_new = 1 - corr_matrix_new.values
    corr_distance_new[corr_distance_new < 0] = 0
    distance_vector_new = squareform(corr_distance_new)
    Z_meta_new = linkage(distance_vector_new, method='average')

    plt.figure(figsize=(14, 14))
    sns.clustermap(corr_matrix_new,
                   row_linkage=Z_meta_new, 
                   col_linkage=Z_meta_new,
                   method='average', 
                   metric='correlation', 
                   cmap='vlag',
                   vmin=-0.4, vmax=0.4, 
                   figsize=(14, 14),
                   xticklabels=False,
                   yticklabels=False)
    
    plt.suptitle("Figure 2B: Meta-modules (Cell Cycle Removed)", y=1.02)
    plt.show()

    NUM_GROUPS_NEW = 6
    clusters_new = fcluster(Z_meta_new, t=NUM_GROUPS_NEW, criterion='maxclust')
    
    print(f"\n--- Final Meta-Module Identification (Cut into {NUM_GROUPS_NEW} Groups) ---")

    meta_modules = {}

    for cid in np.unique(clusters_new):
        indices = np.where(clusters_new == cid)[0]
        current_sigs_names = corr_matrix_new.index[indices]
        
        gene_counter = {}
        for sig_name in current_sigs_names:
            genes = final_signatures[sig_name]
            for gene in genes:
                gene_counter[gene] = gene_counter.get(gene, 0) + 1

        top_genes = sorted(gene_counter.items(), key=lambda x: x[1], reverse=True)[:20]
        genes_only = [g[0] for g in top_genes]

        label = f"Group {cid}"
        if any(g in genes_only for g in ['OLIG1', 'PDGFRA', 'TNR', 'BCAN']):
            label += " (OPC-like)"
        elif any(g in genes_only for g in ['SLC1A3', 'MLC1', 'AQP4', 'S100B', 'CLU']):
            label += " (AC-like)"
        elif any(g in genes_only for g in ['HILPDA', 'CHI3L1', 'VIM', 'CD44', 'NDRG1']):
            label += " (MES-like)"
        elif any(g in genes_only for g in ['DCX', 'SOX11', 'DLL3', 'STMN2']):
            label += " (NPC-like)"
            
        print(f"{label} | Size: {len(indices)} signatures")
        print(f"Core Genes: {genes_only}\n")

        top_50 = sorted(gene_counter.items(), key=lambda x: x[1], reverse=True)[:50]
        meta_modules[label] = [g[0] for g in top_50]

else:
    print("Error: No signatures left after removing Group 1.")

In [None]:
group_order = [
    (5, 'MES-like'), 
    (4, 'AC-like'), 
    (3, 'OPC-like'), 
    (1, 'NPC-like')
]

ordered_genes = []
ordered_sigs = []
group_boundaries_genes = [0]
group_boundaries_sigs = [0]

for gid, name in group_order:
    indices = np.where(clusters_new == gid)[0]

    current_group_sigs = corr_matrix_new.index[indices].tolist()
    ordered_sigs.extend(current_group_sigs)
    group_boundaries_sigs.append(len(ordered_sigs))
    
    target_key = None
    for k in meta_modules.keys():
        if f"Group {gid}" in k:
            target_key = k
            break
            
    if target_key:
        current_genes = meta_modules[target_key]
        ordered_genes.extend(current_genes)
        group_boundaries_genes.append(len(ordered_genes))
    else:
        print(f"Warning: No genes found for Group {gid}")

print(f"Total Signatures to plot: {len(ordered_sigs)}")
print(f"Total Genes to plot: {len(ordered_genes)}")

fig2c_matrix = pd.DataFrame(0, index=ordered_genes, columns=ordered_sigs)

for sig in ordered_sigs:
    sig_genes = final_signatures[sig]
    intersect_genes = [g for g in sig_genes if g in ordered_genes]
    fig2c_matrix.loc[intersect_genes, sig] = 1

calc_height = len(ordered_genes) * 0.12 + 2 
fig_height = max(10, calc_height)
plt.figure(figsize=(10, fig_height)) 

custom_colors = ["#f7f7f7", "#f1a8b0"] 
cmap_binary = mcolors.ListedColormap(custom_colors)

ax = sns.heatmap(fig2c_matrix, 
                 cmap=cmap_binary, 
                 cbar_kws={
                     'label': 'Gene Membership', 
                     'shrink': 0.2,
                     'aspect': 10,
                     'ticks': [0.25, 0.75]
                 },
                 xticklabels=False, 
                 yticklabels=True)

cbar = ax.collections[0].colorbar
cbar.set_ticklabels(['Absent', 'Present'])
cbar.ax.tick_params(labelsize=10)

ax.tick_params(axis='y', labelsize=6, rotation=0) 

for y in group_boundaries_genes[1:-1]:
    ax.hlines(y, *ax.get_xlim(), colors='black', linewidth=1)

for x in group_boundaries_sigs[1:-1]:
    ax.vlines(x, *ax.get_ylim(), colors='black', linewidth=1)

for i, (gid, name) in enumerate(group_order):
    y_pos = (group_boundaries_genes[i] + group_boundaries_genes[i+1]) / 2
    group_size = group_boundaries_genes[i+1] - group_boundaries_genes[i]
    label_size = 12 if group_size > 10 else 10
    
    plt.text(len(ordered_sigs) + 2, y_pos, name, 
             va='center', fontsize=label_size, fontweight='bold')

plt.title("Figure 2C: Meta-module Genes vs Signatures", fontsize=14, y=1.01)
plt.ylabel("Meta-module Genes", fontsize=12)
plt.xlabel("Signatures (Grouped by State)", fontsize=12)

plt.tight_layout()
plt.show()

In [None]:
adata_score = adata_mal.copy() 

neftel_genes = {
    'NPC-like': ['CDK4', 'SOX4', 'SOX11', 'DCX', 'DLL3', 'HES6', 'OLIG2', 'ASCL1', 'STMN2'],
    'OPC-like': ['PDGFRA', 'OLIG1', 'OMG', 'PLP1', 'CSPG4', 'BCAN', 'PTPRZ1', 'TNR'],
    'AC-like':  ['EGFR', 'S100B', 'GFAP', 'SLC1A3', 'HOPX', 'MLC1', 'AQP4', 'CLU', 'ALDOC'],
    'MES-like': ['CHI3L1', 'CD44', 'VIM', 'ANXA1', 'ANXA2', 'FOSL2', 'TIMP1', 'NAMPT', 'YKL-40']
}

print(f"Scoring cells using Neftel et al. meta-module genes...")
for state, genes in neftel_genes.items():
    valid = [g for g in genes if g in adata_score.var_names]
    sc.tl.score_genes(adata_score, gene_list=valid, score_name=f'score_{state}', ctrl_size=100)

print("Calculating 2D state coordinates (D-score and Relative Difference X-axis)...")

sc_npc = adata_score.obs['score_NPC-like']
sc_opc = adata_score.obs['score_OPC-like']
sc_ac  = adata_score.obs['score_AC-like']
sc_mes = adata_score.obs['score_MES-like']

max_neuro = np.maximum(sc_npc, sc_opc)
max_other = np.maximum(sc_ac, sc_mes)
D = max_neuro - max_other

X = np.zeros(len(D))

mask_up = D > 0
diff_up = sc_npc[mask_up] - sc_opc[mask_up]
X[mask_up] = np.sign(diff_up) * np.log2(np.abs(diff_up) + 1)

mask_down = D <= 0
diff_down = sc_mes[mask_down] - sc_ac[mask_down]
X[mask_down] = np.sign(diff_down) * np.log2(np.abs(diff_down) + 1)

adata_score.obs['Neftel_X'] = X
adata_score.obs['Neftel_Y'] = D

scores = adata_score.obs[['score_NPC-like', 'score_OPC-like', 'score_AC-like', 'score_MES-like']]
adata_score.obs['max_state'] = scores.idxmax(axis=1).str.replace('score_', '')

plt.figure(figsize=(10, 10))

state_colors = {
    'NPC-like': '#1f77b4',
    'OPC-like': '#2ca02c',
    'AC-like':  '#bcbd22',
    'MES-like': '#d62728'
}

sns.scatterplot(
    data=adata_score.obs, 
    x='Neftel_X', 
    y='Neftel_Y', 
    hue='max_state',
    palette=state_colors,
    s=10, 
    alpha=0.6,
    linewidth=0,
    legend=False
)

plt.axhline(0, color='black', linestyle='--', linewidth=0.8, alpha=0.5)
plt.axvline(0, color='black', linestyle='--', linewidth=0.8, alpha=0.5)

plt.text(1.8, 2.5, 'NPC-like', fontsize=16, fontweight='bold', color=state_colors['NPC-like'], ha='center')
plt.text(-1.8, 2.5, 'OPC-like', fontsize=16, fontweight='bold', color=state_colors['OPC-like'], ha='center')
plt.text(1.8, -2.5, 'MES-like', fontsize=16, fontweight='bold', color=state_colors['MES-like'], ha='center')
plt.text(-1.8, -2.5, 'AC-like', fontsize=16, fontweight='bold', color=state_colors['AC-like'], ha='center')

plt.xlabel('Relative meta-module score (X)', fontsize=12)
plt.ylabel('Relative meta-module score (Y)', fontsize=12)
plt.title('Replication of Figure 3F: GBM Cellular States', fontsize=16)

sns.despine()
plt.show()