# LIANA tumor vs normal core atlas

## Libraries

In [1]:
import numpy as  np
import pandas as pd
import scanpy as sc
import decoupler as dc

# import liana
import liana as li
from liana.method import singlecellsignalr, connectome, cellphonedb, natmi, logfc, cellchat, geometric_mean
import sc_atlas_helpers as ah
#from scanpy_helper_submodule import scanpy_helpers as sh

In [2]:
from tqdm.auto import tqdm
import contextlib
import os
import statsmodels.stats.multitest
import numpy as np
from anndata import AnnData
import scipy.sparse



## Define paths

In [3]:
# Core atlas
adata =sc.read_h5ad("/data/projects/2022/CRCA/results/v1/final/h5ads/mui_innsbruck-adata.h5ad")

In [4]:
adata

AnnData object with n_obs × n_vars = 126991 × 19793
    obs: 'study_id', 'dataset', 'sample_id', 'sample_type', 'tumor_source', 'sample_tissue', 'anatomic_region', 'anatomic_location', 'tumor_stage_TNM', 'tumor_stage_TNM_T', 'tumor_stage_TNM_N', 'tumor_stage_TNM_M', 'tumor_grade', 'histological_type', 'microsatellite_status', 'patient_id', 'sex', 'age', 'treatment_status', 'platform', 'reference_genome', 'matrix_type', 'enrichment_cell_types', 'tissue_cell_state', 'tissue_processing_lab', 'hospital_location', 'country', 'KRAS_status_driver_mut', 'NRAS_status_driver_mut', 'BRAF_status_driver_mut', 'HER2_status_driver_mut', 'panTRK_status_driver_mut', 'AKT1_status_driver_mut', 'TP53_status_driver_mut', 'CTNNB1_status_driver_mut', 'ABL1_status_driver_mut', 'RET_status_driver_mut', 'Tumor budding', 'Cell_Type_Experimental', 'Sample_Tag', 'Sample_Name', 'BD-Rhapsody File ID', 'n_counts', 'n_genes', '_scvi_batch', '_scvi_labels', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts'

In [5]:
set(adata.obs.cell_type_fine)

{'B cell activated',
 'B cell activated naive',
 'B cell memory',
 'B cell naive',
 'BN1',
 'BN2',
 'BN3',
 'CD4',
 'CD4 cycling',
 'CD4 naive',
 'CD8',
 'CD8 cycling',
 'CD8 naive',
 'Cancer BEST4',
 'Cancer Colonocyte-like',
 'Cancer Crypt-like',
 'Cancer Goblet-like',
 'Cancer TA-like',
 'Cancer cell circulating',
 'Colonocyte',
 'Colonocyte BEST4',
 'Crypt cell',
 'DC mature',
 'DC3',
 'Endothelial arterial',
 'Endothelial lymphatic',
 'Endothelial venous',
 'Enteroendocrine',
 'Eosinophil',
 'Fibroblast S1',
 'Fibroblast S2',
 'Fibroblast S3',
 'GC B cell',
 'Goblet',
 'Granulocyte progenitor',
 'Macrophage',
 'Macrophage cycling',
 'Mast cell',
 'Monocyte',
 'NK',
 'NKT',
 'Neutrophil',
 'Pericyte',
 'Plasma IgA',
 'Plasma IgG',
 'Plasma IgM',
 'Plasmablast',
 'Platelet',
 'Schwann cell',
 'TA progenitor',
 'TAN1',
 'TAN2',
 'TAN3',
 'TAN4',
 'Treg',
 'Tuft',
 'cDC progenitor',
 'cDC1',
 'cDC2',
 'gamma-delta',
 'pDC'}

In [6]:
set(adata.obs.cell_type_middle)

{'B cell',
 'CD4',
 'CD8',
 'Cancer cell',
 'Cancer cell circulating',
 'Dendritic cell',
 'Endothelial cell',
 'Enteroendocrine',
 'Eosinophil',
 'Epithelial cell',
 'Epithelial progenitor',
 'Fibroblast',
 'Goblet',
 'Macrophage',
 'Mast cell',
 'Monocyte',
 'NK',
 'NKT',
 'Neutrophil',
 'Pericyte',
 'Plasma cell',
 'Platelet',
 'Schwann cell',
 'Treg',
 'Tuft',
 'gamma-delta'}

In [7]:
#Create new column with neutrophil subset and cell_type_middle categories
specific_values = {'BN1', 'BN2', 'BN3', 'TAN1', 'TAN2', 'TAN3', 'TAN4'}
adata.obs['cell_type_sub'] = adata.obs.apply(
    lambda row: row['cell_type_fine'] if row['cell_type_fine'] in specific_values else row['cell_type_middle'],
    axis=1
)

In [None]:
#adata.obs['cell_type_sub_new'] = adata.obs['cell_type_sub'].replace(
 #   {'TAN1': 'TAN', 'TAN2': 'TAN', 'TAN3': 'TAN', 'TAN4': 'TAN'}
