In [16]:
import os
import torch
import pandas as pd
import anndata as ad
from collections import defaultdict
from deepsc.scgpt_utils.grn import GeneEmbedding
import seaborn as sns
import matplotlib.pyplot as plt
import scanpy as sc
from deepsc.utils.utils import extract_state_dict, sample_weight_norms, report_loading_result
from deepsc.utils.utils import report_loading_result
from deepsc.models.deepsc_new.model import DeepSC
import tqdm
import networkx as nx
import gseapy as gp
from gears import PertData
from pathlib import Path
import numpy as np

In [17]:
pretrained_model_path="/home/angli/baseline/DeepSC/results/latest_checkpoint_epoch3.ckpt"
data_path_dir="/home/angli/baseline/DeepSC/data/"
gene_map_path="/home/angli/baseline/DeepSC/data/gene_map.csv"

In [18]:
def load_pretrained_model( model, pretrained_model_path):
    ckpt_path = pretrained_model_path
    assert os.path.exists(ckpt_path), f"找不到 ckpt: {ckpt_path}"
    print(f"[LOAD] 读取 checkpoint: {ckpt_path}")
    raw = torch.load(ckpt_path, map_location="cpu")
    state_dict = extract_state_dict(raw)
    sample_weight_norms(model, state_dict, k=5)
    load_info = model.load_state_dict(state_dict, strict=False)
    report_loading_result(load_info)
    return model

In [19]:
import types # 导入 types 模块
model_config = {
    "embedding_dim": 256,
    "num_genes": 34683,
    "num_layers": 10,
    "num_heads": 8,
    "attn_dropout": 0.1,
    "ffn_dropout": 0.1,
    "fused": False,
    "num_bins": 5,  # 从 Hydra 变量 ${num_bin} 转换
    "mask_layer_start": 100,
    "enable_l0": False,       # 从 Hydra 变量 ${enable_l0} 转换
    "enable_mse": True,     # 从 Hydra 变量 ${enable_mse} 转换
    "enable_ce": True,       # 从 Hydra 变量 ${enable_ce} 转换
    "num_layers_ffn": 2,
    "use_moe_regressor": True,
    "number_of_experts": 3,
    "use_M_matrix": False,
    "gene_embedding_participate_til_layer": 3,
    "moe": types.SimpleNamespace(**{
        "n_moe_layers": 4,
        "use_moe_ffn": True,
        "dim": 256,
        "moe_inter_dim": 512,
        "n_routed_experts": 2,
        "n_activated_experts": 2,
        "n_shared_experts": 1,
        "score_func": "softmax",
        "route_scale": 1.0,
        "world_size": 1,
        "rank": 0,
    })
}

In [20]:
model = DeepSC(**model_config)

In [21]:
model = load_pretrained_model(model, pretrained_model_path)
device = next(model.parameters()).device

[LOAD] 读取 checkpoint: /home/angli/baseline/DeepSC/results/latest_checkpoint_epoch3.ckpt
[LOAD] 抽样参数范数对比（加载前 -> 加载后）：
  - expression_layers.2.expr_attn.W_V.bias: 0.558314 -> 0.608517
  - expression_layers.4.norm_expr2.weight: 16.000000 -> 16.453827
  - expression_layers.4.ffn_expr.shared_experts.w3.bias: 0.817375 -> 0.816993
  - layers.9.ffn_gene.layers.0.0.bias: 1.171525 -> 1.164386
  - layers.5.norm_gene2.bias: 0.000000 -> 0.180650
[LOAD] missing_keys: 0 | unexpected_keys: 0


In [22]:
# 1) 读 CSV，构建映射
df = pd.read_csv(gene_map_path)
# feature_name -> id（允许重复 id）
gene2idx = dict(zip(df["feature_name"], df["id"]))
# id -> [feature_name, ...]（可选：用于反查）
idx2genes = defaultdict(list)
for name, i in zip(df["feature_name"], df["id"]):
    idx2genes[int(i)].append(name)

