In [None]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:95% !important; }</style>"))

In [None]:
import os 
import sys

import pandas as pd
import numpy as np
import scanpy as sc
import glob
import json 
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.pyplot import rc_context
from sklearn.metrics import precision_recall_curve, auc

sys.path.append('../..')
from data.load_data import load_datasets
from data.constants import BASE_PATH_EXPERIMENTS, BASE_PATH_DATA

from signaturescoring import score_signature
from signaturescoring.utils.utils import get_mean_and_variance_gene_expression, check_signature_genes

plt.rcParams.update({'pdf.fonttype':42, 'font.family':'sans-serif', 'font.sans-serif':'Arial', 'font.size':14})

### Helper functions

In [None]:
def get_sig_from_emtome_sig_file(filepath):
    assert os.path.exists(filepath)
    with open(filepath, 'r') as f:
        lines = f.readlines()
        lines = lines[2:]
        lines = [str(x[1:-1]) for x in lines]
        
    return lines

In [None]:
def get_path(base, add_folder=[], fn=None):
    curr_path = os.path.join(base, *add_folder) if fn is None else os.path.join(base, *add_folder, fn)
    if not os.path.exists(curr_path):
        os.makedirs(curr_path)
        print(f"Creating folder {curr_path}") 
    return curr_path 

In [None]:
def get_cancer_emt_barcodes(dataset, mode=1):
    if dataset=='breast':
        fn = os.path.join(BASE_PATH_EXPERIMENTS, f'EMT_signature_scoring_case_study/{dataset}/barcodes_cancer_emt_{mode}.csv')
    else:    
        fn = os.path.join(BASE_PATH_EXPERIMENTS, f'EMT_signature_scoring_case_study/{dataset}/barcodes_cancer_emt.csv')
    print(fn)
    barcodes_cancer_emt = (pd.read_csv(fn))['0']
    barcodes_cancer_emt.name = 'cancer_emt_cells'
    return barcodes_cancer_emt

In [None]:
def creat_celltype_emt_col(adata, celltype_col, cancer_emt_label, caf_label, cancer_no_emt_label, barcodes_cancer_emt, mal_cells_name, fibro_name):
    adata.obs['celltype_emt'] = adata.obs[celltype_col].copy()
    adata.obs['celltype_emt'] = adata.obs['celltype_emt'].astype(str)
    
    adata.obs.loc[barcodes_cancer_emt, 'celltype_emt'] = cancer_emt_label
    adata.obs.loc[(adata.obs.celltype == mal_cells_name) & (adata.obs.index.isin(barcodes_cancer_emt)==False), 'celltype_emt'] = cancer_no_emt_label
    adata.obs.loc[(adata.obs.celltype == fibro_name), 'celltype_emt'] = caf_label
    adata.obs['celltype_emt'] = adata.obs['celltype_emt'].astype('category')
    
    return adata

In [None]:
def get_hue_order(adata, cancer_emt_label, caf_label,cancer_no_emt_label):
    return [cancer_emt_label, caf_label,cancer_no_emt_label] +sorted(list(set(adata.obs['celltype_emt'].unique()).difference([cancer_emt_label, caf_label,cancer_no_emt_label])))

In [None]:
def get_and_plot_celltype_proportions(ds, adata, order_hue, sample_col='sample_id', figsize=(20, 10)):
    
    tmp = adata.obs[[sample_col,'celltype_emt']]

    cross_tab_prop = pd.crosstab(index=tmp[sample_col],
                                 columns=tmp['celltype_emt'],
                                 normalize="index"
                                )

    cross_tab_prop = cross_tab_prop[order_hue]

    cross_tab_prop = np.round(cross_tab_prop*100, decimals=2)

    plt.rcParams.update({'pdf.fonttype':42, 'font.family':'sans-serif', 'font.sans-serif':'Arial', 'font.size':14})

    ax = cross_tab_prop.plot(kind='bar', 
                        stacked=True, 
                        colormap='tab20', 
                        figsize=figsize)

    plt.legend(loc='center left',bbox_to_anchor=(1.0, 0.5),ncol=1, fontsize=18)
    plt.xlabel("Sample ID", fontsize= 22)
    plt.ylabel("Celltype proportions (%)", fontsize= 22)
    plt.title(f"Celltype proportions per sample for {ds.upper()} cancer.", fontsize= 24)
    plt.xticks(fontsize=18)
    plt.yticks(np.arange(0,101,5),fontsize=18)
    plt.tight_layout()
    

    cross_tab = pd.crosstab(index=tmp[sample_col],
                             columns=tmp['celltype_emt'],
                            )
    return plt.gcf(), cross_tab
    