#)

In [None]:
set(adata.obs.cell_type_sub)

In [None]:
#set(adata.obs.cell_type_sub_new)

In [None]:
set(adata.obs.sample_type)

## Define comparison: tumor vs blood

In [8]:
comparison="new" #//immune_type
subset = "neutrophil" #//neutrophil_subclusters

In [9]:
cell_type_oi = "new"
n_top_ligands = 30

In [10]:
resDir = f"/data/projects/2022/CRCA/results/v1/final/liana_cell2cell/{subset}/{comparison}"

In [11]:
resDir

'/data/projects/2022/CRCA/results/v1/final/liana_cell2cell/neutrophil/new'

In [12]:
if comparison =="TAN1_blood":
    perturbation = comparison.split("_")[0].upper()
    baseline = comparison.split("_")[1].upper()
    title_plot = f"{perturbation} vs {baseline}: {cell_type_oi}, top {n_top_ligands} DE ligands"
    cell_type_oi = cell_type_oi.replace(" ","")
    save_name_plot =  f"{perturbation}_vs_{baseline}_{cell_type_oi}_top_{n_top_ligands}_DE_ligands"
elif comparison=="tumor_normal":
    perturbation = comparison.split("_")[0].upper()
    baseline = comparison.split("_")[1].upper()
    title_plot = f"{perturbation} vs {baseline}: {cell_type_oi}, top {n_top_ligands} DE ligands"
    cell_type_oi = cell_type_oi.replace(" ","")
    save_name_plot =  f"{perturbation}_vs_{baseline}_{cell_type_oi}_top_{n_top_ligands}_DE_ligands"

## Pseudobulk

In [None]:
adata

In [13]:
set(adata.obs.cell_type_sub)

{'B cell',
 'BN1',
 'BN2',
 'BN3',
 'CD4',
 'CD8',
 'Cancer cell',
 'Cancer cell circulating',
 'Dendritic cell',
 'Endothelial cell',
 'Enteroendocrine',
 'Eosinophil',
 'Epithelial cell',
 'Epithelial progenitor',
 'Fibroblast',
 'Goblet',
 'Macrophage',
 'Mast cell',
 'Monocyte',
 'NK',
 'NKT',
 'Neutrophil',
 'Pericyte',
 'Plasma cell',
 'Platelet',
 'Schwann cell',
 'TAN1',
 'TAN2',
 'TAN3',
 'TAN4',
 'Treg',
 'Tuft',
 'gamma-delta'}

In [14]:
adata_original = adata.copy()

In [15]:
# Filter for only blood cell type in cell_type_sub 
imuune_cells_blood = ['B cell',
 'BN1',
 'BN2',
 'BN3',
 'CD4',
 'CD8',
 'Cancer cell',
 'Cancer cell circulating',
 'Dendritic cell',
 'Endothelial cell',
 'Enteroendocrine',
 'Eosinophil',
 'Epithelial cell',
 'Epithelial progenitor',
 'Fibroblast',
 'Macrophage',
 'Mast cell',
 'Monocyte',
 'Plasma cell',
 'Platelet',
 'TAN1',
 'TAN2',
 'TAN3',
 'TAN4',
 'Treg',] 


adata = adata[adata.obs.cell_type_sub.isin(imuune_cells_blood)]

In [None]:
# Filter for only blood cell type in cell_type_sub 
#imuune_cells_blood = ['B cell',
# 'BN1',
# 'BN2',
# 'BN3',
# 'CD4',
# 'CD8',
# 'Cancer cell circulating',
# 'Dendritic cell',
# 'Eosinophil',
# 'Macrophage',
# 'Monocyte',
# 'Platelet',
# 'Treg','TAN1','TAN2','TAN3','TAN4'] 
#
#
#adata = adata[adata.obs.cell_type_sub.isin(imuune_cells_blood)]

In [None]:
# Filter for only blood cell type in cell_type_sub 
#imuune_cells_blood = ['B cell',
# 'BN1',
# 'BN2',
# 'BN3',
# 'CD4',
# 'CD8',
# 'Cancer cell circulating',
# 'Dendritic cell',
# 'Eosinophil',
# 'Macrophage',
# 'Monocyte',
# 'Platelet',
# 'Treg','TAN']
#adata = adata[adata.obs.cell_type_sub_new.isin(imuune_cells_blood)]

In [16]:
set(adata.obs.cell_type_sub)

{'B cell',
 'BN1',
 'BN2',
 'BN3',
 'CD4',
 'CD8',
 'Cancer cell',
 'Cancer cell circulating',
 'Dendritic cell',
 'Endothelial cell',
 'Enteroendocrine',
 'Eosinophil',
 'Epithelial cell',
 'Epithelial progenitor',
 'Fibroblast',
 'Macrophage',
 'Mast cell',
 'Monocyte',
 'Plasma cell',
 'Platelet',
 'TAN1',
 'TAN2',
 'TAN3',
 'TAN4',
 'Treg'}

In [17]:
set(adata.obs.cell_type_middle)

