# scGen batch correction
Based on tutorial at https://scgen.readthedocs.io/en/stable/tutorials/scgen_batch_removal.html.  
**Prerequisites**: Perform preprocessing and data conversion in `preprocessing.ipynb` and `object-conversion.ipynb` respectively.

## Installation
```
# Install this environment from file
conda env create -f scgen.yml

# Link this env to jupyter
conda activate scgen; python -m ipykernel install --user --name scgen --display-name "scgen"; conda deactivate
```
https://scgen.readthedocs.io/en/stable/installation.html#conda-prerequisites  

## Results
Batch correction appears moderately successful. Model parameters are saved in out/scgen-batch-model.pt.  
Comparing gene expression for Neurod1 in gnp, pnc and tumor cells.  
Outputs:
- before/after correction umaps saved to figures/umap_scgen*.png
- neurod1 gene expression for hyperplastic cells in each sample saved to figures/neurod1_gex.svg
- statistical tests saved to out/scgen.log

In [None]:
import scanpy as sc
import anndata as ad
import pandas as pd
import scgen
import scvi
import warnings
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stats
import math
from statsmodels.stats.multitest import fdrcorrection

# Send print statements to a logfile
import logging
logging.basicConfig(level=logging.INFO, format='%(message)s')
logger = logging.getLogger()
logger.addHandler(logging.FileHandler('out/scgen.log', 'w'))
print = logger.info


In [None]:
# Read data
def read_shiraishi_data():
    '''
    Read h5ad files produced by object_conversions.ipynb.
    '''
    # Read and add sample-specific annotations
    samples = []
    for s in ['gnp','pnc','tumor']:
        obj = ad.read_h5ad(f'out/{s}_anndata.h5ad')
        obj.obs['sample'] = s
        obj.obs.index = obj.obs.index.map(lambda x:x+'_'+s)
        samples.append(obj)
    # cat
    cells = ad.concat(samples, join="inner")
    # delete data that are no longer meaningful after the merge
    cells.obs = cells.obs.drop(['mito_regressed_cluster','seurat_clusters','SCT_snn_res.0.8','nCount_SCT','nFeature_SCT'],axis='columns')
    del cells.obsm
    # scGen requires labels; drop unlabelled cells
    cells = cells[~cells.obs.annotation.isna()]
    return cells

In [None]:
cells = read_shiraishi_data()
cells

In [None]:
# embeddings before batch correction
sc.pp.pca(cells)
sc.pp.neighbors(cells,n_pcs=30)
sc.tl.umap(cells)

In [None]:
mpl.rcParams['figure.figsize'] = (6, 6)
sc.pl.umap(cells, color=["sample", "annotation",'Neurod1','Mki67'], frameon=False, save='_scgen_preintegration.png')

In [None]:
cells = cells.copy()
scgen.SCGEN.setup_anndata(cells, batch_key="sample", labels_key="annotation")

In [None]:
model = scgen.SCGEN(cells)
#model.save("out/scgen-batch-model.pt", overwrite=True)

In [None]:
# warnings suggest changing settings to speed up training
scvi.settings.dl_num_workers=11

model.train(
    accelerator='gpu',
    max_epochs=100,
    batch_size=32,
    early_stopping=True,
    early_stopping_patience=25,
)
#TODO: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` 
# in the `DataLoader` to improve performance.

In [None]:
model.save("out/scgen-batch-model.pt", overwrite=True)

In [None]:
# checkpoint: load model if it isn't already
if 'model' not in globals():
    model = scgen.SCGEN.load("out/scgen-batch-model.pt",adata=cells)

In [None]:
# NB: scGen is bugged, need to change scvi code. See
# https://github.com/theislab/scgen/issues/101
with warnings.catch_warnings():
    warnings.simplefilter(action='ignore', category=FutureWarning)
    model.get_latent_representation()

