In [1]:
import torch
import numpy as np
import random
from sctfbridge.model import scTFBridge

def set_seed(seed):
    import os
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    # torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.set_float32_matmul_precision('high')


set_seed(3407)



In [2]:
import anndata
dataset_name = 'human_PBMC'
cell_key = 'cell_type'
batch_key = ''


gex_data = anndata.read_h5ad(f'filter_data/{dataset_name}/RNA_filter.h5ad')
atac_adata = anndata.read_h5ad(f'filter_data/{dataset_name}/ATAC_filter.h5ad')
TF_adata = anndata.read_h5ad(f'filter_data/{dataset_name}/TF_filter.h5ad')
mask_path = f'filter_data/{dataset_name}/TF_binding/TF_binding.txt'

In [3]:
new_model = scTFBridge.load('sctfbridge_model', device=torch.device('cuda:7'))


🚀 Loading model from sctfbridge_model...
  - Using device: cuda:7
✅ Model loaded and ready for inference.


In [4]:
from src.sctfbridge.model import explain_TF2TG

In [5]:
# calculate TF regulatory scores to each TG
output = explain_TF2TG(
    new_model,
    [gex_data, atac_adata, TF_adata],
    cell_type='CD14 Mono',
    cell_key=cell_key,
    batch_key=batch_key,
    device=torch.device('cuda:7'),
)

🧬 Starting TF-to-Target Gene (TF2TG) explanation for cell type: 'CD14 Mono'...
  - Using device: cuda:7
  - Filtering data for cell type: 'CD14 Mono'
  - Preparing data loaders...
  - Initializing explanation model and background samples...


  4%|▎         | 19/511 [00:00<00:08, 61.03batch/s]


  - Calculating attributions for 3000 target genes. This may take a while...
  - Aggregating results...
✅ TF2TG explanation complete for 'CD14 Mono'. Returning mean absolute SHAP values (Genes x TFs).


In [6]:
output.shape

(3000, 134)

In [10]:
import pandas as pd

trans_df = pd.DataFrame(output, columns=TF_adata.var_names, index=gex_data.var_names)

In [12]:
trans_df

genes,BACH2,AIRE,SOX5,CTCF,IRF8,NFIB,GLIS3,HMX2,NR3C2,CTCFL,...,HOXB7,NFATC4,HNF1A,LHX3,POU6F2,MAF,GFI1B,RUNX1,SIX6,HEY1
genes,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
ISG15,0.002680,0.000901,0.006917,0.008976,0.001245,0.002525,0.002029,0.006060,0.003846,0.001055,...,0.004735,0.005145,0.004021,0.002092,0.002343,0.004000,0.002754,0.002946,0.003722,0.003841
SKI,0.000132,0.000365,0.001249,0.000673,0.000497,0.000344,0.000437,0.000532,0.000216,0.000284,...,0.000395,0.001196,0.000264,0.000831,0.000545,0.000908,0.000441,0.000483,0.000548,0.001366
CEP104,0.000312,0.000096,0.000418,0.000604,0.000386,0.000385,0.000111,0.000391,0.000474,0.000187,...,0.000137,0.000219,0.000232,0.000543,0.000229,0.000518,0.000264,0.000446,0.000183,0.001061
NOL9,0.000115,0.000105,0.000231,0.000401,0.000259,0.000503,0.000090,0.000405,0.000228,0.000123,...,0.000053,0.000132,0.000097,0.000191,0.000358,0.000484,0.000119,0.000513,0.000073,0.000827
DNAJC11,0.000472,0.000224,0.000303,0.000479,0.000289,0.000503,0.000065,0.000396,0.000493,0.000270,...,0.000136,0.000314,0.000339,0.000575,0.000375,0.000672,0.000135,0.000375,0.000257,0.000207
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
MT-ND2,0.035183,0.013598,0.033657,0.029002,0.031747,0.039710,0.015021,0.018838,0.034966,0.018921,...,0.030565,0.019697,0.022923,0.008681,0.020817,0.026375,0.021538,0.050100,0.013455,0.029946
MT-CO1,0.046711,0.017919,0.036469,0.047829,0.039823,0.056997,0.025881,0.021401,0.044552,0.023251,...,0.047945,0.021213,0.029412,0.017130,0.028057,0.032440,0.037039,0.062605,0.018765,0.042281
MT-CO2,0.050151,0.021082,0.056372,0.046511,0.054511,0.060222,0.019511,0.022033,0.059353,0.028175,...,0.042913,0.033873,0.030746,0.018962,0.033484,0.039065,0.033358,0.075183,0.022234,0.052585
MT-ND4,0.040023,0.015869,0.042825,0.041761,0.038980,0.046259,0.015783,0.021200,0.039497,0.022093,...,0.033787,0.024328,0.023906,0.012435,0.023494,0.028112,0.024771,0.056260,0.016277,0.037600
