# Training scDGD
## An example on the 10x mouse brain 5k data set



### imports and installations

In [1]:
import torch
import numpy as np
#!pip install anndata
import anndata as ad
#!pip install scanpy
import scanpy as sc

In [None]:
# install scDGD from the repository
!pip install git+https://github.com/Center-for-Health-Data-Science/scDGD
from scDGD.classes import GaussianMixture
from scDGD.models import DGD
from scDGD.functions import prepate_data, dgd_train

### Hyperparameters

In [None]:
###
# define desired hyperparameters
###
latent_dim = 20
n_epochs = 500

# define which feature should be observed in clustering (e.g. cell type or disease state)
adata_label_column_name = 'cell_type'

## Prepare the data

In [3]:
# load the example anndata file
data_path = './data/'
!wget -P data https://zenodo.org/record/7993711/files/adata.h5ad
adata = ad.read(data_path+'adata.h5ad')

In [None]:
# prepare the data for training (and testing)
adata, trainloader, validationloader, testloader = prepate_data(
    adata,
    label_column=adata_label_column_name
)

In [None]:
# get the number of unique cell types as an initial guess for the number of clusters
labels = trainloader.dataset.get_labels()
n_celltypes = len(np.unique(labels))

## Set up the model

In [None]:
gmm = GaussianMixture(Nmix=n_celltypes, dim=latent_dim)
model = DGD(out=trainloader.dataset.n_genes, latent=latent_dim)

## Train

In [None]:
# for running in the notebook, you can analyze the performance based on the returned history dataframe

model, rep, test_rep, gmm, history = dgd_train(
    model, gmm, trainloader, validationloader, n_epochs=n_epochs,
    export_dir='./', export_name='scDGD'
)

# but it can also be run with logging to wandb (https://wandb.ai/) for more better monitoring and good project organization
'''
wandb.init(id=id, project="project_name", entity="your_userID")
wandb.run.name = "model_name"
wandb.run.save()

model, rep, test_rep, gmm, history = dgd_train(
    model, gmm, trainloader, validationloader, n_epochs=n_epochs,
    export_dir='./', export_name='scDGD',
    wandb_logging=True
)
'''

In [None]:
# plot reconstruction losses

import matplotlib.pyplot as plt

plt.plot(history['epoch'], history['train_recon_loss'], label='train')
plt.plot(history['epoch'], history['test_recon_loss'], label='validation')
plt.xlabel('epoch')
plt.ylabel('reconstruction loss')
plt.legend()
plt.show()

## Downstream use

The learned representation can be added to the anndata object and then one can continue as usual with scanpy tools like UMAP visualizations.

In [None]:
# you can add the representation to the anndata object
# and then continue as usual with scanpy visualization and analysis

adata_train = adata.copy()[adata.obs['train_val_test']=='train']
adata_train.obsm['Latent'] = rep.z.detach().cpu().numpy()

In [None]:
# plot a scanpy umap of the latent space colored by cell type
sc.pp.neighbors(adata_train, use_rep='Latent')
sc.tl.umap(adata_train)
sc.pl.umap(adata_train, color=adata_label_column_name)