In [None]:
import os
os.chdir('/lustre/scratch/kiviaho/prostate_spatial')

import numpy as np
import anndata as ad
import scanpy as sc
import pandas as pd
import infercnvpy as cnv
import matplotlib.pyplot as plt
from pathlib import Path
from scripts.utils import load_from_pickle, save_to_pickle
from scipy.stats import chi2_contingency
import warnings
warnings.filterwarnings('ignore')

from matplotlib import rcParams
sc.set_figure_params(figsize=(6,6),dpi=80)

In [None]:

def add_cluster_phenotype_column(adata,cluster_col='VI_clusters',aggregate_col='phenotype',added_col='cluster_phenotype'):
    '''
    Chi-squared test-based annotation of clusters. Creates a cross-tabulation of phenotypes for one cluster at a time,
    Tests for over-representation and annotates the cluster as the phenotype with the highest positive difference to expected counts
    '''
    meta = adata.obs.copy()
    meta[added_col] = meta[cluster_col].copy()

    for cluster in np.unique(meta[cluster_col]):
        meta['keep_one'] = meta[cluster_col].copy()
        meta['keep_one'] = meta['keep_one'].cat.add_categories('other')
        meta.loc[meta['keep_one'] != cluster,'keep_one'] = 'other'
        tbl = pd.crosstab(meta['keep_one'],meta[aggregate_col])
        res = chi2_contingency(tbl)
        if res.pvalue < 0.05:
            diff = tbl-res.expected_freq
            phtype = diff.loc[cluster].idxmax()

            meta[added_col] = meta[added_col].cat.add_categories(phtype + '_' + cluster)
            meta.loc[meta[added_col] == cluster,added_col] = phtype + '_' + cluster

    meta.drop(columns='keep_one',inplace=True)
    meta[added_col] = meta[added_col].cat.remove_unused_categories()

    if (meta.index == adata.obs_names).all():
        adata.obs = meta
    return adata
    
def plot_stacked_bar(data,sum_variable='phenotype',plot_variable='VI_clusters',filter_kw='',plot_legend=False):
    # This snippet plots the scanorama leiden clusters contents by phenotype as a normalized bar plot.
    plot_data = data.obs[[sum_variable,plot_variable]]

    plot_data = plot_data.groupby(sum_variable)[plot_variable].value_counts()
    plot_data = plot_data.unstack(sum_variable)
    plot_data = plot_data.div(plot_data.sum(axis=1), axis=0)

    # Order according to decreasing number of sample represented
    #cat_order = (plot_data != 0).sum(axis=1).sort_values(ascending=True).index
    cat_order = plot_data.max(axis=1).sort_values(ascending=False).index
    plot_data = plot_data.reindex(cat_order)
    

    if filter_kw !='':
        plot_data = plot_data.loc[[name for name in plot_data.index if filter_kw in name]]
        plot_data.index = plot_data.index.remove_unused_categories()

    if plot_legend == True:
        plot_data.plot.barh(stacked=True,figsize=(12,8),grid=False).legend(loc='center left',bbox_to_anchor=(1.0, 0.5))
    else:
        plot_data.plot.barh(stacked=True,figsize=(12,8),grid=False,legend=False,sort_columns=False)
    return plot_data


In [None]:
## MARKERS
epithelial_markers = {'Basal':['TP63','KRT14','KRT5'],
                      'Club':['SCGB1A3','WFDC2','LCN2','MMP7','KRT4','TACSTD2','SCGB3A1'], #  'SCGB1A3',
                      'Hillock':['KRT13','S100A16','S100A14','KRT19'],
                      'Luminal':['KLK4','KLK3','KLK2','ACPP','AR'],
                      'Tumor':['AMACR','CACNA1D','PCA3','ERG','FABP5','COL9A2','GCNT1','PHGR1'],
                      
                      'EMT':['COL5A2','ECM1','FSTL1','MMP1','MMP2','TAGLN','VIM','SERPINH1','COL1A1',
                      'FN1','TNC','HTRA1','CD44','S100A4','MYL9','ACTG2','ACTA2','MYH11','ERG','CDH2',
                      'HIF1A','TGFBR1','SDC1','ENOPH1','CAMK2N1','EMP3','MKI67','ITGAM','ANXA5','BMP1',
                      'CHD11','FAP','LEF1','IGFBP4','BGN','TWIST1','MCM7','PRRX1','COL3A1','COL1A2',
                      'POSTN','DCN','FBN1','SNAI2','PDGFRB','SPARC','INHBA','COL6A2','TNFAIP6','GREM1',
                      'CDH11','SPOCK1','COPZ2','THY1','PCOLCE','PDGFRD']}

epithelial_markers_plotting = {'Basal':['TP63','KRT14','KRT5'],
                      'Club':['WFDC2','LCN2','MMP7','KRT4','TACSTD2','SCGB3A1'], #  'SCGB1A3',
                      'Hillock':['KRT13','S100A16','S100A14','KRT19'],
                      'Luminal':['KLK4','KLK3','KLK2','ACPP','AR'],
                      'Tumor':['AMACR','CACNA1D','PCA3','ERG','FABP5','COL9A2','GCNT1','PHGR1']}

