# SCENTInEL
#### sc ElasticNeT integrative ensemble learning

# SCENTEL
#### sc ElasticNeT ensemble learning

# LR multi-tissue cross-comparison

##### Ver:: A1_V5
##### Author(s) : Issac Goh
##### Date : 220823;YYMMDD
### Author notes
    - Current defaults scrpae data from web, so leave as default and run
    - slices model and anndata to same feature shape, scales anndata object
    - added some simple benchmarking
    - creates dynamic cutoffs for probability score (x*sd of mean) in place of more memory intensive confidence scoring
    - Does not have majority voting set on as default, but module does exist
    - Multinomial logistic relies on the (not always realistic) assumption of independence of irrelevant alternatives whereas a series of binary logistic predictions does not. collinearity is assumed to be relatively low, as it becomes difficult to differentiate between the impact of several variables if this is not the case
    
### Features to add
    - Add ability to consume anndata zar format for sequential learning
       - Feature assessment weighted by classifications made in query data based on bayes factor of variable expression 
    - Bayesian sampling (KNN-based)
    - Bayesian optimisation layer
    - Bayesian scoring of pribabilities in match samples
    - Weighted R2 per-class model performance scoring, compared to global.
        - In joint latent representation; 
        - Single Bayesian optimisded model
        - R2(hat) * Prob(hat) computed per model
        i.e for 3 data case:
        - Three probs: 
            - local model self-projected probs
            - Global model projection probs
            - Every other model inductive projection across
        - Does global agree with local/other model? 
        - Aggregate the global and each projection? -- Does the model perform well in global space & in each individual model? 
        - Probability of harmonisation
            Aggregate? concat?
### Modes to run in
    - Run in training mode
    - Run in projection mode

In [1]:
import sys
import subprocess

# import pkg_resources
# required = {'harmonypy','sklearn','scanpy','pandas', 'numpy', 'scipy', 'matplotlib', 'seaborn' ,'scipy'}
# installed = {pkg.key for pkg in pkg_resources.working_set}
# missing = required - installed
# if missing:
#    print("Installing missing packages:" )
#    print(missing)
#    python = sys.executable
#    subprocess.check_call([python, '-m', 'pip', 'install', *missing], stdout=subprocess.DEVNULL)

from collections import Counter
from collections import defaultdict
import scanpy as sc
import pandas as pd
import pickle as pkl
import numpy as np
import scipy
import matplotlib.pyplot as plt
import re
import glob
import os
import sys
#from geosketch import gs
from numpy import cov
import scipy.cluster.hierarchy as spc
import seaborn as sns; sns.set(color_codes=True)
from sklearn.linear_model import LogisticRegression
import sklearn
from pathlib import Path
import requests
import psutil
import random
import threading
import tracemalloc
import itertools
import math
import warnings
import sklearn.metrics as metrics

In [1]:
models = {
'pan_fetal':'/nfs/team205/ig7/resources/scripts_dont_modify/logit_regression_models/adifa_lr/celltypist_model.Pan_Fetal_Human.pkl',
'pan_fetal_wget':'https://celltypist.cog.sanger.ac.uk/models/Pan_Fetal_Suo/v2/Pan_Fetal_Human.pkl',
'adata_scvi':'/nfs/team205/ig7/mount/gdrive/g_cloud/projects/amniontic_fluid/scvi_low_dim_model.sav',
'adata_ldvae':'/nfs/team205/ig7/mount/gdrive/g_cloud/projects/amniontic_fluid/ldvae_low_dim_model.sav',
'adata_harmony':'/nfs/team205/ig7/work_backups/backup_210306/projects/amiotic_fluid/train_low_dim_model/organ_low_dim_model.sav',
'test_low_dim_ipsc_ys':'/nfs/team205/ig7/work_backups/backup_210306/projects/YS/YS_030522_notebooks/Integrating_HM_data_030522/YS_logit/lr_model.sav',
'YS_X':'/nfs/team205/ig7/work_backups/backup_210306/projects/YS/resources/YS_X_model_080922.sav',
'YS_X_V3':'/nfs/team205/ig7/work_backups/backup_210306/projects/YS/rebuttal_figs_010922/train_YS_full_X_model/YS_X_A2_V12_lvl3_ELASTICNET_YS.sav',
'SK_model':'/nfs/team205/ig7/resources/scripts_dont_modify/logit_regression_models/LR_app_format/hudaa_skin/for_hudaa_A1_V2',
'Hudaa_model_trained':'/nfs/team298/hg6/Fetal_skin/LR_15012023/train-all_model.pkl',

}

adatas_dict = {
'Fetal_skin_raw': '/nfs/team298/hg6/Fetal_skin/data/FS_raw_sub.h5ad',
'vascular_organoid': '/nfs/team298/hg6/Fetal_skin/data/vasc_org_raw.h5ad',
'YS':'/nfs/team205/ig7/work_backups/backup_210306/projects/YS/YS_data/Submission_2_data/A2_V5_scvi_YS_integrated/A2_V5_scvi_YS_integrated_raw_qc_scr_umap.h5ad',
'YS_test':'/nfs/team205/ig7/resources/scripts_dont_modify/logit_regression_models/LR_app_format/ys_test_data.h5ad',
'YS_A2_V10_X_raw':'/nfs/team205/ig7/work_backups/backup_210306/projects/YS/YS_data/Submission_2_data/A2_V10_scvi_YS_integrated/A2_V10_raw_counts_full_no_obs.h5ad',
'YS_A2_V10_X':'/nfs/team205/ig7/work_backups/backup_210306/projects/YS/YS_data/Submission_2_data/A2_V10_scvi_YS_integrated/A2_V10_qc_raw.h5ad'
}

# Variable assignment
train_model = False
feat_use = 'joint_annotation_20220202'
adata_key = 'Fetal_skin_raw'#'fliv_wget_test' # key for dictionary entry containing local or web path to adata/s can be either url or local 
data_merge = False # read and merge multiple adata (useful, but keep false for now)
model_key = 'SK_model'#'test_low_dim_ipsc_ys'# key for model of choice can be either url or local 
train_x_partition = 'X' # what partition was the data trained on? To keep simple, for now only accepts 'X'
dyn_std = 1.96 # Dynamic cutoffs using std of the mean for each celltype probability, gives a column notifying user of uncertain labels 1 == 68Ci, 1.96 = 95CI
freq_redist = 'joint_annotation_20220202'#'cell.labels'#'False#'cell.labels'#False # False or key of column in anndata object which contains labels/clusters // not currently implemented
partial_scale = True # should data be scaled in batches?
QC_normalise = True # should data be normalised?

# training variables
penalty='elasticnet' # can be ["l1","l2","elasticnet"]
sparcity=0.5 # C penalty for degree of regularisation
thread_num = -1
l1_ratio = 0.5 # ratio between L1 and L2 regulrisatiuon depending on penatly method

# Partial scaling ver
- scale across 10 mini bulks/every 100,000 cells
- sequential learning for scaling
- sequential application of scaling

In [391]:
from collections import Counter
from collections import defaultdict
import scanpy as sc
import pandas as pd
import pickle as pkl
import numpy as np
import scipy
import matplotlib.pyplot as plt
import re
import glob
import os
import sys
#from geosketch import gs
from numpy import cov
import scipy.cluster.hierarchy as spc
import seaborn as sns; sns.set(color_codes=True)
from sklearn.linear_model import LogisticRegression
import sklearn
from pathlib import Path
import requests
import psutil
import random
import threading
import tracemalloc
import itertools
import math
import warnings
import sklearn.metrics as metrics
import numpy as np
from sklearn.metrics import log_loss
import mygene
import gseapy as gp
import mygene
import scipy.sparse as sparse
from sklearn.metrics.pairwise import cosine_similarity

def load_models(model_dict,model_run):
    if (Path(model_dict[model_run])).is_file():
        # Load data (deserialize)
        model = pkl.load(open(model_dict[model_run], "rb"))
        return model
    elif 'http' in model_dict[model_run]:
        print('Loading model from web source')
        r_get = requests.get(model_dict[model_run])
        fpath = './model_temp.sav'
        open(fpath , 'wb').write(r_get.content)
        model = pkl.load(open(fpath, "rb"))
        return model

def load_adatas(adatas_dict,data_merge, data_key_use,QC_normalise):
    if data_merge == True:
        # Read
        gene_intersect = {} # unused here
        adatas = {}
        for dataset in adatas_dict.keys():
            if 'https' in adatas_dict[dataset]:
                print('Loading anndata from web source')
                adatas[dataset] = sc.read('./temp_adata.h5ad',backup_url=adatas_dict[dataset])
            adatas[dataset] = sc.read(data[dataset])
            adatas[dataset].var_names_make_unique()
            adatas[dataset].obs['dataset_merge'] = dataset
            adatas[dataset].obs['dataset_merge'] = dataset
            gene_intersect[dataset] = list(adatas[dataset].var.index)
        adata = list(adatas.values())[0].concatenate(list(adatas.values())[1:],join='inner')
        return adatas, adata
    elif data_merge == False:
        if 'https' in adatas_dict[data_key_use]:
            print('Loading anndata from web source')
            adata = sc.read('./temp_adata.h5ad',backup_url=adatas_dict[data_key_use])
        else: 
            adata = sc.read(adatas_dict[data_key_use])
    if QC_normalise == True:
        print('option to apply standardisation to data detected, performing basic QC filtering')
        sc.pp.filter_cells(adata, min_genes=200)
        sc.pp.filter_genes(adata, min_cells=3)
        sc.pp.normalize_per_cell(adata, counts_per_cell_after=1e4)
        sc.pp.log1p(adata)
        
    return adata

# resource usage logger
class DisplayCPU(threading.Thread):
    def run(self):
        tracemalloc.start()
        starting, starting_peak = tracemalloc.get_traced_memory()
        self.running = True
        self.starting = starting
        currentProcess = psutil.Process()
        cpu_pct = []
        peak_cpu = 0
        while self.running:
            peak_cpu = 0
#           time.sleep(3)
#             print('CPU % usage = '+''+ str(currentProcess.cpu_percent(interval=1)))
#             cpu_pct.append(str(currentProcess.cpu_percent(interval=1)))
            cpu = currentProcess.cpu_percent()
        # track the peak utilization of the process
            if cpu > peak_cpu:
                peak_cpu = cpu
                peak_cpu_per_core = peak_cpu/psutil.cpu_count()
        self.peak_cpu = peak_cpu
        self.peak_cpu_per_core = peak_cpu_per_core
        
    def stop(self):
#        cpu_pct = DisplayCPU.run(self)
        self.running = False
        current, peak = tracemalloc.get_traced_memory()
        tracemalloc.stop()
        return current, peak
    
