In [1]:
import torch
import anndata
import numpy as np
import random
from src.sctfbridge.model import scTFBridge
import scanpy as sc

def set_seed(seed):
    import os
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    # torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.set_float32_matmul_precision('high')


set_seed(3407)



In [2]:
dataset_name = 'BMMC'
cell_key = 'cell_type'
batch_key = 'batch'


gex_adata = anndata.read_h5ad(f'/data2/wfa/project/single_cell_multimodal/data/filter_data/{dataset_name}/RNA_filter.h5ad')
atac_adata = anndata.read_h5ad(f'/data2/wfa/project/single_cell_multimodal/data/filter_data/{dataset_name}/ATAC_filter.h5ad')
TF_adata = anndata.read_h5ad(f'/data2/wfa/project/single_cell_multimodal/data/filter_data/{dataset_name}/TF_filter.h5ad')
mask_path = f'/data2/wfa/project/single_cell_multimodal/data/filter_data/{dataset_name}/TF_binding/TF_binding.txt'

In [3]:
new_model = scTFBridge.load('sctfbridge_BMMC_model', device=torch.device('cuda:7'))


🚀 Loading model from sctfbridge_BMMC_model...
  - Using device: cuda:7
✅ Model loaded and ready for inference.


In [None]:
output = new_model.get_embeddings(
    [gex_adata, atac_adata, TF_adata],
)

In [None]:
gex_adata.obsm['rna_private'] = output['RNA_private_representations']


In [None]:
import pandas as pd
batches_info = gex_adata.obs['batch'].values.tolist()

one_hot_encoded_batches = pd.get_dummies(batches_info, prefix='batch', dtype=float)

In [None]:
from torch.utils.data import Dataset
class scOmicsDataset(Dataset):
    def __init__(self,
                 input_data,
                 batch_info: pd.DataFrame):
        self.input_tensor = torch.Tensor(input_data)
        self.batch_info = torch.Tensor(batch_info.to_numpy())

    def __len__(self):
        return self.input_tensor.shape[0]

    def __getitem__(self, idx):
        return self.input_tensor[idx, :], self.batch_info[idx, :]

In [None]:
rna_private_dataset = scOmicsDataset(gex_adata.obsm['rna_private'], one_hot_encoded_batches)
data_loader = torch.utils.data.DataLoader(rna_private_dataset, shuffle=False, batch_size=128)

In [None]:
all_recon = []
device = new_model.device
with torch.no_grad():
    for data in data_loader:
        rna_private, batch_data = data
        rna_private = rna_private.to(device)
        batch_data = batch_data.to(device)
        rna_recon = new_model.rna_generation_from_rna_private(rna_private, batch_data)
        all_recon.append(rna_recon)
all_recon = torch.cat(all_recon, dim=0).cpu().numpy()

In [None]:
gex_adata.layers['rna_recon_from_rna_private'] = all_recon

In [None]:
sc.pp.neighbors(gex_adata, use_rep='rna_private')
sc.tl.umap(gex_adata)
sc.pl.umap(gex_adata,
           color=['cell_type'])

In [None]:
new_model.batch_key = 'batch'
gex_adata.layers['rna_recon'] = new_model.predict_cross_omics(
    atac_adata,
    'RNA',
    device=torch.device('cuda:6')
)

In [None]:
recon_adata = sc.AnnData(X=gex_adata.layers['rna_recon_from_rna_private'],
                         obs=gex_adata.obs.copy(),
                         var=gex_adata.var.copy())

In [None]:
sc.tl.rank_genes_groups(gex_adata, groupby='cell_type', method='t-test')
# 提取每个细胞类型的前5个marker基因用于绘图
marker_genes_dict = {
    cell_type: gex_adata.uns['rank_genes_groups']['names'][cell_type][:3]
    for cell_type in gex_adata.obs['cell_type'].cat.categories
}

# --- 绘制热图 ---
# 原始数据的热图
print("--- 原始数据的Marker基因表达 (Heatmap) ---")
sc.pl.matrixplot(gex_adata,
                  marker_genes_dict,
                  groupby='cell_type',
                  standard_scale='var', # <-- 按基因（变量）进行Z-score标准化
                  cmap='bwr',           # <-- 使用蓝-白-红色图
                  show=True)
# 重构数据的热图
print("\n--- 重构数据的Marker基因表达 (Heatmap) ---")
sc.pl.matrixplot(recon_adata,
                  marker_genes_dict,
                  groupby='cell_type',
                  standard_scale='var', # <-- 按基因（变量）进行Z-score标准化
                  cmap='bwr',           # <-- 使用蓝-白-红色图
                  show=True)

In [None]:
import itertools

# 从字典的值中提取所有基因，并放入一个列表中
all_marker_genes = list(itertools.chain.from_iterable(marker_genes_dict.values()))

