### Imports
Data is handled with numpy and with scanpy which is used for many computational biology datasets.

In [None]:
import numpy as np
import scanpy as sc

from markermap.vae_models import MarkerMap, train_model
from markermap.utils import (
    new_model_metrics,
    plot_confusion_matrix,
    split_data,
)

### Set Parameters
Define some parameters that we will use when creating the MarkerMap. 
* z_size is the dimension of the latent space in the variational auto-encoder. We always use 16
* hidden_layer_size is the dimension of the hidden layers in the auto-encoder that come before and after the latent space layer. This is dependent on the data, a good rule of thumb is ~10% of the dimension of the input data. For the CITE-seq data which has 500 columns, we will use 64
* k is the number of markers to extract
* batch_size is the batch size used by the model
* Set the file_path to wherever your data is

In [None]:
z_size = 16
hidden_layer_size = 64
k=50
batch_size=64

file_path = 'data/cite_seq/CITEseq.h5ad'

### Data
Set file_path to wherever your data is located. We then read in the data using scanpy which 
returns an AnnData object. The gene data is stored in adata.X, and the cell labels are stored in adata.obs['names']. For consistency across different datasets, we store the labels in the adata.obs['annotation'], but this can be controlled by the `group_by` variable.

We then split the data into training, validation, and test sets with a 70%, 10%, 20% split. We can use MarkerMap.prepareData to construct the train_dataloader and val_dataloader that will be used during training.

In [None]:
#get data
group_by = 'annotation'
adata = sc.read_h5ad(file_path)
adata.obs[group_by] = adata.obs['names']

# we will use 70% training data, 10% vaidation data during the training of the marker map, then 20% for testing
train_indices, val_indices, test_indices = split_data(
    adata.X,
    adata.obs[group_by],
    [0.7, 0.1, 0.2],
)
train_val_indices = np.concatenate([train_indices, val_indices])

train_dataloader, val_dataloader = MarkerMap.prepareData(
    adata,
    train_indices,
    val_indices,
    group_by,
    None, #layer, just use adata.X
    batch_size=batch_size,
)

### Define and Train the Model
Now it is time to define the MarkerMap. There are many hyperparameters than can be tuned here, but the most important are k and the loss_tradeoff. The k parameter may require some domain knowledge, but it is fairly easy to benchmark for different levels of k, as we will see in the later examples. Loss_tradeoff is also important, see the paper for a further discussion. In general, we have 3 levels, 0 (supervised only), 0.5 (mixed supervised-unsupervised) and 1 (unsupervised only). This step may take a couple of minutes.

In [None]:
supervised_marker_map = MarkerMap(
    adata.X.shape[1],
    hidden_layer_size,
    z_size,
    len(adata.obs[group_by].unique()),
    k,
    loss_tradeoff=0,
)
train_model(supervised_marker_map, train_dataloader, val_dataloader)

### Evaluate the model
Finally, we test the model. The new_model_metrics function trains a simple model such as a RandomForestClassifer on the training data restricted to the k markers, and then evaluates it on the testing data. We then print the misclassification rate, the f1-score, and plot a confusion matrix.

In [None]:
misclass_rate, test_rep, cm = new_model_metrics(
    adata[train_val_indices, :].X,
    adata[train_val_indices, :].obs[group_by],
    adata[test_indices, :].X,
    adata[test_indices, :].obs[group_by],
    markers = supervised_marker_map.markers().clone().cpu().detach().numpy(),
)

print(misclass_rate)
print(test_rep['weighted avg']['f1-score'])
plot_confusion_matrix(cm, adata.obs[group_by].unique())