# projection module
def reference_projection(adata, model, dyn_std,partial_scale):
    
    class adata_temp:
        pass
    from sklearn.preprocessing import StandardScaler
    print('Determining model flavour')
    try:
        model_lr =  model['Model']
        print('Consuming celltypist model')
    except:# hasattr(model, 'coef_'):
        print('Consuming non-celltypist model')
        model_lr =  model
    print(model_lr)
    if train_x_partition == 'X':
        print('Matching reference genes in the model')
        k_x = np.isin(list(adata.var.index), list(model_lr.features))
        if k_x.sum() == 0:
            raise ValueError(f"🛑 No features overlap with the model. Please provide gene symbols")
        print(f"🧬 {k_x.sum()} features used for prediction")
        #slicing adata
        k_x_idx = np.where(k_x)[0]
        # adata_temp = adata[:,k_x_idx]
        adata_temp.var = adata[:,k_x_idx].var
        adata_temp.X = adata[:,k_x_idx].X
        adata_temp.obs = adata[:,k_x_idx].obs
        lr_idx = pd.DataFrame(model_lr.features, columns=['features']).reset_index().set_index('features').loc[list(adata_temp.var.index)].values
        # adata_arr = adata_temp.X[:,list(lr_idexes['index'])]
        # slice and reorder model
        ni, fs, cf = model_lr.n_features_in_, model_lr.features, model_lr.coef_
        model_lr.n_features_in_ = lr_idx.size
        model_lr.features = np.array(model_lr.features)[lr_idx]
        model_lr.coef_ = np.squeeze(model_lr.coef_[:,lr_idx]) #model_lr.coef_[:, lr_idx]
        if partial_scale == True:
            print('scaling input data, default option is to use incremental learning and fit in mini bulks!')
            # Partial scaling alg
            scaler = StandardScaler(with_mean=False)
            n = adata_temp.X.shape[0]  # number of rows
            # set dyn scale packet size
            x_len = len(adata_temp.var)
            y_len = len(adata.obs)
            if y_len < 100000:
                dyn_pack = int(x_len/10)
                pack_size = dyn_pack
            else:
                # 10 pack for every 100,000
                dyn_pack = int((y_len/100000)*10)
                pack_size = int(x_len/dyn_pack)
            batch_size =  1000#pack_size#500  # number of rows in each call to partial_fit
            index = 0  # helper-var
            while index < n:
                partial_size = min(batch_size, n - index)  # needed because last loop is possibly incomplete
                partial_x = adata_temp.X[index:index+partial_size]
                scaler.partial_fit(partial_x)
                index += partial_size
            adata_temp.X = scaler.transform(adata_temp.X)
    # model projections
    print('Starting reference projection!')
    if train_x_partition == 'X':
        train_x = adata_temp.X
        pred_out = pd.DataFrame(model_lr.predict(train_x),columns = ['predicted'],index = list(adata.obs.index))
        proba =  pd.DataFrame(model_lr.predict_proba(train_x),columns = model_lr.classes_,index = list(adata.obs.index))
        pred_out = pred_out.join(proba)
        
    elif train_x_partition in list(adata.obsm.keys()): 
        print('{low_dim: this partition modality is still under development!}')
        train_x = adata.obsm[train_x_partition]
        pred_out = pd.DataFrame(model_lr.predict(train_x),columns = ['predicted'],index = list(adata.obs.index))
        proba =  pd.DataFrame(model_lr.predict_proba(train_x),columns = model_lr.classes_,index = list(adata.obs.index))
        pred_out = pred_out.join(proba)
    
    else:
        print('{this partition modality is still under development!}')
    ## insert modules for low dim below

    # Simple dynamic confidence calling
    pred_out['confident_calls'] = pred_out['predicted']
    pred_out.loc[pred_out.max(axis=1)<(pred_out.mean(axis=1) + (1*pred_out.std(axis=1))),'confident_calls'] = pred_out.loc[pred_out.max(axis=1)<(pred_out.mean(axis=1) + (1*pred_out.std(axis=1))),'confident_calls'].astype(str) + '_uncertain'
    # means_ = self.model.scaler.mean_[lr_idx] if self.model.scaler.with_mean else 0
    return(pred_out,train_x,model_lr,adata_temp)

def freq_redist_68CI(adata,clusters_reassign):
    if freq_redist != False:
        print('Frequency redistribution commencing')
        cluster_prediction = "consensus_clus_prediction"
        lr_predicted_col = 'predicted'
        pred_out[clusters_reassign] = adata.obs[clusters_reassign].astype(str)
        reassign_classes = list(pred_out[clusters_reassign].unique())
        lm = 1 # lambda value
        pred_out[cluster_prediction] = pred_out[clusters_reassign]
        for z in pred_out[clusters_reassign][pred_out[clusters_reassign].isin(reassign_classes)].unique():
            df = pred_out
            df = df[(df[clusters_reassign].isin([z]))]
            df_count = pd.DataFrame(df[lr_predicted_col].value_counts())
            # Look for classificationds > 68CI
            if len(df_count) > 1:
                df_count_temp = df_count[df_count[lr_predicted_col]>int(int(df_count.mean()) + (df_count.std()*lm))]
                if len(df_count_temp >= 1):
                    df_count = df_count_temp
            #print(df_count)     
            freq_arranged = df_count.index
            cat = freq_arranged[0]
        #Make the cluster assignment first
            pred_out[cluster_prediction] = pred_out[cluster_prediction].astype(str)
            pred_out.loc[pred_out[clusters_reassign] == z, [cluster_prediction]] = cat
        # Create assignments for any classification >68CI
            for cats in freq_arranged:
                #print(cats)
                cats_assignment = cats#.replace(data1,'') + '_clus_prediction'
                pred_out.loc[(pred_out[clusters_reassign] == z) & (pred_out[lr_predicted_col] == cats),[cluster_prediction]] = cats_assignment
        min_counts = pd.DataFrame((pred_out[cluster_prediction].value_counts()))
        reassign = list(min_counts.index[min_counts[cluster_prediction]<=2])
        pred_out[cluster_prediction] = pred_out[cluster_prediction].str.replace(str(''.join(reassign)),str(''.join(pred_out.loc[pred_out[clusters_reassign].isin(list(pred_out.loc[(pred_out[cluster_prediction].isin(reassign)),clusters_reassign])),lr_predicted_col].value_counts().head(1).index.values)))
        return pred_out

### Feature importance notes
#- If we increase the x feature one unit, then the prediction will change e to the power of its weight. We can apply this rule to the all weights to find the feature importance.
#- We will calculate the Euler number to the power of its coefficient to find the importance.
#- To sum up an increase of x feature by one unit increases the odds of being versicolor class by a factor of x[importance] when all other features remain the same.
#- For low-dim, we look at the distribution of e^coef per class, we extract the 
# class coef_extract:
#     def __init__(self, model,features, pos):
# #         self.w = list(itertools.chain(*(model.coef_[pos]).tolist())) #model.coef_[pos]
#         self.w = model.coef_[class_pred_pos]
#         self.features = features 
def long_format_features(top_loadings):
    p = top_loadings.loc[:, top_loadings.columns.str.endswith("_e^coef")]
    p = pd.melt(p)
    n = top_loadings.loc[:, top_loadings.columns.str.endswith("_feature")]
    n = pd.melt(n)
    l = top_loadings.loc[:, top_loadings.columns.str.endswith("_coef")]
    l = pd.melt(l)
    n = n.replace(regex=r'_feature', value='')
    n = n.rename(columns={"variable": "class", "value": "feature"})
    p = (p.drop(["variable"],axis = 1)).rename(columns={ "value": "e^coef"})
    l = (l.drop(["variable"],axis = 1)).rename(columns={ "value": "coef"})
    concat = pd.concat([n,p,l],axis=1)
    return concat

def model_feature_sf(long_format_feature_importance, coef_use):
        long_format_feature_importance[str(coef_use) + '_pval'] = 'NaN'
        for class_lw in long_format_feature_importance['class'].unique():
            df_loadings = long_format_feature_importance[long_format_feature_importance['class'].isin([class_lw])]
            comps = coef_use #'e^coef'
            U = np.mean(df_loadings[comps])
            std = np.std(df_loadings[comps])
            med =  np.median(df_loadings[comps])
            mad = np.median(np.absolute(df_loadings[comps] - np.median(df_loadings[comps])))
            # Survival function scaled by 1.4826 of MAD (approx norm)
            pvals = scipy.stats.norm.sf(df_loadings[comps], loc=med, scale=1.4826*mad) # 95% CI of MAD <10,000 samples
            #pvals = scipy.stats.norm.sf(df_loadings[comps], loc=U, scale=1*std)
            df_loadings[str(comps) +'_pval'] = pvals
            long_format_feature_importance.loc[long_format_feature_importance.index.isin(df_loadings.index)] = df_loadings
        long_format_feature_importance['is_significant_sf'] = False
        long_format_feature_importance.loc[long_format_feature_importance[coef_use+ '_pval']<0.05,'is_significant_sf'] = True
        return long_format_feature_importance
# Apply SF to e^coeff mat data
#         pval_mat = pd.DataFrame(columns = mat.columns)
#         for class_lw in mat.index:
#             df_loadings = mat.loc[class_lw]
#             U = np.mean(df_loadings)
#             std = np.std(df_loadings)
#             med =  np.median(df_loadings)
#             mad = np.median(np.absolute(df_loadings - np.median(df_loadings)))
#             pvals = scipy.stats.norm.sf(df_loadings, loc=med, scale=1.96*U)

class estimate_important_features: # This calculates feature effect sizes of the model
    def __init__(self, model, top_n):
        print('Estimating feature importance')
        classes =  list(model.classes_)
         # get feature names
        try:
            model_features = list(itertools.chain(*list(model.features)))
        except:
            warnings.warn('no features recorded in data, naming features by position')
            print('if low-dim lr was submitted, run linear decoding function to obtain true feature set')
            model_features = list(range(0,model.coef_.shape[1]))
            model.features = model_features
        print('Calculating the Euler number to the power of coefficients')
        impt_ = pow(math.e,model.coef_)
        try:
            self.euler_pow_mat = pd.DataFrame(impt_,columns = list(itertools.chain(*list(model.features))),index = list(model.classes_))
        except:
            self.euler_pow_mat = pd.DataFrame(impt_,columns = list(model.features),index = list(model.classes_))
        self.top_n_features = pd.DataFrame(index = list(range(0,top_n)))
        # estimate per class feature importance
        
        print('Estimating feature importance for each class')
        mat = self.euler_pow_mat
        for class_pred_pos in list(range(0,len(mat.T.columns))):
            class_pred = list(mat.T.columns)[class_pred_pos]
            #     print(class_pred)
            temp_mat =  pd.DataFrame(mat.T[class_pred])
            temp_mat['coef'] = model.coef_[class_pred_pos]
            temp_mat = temp_mat.sort_values(by = [class_pred], ascending=False)
            temp_mat = temp_mat.reset_index()
            temp_mat.columns = ['feature','e^coef','coef']
            temp_mat = temp_mat[['feature','e^coef','coef']]
            temp_mat.columns =str(class_pred)+ "_" + temp_mat.columns
            self.top_n_features = pd.concat([self.top_n_features,temp_mat.head(top_n)], join="inner",ignore_index = False, axis=1)
            self.to_n_features_long = model_feature_sf(long_format_features(self.top_n_features),'e^coef')
            
    
    # plot class-wise features
def model_class_feature_plots(top_loadings, classes, comps):
    import matplotlib.pyplot as plt
    for class_temp in classes:
        class_lw = class_temp
        long_format = top_loadings
        df_loadings = long_format[long_format['class'].isin([class_lw])]
        plt.hist(df_loadings[comps])
        for i in ((df_loadings[comps][df_loadings[str(comps) +'_pval']<0.05]).unique()):
            plt.axvline(x=i,color='red')
        med = np.median(df_loadings[comps])
        plt.axvline(x=med,color='blue')
        plt.xlabel('feature_importance', fontsize=12)
        plt.title(class_lw)
        #plt.axvline(x=med,color='pink')
        df_loadings[comps][df_loadings[str(comps) +'_pval']<0.05]
        print(len(df_loadings[comps][df_loadings[str(comps) +'_pval']<0.05]))
        #Plot feature ranking
        plot_loading = pd.DataFrame(pd.DataFrame(df_loadings[comps][df_loadings[str(comps) +'_pval']<0.05]).iloc[:,0].sort_values(ascending=False))
        table = plt.table(cellText=plot_loading.values,colWidths = [1]*len(plot_loading.columns),
        rowLabels= list(df_loadings['feature'][df_loadings.index.isin(plot_loading.index)].reindex(plot_loading.index)), #plot_loading.index,
        colLabels=plot_loading.columns,
        cellLoc = 'center', rowLoc = 'center',
        loc='right', bbox=[1.4, -0.05, 0.5,1])
        table.scale(1, 2)
        table.set_fontsize(10)
        
def report_f1(model,train_x, train_label):
    ## Report accuracy score
    from sklearn.model_selection import cross_val_score
    from sklearn.model_selection import RepeatedStratifiedKFold
    from sklearn import metrics
    import seaborn as sn
    import pandas as pd
    from sklearn.metrics import confusion_matrix
    import matplotlib.pyplot as plt
    
    # cv = RepeatedStratifiedKFold(n_splits=2, n_repeats=2, random_state=1)
    # # evaluate the model and collect the scores
    # n_scores = cross_val_score(lr, train_x, train_label, scoring='accuracy', cv=cv, n_jobs=-1)
    # # report the model performance
    # print('Mean Accuracy: %.3f (%.3f)' % (np.mean(n_scores), np.std(n_scores)))

    # Report Precision score
    metric = pd.DataFrame((metrics.classification_report(train_label, model.predict(train_x), digits=2,output_dict=True))).T
    cm = confusion_matrix(train_label, model.predict(train_x))
    #cm = confusion_matrix(train_label, model.predict_proba(train_x))
    df_cm = pd.DataFrame(cm, index = model.classes_,columns = model.classes_)
    df_cm = (df_cm / df_cm.sum(axis=0))*100
    plt.figure(figsize = (20,15))
    sn.set(font_scale=1) # for label size
    pal = sns.diverging_palette(240, 10, n=10)
    #plt.suptitle(('Mean Accuracy 5 fold: %.3f std: %.3f' % (np.mean(n_scores),  np.std(n_scores))), y=1.05, fontsize=18)
    #Plot precision recall and recall
    table = plt.table(cellText=metric.values,colWidths = [1]*len(metric.columns),
    rowLabels=metric.index,
    colLabels=metric.columns,
    cellLoc = 'center', rowLoc = 'center',
    loc='bottom', bbox=[0.25, -0.6, 0.5, 0.3])
    table.scale(1, 2)
    table.set_fontsize(10)

    sn.heatmap(df_cm, annot=True, annot_kws={"size": 16},cmap=pal) # font size
    print(metrics.classification_report(train_label, model.predict(train_x), digits=2))

