# Getting started with Starling (ST)


In [None]:
import os
import json
import argparse

import anndata as ad
import pandas as pd
import numpy as np
import scanpy as sc

import pytorch_lightning as pl
from lightning_lite import seed_everything
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from sklearn.metrics import adjusted_rand_score
from sklearn.preprocessing import MinMaxScaler

from starling import starling, utility, label_mapper

In [None]:
INTEGRATION_METHOD = 'exprs'
DATASET = 'IMMUcan_2022_CancerExample'

UNLABELED_CELL_TYPES = ['unlabeled', 'undefined', 'unknown', 'BnTcell', "BnT cell"]
COLUMNS_OF_INTEREST = ['sample_id', 'object_id', 'cell_type', 'init_label', 'st_label', 'doublet', 'doublet_prob', 'max_assign_prob', 'st_prob_list']

EXPERIMENT_DIR = f"/home/dani/Documents/Thesis/Methods/IMCBenchmark/output/{DATASET}/starling/{INTEGRATION_METHOD}"
CONFIG_PATH = os.path.join(EXPERIMENT_DIR, 'config.json')

# load the params
with open(CONFIG_PATH) as f:
    config = json.load(f)

In [None]:
parser = argparse.ArgumentParser(description='starling')

args = parser.parse_args(args=[])
args.dataset = config['dataset']
args.init_clustering_method = config['init_clustering_method']
args.error_free_cells_prop = config['error_free_cells_prop']
args.epochs = config['epochs']
args.lr = config['lr']
args.num_classes = config['num_classes']
args.seed = config['seed']

In [None]:
args

## Setting seed for everything


In [None]:
seed_everything(args.seed, workers=True)

## Load data


In [None]:
adata = ad.read_h5ad(args.dataset)

adata.obs

### Scale expression data

In [None]:
scaler = MinMaxScaler(feature_range=(0, 1))
X_scaled_df = scaler.fit_transform(adata.X)
adata.X = X_scaled_df

### Annotate initial clustering with KM clustering results

In [None]:
print(f'Initial cluster annotation using `{args.init_clustering_method}` algorithm.')
labels = np.array(adata.obs.get('user_init_label'))
num_classes = len(np.unique(labels))
adata = utility.init_clustering(args.init_clustering_method, adata, 
                                k=num_classes, 
                                labels=labels)

assert "init_exp_centroids" in adata.varm
assert adata.varm["init_exp_centroids"].shape == (adata.X.shape[1], num_classes)

assert "init_exp_centroids" in adata.varm
assert adata.varm["init_exp_variances"].shape == (adata.X.shape[1], num_classes)

assert "init_label" in adata.obs
assert adata.obs["init_label"].shape == (adata.X.shape[0],)

labeled_obs = adata.obs[~adata.obs['cell_type'].isin(UNLABELED_CELL_TYPES)]
print("Init ARI:", adjusted_rand_score(labeled_obs['cell_type'], labeled_obs['init_label']))

## Setting initializations


In [None]:
st = starling.ST(adata, learning_rate=args.lr, singlet_prop=args.error_free_cells_prop)


A list of parameters are shown:

- adata: annDATA object of the sample
- dist_option (default: 'T'): T for Student-T (df=2) and N for Normal (Gaussian)
- the proportion of anticipated segmentation error free cells (default: 0.6)
- model_cell_size (default: 'Y'): Y for incoporating cell size in the model and N otherwise
- cell_size_col_name (default: 'area'): area is the column name in anndata.obs dataframe
- model_zplane_overlap (default: 'Y'): Y for modeling z-plane overlap when cell size is modelled and N otherwise
  Note: if the user sets model_cell_size = 'N', then model_zplane_overlap is ignored
- model_regularizer (default: 1): Regularizier term impose on synthetic doublet loss (BCE)
- learning_rate (default: 1e-3): The learning rate of ADAM optimizer for STARLING

Equivalent as the above example:
st = starling.ST(adata, 'T', 'Y', 'area', 'Y', 1, 1e-3)


## Setting trainning log


Once training starts, a new directory 'log' will created.


In [None]:
## log training results via tensorboard
log_tb = TensorBoardLogger(save_dir="log")


