In [3]:
import torch
import numpy as np
import random
from sctfbridge.model import scTFBridge

def set_seed(seed):
    import os
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    # torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.set_float32_matmul_precision('high')


set_seed(3407)

In [4]:
from sctfbridge.utils.data_processing import multiomics_processing, preload_TF_binding
import anndata


dataset = 'human_PBMC'
output_path = f'filter_data/{dataset}/'

rna_adata = anndata.read_h5ad(f'/data2/wfa/project/single_cell_multimodal/data/filter_data/{dataset}/10x-Multiome-Pbmc10k-RNA.h5ad')
atac_data = anndata.read_h5ad(f'/data2/wfa/project/single_cell_multimodal/data/filter_data/{dataset}/10x-Multiome-Pbmc10k-ATAC.h5ad')

TF_name = '/data2/wfa/project/single_cell_multimodal/data/GRN/data_bulk/TFName.txt'
TF_name = open(TF_name, 'r').readlines()
for i in range(len(TF_name)):
    TF_name[i] = TF_name[i].replace('\n', '')


multiomics_processing([rna_adata, atac_data],
                            output_path,
                            TF_name,
                            3000,
                            0.01, )

rna_adata = anndata.read_h5ad(f'filter_data/{dataset}/RNA_filter.h5ad')

GRNdir = '/data2/wfa/project/single_cell_multimodal/data/GRN/data_bulk/'
tf_binding_output_path = f'filter_data/{dataset}/TF_binding/'
preload_TF_binding(output_path, GRNdir, tf_binding_output_path)

Found 514 TFs from TF_list in the original RNA data.
Final TF set includes 128 HVG TFs and 6 essential TFs, for a total of 134 unique TFs.




Finish data pre-processing
filter_data/human_PBMC/TF_binding/Region.bed
                                0
0            chr1\t816881\t817647
1            chr1\t819912\t823500
2            chr1\t826612\t827979
3            chr1\t841243\t843059
4            chr1\t843966\t845044
...                           ...
79358  chrX\t155611306\t155613309
79359  chrX\t155632352\t155633090
79360  chrX\t155820122\t155820523
79361  chrX\t155841301\t155841724
79362  chrX\t155880572\t155882091

[79363 rows x 1 columns]
filter_data/human_PBMC/TF_binding/Region.bed
Index(['chr1:816881-817647', 'chr1:819912-823500', 'chr1:826612-827979',
       'chr1:841243-843059', 'chr1:843966-845044', 'chr1:857960-858997',
       'chr1:865576-866044', 'chr1:869449-870383', 'chr1:898356-899127',
       'chr1:903617-907386',
       ...
       'GL000219.1:125214-125653', 'KI270721.1:2341-2736',
       'KI270726.1:27352-27794', 'KI270726.1:41529-42186',
       'KI270713.1:4147-4624', 'KI270713.1:20444-22615',
       'KI27071

100%|██████████| 23/23 [07:24<00:00, 19.32s/it]


finish load TF_binding


In [5]:
import anndata
dataset_name = 'human_PBMC'
cell_key = 'cell_type'
batch_key = ''


gex_data = anndata.read_h5ad(f'filter_data/{dataset_name}/RNA_filter.h5ad')
atac_adata = anndata.read_h5ad(f'filter_data/{dataset_name}/ATAC_filter.h5ad')
TF_adata = anndata.read_h5ad(f'filter_data/{dataset_name}/TF_filter.h5ad')
mask_path = f'filter_data/{dataset_name}/TF_binding/TF_binding.txt'

In [7]:
model = scTFBridge.from_anndata(
    gex_data,
    atac_adata,
    TF_adata,
    mask_path,
    batch_key='',
    device=torch.device('cuda:6'),
)
model.fit(
    [gex_data, atac_adata, TF_adata],
    epochs=150
)

🔧 Initializing model from AnnData objects...
  - RNA input features: 3000
  - ATAC input features: 79386
  - Latent dimension (from TFs): 134
  - Loading TF-peak mask from: filter_data/human_PBMC/TF_binding/TF_binding.txt




✅ Model ready.
🔥 Starting model training for up to 150 epochs...
  - Using device: cuda:6
  - Preparing datasets and data loaders...




  - Splitting data: 7705 training samples, 1926 validation samples.


Epoch 85/150:  56%|█████▌    | 84/150 [06:27<05:04,  4.61s/epoch, train_loss=623428.7573, val_loss=510759.4165]



🛑 Early stopping triggered at epoch 85.
✅ Training finished successfully.


In [8]:
model.save('sctfbridge_model', overwrite=True)


💾 Saving model to sctfbridge_model...
✅ Model saved successfully.


In [None]:
new_model = scTFBridge.load('sctfbridge_model', device=torch.device('cuda:7'))


In [None]:
output = new_model.get_embeddings(
    [gex_data, atac_adata, TF_adata],
)

In [None]:
from sctfbridge.model import explain_TF2TG, explain_RE2TG, explain_DisLatent


In [None]:
output = explain_RE2TG(new_model,
                       [gex_data, atac_adata, TF_adata],
                       'human_PBMC',
                       use_gene_list=gex_data.var.index[:10],
                       cell_type='CD14 Mono',
                       cell_key='cell_type',
                       batch_key='',
                       device=torch.device('cuda:7'),
                       tf_binding_path=f'filter_data/{dataset_name}/TF_binding/')

In [None]:
TF_output = explain_TF2TG(new_model,
                          [gex_data, atac_adata, TF_adata],
                          cell_type='CD14 Mono',
                          cell_key='cell_type',
                          batch_key='',
                          device=torch.device('cuda:7'),)