def subset_top_hvgs(adata_lognorm, n_top_genes):
    dispersion_norm = adata_lognorm.var['dispersions_norm'].values.astype('float32')

    dispersion_norm = dispersion_norm[~np.isnan(dispersion_norm)]
    dispersion_norm[
                ::-1
            ].sort()  # interestingly, np.argpartition is slightly slower

    disp_cut_off = dispersion_norm[n_top_genes - 1]
    gene_subset = adata_lognorm.var['dispersions_norm'].values >= disp_cut_off
    return(adata_lognorm[:,gene_subset])

def prep_scVI(adata, 
              n_hvgs = 5000,
              remove_cc_genes = True,
              remove_tcr_bcr_genes = False
             ):
    ## Remove cell cycle genes
    if remove_cc_genes:
        adata = panfetal_utils.remove_geneset(adata,genes.cc_genes)

    ## Remove TCR/BCR genes
    if remove_tcr_bcr_genes:
        adata = panfetal_utils.remove_geneset(adata, genes.IG_genes)
        adata = panfetal_utils.remove_geneset(adata, genes.TCR_genes)
        
    ## HVG selection
    adata = subset_top_hvgs(adata, n_top_genes=n_hvgs)
    return(adata)

# Modified LR train module, does not work with low-dim by default anymore, please use low-dim adapter
def LR_train(adata, train_x, train_label, penalty='elasticnet', sparcity=0.2,max_iter=200,l1_ratio =0.2,tune_hyper_params =False,n_splits=5, n_repeats=3,l1_grid = [0.01,0.2,0.5,0.8], c_grid = [0.01,0.2,0.4,0.6]):
    if tune_hyper_params == True:
        train_labels=train_label
        results = tune_lr_model(adata, train_x_partition = train_x, random_state = 42,  train_labels = train_labels, n_splits=n_splits, n_repeats=n_repeats,l1_grid = l1_grid, c_grid = c_grid)
        print('hyper_params tuned')
        sparcity = results.best_params_['C']
        l1_ratio = results.best_params_['l1_ratio']
    
    lr = LogisticRegression(penalty = penalty, C = sparcity, max_iter =  max_iter, n_jobs=thread_num)
    if (penalty == "l1"):
        lr = LogisticRegression(penalty = penalty, C = sparcity, max_iter =  max_iter, dual = True, solver = 'liblinear',multi_class = 'ovr', n_jobs=thread_num ) # one-vs-rest
    if (penalty == "elasticnet"):
        lr = LogisticRegression(penalty = penalty, C = sparcity, max_iter =  max_iter, dual=False,solver = 'saga',l1_ratio=l1_ratio,multi_class = 'ovr', n_jobs=thread_num)
    if train_x == 'X':
        subset_train = adata.obs.index
        # Define training parameters
        train_label = adata.obs[train_label].values
#        predict_label = train_label[subset_predict]
#        train_label = train_label[subset_train]
        train_x = adata.X#[adata.obs.index.isin(list(adata.obs[subset_train].index))]
#        predict_x = adata.X[adata.obs.index.isin(list(adata.obs[subset_predict].index))]
    elif train_x in adata.obsm.keys():
        # Define training parameters
        train_label = adata.obs[train_label].values
#        predict_label = train_label[subset_predict]
#         train_label = train_label[subset_train]
        train_x = adata.obsm[train_x]
#        predict_x = train_x
#        train_x = train_x[subset_train, :]
        # Define prediction parameters
#        predict_x = predict_x[subset_predict]
#        predict_x = pd.DataFrame(predict_x)
#        predict_x.index = adata.obs[subset_predict].index
    # Train predictive model using user defined partition labels (train_x ,train_label, predict_x)
    model = lr.fit(train_x, train_label)
    model.features = np.array(adata.var.index)
    return model

def tune_lr_model(adata, train_x_partition = 'X', random_state = 42, use_bayes_opt=True, train_labels = None, n_splits=5, n_repeats=3,l1_grid = [0.1,0.2,0.5,0.8], c_grid = [0.1,0.2,0.4,0.6]):
    import bless as bless
    from sklearn.gaussian_process.kernels import RBF
    from numpy import arange
    from sklearn.model_selection import RepeatedKFold
    from sklearn.datasets import make_classification
    from sklearn.linear_model import LogisticRegression
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import f1_score
    from sklearn.model_selection import GridSearchCV
    from skopt import BayesSearchCV

    # If latent rep is provided, randomly sample data in spatially aware manner for initialisation
    r = np.random.RandomState(random_state)
    if train_x_partition in adata.obsm.keys():
        lvg = bless.bless(tune_train_x, RBF(length_scale=20), lam_final = 2, qbar = 2, random_state = r, H = 10, force_cpu=True)
    #     try:
    #         import cupy
    #         lvg_2 = bless(adata.obsm[train_x_partition], RBF(length_scale=10), 10, 10, r, 10, force_cpu=False)
    #     except ImportError:
    #         print("cupy not found, defaulting to numpy")
        adata_tuning = adata[lvg.idx]
        tune_train_x = adata_tuning.obsm[train_x_partition][:]
    else:
        print('no latent representation provided, random sampling instead')
        prop = 0.1
        random_vertices = []
        n_ixs = int(len(adata.obs) * prop)
        random_vertices = random.sample(list(range(len(adata.obs))), k=n_ixs)
        adata_tuning = adata[random_vertices]
        tune_train_x = adata_tuning.X
        
    if not train_labels == None:
        tune_train_label = adata_tuning.obs[train_labels]
    elif train_labels == None:
        try:
            print('no training labels provided, defaulting to unsuperived leiden clustering, updates will change this to voronoi greedy sampling')
            sc.tl.leiden(adata_tuning)
        except:
            print('no training labels provided, no neighbors, defaulting to unsuperived leiden clustering, updates will change this to voronoi greedy sampling')
            sc.pp.neighbors(adata_hm, n_neighbors=15, n_pcs=50)
            sc.tl.leiden(adata_tuning)
        tune_train_label = adata_tuning.obs['leiden']
    ## tune regularization for multinomial logistic regression
    print('starting tuning loops')
    X = tune_train_x
    y = tune_train_label
    grid = dict()
    # define model
    cv = RepeatedKFold(n_splits=n_splits, n_repeats=n_repeats, random_state=random_state)
    #model = LogisticRegression(penalty = penalty, max_iter =  200, dual=False,solver = 'saga', multi_class = 'multinomial',)
    model = LogisticRegression(penalty = penalty, C = sparcity, max_iter =  100, n_jobs=4)
    if (penalty == "l1"):
        model = LogisticRegression(penalty = penalty, C = sparcity, max_iter =  100, dual = True, solver = 'liblinear',multi_class = 'multinomial', n_jobs=4 ) # one-vs-rest
    if (penalty == "elasticnet"):
        model = LogisticRegression(penalty = penalty, C = sparcity, max_iter =  100, dual=False,solver = 'saga',l1_ratio=l1_ratio,multi_class = 'multinomial', n_jobs=4) # use multinomial class if probabilities are descrete
        grid['l1_ratio'] = l1_grid
    grid['C'] = c_grid
    
    if use_bayes_opt == True:
        # define search space
        search_space = {'C': (np.min(c_grid), np.max(c_grid), 'log-uniform'), 
                        'l1_ratio': (np.min(l1_grid), np.max(l1_grid), 'uniform') if 'elasticnet' in penalty else None}
        # define search
        search = BayesSearchCV(model, search_space, scoring='neg_mean_absolute_error', cv=cv, n_jobs=-1)
        # perform the search
        results = search.fit(X, y)
    else:
        # define search
        search = GridSearchCV(model, grid, scoring='neg_mean_absolute_error', cv=cv, n_jobs=-1)
        # perform the search
        results = search.fit(X, y)
    # summarize
    print('MAE: %.3f' % results.best_score_)
    print('Config: %s' % results.best_params_)
    return results

def prep_training_data(adata_temp,feat_use,batch_key, model_key, batch_correction=False, var_length = 7500,penalty='elasticnet',sparcity=0.2,max_iter = 200,l1_ratio = 0.1,partial_scale=True,train_x_partition ='X',theta = 3,tune_hyper_params=False ):
    model_name = model_key + '_lr_model'
    print('performing highly variable gene selection')
    sc.pp.highly_variable_genes(adata_temp, batch_key = batch_key, subset=False)
    adata_temp = subset_top_hvgs(adata_temp,var_length)
    #scale the input data
    if partial_scale == True:
        print('scaling input data, default option is to use incremental learning and fit in mini bulks!')
        # Partial scaling alg
        #adata_temp.X = (adata_temp.X)
        scaler = StandardScaler(with_mean=False)
        n = adata_temp.X.shape[0]  # number of rows
        # set dyn scale packet size
        x_len = len(adata_temp.var)
        y_len = len(adata_temp.obs)
        if y_len < 100000:
            dyn_pack = int(x_len/10)
            pack_size = dyn_pack
        else:
            # 10 pack for every 100,000
            dyn_pack = int((y_len/100000)*10)
            pack_size = int(x_len/dyn_pack)
        batch_size =  1000#pack_size#500  # number of rows in each call to partial_fit
        index = 0  # helper-var
        while index < n:
            partial_size = min(batch_size, n - index)  # needed because last loop is possibly incomplete
            partial_x = adata_temp.X[index:index+partial_size]
            scaler.partial_fit(partial_x)
            index += partial_size
        adata_temp.X = scaler.transform(adata_temp.X)
#     else:
#         sc.pp.scale(adata_temp, zero_center=True, max_value=None, copy=False, layer=None, obsm=None)
    if (train_x_partition != 'X') & (train_x_partition in adata_temp.obsm.keys()):
        print('train partition is not in OBSM, defaulting to PCA')
        # Now compute PCA
        sc.pp.pca(adata_temp, n_comps=100, use_highly_variable=True, svd_solver='arpack')
        sc.pl.pca_variance_ratio(adata_temp, log=True,n_pcs=100)
        
        # Batch correction options
        # The script will test later which Harmony values we should use 
        if(batch_correction == "Harmony"):
            print("Commencing harmony")
            adata_temp.obs['lr_batch'] = adata_temp.obs[batch_key]
            batch_var = "lr_batch"
            # Create hm subset
            adata_hm = adata_temp[:]
            # Set harmony variables
            data_mat = np.array(adata_hm.obsm["X_pca"])
            meta_data = adata_hm.obs
            vars_use = [batch_var]
            # Run Harmony
            ho = hm.run_harmony(data_mat, meta_data, vars_use,theta=theta)
            res = (pd.DataFrame(ho.Z_corr)).T
            res.columns = ['X{}'.format(i + 1) for i in range(res.shape[1])]
            # Insert coordinates back into object
            adata_hm.obsm["X_pca_back"]= adata_hm.obsm["X_pca"][:]
            adata_hm.obsm["X_pca"] = np.array(res)
            # Run neighbours
            #sc.pp.neighbors(adata_hm, n_neighbors=15, n_pcs=50)
            adata_temp = adata_hm[:]
            del adata_hm
        elif(batch_correction == "BBKNN"):
            print("Commencing BBKNN")
            sc.external.pp.bbknn(adata_temp, batch_key=batch_var, approx=True, metric='angular', copy=False, n_pcs=50, trim=None, n_trees=10, use_faiss=True, set_op_mix_ratio=1.0, local_connectivity=15) 
        print("adata1 and adata2 are now combined and preprocessed in 'adata' obj - success!")


    # train model
#    train_x = adata_temp.X
    #train_label = adata_temp.obs[feat_use]
    print('proceeding to train model')
    model = LR_train(adata_temp, train_x = train_x_partition, train_label=feat_use, penalty=penalty, sparcity=sparcity,max_iter=max_iter,l1_ratio = l1_ratio,tune_hyper_params = tune_hyper_params)
    model.features = list(adata_temp.var.index)
    return model