In [None]:
def create_violin_plot(adatas, hue_orders, row_order, col_order):
    plt.rcParams.update({'pdf.fonttype':42, 'font.family':'sans-serif', 'font.sans-serif':'Arial', 'font.size':14})
    fig, axes = plt.subplots(nrows=len(row_order), ncols=len(col_order), sharex='col', sharey=False, figsize=(len(col_order)*5.7, len(row_order)*4.4))
    for i, row in enumerate(row_order):
        for j, col in enumerate(col_order):
            sc.pl.violin(adatas[col.lower()], keys=row, groupby='celltype_emt', order=hue_orders[col.lower()], rotation=90, stripplot=False, ax=axes[i,j], show=False)
            if j!=0:
                axes[i,j].set_ylabel(None)
            else:
                axes[i,j].set_ylabel(row, fontsize=20) 
            if i==0:
                axes[i,j].set_title(col, fontsize=26)
            if i==(len(row_order)-1):
                axes[i,j].tick_params(axis='x', labelsize=20)
            axes[i,j].tick_params(axis='y', labelsize=16)
    plt.tight_layout()
    return fig

### Global variables 

In [None]:
save = True
pl_size = 6
mode = 1

base_path_emt_signatures = os.path.join(BASE_PATH_DATA, f'annotations/emt')
base_emt_exp_path = os.path.join(BASE_PATH_EXPERIMENTS, f'EMT_signature_scoring_case_study')
if save:
    base_storing_path = get_path(base_emt_exp_path, add_folder=['results'])

In [None]:
cancer_emt_label = 'Cancer expr. EMT'
caf_label = 'Fibroblast'
cancer_no_emt_label = 'Cancer not expr. EMT'

aucpr_lbl_emt_rest = 'AUCPR cancer expr. EMT vs. rest'
aucpr_lbl_emt_cafs = 'AUCPR cancer expr. EMT vs. CAFs'
aucpr_lbl_emt_mal = 'AUCPR cancer expr. EMT vs. cancer not expr. EMT'


datasets = ['crc', 'escc', 'luad_xing', 'breast']

### Load data and mark cancer cells expressing EMT

In [None]:
adatas = {}

for ds in datasets:
    print(f'Loading dataset {ds}.')
    adatas[ds] = load_datasets('breast_small' if ds == 'breast' else ds , preprocessed=True, norm_method='mean')
    if ds == 'luad_xing':
        adatas[ds] = adatas[ds][adatas[ds].obs.celltype != 'Granulocytes'].copy()
    elif ds == 'breast' and mode==1:
        barcodes_to_remove = pd.read_csv(os.path.join(base_emt_exp_path, ds, 'barcodes_to_remove.csv')).iloc[:,1].tolist()
        print(adatas[ds].shape, f'removing {len(barcodes_to_remove)} cells in gray area')
        adatas[ds] = adatas[ds][~adatas[ds].obs.index.isin(barcodes_to_remove)].copy()
        print(adatas[ds].shape)
    if 'log1p' in adatas[ds].uns_keys():
        adatas[ds].uns['log1p']['base'] = None
    else:
        adatas[ds].uns['log1p'] = {'base': None}