One could view the training information via tensorboard. Please refer to torch lightning (https://lightning.ai/docs/pytorch/stable/api_references.html#profiler) for other possible loggers.


## Setting early stopping criterion


In [None]:
## set early stopping criterion
cb_early_stopping = EarlyStopping(monitor="train_loss", mode="min", verbose=False)


Training loss is monitored.


## Training Starling


In [None]:
## train ST
trainer = pl.Trainer(
    max_epochs=args.max_epochs,
    accelerator="auto",
    devices="auto",
    deterministic=True,
    callbacks=[cb_early_stopping],
    logger=[log_tb],
)
trainer.fit(st)


## Appending STARLING results to annData object


In [None]:
## retrive starling results
st.result()


## The following information can be retrived from annData object:

- st.adata.varm['init_exp_centroids'] -- initial expression cluster centroids (P x C matrix)
- st.adata.varm['st_exp_centroids'] -- ST expression cluster centroids (P x C matrix)
- st.adata.uns['init_cell_size_centroids'] -- initial cell size centroids if STARLING models cell size
- st.adata.uns['st_cell_size_centroids'] -- initial & ST cell size centroids if ST models cell size
- st.adata.obsm['assignment_prob_matrix'] -- cell assignment probability (N x C maxtrix)
- st.adata.obsm['gamma_prob_matrix'] -- gamma probabilitiy of two cells (N x C x C maxtrix)
- st.adata.obs['doublet'] -- doublet indicator
- st.adata.obs['doublet_prob'] -- doublet probabilities
- st.adata.obs['init_label'] -- initial assignments
- st.adata.obs['st_label'] -- ST assignments
- st.adata.obs['max_assign_prob'] -- ST max probabilites of assignments
  - N: # of cells; C: # of clusters; P: # of proteins


## Showing STARLING results


In [None]:
st.adata.obs


One could easily perform further analysis such as co-occurance, enrichment analysis and etc.


In [None]:
mapper = label_mapper.AutomatedLabelMapping(st.adata.obs['cell_type'], st.adata.obs['init_label'])

# map init label to cell type
st.adata.obs['init_label'] = mapper.get_pred_labels(st.adata.obs['init_label'])

# map st label to cell type
st.adata.obs['st_label'] = mapper.get_pred_labels(st.adata.obs['st_label'])

st.labeled_obs = st.adata.obs[~st.adata.obs['cell_type'].isin(UNLABELED_CELL_TYPES)]

In [None]:
st.adata.obs


Starling provides doublet probabilities and cell assignment if it were a singlet for each cell.


## Showing initial expression centriods:


In [None]:
## initial expression centriods (p x c) matrix
pd.DataFrame(st.adata.varm["init_exp_centroids"], index=st.adata.var_names)


There are 10 centroids since we set Kmeans (KM) as k = 10 earlier.


## Showing Starling expression centriods:


In [None]:
## starling expression centriods (p x c) matrix
pd.DataFrame(st.adata.varm["st_exp_centroids"], index=st.adata.var_names)


From here one could easily annotate cluster centriods to cell type.


## Showing Assignment Distributions:


In [None]:
## assignment distributions (n x c maxtrix)
pd.DataFrame(st.adata.obsm["assignment_prob_matrix"], index=st.adata.obs.index)


Currently, we assign a cell label based on the maximum probability among all possible clusters. However, there could be mislabeled because maximum and second highest probabilies can be very close that the user might be interested.


## Analyzing the results

In [None]:
prob_matrix = adata.obsm['assignment_prob_matrix']
prob_vector = np.array([f"[{', '.join(map(str, row))}]" for row in prob_matrix])
adata.obs['st_prob_list'] = prob_vector

results_df = adata.obs[COLUMNS_OF_INTEREST]
results_df = results_df.rename(columns={
    'sample_id': 'image_id', 
    'object_id': 'cell_id',
    'cell_type': 'label',
    'init_label': 'init_pred',
    'st_label': 'st_pred',
    'max_assign_prob': 'st_pred_prob'
})

results_df.to_csv(os.path.join(EXPERIMENT_DIR, 'starling_results.csv'), index=False)

results_df

Calculate ARI score compared to ground truth labels

In [None]:
labeled_results_df = results_df[~results_df['label'].isin(UNLABELED_CELL_TYPES)]

print("Init ARI:", adjusted_rand_score(labeled_results_df['label'], labeled_results_df['init_pred']))
print("Starling ARI:", adjusted_rand_score(labeled_results_df['label'], labeled_results_df['st_pred']))

Let us draw a UMAP plot coloured by cell type. (This may take a while because it has to run UMAP first).

In [None]:
sc.pp.neighbors(adata)
sc.tl.umap(adata)
sc.pl.umap(adata, color = ['cell_type', 'init_label', 'st_label'], size = 14, ncols = 3, wspace = 0.3)