def compute_label_log_losses(df, true_label, pred_columns):
    """
    Compute log loss (cross-entropy loss).
    
    Parameters:
    df : dataframe containing the predicted probabilities and original labels as columns
    true_label : column or array-like of shape (n_samples,) containg cateogrical labels
    pred_columns : columns or array-like of shape (n_samples, n_clases) containg predicted probabilities

    converts to:
    y_true : array-like of shape (n_samples,) True labels. The binary labels in a one-vs-rest fashion.
    y_pred : array-like of shape (n_samples, n_classes) Predicted probabilities. 
        
    Returns:
    log_loss : dictionary of celltype key and float
    weights : float
    """
    log_losses = {}
    y_true = (pd.get_dummies(df[true_label]))
    y_pred = df[pred_col]
    loss = log_loss(np.array(y_true), np.array(y_pred))
    for label in range(y_true.shape[1]):
        log_loss_label = log_loss(np.array(y_true)[:, label], np.array(y_pred)[:, label])
        log_losses[list(y_true.columns)[label]] = (log_loss_label)
    weights = 1/np.array(list(log_losses.values()))
    weights /= np.sum(weights)
    weights = np.array(weights)
    return loss, log_losses, weights

def regression_results(df, true_label, pred_label, pred_columns):
    # Regression metrics
    y_true = df[true_label]
    y_pred = df[pred_label]
    loss, log_losses, weights = compute_label_log_losses(df, true_label, pred_columns)
    mean_absolute_error=metrics.mean_absolute_error(y_true, y_pred) 
    mse=metrics.mean_squared_error(y_true, y_pred) 
    mean_squared_log_error=metrics.mean_squared_log_error(y_true, y_pred)
    median_absolute_error=metrics.median_absolute_error(y_true, y_pred)
#     r2=metrics.r2_score(y_true, y_pred)
    print('Cross entropy loss: ', round(loss,4))    
    print('mean_squared_log_error: ', round(mean_squared_log_error,4))
    print('MAE: ', round(mean_absolute_error,4))
    print('MSE: ', round(mse,4))
    print('RMSE: ', round(np.sqrt(mse),4))
    print('label Cross entropy loss: ')
    print(log_losses)  
    return loss, log_losses, weights


# ENSDB-HGNC Option 1 
#from gseapy.parser import Biomart
#bm = Biomart()
## view validated marts#
#marts = bm.get_marts()
## view validated dataset
#datasets = bm.get_datasets(mart='ENSEMBL_MART_ENSEMBL')
## view validated attributes
#attrs = bm.get_attributes(dataset='hsapiens_gene_ensembl')
## view validated filters
#filters = bm.get_filters(dataset='hsapiens_gene_ensembl')
## query results
#queries = ['ENSG00000125285','ENSG00000182968'] # need to be a python list
#results = bm.query(dataset='hsapiens_gene_ensembl',
#                       attributes=['ensembl_gene_id', 'external_gene_name', 'entrezgene_id', 'go_id'],
#                       filters={'ensemble_gene_id': queries}
                      
# ENSDB-HGNC Option 2
def convert_hgnc(input_gene_list):
    import mygene
    mg = mygene.MyGeneInfo()
    geneList = input_gene_list 
    geneSyms = mg.querymany(geneList , scopes='ensembl.gene', fields='symbol', species='human')
    return(pd.DataFrame(geneSyms))
# Example use: convert_hgnc(['ENSG00000148795', 'ENSG00000165359', 'ENSG00000150676'])

# Scanpy_degs_to_long_format
def convert_scanpy_degs(input_dataframe):
    if 'concat' in locals() or 'concat' in globals():
        del(concat)
    degs = input_dataframe
    n = degs.loc[:, degs.columns.str.endswith("_n")]
    n = pd.melt(n)
    p = degs.loc[:, degs.columns.str.endswith("_p")]
    p = pd.melt(p)
    l = degs.loc[:, degs.columns.str.endswith("_l")]
    l = pd.melt(l)
    n = n.replace(regex=r'_n', value='')
    n = n.rename(columns={"variable": "cluster", "value": "gene"})
    p = (p.drop(["variable"],axis = 1)).rename(columns={ "value": "p_val"})
    l = (l.drop(["variable"],axis = 1)).rename(columns={ "value": "logfc"})
    return(pd.concat([n,p,l],axis=1))
#Usage: convert_scanpy_degs(scanpy_degs_file)

# Clean convert gene list to list
def as_gene_list(input_df,gene_col):
    gene_list = input_df[gene_col]
    glist = gene_list.squeeze().str.strip().tolist()
    return(glist)

# No ranking enrichr
def enrich_no_rank(input_gene_list,database,species="Human",description="enr_no_rank",outdir = "./enr",cutoff=0.5):
    # list, dataframe, series inputs are supported
    enr = gp.enrichr(gene_list=input_gene_list,
                     gene_sets=database,
                     organism=species, # don't forget to set organism to the one you desired! e.g. Yeast
                     #description=description,
                     outdir=outdir,
                     # no_plot=True,
                     cutoff=cutoff # test dataset, use lower value from range(0,1)
                    )
    return(enr)
    #Usge: enrich_no_rank(gene_as_list)
    
# Custom genelist test #input long format degs or dictionary of DEGS
def custom_local_GO_enrichment(input_gene_list,input_gmt,input_gmt_key_col,input_gmt_values,description="local_go",Background='hsapiens_gene_ensembl',Cutoff=0.5):
    
    #Check if GMT input is a dictionary or long-format input
    if isinstance(input_gmt, dict):
        print("input gmt is a dictionary, proceeding")
        dic = input_gmt
    else:
        print("input gmt is not a dictionary, if is pandas df,please ensure it is long-format proceeding to convert to dictionary")
        dic =  input_gmt.groupby([input_gmt_key_col])[input_gmt_values].agg(lambda grp: list(grp)).to_dict()
        
    enr_local = gp.enrichr(gene_list=input_gene_list,
                 description=description,
                 gene_sets=dic,
                 background=Background, # or the number of genes, e.g 20000
                 cutoff=Cutoff, # only used for testing.
                 verbose=True)
    return(enr_local)
    #Example_usage: custom_local_GO_enrichment(input_gene_list,input_gmt,input_gmt_key_col,input_gmt_values) #input gmt can be long-format genes and ontology name or can be dictionary of the same   

    
# Pre-ranked GS enrichment
def pre_ranked_enr(input_gene_list,gene_and_ranking_columns,database ='GO_Biological_Process_2018',permutation_num = 1000, outdir = "./enr_ranked",cutoff=0.25,min_s=5,max_s=1000):
    glist = input_gene_list[gene_and_ranking_columns]
    pre_res = gp.prerank(glist, gene_sets=database,
                     threads=4,
                     permutation_num=permutation_num, # reduce number to speed up testing
                     outdir=outdir,
                     seed=6,
                     min_size=min_s,
                     max_size=max_s)
    return(pre_res)
    #Example usage: pre_ranked_hyper_geom(DE, gene_and_ranking_columns = ["gene","logfc"],database=['KEGG_2016','GO_Biological_Process_2018'])

    
# GSEA module for permutation test of differentially regulated genes
# gene set enrichment analysis (GSEA), which scores ranked genes list (usually based on fold changes) and computes permutation test to check if a particular gene set is more present in the Up-regulated genes, 
# among the DOWN_regulated genes or not differentially regulated.
#NES = normalised enrichment scores accounting for geneset size
def permutation_ranked_enr(input_DE,cluster_1,cluster_2,input_DE_clust_col,input_ranking_col ,input_gene_col ,database = "GO_Biological_Process_2018"):
    input_DE = input_DE[input_DE[input_DE_clust_col].isin([cluster_1,cluster_2])]
    #Make set2 negative values to represent opposing condition
    input_DE[input_ranking_col].loc[input_DE[input_DE_clust_col].isin([cluster_2])] = -input_DE[input_ranking_col].loc[input_DE[input_DE_clust_col].isin([cluster_2])]
    enr_perm = pre_ranked_enr(input_DE,[input_gene_col,input_ranking_col],database,permutation_num = 100, outdir = "./enr_ranked_perm",cutoff=0.5)
    return(enr_perm)
    #Example usage:permutation_ranked_enr(input_DE = DE, cluster_1 = "BM",cluster_2 = "YS",input_DE_clust_col = "cluster",input_ranking_col = "logfc",input_gene_col = "gene",database = "GO_Biological_Process_2018")
    #input long-format list of genes and with a class for permutaion, the logfc ranking should have been derived at the same time


#Creating similarity matrix from nested gene lists
def create_sim_matrix_from_enr(input_df,nested_gene_column="Genes",seperator=";",term_col="Term"):
#    input_df[gene_col] = input_df[gene_col].astype(str)
#    input_df[gene_col] = input_df[gene_col].str.split(";")
#    uni_val = list(input_df.index.unique())
#    sim_mat = pd.DataFrame(index=uni_val, columns=uni_val)
#    exploded_df = input_df.explode(gene_col)
#    # Ugly loop for cosine gs similarity matrix (0-1)
#    for i in (uni_val):
#        row = exploded_df[exploded_df.index.isin([i])]
#        for z in (uni_val):
#            col = exploded_df[exploded_df.index.isin([z])]
#            col_ls = col[gene_col]
#            row_ls = row[gene_col]
#            sim = len(set(col_ls) & set(row_ls)) / float(len(set(col_ls) | set(row_ls)))
#            sim_mat.loc[i , z] = sim

#    Check term col in columns else, check index as it\s sometimes heree
    if not term_col in list(input_df.columns):
        input_df[term_col] = input_df.index

#    Create a similarity matrix by cosine similarity
    input_df = input_df.copy()
    gene_col = nested_gene_column #"ledge_genes"
    input_df[gene_col] = input_df[gene_col].astype(str)
    input_df[gene_col] = input_df[gene_col].str.split(seperator)
    term_vals = list(input_df[term_col].unique())
    uni_val = list(input_df[term_col].unique())
    sim_mat = pd.DataFrame(index=uni_val, columns=uni_val)
    exploded_df = input_df.explode(gene_col)
    arr = np.array(input_df[gene_col])
    vals = list(exploded_df[nested_gene_column])
    import scipy.sparse as sparse
    def binarise(sets, full_set):
        """Return sparse binary matrix of given sets."""
        return sparse.csr_matrix([[x in s for x in full_set] for s in sets])
    # Turn the matrix into a sparse boleen matrix of binarised values
    sparse_matrix = binarise(arr, vals)
    from sklearn.metrics.pairwise import cosine_similarity
    similarities = cosine_similarity(sparse_matrix)
    sim_mat = pd.DataFrame(similarities)
    sim_mat.index = uni_val
    sim_mat.columns = uni_val
    return(sim_mat)
#Example usage : sim_mat = create_sim_matrix_from_enr(enr.res2d)


#Creating similarity matrix from GO terms
def create_sim_matrix_from_term(input_df,nested_gene_column="Term",seperator=" ",term_col="Term"):

#    Check term col in columns else, check index as it\s sometimes heree
    if not term_col in list(input_df.columns):
        input_df[term_col] = input_df.index

#    Create a similarity matrix by cosine similarity
    input_df = input_df.copy()
    gene_col = nested_gene_column #"ledge_genes"
    #input_df[gene_col] = input_df[gene_col].astype(str)
    input_df[gene_col] = input_df[gene_col].str.split(seperator)
    term_vals = list(input_df[term_col].unique())
    uni_val = list(input_df.index.unique())
    sim_mat = pd.DataFrame(index=uni_val, columns=uni_val)
    exploded_df = input_df.explode(gene_col)
    arr = np.array(input_df[gene_col])
    vals = list(exploded_df[nested_gene_column])
    import scipy.sparse as sparse
    def binarise(sets, full_set):
        """Return sparse binary matrix of given sets."""
        return sparse.csr_matrix([[x in s for x in full_set] for s in sets])
    sparse_matrix = binarise(arr, vals)
    from sklearn.metrics.pairwise import cosine_similarity
    similarities = cosine_similarity(sparse_matrix)
    sim_mat = pd.DataFrame(similarities)
    sim_mat.index = uni_val
    sim_mat.columns = uni_val
    return(sim_mat)

