# Reproducing results gradient interpolation WOT 

<div style="margin-top: 20px; margin-bottom: 20px; text-align: center;">
    <img src="images/gradient.png" alt="Image description" style="max-width: 50%; height: auto;">
</div>

## Before running the notebook, ensure you have ran the following commands

## You have trained the VAE

```jsx
echo "Training Autoencoder, this might take a long time"
CUDA_VISIBLE_DEVICES=0 python /path/to/VAE_train.py --data_dir '/path/where/you/saved/data/WOT/filted_data.h5ad' --num_genes 19423 --state_dict '/path/to/pretrained/VAE/annotation_model_v1' --save_dir '/dir/where/to/save/the/VAE/model' --max_steps 50000 --max_minutes 300
echo "Training Autoencoder done"
```

First, we need to do the following change on top of **VAE_train.py** to specify we are loading WOT data:

```python
# from guided_diffusion.cell_datasets import load_data
# from guided_diffusion.cell_datasets_sapiens import load_data
from guided_diffusion.cell_datasets_WOT import load_data
# from guided_diffusion.cell_datasets_muris import load_data
````

Note that in the data-loading function the original dataset is filtered to take out data points D3.5 and D4 for training the VAE. The VAE converges very fast: 50'000 steps are more than enough.

## You have trained the diffusion model

```jsx
echo "Training diffusion backbone"
CUDA_VISIBLE_DEVICES=0 python /path/to/cell_train.py --data_dir '/path/where/you/saved/data/WOT/filted_data.h5ad'  --vae_path '/path/where/you/saved/VAE/model.pt' \
 --save_dir '/dir/where/to/save/the/diffusion/model' --model_name 'choose_any_name' --lr_anneal_steps 80000
echo "Training diffusion backbone done" 
```

Where we changed at the top of cell_train.py the following:

```python
# from guided_diffusion.cell_datasets import load_data
from guided_diffusion.cell_datasets_WOT import load_data
# from guided_diffusion.cell_datasets_sapiens import load_data
# from guided_diffusion.cell_datasets_muris import load_data
```

The model converges relatively fast, anything above 50000 steps should be vastly sufficient. 

## You have trained the classifier 

```jsx
echo "Training diffusion backbone"
CUDA_VISIBLE_DEVICES=0 python /path/to/cell_train.py --data_dir '/path/where/you/saved/data/WOT/filted_data.h5ad' --model_path '/dir/where/to/save/the/diffusion/model'  \
  --iterations 100000 --vae_path '/path/where/you/saved/VAE/model.pt' 
echo "Training diffusion backbone done" 
```

Where we changed the top of **classifier_train.py** as:

```python
# from guided_diffusion.cell_datasets import load_data
from guided_diffusion.cell_datasets_WOT import load_data
# from guided_diffusion.cell_datasets_sapiens import load_data
# from guided_diffusion.cell_datasets_muris import load_data
```

Also ensure that the default number of classes is set to 15 line 235:

```python
# In classifier_train.py line 235 ensure:
defaults['num_class']= 15

## Generating latent spaces

In classifier_sample.py, we need to change the following lines:

```python
#line 319 of classifier_sample.py
defaults['num_class'] = 15

And the main function as: 

```python
if __name__ == "__main__":
    # for Gradient Interpolation, run
    parser = create_argparser()
    args = parser.parse_args()
    save_dir = args.sample_dir
    for i in range(0,11):
        path_to_save = save_dir + f"{i}"
        to_save = main(cell_type=[6,7], inter=True, weight=[10-i,i])
        save_data(to_save, i, path_to_save)

And we can then run the following command:

```jsx
# Conditional sampling 
python /workspace/projects/001_scDiffusion/scripts/scDiffusion/classifier_sample.py --model_path "/workspace/projects/001_scDiffusion/results/waddington/diffusion_model/WOT_diffusion_model/model080000.pt" --classifier_path "/workspace/projects/001_scDiffusion/results/waddington/classifier_model/model099999.pt" --sample_dir "/workspace/projects/001_scDiffusion/results/waddington/generated_latent_spaces/" --ae_dir '/workspace/projects/001_scDiffusion/results/waddington/VAE_model/model_seed=0_step=49999.pt' --init_cell_path '/workspace/projects/001_scDiffusion/data/data_in/data/WOT/filted_data.h5ad' --num_gene 19423 

## You have downladed GSM3195672_D6_Dox_C1_gene_bc_mat.h5

Go to: https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSM3195672. At the bottom of the page, download the file in Supplementary file

## Then you can run the following command to get the umap

In [None]:
!CUDA_VISIBLE_DEVICE=0

In [None]:
### CHANGE IN FUNCTION OF YOUR FILE SYSTEM ###
vae_path = '/path/where/you/saved/the/VAE/model.pt'
path_to_wot_data = '/path/where/you/saved/WOT/filted_data.h5ad'


In [None]:
%config InlineBackend.figure_format = 'retina'
%matplotlib inline
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import anndata as ad
import scanpy as sc
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import random
from scipy import stats
import torch
import sys
sys.path.append('/workspace/projects/001_scDiffusion/scripts/scDiffusion/') ### CHANGE ACCORDING TO YOUR FILE SYSTEM
from VAE.VAE_model import VAE
import celltypist

In [None]:
def load_VAE():
    autoencoder = VAE(
        num_genes=19423,
        device='cpu',
        seed=0,
        loss_ae='mse',
        hidden_dim=128,
        decoder_activation='ReLU',
    )
    autoencoder.load_state_dict(torch.load(vae_path, map_location=torch.device('cpu')))
    return autoencoder

In [None]:
# load WOT dataset
real_adata = sc.read_h5ad('path_to_wot_data')
real_adata = real_adata[real_adata.obs.period.isin(["D0","D0.5","D1","D1.5","D2","D2.5","D3","D3.5","D4","D4.5","D5","D5.5","D6","D6.5","D7","D7.5","D8"])]
real_adata = real_adata[::5]
tmp_ = sc.read_10x_h5('/workspace/projects/001_scDiffusion/data/data_in/data/WOT/GSM3195672_D6_Dox_C1_gene_bc_mat.h5')
real_adata.var_names_make_unique()
gene_names = tmp_.var_names[real_adata.var_names.astype(np.int32)]
sc.pp.normalize_total(real_adata, target_sum=1e4)
sc.pp.log1p(real_adata)
real_adata_period = real_adata.obs['period']

In [None]:
# load generated cells with different Gradient Interpolations
gen_data = []
cell_stage = []
cell_dis = []

for i in range(11):
    npzfile=np.load(f'/workspace/projects/001_scDiffusion/results/waddington/generated_latent_spaces/{i}.npz',allow_pickle=True)
    length = 500
    gen_data.append(npzfile['cell_gen'][:int(length)])

    cell_stage+=[f'gen {i}']*int(length)
    cell_dis+=[float(i)]*int(length)

gen_data = np.concatenate(gen_data,axis=0)
number_of_gen_cells = gen_data.shape[0]


autoencoder = load_VAE()
gen_data = autoencoder(torch.tensor(gen_data).cpu(),return_decoded=True).cpu().detach().numpy()

gen_data.shape

In [None]:
adata = np.concatenate((real_adata.X,gen_data))
adata = ad.AnnData(adata, dtype=np.float32)
adata.obs['cell_gen_type'] = [f"real_cell" for i in range(real_adata.X.shape[0])]+[f"gen_data" for i in range(gen_data.shape[0])]
adata.obs['cell_period'] = pd.Categorical(list(real_adata_period.values)+cell_stage)
celldis = np.concatenate(([np.nan]*real_adata.X.shape[0], cell_dis)).astype(np.float32)
adata.obs['cell_dis'] = celldis

In [None]:
sc.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=3, min_disp=0.5) # why only highly variable genes?
adata.raw = adata
adata = adata[:, adata.var.highly_variable]

sc.pp.scale(adata)
sc.tl.pca(adata, svd_solver='arpack')

In [None]:
sc.pp.neighbors(adata, n_neighbors=10, n_pcs=20)
sc.tl.umap(adata)


In [None]:
n_cell_periods = len(np.unique(adata.obs['cell_period']))  
cmap = plt.get_cmap('viridis', n_cell_periods)  # Replace 'viridis' with your desired colormap  
colors = [cmap(i) for i in range(cmap.N)]
cell_type_color_map = dict(zip(np.unique(adata.obs['cell_period']), colors))
cell_type_color_map['D3']='tab:red'
cell_type_color_map['D3.5']='tab:orange'
cell_type_color_map['D4']='tab:blue'
cell_type_color_map['D4.5']='tab:green'
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
sc.pl.umap(adata=adata,color="cell_dis", groups=list(np.unique(adata.obs['cell_dis'])), size=20, show=True)