In [None]:
for (ds, adata), (mal_name, caf_name) in zip(adatas.items(),[('Epi','Fibro'),('Epi', 'Fibroblasts'),('Malignant', 'Fibroblast'), ('Cancer Epithelial', 'CAFs')]):
    barcodes_cancer_emt = get_cancer_emt_barcodes(ds, mode=mode)
    adatas[ds] = creat_celltype_emt_col(adata, 
                                        'celltype', 
                                        cancer_emt_label,
                                        caf_label,
                                        cancer_no_emt_label,
                                        barcodes_cancer_emt, 
                                        mal_name, 
                                        caf_name)

    print(ds, adatas[ds].obs['celltype_emt'].value_counts().sort_index())
    print()

In [None]:
hue_orders = {}
for ds, adata in adatas.items():
    hue_orders[ds] = get_hue_order(adata, cancer_emt_label, cancer_no_emt_label, caf_label)

### Get celltype proportions per datasets

In [None]:
plt.rcParams.update({'pdf.fonttype':42, 'font.family':'sans-serif', 'font.sans-serif':'Arial', 'font.size':14})
curr_storing_path = get_path(base_emt_exp_path, add_folder=['results', 'dataset_composition'])
for ds in datasets:
    fig, cross_tab_prop = get_and_plot_celltype_proportions(ds, adatas[ds], hue_orders[ds], sample_col='sample_id', figsize=(20, 10))
    
    if save:
        fig.savefig(os.path.join(curr_storing_path, f'celltype_proportions_{ds}.svg'))
        cross_tab_prop.to_csv(os.path.join(curr_storing_path, f'celltype_proportions_{ds}.csv'))
    plt.show(fig)

## Get all signatures to compare

In [None]:
sig_list = glob.glob(base_path_emt_signatures+"/sigs_from_emtome/pan_cancer/*.txt")
sig_list.sort()
sig_list = {x.split('/')[-1].split('.')[0]:x for x in sig_list}
sig_list = {key:get_sig_from_emtome_sig_file(val) for key, val in sig_list.items()}

with open(base_path_emt_signatures+'/HALLMARK_EPITHELIAL_MESENCHYMAL_TRANSITION.v7.5.1.json', 'r') as f:
    hemt = json.load(f)

sig_list['hallmark_emt'] = hemt['HALLMARK_EPITHELIAL_MESENCHYMAL_TRANSITION']['geneSymbols']


GM_B_22 = pd.read_csv(base_path_emt_signatures+'/gene_modules_from_Barkley_et_al_2022.csv')
pEMT_gm = GM_B_22.pEMT

sig_list['pEMT_gm'] = pEMT_gm.dropna().tolist()

In [None]:
# ESOPHAG_CANCER_EMT_SIGNATURE_1 = pd.read_csv(os.path.join(base_emt_exp_path,'escc', 'dataset_specific_emt_sig', 'ESOPHAG_CANCER_EMT_SIGNATURE_1.csv'))
# sig_list['ESCC EMT signature 1'] = ESOPHAG_CANCER_EMT_SIGNATURE_1.iloc[:,1].tolist()

# ESOPHAG_CANCER_EMT_SIGNATURE_2 = pd.read_csv(os.path.join(base_emt_exp_path, 'escc', 'dataset_specific_emt_sig',  'ESOPHAG_CANCER_EMT_SIGNATURE_2.csv'))
# sig_list['ESCC EMT signature 2'] = ESOPHAG_CANCER_EMT_SIGNATURE_2.iloc[:,1].tolist()

# LUNG_CANCER_EMT_SIGNATURE_1 = pd.read_csv(os.path.join(base_emt_exp_path, 'luad_xing', 'dataset_specific_emt_sig','LUNG_CANCER_EMT_SIGNATURE_1.csv'))
# sig_list['LUAD EMT signature 1'] = LUNG_CANCER_EMT_SIGNATURE_1.iloc[:,1].tolist()