{'B cell',
 'CD4',
 'CD8',
 'Cancer cell',
 'Cancer cell circulating',
 'Dendritic cell',
 'Endothelial cell',
 'Enteroendocrine',
 'Eosinophil',
 'Epithelial cell',
 'Epithelial progenitor',
 'Fibroblast',
 'Macrophage',
 'Mast cell',
 'Monocyte',
 'Neutrophil',
 'Plasma cell',
 'Platelet',
 'Treg'}

In [None]:
#set(adata.obs.cell_type_sub_new)

In [18]:
adata.obs.sample_type.value_counts() #only neutro 

sample_type
tumor     53322
normal    32248
blood     26235
Name: count, dtype: int64

In [19]:
# Step 1: Condition 1 - Rows where 'cell_type_middle' is 'Neutrophil' and 'sample_type' is 'blood' or 'tumor'
condition_1 = (adata.obs['cell_type_middle'] == 'Neutrophil') & (adata.obs['sample_type'].isin(['blood', 'tumor']))

# Step 2: Condition 2 - Rows where 'cell_type_middle' is not 'Neutrophil' and 'sample_type' is 'normal' or 'tumor'
condition_2 = (adata.obs['cell_type_middle'] != 'Neutrophil') & (adata.obs['sample_type'].isin(['normal', 'tumor']))

# Step 3: Combine both conditions
combined_condition = condition_1 | condition_2

# Step 4: Subset the AnnData object based on the combined condition
adata= adata[combined_condition].copy()

In [20]:
adata.obs.sample_type.value_counts() #only neutro 

sample_type
tumor     53322
normal    32090
blood     14365
Name: count, dtype: int64

In [21]:
set(adata.obs.sample_type)

{'blood', 'normal', 'tumor'}

In [22]:
adata.obs.groupby(["patient_id", "sample_type"]).count()



Unnamed: 0_level_0,Unnamed: 1_level_0,study_id,dataset,sample_id,tumor_source,sample_tissue,anatomic_region,anatomic_location,tumor_stage_TNM,tumor_stage_TNM_T,tumor_stage_TNM_N,...,total_counts,pct_counts_in_top_20_genes,pct_counts_mito,S_score,G2M_score,phase,cell_type_coarse,cell_type_middle,cell_type_fine,cell_type_sub
patient_id,sample_type,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
P3,blood,645,645,645,645,645,645,645,645,645,645,...,645,645,645,645,645,645,645,645,645,645
P3,normal,1895,1895,1895,1895,1895,1895,1895,1895,1895,1895,...,1895,1895,1895,1895,1895,1895,1895,1895,1895,1895
P3,tumor,3125,3125,3125,3125,3125,3125,3125,3125,3125,3125,...,3125,3125,3125,3125,3125,3125,3125,3125,3125,3125
P4,blood,747,747,747,747,747,747,747,747,747,747,...,747,747,747,747,747,747,747,747,747,747
P4,normal,2661,2661,2661,2661,2661,2661,2661,2661,2661,2661,...,2661,2661,2661,2661,2661,2661,2661,2661,2661,2661
P4,tumor,4460,4460,4460,4460,4460,4460,4460,4460,4460,4460,...,4460,4460,4460,4460,4460,4460,4460,4460,4460,4460
P5,blood,1600,1600,1600,1600,1600,1600,1600,1600,1600,1600,...,1600,1600,1600,1600,1600,1600,1600,1600,1600,1600
P5,normal,1966,1966,1966,1966,1966,1966,1966,1966,1966,1966,...,1966,1966,1966,1966,1966,1966,1966,1966,1966,1966
P5,tumor,6360,6360,6360,6360,6360,6360,6360,6360,6360,6360,...,6360,6360,6360,6360,6360,6360,6360,6360,6360,6360
P7,blood,1052,1052,1052,1052,1052,1052,1052,1052,1052,1052,...,1052,1052,1052,1052,1052,1052,1052,1052,1052,1052


In [23]:
pdata = dc.get_pseudobulk(
    adata,
    sample_col='sample_id',
    groups_col=['sample_type',"cell_type_sub"],
    layer='counts',
    mode='sum',
    min_cells=0,
    min_counts=0
)



In [26]:
set(pdata.obs.cell_type_sub)

{'B cell',
 'BN1',
 'BN2',
 'BN3',
 'CD4',
 'CD8',
 'Cancer cell',
 'Dendritic cell',
 'Endothelial cell',
 'Enteroendocrine',
 'Eosinophil',
 'Epithelial cell',
 'Epithelial progenitor',
 'Fibroblast',
 'Macrophage',
 'Mast cell',
 'Monocyte',
 'Plasma cell',
 'TAN1',
 'TAN2',
 'TAN3',
 'TAN4',
 'Treg'}

In [27]:
pdata = pdata[~((pdata.obs['cell_type_sub'].isin(['TAN1', 'TAN2', 'TAN3', 'TAN4'])) & (pdata.obs['sample_type'] == 'blood'))]

