- STELLA GRN Tutorial

    - extract gene programs from gene embedding network.

        - identify cell type marker gene programs.

        - investigating the correlations between genes within gene programs of interest.

    - refer to https://github.com/biomap-research/scFoundation/blob/main/genemodule/plot_geneemb.ipynb

In [None]:
import sys
sys.path.append("../../src")

import torch
import random
import pickle
import warnings
import collections
import numpy as np
import pandas as pd
import scanpy as sc
import networkx as nx
import matplotlib as mpl
import matplotlib.pyplot as plt

from tqdm.auto import tqdm
from matplotlib import rcParams
from stella import STELLADataCollatorV1
from torch.utils.data import DataLoader
from stella.tokenizer import Preprocessor
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics.pairwise import cosine_similarity
from stella.models.modeling_stella import STELLAForMaskedLM

random.seed(42)

warnings.filterwarnings("ignore")

### Step 1: Load Pretrained STELLA.

In [None]:
model = STELLAForMaskedLM.from_pretrained("../../pretrained_models/B100_L2048")

### Step 2: Load adata, and select celltypes you are interested in.

- For each cell type, we randomly select n (default: n=100) cells for forward propagation.

In [None]:
adata = sc.read_h5ad("/fse/home/wupengpeng/process_data/zheng68k/zheng68k.h5ad")

In [None]:
# 'CD8+ Cytotoxic T', 'CD14+ Monocyte', 'CD19+ B', every cell type randomly select n samples
n = 100
query_celltypes = ["CD8+ Cytotoxic T", "CD14+ Monocyte", "CD19+ B"]

# sample procedure
selected_samples = {
    ct: random.sample(list(barcodes), k=min(n, len(barcodes)))
    for ct, barcodes in adata.obs.groupby("celltype", observed=True).groups.items()
    if ct in query_celltypes
}

selected_samples_barcodes = sum(list(selected_samples.values()), [])

# extract selected samples
selected_adata = adata[selected_samples_barcodes].copy()

# filter all-zero expression genes
selected_gene = selected_adata.X.sum(0) > 0
selected_adata = selected_adata[:, selected_gene].copy()

# filter genes not in vocab
with open("../../src/stella/gene2id.pkl", "rb") as f:
    gene2idx = pickle.load(f)

selected_adata = selected_adata[:, selected_adata.var_names.isin(gene2idx.keys())].copy()

In [None]:
# view the number of selected cells for each cell type
selected_adata.obs["celltype"].value_counts()

### Step 3: Retrieve gene embeddings from STELLA.

In [None]:
# preprocess selected_adata
preprocessor = Preprocessor(
    filter_gene_by_counts=False,
    filter_cell_by_counts=False,
    select_genes_mode="hvg",
    nhvgs=300
)

# before processing, we need to back up a copy of the unprocessed data
unprocessed_selected_adata = selected_adata.copy()
# do normalization and log1p on unprocessed data, without binning gene expression
preprocessor.normalize_and_log1p(unprocessed_selected_adata)

# data after binning is used to extract gene embeddings
selected_adata = preprocessor(selected_adata)
ds = preprocessor.get_hf_dataset_from_adata(selected_adata)

In [None]:
# set device, if you do not have enough gpu memory, try using cpu!
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
model = model.to(device)

dl = DataLoader(ds, batch_size=128, shuffle=False, collate_fn=STELLADataCollatorV1)

model.eval()
all_outputs = []
with torch.no_grad():
    for batch in tqdm(dl):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model.stella(**batch, output_hidden_states=True)
        all_outputs.append(outputs.last_hidden_state.detach().cpu().numpy())
        torch.cuda.empty_cache()

In [None]:
# rearrange outputs, then get gene embedding
gene_embs = np.concatenate(all_outputs, axis=0).mean(0)
gene_embs = pd.DataFrame(gene_embs, index=selected_adata.var_names)

### Step 4: Construct gene_adata (genes x embeddings), then cluter these genes to get gene programs.

In [None]:
gene_adata = sc.AnnData(gene_embs)
sc.pp.neighbors(gene_adata, use_rep="X")
sc.tl.umap(gene_adata)
sc.tl.leiden(gene_adata, resolution=5)
sc.pl.umap(gene_adata, color="leiden")

