In [1]:
from typing import Dict, Iterable, Optional

import numpy as np
import torch
from torch.distributions import Normal, Poisson
from torch.distributions import kl_divergence as kld
from torch import tensor
from complementary_models import HALOVI
from complementary_models import HALOVAE
import scanpy as sc
import scvi
import pandas as pd
torch.autograd.set_detect_anomaly(True) 


Global seed set to 0


<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7fd0b823eb20>

In [2]:
## load the data
adata_multi = sc.read_h5ad("halo/E18_mouse_Brain/multiomic.h5ad")
adata_multi.obs["batch_id"] = 1
adata_multi.var["modality"] =adata_multi.var["feature_types"]
adata_mvi = scvi.data.organize_multiome_anndatas(adata_multi)
sc.pp.filter_genes(adata_mvi, min_cells=int(adata_mvi.shape[0] * 0.01))
df_meta= pd.read_csv("halo/E18_mouse_Brain/RNA/metadata.tsv",sep = "\t",index_col=0)


In [3]:
## Merge Data
bins = df_meta.binned.unique()
times = {}
index = 0
for bin in sorted(bins):
    times[bin] = index
    index += 1

def add_time(row, times):
    timestamp = times[row.binned]
    return timestamp

df_meta['time_key'] = df_meta.apply(lambda row: add_time(row, times), axis=1)

newindex = []

for idx, row in df_meta.iterrows():
    newindex.append(idx+"_paired")

df_meta['Id'] = newindex    

df_meta_sub = df_meta[["Id", 'time_key']]

df_meta_sub.set_index("Id", inplace=True)
adata_mvi.obs = adata_mvi.obs.join(df_meta_sub, how="inner")
adata_mvi.obs

Unnamed: 0,celltype,batch_id,modality,time_key
AAACAGCCAACCGCCA-1_paired,Upper Layer,1,paired,12
AAACAGCCAAGGTCGA-1_paired,"RG, Astro, OPC",1,paired,5
AAACAGCCAGGAACAT-1_paired,Deeper Layer,1,paired,17
AAACAGCCATATTGAC-1_paired,Deeper Layer,1,paired,19
AAACAGCCATGGTTAT-1_paired,Subplate,1,paired,10
...,...,...,...,...
TTTGTGGCATAATCGT-1_paired,Ependymal cells,1,paired,9
TTTGTGGCATTTGCTC-1_paired,Upper Layer,1,paired,11
TTTGTGTTCAATGACC-1_paired,IPC,1,paired,5
TTTGTTGGTGGAGCAA-1_paired,Deeper Layer,1,paired,15


In [4]:
adata_mvi.X

<3365x138466 sparse matrix of type '<class 'numpy.float32'>'
	with 42912160 stored elements in Compressed Sparse Row format>

In [5]:
HALOVI.setup_anndata(adata_mvi, batch_key="modality", time_key='time_key')
mvi_p = HALOVI(
    adata_mvi,
    n_genes=(adata_mvi.var['modality']=='Gene Expression').sum(),
    n_regions=(adata_mvi.var['modality']=='Peaks').sum(),
    alpha=0.01,
    beta_1=1e5,
    beta_2=1e5,
    beta_3=1e6

)

time key in registry : True
cell type key in registry: False
alpha: 0.025, beta1: 100000.0, beta2: 100000.0, beta3: 1000000.0


