In [1]:
import os
import torch
import pandas as pd
import anndata as ad
from collections import defaultdict
from deepsc.scgpt_utils.grn import GeneEmbedding
import seaborn as sns
import matplotlib.pyplot as plt
import scanpy as sc
from deepsc.utils.utils import extract_state_dict, sample_weight_norms, report_loading_result
from deepsc.utils.utils import report_loading_result
from deepsc.models.deepsc_new.model import DeepSC
import tqdm
import networkx as nx
import gseapy as gp
from gears import PertData
from pathlib import Path
import numpy as np

  return module_get_attr_redirect(attr_name, deprecated_mapping=_DEPRECATED)
  return module_get_attr_redirect(attr_name, deprecated_mapping=_DEPRECATED)
  return module_get_attr_redirect(attr_name, deprecated_mapping=_DEPRECATED)
  return module_get_attr_redirect(attr_name, deprecated_mapping=_DEPRECATED)
  return module_get_attr_redirect(attr_name, deprecated_mapping=_DEPRECATED)
  return module_get_attr_redirect(attr_name, deprecated_mapping=_DEPRECATED)
  return module_get_attr_redirect(attr_name, deprecated_mapping=_DEPRECATED)
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
pretrained_model_path="/home/angli/baseline/DeepSC/results/latest_checkpoint_epoch3.ckpt"
data_path_dir="/home/angli/baseline/DeepSC/data/"
gene_map_path="/home/angli/baseline/DeepSC/data/gene_map.csv"

In [3]:
def load_pretrained_model( model, pretrained_model_path):
    ckpt_path = pretrained_model_path
    assert os.path.exists(ckpt_path), f"找不到 ckpt: {ckpt_path}"
    print(f"[LOAD] 读取 checkpoint: {ckpt_path}")
    raw = torch.load(ckpt_path, map_location="cpu")
    state_dict = extract_state_dict(raw)
    sample_weight_norms(model, state_dict, k=5)
    load_info = model.load_state_dict(state_dict, strict=False)
    report_loading_result(load_info)
    return model

In [4]:
import types # 导入 types 模块
model_config = {
    "embedding_dim": 256,
    "num_genes": 34683,
    "num_layers": 10,
    "num_heads": 8,
    "attn_dropout": 0.1,
    "ffn_dropout": 0.1,
    "fused": False,
    "num_bins": 5,  # 从 Hydra 变量 ${num_bin} 转换
    "mask_layer_start": 100,
    "enable_l0": False,       # 从 Hydra 变量 ${enable_l0} 转换
    "enable_mse": True,     # 从 Hydra 变量 ${enable_mse} 转换
    "enable_ce": True,       # 从 Hydra 变量 ${enable_ce} 转换
    "num_layers_ffn": 2,
    "use_moe_regressor": True,
    "number_of_experts": 3,
    "use_M_matrix": False,
    "gene_embedding_participate_til_layer": 3,
    "moe": types.SimpleNamespace(**{
        "n_moe_layers": 4,
        "use_moe_ffn": True,
        "dim": 256,
        "moe_inter_dim": 512,
        "n_routed_experts": 2,
        "n_activated_experts": 2,
        "n_shared_experts": 1,
        "score_func": "softmax",
        "route_scale": 1.0,
        "world_size": 1,
        "rank": 0,
    })
}

In [5]:
model = DeepSC(**model_config)

In [6]:
model = load_pretrained_model(model, pretrained_model_path)
device = next(model.parameters()).device

[LOAD] 读取 checkpoint: /home/angli/baseline/DeepSC/results/latest_checkpoint_epoch3.ckpt
[LOAD] 抽样参数范数对比（加载前 -> 加载后）：
  - layers.4.gene_attn.out_proj.bias: 0.602360 -> 0.559752
  - regressor.experts.1.4.bias: 0.000000 -> 0.147406
  - layers.0.norm_gene2.bias: 0.000000 -> 0.591487
  - layers.2.ffn_expr.layers.1.0.bias: 0.286975 -> 0.297348
  - layers.4.norm_gene2.bias: 0.000000 -> 0.234143
[LOAD] missing_keys: 0 | unexpected_keys: 0


