In [1]:
import warnings
warnings.filterwarnings('ignore')

# Imports

In [2]:
import torch
import scanpy as sc
import numpy as np
import anndata as ad
from tqdm import tqdm
import os

from scdisentangle.train.tools import get_trainer, set_seed

# Params

In [3]:
seed_nb = 42
yaml_path = '../configs/seurat_disentangle.yaml'
weights_path =  '../weights/MIG_BINNED_dis_latent_stack_donor_train'

counterfactual_dict = {
    'donor': 'P7'
}

condition_name = 'donor'
get_recs = False

n_latent_size = 16

# Set seed

In [4]:
set_seed(seed_nb)

ic| 'Setting seed to', seed: 42


# Get trainer

In [5]:
trainer = get_trainer(yaml_path, wandb_log=False)
trainer.load_weights(weights_path)

Global seed set to 0
ic| 'Setting seed to', seed: 42
ic| 'Creating cell mappings'
ic| 'Creating inputs'
ic| 'Creating inputs'


Wandb is off
Loading weights from ../weights/MIG_BINNED_dis_latent_stack_donor_train


In [6]:
counterfactual_dict_int = {
    k:trainer.dataset.reverse_label_mapping[k][v] for k,v in counterfactual_dict.items()
}

# Input adata to predict

In [7]:
adata = sc.read_h5ad(
    '../../../Datasets/preprocessed_datasets/seurat.h5ad'
)

# Predict to get latent

In [8]:
adata = trainer.predict(
    trainer.dataset.data.copy(), 
    counterfactual_dict=counterfactual_dict, 
    bs=256
)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 366/366 [00:02<00:00, 128.86it/s]


In [9]:
adata

AnnData object with n_obs × n_vars = 93542 × 5000
    obs: 'nCount_ADT', 'nFeature_ADT', 'nCount_RNA', 'nFeature_RNA', 'orig.ident', 'lane', 'donor', 'time', 'celltype.l1', 'celltype.l2', 'celltype.l3', 'Phase', 'nCount_SCT', 'nFeature_SCT', 'X_index', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'Protein log library size', 'Number proteins detected', 'RNA log library size', 'n_counts', 'sc_cell_ids', 'cell_ids', 'split_anndata', 'donor_org', 'donor_pred'
    var: 'mt', 'n_counts', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'hvg', 'log1p'
    obsm: 'protein_counts', 'dis_latent_stack', 'cat_latent_stack', 'cat_latent_stack_collapse', 'map_latent_summed_collapse', 'map_latent_summed'
    layers: 'counts'

In [None]:
from scipy.sparse import csr_matrix

adata.X = csr_matrix(adata.X)

In [None]:
adata.obs['celltype.l1'].value_counts()

# Optionally balance cell types

In [10]:
from utils import downsample_balance_by_cell_type
adata = downsample_balance_by_cell_type(adata, key='celltype.l1')
adata

AnnData object with n_obs × n_vars = 855 × 5000
    obs: 'nCount_ADT', 'nFeature_ADT', 'nCount_RNA', 'nFeature_RNA', 'orig.ident', 'lane', 'donor', 'time', 'celltype.l1', 'celltype.l2', 'celltype.l3', 'Phase', 'nCount_SCT', 'nFeature_SCT', 'X_index', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'Protein log library size', 'Number proteins detected', 'RNA log library size', 'n_counts', 'sc_cell_ids', 'cell_ids', 'split_anndata', 'donor_org', 'donor_pred'
    var: 'mt', 'n_counts', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'hvg', 'log1p'
    obsm: 'protein_counts', 'dis_latent_stack', 'cat_latent_stack', 'cat_latent_stack_collapse', 'map_latent_summed_collapse', 'map_latent_summed'
    layers: 'counts'

In [11]:
get_recs = True

In [12]:
adata_structured = None