In [None]:
# A bunch of concerning futurewarnings...
corrected_cells = model.batch_removal()
corrected_cells

In [None]:
corrected_cells.write('out/shiraishi_merge.h5ad')

In [None]:
sc.pp.pca(corrected_cells)
sc.pp.neighbors(corrected_cells,n_pcs=30)
sc.tl.umap(corrected_cells)

In [None]:
mpl.rcParams['figure.figsize'] = (6, 6)
sc.pl.umap(corrected_cells, color=['sample', 'annotation','Neurod1','Mki67'], frameon=False, save='_scgen_postintegration.png')

## Plot gene expression

In [None]:
path='out/shiraishi_merge.h5ad'
corrected_cells=ad.read_h5ad(path)

In [None]:
asdf = sns.color_palette("tab10")
print(asdf.as_hex())

In [None]:
import statistics
def plot_gex(anndata,genes):
    '''
    Plot gene expression in hyperplastic cells (ProliferativeCells, DifferentiatedCells).
    '''
    # data formatting
    df = anndata.obs.copy()
    df = df.merge(anndata[:,genes].to_df(),left_index=True,right_index=True)
    df = df[df.annotation.isin(['ProliferativeCells','DifferentiatedCells'])]
    data_long = pd.melt(df, id_vars=['sample'],value_vars=genes, var_name='gene',value_name='expression')
    # plot formatting
    plt.rcParams['figure.figsize'] = (5,5)
    plt.rcParams['svg.fonttype'] = 'none'
    plt.rcParams['font.family'] = 'Arial'
    sns.set_theme(style='ticks')
    ax = sns.violinplot(x='sample',y='expression',data=data_long,inner='box',legend=False,palette=['#1f77b4','#bcbd22','#d62728'])
    sns.despine()
    return ax

def run_statistics(anndata,genes):
    print(f'Running Mann-Whitney U test to compare medians of {genes} gene expression...')
    df = anndata.obs.copy()
    df = df.merge(anndata[:,genes].to_df(),left_index=True,right_index=True)
    df = df[df.annotation.isin(['ProliferativeCells','DifferentiatedCells'])]
    # Initialize a list to store p-values
    p_values = []
    # Perform Mann-Whitney U test for each gene
    samples = df['sample'].unique()
    for gene in genes:
        for i in range(len(samples)):
            si = samples[i]
            class1_values = df[df['sample'] == si][gene].values
            print(f'{si}: n = {len(class1_values)} observations, median {statistics.median(class1_values)}.')
            for j in range(i+1,len(samples)):
                sj=samples[j]
                class2_values = df[df['sample'] == sj][gene].values
                if len(class1_values) > 1 and len(class2_values) > 1:  # Ensure there are enough values for the test
                    stat, p_value = stats.mannwhitneyu(class1_values, class2_values)
                    p_values.append(p_value)
                else:
                    warnings.warn(f'not enough data for test: {gene} {si} {sj}')
                    p_values.append(1.0)  # If there's not enough data, assign a non-significant p-value
                
    # Apply Benjamini-Hochberg correction
    _, p_values_corrected = fdrcorrection(p_values)
    n = len(samples)
    for g in range(len(genes)):
        gene = genes[g]
        for i in range(n):
            for j in range(i+1,n):
                k = int(math.comb(n,2)*g +i*(n-1) - i**2/2 - i/2  +j -1)
                print(f'Mann-Whitney U test for {gene}, {samples[i]} vs {samples[j]}: p-value: {p_values[k]}, Corrected p-value: {p_values_corrected[k]}')
    return

In [None]:
ax = plot_gex(corrected_cells,['Neurod1'])
ax.figure.savefig('figures/neurod1_gex.svg')
run_statistics(corrected_cells,['Neurod1'])

In [None]:
ax = plot_gex(corrected_cells,['Grin2b'])
ax.figure.savefig('figures/Grin2b_gex.svg')
run_statistics(corrected_cells,['Grin2b'])