In [1]:
import warnings
warnings.filterwarnings('ignore')

# Imports

In [2]:
import torch
import scanpy as sc
import numpy as np
import anndata as ad
from tqdm import tqdm
import os

from scdisentangle.train.tools import get_trainer, set_seed

# Params

In [3]:
seed_nb = 42
yaml_path = '../configs/leukemia_disentangle.yaml'
weights_path =  '../weights/MIG_BINNED_dis_latent_stack_Sample_id_train'

counterfactual_dict = {
    'Sample_id': 'patient1_IP'
}

condition_name = 'Sample_id'
get_recs = False

n_latent_size = 16

# Set seed

In [4]:
set_seed(seed_nb)

ic| 'Setting seed to', seed: 42


# Get trainer

In [5]:
trainer = get_trainer(yaml_path, wandb_log=False)
trainer.load_weights(weights_path)

Global seed set to 0
ic| 'Setting seed to', seed: 42
ic| 'Creating cell mappings'
ic| 'Creating inputs'
ic| 'Creating inputs'


Wandb is off
Loading weights from ../weights/MIG_BINNED_dis_latent_stack_Sample_id_train


In [6]:
counterfactual_dict_int = {
    k:trainer.dataset.reverse_label_mapping[k][v] for k,v in counterfactual_dict.items()
}

# Input adata to predict

In [7]:
adata = sc.read_h5ad(
    '../../../Datasets/preprocessed_datasets/leukemia.h5ad'
)

# Predict to get latent

In [8]:
adata = trainer.predict(
    trainer.dataset.data.copy(), 
    counterfactual_dict=counterfactual_dict, 
    bs=256
)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 145/145 [00:01<00:00, 133.99it/s]


In [9]:
adata

AnnData object with n_obs × n_vars = 37100 × 5000
    obs: 'nCount_RNA', 'nFeature_RNA', 'percent.ribo', 'percent.mito', 'Sample_id', 'Transduction', 'Phase', 'Timepoint', 'Condition', 'CARexpresion', 'cloneType', 'Frequency', 'author_cell_type', 'tissue_ontology_term_id', 'assay_ontology_term_id', 'disease_ontology_term_id', 'cell_type_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'development_stage_ontology_term_id', 'sex_ontology_term_id', 'donor_id', 'suspension_type', 'is_primary_data', 'tissue_type', 'cell_type', 'assay', 'disease', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid', 'n_counts', 'sc_cell_ids', 'cell_ids', 'split_anndata', 'Sample_id_org', 'Sample_id_pred'
    var: 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype', 'feature_length', 'feature_type', 'n_counts', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'hvg', 'log1p'
    obsm: 'HARMONY', 'X_UMAP', 'dis_laten

In [10]:
from scipy.sparse import csr_matrix

adata.X = csr_matrix(adata.X)

In [12]:
adata = adata[adata.obs['cell_type'] != 'Ribosomal/Mitochondrial/Degraded cells']

In [13]:
adata.obs['cell_type'].value_counts()

cell_type
CD8+ Cytotoxic T cells                           6407
Early prolif: MCM3/5/7+ PCNA+ T cells            5160
CD4+ Naive T cells                               5119
Late prolif: histones enriched MKI67+ T cells    5061
CD8+ Effector T cells (E)                        4929
CD8+ Eff/Mem T cells (EM)                        3260
CD4+ Central/Effector memory T cells (CM/EM)     1967
Late prolif: CCNB1/2+ CDK1+ T cells              1708
Late prolif: STMN1+ BIRC5+                       1576
gamma-delta T cells                               988
Name: count, dtype: int64

# Optionally balance cell types

In [15]:
from utils import downsample_balance_by_cell_type
adata = downsample_balance_by_cell_type(adata, key='cell_type')
adata

AnnData object with n_obs × n_vars = 9880 × 5000
    obs: 'nCount_RNA', 'nFeature_RNA', 'percent.ribo', 'percent.mito', 'Sample_id', 'Transduction', 'Phase', 'Timepoint', 'Condition', 'CARexpresion', 'cloneType', 'Frequency', 'author_cell_type', 'tissue_ontology_term_id', 'assay_ontology_term_id', 'disease_ontology_term_id', 'cell_type_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'development_stage_ontology_term_id', 'sex_ontology_term_id', 'donor_id', 'suspension_type', 'is_primary_data', 'tissue_type', 'cell_type', 'assay', 'disease', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid', 'n_counts', 'sc_cell_ids', 'cell_ids', 'split_anndata', 'Sample_id_org', 'Sample_id_pred'
    var: 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype', 'feature_length', 'feature_type', 'n_counts', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'hvg', 'log1p'
    obsm: 'HARMONY', 'X_UMAP', 'dis_latent

In [16]:
adata_structured = None

for idx in tqdm(range(n_latent_size)):
    _adata = sc.AnnData(adata.X, obs=adata.obs)
    _adata.obsm = adata.obsm.copy()
    
    dis_latent_stack_progressive = adata.obsm['dis_latent_stack'].copy()
    dis_latent_stack = adata.obsm['dis_latent_stack'].copy()
        
    if idx != n_latent_size-1:
        dis_latent_stack_progressive[:, idx+1:] = dis_latent_stack_progressive[:, idx+1:].mean(axis=0)
        
    covariates_batch = {
        condition_name: torch.zeros(
            _adata.shape[0]
            ).long().to(trainer.device) * counterfactual_dict_int[condition_name]
    }
    
    x_inp = torch.tensor(
        dis_latent_stack_progressive
        ).to(trainer.device)

    dis_latent_prog = [ 
        torch.tensor(
            dis_latent_stack_progressive[:, x])
        .unsqueeze(1).to(
            trainer.device
        ) for x in range(
            dis_latent_stack_progressive.shape[1])
        ]
    
    counterfactual_latent = trainer.get_counterfactuals(
                        x_inp=x_inp,
                        variables=covariates_batch,
                        dis_latent=dis_latent_prog,
                        counterfactual_dict=counterfactual_dict,
                        suffixe='',
                        )

    map_latent_summed = counterfactual_latent['map_latent_summed'].cpu().detach().numpy()
    _adata.obsm['map_latent_summed'] = map_latent_summed
    
    if get_recs:
        library = torch.log(
            torch.tensor(
                adata.X.sum(1)).unsqueeze(1)
        ).to(trainer.device)
        counterfactual_recs = trainer.get_recs(
                            decoder_name='decoder',
                            decoder_input= counterfactual_latent['map_latent_summed'],
                            px_name='px_r',
                            library=library,
                            suffixe='_collapse',
                        )
        reconstructed = counterfactual_recs['reconstructed_collapse'].cpu().detach().numpy()
        _adata.obsm['reconstructed'] = reconstructed
        
    _adata.obsm['dis_latent_stack'] = adata.obsm['dis_latent_stack']
    _adata.obsm['dis_latent_progressive'] = dis_latent_stack_progressive
        
    _adata.obs['level'] = idx
    
    if adata_structured is None:
        adata_structured = _adata.copy()
    else:
        adata_structured = ad.concat([adata_structured, _adata])

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:06<00:00,  2.46it/s]


# Save adata with latent levels

In [17]:
adata_structured.write_h5ad('adata_structured_balanced.h5ad')