In [28]:
#pdata = pdata[~((pdata.obs['cell_type_sub'].isin(['TAN'])) & (pdata.obs['sample_type'] == 'blood'))]

In [29]:
pdata.obs['sample_type'] = pdata.obs['sample_type'].replace('blood', 'normal')



In [30]:
pdata.obs.sample_type.value_counts()

sample_type
tumor     208
normal    204
Name: count, dtype: int64

In [31]:
pdata.obs.cell_type_sub

P11_blood_blood_BN1     BN1
P12_blood_blood_BN1     BN1
P13_blood_blood_BN1     BN1
P14_blood_blood_BN1     BN1
P15_blood_blood_BN1     BN1
                       ... 
P4_tumor_tumor_Treg    Treg
P5_tumor_tumor_Treg    Treg
P7_tumor_tumor_Treg    Treg
P8_tumor_tumor_Treg    Treg
P9_tumor_tumor_Treg    Treg
Name: cell_type_sub, Length: 412, dtype: object

In [32]:
resDir

'/data/projects/2022/CRCA/results/v1/final/liana_cell2cell/neutrophil/new'

In [33]:
#pdata.var_names.name = "gene_id"
#
#colData = pdata.obs
#colData.index.name = "sample_col"
#
#colData.to_csv(f"{resDir}/02_pseudobulk/{comparison}_colData_tan.csv")
#rowData = pdata.var[["Geneid", "GeneSymbol", "Chromosome", "Class", "Length"]]
#rowData.to_csv(f"{resDir}/02_pseudobulk/{comparison}_rowData_tan.csv")
#count_mat = pdata.to_df().T
#count_mat.index.name = "gene_id"
#count_mat.to_csv(f"{resDir}/02_pseudobulk/{comparison}_count_mat_tan.csv")

In [34]:
pdata.var_names.name = "gene_id"

colData = pdata.obs
colData.index.name = "sample_col"

colData.to_csv(f"{resDir}/02_pseudobulk/all_cells/{comparison}_colData.csv")
rowData = pdata.var[["Geneid", "GeneSymbol", "Chromosome", "Class", "Length"]]
rowData.to_csv(f"{resDir}/02_pseudobulk/all_cells/{comparison}_rowData.csv")
count_mat = pdata.to_df().T
count_mat.index.name = "gene_id"
count_mat.to_csv(f"{resDir}/02_pseudobulk/all_cells/{comparison}_count_mat.csv")

## TUMOR vs NORMAL

DeSeq2 script: "/data/scratch/kvalem/projects/2022/differential_gene_expression/bin/03_DESeq2_DGEA_studio.R"

### Parameters for DeSeq2 
- input: colData, count_mat, rowData
- covariate_formula = "patient_id +"
- sample_col="sample_col" 
- cond_col="sample_type"
- sum2zero=FALSE 
- c1="TAN1" 
- c2="blood"
- cpus=8

In [None]:
# DESEQ2 output path 
deseq2_path_prefix = "/data/projects/2022/CRCA/results/v1/final/liana_cell2cell/neutrophil/new/03_deseq2/"

In [None]:
file_name_deseq2_out = "new_tumor_vs_normal_DESeq2_result.tsv"

In [None]:
de_res = (
    pd.read_csv(f"{deseq2_path_prefix}/{file_name_deseq2_out}",
        sep="\t",
    )
    .fillna(1)
    .pipe(fdr_correction)
    .rename(columns={"comparison": "group"})
)

## LIANA- rank agregate

### NEUTROPHILS TAN1

In [None]:
#Only blood cells 
set(adata.obs.cell_type_middle)

In [None]:
# Run rank_aggregate for neutrophil
#li.mt.rank_aggregate(adata, groupby='cell_type_sub', expr_prop=0.1,resource_name='consensus',  verbose=True,key_added='rank_aggregate', layer = "log1p_norm", use_raw = False)

In [None]:
#adata.write_h5ad("adata_rank_agregate_neutrophil_TAN1.h5ad")

In [None]:
# rank agregate for core atlas 
adata = sc.read_h5ad(f"adata_rank_agregate_neutrophil_TAN.h5ad") 

In [None]:
# rank agregate for neutrophils
#adata_n = sc.read_h5ad(f"/data/projects/2022/CRCA/results/v0.1/crc-atlas-dataset/latest/ds_analyses/liana_cell2cell/neutrophil_subclusters/adata_rank_agregate_neutrophil.h5ad") 

In [None]:
# Run rank_aggregate for neutrophil
#li.mt.rank_aggregate(adata, groupby='cell_type_sub', expr_prop=0.1,resource_name='consensus',  verbose=True,key_added='rank_aggregate', layer = "log1p_norm", use_raw = False)

In [None]:
#adata.write_h5ad("adata_rank_agregate_neutrophil_cell_type_sub.h5ad")

###  CORE ATLAS 

