- STELLA GRN Tutorial

    - harness STELLA gene embedding to predict celltype-specific tfs.

    - refer to https://github.com/biomap-research/scFoundation/blob/main/genemodule/plot_geneemb.ipynb

In [None]:
import sys
sys.path.append("../../src")

import torch
import random
import pickle
import numpy as np
import pandas as pd
import scanpy as sc

from tqdm.auto import tqdm
from stella import STELLADataCollatorV1
from torch.utils.data import DataLoader
from stella.tokenizer import Preprocessor
from sklearn.metrics.pairwise import cosine_similarity
from stella.models.modeling_stella import STELLAForMaskedLM

random.seed(42)

### Step 1: Load Pretrained STELLA.

In [None]:
model = STELLAForMaskedLM.from_pretrained("../../pretrained_models/B100_L2048")

### Step 2: Load adata, and select celltypes you are interested in.

- For each cell type, we randomly select n (default: n=100) cells for forward propagation.

In [None]:
adata = sc.read_h5ad("/fse/home/wupengpeng/process_data/zheng68k/zheng68k.h5ad")

In [None]:
# 'CD8+ Cytotoxic T', 'CD14+ Monocyte', 'CD19+ B', every cell type randomly select n samples
n = 100
query_celltypes = ["CD8+ Cytotoxic T", "CD14+ Monocyte", "CD19+ B"]

# sample procedure
selected_samples = {
    ct: random.sample(list(barcodes), k=min(n, len(barcodes)))
    for ct, barcodes in adata.obs.groupby("celltype", observed=True).groups.items()
    if ct in query_celltypes
}

selected_samples_barcodes = sum(list(selected_samples.values()), [])

# extract selected samples
selected_adata = adata[selected_samples_barcodes].copy()

# filter all-zero expression genes
selected_gene = selected_adata.X.sum(0) > 0
selected_adata = selected_adata[:, selected_gene].copy()

# filter genes not in vocab
with open("../../src/stella/gene2id.pkl", "rb") as f:
    gene2idx = pickle.load(f)

selected_adata = selected_adata[:, selected_adata.var_names.isin(gene2idx.keys())].copy()

# save
selected_adata.to_df().to_csv("./grn_data/adata_subset.csv")
selected_adata.obs.celltype.to_csv("./grn_data/adata_subset_celltype.csv")

In [None]:
# view the number of selected cells for each cell type
selected_adata.obs["celltype"].value_counts()

### Step 3: Find tf genes in `selected_adata`.

In [None]:
TF = pd.read_csv('./grn_data/allTFs_hg38.txt', header=None).values.T[0]

tfs_in_selected_adata = [x for x in TF if x in selected_adata.var_names]
print(f"There are {len(tfs_in_selected_adata)} tfs in all {len(selected_adata.var_names)} genes. (Total tfs: {len(TF)})")

### Step 4: Retrieve gene embeddings from STELLA.

In [None]:
# preprocess selected_adata
preprocessor = Preprocessor(
    filter_gene_by_counts=False,
    filter_cell_by_counts=False,
    select_genes_mode="specify",
    gene_list=selected_adata.var_names
)

selected_adata = preprocessor(selected_adata)
ds = preprocessor.get_hf_dataset_from_adata(selected_adata)

In [None]:
# set device, if you do not have enough gpu memory, try using cpu!
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
model = model.to(device)

dl = DataLoader(ds, batch_size=8, shuffle=False, collate_fn=STELLADataCollatorV1)

model.eval()
all_outputs = []
with torch.no_grad():
    for batch in tqdm(dl):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model.stella(**batch, output_hidden_states=True)
        all_outputs.append(outputs.last_hidden_state.detach().cpu().numpy())
        torch.cuda.empty_cache()

In [None]:
# rearrange outputs, then get gene embedding
gene_embs = np.concatenate(all_outputs, axis=0).mean(0)
gene_embs = pd.DataFrame(gene_embs, index=selected_adata.var_names)

### Step 5: Calculate the cosine similarity between transcription factors and other genes.

In [None]:
n_candidate_target_genes = 1000

coexplist=[]
for tf in tqdm(tfs_in_selected_adata):
    tmpsim = cosine_similarity(gene_embs.loc[tf, :].values.reshape(1,-1), gene_embs)
    tmpsim[tmpsim > 0.9999] = 0
    tmpsimdf = pd.DataFrame(tmpsim, columns=selected_adata.var_names, index=['simi']).T
    tmpsimdf = tmpsimdf.sort_values('simi', ascending=False)
    for idx in range(n_candidate_target_genes):
        coexplist.append([tf, tmpsimdf.index[idx], tmpsimdf.iloc[idx, 0] * 100])

In [None]:
grndf = pd.DataFrame(coexplist, columns=['TF', 'target', 'importance'])
grndf.to_csv('./grn_data/scf_grn_1000.tsv', index=False, sep='\t')
grndf.head()