broad_marker_genes = {'Epithelial':['S100A16','S100A14','TACSTD2','KLK4','KLK3','KLK2','ACPP','AR','TMPRSS2'],
                      'Fibroblast':['DCN','LUM','PTN','IGF1','APOD','COL1A2','FBLN1','MEG3','CXCL12'],
                      'Pericyte':['RGS5','ACTA2','MYH11','MT1M','FRZB','MT1A','NDUFA4L2','PPP1R14A','MYLK','PHLDA1'],
                      'Endothelial':['VWF','ENG','CLDN5'],
                      'Mast':['MS4A2','TPSAB1','CPA3'],
                      'Monocytic':['LYZ','FCGR3A','CSF1R','CD68','CD14','CD163','C1QA','C1QB','C1QC','GPR34','MS4A4A'],
                      'T_cell':['CD4','PTPRC','IL7R','CD7','CD2','CD3G','CD3E','CD3D'],
                      'B_cell':['CD79A','CD79B','VPREB3','BANK1'], #'MS4A1', 'CD19','IGLL5'
                      'Plasma_cell':['MZB1','DNAJB9'], #'IGJ', 'MGP1', 'SEC11C','XBP1','PRDX4','SPCS2','SSR3','SDF2L1','MANF','TMEM258',
                      #'MDC':['PKIB','INSIG1','CLEC10A','C15orf48','PPA1'] # 'CD1C',
                      } 


## All cell types in the integration

In [None]:
adata = load_from_pickle('all-scvi-integrated-6-sc-datasets-with-infercnv.pickle')
adata = add_cluster_phenotype_column(adata)

In [None]:
sc.tl.dendrogram(adata,groupby='cluster_phenotype',use_rep='X_scVI')
sc.pl.dotplot(adata, broad_marker_genes, groupby='cluster_phenotype', dendrogram=True, log= False,
              swap_axes = True, vmax=4)

In [None]:

# Find marker genes for unresolved clusters
unidentified_clusters = ['CRPC_33','CRPC_37','CRPC_41']

#sc.tl.rank_genes_groups(adata, groupby='cluster_phenotype', groups= unidentified_clusters,method='t-test', n_genes=50) # 

for s in unidentified_clusters:
    print(s+': ')
    print(adata[adata.obs['cluster_phenotype'] == s].obs['sample'].value_counts())
    
rcParams['figure.figsize'] = 4,4
rcParams['axes.grid'] = True
sc.pl.rank_genes_groups(adata)

In [None]:
# Modify the labels according to the dotplot
detailed_celltypes = adata.obs['cluster_phenotype'].copy()


'''
detailed_celltypes = detailed_celltypes.replace(['PCa_30','normal_8','PCa_9','PCa_55','PCa_28','PCa_54','normal_25','CRPC_34','PCa_40','PCa_19'],'Epithelial_1')
detailed_celltypes = detailed_celltypes.replace(['CRPC_21','PCa_11','PCa_42','CRPC_47','PCa_44','PCa_49'],'Epithelial_2')
detailed_celltypes = detailed_celltypes.replace(['PCa_61','PCa_22','PCa_62'],'Epithelial_3')
detailed_celltypes = detailed_celltypes.replace(['CRPC_57','normal_32','CRPC_27','normal_59','PCa_18','PCa_2','CRPC_51'],'Epithelial_4')
detailed_celltypes = detailed_celltypes.replace(['PCa_35','CRPC_50'],'Epithelial_5')
'''
detailed_celltypes = detailed_celltypes.replace(['PCa_30','normal_8','PCa_9','PCa_55','PCa_28','PCa_54','normal_25','CRPC_34','PCa_40','PCa_19',
                                                'CRPC_21','PCa_11','PCa_42','CRPC_47','PCa_44','PCa_49',
                                                'PCa_61','PCa_22','PCa_62',
                                                'CRPC_57','normal_32','CRPC_27','normal_59','PCa_18','PCa_2','CRPC_51',
                                                'PCa_35','CRPC_50'],'Epithelial')

detailed_celltypes = detailed_celltypes.replace(['PCa_56','CRPC_12'],'Fibroblast')
detailed_celltypes = detailed_celltypes.replace(['normal_36','PCa_4'],'Pericyte')
detailed_celltypes = detailed_celltypes.replace(['PCa_0','PCa_60'],'Endothelial')
detailed_celltypes = detailed_celltypes.replace(['PCa_15','CRPC_38'],'Mast')
detailed_celltypes = detailed_celltypes.replace(['PCa_1','normal_29','CRPC_52','PCa_53'],'Monocytic')
detailed_celltypes = detailed_celltypes.replace(['PCa_43','normal_7','normal_16','normal_23','normal_14','normal_3','normal_13','PCa_5','normal_17',
                                                'PCa_20','normal_31','PCa_26','CRPC_6','normal_24','normal_39'],'T_cell')