#Creating similarity matrix from GO terms
def create_sim_matrix_from_term2(input_df,nested_gene_column="Term",seperator=" ",term_col="Term"):
#    Horrifically bad cosine similairty estimate for word frequency
#    Check term col in columns else, check index as it\s sometimes heree
    if not term_col in list(input_df.columns):
        input_df[term_col] = input_df.index
    input_df = input_df.copy()
    gene_col = nested_gene_column #"ledge_genes"
    #input_df[gene_col] = input_df[gene_col].astype(str)
    term_vals = list(input_df[term_col].unique())
    input_df[gene_col] = input_df[gene_col].str.split(seperator)
    uni_val = list(input_df.index.unique())
    sim_mat = pd.DataFrame(index=uni_val, columns=uni_val)
    exploded_df = input_df.explode(gene_col)

    nan_value = float("NaN")
    exploded_df.replace("", nan_value, inplace=True)
    exploded_df.dropna(subset = [gene_col], inplace=True)
    arr = np.array(input_df[gene_col])

    vals = list(exploded_df[nested_gene_column])

    import scipy.sparse as sparse
    def binarise(sets, full_set):
        """Return sparse binary matrix of given sets."""
        return sparse.csr_matrix([[x in s for x in full_set] for s in sets])
    sparse_matrix = binarise(arr, vals)
    from sklearn.metrics.pairwise import cosine_similarity
    similarities = cosine_similarity(sparse_matrix)
    sim_mat = pd.DataFrame(similarities)
    sim_mat.index = uni_val
    sim_mat.columns = uni_val
    return(sim_mat)
    #Example usage : sim_mat = create_sim_matrix_from_enr(enr.res2d)

In [395]:
l1_grid = [0.1,0.2,0.5,0.8]
np.min(l1_grid)

0.1

In [392]:
# define search space
search_space = {'C': (0.01, 1, 'log-uniform'), 
                'l1_ratio': (0.01, 1, 'uniform') if 'elasticnet' in penalty else None}
# define search
search = BayesSearchCV(model, search_space, scoring='neg_mean_absolute_error', cv=cv, n_jobs=-1)
# perform the search
results = search.fit(X, y)

In [11]:
adata =  load_adatas(adatas_dict, data_merge, adata_key,QC_normalise)

option to apply standardisation to data detected, performing basic QC filtering


In [12]:
adata.obs

Unnamed: 0_level_0,bh_doublet_pval,cell_caller,cluster_scrublet_score,doublet_pval,mt_prop,n_counts,n_genes,sanger_id,scrublet_score,chemistry,...,gender,pcw,sorting,sample,chemistry_sorting,annot,hierarchy1,joint_annotation_20220202,independent_annotation_refined_20220202,fig1b_annotation_20220202
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
AAACCTGGTCAGTGGA-1-4834STDY7002879,0.907861,Both,0.157082,0.500000,0.062532,5917.0,1776,4834STDY7002879,0.225806,SC3Pv2,...,male,8,CD45P,F16_male_8+1PCW,SC3Pv2_CD45P,fs_Macrophage,1.0,LYVE1++ macrophage,LYVE1++ macrophage,Macrophage
AAAGATGGTCGATTGT-1-4834STDY7002879,0.907861,Both,0.157082,0.500000,0.030894,10261.0,2750,4834STDY7002879,0.149606,SC3Pv2,...,male,8,CD45P,F16_male_8+1PCW,SC3Pv2_CD45P,fs_Monocyte,1.0,Monocyte (activated/differentiating),Monocyte (activated/differentiating),Monocyte
AAAGCAAAGATGTGGC-1-4834STDY7002879,0.882352,Both,0.225806,0.150885,0.012647,7749.0,2308,4834STDY7002879,0.201970,SC3Pv2,...,male,8,CD45P,F16_male_8+1PCW,SC3Pv2_CD45P,fs_Macrophage,1.0,LYVE1++ macrophage,LYVE1++ macrophage,Macrophage
AAAGTAGCAGATCGGA-1-4834STDY7002879,0.907861,Both,0.164557,0.455284,0.017443,14791.0,3099,4834STDY7002879,0.164557,SC3Pv2,...,male,8,CD45P,F16_male_8+1PCW,SC3Pv2_CD45P,fs_Mast cell,5.0,Eo/baso/mast cell progenitor,Eo/baso/mast cell progenitor,Progenitor
AAAGTAGTCCGCATCT-1-4834STDY7002879,0.882352,Both,0.201970,0.250000,0.041431,7434.0,2283,4834STDY7002879,0.181818,SC3Pv2,...,male,8,CD45P,F16_male_8+1PCW,SC3Pv2_CD45P,fs_Macrophage,1.0,LYVE1++ macrophage,LYVE1++ macrophage,Macrophage
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGTCAAGTGAACGC-1-FCAImmP7964510,0.851322,Both,0.109737,0.314770,0.027876,5094.0,1875,FCAImmP7964510,0.085937,SC5P-R2,...,female,14,CD45en,F71-GEX_5_SKI_45en,SC5P-R2_CD45en,,0.0,Pericytes,Pericytes,Mural cell
TTTGTCAGTGCGAAAC-1-FCAImmP7964510,0.851322,Both,0.125778,0.231265,0.038952,6906.0,2161,FCAImmP7964510,0.125778,SC5P-R2,...,female,14,CD45en,F71-GEX_5_SKI_45en,SC5P-R2_CD45en,,0.0,Pericytes,Pericytes,Mural cell
TTTGTCATCCATGAGT-1-FCAImmP7964510,0.851322,Both,0.024831,0.803208,0.012081,8526.0,658,FCAImmP7964510,0.034700,SC5P-R2,...,female,14,CD45en,F71-GEX_5_SKI_45en,SC5P-R2_CD45en,,8.0,Early erythroid,Early erythroid,Erythroid
TTTGTCATCGCAAGCC-1-FCAImmP7964510,0.851322,Both,0.044693,0.705651,0.053341,10911.0,3368,FCAImmP7964510,0.052632,SC5P-R2,...,female,14,CD45en,F71-GEX_5_SKI_45en,SC5P-R2_CD45en,,4.0,LE,LE,Lymphatic endothelium


# Read in query data for projection

In [None]:
if train_model == True:
    from sklearn.preprocessing import StandardScaler
    adata =  load_adatas(adatas_dict, data_merge, adata_key,QC_normalise)
    print('adata_loaded')
    import time
    t0 = time.time()
    display_cpu = DisplayCPU()
    display_cpu.start()
    try:
        model_trained = prep_training_data(feat_use = feat_use,
        adata_temp = adata,
        train_x_partition = train_x_partition,
        model_key = model_key + '_lr_model',
        batch_correction = 'Harmony',
        var_length = 7500,
        batch_key = 'donor',
        penalty='elasticnet', # can be ["l1","l2","elasticnet"],
        sparcity=sparcity, #If using LR without optimisation, this controls the sparsity in model
        max_iter = 1000, #Increase if experiencing max iter issues
        l1_ratio = l1_ratio, #If using elasticnet without optimisation, this controls the ratio between l1 and l2)
        partial_scale = False, #partial_scale,
        tune_hyper_params = True # Current implementation is very expensive, intentionally made rigid for now
        )
        filename =model_name
        pkl.dump(model_trained, open(filename, 'wb'))
    finally: #
        current, peak = display_cpu.stop()
        t1 = time.time()
        time_s = t1-t0
        print('training complete!')
        time.sleep(3)
        print('projection time was ' + str(time_s) + ' seconds')
        print(f"Current memory usage is {current / 10**6}MB; Peak was {peak / 10**6}MB")
        print(f"starting memory usage is" +'' + str(display_cpu.starting))
        print('peak CPU % usage = '+''+ str(display_cpu.peak_cpu))
        print('peak CPU % usage/core = '+''+ str(display_cpu.peak_cpu_per_core))
    model_lr= model_trained
    adata =  load_adatas(adatas_dict, data_merge, adata_key)
else:
    adata =  load_adatas(adatas_dict, data_merge, adata_key,QC_normalise)
    model = load_models(models,model_key)
    model_lr =  model
# run with usage logger
import time
t0 = time.time()
display_cpu = DisplayCPU()
display_cpu.start()
try: #code here ##
    pred_out,train_x,model_lr,adata_temp = reference_projection(adata, model_lr, dyn_std,partial_scale)
    if freq_redist != False:
        pred_out = freq_redist_68CI(adata,freq_redist)
        pred_out['orig_labels'] = adata.obs[freq_redist]
        adata.obs['consensus_clus_prediction'] = pred_out['consensus_clus_prediction']
    adata.obs['predicted'] = pred_out['predicted']
    adata_temp.obs = adata.obs
    
    # Estimate top model features for class descrimination
    feature_importance = estimate_important_features(model_lr, 100)
    mat = feature_importance.euler_pow_mat
    top_loadings = feature_importance.to_n_features_long
    
    # Estimate dataset specific feature impact
#     for classes in ['pDC precursor_ys_HL','AEC_ys_HL']:
#         model_class_feature_plots(top_loadings, [str(classes)], 'e^coef')
#         plt.show()
finally: #
    current, peak = display_cpu.stop()
t1 = time.time()
time_s = t1-t0
print('projection complete!')
time.sleep(3)
print('projection time was ' + str(time_s) + ' seconds')
print(f"Current memory usage is {current / 10**6}MB; Peak was {peak / 10**6}MB")
print(f"starting memory usage is" +'' + str(display_cpu.starting))
print('peak CPU % usage = '+''+ str(display_cpu.peak_cpu))
print('peak CPU % usage/core = '+''+ str(display_cpu.peak_cpu_per_core))

# regression summary
idx_map = dict(zip(  list(adata.obs[feat_use].unique()),list(range(0,len(list(adata.obs[feat_use].unique()))))))
regression_results(adata.obs[feat_use].map(idx_map), adata.obs['predicted'].map(idx_map))

In [14]:
pred_out

Unnamed: 0,predicted,ASDC,Adipocytes,Arterial,B cell,Basal,CD4 T cell,CD8 T cell,Capillary (venular tip),Capillary/postcapillary venule,...,Suprabasal IFE,TREM2+ macrophage,Tip cell (arterial),Treg,WNT2+ fibroblast,pDC,confident_calls,joint_annotation_20220202,consensus_clus_prediction,orig_labels
AAACCTGGTCAGTGGA-1-4834STDY7002879,LYVE1++ macrophage,3.121147e-08,2.942001e-08,5.584572e-09,5.133540e-10,3.132832e-09,2.463025e-10,7.582230e-10,4.643047e-07,1.149033e-08,...,8.762978e-08,1.004222e-04,1.232687e-06,1.720775e-09,2.245473e-12,2.170336e-07,LYVE1++ macrophage,LYVE1++ macrophage,LYVE1++ macrophage,LYVE1++ macrophage
AAAGATGGTCGATTGT-1-4834STDY7002879,Monocyte,8.021268e-05,1.136223e-09,8.952242e-09,7.051683e-09,8.112603e-09,1.610157e-11,6.927696e-07,4.949447e-11,1.130580e-08,...,5.018859e-09,1.562070e-07,8.123187e-09,1.345388e-06,2.181822e-15,9.321094e-07,Monocyte,Monocyte (activated/differentiating),Monocyte (activated/differentiating),Monocyte (activated/differentiating)
AAAGCAAAGATGTGGC-1-4834STDY7002879,LYVE1++ macrophage,4.388417e-06,2.971849e-11,7.536972e-09,2.891398e-09,3.302528e-09,1.790903e-11,3.914251e-10,1.995001e-08,2.178952e-07,...,2.377020e-08,2.140535e-04,1.696714e-07,1.792986e-08,4.258448e-15,1.052342e-06,LYVE1++ macrophage,LYVE1++ macrophage,LYVE1++ macrophage,LYVE1++ macrophage
AAAGTAGCAGATCGGA-1-4834STDY7002879,Eo/baso/mast cell progenitor,4.141562e-05,1.333213e-08,2.398914e-09,2.734619e-09,3.996713e-09,6.980059e-07,3.711175e-09,1.472187e-08,3.175144e-09,...,5.670171e-09,3.371311e-08,1.082911e-08,7.374140e-11,1.589416e-12,1.098155e-07,Eo/baso/mast cell progenitor,Eo/baso/mast cell progenitor,Eo/baso/mast cell progenitor,Eo/baso/mast cell progenitor
AAAGTAGTCCGCATCT-1-4834STDY7002879,LYVE1++ macrophage,8.686127e-08,3.933139e-10,1.384100e-07,2.455699e-10,2.286061e-08,1.976017e-12,5.729960e-11,5.051004e-08,6.295254e-09,...,1.143184e-08,1.045524e-04,2.108652e-08,1.197810e-09,2.699467e-14,4.509701e-09,LYVE1++ macrophage,LYVE1++ macrophage,LYVE1++ macrophage,LYVE1++ macrophage
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGTCAAGTGAACGC-1-FCAImmP7964510,Pericytes,8.901155e-08,3.238348e-06,8.571223e-06,1.924168e-08,3.902949e-07,1.215860e-06,5.714410e-07,3.698724e-04,1.986818e-07,...,4.664788e-07,3.326158e-07,1.879099e-07,4.614634e-08,1.732567e-17,3.238165e-07,Pericytes,Pericytes,Pericytes,Pericytes
TTTGTCAGTGCGAAAC-1-FCAImmP7964510,Pericytes,1.477536e-08,1.577190e-03,1.236769e-05,1.134152e-07,1.215027e-06,3.224812e-08,7.083651e-08,1.135012e-06,1.699295e-07,...,2.546215e-07,9.835366e-07,8.720760e-08,2.191711e-08,1.404760e-19,3.613958e-08,Pericytes,Pericytes,Pericytes,Pericytes
TTTGTCATCCATGAGT-1-FCAImmP7964510,Early erythroid,4.532674e-07,4.647544e-07,2.212564e-06,2.974463e-06,1.553380e-06,2.124671e-06,1.731951e-06,2.041023e-06,5.534186e-06,...,2.340904e-06,9.565093e-07,1.015629e-06,1.390459e-06,2.033909e-10,1.078917e-06,Early erythroid,Early erythroid,Early erythroid,Early erythroid
TTTGTCATCGCAAGCC-1-FCAImmP7964510,LE,2.256458e-07,2.015986e-06,1.393328e-01,6.879180e-10,4.974338e-07,2.818557e-11,2.924471e-08,4.774858e-06,7.361988e-04,...,7.698554e-08,7.992969e-09,3.424883e-05,8.645525e-09,9.385402e-16,2.426987e-09,LE,LE,LE,LE