In [7]:
# 1) 读 CSV，构建映射
df = pd.read_csv(gene_map_path)
# feature_name -> id（允许重复 id）
gene2idx = dict(zip(df["feature_name"], df["id"]))
# id -> [feature_name, ...]（可选：用于反查）
idx2genes = defaultdict(list)
for name, i in zip(df["feature_name"], df["id"]):
    idx2genes[int(i)].append(name)

In [8]:
data_dir = Path("/home/angli/baseline/DeepSC/src/deepsc/finetune/data_path_dir")
pert_data = PertData(data_dir)
pert_data.load(data_name="adamson")
adata = sc.read(data_dir / "adamson/perturb_processed.h5ad")
ori_batch_col = "control"
adata.obs["celltype"] = adata.obs["condition"].astype("category")
adata.obs["str_batch"] = adata.obs["control"].astype(str)

Found local copy...
Found local copy...
Found local copy...
These perturbations are not in the GO graph and their perturbation can thus not be predicted
['SRPR+ctrl' 'SLMO2+ctrl' 'TIMM23+ctrl' 'AMIGO3+ctrl' 'KCTD16+ctrl']
Local copy of pyg dataset is detected. Loading...
Done!


In [9]:
adata.var["id_in_vocab"] = [1 if gene in gene2idx else -1 for gene in adata.var["gene_name"]]
gene_ids_in_vocab = np.array(adata.var["id_in_vocab"])
adata = adata[:, adata.var["id_in_vocab"] >= 0]

In [10]:
TF_name = 'BHLHE40'
adata = adata[adata.obs.condition.isin(['{}+ctrl'.format(TF_name), 'ctrl'])].copy()
np.unique(adata.obs.condition)

array(['BHLHE40+ctrl', 'ctrl'], dtype=object)

In [11]:
import numpy as np
import torch
import scanpy as sc
from scipy.sparse import issparse

def discretize_expression(input_values, num_bins=5):
    """对表达值进行离散化分箱"""
    batch_size = input_values.shape[0]
    discrete_input_bins = torch.zeros_like(input_values, dtype=torch.long)

    for i in range(batch_size):
        row_vals = input_values[i]
        valid_mask = row_vals != -1.0
        if valid_mask.any():
            valid_vals = row_vals[valid_mask]
            min_val = valid_vals.min()
            max_val = valid_vals.max()
            norm = (valid_vals - min_val) / (max_val - min_val + 1e-8)
            bins = torch.floor(norm * (num_bins - 1)).long()
            bins = torch.clamp(bins, 0, num_bins - 1) + 1
            discrete_input_bins[i][valid_mask] = bins

    return discrete_input_bins


def process_anndata(adata, num_bins=5):
    # Step 1: 转成稠密矩阵 (如果是稀疏的)
    X = adata.X
    if issparse(X):
        X = X.toarray()

    # Step 2: 转成 torch tensor
    X_tensor = torch.tensor(X, dtype=torch.float32)

    # Step 3: 对每个 cell (行) 做归一化到 [0,1]
    X_normed = torch.zeros_like(X_tensor)
    for i in range(X_tensor.shape[0]):
        row = X_tensor[i]
        valid_mask = row != 0  # 只考虑非零表达
        if valid_mask.any():
            min_val = row[valid_mask].min()
            max_val = row[valid_mask].max()
            X_normed[i, valid_mask] = (row[valid_mask] - min_val) / (max_val - min_val + 1e-8)

    # Step 4: 保存归一化数据到 layer
    adata.layers["X_normed"] = X_normed.numpy().astype(np.float32)

    # Step 5: 离散化
    X_binned = discretize_expression(X_tensor, num_bins=num_bins)
    adata.layers["X_binned"] = X_binned.numpy().astype(np.int32)

    return adata

In [12]:
adata=process_anndata(adata)

In [None]:
input_layer_key = "X_binned"
all_counts = (
    adata.layers[input_layer_key].A
    if issparse(adata.layers[input_layer_key])
    else adata.layers[input_layer_key]
)
genes = adata.var["gene_name"].tolist()
gene_ids = np.array(vocab(genes), dtype=int)