detailed_celltypes = detailed_celltypes.replace(['normal_10','PCa_46','CRPC_45'],'B_cell')
detailed_celltypes = detailed_celltypes.replace(['normal_48','normal_58'],'Plasma')

# Unclassified
detailed_celltypes = detailed_celltypes.replace(['CRPC_33'],'Ribosomal_CRPC')
detailed_celltypes = detailed_celltypes.replace(['CRPC_37'],'Fibroblast_MEG3')
detailed_celltypes = detailed_celltypes.replace(['CRPC_41'],'Fibroblast_ASPSCR1')



adata.obs['detailed_celltypes'] = detailed_celltypes

sc.tl.dendrogram(adata,groupby='detailed_celltypes',use_rep='X_scVI')
sc.pl.dotplot(adata, broad_marker_genes, groupby='detailed_celltypes', dendrogram=True, log= False,
              swap_axes = True, vmax=4)

In [None]:
epithelial_subset = adata[adata.obs['detailed_celltypes'].str.contains('Epithelial')]

In [None]:
for k in epithelial_markers.keys():
    sc.tl.score_genes(epithelial_subset,epithelial_markers[k],score_name=k+'_score')
    

In [None]:
df_sorted = epithelial_subset.obs[['cluster_phenotype',
 'Basal_score',
 'Luminal_score',
 'Club_score',
 'Hillock_score']]
df_sorted = df_sorted.groupby(['cluster_phenotype']).mean()
df_sorted[df_sorted>0]

In [None]:
df_EMT_tumor = epithelial_subset.obs[['cluster_phenotype',
 'EMT_score',
 'Tumor_score',]]
df_EMT_tumor = df_EMT_tumor.groupby(['cluster_phenotype']).mean()
df_EMT_tumor[df_EMT_tumor>0]

In [None]:
annotations = {}
for idx in range(len(df_sorted)):
    names = []
    vec = df_sorted.iloc[idx].sort_values(ascending=False)
    if vec[0] > 0:
        names.append(vec.index[0])
    if vec[1] > 0:
        names.append(vec.index[1])
    annot = '_'.join(sorted([s.strip('_score')for s in names]))
    annotations[vec.name] = annot

EMT_tumor_annot = df_EMT_tumor[df_EMT_tumor>0].idxmax(axis=1)
EMT_tumor_annot = EMT_tumor_annot.fillna('')
EMT_tumor_annot = [a[:-6] if a.endswith('_score') else a for a in EMT_tumor_annot]

# Merge with the epithelial cell type annotations
ks = list(annotations.keys())
for i in range(len(EMT_tumor_annot)):
    if EMT_tumor_annot[i] != '':
        annotations[ks[i]] = ('_').join([annotations[ks[i]],EMT_tumor_annot[i]])
annotations

In [None]:
epithelial_subset.obs['epithelial_celltypes'] = epithelial_subset.obs['cluster_phenotype'].map(annotations)
epithelial_subset.obs['epithelial_celltypes'] = pd.Categorical(epithelial_subset.obs['epithelial_celltypes'])

sc.tl.dendrogram(epithelial_subset,groupby='epithelial_celltypes',use_rep='X_scVI')
sc.pl.dotplot(epithelial_subset, epithelial_markers_plotting, groupby='epithelial_celltypes', dendrogram=True, log= False,
              swap_axes = True, vmax=4)

In [None]:
#sample_order = epithelial_subset.obs[['cluster_phenotype','Luminal_score']].groupby(['cluster_phenotype']).mean().sort_values('Luminal_score', ascending=False).index

sc.set_figure_params(figsize=(16,8),dpi=80)
for k in epithelial_markers.keys():
    sample_order = epithelial_subset.obs[['epithelial_celltypes',k+'_score']].groupby(['epithelial_celltypes']).mean().sort_values(k+'_score', ascending=False).index
    sc.pl.violin(epithelial_subset,groupby='epithelial_celltypes',keys=k+'_score',order=sample_order,rotation=45)


In [None]:
#adata.obs['detailed_celltypes'] = adata.obs['detailed_celltypes'].cat.add_categories(np.unique(list(annotations.values())))
adata.obs['detailed_celltypes'] = adata.obs['detailed_celltypes'].cat.remove_unused_categories()
for k in list(annotations.keys()):
    adata.obs.loc[adata.obs['cluster_phenotype']==k,'detailed_celltypes'] = annotations[k]

sc.tl.dendrogram(adata,groupby='detailed_celltypes',use_rep='X_scVI')
sc.pl.dotplot(adata, broad_marker_genes, groupby='detailed_celltypes', dendrogram=True, log= False,
              swap_axes = True, vmax=4)

In [None]:
adata.obs['detailed_celltypes'].value_counts()

In [None]:
adata.write('./single-cell-reference-with-revised-cell-types-20230322.h5ad')