# LUNG_CANCER_EMT_SIGNATURE_2 = pd.read_csv(os.path.join(base_emt_exp_path, 'luad_xing', 'dataset_specific_emt_sig','LUNG_CANCER_EMT_SIGNATURE_2.csv'))
# sig_list['LUAD EMT signature 2'] = LUNG_CANCER_EMT_SIGNATURE_2.iloc[:,1].tolist()

# LUNG1_ESCC2_CANCER_EMT_SIGNATURE_1 = pd.read_csv(os.path.join(base_emt_exp_path, 'escc', 'union_emt_sigs','LUNG1_ESCC2_CANCER_EMT_SIGNATURE_1.csv'))
# sig_list['ESCC and LUAD EMT signature 1'] = LUNG1_ESCC2_CANCER_EMT_SIGNATURE_1.iloc[:,1].tolist()

# LUNG1_ESCC2_CANCER_EMT_SIGNATURE_2 = pd.read_csv(os.path.join(base_emt_exp_path, 'escc', 'union_emt_sigs','LUNG1_ESCC2_CANCER_EMT_SIGNATURE_2.csv'))
# sig_list['ESCC and LUAD EMT signature 2'] = LUNG1_ESCC2_CANCER_EMT_SIGNATURE_2.iloc[:,1].tolist()

# LUNG1_ESCC2_CANCER_EMT_SIGNATURE_3 = pd.read_csv(os.path.join(base_emt_exp_path, 'escc', 'union_emt_sigs','LUNG1_ESCC2_CANCER_EMT_SIGNATURE_3.csv'))
# sig_list['ESCC and LUAD EMT signature 3'] = LUNG1_ESCC2_CANCER_EMT_SIGNATURE_3.iloc[:,1].tolist()


In [None]:
LUNG1_ESCC2_CANCER_EMT_SIGNATURE_3 = pd.read_csv(os.path.join(base_emt_exp_path, 'escc', 'union_emt_sigs','LUNG1_ESCC2_CANCER_EMT_SIGNATURE_3.csv'))
# sig_list['ESCC and LUAD EMT signature'] = LUNG1_ESCC2_CANCER_EMT_SIGNATURE_3.iloc[:,1].tolist() + ['FOSL1', 'MLLT11']
sig_list['ESCC and LUAD EMT signature'] = LUNG1_ESCC2_CANCER_EMT_SIGNATURE_3.iloc[:,1].tolist()

In [None]:
# sig_list['ESCC and LUAD EMT signature (small)'] = ['LAMC2', 'SERPINE1', 'PRSS8', 'SERPINE2', 'TNC', 'FLNA',
#                                                    'TGFBI', 'ANGPTL4', 'AREG', 'BMP2', 'CAV1', 'CRLF1', 'ITGA2',
#                                                    'LAMA3', 'LAMA5', 'LCN2', 'PLEK2', 'TMPRSS4', 'TNFRSF12A',
#                                                    'VEGFA', 'CBLC', 'FLNB', 'DSG2', 'PPL', 'ANXA3', 'EDN1',
#                                                    'FGFBP1', 'ITGA3', 'ITGB4', 'L1CAM', 'LAMB3', 'MET', 'MISP',
#                                                    'MMP10', 'S100A2', 'SCEL', 'TRIM29', 'UCHL1', 'KDR', 'KRT14',
#                                                    'CDH3', 'EVPL', 'TMC5', 'ADGRF1', 'B3GNT3', 'CARD10', 'CDCP1',
#                                                    'CX3CL1', 'CXCL14', 'FAM83A', 'ITGB8', 'MUC16', 'PHLDA2',
#                                                    'PLOD3', 'S100A10', 'SAA1', 'SERINC2', 'SLC2A1', 'UPP1',
#                                                    'HHLA2', 'KISS1', 'SPINK1', 'C19orf33', 'COBL', 'CYP27B1',
#                                                    'DFNA5', 'MT2A', 'PPP1R14C', 'ABHD11-AS1', 'AGRN', 'BCYRN1',
#                                                    'ERO1A', 'FBXO2', 'METTL7B', 'PITX1', 'PPP1R14B', 'SFN',
#                                                    'SLC6A14', 'SNCG', 'TNNT1', 'UBE2C', 'VSIG1', 'VSTM2L',
#                                                    'WDR66', 'WNT7B']

