In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION']='python'

In [3]:
import dis2p.dis2pvae as dvae
import dis2p.dis2pvi as dvi

[2023-11-09 11:41:24,194] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)


Global seed set to 0


In [4]:
import scvi
import scanpy as sc
import pandas as pd
from datetime import datetime
def create_cats_idx(adata, cats):
    # create numerical index for each attr in cats

    for i in range(len(cats)):
        values = list(set(adata.obs[cats[i]]))

        val_to_idx = {v: values.index(v) for v in values}

        idx_list = [val_to_idx[v] for v in adata.obs[cats[i]]]

        adata.obs[cats[i] + '_idx'] = pd.Categorical(idx_list)

    return adata

adata = scvi.data.heart_cell_atlas_subsampled()

# preprocess dataset
sc.pp.filter_genes(adata, min_counts=3)
adata.layers["counts"] = adata.X.copy()
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
adata.raw = adata
sc.pp.highly_variable_genes(
    adata,
    n_top_genes=1200,
    subset=True,
    layer="counts",
    flavor="seurat_v3",
)

# specify name of dataset 
data_name = 'HeartAtlas'

# specify attributes
cats = ['cell_type', 'cell_source', 'gender', 'region']

# create numerical index for each attr in cats
create_cats_idx(adata, cats)

# save adata
# adata.write_h5ad('data/heart_preprocessed1200.h5ad')

today = datetime.today().strftime('%Y-%m-%d')

[34mINFO    [0m File data/hca_subsampled_20k.h5ad already downloaded                                                      


In [5]:
# train params
!rm -rf models
epochs = 400
batch_size = 128
cf_weight = 1
beta = 1
clf_weight = 50
adv_clf_weight = 10
adv_period = 1
n_cf = 1

# architecture params
n_layers=1

train_dict = {'max_epochs': epochs, 'batch_size': batch_size, 'cf_weight': cf_weight,
              'beta': beta, 'clf_weight': clf_weight, 'adv_clf_weight': adv_clf_weight,
              'adv_period': adv_period, 'n_cf': n_cf}

module_name = 'dis2p'
pre_path = f'models/{module_name}'
if not os.path.exists(pre_path):
    os.makedirs(pre_path)

# specify a name for your model
model_name =  f'{today},{module_name},{data_name},' + f'n_layers={n_layers},' + ','.join(k + '=' + str(v) for k, v in train_dict.items())

dvi.Dis2pVI.setup_anndata(
    adata,
    layer='counts',
    categorical_covariate_keys=cats,
    continuous_covariate_keys=[]
)
model = dvi.Dis2pVI(adata, n_layers=n_layers)
model.train(**train_dict)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100-SXM4-80GB MIG 1c.7g.79gb') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [MIG-GPU-8391e223-da74-0458-e121-783edc78bf21/0/0]
  rank_zero_warn(


Epoch 13/400:   3%|▎         | 12/400 [01:48<58:13,  9.00s/it, v_num=1, loss_validation=2.19e+3, x_0_validation=294, x_1_validation=303, x_2_validation=302, x_3_validation=301, x_4_validation=302, rec_x_cf_validation=497, z_1_validation=23.4, z_2_validation=23.6, z_3_validation=28.2, z_4_validation=25.8, ce_validation=0.916, acc_validation=0.996, f1_validation=0.996, adv_ce_validation=1.37, adv_acc_validation=0.531, adv_f1_validation=0.531, loss_train=2.05e+3, x_0_train=283, x_1_train=292, x_2_train=291, x_3_train=290, x_4_train=292, rec_x_cf_train=468, z_1_train=24.4, z_2_train=24, z_3_train=26.4, z_4_train=25.8, ce_train=0.92, acc_train=0.992, f1_train=0.992, adv_ce_train=1.35, adv_acc_train=0.517, adv_f1_train=0.517]  

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