In [6]:
mvi_p.train(use_gpu=True, batch_size=512, max_epochs=20)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_deprecation(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Epoch 1/20:   0%|          | 0/20 [00:00<?, ?it/s]

  x = torch.where(mask_expr.T, x_expr.T, x_acc.T).T


coupled  ATAC->RNA 0.01130282635659841, RNA->ATAC 0.011302880666233077
Lagging ATAC->RNA score 0.011302921511170043, RNA->ATAC 0.011302948000883526
a2rscore_coupled_loss: 1.469717364340159 , r2ascore_coupled_loss: 1.4697119333766926 /n a2rscore_lagging_loss: 0.0,  a2r_r2a_score_loss: 0.0999735102865165
independent distance ATAC-RNA -2.6489713483499044e-08
coupled  ATAC->RNA 0.01145673093545698, RNA->ATAC 0.011456415637174667
Lagging ATAC->RNA score 0.011456858533772535, RNA->ATAC 0.01145630549800588
a2rscore_coupled_loss: 1.454326906454302 , r2ascore_coupled_loss: 1.4543584362825335 /n a2rscore_lagging_loss: 0.0,  a2r_r2a_score_loss: 0.10055303576665563
independent distance ATAC-RNA 5.530357666556213e-07
coupled  ATAC->RNA 0.009946996776855968, RNA->ATAC 0.009947041744789949
Lagging ATAC->RNA score 0.009947109025225834, RNA->ATAC 0.009947361428767697
a2rscore_coupled_loss: 1.6053003223144033 , r2ascore_coupled_loss: 1.6052958255210052 /n a2rscore_lagging_loss: 0.0,  a2r_r2a_score_loss:

In [8]:
mvi_p.save("./models/alpha025beta100_100epoch.pt")

In [6]:
## load models of HALO
mvi_p = mvi_p.load("models/alpha025beta100_100epoch.pt", adata=adata_mvi)

[34mINFO    [0m File models/alpha025beta100_100epoch.pt/model.pt already downloaded                 
time key in registry : True
cell type key in registry: False
alpha: 0.01, beta1: 10000.0, beta2: 10000.0, beta3: 10000.0


In [7]:
mvi_p.train(use_gpu=True, batch_size=512, max_epochs=20)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_deprecation(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Epoch 1/20:   0%|          | 0/20 [00:00<?, ?it/s]

  x = torch.where(mask_expr.T, x_expr.T, x_acc.T).T


coupled  ATAC->RNA 0.010294326900044178, RNA->ATAC 0.01029448154105592
Lagging ATAC->RNA score 0.010294368026266883, RNA->ATAC 0.010294321507133665
a2rscore_coupled_loss: 7.056730999558226 , r2ascore_coupled_loss: 7.0551845894408025 /n a2rscore_lagging_loss: 3.9436802626688245,  a2r_r2a_score_loss: 1.000465191332174
independent distance ATAC-RNA 4.651913321739032e-08
coupled  ATAC->RNA 0.01012786053782926, RNA->ATAC 0.0101279280146607
Lagging ATAC->RNA score 0.01012786424322838, RNA->ATAC 0.010128030175755155
a2rscore_coupled_loss: 8.72139462170741 , r2ascore_coupled_loss: 8.720719853393009 /n a2rscore_lagging_loss: 2.2786424322838026,  a2r_r2a_score_loss: 0.9983406747322556
independent distance ATAC-RNA -1.6593252677443715e-07
coupled  ATAC->RNA 0.00990263617071881, RNA->ATAC 0.009902818959073021
Lagging ATAC->RNA score 0.00990257867696763, RNA->ATAC 0.009903061446583336
a2rscore_coupled_loss: 10.973638292811902 , r2ascore_coupled_loss: 10.97181040926979 /n a2rscore_lagging_loss: 0.02

In [8]:

latent_atac, latent_expr, latent_atac_dep, latent_expr_dep, latent_atac_indep, latent_expr_indep, times = mvi_p.get_latent_representation()

In [9]:
mvi_p.train_statics

False

In [9]:
from complementary_models.infer_nonsta_dir import infer_nonsta_dir

score1, _, _ = infer_nonsta_dir(latent_atac_dep, latent_expr_dep, times)
score1

0.009121775391051882

In [10]:
score2, _, _= infer_nonsta_dir(latent_expr_dep, latent_atac_dep, times)
score2

0.009121665715569438

In [11]:
score3, _, _ = infer_nonsta_dir(latent_atac_indep, latent_expr_indep, times)
score3

0.009121652947142582

In [12]:
score4, _, _ = infer_nonsta_dir(latent_expr_indep, latent_atac_indep, times)
score4

0.009121531946050367

In [13]:
score3 - score4

1.2100109221535593e-07