# CellAssign Cell Type Assignment

Load the data

In [None]:
import os

marker_list = os.getenv("SNAKEMAKE_MARKER_LIST")
if not os.path.exists(marker_list):
    raise FileNotFoundError(f"Marker list file does not exist: {marker_list}")
data = os.getenv("SNAKEMAKE_H5AD_INPUT")
if not os.path.exists(data):
    raise FileNotFoundError(f"Input data file does not exist: {data}")
output_file = os.getenv("SNAKEMAKE_OUTPUT_FILE")  # Output csv mapping barcodes to cell type data
n_gpus = int(os.getenv("SNAKEMAKE_NUM_GPUS", "0"))  # Number of GPUs to use, default is 0 (CPU only)

print(f"Marker list file: {marker_list}")
print(f"Input data file: {data}")
print(f"Output file: {output_file}")
print(f"Number of GPUs: {n_gpus}")

In [None]:
import scanpy as sc
adata = sc.read_h5ad(data)
adata

In [None]:
import pandas as pd
marker_df = pd.read_csv(marker_list, index_col=0)
# Add an unassigned cell type
marker_df['Unassigned'] = 0
marker_df

Performing cell type assignment

In [None]:
import scvi
from scvi.external import CellAssign
scvi.settings.seed = 0
print("Last run with scvi-tools version:", scvi.__version__)

In [None]:
import numpy as np
# Strip genes not in the marker list
bdata = adata[:, adata.var_names.isin(marker_df.index.values)].copy()
marker_df = marker_df.loc[bdata.var_names]

lib_size = bdata.X.sum(1)
bdata.obs["size_factor"] = lib_size / np.mean(lib_size)
bdata

In [None]:
CellAssign.setup_anndata(bdata, size_factor_key="size_factor")
bdata

In [None]:
model = CellAssign(bdata, marker_df)
model.train(
    accelerator='cpu' if n_gpus == 0 else 'auto',
)

In [None]:
model.history['elbo_validation'].plot()

Extract assignments

In [None]:
predictions = model.predict()
predictions.head()

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
# sns.clustermap(predictions, cmap="viridis")  Skip because this is too large most of the time

In [None]:
# Initial cell type assignments
THRESHOLD = 0.9
max_idx = predictions.idxmax(axis=1)
max_prob = predictions.max(axis=1)
predictions['cell_type'] = np.where(max_prob >= THRESHOLD, max_idx, 'Unassigned')
predictions['cell'] = adata.obs.index.values
predictions = predictions.reset_index(drop=True).set_index('cell')

Initial predictions completed. We will now refine uncertain assignments by clustering the data and reassigning uncertain cells based on cluster consensus.

In [None]:
# The prepared anndata should already have a cluster key in obs
cluster_key="leiden"

for cluster in adata.obs[cluster_key].unique():
    clustered_cells = adata[adata.obs[cluster_key] == cluster].obs.index.values
    # If none are unassigned, skip
    if not predictions.loc[clustered_cells, 'cell_type'].eq('Unassigned').any():
        continue
    # Make a plot showing the average probabilities for the cluster
    cluster_predictions = predictions.loc[clustered_cells].drop('cell_type', axis=1)
    cluster_avg = cluster_predictions.mean(axis=0)
    # Add a bar plot of the average probabilities with a horizontal line at the threshold
    plt.figure(figsize=(10, 5))
    sns.barplot(x=cluster_avg.index, y=cluster_avg.values)
    plt.axhline(y=THRESHOLD, color='r', linestyle='--', label='Threshold')
    plt.title(f'Cluster {cluster} Mean Probabilities')
    plt.xlabel('Cell Type')
    plt.ylabel('Average Probability')
    plt.xticks(rotation=90)
    plt.legend()
    plt.show()
    plt.clf()

    # Select the cell type that is most confident in the cluster
    most_confident_type = cluster_avg.idxmax()
    most_confident_prob = cluster_avg.max()
    print(f"Cluster {cluster}: Most confident type is {most_confident_type} with mean probability {most_confident_prob:.2f}")
    # Reassign unassigned cells in the cluster to the most confident type
    predictions.loc[clustered_cells, 'cell_type'] = np.where(
        predictions.loc[clustered_cells, 'cell_type'] == 'Unassigned',
        most_confident_type,
        predictions.loc[clustered_cells, 'cell_type']
    )

In [None]:
# Plot the distribution of cell types
sns.countplot(data=predictions, x='cell_type', order=predictions['cell_type'].value_counts().index)
plt.xticks(rotation=90)
plt.show()

In [None]:
predictions['cell_type'].value_counts()

Plot a umap of the cell type assignments (The umap should be precomputed in the adata object)

In [None]:
adata.obs['cell_type'] = predictions['cell_type'].values
sc.pl.umap(adata, color='cell_type', frameon=False, legend_loc='on data', title='Cell Type Assignments', size=20, wspace=0.4)

Save the results

In [None]:
predictions.to_csv(output_file, index=True)