# Reproducing results conditional tabula muris single cell diffusion

## Before running the notebook, ensure you have ran the following commands, as described in detail in the README found in Jean_Instructions folder

## You have trained the VAE

```jsx
echo "Training Autoencoder, this might take a long time"
CUDA_VISIBLE_DEVICES=0 python /path/to/VAE_train.py --data_dir '/path/where/you/saved/tabula_muris/all.h5ad' --num_genes 18996 --state_dict '/path/where/you/saved/scimilarity/pretrained/wrights/annotation_model_v1' --save_dir '/dir/where/to/save/the/trained/VAE/model/' --max_steps 200000 --max_minutes 600
echo "Training Autoencoder done"
```

## You have trained the diffusion model

```jsx
echo "Training diffusion backbone"
CUDA_VISIBLE_DEVICES=0 python path/to/cell_train.py --data_dir '/path/where/you/saved/tabula_muris/all.h5ad' --vae_path '/path/where/you/saved/VAE/model.pt'   \
    --save_dir '/dir/where/to/save/the/trained/diffusion/model/' --model_name 'name_you_want_to_give' --lr_anneal_steps 80000
echo "Training diffusion backbone done"
```

## You have trained the classifier 

```jsx
echo "Training classifier"
CUDA_VISIBLE_DEVICES=0 python /path/to/classifier_train.py --data_dir '/path/where/you/saved/tabula_muris/all.h5ad' --vae_path '/path/where/you/saved/the/VAE/model.pt' --model_path '/path/where/you/want/to/save/the/classifier_model' --iterations 400_000 
echo "Training classifier, done"
```

## You created latent spaces

```jsx
# Conditional sampling to get the .npz
python /path/to/classifier_sample.py --classifier_path "/path/to/classifier/model.pt" --model_path "/path/to/diffusion/model.pt" --sample_dir "/dir/where/to/save/the/generated/latent/spaces" --num_gene 18996
````

## Then you can run the following commands 

In [1]:
### CHANGE ACCORDING TO YOUR FILE SYSTEM ###
path_to_anndata = '/path/where/you/saved/tabula_muris/all.h5ad'
path_to_saved_VAE_model = '/path/where/you/saved/the/VAE/model.pt'
path_to_conditional_sample = '/dir/where/you/saved/the/generated/latent/spaces'

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import anndata as ad
import scanpy as sc
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import random
from scipy import stats
import torch
import sys
sys.path.append('/workspace/projects/001_scDiffusion/scripts/scDiffusion') ### CHANGE ACCORDING TO YOUR SYSTEM
from VAE.VAE_model import VAE

In [None]:
def load_VAE():
    autoencoder = VAE(
        num_genes=18996,
        device='cpu',
        seed=0,
        loss_ae='mse',
        hidden_dim=128,
        decoder_activation='ReLU',
    )
    autoencoder.load_state_dict(torch.load(path_to_saved_VAE_model, map_location=torch.device('cpu')))
    return autoencoder

In [None]:
 # Load generated data
gen_data = []
gen_data_cell_type = []
tissues = ['Bladder', 'Heart_and_Aorta', 'Kidney', 'Limb_Muscle', 'Liver',
       'Lung', 'Mammary_Gland', 'Marrow', 'Spleen', 'Thymus', 'Tongue',
       'Trachea']
for tissue in tissues:
	npzfile=np.load(f'{path_to_conditional_sample}/tissue_{tissue}.npz',allow_pickle=True)
	number_of_cells = adata[adata.obs['celltype']==tissue].X.shape[0]
	gen_data.append(npzfile['cell_gen'][:number_of_cells])#.squeeze(1)
	gen_data_cell_type+=['gen '+tissue]*number_of_cells
	gen_data = np.concatenate(gen_data,axis=0)

autoencoder = load_VAE()
gen_data = autoencoder(torch.tensor(gen_data),return_decoded=True).cpu().detach().numpy()

In [None]:
# Concatenate, filter and compute umap
adata = np.concatenate((real_data, gen_data),axis=0)
adata = ad.AnnData(adata, dtype=np.float32)
adata.obs['cell_type'] = np.concatenate((real_data_cell_type, gen_data_cell_type))
adata.obs['data_type'] = [f"real" for i in range(real_data.shape[0])]+[f"generated" for i in range(gen_data.shape[0])]
sc.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=3, min_disp=0.5)
adata.raw = adata
adata = adata[:, adata.var.highly_variable]
sc.pp.scale(adata)
sc.tl.pca(adata, svd_solver='arpack')
sc.pp.neighbors(adata, n_neighbors=10, n_pcs=20)
sc.tl.umap(adata)

In [None]:
# Plot the umaps
for tissue in tissues:
    color_dict = {}
    for cat in np.unique(adata.obs['cell_type'].values):
        if cat == tissue:
            color_dict[cat] = 'tab:orange'
        elif cat == 'gen '+tissue:
            color_dict[cat] = 'tab:blue'
        else:
            color_dict[cat] = 'black'
    sc.pl.umap(adata=adata,color="cell_type",groups=[tissue,'gen '+tissue],size=8, palette=color_dict, show = True, title = tissue)