In [None]:
# Run rank_aggregate
#li.mt.rank_aggregate(adata, groupby='cell_type_middle', expr_prop=0.1,resource_name='consensus',  verbose=True,key_added='rank_aggregate', layer = "log1p_norm", use_raw = False)

In [None]:
#adata.write_h5ad("adata_rank_agregate.h5ad")

In [None]:
cell_type_oi = "TAN4"
n_top_ligands = 30

In [None]:
immune_cells =['B cell',
 'DC mature',
 'Macrophage',
 'Mast cell',
 'Monocyte',
 'NK',
 'Neutrophil',
 'Plasma cell',
 'T cell CD4',
 'T cell CD8',
 'T cell regulatory',
 'cDC',
 'pDC']

In [None]:
immune_cells_cancer =['Cancer cell','B cell',
 'DC mature',
 'Macrophage',
 'Mast cell',
 'Monocyte',
 'NK',
 'Neutrophil',
 'Plasma cell',
 'T cell CD4',
 'T cell CD8',
 'T cell regulatory',
 'cDC',
 'pDC']

In [None]:
#result of `significant_interactions`. May be further filtered or modified.
cpdb_res = adata.uns['rank_aggregate'].loc[
        lambda x: x["specificity_rank"] <= 0.01
    ]

In [None]:
# rename columns in liana results 
cpdb_res=cpdb_res.rename(columns={"ligand_complex":"source_genesymbol","receptor_complex":"target_genesymbol"})

In [None]:
cpdb_res.columns

In [None]:
# use scanpy helper class CpdbAnalysis to compute pseudobulk, cell fraction and 
cpdba = CpdbAnalysis(
    cpdb_res,
    adata,
    pseudobulk_group_by=["patient_id"],
    cell_type_column="cell_type_sub"
)

In [None]:
cpdba

In [None]:
cpdb_sig_int = cpdba.significant_interactions(
    de_res, max_pvalue=0.1
)

In [None]:
immune_cells = list(set(cpdb_sig_int.cell_type_sub))

In [None]:
cpdb_sig_int.columns

In [None]:
## This is input for CIRCOS PLOT 
cpdb_sig_int.to_csv(f"/data/projects/2022/CRCA/results/v1/final/liana_cell2cell/neutrophil/new/neutrophil_tan4.csv")

In [None]:
cpdb_sig_int = pd.read_csv(f"/data/projects/2022/CRCA/results/v1/final/liana_cell2cell/neutrophil/new/neutrophil_tan4.csv")


In [None]:
set(cpdb_sig_int.source)

In [None]:
set(cpdb_sig_int.target)

In [None]:
cpdb_sig_int = cpdb_sig_int.loc[lambda x: x["cell_type_sub"].isin(immune_cells)]

In [None]:
top_genes = (
    cpdb_sig_int.loc[:, ["source_genesymbol", "fdr"]]
    .drop_duplicates()
    .sort_values("fdr")["source_genesymbol"][:30]
    .tolist()
)

In [None]:
perturbation="tumor"
baseline="blood"

In [None]:
title_plot = f"{perturbation} vs {baseline}: {cell_type_oi}, FDR<0.1"

In [None]:
save_name_plot =  f"{perturbation}_vs_{baseline}_{cell_type_oi}_fdr_0.1"

In [None]:
heatmap = cpdba.plot_result(
    cpdb_sig_int.loc[lambda x: x["source_genesymbol"].isin(top_genes)],
    title=title_plot,
    aggregate=False,
    cluster="heatmap",
    label_limit=110,
)
heatmap

In [None]:
heatmap = cpdba.plot_result(
    cpdb_sig_int.loc[lambda x: x["source_genesymbol"].isin(top_genes)],
    title=title_plot,
    aggregate=False,
    cluster="heatmap",
    label_limit=110,
)
heatmap

In [None]:
heatmap = cpdba.plot_result(
    cpdb_sig_int.loc[lambda x: x["source_genesymbol"].isin(top_genes)],
    title=title_plot,
    aggregate=False,
    cluster="heatmap",
    label_limit=110,
)
heatmap

In [None]:
heatmap.save(f'{resDir}/figures/{save_name_plot}.png')
heatmap.save(f'{resDir}/figures/{save_name_plot}.svg')
heatmap.save(f'{resDir}/figures/{save_name_plot}.pdf')

## CIRCOS PLOT 

## This is input for CIRCOS PLOT 
input = "/data/projects/2022/CRCA/results/v0.1/crc-atlas-dataset/latest/ds_analyses/liana_cell2cell/core_atlas/tumor_normal/epithelial_cancer.csv"

Circosp plot script  "/data/scratch/kvalem//projects/2022/crc-atlas/analyses/05_Liana/circosplot.Rmd"

In [None]:
resDir

# Functions

In [None]:
def fdr_correction(df, pvalue_col="pvalue", *, key_added="fdr", inplace=False):
    """Adjust p-values in a data frame with test results using FDR correction."""
    if not inplace:
        df = df.copy()

    df[key_added] = statsmodels.stats.multitest.fdrcorrection(df[pvalue_col].values)[1]

    if not inplace:
        return df

