## Get cancer EMT cells for BREAST
Cancer EMT cells are cancer cells assigned to gene-module 3 (GM3) in the dataset publication [[1]](https://www.nature.com/articles/s41588-021-00911-1). 

In [None]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

In [None]:
import os 
import sys
from glob import glob

sys.path.append('../../..')

import pandas as pd
import numpy as np
import scanpy as sc
import matplotlib.pyplot as plt
from matplotlib_venn import venn2
import seaborn as sns

sys.path.append('../../..')
from data.load_data import load_datasets
from data.constants import BASE_PATH_DATA, BASE_PATH_EXPERIMENTS

from signaturescoring import score_signature
from signaturescoring.utils.utils import get_mean_and_variance_gene_expression, check_signature_genes

In [None]:
sc.settings.verbosity = 2

save = True

dataset_long = 'breast_large'
dataset = 'breast'

In [None]:
storing_path = os.path.join(BASE_PATH_EXPERIMENTS, f'EMT_signature_scoring_case_study/{dataset}')
if not os.path.exists(storing_path):
    os.makedirs(storing_path)
    sc.logging.info(f'Creating new directory to store the results.')

Load preprocessed data

In [None]:
adata = load_datasets(dataset_long, preprocessed=True, norm_method='mean')
if 'log1p' in adata.uns_keys():
    adata.uns['log1p']['base'] = None
else:
    adata.uns['log1p'] = {'base': None}

In [None]:
adata.obs.malignant_key.value_counts().sort_index()

In [None]:
#sc.tl.pca(adata)
#sc.pp.neighbors(adata)
#sc.tl.umap(adata)

In [None]:
#plt.rcParams.update({'pdf.fonttype':42, 'font.family':'sans-serif', 'font.sans-serif':'Arial', 'font.size':14})

#sc.pl.umap(adata, color=['sample_id','malignant_key', 'celltype'], ncols=1)

### Select cancer EMT cells and store barcodes

In [None]:
def create_pivot_table(data, index, column):
    tmp = data.obs[[index, column]].value_counts().sort_index().reset_index()
    tmp.columns = [index, column, 'count']
    return pd.pivot(tmp,
                    columns=column,
                    index=index,
                    values='count')

In [None]:
create_pivot_table(adata,'gene_module', 'celltype')

In [None]:
adata.obs.celltype.value_counts().sort_index()

In [None]:
## get gene module signatures 

In [None]:
fns = sorted(glob(BASE_PATH_DATA+'/annotations/breast/*_GM*.csv'))
sigs = {f"{fn.rsplit('/', 1)[1].split('.')[0].split('_', 1)[1]}_scores": pd.read_csv(fn).iloc[:,0].tolist() for fn in fns}

In [None]:
cancer_cells = adata[(adata.obs.celltype=='Cancer Epithelial')].copy()

In [None]:
for key, sig in sigs.items():
    score_signature(method="adjusted_neighborhood_scoring",
                    adata=cancer_cells,
                    gene_list=sig,
                    ctrl_size=100,
                    score_name=key)

In [None]:
adata.obs.loc[cancer_cells.obs_names, list(sigs.keys())] = cancer_cells.obs[list(sigs.keys())]

### filter non-BASAL samples

In [None]:
basal_samples = ['CID3586',  'CID3963', 
                 'CID4465',  'CID4495',
                 'CID44971', 'CID4513',
                 'CID4515',  'CID4523',
                ]

In [None]:
adata = adata[adata.obs.sample_id.isin(basal_samples)].copy()

### look at score distributions between GM3 and other gene modules

In [None]:
from scipy.stats import pearsonr
import matplotlib.pyplot as plt 

def corrfunc(x, y, ax=None, **kws):
    """Plot the correlation coefficient in the top left hand corner of a plot."""
    r, _ = pearsonr(x, y)
    ax = ax or plt.gca()
    ax.annotate(f'r = {r:.2f}', xy=(.1, .9), xycoords=ax.transAxes)

In [None]:
## What are the ranges of score  of other gene modules for the cancer epithelial cells with class label GM3
cancer_gm3 = adata[(adata.obs.celltype=='Cancer Epithelial')&(adata.obs.gene_module.isin(['3']))].obs

In [None]:
f, ax = plt.subplots(nrows=len(sigs.keys()), ncols=1, figsize=(8,len(sigs.keys())*5))
for i, val in enumerate(list(sigs.keys())):
    min_val = cancer_gm3[['GM3_scores', val]].min().min()
    max_val = cancer_gm3[['GM3_scores', val]].max().max()
    line_plot = np.linspace(min_val, max_val, 100)

    sns.scatterplot(cancer_gm3,
                    x='GM3_scores',
                    y=val,
                    hue='gene_module', 
                    alpha=0.5,
                   ax=ax[i])
    ax[i].plot(line_plot,line_plot, color='r', ls=':')
    ax[i].axhline(0.25, color='g', ls=':')
    ax[i].set_title(f'GM3_scores vs. {val}')

In [None]:
rem_gms = sorted(list(set(list(sigs.keys())) - set(['GM3_scores'])))
rem_gms

In [None]:
good_remaining_scores = pd.concat([cancer_gm3[gm]<=0.25 for gm in rem_gms], axis=1).all(axis=1)
good_remaining_scores.value_counts()

In [None]:
barcodes_cancer_emt_1 = cancer_gm3.index.tolist()
barcodes_cancer_emt_2 = cancer_gm3[good_remaining_scores].index.tolist()
barcodes_cancer_emt_3 = cancer_gm3[cancer_gm3[list(sigs.keys())].idxmax(axis=1) == 'GM3_scores'].index.tolist()

In [None]:
## What is the range of GM3 scores for cancer epithelial cells with class label other than GM3
cancer_non_gm3 = adata[(adata.obs.celltype=='Cancer Epithelial')&(~adata.obs.gene_module.isin(['3']))].obs
cancer_non_gm3.gene_module.value_counts()

In [None]:
f, ax = plt.subplots(nrows=len(sigs.keys()), ncols=1, figsize=(8,len(sigs.keys())*5))
for i, val in enumerate(list(sigs.keys())):
    min_val = cancer_non_gm3[['GM3_scores', val]].min().min()
    max_val = cancer_non_gm3[['GM3_scores', val]].max().max()
    line_plot = np.linspace(min_val, max_val, 100)

    sns.scatterplot(cancer_non_gm3,
                    x='GM3_scores',
                    y=val,
                    hue='gene_module', 
                    alpha=0.5,
                   ax=ax[i])
    ax[i].plot(line_plot,line_plot, color='r', ls=':')
    ax[i].axvline(0.25, color='g', ls=':')
    ax[i].set_title(f'GM3_scores vs. {val}')

In [None]:
barcodes_to_remove = cancer_non_gm3[cancer_non_gm3.GM3_scores>0.15].index.tolist()
(cancer_non_gm3.GM3_scores>0.15).value_counts()

In [None]:
g = sc.pl.violin(adata, keys=list(sigs.keys()), groupby='gene_module', rotation=90, show=False)
for ax in g:
    ax.axhline(0.25, color='r')

In [None]:
# barcodes_cancer_emt = adata[(adata.obs.celltype=='Cancer Epithelial')&(adata.obs.gene_module=='3')].obs_names.tolist()

In [None]:
len(barcodes_cancer_emt_1), len(barcodes_cancer_emt_2), len(barcodes_cancer_emt_3), len(barcodes_to_remove)

In [None]:
## define path to store the cancer emt cell barcodes
if save:
    pd.Series(barcodes_cancer_emt_1).to_csv(os.path.join(storing_path, 'barcodes_cancer_emt_1.csv'))
    pd.Series(barcodes_cancer_emt_2).to_csv(os.path.join(storing_path, 'barcodes_cancer_emt_2.csv'))
    pd.Series(barcodes_cancer_emt_3).to_csv(os.path.join(storing_path, 'barcodes_cancer_emt_3.csv'))
    pd.Series(barcodes_to_remove).to_csv(os.path.join(storing_path, 'barcodes_to_remove.csv'))