In [3]:
from typing import Dict, Iterable, Optional
# import scvi
import numpy as np
import torch
import pandas as pd
from torch.distributions import Normal, Poisson
from torch.distributions import kl_divergence as kld
from torch import tensor
from models import HALOMASKVIR as HALOVI
from models import HALOMASKVAE as HALOVAE
import scanpy as sc
import scvi
from matplotlib import pyplot as plt

# torch.autograd.set_detect_anomaly(True)

## Load the dataset

In [3]:
### test whole data with RNA only 
# path = "/mnt/data0/halo/skin/multivelo_hair.h5ad"
path = "/path/to/the hair/dataset"
adata_multi = sc.read_h5ad(path)
adata_multi.obs["batch_id"] = 1
adata_multi.var["modality"] =adata_multi.var["feature_types"]
adata_mvi = scvi.data.organize_multiome_anndatas(adata_multi)
sc.pp.filter_genes(adata_mvi, min_cells=int(adata_mvi.shape[0] * 0.01))

## Load the Model

In [4]:
HALOVI.setup_anndata(adata_mvi, batch_key="modality", time_key='latent_time')
model = HALOVI(
    adata_mvi,
    n_genes=(adata_mvi.var['modality']=='Gene Expression').sum(),
    n_regions=(adata_mvi.var['modality']=='Peaks').sum(),
    n_latent=20,
    n_dependent=10
)    

gate decoder initialization n_input 20, n_output 112656,         n_hidden_local 20, n_hidden_global 128, n_cat_list [1], *cat_list 1


## Non causal training Stage

In [6]:
model.module.set_finetune_params(0)
model.module.set_train_params(expr_train=True, acc_train=True)
model.train(max_epochs=800, batch_size=512)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Epoch 568/800:  71%|███████   | 567/800 [37:16<15:17,  3.94s/it, loss=9.02e+04, v_num=1]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [7]:
# model.save("model_myeloid_data_nocausal",overwrite=True)
model.save("model_hair_align_nocausal_20",overwrite=True)

In [5]:
## load causal model
model = model.load("model_hair_align_causal01_20",adata_mvi)

[34mINFO    [0m File model_hair_align_causal01_20/model.pt already downloaded                       
gate decoder initialization n_input 20, n_output 112656,         n_hidden_local 20, n_hidden_global 128, n_cat_list [1], *cat_list 1


## Causal Training Stage

In [8]:
model.module.alpha = 0.1
model.module.beta_2 = 1e6
model.module.beta_3 = 1e6
model.module.beta_1 = 1e6
model.module.n_latent_dep = 10
print(model.module.alpha)


0.1


In [9]:
## finetune without L0
# model.module.alpha = 0.06
model.module.set_finetune_params(2)
model.module.set_train_params(expr_train=True, acc_train=True)
model.plan_kwargs = {"n_epochs_kl_warmup":300}
model.train(max_epochs=800, batch_size=600)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Epoch 458/800:  57%|█████▋    | 457/800 [40:39<31:01,  5.43s/it, loss=1.89e+05, v_num=1]  

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


## Save the model

In [10]:

model.save("model_hair_align_causal01_20",overwrite=True)