### Step 6: Predicting cell type specific transcription factors using pySCENIC.

- Since ​​pySCENIC​​ has compatibility issues, the subsequent code can be run in ​​a new conda environment​​.

- Requirements:

    - notebook

    - pyscenic

    - pandas

    - matplotlib

    - seaborn

    - adjustText

    - numpy == 1.23.5

    - numba == 0.59.0

In [None]:
!wget -P grn_data https://resources.aertslab.org/cistarget/databases/homo_sapiens/hg38/refseq_r80/mc_v10_clust/gene_based/hg38_10kbp_up_10kbp_down_full_tx_v10_clust.genes_vs_motifs.rankings.feather
!wget -P grn_data https://resources.aertslab.org/cistarget/motif2tf/motifs-v10nr_clust-nr.hgnc-m0.001-o0.0.tbl

In [None]:
!pyscenic ctx \
    ./grn_data/scf_grn_1000.tsv \
    ./grn_data/hg38_10kbp_up_10kbp_down_full_tx_v10_clust.genes_vs_motifs.rankings.feather \
    --annotations_fname ./grn_data/motifs-v10nr_clust-nr.hgnc-m0.001-o0.0.tbl \
    --expression_mtx_fname ./grn_data/adata_subset.csv \
    --output ./grn_data/regulons_1000.csv \
    --mask_dropouts \
    --num_workers 16

# !docker run -it --rm -v /fse/home/wupengpeng/STELLA/tutorials/02_GRN/:/GRN/ aertslab/pyscenic:0.12.1 \
#     pyscenic ctx \
#     /GRN/grn_data/scf_grn_1000.tsv \
#     /GRN/grn_data/hg38_10kbp_up_10kbp_down_full_tx_v10_clust.genes_vs_motifs.rankings.feather \
#     --annotations_fname /GRN/grn_data/motifs-v10nr_clust-nr.hgnc-m0.001-o0.0.tbl \
#     --expression_mtx_fname /GRN/grn_data/adata_subset.csv \
#     --output /GRN/grn_data/regulons_1000.csv \
#     --mask_dropouts \
#     --num_workers 16

In [None]:
!pyscenic aucell \
    ./grn_data/adata_subset.csv \
    ./grn_data/regulons_1000.csv \
    -o ./grn_data/auc_mtx_1000.csv \
    --num_workers 16

# !docker run -it --rm -v /fse/home/wupengpeng/STELLA/tutorials/02_GRN/:/GRN/ aertslab/pyscenic:0.12.1 \
#     pyscenic aucell \
#     /GRN/grn_data/adata_subset.csv \
#     /GRN/grn_data/regulons_1000.csv \
#     -o /GRN/grn_data/auc_mtx_1000.csv \
#     --num_workers 16

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

from adjustText import adjust_text
from pyscenic.plotting import plot_rss
from pyscenic.rss import regulon_specificity_scores

- Calculate RSS score

In [None]:
auc_mtx = pd.read_csv("./grn_data/auc_mtx_1000.csv", index_col=0)
celltype = pd.read_csv("./grn_data/adata_subset_celltype.csv", index_col=0)

# calculate regulon specificity scores
rss_cellType = regulon_specificity_scores(auc_mtx, celltype.celltype)
rss_cellType.to_csv("./grn_data/RSS.csv")
rss_cellType

- Plot

In [None]:
cats = sorted(list(set(celltype.celltype)))

fig = plt.figure(figsize=(15, 5))
for c, num in zip(cats, range(1, len(cats) + 1)):
    x = rss_cellType.T[c]
    ax = fig.add_subplot(1, 3, num)
    plot_rss(rss_cellType, c, top_n=3, max_n=None, ax=ax)
    ax.set_ylim(
        x.min() - (x.max() - x.min()) * 0.05, x.max() + (x.max() - x.min()) * 0.05
    )
    for t in ax.texts:
        t.set_fontsize(12)
    ax.set_ylabel("")
    ax.set_xlabel("")
    adjust_text(
        ax.texts,
        autoalign="xy",
        ha="right",
        va="bottom",
        arrowprops=dict(arrowstyle="-", color="lightgrey"),
        precision=0.001,
    )

fig.text(0.5, 0.0, "Regulon", ha="center", va="center", size="x-large")
fig.text(
    0.00,
    0.5,
    "Regulon specificity score (RSS)",
    ha="center",
    va="center",
    rotation="vertical",
    size="x-large",
)
plt.tight_layout()
plt.rcParams.update(
    {
        "figure.autolayout": True,
        "figure.titlesize": "large",
        "axes.labelsize": "medium",
        "axes.titlesize": "large",
        "xtick.labelsize": "medium",
        "ytick.labelsize": "medium",
    }
)
plt.rcParams["pdf.fonttype"] = 42
plt.rcParams["ps.fonttype"] = 42
plt.savefig("RSS.pdf", bbox_inches="tight")