# This notebook ... TODO

In [None]:
import pathlib
import yaml
import subprocess
import pickle

import pandas as pd
import numpy as np
import scanpy as sc
import anndata as ad

from buddi.preprocessing import sc_preprocess


## Notebook Parameters

In [None]:
# TODO consider whether to move these into config.yml
CELL_TYPE_COL = 'cellType'
SAMPLE_ID_COL = 'sample_id'
STIM_COL = 'stim'

GENE_ID_COL = 'gene_ids'

## Load config

In [None]:
# Get the root directory of the analysis repository
REPO_ROOT = subprocess.run(
    ["git", "rev-parse", "--show-toplevel"], capture_output=True, text=True
).stdout.strip()
REPO_ROOT = pathlib.Path(REPO_ROOT)

CONFIG_FILE = REPO_ROOT / 'config.yml'
assert CONFIG_FILE.exists(), f"Config file not found at {CONFIG_FILE}"

with open(CONFIG_FILE, 'r') as file:
    config_dict = yaml.safe_load(file)

## Retrieve Path to Processed Single-Cell RNA-seq Data and relevant Metadata

In [None]:
STUDY_GEO_ID = 'GSE154600' # TODO consider whether to move this into config.yml as well
SC_DATA_PATH = pathlib.Path(config_dict['data_path']['sc_data_path'])

SC_ADATA_PATH = SC_DATA_PATH / f'{STUDY_GEO_ID}_processed'
assert SC_ADATA_PATH.exists(), f"Processed Single-cell Data path {SC_ADATA_PATH} does not exist"
SC_ADATA_FILE = SC_ADATA_PATH / f'{STUDY_GEO_ID}_processed.h5ad'
assert SC_ADATA_FILE.exists(), f"Processed Single-cell Data file {SC_ADATA_FILE} does not exist"

SC_METADATA_PATH = SC_DATA_PATH / f'{STUDY_GEO_ID}_metadata'
assert SC_METADATA_PATH.exists(), f"Single-cell Metadata path {SC_METADATA_PATH} does not exist"

## Define Path to write Pre-Processing Outputs

In [None]:
PREPROCESSING_OUTPUT_PATH = REPO_ROOT / 'processed_data'
assert PREPROCESSING_OUTPUT_PATH.exists(), f"Preprocessing output path {PREPROCESSING_OUTPUT_PATH} does not exist"
SC_AUGMENTED_DATA_PATH = PREPROCESSING_OUTPUT_PATH / 'sc_augmented'
SC_AUGMENTED_DATA_PATH.mkdir(exist_ok=True, parents=True)

## Preprocessing of scRNA-seq Anndata before Moving to Pseudobulk

### Load and Preprocess Anndata

In [None]:
adata = sc.read_h5ad(SC_ADATA_FILE)
adata.var_names_make_unique()
adata.var[GENE_ID_COL] = adata.var.index.tolist()

In [None]:
# checking if the defined columns are present in the adata.obs
assert CELL_TYPE_COL in adata.obs.columns, f"Column {CELL_TYPE_COL} not found in adata.obs"
assert SAMPLE_ID_COL in adata.obs.columns, f"Column {SAMPLE_ID_COL} not found in adata.obs"
assert STIM_COL in adata.obs.columns, f"Column {STIM_COL} not found in adata.obs"

Print some basic information

In [None]:
print(adata.shape)
print(adata.var.head())
print(adata.obs.head())

## Some Stats

### Stimulation

In [None]:
tab = adata.obs.groupby([SAMPLE_ID_COL, STIM_COL]).size()
tab.unstack()

### Cell type

In [None]:
adata.obs[CELL_TYPE_COL].value_counts()

## Write the Dense Expression Matrix and Celltype column to use in CIBERSORTx


In [None]:
sc_profile_file = SC_AUGMENTED_DATA_PATH / f'{STUDY_GEO_ID}_sig.pkl'

dense_matrix = adata.X.todense()
dense_df = pd.DataFrame(dense_matrix, columns = adata.var[GENE_ID_COL])
dense_df.insert(loc=0, column=CELL_TYPE_COL, value=adata.obs[CELL_TYPE_COL].to_list())

pickle.dump( dense_df, open( sc_profile_file, "wb" ) )

# free up memory
del dense_matrix
del dense_df

## Make Pseudobulks

In [None]:
gene_out_file = SC_AUGMENTED_DATA_PATH / f'{STUDY_GEO_ID}_genes.pkl'
gene_ids = adata.var[GENE_ID_COL]
pickle.dump(gene_ids, open( gene_out_file, "wb" ) )

In [None]:
ADD_PER_CELL_TYPE_NOISE = True
N_CELLS_PER_PSEUDO_BULK = 5_000
N_PSEUDO_BULKS_PER_CONDITION = 1_000