# 获取唯一的基因列表
unique_marker_genes = list(set(all_marker_genes))
gex_adata = gex_adata[:, unique_marker_genes].copy()

# 打印新的AnnData对象的信息，以确认基因数量已减少
print("\n过滤后的gex_adata信息：")
print(gex_adata)

In [None]:
gex_adata.layers

In [None]:
gex_adata.uns['rank_genes_groups']['names']

In [None]:
gex_adata.write('/data2/wfa/project/single_cell_multimodal/data/filter_data/BMMC/RNA_recon.h5ad')


In [None]:
def calculate_pcc_torch(x, y):
    """
    Calculates the Pearson Correlation Coefficient between two PyTorch tensors.

    Args:
        x (torch.Tensor): The first tensor.
        y (torch.Tensor): The second tensor.

    Returns:
        torch.Tensor: The Pearson Correlation Coefficient.
    """
    if x.ndim > 1:
        x = x.squeeze()
    if y.ndim > 1:
        y = y.squeeze()

    if x.shape != y.shape:
        raise ValueError("Input tensors must have the same shape")

    vx = x - torch.mean(x)
    vy = y - torch.mean(y)

    numerator = torch.sum(vx * vy)
    denominator = torch.sqrt(torch.sum(vx ** 2)) * torch.sqrt(torch.sum(vy ** 2))

    if denominator == 0:
        # Return 0 if the standard deviation of either variable is 0
        # (e.g., if all values in a tensor are the same)
        # This avoids division by zero.
        # Other conventions might return NaN or raise an error.
        return torch.tensor(0.0)

    return numerator / denominator

def calculate_r_squared_torch(y_true, y_pred):
    # 计算总平方和 (TSS)
    tss = torch.sum((y_true - torch.mean(y_true, axis=0)) ** 2, axis=0)
    # 计算残差平方和 (RSS)
    rss = torch.sum((y_true - y_pred) ** 2, axis=0)
    # 计算R平方值
    r_squared = 1 - (rss / tss)
    return r_squared

def calculate_pcc_per_cell(X, Y):
    """
    Calculates the Pearson Correlation Coefficient for each corresponding row (cell)
    between two matrices.

    Args:
        X (torch.Tensor): The first matrix (cells x genes).
        Y (torch.Tensor): The second matrix (cells x genes).

    Returns:
        torch.Tensor: A 1D tensor where each element is the PCC for the
                      corresponding row (cell).
    """
    if X.shape != Y.shape:
        raise ValueError("Input matrices must have the same shape")

    # Calculate mean for each row (cell), across all columns (genes)
    # The `dim=1` argument is the key change here.
    # `keepdim=True` makes broadcasting easier in the next step.
    mean_x = torch.mean(X, dim=1, keepdim=True)
    mean_y = torch.mean(Y, dim=1, keepdim=True)

    # Center the matrices by subtracting the row means
    vx = X - mean_x
    vy = Y - mean_y

    # Calculate the numerator of the PCC formula for each row
    # We sum the products along the 'gene' dimension (dim=1)
    numerator = torch.sum(vx * vy, dim=1)

    # Calculate the denominator for each row
    denominator = torch.sqrt(torch.sum(vx ** 2, dim=1)) * torch.sqrt(torch.sum(vy ** 2, dim=1))

    # Add a small epsilon to the denominator to avoid division by zero
    epsilon = 1e-8
    pcc = numerator / (denominator + epsilon)

    return pcc

import numpy as np
from scipy.spatial.distance import pdist, squareform
from scipy.stats import pearsonr

def mantel_test(dist_x, dist_y, perms=1000):
    """
    Performs a Mantel test between two distance matrices.
    """
    # 展平上三角（不含对角线）
    v_x = dist_x[np.triu_indices_from(dist_x, k=1)]
    v_y = dist_y[np.triu_indices_from(dist_y, k=1)]

    # 计算观测到的相关性
    observed_corr, _ = pearsonr(v_x, v_y)

    # 置换检验来评估显著性
    perm_corrs = []
    for _ in range(perms):
        # 随机打乱其中一个向量
        shuffled_v_y = np.random.permutation(v_y)
        perm_corrs.append(pearsonr(v_x, shuffled_v_y)[0])

    p_value = (np.sum(np.abs(perm_corrs) >= np.abs(observed_corr)) + 1) / (perms + 1)

    return observed_corr, p_value

In [None]:
import pandas as pd
import torch
from sklearn.preprocessing import StandardScaler # <-- 导入 StandardScaler
from sklearn.preprocessing import MinMaxScaler # <-- 导入 MinMaxScaler

results_list = []