In [None]:
# The function is modified from https://github.com/bowang-lab/scGPT
def get_metagenes(gdata):
    metagenes = collections.defaultdict(list)
    for x, y in zip(gdata.obs["leiden"], gdata.obs.index):
        metagenes[x].append(y)
    return metagenes

metagenes = get_metagenes(gene_adata)

# Obtain the set of gene programs from clusters with #genes >= 5
mgs = dict()
for mg, genes in metagenes.items():
    if len(genes) > 4:
        mgs[mg] = genes

### Step 5: Calculate the score of all gene programs in each cell.

In [None]:
# The function is modified from https://github.com/bowang-lab/scGPT
def score_metagenes(adata, metagenes):
    for p, genes in tqdm(metagenes.items()):
        try:
            sc.tl.score_genes(adata, score_name=str(p) + "_SCORE", gene_list=genes)
            scores = np.array(adata.obs[str(p) + "_SCORE"].tolist()).reshape(-1, 1)
            scaler = MinMaxScaler()
            scores = scaler.fit_transform(scores)
            scores = list(scores.reshape(1, -1))[0]
            adata.obs[str(p) + "_SCORE"] = scores
        except Exception as e:
            adata.obs[str(p) + "_SCORE"] = 0.0

# Here, ​​unprocessed_selected_adata​​ is used,
# because binning the data would significantly reduce the heterogeneity between genes.
score_metagenes(unprocessed_selected_adata,mgs)

- construct genescoreadata (cells x gene_programs_score), and find cell type marker gene programs.

In [None]:
scorelist = [x for x in unprocessed_selected_adata.obs.columns if x.__contains__("SCORE")]
genescoreadata = sc.AnnData(unprocessed_selected_adata.obs[scorelist])
genescoreadata.obs["celltype"] = unprocessed_selected_adata.obs["celltype"]
sc.tl.rank_genes_groups(genescoreadata, groupby="celltype")

In [None]:
# plot
rcParams["pdf.fonttype"] = 42
rcParams["ps.fonttype"] = 42
sc.pl.rank_genes_groups_matrixplot(
    genescoreadata,
    n_genes=3,
    standard_scale="var",
    cmap="Blues",
    show=False
)
plt.savefig("matrixplot_genemodule.png", bbox_inches="tight", dpi=300)
plt.show()

In [None]:
print(mgs['23'],'\n',mgs['17'],'\n',mgs['24'])

### Step 6: Visualize network connectivity within desired gene program

We can further visualize the connectivity between genes within any gene program of interest from Step 5

In [None]:
gene_programs = mgs["23"]
gene_programs_embs = gene_embs.loc[gene_programs].copy()

In [None]:
# The function is modified from https://github.com/bowang-lab/scGPT
G = nx.Graph()
similarities = cosine_similarity(gene_programs_embs)
genes = list(gene_programs_embs.index.values)
similarities[similarities > 0.9999] = 0

edges = []
nz = list(zip(*similarities.nonzero()))
for n in tqdm(nz):
    edges.append((genes[n[0]], genes[n[1]], {"weight": similarities[n[0], n[1]]}))
G.add_nodes_from(genes)
G.add_edges_from(edges)

widths = nx.get_edge_attributes(G, "weight")
weightvalue = np.array(list(widths.values()))
scaled_weightvalue = (
    (weightvalue - weightvalue.min()) / (weightvalue.max() - weightvalue.min()) * 3
)
widsorted = sorted(widths.items(), key=lambda x: x[1], reverse=True)
toppair = np.array(list(widths))[weightvalue.argsort() < 3]
pos = nx.spring_layout(G, k=0.4, iterations=15, seed=42)

nx.draw_networkx_edges(
    G,
    pos,
    edgelist=widths.keys(),
    edge_color=list(widths.values()),
    width=scaled_weightvalue,
    edge_cmap=mpl.colormaps["cool"],
    alpha=1,
)

nx.draw_networkx_labels(G, pos, font_size=15, font_family="sans-serif")

# edge weight labels
edge_labels = {widsorted[i][0]: f"rank{i + 1}" for i in range(5)}
nx.draw_networkx_edge_labels(G, pos, edge_labels, font_size=10)

ax = plt.gca()
ax.margins(0.08)
plt.axis("off")
plt.tight_layout()
rcParams["pdf.fonttype"] = 42
rcParams["ps.fonttype"] = 42
plt.savefig("genemodule.png", bbox_inches="tight", dpi=300)