In [23]:
data_dir = Path("data_path_dir")
pert_data = PertData(data_dir)
pert_data.load(data_name="adamson")
adata = sc.read(data_dir / "adamson/perturb_processed.h5ad")
ori_batch_col = "control"
adata.obs["celltype"] = adata.obs["condition"].astype("category")
adata.obs["str_batch"] = adata.obs["control"].astype(str)
data_is_raw = False

Found local copy...
Found local copy...
Found local copy...
These perturbations are not in the GO graph and their perturbation can thus not be predicted
['SRPR+ctrl' 'SLMO2+ctrl' 'TIMM23+ctrl' 'AMIGO3+ctrl' 'KCTD16+ctrl']
Local copy of pyg dataset is detected. Loading...
Done!


In [27]:
# 在 AnnData.var 里新建一列，把 gene_name 映射到 ID
adata.var["id_in_vocab"] = [1 if g in gene2idx else -1 for g in adata.var["gene_name"]]

# 取出所有基因的 id 数组
gene_ids_in_vocab = np.array(adata.var["id_in_vocab"])

# 过滤掉映射不到 id 的基因
adata = adata[:, adata.var["id_in_vocab"] >= 0]


In [28]:
TF_name = 'BHLHE40'
adata = adata[adata.obs.condition.isin(['{}+ctrl'.format(TF_name), 'ctrl'])].copy()
np.unique(adata.obs.condition)
print(adata)

AnnData object with n_obs × n_vars = 24767 × 4079
    obs: 'condition', 'cell_type', 'dose_val', 'control', 'condition_name', 'celltype', 'str_batch'
    var: 'gene_name', 'gene_id', 'id_in_vocab'
    uns: 'non_dropout_gene_idx', 'non_zeros_gene_idx', 'rank_genes_groups_cov_all', 'top_non_dropout_de_20', 'top_non_zero_de_20'


In [30]:
def discretize_expression(input_values, num_bins=5):
    """对表达值进行离散化分箱"""
    batch_size = input_values.shape[0]
    discrete_input_bins = torch.zeros_like(input_values, dtype=torch.long)

    for i in range(batch_size):
        row_vals = input_values[i]
        valid_mask = row_vals != -1.0
        if valid_mask.any():
            valid_vals = row_vals[valid_mask]
            min_val = valid_vals.min()
            max_val = valid_vals.max()
            norm = (valid_vals - min_val) / (max_val - min_val + 1e-8)
            bins = torch.floor(norm * (num_bins - 1)).long()
            bins = torch.clamp(bins, 0, num_bins - 1) + 1
            discrete_input_bins[i][valid_mask] = bins

    return discrete_input_bins

In [31]:
sc.pp.highly_variable_genes(
    adata,
    layer=None,
    n_top_genes=1200,
    batch_key="str_batch",
    flavor="seurat_v3" if data_is_raw else "cell_ranger",
    subset=False,
)
adata.var.loc[adata.var[adata.var.gene_name==TF_name].index, 'highly_variable'] = True
adata = adata[:, adata.var["highly_variable"]].copy()
print(adata)

AnnData object with n_obs × n_vars = 24767 × 1200
    obs: 'condition', 'cell_type', 'dose_val', 'control', 'condition_name', 'celltype', 'str_batch'
    var: 'gene_name', 'gene_id', 'id_in_vocab', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'highly_variable_nbatches', 'highly_variable_intersection'
    uns: 'non_dropout_gene_idx', 'non_zeros_gene_idx', 'rank_genes_groups_cov_all', 'top_non_dropout_de_20', 'top_non_zero_de_20', 'hvg'


In [33]:
all_values  = adata.X.A if hasattr(adata.X, "A") else adata.X 
all_values = torch.tensor(all_values, dtype=torch.float)
all_discrete_bins = discretize_expression(all_values, 5)

TypeError: sparse array length is ambiguous; use getnnz() or shape[0]