In [None]:
sig_name_mapping = {
    'Foroutan_et_al_2017' : 'EMT signature Foroutan et al. 2017',
    'Groeger_et_al_2012' : 'EMT signature Groeger et al. 2012',
    'Hollern_et_al_2018' : 'EMT signature Hollern et al. 2018',
    'Mak_et_al_2016' : 'EMT signature Mak et al. 2016',
    'Tuan_et_a_2014' : 'EMT signature Tuan et al. 2014',
    'hallmark_emt' : 'Hallmark EMT signature',
    'pEMT_gm' : 'pEMT gene module'}

In [None]:
for old_key, new_key in sig_name_mapping.items():
    sig_list[new_key] = sig_list.pop(old_key)

In [None]:
sig_list.keys()

## Score datasets for signatures 

In [None]:
df_means = {}
for ds in datasets:
    df_means[ds] = get_mean_and_variance_gene_expression(adatas[ds], estim_var=False)

In [None]:
for ds in datasets:
    print(f'Scoring {ds.upper()}')
    for score_name, gene_list in sig_list.items():
        print(f'> signature {score_name}')
        score_signature(method="adjusted_neighborhood_scoring",
                        adata=adatas[ds],
                        gene_list=gene_list,
                        ctrl_size=100,
                        df_mean_var = df_means[ds],
                        score_name=score_name)
    print()

In [None]:
if save:
    curr_storing_path = get_path(base_emt_exp_path, add_folder=['results', 'sig_scores'])
    for ds in datasets:
        adatas[ds].obs[['sample_id', 'celltype', 'celltype_emt']+list(sig_list.keys())].to_csv(os.path.join(curr_storing_path, f'sig_scores_{ds}.csv'))
else:
    print(f'Not storing dataset scores.')                                                                                          

## Create violinplots of scores 

In [None]:
col_order = ['ESCC', 'LUAD_XING', 'CRC', 'BREAST']
row_order = ['Hallmark EMT signature',
             'pEMT gene module',
             'ESCC and LUAD EMT signature', 
             'EMT signature Foroutan et al. 2017',
             'EMT signature Groeger et al. 2012',
             'EMT signature Hollern et al. 2018',
             'EMT signature Mak et al. 2016',
             'EMT signature Tuan et al. 2014',
             ]
# row_order = list(sig_list.keys())

In [None]:
curr_storing_path = get_path(base_emt_exp_path, add_folder=['results', 'violinplots'])

In [None]:
fig = create_violin_plot(adatas, hue_orders, row_order, col_order)
if save:
    fig.savefig(os.path.join(curr_storing_path, 'all_sigs.svg'))
    fig.savefig(os.path.join(curr_storing_path, 'all_sigs.png'), dpi=600)
plt.show(fig) 

In [None]:
fig = create_violin_plot(adatas, hue_orders, row_order=['Hallmark EMT signature', 'pEMT gene module', 'ESCC and LUAD EMT signature'], col_order=col_order)
if save:
    fig.savefig(os.path.join(curr_storing_path, 'hallmark_pemt_escc_luad.svg'))
    fig.savefig(os.path.join(curr_storing_path, 'hallmark_pemt_escc_luad.png'), dpi=600)
plt.show(fig)

## Evaluate scores of signatures 

In [None]:
GTs = {}
CAFS_MAL_EMT = {}
MAL_MAL_EMT = {}

