In [None]:
import numpy as np
import scanpy as sc
import anndata as ad
import torch
import random

import biolord
import os
from icecream import ic
import gc

AttributeError: `np.unicode_` was removed in the NumPy 2.0 release. Use `np.str_` instead.

In [None]:
def total_to_median_norm(_adata, data_median):
        _adata.X = np.expm1(_adata.X)
        _adata.X = _adata.X / 1e4
        _adata.X = _adata.X * data_median
    
        sc.pp.log1p(_adata)
        return _adata

In [None]:
def set_seed(seed):
    ic('Setting seed to', seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
adata_path = '/data/Experiments/Benchmark/SCDISENTANGLE_REPRODUCE/Datasets/preprocessed_datasets/kang.h5ad'
cov_key = "cell_type"
cond_key = "condition"
ood_cov = "B"
control_name = "control"
stim_name = "stimulated"
vars_to_predict = ['stimulated', 'control']
categorical_attributes = ['condition', 'cell_type']
seed_nb = 1
device_nb = 1

In [None]:
adata = sc.read_h5ad(adata_path)

In [None]:
# Set seed
set_seed(seed_nb)

In [None]:
_train_adata = adata[adata.obs[f'split_{stim_name}_{ood_cov}'] == 'train'].copy()
_val_adata = adata[adata.obs[f'split_{stim_name}_{ood_cov}'] == 'val'].copy()
_ood_adata = adata[adata.obs[f'split_{stim_name}_{ood_cov}'] == 'ood'].copy()

In [None]:
# Compute train median
_sums = _train_adata.X.sum(axis=1, keepdims=True)
data_median = np.median(_sums)

# Replace val by test
adata.obs[f'split_{stim_name}_{ood_cov}'] = [x.replace('val', 'test') for x in adata.obs[f'split_{stim_name}_{ood_cov}']]
    
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
    
# Setup Biolord data
biolord.Biolord.setup_anndata(
        adata,
        ordered_attributes_keys=None,
        categorical_attributes_keys= categorical_attributes,
    )

In [None]:
# Params
module_params = {
        "decoder_width": 1024,
        "decoder_depth": 4,
        "attribute_nn_width": 512,
        "attribute_nn_depth": 2,
        "n_latent_attribute_categorical": 4,
        "gene_likelihood": "normal",
        "reconstruction_penalty": 1e2,
        "unknown_attribute_penalty": 1e1,
        "unknown_attribute_noise_param": 1e-1,
        "attribute_dropout_rate": 0.1,
        "use_batch_norm": False,
        "use_layer_norm": False,
        "seed": seed_nb,
    }

In [None]:
# Init Biolord model
model = biolord.Biolord(
        adata=adata,
        n_latent=32,
        model_name=f'{ood_cov}_{seed_nb}',
        module_params=module_params,
        train_classifiers=False,
        split_key=f'split_{stim_name}_{ood_cov}',
    )

In [None]:
# Trainer params
trainer_params = {
        "n_epochs_warmup": 0,
        "latent_lr": 1e-4,
        "latent_wd": 1e-4,
        "decoder_lr": 1e-4,
        "decoder_wd": 1e-4,
        "attribute_nn_lr": 1e-2,
        "attribute_nn_wd": 4e-8,
        "step_size_lr": 45,
        "cosine_scheduler": True,
        "scheduler_final_lr": 1e-5,
    }

In [None]:
# Train
model.train(
        max_epochs=500,
        batch_size=512,
        plan_kwargs=trainer_params,
        early_stopping=True,
        early_stopping_patience=20,
        check_val_every_n_epoch=10,
        num_workers=1,
        enable_checkpointing=False,
    )

In [None]:
adata_subset = adata[(adata.obs[cond_key] == control_name) & (adata.obs[cov_key] == ood_cov) & (adata.obs[f'split_{stim_name}_{ood_cov}'] == 'train')].copy()

In [None]:
adata_source = adata_subset.copy()

In [None]:
adata_preds = model.compute_prediction_adata(
        adata, 
        adata_source, 
        target_attributes=[cond_key], 
        add_attributes=[cov_key, f'split_{stim_name}_{ood_cov}', 'sc_cell_ids']
    )

In [None]:
adata_preds = total_to_median_norm(adata_preds, data_median)

In [None]:
adata_preds[adata_preds.obs['condition'] == 'stimulated'].X.max()

In [None]:
adata_preds[adata_preds.obs['condition'] == 'control'].X.max()