# Label stability scoring for harmonisation
- here we present Log-loss function to compute

In [390]:
import numpy as np
from sklearn.metrics import log_loss
def compute_label_log_losses(df, true_label, pred_columns):
    """
    Compute log loss (cross-entropy loss).
    
    Parameters:
    df : dataframe containing the predicted probabilities and original labels as columns
    true_label : column or array-like of shape (n_samples,) containg cateogrical labels
    pred_columns : columns or array-like of shape (n_samples, n_clases) containg predicted probabilities

    converts to:
    y_true : array-like of shape (n_samples,) True labels. The binary labels in a one-vs-rest fashion.
    y_pred : array-like of shape (n_samples, n_classes) Predicted probabilities. 
        
    Returns:
    log_loss : dictionary of celltype key and float
    weights : float
    """
    log_losses = {}
    y_true = (pd.get_dummies(df[true_label]))
    y_pred = df[pred_col]
    loss = log_loss(np.array(y_true), np.array(y_pred))
    for label in range(y_true.shape[1]):
        log_loss_label = log_loss(np.array(y_true)[:, label], np.array(y_pred)[:, label])
        log_losses[list(y_true.columns)[label]] = (log_loss_label)
    weights = 1/np.array(list(log_losses.values()))
    weights /= np.sum(weights)
    weights = np.array(weights)
    return loss, log_losses, weights

def regression_results(df, true_label, pred_label, pred_columns):
    # Regression metrics
    y_true = df[true_label]
    y_pred = df[pred_label]
    loss, log_losses, weights = compute_label_log_losses(df, true_label, pred_columns)
    mean_absolute_error=metrics.mean_absolute_error(y_true, y_pred) 
    mse=metrics.mean_squared_error(y_true, y_pred) 
    mean_squared_log_error=metrics.mean_squared_log_error(y_true, y_pred)
    median_absolute_error=metrics.median_absolute_error(y_true, y_pred)
#     r2=metrics.r2_score(y_true, y_pred)
    print('Cross entropy loss: ', round(loss,4))    
    print('mean_squared_log_error: ', round(mean_squared_log_error,4))
    print('MAE: ', round(mean_absolute_error,4))
    print('MSE: ', round(mse,4))
    print('RMSE: ', round(np.sqrt(mse),4))
    print('label Cross entropy loss: ')
    print(log_losses)  
    return loss, log_losses, weights

In [377]:
# pred_col shape should match the pred_out original labels, so some self-projection works best here
pred_col = list(pred_out.columns[pred_out.columns.isin(set(pred_out['orig_labels']))])
log_losses, weights = compute_label_log_losses(pred_out, 'orig_labels', pred_col)

In [384]:
y_true = (pd.get_dummies(df[true_label]))
y_pred = df[pred_col]

In [385]:
from sklearn.metrics import log_loss
log_loss(np.array(y_true), np.array(y_pred))

15.211918064809838

In [379]:
log_losses