for ds in datasets:
    print(ds.upper())
    gt = adatas[ds].obs.celltype_emt.copy()
    gt = gt.astype(str)
    gt[gt!=cancer_emt_label]= 'Rest'
    GTs[ds] = gt
    
    CAFS_MAL_EMT[ds] = adatas[ds].obs.celltype_emt[adatas[ds].obs.celltype_emt.isin([cancer_emt_label,caf_label])].index.tolist()
    MAL_MAL_EMT[ds] = adatas[ds].obs.celltype_emt[adatas[ds].obs.celltype_emt.isin([cancer_emt_label,cancer_no_emt_label])].index.tolist()
    
    print(len(CAFS_MAL_EMT[ds]), len(MAL_MAL_EMT[ds]))

In [None]:
plt.rcParams.update({'pdf.fonttype':42, 'font.family':'sans-serif', 'font.sans-serif':'Arial', 'font.size':14})
curr_storing_path = get_path(base_emt_exp_path, add_folder=['results', 'distribution_plots'])
pal = sns.color_palette('tab10') + sns.color_palette('Set3')
performance_aucprs = []
for ds in datasets:
    gt = GTs[ds]
    adata = adatas[ds]
    order_hue = hue_orders[ds]
    caf_and_cancer_emt = CAFS_MAL_EMT[ds]
    cancer_and_cancer_emt = MAL_MAL_EMT[ds]
    
    for score_name in sig_list.keys():

        lr_precision, lr_recall, _ = precision_recall_curve(gt, adata.obs[score_name], pos_label=cancer_emt_label)
        lr_auc = auc(lr_recall, lr_precision)

        lr_precision, lr_recall, _ = precision_recall_curve(gt[caf_and_cancer_emt], adata.obs.loc[caf_and_cancer_emt,score_name], pos_label=cancer_emt_label)
        lr_auc_caf_and_emt = auc(lr_recall, lr_precision)

        lr_precision, lr_recall, _ = precision_recall_curve(gt[cancer_and_cancer_emt], adata.obs.loc[cancer_and_cancer_emt,score_name], pos_label=cancer_emt_label)
        lr_auc_cancer_and_emt = auc(lr_recall, lr_precision)


        performance_aucprs.append(
          {'dataset':ds.upper(),
           'sig_name':score_name,
           aucpr_lbl_emt_rest:lr_auc,
           aucpr_lbl_emt_cafs:lr_auc_caf_and_emt,
           aucpr_lbl_emt_mal:lr_auc_cancer_and_emt,
          }
        )

        plt.figure(figsize=(12,10))
        grouped = adata.obs.groupby('celltype_emt')
        for i, curr_type in enumerate(order_hue):
            group = grouped.get_group(curr_type)
            group[score_name].hist(bins=100, density=True, alpha=0.5, label=curr_type, color=pal[i])

        plt.title(f'{score_name} on {ds.upper()}'+\
                  f'\nAUCPR {cancer_emt_label} vs. Rest '+ str(np.round(lr_auc, decimals=3))+\
                  f'\nAUCPR {cancer_emt_label} vs. {caf_label} '+str(np.round(lr_auc_caf_and_emt, decimals=3))+\
                  f'\nAUCPR {cancer_emt_label} vs. {cancer_no_emt_label} '+str(np.round(lr_auc_cancer_and_emt, decimals=3)), fontsize=18)
        plt.legend(fontsize=16)
    #     plt.ylim([0,20])
        plt.xlabel('Score',fontsize=16)
        plt.ylabel('Density',fontsize=16)
        plt.xticks(fontsize=14)
        plt.yticks(fontsize=14)
        plt.tight_layout()
        if save:
            plt.savefig(os.path.join(curr_storing_path, f'{ds}_{score_name}.svg'), format='svg')
            plt.close()

In [None]:
performance_aucprs = pd.DataFrame(performance_aucprs)
if save:
    curr_storing_path = get_path(base_emt_exp_path, add_folder=['results', 'performance'])
    performance_aucprs.to_csv(os.path.join(curr_storing_path, f'performances_scored_sigs.csv'))

In [None]:
performance_aucprs_melted = performance_aucprs.melt(id_vars=['dataset', 'sig_name'],
                                                    value_vars=[aucpr_lbl_emt_rest, aucpr_lbl_emt_cafs, aucpr_lbl_emt_mal],
                                                    var_name = 'AUCPR_type',
                                                    value_name='aucpr'
                                                   )