# 创建一个 StandardScaler 的实例
# 在循环外部创建，可以在不同细胞类型间重用，但每次都用 .fit_transform() 会重新拟合
scaler = StandardScaler()

for cell_type in gex_adata.obs['cell_type'].unique():
    # 提取特定细胞类型的数据
    gex_adata_cell_type = gex_adata[gex_adata.obs['cell_type'] == cell_type].copy()

    # 准备计算所需的数据
    actual_rna = gex_adata_cell_type.X.toarray()
    reconstructed_rna = gex_adata_cell_type.layers['rna_recon_from_rna_private']

    actual_rna = scaler.fit_transform(actual_rna)
    reconstructed_rna = scaler.fit_transform(reconstructed_rna)


    # 使用标准化后的数据计算 PCC
    # average_pcc = calculate_pcc_torch(torch.from_numpy(actual_rna), torch.from_numpy(reconstructed_rna))
    average_pcc, _ = mantel_test(actual_rna, reconstructed_rna)
    # pcc_value = average_pcc.mean(dim=0).item()
    pcc_value = average_pcc
    cell_count = gex_adata_cell_type.shape[0]

    # 打印实时结果 (此步骤可选)
    print(f'cell type: {cell_type}, pcc: {pcc_value:.4f}, cell num: {cell_count}')

    # 2. 将本次循环的结果存为一个字典
    result_row = {
        'cell_type': cell_type,
        'pcc': pcc_value,
        'cell_num': cell_count
    }

    # 3. 将该字典追加到列表中
    results_list.append(result_row)

# 4. 循环结束后，使用列表一次性创建 DataFrame
results_df = pd.DataFrame(results_list)

# 5. 显示最终的 DataFrame
print("\n--- 所有结果已汇总到 DataFrame ---")
print(results_df)

In [None]:
actual_rna = gex_adata.X.toarray()
reconstructed_rna = gex_adata.layers['rna_recon_from_rna_private']

actual_rna = scaler.fit_transform(actual_rna)
reconstructed_rna = scaler.fit_transform(reconstructed_rna)

all_pcc = calculate_pcc_torch(torch.from_numpy(actual_rna), torch.from_numpy(reconstructed_rna))
print(all_pcc.mean(dim=0).item())

In [None]:
# The list of cell types you want to exclude
exclude_cell_types = ['Erythroblast', 'Normoblast']

# Use .isin() with the ~ operator to select cells NOT in the list
gex_adata_others = gex_adata[~gex_adata.obs['cell_type'].isin(exclude_cell_types)].copy()
actual_rna = gex_adata_others.X.toarray()
reconstructed_rna = gex_adata_others.layers['rna_recon_from_rna_private']

all_pcc = calculate_pcc_per_cell(torch.from_numpy(actual_rna), torch.from_numpy(reconstructed_rna))
print(all_pcc.mean(dim=0).item())

In [4]:
from sctfbridge.model import explain_DisLatent

In [None]:
output = explain_DisLatent(
    new_model,
    [gex_adata, atac_adata, TF_adata],
    'RNA',
    'Erythroblast',
    'cell_type',
    'batch',
    torch.device(6),
)

In [None]:
output

In [None]:
Normoblast_output = explain_DisLatent(
    new_model,
    [gex_adata, atac_adata, TF_adata],
    'RNA',
    'Normoblast',
    'cell_type',
    'batch',
    torch.device(6),
)

In [None]:
Normoblast_output

In [5]:
CD4_T_naive = explain_DisLatent(
    new_model,
    [gex_adata, atac_adata, TF_adata],
    'RNA',
    'CD4+ T naive',
    'cell_type',
    'batch',
    torch.device(6),
)

🔍 Starting latent space contribution analysis for 'RNA' omics and cell type: 'CD4+ T naive'...
  - Using device: cuda:6
Calculating RNA private-share contribution value for: CD4+ T naive
  - Preparing data loaders...
  - Initializing explanation model and background samples...


  0%|          | 1/880 [00:00<04:57,  2.96batch/s]
  1%|          | 9/880 [00:00<00:19, 44.94batch/s]


  - Calculating private vs. shared contributions for 3000 features. This may take a while...
✅ Latent space contribution analysis complete for 'RNA' omics.


In [6]:
CD4_T_naive

Unnamed: 0,feature_name,private_embedding_contribute,share_embedding_contribute
0,C1orf159,0.734662,0.265338
1,SLC35E2B,0.661319,0.338681
2,SLC35E2A,0.700663,0.299337
3,CEP104,0.621083,0.378917
4,C1orf174,0.740358,0.259642
...,...,...,...
2995,MT-ND4,0.613493,0.386507
2996,MT-ND5,0.690819,0.309181
2997,MT-ND6,0.633150,0.366850
2998,MT-CYB,0.741963,0.258037