{'ASDC': 0.002166725283596194,
 'Iron-recycling macrophage': 0.17990448424623018,
 'Adipocytes': 0.2534050368123189,
 'Arterial': 0.028345456842149636,
 'B cell': 0.17682387615268078,
 'Basal': 0.11000984482862723,
 'POSTN+ basal': 0.033250798140136956,
 'CD4 T cell': 0.3915087085051776,
 'CD8 T cell': 0.14488004682886188,
 'Capillary (venular tip)': 0.04270657904168145,
 'Capillary/postcapillary venule': 0.05345640026131761,
 'Companion layer': 0.018185981129859276,
 'Cuticle/cortex': 0.033050182304249914,
 'DC1': 0.26097126075650373,
 'DC2': 0.44995462121622004,
 'Dermal condensate': 0.42455929943498627,
 'Dermal papillia': 0.3954969581789894,
 'Early LE': 0.017147790328675704,
 'Early endothelial cell': 0.014492674947992435,
 'Early erythroid': 0.04130081987719835,
 'Early erythroid (embryonic)': 0.03961955728936468,
 'FRZB+ early fibroblast': 1.2125858973233294,
 'HOXC5+ early fibroblast': 1.4776061019723772,
 'Early myocytes': 0.4424495807573584,
 'Eo/baso/mast cell progenitor': 0

In [373]:
weights

array([1.11354782e-01, 1.34112956e-03, 9.52132696e-04, 8.51195387e-03,
       1.36449458e-03, 2.19321481e-03, 7.25622344e-03, 6.16270381e-04,
       1.66534472e-03, 5.64960309e-03, 4.51349548e-03, 1.32670995e-02,
       7.30026899e-03, 9.24527936e-04, 5.36221231e-04, 5.68295692e-04,
       6.10055820e-04, 1.40703389e-02, 1.66480806e-02, 5.84189906e-03,
       6.08980103e-03, 1.98975777e-04, 1.63287916e-04, 5.45316871e-04,
       3.78550306e-04, 4.52545761e-03, 1.95971811e-03, 8.03457428e-04,
       1.14890768e-04, 2.27153322e-02, 7.73858752e-03, 2.60156033e-03,
       2.25208664e-02, 2.09835627e-01, 1.98591248e-02, 3.44956112e-03,
       1.01745341e-02, 1.35463479e-02, 5.53186420e-03, 4.30501286e-04,
       9.61903563e-03, 3.06123412e-02, 6.49679689e-02, 4.41962773e-03,
       1.27208985e-03, 1.80095227e-04, 1.35015779e-02, 3.14139239e-03,
       2.92026151e-03, 1.15456887e-02, 1.26397677e-02, 3.46800192e-03,
       3.29915941e-03, 1.51388687e-03, 6.93025378e-04, 6.97158895e-04,
      

In [151]:
log_losses

{'ASDC': 0.002166725283596194,
 'Iron-recycling macrophage': 0.17990448424623018,
 'Adipocytes': 0.2534050368123189,
 'Arterial': 0.028345456842149636,
 'B cell': 0.17682387615268078,
 'Basal': 0.11000984482862723,
 'POSTN+ basal': 0.033250798140136956,
 'CD4 T cell': 0.3915087085051776,
 'CD8 T cell': 0.14488004682886188,
 'Capillary (venular tip)': 0.04270657904168145,
 'Capillary/postcapillary venule': 0.05345640026131761,
 'Companion layer': 0.018185981129859276,
 'Cuticle/cortex': 0.033050182304249914,
 'DC1': 0.26097126075650373,
 'DC2': 0.44995462121622004,
 'Dermal condensate': 0.42455929943498627,
 'Dermal papillia': 0.3954969581789894,
 'Early LE': 0.017147790328675704,
 'Early endothelial cell': 0.014492674947992435,
 'Early erythroid': 0.04130081987719835,
 'Early erythroid (embryonic)': 0.03961955728936468,
 'FRZB+ early fibroblast': 1.2125858973233294,
 'HOXC5+ early fibroblast': 1.4776061019723772,
 'Early myocytes': 0.4424495807573584,
 'Eo/baso/mast cell progenitor': 0

In [None]:
# input
df_i = pred_out.copy() # df which contains proba, orig and pred labels
orig_labels = 'orig_labels'
pred_labels = 'predicted'

# function
class_loss = {}
proba = df_i[df_i.columns[(df_i.columns.isin(list(df_i[pred_labels])))]]
idx_map = dict(zip(list((set(list(df_i[orig_labels]) + list(df_i[pred_labels])))),list(range(0,len(list(set(list(df_i[orig_labels]) + list(df_i[pred_labels]))))))))
for class_n in list(set(df_i[orig_labels])):
    class_n_df = df_i.loc[df_i[orig_labels].isin([class_n]),[orig_labels,pred_labels]]
    class_n_proba = proba
    #class_R2[class_n] = regression_results(class_n_df[orig_labels].map(idx_map),class_n_df[pred_labels].map(idx_map))
    class_loss[class_n] = compute_log_loss()

In [75]:
# proba = df_i[df_i.columns[(df_i.columns.isin(list(df_i[pred_labels])))]]
# compute_log_loss()

In [None]:
proba

In [49]:
class_n_df = df_i.loc[df_i[orig_labels].isin([class_n]),[orig_labels,pred_labels]]
class_n_df[orig_labels].map(idx_map)

CTAACTTCACCGAAAG-1-FCAImmP7241241    82
ATCGAGTCAAGTTGTC-1-FCAImmP7528290    82
CAGCGACGTTCTCATT-1-FCAImmP7528291    82
TCATTACCAGCGAACA-1-FCAImmP7555848    82
TCGTAGATCTAACTGG-1-FCAImmP7579213    82
                                     ..
GTGCATATCCTTGCCA-1-FCAImmP7964510    82
GTGGGTCAGGGAACGG-1-FCAImmP7964510    82
GTGTTAGAGCAGCCTC-1-FCAImmP7964510    82
TACTTACAGGCGTACA-1-FCAImmP7964510    82
TGGGCGTAGTCGTTTG-1-FCAImmP7964510    82
Name: orig_labels, Length: 65, dtype: category
Categories (83, int64): [57, 23, 42, 49, ..., 79, 68, 9, 40]

In [17]:
# get regression results per class

AAACCTGGTCAGTGGA-1-4834STDY7002879              LYVE1++ macrophage
AAAGATGGTCGATTGT-1-4834STDY7002879                        Monocyte
AAAGCAAAGATGTGGC-1-4834STDY7002879              LYVE1++ macrophage
AAAGTAGCAGATCGGA-1-4834STDY7002879    Eo/baso/mast cell progenitor
AAAGTAGTCCGCATCT-1-4834STDY7002879              LYVE1++ macrophage
                                                  ...             
TTTGTCAAGTGAACGC-1-FCAImmP7964510                        Pericytes
TTTGTCAGTGCGAAAC-1-FCAImmP7964510                        Pericytes
TTTGTCATCCATGAGT-1-FCAImmP7964510                  Early erythroid
TTTGTCATCGCAAGCC-1-FCAImmP7964510                               LE
TTTGTCATCTGCTTGC-1-FCAImmP7964510                  Early erythroid
Name: predicted, Length: 186533, dtype: object

In [15]:
regression_results()

explained_variance:  0.8923
mean_squared_log_error:  0.1743
r2:  0.8898
MAE:  1.5957
MSE:  50.5453
RMSE:  7.1095


In [None]:
# regression summary
idx_map = dict(zip(  list(adata.obs[feat_use].unique()),list(range(0,len(list(adata.obs[feat_use].unique()))))))
regression_results(adata.obs[feat_use].map(idx_map), adata.obs['predicted'].map(idx_map))

In [None]:
adata.obs['confident_calls'] = pred_out['confident_calls']
adata.obs[cluster_prediction] = adata.obs.index
for z in adata.obs[clusters_reassign].unique():
    df = adata.obs
    df = df[(df[clusters_reassign].isin([z]))]
    df_count = pd.DataFrame(df[lr_predicted_col].value_counts())
    freq_arranged = df_count.index
    cat = freq_arranged[0]
    df.loc[:,cluster_prediction] = cat
    adata.obs.loc[adata.obs[clusters_reassign] == z, [cluster_prediction]] = cat

# View by median probabilities per classification

In [None]:
model_mean_probs = pred_out.loc[:, pred_out.columns != 'predicted'].groupby('orig_labels').median()
model_mean_probs = model_mean_probs*100
model_mean_probs = model_mean_probs.dropna(axis=0, how='any', thresh=None, subset=None, inplace=False)
crs_tbl = model_mean_probs.copy()
# Sort df columns by rows
index_order = list(crs_tbl.max(axis=1).sort_values(ascending=False).index)
col_order = list(crs_tbl.max(axis=0).sort_values(ascending=False).index)
crs_tbl = crs_tbl.loc[index_order]
crs_tbl = crs_tbl[col_order]
# Plot_df_heatmap(crs_tbl, cmap='coolwarm', rotation=90, vmin=20, vmax=70)
pal = sns.diverging_palette(240, 10, n=10)
plt.figure(figsize=(20,15))
sns.set(font_scale=0.5)
g = sns.heatmap(crs_tbl, cmap='viridis_r',  annot=False,vmin=0, vmax=max(np.max(crs_tbl)), linewidths=1, center=max(np.max(crs_tbl))/2, square=True, cbar_kws={"shrink": 0.5})

plt.ylabel("Original labels")
plt.xlabel("Training labels")
#plt.savefig('./ldvae_ver5_lr_model_means_subclusters.pdf',dpi=300)
plt.show()

# View by label assignment

In [None]:
x=feat_use
y = 'predicted'

y_attr = adata_temp.obs[y]
x_attr = adata_temp.obs[x]
crs = pd.crosstab(x_attr, y_attr)
crs_tbl = crs
for col in crs_tbl :
    crs_tbl[col] = crs_tbl[col].div(crs_tbl[col].sum(axis=0)).multiply(100).round(2)
    
index_order = list(crs_tbl.max(axis=1).sort_values(ascending=False).index)
col_order = list(crs_tbl.max(axis=0).sort_values(ascending=False).index)
crs_tbl = crs_tbl.loc[index_order]
crs_tbl = crs_tbl[col_order]

#plot_df_heatmap(crs_tbl, cmap='coolwarm', rotation=90, vmin=20, vmax=70)
pal = sns.diverging_palette(240, 10, n=10)
plt.figure(figsize=(20,15))
sns.set(font_scale=0.8)
g = sns.heatmap(crs_tbl, cmap='viridis_r', vmin=0, vmax=100, linewidths=1, center=50, square=True, cbar_kws={"shrink": 0.3})
plt.xlabel("Original labels")
plt.ylabel("Predicted labels")
# plt.savefig(save_path + "/LR_predictions_consensus.pdf")
# crs_tbl.to_csv(save_path + "/post-freq_LR_predictions_consensus_supp_table.csv")
plt.show()

In [None]:
x='consensus_clus_prediction'
y = 'predicted'

y_attr = adata_temp.obs[y]
x_attr = adata_temp.obs[x]
crs = pd.crosstab(x_attr, y_attr)
crs_tbl = crs
for col in crs_tbl :
    crs_tbl[col] = crs_tbl[col].div(crs_tbl[col].sum(axis=0)).multiply(100).round(2)
    
index_order = list(crs_tbl.max(axis=1).sort_values(ascending=False).index)
col_order = list(crs_tbl.max(axis=0).sort_values(ascending=False).index)
crs_tbl = crs_tbl.loc[index_order]
crs_tbl = crs_tbl[col_order]

#plot_df_heatmap(crs_tbl, cmap='coolwarm', rotation=90, vmin=20, vmax=70)
pal = sns.diverging_palette(240, 10, n=10)
plt.figure(figsize=(20,15))
sns.set(font_scale=0.8)
g = sns.heatmap(crs_tbl, cmap='viridis_r', vmin=0, vmax=100, linewidths=1, center=50, square=True, cbar_kws={"shrink": 0.3})
plt.xlabel("Original labels")
plt.ylabel("Predicted labels")
# plt.savefig(save_path + "/LR_predictions_consensus.pdf")
# crs_tbl.to_csv(save_path + "/post-freq_LR_predictions_consensus_supp_table.csv")
plt.show()

# Save predicted output

In [None]:
pred_out.to_csv('./A1_V3_sk_sk_pred_outs.csv')

In [None]:
# # filter unlikely predictions
# filtered = pred_out[np.max(pred_out.loc[:,~pred_out.columns.isin(['predicted','confident_calls','annot_celltype', 'consensus_clus_prediction', 'orig_labels','clus_prediction_confident'])],axis = 1) > 0.3]
# adata_temp = adata[adata.obs.index.isin(filtered.index)]
# filtered['clus_prediction_confident'] = adata_temp.obs['clus_prediction_confident']

 # Significant contributors to feature effect size per class of model
     - Bear in mind these are only top features..
    - assess the positive descriminators (markers) of the model
    - “…provide information about the magnitude and direction of the difference between two groups or the relationship between two variables.”

In [None]:
list(top_loadings['class'].unique())

In [None]:
top_loadings[top_loadings['class'].isin(['Tip cell (arterial)','HSC','SPP1+ proliferating neuron proneitors'])].groupby(['class']).head(10)

In [None]:
for classes in ['Tip cell (arterial)','HSC','SPP1+ proliferating neuron proneitors']:
    model_class_feature_plots(top_loadings, [str(classes)], 'e^coef')
    plt.show()

In [None]:
plt.rcdefaults()
# plot_states = ['Tip cell (arterial)','HSC','SPP1+ proliferating neuron proneitors']
markers = top_loadings[top_loadings['class'].isin(adata_temp.obs['consensus_clus_prediction'])].groupby(['class']).head(5).groupby(['class'])['feature'].agg(lambda grp: list(grp)).to_dict()
sc.pl.dotplot(adata_temp, groupby = 'consensus_clus_prediction', var_names = markers,standard_scale='var')

In [None]:
top_loadings[top_loadings['class'].isin(['Lymphoid progenitor','Early erythroid (embryonic)','Pre-dermal condensate'])].groupby(['class']).head(15)

# Label confidence scoring
- Let's study label stability given K-neighborhood assignments

**Author notes:** 
-  Hey! If you're reading this, I've probably messed up somewhere and you're looking for an explanation why :) 
- Code blocks marked **Prototype** are usually incomplete or a irresponsible lift from another pipeline, if the source pipeline is already distributed/published, I will leave git links associated with the module.
- If there are no links, there should be some run notes

**Run mode 2 of prototype $alpha$ $beta$ sampling via leverage-score**
- Mode 2 was chosen as we want to define a sampling space which satisfies same KNN distribution and density instead of prioritising variability
- Neighborhood assignment is done via majority voting
- Posterior probability computed and sampling rate for X is determined

# Bayesian KNN label stability
For modelling label uncertainty given neighborhood membership and distances

#### Step 1: Generate Binary Neighborhood Membership Matrix
The first step is to generate a binary neighborhood membership matrix from the connectivity matrix. This is done with the function get_binary_neigh_matrix(connectivities), which takes a connectivity matrix as input and outputs a binary matrix indicating whether a cell is a neighbor of another cell.

The connectivity matrix represents the neighborhood relationships between cells, typically obtained from KNN analysis. In this matrix, each row and column represent a cell, and an entry indicates the 'connectivity' between the corresponding cells.

The function transforms the connectivity matrix into a binary matrix by setting all non-zero values to 1, indicating a neighborhood relationship, and all zero values remain as 0, indicating no neighborhood relationship.

#### Step 2: Calculate Label Counts
Next, the function get_label_counts(neigh_matrix, labels) is used to count the number of occurrences of each label in the neighborhood of each cell. The input to this function is the binary neighborhood membership matrix and a list of labels for each cell.

The function returns a matrix in which each row corresponds to a cell, and each column corresponds to a label. Each entry is the count of cells of a particular label in the neighborhood of a given cell.

#### Step 3: Compute Distance-Entropy Product
In the third step, the function compute_dist_entropy_product(neigh_membership, labels, dist_matrix) computes the product of the average neighborhood distance and the entropy of the label distribution in the neighborhood for each cell and each label.

The entropy of a label distribution in a neighborhood is a measure of the diversity or 'mix' of labels in that neighborhood, with higher entropy indicating a more diverse mix of labels. The average neighborhood distance for a cell is the average distance from that cell to all other cells in its neighborhood.

By multiplying the entropy with the average distance, this function captures two important aspects of the neighborhood:

Entropy: The diversity of labels in a neighborhood. High entropy means the neighborhood is a 'melting pot' of many different labels, while low entropy indicates a neighborhood dominated by a single label.
Distance: The spatial proximity of cells in a neighborhood. A high average distance means the cells in a neighborhood are widely dispersed, while a low average distance indicates a compact, closely-knit neighborhood.
Thus, the distance-entropy product for a cell provides a measure of the 'stability' of the cell's label, with lower values indicating a stable, consistent label and higher values indicating an unstable, inconsistent label.

#### Step 4: Bayesian Sampling and Weight Calculation
The final step is the compute_weights function, which uses Bayesian inference to compute a posterior distribution of the distance-entropy product for each label and calculates the weights.

In Bayesian inference, we start with a prior distribution that represents our initial belief about the parameter we're interested in, and we update this belief using observed data to get a posterior distribution.

In this case, the prior distribution is a normal distribution with mean and standard deviation equal to the mean and standard deviation of the distance-entropy product for the original labels. The observed data is the distance-entropy product for the predicted labels. A normal distribution is a reasonable choice for the prior because the distance-entropy product is a continuous variable that can theoretically take on any real value, and the normal distribution is the most common distribution for such variables.

After sampling from the posterior distribution, the weight for each label is calculated as one minus the ratio of the standard deviation of the posterior distribution to the maximum standard deviation across all labels. This means that labels with a larger standard deviation (indicating greater uncertainty about their stability) will have smaller weights, and labels with a smaller standard deviation (indicating less uncertainty) will have larger weights.

The weights are returned as a dictionary where each key-value pair corresponds to a label and its weight.

#### Step 5: Apply Weights to Probabilities
Finally, the weights are applied to the probability dataframe with the function apply_weights(prob_df, weights). The input to this function is a dataframe where each row corresponds to a cell and each column corresponds to a label, with each entry being the probability of the cell being of the label, and a dictionary of weights.

This function multiplies each column of the probability dataframe by the corresponding weight, effectively 'boosting' the probabilities of labels with larger weights and 'penalizing' the probabilities of labels with smaller weights. After applying the weights, the function normalizes the probabilities so that they sum to 1 for each cell, returning a dataframe of the same shape as the input but with the probabilities weighted and normalized.

Overall, this method provides a principled way to quantify label uncertainty and adjust the probabilities output by a logistic regression model accordingly. It combines the strengths of KNN, which can capture local structure and relationships in the data, and Bayesian inference, which provides a robust framework for dealing with uncertainty and incorporating prior knowledge. By weighting the probabilities according to the stability of the labels, this method can potentially improve the accuracy and interpretability of the logistic regression model's predictions.

In [365]:
import numpy as np
import pandas as pd
import pymc3 as pm
from scipy.sparse import csr_matrix
from scipy.stats import entropy

def get_binary_neigh_matrix(connectivities):
    """
    Converts the connectivities matrix to a binary neighborhood membership matrix.
    """
    return (connectivities > 0).astype(int)

def get_label_counts(neigh_matrix, labels):
    """
    Counts the number of occurrences of each label in the neighborhood of each cell.
    """
    return pd.DataFrame(neigh_matrix.T.dot(pd.get_dummies(labels)))

def compute_dist_entropy_product(neigh_membership, labels, dist_matrix):
    """
    Computes the product of the average neighborhood distance and the entropy
    of the label distribution in the neighborhood for each cell and each label.
    """
    # Count the occurrences of each label in the neighborhood of each cell
    label_counts = get_label_counts(neigh_membership, labels)

    # Compute the entropy of the label distribution in the neighborhood of each cell
    entropy_values = label_counts.apply(entropy, axis=1)

    # Compute the average neighborhood distance for each cell
    avg_distances = dist_matrix.multiply(neigh_membership).mean(axis=1).A1

    # Compute the product of the average distance and the entropy for each cell
    dist_entropy_product = avg_distances * entropy_values

    return dist_entropy_product

class WeightsOutput:
    def __init__(self, weights, rhats, means, sds):
        self.weights = weights
        self.rhats = rhats
        self.means = means
        self.sds = sds

def compute_weights(adata, use_rep, original_labels_col, predicted_labels_col):
    # Extract the necessary data from the anndata object
    obs_met = adata.obs
    neigh_membership = get_binary_neigh_matrix(adata.obsp[adata.uns[use_rep]['connectivities_key']])
    original_labels = obs_met[original_labels_col]
    predicted_labels = obs_met[predicted_labels_col]
    dist_matrix = adata.obsp[adata.uns[use_rep]['distances_key']]

    # Compute the 'distance-entropy' product for each cell and each label
    dist_entropy_product = compute_dist_entropy_product(neigh_membership, predicted_labels, dist_matrix)

    # Compute the 'distance-entropy' product for the original labels
    dist_entropy_product_orig = compute_dist_entropy_product(neigh_membership, original_labels, dist_matrix)

    weights = {}
    rhat_values = {}
    means = []  # Collect all posterior means
    sds = []  # Collect all posterior standard deviations
    for label in np.unique(predicted_labels):
        print("Sampling {} posterior distribution".format(label))
        # Perform Bayesian inference to compute the posterior distribution of the
        # 'distance-entropy' product for this label
        orig_pos = obs_met[original_labels_col].isin([label])
        pred_pos = obs_met[predicted_labels_col].isin([label])
        with pm.Model() as model:
            #priors
            mu = pm.Normal('mu', mu=dist_entropy_product_orig[orig_pos.values].mean(), sd=dist_entropy_product_orig[orig_pos.values].std())
            sd = pm.HalfNormal('sd', sd=dist_entropy_product_orig[orig_pos.values].std())
            #observations
            obs = pm.Normal('obs', mu=mu, sd=sd, observed=dist_entropy_product_orig[pred_pos.values])
            
            if len(orig_pos) > 10000:
                samp_rate = 0.1
                smp = int(len(orig_pos)*samp_rate)
                tne = smp = int(len(orig_pos)*samp_rate)/2
                trace = pm.sample(smp, tune=tne)
            else:
                trace = pm.sample(1000, tune=500)
        # Compute R-hat for this label
        rhat = pm.rhat(trace)
        rhat_values[label] = {var: rhat[var].data for var in rhat.variables}
        # Compute the mean and the standard deviation of the posterior distribution for this label
        mean_posterior = pm.summary(trace)['mean']['mu']
        sd_posterior = pm.summary(trace)['sd']['sd']
        sds.append(sd_posterior)
        means.append(mean_posterior)
        
    # Mean posterior probabilitty models the stability of a label given entropy_distance measures within it's neighborhood
    max_mean = max(means)
    # SD here models the uncertainty of label entropy_distance measures
    max_sd = max(sds)  # Compute the maximum standard deviation
    
    # Compute the weights as the sum of the normalized mean and the normalized standard deviation. This makes each weight relative to each other
    # shift all weights up by epiislon constant
    epsilon = 0.01
    for label, mean, sd in zip(np.unique(predicted_labels), means, sds):
        weights[label] = (1 - mean / max_mean) * (1 - sd / max_sd) + epsilon

    return WeightsOutput(weights, rhat_values, means, sds)

def apply_weights(prob_df, weights):
    """
    Applies the computed weights to the probability dataframe and normalizes the result.
    Parameters:
    prob_df (pd.DataFrame): A dataframe where each row corresponds to a cell and each column corresponds to a label. Each entry is the probability of the cell being of the label.
    weights (dict): A dictionary where each key-value pair corresponds to a label and its weight.

    Returns:
    norm_df (pd.DataFrame): A dataframe of the same shape as prob_df, but with the probabilities weighted and normalized.
    """
    # Apply the weights
    weighted_df = prob_df.mul(weights.weights)
    # Normalize the result
    norm_df = weighted_df.div(weighted_df.sum(axis=1), axis=0)
    return norm_df

weights = compute_weights(adata,use_rep = 'neighbors', original_labels_col ='cell.labels', predicted_labels_col = 'cell.labels')
apply_weights(adata.obsm['cell.labels'],weights)

In [354]:
weights = compute_weights(adata,use_rep = 'neighbors', original_labels_col ='cell.labels', predicted_labels_col = 'cell.labels')
apply_weights(adata.obsm['cell.labels'],weights)

Sampling CMP posterior distribution


  return wrapped_(*args_, **kwargs_)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [sd, mu]


Sampling 4 chains for 500 tune and 1_000 draw iterations (2_000 + 4_000 draws total) took 2 seconds.
The acceptance probability does not match the target. It is 0.8882359954302748, but should be close to 0.8. Try to increase the number of tuning steps.
The acceptance probability does not match the target. It is 0.9152261534843218, but should be close to 0.8. Try to increase the number of tuning steps.
The acceptance probability does not match the target. It is 0.8984155277447178, but should be close to 0.8. Try to increase the number of tuning steps.
Got error No model on context stack. trying to find log_likelihood in translation.
Got error No model on context stack. trying to find log_likelihood in translation.
Got error No model on context stack. trying to find log_likelihood in translation.


Sampling HSPC_1 posterior distribution


  return wrapped_(*args_, **kwargs_)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [sd, mu]


Sampling 4 chains for 500 tune and 1_000 draw iterations (2_000 + 4_000 draws total) took 7 seconds.
The acceptance probability does not match the target. It is 0.948912690284732, but should be close to 0.8. Try to increase the number of tuning steps.
The acceptance probability does not match the target. It is 0.9932458013873032, but should be close to 0.8. Try to increase the number of tuning steps.
The acceptance probability does not match the target. It is 0.8991231533910644, but should be close to 0.8. Try to increase the number of tuning steps.
The acceptance probability does not match the target. It is 0.8955961583272648, but should be close to 0.8. Try to increase the number of tuning steps.
Got error No model on context stack. trying to find log_likelihood in translation.
Got error No model on context stack. trying to find log_likelihood in translation.
Got error No model on context stack. trying to find log_likelihood in translation.


Sampling HSPC_2 posterior distribution


  return wrapped_(*args_, **kwargs_)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [sd, mu]


Sampling 4 chains for 500 tune and 1_000 draw iterations (2_000 + 4_000 draws total) took 2 seconds.
The acceptance probability does not match the target. It is 0.8822058084837736, but should be close to 0.8. Try to increase the number of tuning steps.
The acceptance probability does not match the target. It is 0.890976176435348, but should be close to 0.8. Try to increase the number of tuning steps.
The acceptance probability does not match the target. It is 0.9033886891848927, but should be close to 0.8. Try to increase the number of tuning steps.
The acceptance probability does not match the target. It is 0.9106460334686525, but should be close to 0.8. Try to increase the number of tuning steps.
Got error No model on context stack. trying to find log_likelihood in translation.
Got error No model on context stack. trying to find log_likelihood in translation.
Got error No model on context stack. trying to find log_likelihood in translation.


Sampling MOP posterior distribution


  return wrapped_(*args_, **kwargs_)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [sd, mu]


Sampling 4 chains for 500 tune and 1_000 draw iterations (2_000 + 4_000 draws total) took 2 seconds.
The acceptance probability does not match the target. It is 0.8925428920498023, but should be close to 0.8. Try to increase the number of tuning steps.
Got error No model on context stack. trying to find log_likelihood in translation.
Got error No model on context stack. trying to find log_likelihood in translation.
Got error No model on context stack. trying to find log_likelihood in translation.


Sampling Macrophage posterior distribution


  return wrapped_(*args_, **kwargs_)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [sd, mu]


Sampling 4 chains for 500 tune and 1_000 draw iterations (2_000 + 4_000 draws total) took 4 seconds.
The acceptance probability does not match the target. It is 0.9146022607460513, but should be close to 0.8. Try to increase the number of tuning steps.
The acceptance probability does not match the target. It is 0.9775637974453769, but should be close to 0.8. Try to increase the number of tuning steps.
The acceptance probability does not match the target. It is 0.8968273691723939, but should be close to 0.8. Try to increase the number of tuning steps.
The acceptance probability does not match the target. It is 0.9335216100559266, but should be close to 0.8. Try to increase the number of tuning steps.
Got error No model on context stack. trying to find log_likelihood in translation.
Got error No model on context stack. trying to find log_likelihood in translation.
Got error No model on context stack. trying to find log_likelihood in translation.


Sampling Pre_Macrophage posterior distribution


  return wrapped_(*args_, **kwargs_)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [sd, mu]


Sampling 4 chains for 500 tune and 1_000 draw iterations (2_000 + 4_000 draws total) took 2 seconds.
The acceptance probability does not match the target. It is 0.8793731411008892, but should be close to 0.8. Try to increase the number of tuning steps.
The acceptance probability does not match the target. It is 0.9069842863259212, but should be close to 0.8. Try to increase the number of tuning steps.
The acceptance probability does not match the target. It is 0.8885530536399899, but should be close to 0.8. Try to increase the number of tuning steps.
Got error No model on context stack. trying to find log_likelihood in translation.
Got error No model on context stack. trying to find log_likelihood in translation.
Got error No model on context stack. trying to find log_likelihood in translation.


The proposed method is a Bayesian Gaussian Process (GP) model that uses prior information from a probabilistic classifier and a K-nearest neighbors (KNN) graph to assess label stability in single-cell data. This approach takes advantage of the strengths of both the Bayesian statistical framework and machine learning to provide a flexible and robust way to evaluate label stability in high-dimensional data, such as single-cell transcriptomics.

The Bayesian GP model is designed to model the label stability as a latent function over the single-cell data, which is influenced by both the original labels and the predicted labels from the probabilistic classifier. The model uses a Gaussian Process to represent the latent function, which is a flexible non-parametric model that can capture complex patterns in the data. The Gaussian Process is defined by a covariance function, which determines the similarity between the cells based on their features.

The covariance function used in the model is the exponential quadratic function, also known as the squared exponential or Gaussian kernel. This function is a popular choice in Gaussian Process models due to its flexibility and its ability to model a wide range of patterns in the data. It is defined as:

$$[k(x, x') = \sigma^2 \exp\left(-\frac{(x - x')^2}{2l^2}\right)$$

where \(x\) and \(x'\) are two cells, \(\sigma^2\) is the signal variance (which determines the average distance of the function away from its mean), and \(l\) is the length-scale (which determines the smoothness of the function).

The length-scale and the noise term are hyperparameters of the model, which are learned from the data using Bayesian inference. The length-scale is modeled as a Gamma distribution, which is a common choice for positive-valued parameters. The noise term is modeled as a Half-Cauchy distribution, which is a heavy-tailed distribution that is often used for scale parameters in Bayesian models.

The model uses the log-odds of the labels as the observable variable, which is a common choice for binary or multi-class classification problems. The log-odds is defined as the logarithm of the odds ratio, which is the ratio of the probability of the label being true to the probability of the label being false. For a binary classification problem, the log-odds is defined as:


$$\log\left(\frac{p}{1 - p}\right)\$$

where \(p\) is the probability of the label being true.

The strength of this approach lies in its ability to model the uncertainty in the label assignments and to incorporate prior information from the probabilistic classifier and the KNN graph. By using Bayesian inference, the model is able to provide a measure of uncertainty for the label stability, which can be used to assess the reliability of the label assignments.

This model is designed to output a measure of label stability for each cell, which can be interpreted as the likelihood that a cell forms consistent neighborhoods in the KNN graph. This measure of label stability can then be used as weights for the output of the probabilistic classifier, to prioritize labels that are more likely to form consistent neighborhoods.

In terms of feasibility, this approach requires the computation of the KNN graph and the inference of the Bayesian GP model, which can be computationally intensive for large datasets. However, the use of modern Bayesian computational techniques, such as variational inference and Markov chain Monte Carlo (MCMC) methods, can make this approach feasible for large single-cell datasets. Furthermore, the model's output can be easily integrated with the output of the probabilistic classifier, providing a way to improve the classifier's performance by taking into account the label stability in the data.