In [None]:
"""Plotting functions for group comparisons"""

import altair as alt
import pandas as pd
import numpy as np


def plot_lm_result_altair(
    df,
    p_cutoff=0.1,
    p_col="fdr",
    x="variable",
    y="group",
    color="coef",
    title="heatmap",
    cluster=False,
    value_max=None,
    configure=lambda x: x.configure_mark(opacity=1),
    cmap="redblue",
    reverse=True,
    domain=lambda x: [-x, x],
    order=None,
):
    """
    Plot a results data frame of a comparison as a heatmap
    """
    df_filtered = df.loc[lambda _: _[p_col] < p_cutoff, :]
    df_subset = df.loc[
        lambda _: _[x].isin(df_filtered[x].unique()) & _[y].isin(df[y].unique())
    ]
    if not df_subset.shape[0]:
        print("No values to plot")
        return

    if order is None:
        order = "ascending"
        if cluster:
            from scipy.cluster.hierarchy import linkage, leaves_list

            values_df = df_subset.pivot(index=y, columns=x, values=color)
            order = values_df.columns.values[
                leaves_list(
                    linkage(values_df.values.T, method="average", metric="euclidean")
                )
            ]

    def _get_significance(fdr):
        if fdr < 0.001:
            return "< 0.001"
        elif fdr < 0.01:
            return "< 0.01"
        elif fdr < 0.1:
            return "< 0.1"
        else:
            return np.nan

    df_subset["FDR"] = pd.Categorical([_get_significance(x) for x in df_subset[p_col]])

    if value_max is None:
        value_max = max(
            abs(np.nanmin(df_subset[color])), abs(np.nanmax(df_subset[color]))
        )
    # just setting the domain in altair will lead to "black" fields. Therefore, we constrain the values themselves.
    df_subset[color] = np.clip(df_subset[color], *domain(value_max))
    return configure(
        alt.Chart(df_subset, title=title)
        .mark_rect()
        .encode(
            x=alt.X(x, sort=order),
            y=y,
            color=alt.Color(
                color,
                scale=alt.Scale(scheme=cmap, reverse=reverse, domain=domain(value_max)),
            ),
        )
        + alt.Chart(df_subset.loc[lambda x: ~x["FDR"].isnull()])
        .mark_point(color="white", filled=True, stroke="black", strokeWidth=0)
        .encode(
            x=alt.X(x, sort=order),
            y=y,
            size=alt.Size(
                "FDR:N",
                scale=alt.Scale(
                    domain=["< 0.001", "< 0.01", "< 0.1"],
                    range=4 ** np.array([3, 2, 1]),
                ),
            ),
        )
    )

In [None]:
from typing import Sequence, Union
from anndata import AnnData, ImplicitModificationWarning
import numpy as np
import pandas as pd
from operator import and_
from functools import reduce
import warnings


def pseudobulk(
    adata,
    *,
    groupby: Union[str, Sequence[str]],
    aggr_fun=np.sum,
    min_obs=10,
) -> AnnData:
    """
    Calculate Pseudobulk of groups

    Parameters
    ----------
    adata
        annotated data matrix
    groupby
        One or multiple columns to group by
    aggr_fun
        Callback function to calculate pseudobulk. Must be a numpy ufunc supporting
        the `axis` attribute.
    min_obs
        Exclude groups with less than `min_obs` observations

    Returns
    -------
    New anndata object with same vars as input, but reduced number of obs.
    """
    if isinstance(groupby, str):
        groupby = [groupby]

    combinations = adata.obs.loc[:, groupby].drop_duplicates()

    if adata.is_view:
        # for whatever reason, the pseudobulk function is terribly slow when operating on a view.
        adata = adata.copy()

    # precompute masks
    masks = {}
    for col in groupby:
        masks[col] = {}
        for val in combinations[col].unique():
            masks[col][val] = adata.obs[col] == val

    expr_agg = []
    obs = []

    for comb in combinations.itertuples(index=False):
        mask = reduce(and_, (masks[col][val] for col, val in zip(groupby, comb)))
        if np.sum(mask) < min_obs:
            continue
        expr_row = aggr_fun(adata.X[mask, :], axis=0)
        obs_row = comb._asdict()
        obs_row["n_obs"] = np.sum(mask)
        # convert matrix to array if required (happens when aggregating spares matrix)
        try:
            expr_row = expr_row.A1
        except AttributeError:
            pass
        obs.append(obs_row)
        expr_agg.append(expr_row)

    with warnings.catch_warnings():
        warnings.simplefilter("ignore", ImplicitModificationWarning)
        return AnnData(
            X=np.vstack(expr_agg),
            var=adata.var,
            obs=pd.DataFrame.from_records(obs),
        )

In [None]:
"""Helper functions for cellphonedb analysis

Focuses on differential cellphonedb analysis between conditions.
"""
from typing import List, Literal
import pandas as pd
#from .pseudobulk import pseudobulk
import numpy as np
import scanpy as sc
import altair as alt
#from .compare_groups.pl import plot_lm_result_altair
#from .util import fdr_correction


