# Single-Cell Report: Harmony Batch Correction

In [None]:
# Import packages
import scanpy as sc
import numpy as np
import matplotlib.pyplot as plt
import random
import json
import os

In [None]:
# To run outside VSN framework
# PRE_BEC_H5AD=os.join("../data/intermediate", PRE_BEC_H5AD)
# POST_BEC_H5AD=os.join("../data/intermediate", POST_BEC_H5AD)

In [None]:
PRE_BEC_H5AD=FILE1
POST_BEC_H5AD=FILE2

In [None]:
params = json.loads(WORKFLOW_PARAMETERS)
bec_params = params["tools"]["harmony"]
batch = bec_params["varsUse"][0] if "varsUse" in bec_params else "sample_id"

#### Plotting settings and functions

In [None]:
# plot settings
sc.set_figure_params(dpi=150, fontsize=10, dpi_save=600)

In [None]:
def barPlotByAnnotation( obs, axis1, axis2, title, clustering_algorithm, annotation):
    cluster_by_anno = obs.groupby(by=[clustering_algorithm, annotation]).size().unstack()
    cluster_by_anno.columns = [f"{c} (n={sum((obs[annotation] == c))})" for c in cluster_by_anno.columns]
    cluster_by_anno.index = [int(n) + 1 for n in cluster_by_anno.index]
    cluster_by_anno_norm = (cluster_by_anno/cluster_by_anno.sum()) * 100
    cluster_by_anno_norm.plot(kind='bar', stacked=False, fontsize=8, width=.7, grid=False, ax=axis1)
    axis1.set_ylabel(f'Percentage of {annotation.capitalize()} (%)');
    axis1.set_title(title)
    axis1.legend(fancybox=True, framealpha=0.5, loc='right', bbox_to_anchor=(1.15, 0.5))
    (cluster_by_anno.divide(cluster_by_anno.sum(axis=1), axis='rows') * 100).plot(kind='bar', stacked=True, fontsize=8, width=.75, grid=False, ax=axis2)
    #axis1.xticks
    axis2.set_ylabel('Percent of Cluster (%)');
    axis2.set_title(title)
    axis2.legend(fancybox=True, framealpha=0.5, loc='right', bbox_to_anchor=(1.15, 0.5))

#### Read Data

In [None]:
adata1 = sc.read_h5ad(filename=PRE_BEC_H5AD)
adata2 = sc.read_h5ad(filename=POST_BEC_H5AD)

In [None]:
# shuffle the cells in the matrix to avoid batch overplotting
cellID1 = list(adata1.obs_names)
random.shuffle(cellID1)
adata1 = adata1[cellID1]

cellID2 = list(adata2.obs_names)
random.shuffle(cellID2)
adata2 = adata2[cellID2]

In [None]:
# Add by default sample_id as annotation to plot on top of the cell embeddings
if "sample_id" in adata1.obs.keys() and "sample_id" in adata2.obs.keys():
    if len(np.unique(adata1.obs.sample_id)) < 256 and len(np.unique(adata2.obs.sample_id)) < 256:
        annotations_to_plot = annotations_to_plot if batch == "sample_id" else annotations_to_plot + ["sample_id"]

In [None]:
# Check all annotations are in the 2 adata files
for i,v in enumerate(range(len(annotations_to_plot))):
    if annotations_to_plot[i] not in adata1.obs.keys():
        raise Exception(f"The annotation {annotations_to_plot[i]} is not present in {PRE_BEC_H5AD}.")
    if annotations_to_plot[i] not in adata2.obs.keys():
        raise Exception(f"The annotation {annotations_to_plot[i]} is not present in {POST_BEC_H5AD}.")

---
## Batch effect correction

In [None]:
clustering_algorithm = ''
if 'louvain' in adata1.uns:
    clustering_algorithm = 'louvain'
elif 'leiden' in adata1.uns:
    clustering_algorithm = 'leiden'
else:
    print("Invalid clustering algorithm!")

In [None]:
print(f"{clustering_algorithm.capitalize()} resolution: {adata1.uns[clustering_algorithm]['params']['resolution']}")

### t-SNE

In [None]:
a = 0.6 # alpha setting
number_of_subplots=len(annotations_to_plot)

for i,v in enumerate(range(number_of_subplots)):
    fig, (axs) = plt.subplots(1,2, figsize=(10,5), dpi=150 )
    annotation_to_plot = annotations_to_plot[i]
    ax1 = sc.pl.tsne(adata1, color=annotation_to_plot, alpha=a, ax=axs[0], show=False, wspace=0.5)
    ax1.legend(fancybox=True, framealpha=0.5, loc='right', bbox_to_anchor=(1.15, 0.5))
    ax1.set_title(f"Pre-batch correction ({annotation_to_plot})")
    ax2 = sc.pl.tsne(adata2, color=annotation_to_plot, alpha=a, ax=axs[1], show=False, wspace=0.5)
    ax2.legend(fancybox=True, framealpha=0.5, loc='right', bbox_to_anchor=(1.15, 0.5))
    ax2.set_title(f"Post-batch correction ({annotation_to_plot})")
    plt.tight_layout()