# Unique values for experiment/perturbation/cell type
samples = adata.obs[SAMPLE_ID_COL].unique()
stims = adata.obs[STIM_COL].unique()
cell_types = adata.obs[CELL_TYPE_COL].unique()

n_samples = len(samples)
n_genes = len(gene_ids)
n_cell_types = len(cell_types)

# Define cell-type level noise for the generated pseudo-bulk profiles
if ADD_PER_CELL_TYPE_NOISE:
    # this produces a list of numpy arrays, each of length n_genes
    # to reflect the expression noise associated with each specific cell type
    per_cell_type_noise = [
        np.random.lognormal(0, 0, n_genes) for i in range(n_cell_types)]
else:
    per_cell_type_noise = None

# Generate pseudo-bulk profiles grouping by sample_id and stim
for _sample in samples:
    for _stim in stims:

        print(f"Generating pseudo-bulk profiles for sample {_sample} and stim {_stim}")

        ## Subset adata to the current sample and stim
        subset_idx = np.logical_and(
            adata.obs[SAMPLE_ID_COL] == _sample, 
            adata.obs[STIM_COL] == _stim
            )
        if len(subset_idx) == 0:
            continue
        subset_adata = adata[subset_idx, :]

        print("Generating random prop pseudo-bulk profiles ...")
        random_prop_pb_outputs = sc_preprocess.make_prop_and_sum(
            in_adata=subset_adata,
            # the number of pseudo-bulk profiles to generate
            num_samples=N_PSEUDO_BULKS_PER_CONDITION,
            # the number of cells included/sampled when generating each pseudo-bulk profile
            num_cells=N_CELLS_PER_PSEUDO_BULK,
            # pseudo-bulk profiles will be generated with random proportions
            use_true_prop=False,
            # apply the per cell type noise
            cell_noise=per_cell_type_noise,
            # no sample noise
            useSampleNoise=False,
        )

        count_df, pb_df, test_count_df, test_pb_df = random_prop_pb_outputs
        # divide the count matrix by the sum of each row to get proportions
        prop_df = count_df.div(count_df.sum(axis=1), axis=0)
        test_prop_df = test_count_df.div(test_count_df.sum(axis=1), axis=0)

        del count_df
        del test_count_df

        n_random_prop_pbs = len(pb_df)

        print("Generating single cell type dominant pseudo-bulk profiles ...")

        # Generate pseudo-bulk profiles where a single cell type dominates
        # this will produce num_samp * n_cell_types pseudo-bulk profiles
        ct_prop_df = sc_preprocess.get_single_celltype_prop_matrix(
            num_samp=100, #  generate 100 per cell type
            cell_order=cell_types
        )

        # Use proportion matrix to generate pseudo-bulk profiles
        sc_prop_df, sc_pb_df, _ = sc_preprocess.use_prop_make_sum(
            in_adata=subset_adata,
            num_cells=N_CELLS_PER_PSEUDO_BULK,
            # use the generated single cell type dominant proportion matrix
            props_vec=ct_prop_df,
            # apply the same per cell type noise used for random prop pseudo-bulk profiles
            cell_noise=per_cell_type_noise,
            # no sample noise
            sample_noise=None,
            useSampleNoise=False
        )

        n_single_celltype_pbs = len(sc_pb_df)

        print('Concatenating the two types of pseudo-bulk profiles ...')
        prop_df =  pd.concat([prop_df, sc_prop_df])
        pb_df = pd.concat([pb_df, sc_pb_df])

        n_total_pbs = n_random_prop_pbs + n_single_celltype_pbs

        metadata_df = pd.DataFrame(
            data = {"sample_id":[_sample]*n_total_pbs,
                    "stim":[_stim]*n_total_pbs,
                    "cell_prop_type":['random']*n_random_prop_pbs + ['single_celltype']*n_single_celltype_pbs,
                    "samp_type":['sc_ref']*n_total_pbs}
                    )
        
        print("Writing the pseudo-bulk profiles ...")
        pseudobulk_file = SC_AUGMENTED_DATA_PATH / f'{STUDY_GEO_ID}_{_sample}_{_stim}_pseudo_splits.pkl'
        prop_file = SC_AUGMENTED_DATA_PATH / f'{STUDY_GEO_ID}_{_sample}_{_stim}_prop_splits.pkl'
        meta_file = SC_AUGMENTED_DATA_PATH / f'{STUDY_GEO_ID}_{_sample}_{_stim}_meta_splits.pkl'

        pickle.dump( prop_df, open( prop_file, "wb" ) )
        pickle.dump( pb_df, open( pseudobulk_file, "wb" ) )
        pickle.dump( metadata_df, open( meta_file, "wb" ) )