In [None]:
performance_table = pd.crosstab(index=performance_aucprs_melted['sig_name'],
                columns=[performance_aucprs_melted['AUCPR_type'], performance_aucprs_melted['dataset']],
               values=performance_aucprs_melted['aucpr'],
               aggfunc=lambda x:x)


In [None]:
curr_row_order = [row_order[2]]+row_order[0:2]+row_order[3:]
curr_col_order = sorted(performance_table.columns.tolist(), key=lambda x: (x[0], col_order.index(x[1])))

In [None]:
performance_table = performance_table.loc[curr_row_order, curr_col_order]

In [None]:
performance_table.style.highlight_max(color = 'lightgreen', axis = 0)

In [None]:
if save:
    curr_storing_path = get_path(base_emt_exp_path, add_folder=['results', 'performance'])
    performance_table.to_csv(os.path.join(curr_storing_path, f'performances_scored_sigs_table.csv'))

In [None]:
raise ValueError()

## Overlap signatures 

In [None]:
from  matplotlib_venn import venn3

In [None]:
plt.rcParams.update({'pdf.fonttype':42, 'font.family':'sans-serif', 'font.sans-serif':'Arial', 'font.size':14})

In [None]:
A = set(sig_list['ESCC and LUAD EMT signature'])
B = set(sig_list['Hallmark EMT signature'])
C = set(sig_list['pEMT gene module'])

In [None]:
def get_string_from_list(gene_list, n_words_width=1):
    new_string = ''
    for i in range(0, len(gene_list), n_words_width):
        curr_genes = gene_list[i:(i+n_words_width)]
        new_string += f"{', '.join(curr_genes)}\n"
    new_string = new_string[0:-2]
    return new_string

In [None]:
D = sorted(list((A.intersection(B)).difference(C)))
E = sorted(list((A.intersection(C)).difference(B)))
F = sorted(list(A.intersection(B).intersection(C)))

In [None]:
g = venn3(
    subsets=(
        set(sig_list['ESCC and LUAD EMT signature']), 
        set(sig_list['Hallmark EMT signature']), 
        set(sig_list['pEMT gene module']),
    ),
    set_labels=(
        'ESCC and LUAD EMT signature', 
         'Hallmark EMT signature', 
         'pEMT gene module',
    )
)

# Add a box with text annotation
text_box = {
    'facecolor': 'ivory',        # Box background color
    'edgecolor': 'grey',        # Box border color
    'boxstyle': 'square',         # Box style
    'pad': 0.5,                  # Padding around the text
}

# Set the coordinates and text for the box
width=2
(x1, y1) = (-0.3, 0.75)
text1 = get_string_from_list(D, n_words_width=width)
(x2, y2) = (-1.6, -0.5)
text2 = get_string_from_list(E, n_words_width=width)
(x3, y3) = (0.8, -0.3)
text3 = get_string_from_list(F, n_words_width=width)

plt.text(x1, y1, text1, bbox=text_box, fontstyle='italic',linespacing=1.5);
plt.text(x2, y2, text2, bbox=text_box, fontstyle='italic',linespacing=1.5);
plt.text(x3, y3, text3, bbox=text_box, fontstyle='italic',linespacing=1.5);

t1 = g.get_label_by_id('110').get_position()
t1 = (t1[0], t1[1]+0.07)

t2 = g.get_label_by_id('101').get_position()
t2 = (t2[0]-0.05, t2[1])

t3 = g.get_label_by_id('111').get_position()
t3 = (t3[0]+0.07, t3[1])

# Add arrows from text boxes to the intersection parts of the Venn diagram
arrow_style = dict(arrowstyle='-|>', color='black')
plt.annotate('', xy=t1, xytext=(x1+0.25, y1-0.05), arrowprops=arrow_style);
plt.annotate('', xy=t2, xytext=(x2+0.75, y2+0.325), arrowprops=arrow_style);
plt.annotate('', xy=t3, xytext=(x3-0.03, y3+0.21), arrowprops=arrow_style);