class CpdbAnalysis:
    def __init__(
        self, cpdb, adata, *, pseudobulk_group_by: List[str], cell_type_column: str
    ):
        """
        Class that handles comparative cellphonedb analysis.

        Parameters
        ----------
        cpdb
            pandas data frame with cellphonedb interactions.
            Required columns: `source_genesymbols`, `target_genesymbol`.
            You can get this from omnipathdb:
            https://omnipathdb.org/interactions/?fields=sources,references&genesymbols=1&databases=CellPhoneDB
        adata
            Anndata object with the target cells. Will use this to derive mean fraction of expressed cells.
            Should contain counts in X.
        pseudobulk_group_by
            See :func:`scanpy_helper.pseudobulk.pseudobulk`. Pseudobulk is used to compute the mean fraction
            of expressed cells by patient
        cell_type_column
            Column in anndata that contains the cell-type annotation.
        """
        self.cpdb = cpdb
        self.cell_type_column = cell_type_column
        self._find_expressed_genes(adata, pseudobulk_group_by)

    def _find_expressed_genes(self, adata, pseudobulk_group_by):
        """Compute the mean expression and fraction of expressed cells per cell-type.
        This is performed on the pseudobulk level, i..e. the mean of means per patient is calculated.
        """
        pb_fracs = pseudobulk(
            adata,
            groupby=pseudobulk_group_by + [self.cell_type_column],
            aggr_fun=lambda x, axis: np.sum(x > 0, axis) / x.shape[axis],  # type: ignore
        )
        fractions_expressed = pseudobulk(
            pb_fracs, groupby=self.cell_type_column, aggr_fun=np.mean
        )
        fractions_expressed.obs.set_index(self.cell_type_column, inplace=True)

        pb = pseudobulk(
            adata,
            groupby=pseudobulk_group_by + [self.cell_type_column],
        )
        sc.pp.normalize_total(pb, target_sum=1e6)
        sc.pp.log1p(pb)
        pb_mean_cell_type = pseudobulk(
            pb, groupby=self.cell_type_column, aggr_fun=np.mean
        )
        pb_mean_cell_type.obs.set_index(self.cell_type_column, inplace=True)

        self.expressed_genes = (
            fractions_expressed.to_df()
            .melt(ignore_index=False, value_name="fraction_expressed")
            .reset_index()
            .merge(
                pb_mean_cell_type.to_df()
                .melt(ignore_index=False, value_name="expr_mean")
                .reset_index(),
                on=[self.cell_type_column, "variable"],
            )
        )

    def significant_interactions(
        self,
        de_res: pd.DataFrame,
        *,
        pvalue_col="pvalue",
        fc_col="log2FoldChange",
        gene_symbol_col="gene_id",
        max_pvalue=0.1,
        min_abs_fc=1,
        adjust_fdr=True,
        min_frac_expressed=0.1,
        de_genes_mode: Literal["ligand", "receptor"] = "ligand",
    ) -> pd.DataFrame:
        """
        Generates a data frame of differentiall cellphonedb interactions.

        This function will extract all known ligands (or receptors, respectively) from a list of differentially expressed
        and find all receptors (or ligands, respectively) that are expressed above a certain cutoff in all cell-types.

        Parameters:
        -----------
        de_res
            List of differentially expressed genes
        pvalue_col
            column in de_res that contains the pvalue or false discovery rate
        gene_id_col
            column in de_res that contains the gene symbol
        min_frac_expressed
            Minimum fraction cells that need to express the receptor (or ligand) to be considered a potential interaction
        de_genes_mode
            If the list of de genes provided are ligands (default) or receptors. In case of `ligand`, cell-types
            that express corresonding receptors above the threshold will be identified. In case of `receptor`,
            cell-types that express corresponding ligands above the threshold will be identified.
        adjust_fdr
            If True, calculate false discovery rate on the pvalue, after filtering for genes that are contained
            in the cellphonedb.
        """
        if de_genes_mode == "ligand":
            cpdb_de_col = "source_genesymbol"
            cpdb_expr_col = "target_genesymbol"
        elif de_genes_mode == "receptor":
            cpdb_de_col = "target_genesymbol"
            cpdb_expr_col = "source_genesymbol"
        else:
            raise ValueError("Invalud value for de_genes_mode!")

        de_res = de_res.loc[lambda x: x[gene_symbol_col].isin(self.cpdb[cpdb_de_col])]
        if adjust_fdr:
            de_res = fdr_correction(de_res, pvalue_col=pvalue_col, key_added="fdr")
            pvalue_col = "fdr"

        significant_genes = de_res.loc[
            lambda x: (x[pvalue_col] < max_pvalue) & (np.abs(x[fc_col]) >= min_abs_fc),
            gene_symbol_col,
        ].unique()  # type: ignore
        significant_interactions = self.cpdb.loc[
            lambda x: x[cpdb_de_col].isin(significant_genes)
        ]

        res_df = (
            self.expressed_genes.loc[
                lambda x: x["fraction_expressed"] >= min_frac_expressed
            ]  # type: ignore
            .merge(
                significant_interactions,
                left_on="variable",
                right_on=cpdb_expr_col,
            )
            .drop(columns=["variable"])
            .merge(de_res, left_on=cpdb_de_col, right_on=gene_symbol_col)
            .drop(columns=[gene_symbol_col])
        )

        return res_df

    def plot_result(
        self,
        cpdb_res,
        *,
        pvalue_col="fdr",
        group_col="group",
        fc_col="log2FoldChange",
        title="CPDB analysis",
        aggregate=True,
        clip_fc_at=(-5, 5),
        label_limit=100,
        cluster: Literal["heatmap", "dotplot"] = "dotplot",
        de_genes_mode: Literal["ligand", "receptor"] = "ligand",
    ):
        """
        Plot cpdb results as heatmap

        Parameters
        ----------
        cpdb_res
            result of `significant_interactions`. May be further filtered or modified.
        group_col
            column to be used for the y axis of the heatmap
        aggregate
            whether to merge multiple targets of the same ligand into a single column
        de_genes_mode
            If the list of de genes provided are ligands (default) or receptors. If receptor, will show the dotplot
            at the top (source are expressed ligands) and the de heatmap at the bottom (target are the DE receptors).
            Otherwise the other way round.
        """
        if de_genes_mode == "ligand":
            cpdb_de_col = "source_genesymbol"
            cpdb_expr_col = "target_genesymbol"
        elif de_genes_mode == "receptor":
            cpdb_de_col = "target_genesymbol"
            cpdb_expr_col = "source_genesymbol"
        else:
            raise ValueError("Invalud value for de_genes_mode!")

        cpdb_res[fc_col] = np.clip(cpdb_res[fc_col], *clip_fc_at)

        # aggregate if there are multiple receptors per ligand
        if aggregate:
            cpdb_res = (
                cpdb_res.groupby(
                    [
                        self.cell_type_column,
                        cpdb_de_col,
                        fc_col,
                        pvalue_col,
                        group_col,
                    ]
                )
                .agg(
                    n=(cpdb_expr_col, len),
                    fraction_expressed=("fraction_expressed", np.max),
                    expr_mean=("expr_mean", np.max),
                )
                .reset_index()
                .merge(
                    cpdb_res.groupby(cpdb_de_col).agg(
                        **{
                            cpdb_expr_col: (
                                cpdb_expr_col,
                                lambda x: "|".join(np.unique(x)),
                            )
                        }
                    ),
                    on=cpdb_de_col,
                )
            )

        cpdb_res["interaction"] = [
            f"{s}_{t}" for s, t in zip(cpdb_res[cpdb_de_col], cpdb_res[cpdb_expr_col])
        ]

        # cluster heatmap
        if cluster is not None:
            from scipy.cluster.hierarchy import linkage, leaves_list

            _idx = self.cell_type_column if cluster == "dotplot" else group_col
            _values = "fraction_expressed" if cluster == "dotplot" else fc_col
            _columns = "interaction"
            values_df = (
                cpdb_res.loc[:, [_idx, _values, _columns]]
                .drop_duplicates()
                .pivot(
                    index=_idx,
                    columns=_columns,
                    values=_values,
                )
                .fillna(0)
            )
            order = values_df.columns.values[
                leaves_list(
                    linkage(values_df.values.T, method="average", metric="euclidean")
                )
            ]
        else:
            order = "ascending"

        p1 = plot_lm_result_altair(
            cpdb_res,
            color=fc_col,
            p_col=pvalue_col,
            x="interaction",
            configure=lambda x: x,
            title="",
            order=order,
            p_cutoff=1,
        ).encode(
            x=alt.X(
                title=None,
                axis=alt.Axis(
                    labelExpr="split(datum.label, '_')[0]",
                    orient="top" if de_genes_mode == "receptor" else "bottom",
                ),
            )
        )

        p2 = (
            alt.Chart(cpdb_res)
            .mark_circle()
            .encode(
                x=alt.X(
                    "interaction",
                    axis=alt.Axis(
                        grid=True,
                        orient="bottom" if de_genes_mode == "receptor" else "top",
                        title=None,
                        labelExpr="split(datum.label, '_')[1]",
                        labelLimit=label_limit,
                    ),
                    sort=order,
                ),
                y=alt.Y(self.cell_type_column, axis=alt.Axis(grid=True), title=None),
                size=alt.Size("fraction_expressed"),
                color=alt.Color("expr_mean", scale=alt.Scale(scheme="cividis")),
            )
        )

        if de_genes_mode == "receptor":
            p1, p2 = p2, p1

        return (
            alt.vconcat(p1, p2, title=title)
            .resolve_scale(size="independent", color="independent", x="independent")
            .configure_mark(opacity=1)
            .configure_concat(spacing=label_limit - 130)
        )