In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [3]:
from src.model import MVAE

import numpy as np
import pandas as pd
import scanpy as sc
import muon as mu
from muon import MuData
from tqdm import tqdm, trange
import matplotlib.pyplot as plt

In [4]:
mdata = mu.read('/media/yob/gabriele/OtF-prostate/data/mdata_preprocessed.h5mu')
mdata

In [7]:
sc.pp.neighbors(mdata.mod['rna'])
sc.tl.umap(mdata.mod['rna'])

In [8]:
sc.pp.neighbors(mdata.mod['msi'])
sc.tl.umap(mdata.mod['msi'])

In [9]:
from scipy import sparse

if sparse.issparse(mdata.mod['rna'].X):
    mdata.mod['rna'].X = mdata.mod['rna'].X.A
if sparse.issparse(mdata.mod['msi'].X):
    mdata.mod['msi'].X = mdata.mod['msi'].X.A

In [10]:
mdata.obs['extra_categorical_covs']=0

# assign batch_id
mdata.obs["batch_id"] = pd.Categorical(pd.factorize(mdata.obs.loc[:,"sample"])[0])
mdata.obs["batch"] = pd.Categorical(pd.factorize(mdata.obs.loc[:,"sample"])[0])
## both modalities have same size
mdata.mod['rna'].obs['batch_id'] = mdata.obs["batch_id"]
mdata.mod['msi'].obs['batch_id'] = mdata.obs["batch_id"]
mdata.mod["rna"].uns["n_batch"] = len(mdata.mod["rna"].obs["batch_id"].cat.categories)
mdata.mod["msi"].uns["n_batch"] = len(mdata.mod["msi"].obs["batch_id"].cat.categories)

In [11]:
mdata.mod["rna"].shape

(42475, 9048)

In [12]:
from src.model import MVAEParams
from src.train import TrainParams
params = MVAEParams(beta=0.1, n_layers=2, z_dim=100, n_hidden=300)
train_params = TrainParams(batch_size=1028)
model = MVAE(mdata, params, use_cuda=True)


N batches for mod1:  16
N batches for mod2:  16
(42475, 9048)
(42475, 31975)


In [13]:
import torch
model_dict = torch.load('mvae_params.pt', map_location='cuda')

In [14]:
model.load_state_dict(model_dict)

<All keys matched successfully>

In [15]:
torch.set_num_threads(16)

In [16]:
from src.train import to_latent
mvae_emb, z1_p, z2_p, z1_p_mod, z2_p_mod, z1_s, z2_s = to_latent(model, mdata, train_params)

100%|███████████████████████████████████████████| 42/42 [00:34<00:00,  1.21it/s]


In [None]:
from src.train import predict
x1_poe, x2_poe, x1, x2, x1_2, x2_1, x1_batch_free, x2_batch_free = predict(model, mdata, train_params)


In [None]:
mdata.mod['rna'].layers['rna_poe'] = np.vstack(x1_poe)
mdata.mod['msi'].layers['msi_poe'] = np.vstack(x2_poe)
mdata.mod['rna'].layers['rna'] = np.vstack(x1)
mdata.mod['msi'].layers['msi'] = np.vstack(x2)
mdata.mod['rna'].layers['loss_msi_rna'] = np.vstack(x2_1)
mdata.mod['msi'].layers['loss_rna_msi'] = np.vstack(x1_2)
mdata.mod['rna'].layers['rna_batch_free'] = np.vstack(x1_batch_free)
mdata.mod['msi'].layers['msi_batch_free'] = np.vstack(x2_batch_free)

In [None]:
mvae_emb = [x.numpy() for x in mvae_emb]
mvae_emb = np.vstack(mvae_emb)

In [None]:
mdata.obsm['z_mvae'] = mvae_emb
mdata.obsm['z1_s'] = np.vstack(z1_s)
mdata.obsm['z2_s'] = np.vstack(z2_s)
mdata.obsm['z1_p'] = np.vstack(z1_p)
mdata.obsm['z1_p_mod'] = np.vstack(z1_p_mod)
mdata.obsm['z2_p'] = np.vstack(z2_p)
mdata.obsm['z2_p_mod'] = np.vstack(z2_p_mod)

In [None]:
mdata.obs['mod_id'] = mdata.obsm['rna'].astype(int)+(mdata.obsm['msi'].astype(int)*2)
d ={1:'rna', 2:'msi', 3:'multiome'}
mdata.obs.mod_id = mdata.obs.mod_id.map(d)
mdata.obs

In [None]:
sc.pp.neighbors(mdata, use_rep='z_mvae', n_neighbors=10)
sc.tl.umap(mdata)
mdata.obsm['X_mvae'] =  mdata.obsm['X_umap']

In [None]:
sc.pp.neighbors(mdata, n_neighbors=5, use_rep='z2_p', key_added='msi_p')
sc.tl.umap(mdata, neighbors_key='msi_p')
mdata.obsm['X_z2_p'] =  mdata.obsm['X_umap']

sc.pp.neighbors(mdata, n_neighbors=5, use_rep='z2_p_mod', key_added='msi_p_mod')
sc.tl.umap(mdata, neighbors_key='msi_p_mod')
mdata.obsm['X_z2_p_mod'] =  mdata.obsm['X_umap']

sc.pp.neighbors(mdata, n_neighbors=5, use_rep='z1_p', key_added='rna_p')
sc.tl.umap(mdata, neighbors_key='rna_p')
mdata.obsm['X_z1_p'] =  mdata.obsm['X_umap']

sc.pp.neighbors(mdata, n_neighbors=5, use_rep='z1_p_mod', key_added='rna_p_mod')
sc.tl.umap(mdata, neighbors_key='rna_p_mod')
mdata.obsm['X_z1_p_mod'] =  mdata.obsm['X_umap']

sc.pp.neighbors(mdata, n_neighbors=5, use_rep='z2_s', key_added='msi_s')
sc.tl.umap(mdata, neighbors_key='msi_s')
mdata.obsm['X_z2_s'] =  mdata.obsm['X_umap']

sc.pp.neighbors(mdata, n_neighbors=5, use_rep='z1_s', key_added='rna_s')
sc.tl.umap(mdata, neighbors_key='rna_s')
mdata.obsm['X_z1_s'] =  mdata.obsm['X_umap']

In [None]:
mdata.obs['patient'] = mdata.mod['rna'].obs['patient']
mdata.obs['tissue'] = mdata.mod['rna'].obs['tissue']
mdata.obs['seurat_clusters'] = mdata.mod['rna'].obs['seurat_clusters'].astype('category')
mdata.mod['rna'].obs['seurat_clusters'] = mdata.mod['rna'].obs['seurat_clusters'].astype('category')

In [None]:
sc.tl.leiden(mdata, resolution=0.8, key_added='r0.8')

In [None]:
mdata.mod['rna'].obs['r0.8'] = mdata.obs.loc[mdata.mod['rna'].obs.index, 'r0.8'].values

In [None]:
sc.pl.embedding(mdata, 'X_mvae', color=['tissue', 'r0.8', 'seurat_clusters', 'sample'], size=15, wspace=0.35)

In [None]:
sc.pl.embedding(mdata, 'X_z1_p', color=['tissue', 'r0.8', 'seurat_clusters', 'sample'], size=15, wspace=0.35)
