### Notebook for scoring the confidence of annotations using `Annotability`

#### Environment: Annotability

- **Developed by:** Alexandra Cirnu
- **Modified by:** Alexandra Cirnu
- **Würzburg Institute for Systems Immunology & Julius-Maximilian-Universität Würzburg**
- **Date of creation:** 240415
- **Date of modification:** 240415

### Load in required modules

In [None]:
from Annotatability import metrics, models
import numpy as np
import pandas as pd
import scanpy as sc
from torch.utils.data import TensorDataset, DataLoader , WeightedRandomSampler
import torch
import torch.optim as optim
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
import muon as mu
from muon import atac as ac
from muon import prot as pt
from scipy.sparse import csr_matrix
import matplotlib.pyplot as plt

In [None]:
sc.settings.verbosity = 3             # verbosity: errors (0), warnings (1), info (2), hints (3)
sc.logging.print_versions()

sc.settings.set_figure_params(dpi = 300, color_map = 'magma_r', dpi_save = 300, vector_friendly = True, format = 'svg')
%matplotlib inline

In [None]:
SMALL_SIZE = 14
MEDIUM_SIZE = 18
BIGGER_SIZE = 20
sc.set_figure_params(scanpy=True, fontsize=16)
plt.rc('font', size=MEDIUM_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=MEDIUM_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=MEDIUM_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=MEDIUM_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=MEDIUM_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

### Load in the data set

In [None]:
input = '/home/acirnu/data/ACM_cardiac_leuco/5_Leiden_clustering_and_annotation/ACM_myeloids_clustered_muon_ac240415.raw.h5mu'
mdata = mu.read_h5mu(input)
mdata

In [None]:
adata = mdata.mod["rna"]

In [None]:
X_data = adata.X.copy()
X_data_sparse = csr_matrix(X_data)
X_data_df = pd.DataFrame.sparse.from_spmatrix(X_data_sparse, index=adata.obs.index, columns=adata.var.index)
print("Shape of counts DataFrame:", X_data_df.shape)
print(X_data_df)

### Standard preprocessing

In [None]:
adata_raw = adata.copy()

In [None]:
adata_raw = adata.copy()
adata.layers['counts'] = adata.X.copy()

sc.pp.highly_variable_genes(
    adata,
    flavor = "seurat_v3",
    n_top_genes = 7000,
    layer = "counts",
    batch_key = "donor",
    subset = True,
    span = 1
    )

adata

In [None]:
sc.pp.normalize_total(adata, target_sum = 1e6, exclude_highly_expressed = True)
sc.pp.log1p(adata)

In [None]:
X_data = adata.X.copy()
X_data_sparse = csr_matrix(X_data)
X_data_df = pd.DataFrame.sparse.from_spmatrix(X_data_sparse, index=adata.obs.index, columns=adata.var.index)
print("Shape of counts DataFrame:", X_data_df.shape)
print(X_data_df)

### Visualization

In [None]:
sc.set_figure_params(dpi =300, figsize = (10,10))
sc.pl.umap(adata,color=['classification'], frameon = False, legend_fontsize= 10, size = 20)

In [None]:
adata

### Train the neural network and monitor the traiing dynamics

In [None]:
epoch_num=50
prob_list = models.follow_training_dyn_neural_net(adata, label_key='classification',iterNum=epoch_num, device=device)

In [None]:
all_conf , all_var = models.probability_list_to_confidence_and_var(prob_list, n_obs= adata.n_obs, epoch_num=epoch_num)

#### Visualize the data map

In [None]:
plt.scatter( all_var.detach().numpy(),all_conf.detach().numpy())
plt.xlabel('variability')
plt.ylabel('confidence')
plt.show()

In [None]:
adata.obs["var"] = list(all_var.detach().numpy())
adata.obs["conf"] = list(all_conf.detach().numpy())

In [None]:
sc.pl.umap(adata,color=['conf','var','classification'])

In [None]:
# Create a FacetGrid, using 'classification' to create a subplot for each unique value
g = sns.FacetGrid(adata.obs, col="classification", col_wrap=4, height=5)
g.map_dataframe(sns.scatterplot, x="var", y="conf")

# Add a main title and adjust spacing
g.fig.suptitle('Scatter Plots by classification', fontsize=16)
g.fig.subplots_adjust(top=0.97) 

# Show the plot
plt.show()


Find the cutoff, cells with confidence lower than the cutoff and varibility lower than the cutoff will be classified as hard-to-learn.

In [None]:
cutoff_conf, cutoff_var = models.find_cutoff_paramter(adata,'classification', device=device, probability=0.05,percentile=50, epoch_num=epoch_num)

In [None]:
cutoff_conf, cutoff_var

In [None]:
adata.obs['conf_binaries'] = pd.Categorical((adata.obs['conf'] > cutoff_conf) |  (adata.obs['var'] > cutoff_var))
adata.obs['conf_binaries'].value_counts()

Mark which cells are correctly annotated. We will define cells that are not either correctly annotated or erroneously annotated as ambiguous annotation.

In [None]:
adata.obs['conf_correct'] = pd.Categorical((adata.obs['conf'] > 0.95) &  (adata.obs['var'] < 0.2))  #choose values that show in any case correctly labeled cells
adata.obs['conf_correct'].value_counts()

In [None]:
corr_classified_list =[]
for i in range(adata.n_obs):
    if adata.obs['conf_binaries'][i]==False:
        corr_classified_list.append('Erroneously annotated')
    else:
        if adata.obs['conf_correct'][i]==False:
            corr_classified_list.append('Ambiguous annotation')
        else:
            corr_classified_list.append('Correctly annotated')

adata.obs['Annotation']=corr_classified_list
adata.obs['Annotation'].value_counts()

In [None]:
adata.obs['Confidence']=adata.obs['conf']
adata.obs['Variability']=adata.obs['var']

In [None]:
fig = sns.jointplot(data=adata.obs, x="Variability", y="Confidence",hue='Annotation',s=25)
plt.show(fig)

In [None]:
sc.set_figure_params(dpi =300, figsize = (10,10))
sc.pl.umap(adata,color=['conf', 'batch', 'conf_binaries', 'Annotation'], ncols = 2, size = 15, frameon= False)

In [None]:
sc.set_figure_params(dpi =300, figsize = (10,10))
sc.pl.umap(adata,color=['classification', 'classification'], legend_fontsize= 10, size = 15, frameon= False)

### Predict true labels

In [None]:
hdata = models.predict_true_labels(adata, label='classification', device=device, epoch_num=50)

In [None]:
hdata

In [None]:
adata_false_annotation= adata[adata.obs['conf_binaries'].isin([False])]
adata_true_annotation= adata[adata.obs['conf_binaries'].isin([True])]

In [None]:
hdata.obs['changed_anno'] = (hdata.obs['classification']==hdata.obs['CorrectedCellType'])
hdata.obs['changed_anno'].value_counts()

In [None]:
adata_did_not_changed= hdata[hdata.obs['changed_anno'].isin([True])]
adata_did_not_changed = adata_did_not_changed[adata_did_not_changed.obs['conf_binaries'].isin([False])]
adata_did_not_changed

In [None]:
adata_annotation_changed= hdata[hdata.obs['changed_anno'].isin([False])]
adata_annotation_changed

In [None]:
adata.obs["classification"].cat.categories

In [None]:
adata_annotation_changed_no_per= adata_annotation_changed[adata_annotation_changed.obs['CorrectedCellType'].isin(['DC', 'DOCK4+MØ', 'LYVE1+MØ', 'Mast', 'Monocytes', 'MØ_general','Neutrophils'])]

In [None]:
adata_false_annotated= hdata[hdata.obs['conf_binaries'].isin([False])]

In [None]:
tmp_tdata2 = hdata[hdata.obs['CellType'].isin(['Excitatory'])]
False_or_pos = []
for i in range(tmp_tdata2.n_obs):
    if tmp_tdata2.obs['conf_binaries'][i]:
        False_or_pos.append('  Correct annotation')
    else:
        False_or_pos.append('  Erroneous annotation')

tmp_tdata2.obs['Celltype_to_corrected'] = False_or_pos


Celltype_to_corrected =[]
for i in range(adata_false_annotated.n_obs):
    if adata_false_annotated.obs['CorrectedCellType'][i]=='Inhibitory':
        string = " " +str(adata_false_annotated.obs['CellType'][i]) + "_" + str(adata_false_annotated.obs['CorrectedCellType'][i])
    elif adata_false_annotated.obs['CorrectedCellType'][i]=='Excitatory':
        string = " " +str(adata_false_annotated.obs['CellType'][i]) + "_" + str(adata_false_annotated.obs['CorrectedCellType'][i])
    else:
        string = str(adata_false_annotated.obs['CellType'][i]) + "_" + str(adata_false_annotated.obs['CorrectedCellType'][i])
    Celltype_to_corrected.append(string)
adata_false_annotated.obs['Celltype_to_corrected'] = Celltype_to_corrected

tmp_tdata = adata_false_annotated[adata_false_annotated.obs['CellType'].isin(['Excitatory'])]
Celltype_to_corrected =[]
for i in range(tmp_tdata.n_obs):
        if tmp_tdata.obs['conf_binaries'][i]==False:        
            if tmp_tdata.obs['CorrectedCellType'][i]=='Inhibitory': 
                string = " Corrected from " + str(tmp_tdata.obs['CellType'][i]) + " to " + str(tmp_tdata.obs['CorrectedCellType'][i])
            elif tmp_tdata.obs['CorrectedCellType'][i]=='Excitatory': 
                string = " Corrected from " + str(tmp_tdata.obs['CellType'][i]) + " to " + str(tmp_tdata.obs['CorrectedCellType'][i])
            else:
                string = "Corrected from " + str(tmp_tdata.obs['CellType'][i]) + " to " + str(tmp_tdata.obs['CorrectedCellType'][i])
            Celltype_to_corrected.append(string)

tmp_tdata.obs['Celltype_to_corrected'] = Celltype_to_corrected

marker_genes_dict = {
                     'Inhibitory neurons':['Gad1'],#, 'Gad2', 'Slc32a1'],
                    'Excitatory neurons': ['Slc17a6'],
                     'Astrocytes': ['Aqp4'],
                     'Endothelial':['Fn1'],
                     'Ependymal':['Cd24a'],
                     'Microglia':['Selplg'],
                     'OD Immature': ['Pdgfra'],
                     'OD Mature':['Ttyh2','Mbp'],
                     'Pericytes':['Myh11']}


tmp_tdata3 = tmp_tdata2.concatenate(tmp_tdata)