if save:
    curr_storing_path = get_path(base_emt_exp_path, add_folder=['results', 'sig_overlap'])
    plt.savefig(os.path.join(curr_storing_path, f'overlap_hallmark_pemt_escc_luad.svg'), format='svg')

In [None]:
# Convert the lists to sets using set comprehension
sets = [set(values) for values in sig_list.values()]

# Perform the union operation on the sets
all_genes = sorted(list(set.union(*sets)))

In [None]:
df = pd.DataFrame(index=all_genes)
df.index.name = 'ESCC and LUAD EMT signature'

In [None]:
for key, values in sig_list.items():
    df[key] = False
    df.loc[values, key] = True

In [None]:
df = df[['Hallmark EMT signature', 
         'pEMT gene module', 
         'EMT signature Foroutan et al. 2017',
         'EMT signature Groeger et al. 2012',
         'EMT signature Hollern et al. 2018', 
         'EMT signature Mak et al. 2016',
         'EMT signature Tuan et al. 2014', ]]

In [None]:
'ET-1' in df.index

In [None]:
df = df.loc[sig_list['ESCC and LUAD EMT signature']].copy()

In [None]:
df['Contained in nr. pancancer sigs'] = df.sum(axis=1)

In [None]:
df['Contained in nr. pancancer sigs'].value_counts()

In [None]:
df

In [None]:
if save:
    curr_storing_path = get_path(base_emt_exp_path, add_folder=['results', 'sig_overlap'])
    df.to_csv(os.path.join(curr_storing_path, f'overlap_pancancer_sigs.csv'))
    df.to_excel(os.path.join(curr_storing_path, f'overlap_pancancer_sigs.xlsx'))

In [None]:
curr_storing_path = get_path(base_emt_exp_path, add_folder=['results', 'sig_overlap'])

In [None]:
tmp = pd.read_excel(os.path.join(curr_storing_path, r'NEW ESCC- and LUAD-specific cancer EMT signature.xlsx'))

In [None]:
cols_of_interest = ['Gene code', 'Biotype', 'Transcription factor',
       'EMT related genes according to EMTome ',
       'EMT related genes according to EMTome: gene query link ',
       'EMT related genes according to dbEMT ',
       'EMT related genes according to dbEMT : gene query link ',
       'Literature relationship gene with EMT', 'Notes']

In [None]:
tmp = tmp[cols_of_interest].copy()

In [None]:
df.index.name = 'Gene code'

In [None]:
df = df.reset_index()

In [None]:
tmp2 = pd.merge(df, tmp, on='Gene code', how='outer')

In [None]:
df = tmp2[tmp2['Gene code'].isin(df['Gene code'].values)].copy()

In [None]:
df.columns

In [None]:
df = df[['Gene code',
    'Biotype',
    'Transcription factor',
    'Contained in nr. pancancer sigs',
    'Hallmark EMT signature',
    'pEMT gene module',
    'EMT signature Foroutan et al. 2017',
    'EMT signature Groeger et al. 2012',
    'EMT signature Hollern et al. 2018',
    'EMT signature Mak et al. 2016',
    'EMT signature Tuan et al. 2014',
    'EMT related genes according to EMTome ',
    'EMT related genes according to EMTome: gene query link ',
    'EMT related genes according to dbEMT ',
    'EMT related genes according to dbEMT : gene query link ',
    'Literature relationship gene with EMT', 'Notes']].copy()

In [None]:
if save:
    curr_storing_path = get_path(base_emt_exp_path, add_folder=['results', 'sig_overlap'])
    df.to_csv(os.path.join(curr_storing_path, f'overlap_pancancer_sigs_extended.csv'))
    df.to_excel(os.path.join(curr_storing_path, f'overlap_pancancer_sigs_extended.xlsx'))