for idx in tqdm(range(n_latent_size)):
    _adata = sc.AnnData(adata.X, obs=adata.obs)
    _adata.obsm = adata.obsm.copy()
    
    dis_latent_stack_progressive = adata.obsm['dis_latent_stack'].copy()
    dis_latent_stack = adata.obsm['dis_latent_stack'].copy()
        
    if idx != n_latent_size-1:
        dis_latent_stack_progressive[:, idx+1:] = dis_latent_stack_progressive[:, idx+1:].mean(axis=0)
        
    covariates_batch = {
        condition_name: torch.zeros(
            _adata.shape[0]
            ).long().to(trainer.device) * counterfactual_dict_int[condition_name]
    }
    
    x_inp = torch.tensor(
        dis_latent_stack_progressive
        ).to(trainer.device)

    dis_latent_prog = [ 
        torch.tensor(
            dis_latent_stack_progressive[:, x])
        .unsqueeze(1).to(
            trainer.device
        ) for x in range(
            dis_latent_stack_progressive.shape[1])
        ]
    
    counterfactual_latent = trainer.get_counterfactuals(
                        x_inp=x_inp,
                        variables=covariates_batch,
                        dis_latent=dis_latent_prog,
                        counterfactual_dict=counterfactual_dict,
                        suffixe='',
                        )

    map_latent_summed = counterfactual_latent['map_latent_summed'].cpu().detach().numpy()
    _adata.obsm['map_latent_summed'] = map_latent_summed
    
    if get_recs:
        library = torch.log(
            torch.tensor(
                adata.X.sum(1)).unsqueeze(1)
        ).to(trainer.device)
        counterfactual_recs = trainer.get_recs(
                            decoder_name='decoder',
                            decoder_input= counterfactual_latent['map_latent_summed'],
                            px_name='px_r',
                            library=library,
                            suffixe='_collapse',
                        )
        print(counterfactual_latent['map_latent_summed'].shape, library.shape)
        reconstructed = counterfactual_recs['reconstructed_collapse'].cpu().detach().numpy()
        _adata.X = reconstructed
        
    _adata.obsm['dis_latent_stack'] = adata.obsm['dis_latent_stack']
    _adata.obsm['dis_latent_progressive'] = dis_latent_stack_progressive
        
    _adata.obs['level'] = idx

    if not get_recs:
        _adata = sc.AnnData(
            _adata.obsm['map_latent_summed'],
            obs = _adata.obs.copy(),
            obsm = _adata.obsm.copy()
        )
        
    if adata_structured is None:
        adata_structured = _adata.copy()
    else:
        adata_structured = ad.concat([adata_structured, _adata])

 38%|███████████████████████████████████████████████████████████████████████████▊                                                                                                                              | 6/16 [00:00<00:00, 55.47it/s]

torch.Size([855, 16]) torch.Size([855, 1])
torch.Size([855, 16]) torch.Size([855, 1])
torch.Size([855, 16]) torch.Size([855, 1])
torch.Size([855, 16]) torch.Size([855, 1])
torch.Size([855, 16]) torch.Size([855, 1])
torch.Size([855, 16]) torch.Size([855, 1])
torch.Size([855, 16]) torch.Size([855, 1])
torch.Size([855, 16]) torch.Size([855, 1])
torch.Size([855, 16]) torch.Size([855, 1])
torch.Size([855, 16]) torch.Size([855, 1])


 75%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                                  | 12/16 [00:00<00:00, 36.91it/s]

torch.Size([855, 16]) torch.Size([855, 1])
torch.Size([855, 16]) torch.Size([855, 1])
torch.Size([855, 16]) torch.Size([855, 1])
torch.Size([855, 16]) torch.Size([855, 1])
torch.Size([855, 16]) torch.Size([855, 1])


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 33.02it/s]

torch.Size([855, 16]) torch.Size([855, 1])





In [13]:
adata_structured

AnnData object with n_obs × n_vars = 13680 × 5000
    obs: 'nCount_ADT', 'nFeature_ADT', 'nCount_RNA', 'nFeature_RNA', 'orig.ident', 'lane', 'donor', 'time', 'celltype.l1', 'celltype.l2', 'celltype.l3', 'Phase', 'nCount_SCT', 'nFeature_SCT', 'X_index', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'Protein log library size', 'Number proteins detected', 'RNA log library size', 'n_counts', 'sc_cell_ids', 'cell_ids', 'split_anndata', 'donor_org', 'donor_pred', 'level'
    obsm: 'protein_counts', 'dis_latent_stack', 'cat_latent_stack', 'cat_latent_stack_collapse', 'map_latent_summed_collapse', 'map_latent_summed', 'dis_latent_progressive'

# Save adata with latent levels

In [14]:
adata_structured.write_h5ad('adata_structured_balanced_count.h5ad')