In [None]:
fig, ((ax1,ax2), (ax3,ax4)) = plt.subplots(2,2, figsize=(10,10), dpi=150 )
ax1 = sc.pl.tsne(adata1, color=batch, alpha=a, ax=ax1, show=False, wspace=0.5, title='batch')
ax2 = sc.pl.tsne(adata2, color=batch, alpha=a, ax=ax2, show=False, wspace=0.5, title='batch')
ax1.legend(fancybox=True, framealpha=0.5, loc='right', bbox_to_anchor=(1.15, 0.5))
ax2.legend(fancybox=True, framealpha=0.5, loc='right', bbox_to_anchor=(1.15, 0.5))
sc.pl.tsne(adata1, color=clustering_algorithm, alpha=a, palette=sc.pl.palettes.godsnot_102, ax=ax3, show=False, wspace=0.5)
sc.pl.tsne(adata2, color=clustering_algorithm, alpha=a, palette=sc.pl.palettes.godsnot_102, ax=ax4, show=False, wspace=0.5)

ax1.set_title('Pre-batch correction (batch)')
ax2.set_title('Post-batch correction (batch)')
ax3.set_title(f'Pre-batch correction ({clustering_algorithm.capitalize()})')
ax4.set_title(f'Post-batch correction ({clustering_algorithm.capitalize()})')
#
plt.tight_layout()

### UMAP

In [None]:
a = 0.6 # alpha setting
number_of_subplots=len(annotations_to_plot)

for i,v in enumerate(range(number_of_subplots)):
    fig, (axs) = plt.subplots(1,2, figsize=(10,5), dpi=150 )
    annotation_to_plot = annotations_to_plot[i]
    ax1 = sc.pl.umap(adata1, color=annotation_to_plot, alpha=a, ax=axs[0], show=False, wspace=0.5)
    ax1.legend(fancybox=True, framealpha=0.5, loc='right', bbox_to_anchor=(1.15, 0.5))
    ax1.set_title(f"Pre-batch correction ({annotation_to_plot})")
    ax2 = sc.pl.umap(adata2, color=annotation_to_plot, alpha=a, ax=axs[1], show=False, wspace=0.5)
    ax2.legend(fancybox=True, framealpha=0.5, loc='right', bbox_to_anchor=(1.15, 0.5))
    ax2.set_title(f"Post-batch correction ({annotation_to_plot})")
    plt.tight_layout()

In [None]:
fig, ((ax1,ax2), (ax3,ax4)) = plt.subplots(2,2, figsize=(10,10), dpi=150 )
ax1 = sc.pl.umap(adata1, color=batch, alpha=a, ax=ax1, show=False, wspace=0.5, title='batch')
ax2 = sc.pl.umap(adata2, color=batch, alpha=a, ax=ax2, show=False, wspace=0.5, title='batch')
ax1.legend(fancybox=True, framealpha=0.5, loc='right', bbox_to_anchor=(1.15, 0.5))
ax2.legend(fancybox=True, framealpha=0.5, loc='right', bbox_to_anchor=(1.15, 0.5))
sc.pl.umap(adata1, color=clustering_algorithm, alpha=a, palette=sc.pl.palettes.godsnot_102, ax=ax3, show=False, wspace=0.5)
sc.pl.umap(adata2, color=clustering_algorithm, alpha=a, palette=sc.pl.palettes.godsnot_102, ax=ax4, show=False, wspace=0.5)

ax1.set_title('Pre-batch correction (batch)')
ax2.set_title('Post-batch correction (batch)')
ax3.set_title(f'Pre-batch correction ({clustering_algorithm.capitalize()})')
ax4.set_title(f'Post-batch correction ({clustering_algorithm.capitalize()})')
#
plt.tight_layout()

#### Cluster membership by batch

The following plots show how the batches distribute in each predicted Louvain cluster. The proportion of cells from each batch that belong to a particular cluster are shown in the top row (percentage of batch), pre- and post filtering. The bottom row (percent of cluster) shows the batch composition of each cluster, pre- and post filtering. In the pre-batch correction plots, clusters tend to be based on batch, while post-batch correction, there is a more even distribution of batches in each cluster.

In [None]:
for annotation_to_plot in annotations_to_plot:
    fig, ((ax1,ax2),(ax3,ax4)) = plt.subplots(2,2, figsize=(15,8), dpi=150 )
    barPlotByAnnotation(adata1.obs, axis1=ax1, axis2=ax3, title=f"Pre-batch correction ({annotation_to_plot})", clustering_algorithm=clustering_algorithm, annotation=annotation_to_plot)
    barPlotByAnnotation(adata2.obs, axis1=ax2, axis2=ax4, title=f"Post-batch correction ({annotation_to_plot})", clustering_algorithm=clustering_algorithm, annotation=annotation_to_plot)
    plt.tight_layout()

In [None]:
fig, ((ax1,ax2),(ax3,ax4)) = plt.subplots(2,2, figsize=(15,8), dpi=150 )
barPlotByAnnotation(adata1.obs, axis1=ax1, axis2=ax3, title="Pre-batch correction (batch)", clustering_algorithm=clustering_algorithm, annotation=batch)
barPlotByAnnotation(adata2.obs, axis1=ax2, axis2=ax4, title="Post-batch correction (batch)", clustering_algorithm=clustering_algorithm, annotation=batch)
plt.tight_layout()