# Integration of published wound healing datasets using Scanorama #
This notebook contains the code used to generate the majority of Figure 1. That is, the integration of various published *mouse* wound healing datasets, which detail various conditions, wound size, and time from wounding:
a
1. [Haensel et al. (2020)](https://www.sciencedirect.com/science/article/pii/S2211124720302655): Uninjured, small wound PWD4 (multiple samples of each), epidermis + dermis
2. [Abbasi et al. (2020)](https://doi.org/10.1016/j.stem.2020.07.008): Uninjured (this one is optional)
3. [Phan et al. (2020)](https://elifesciences.org/articles/60066): Uninjured and small wound PWD7.
4. [Guerrero-Juarez et al. (2019)](https://doi.org/10.1038/s41467-018-08247-x): Large wound PWD12, mainly dermis
5. [Gay et al. (2020)](https://doi.org/10.1126/sciadv.aay3704): Fibrotic and regenerative large wounds PWD18, mostly dermis and immune cells.

For data integration, we will use the package, [Scanorama](https://github.com/brianhie/scanorama) (see [Hie et al., 2019](https://www.nature.com/articles/s41587-019-0113-3)).

In [None]:
# Let's start loading the data. We have to do this paper by paper, so we'll initialise a list of the annotated data frames.
data_directory = '../data/' # Where we loaded the data

# Where we will save the h5ad file
results_directory = '../data/'

In [1]:
# Load the relevant packages.
# First load the packages.
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
import seaborn as sns
import scanorama as scrama
from scipy import sparse

sc.settings.verbosity = 3 # Possible values: (0) errors, (1) warnings, (2) info, (3) hints
sc.settings.set_figure_params(dpi = 100, facecolor='white', fontsize=18, transparent=True)

Load the data from each study.

### UW P21: Phan et al. (2020)

In [None]:
# Phan et al. (2020)
paper_name = 'Phan2020/'
conditions = ['UnwoundedP21/P21_1.loom', 'UnwoundedP21/P21_2.loom', 'UnwoundedP21/P21_3.loom']
samples = ['UW P21', 'UW P21', 'UW P21']
sub_samples = ['UWP21_1', 'UWP21_2', 'UWP21_3']

uw_p21_datasets = []
# We now add the data to the list
for i in range(len(conditions)):
    
    directory = data_directory + paper_name + conditions[i]
    data = sc.read_loom(directory, obs_names = 'CellID', var_names = 'Gene') # Gene symbols will be used as variable names
    data.var_names_make_unique()
    data.obs.index = data.obs.index.str.split('-').str[0]
    data.obs.index.name = None
    data.obs['sample'] = samples[i]
    data.obs['sub_sample'] = sub_samples[i]
    data.var.rename(columns={'Accession':'gene_ids'}, inplace=True)
    data.obs.drop(['Clusters', '_X', '_Y'], axis=1, inplace=True) # Drop the unnecessary labels
    data.var.drop(['Chromosome', 'Start', 'End', 'Strand'], axis=1, inplace=True)
    del data.layers['matrix'], data.layers['ambiguous'] # Don't need these from velocyto
    uw_p21_datasets.append(data)

### UW P49: Haensel et al. (2020)

In [None]:
# Now Haensel et al. (2020)
paper_name = 'Haensel2020/'
conditions = ['Unwounded/1/', 'Unwounded/2/', 'Wounded/1/', 'Wounded/2/', 'Wounded/3/']
samples = ['UW P49', 'UW P49', 'SW PWD4', 'SW PWD4', 'SW PWD4']
sub_samples = ['UWP49_1', 'UWP49_2', 'SWPWD4_1', 'SWPWD4_2', 'SWPWD4_3']

uwp49_datasets = []
swpwd4_datasets = []
# We now add the data to the list
for i in range(len(conditions)):
    
    # Get the published counts
    directory = data_directory + paper_name + conditions[i]
    data = sc.read_10x_mtx(directory, # Directory with relevant .mtx file
        var_names = 'gene_symbols', # Gene symbols will be used as variable names
        cache=True)
    data.obs.index = data.obs.index.str.split('-').str[0]
    data.obs['sample'] = samples[i]
    data.obs['sub_sample'] = sub_samples[i]
#     wound_datasets.append(data)
    
    # Get the RNA velocity data obtained with kallisto
    directory = data_directory + paper_name + conditions[i] + 'counts_unfiltered/adata.h5ad'
    data_kb = sc.read_h5ad(directory)
    data_kb.obs['sample'] = samples[i]
    data_kb.obs['sub_sample'] = sub_samples[i]
    data_kb.obs.index.name = None
    
    data_kb.obs.index = data_kb.obs.index.str.split('-').str[0]
    data_kb.var.index = data_kb.var.index.str.split('.').str[0]
    data_kb.var.rename({'gene_name':'gene_ids'}, axis=1, inplace=True)
    gene_ids = data_kb.var.index
    data_kb.var.index = pd.Index(data_kb.var['gene_ids'].tolist()) # The index names should be the gene ids
    data_kb.var['gene_ids'] = gene_ids
    data_kb.var.index.name = None
    data_kb.var_names_make_unique()
    
    # Get the intersections of the cells and gene sets
    var_intersect = data.var.index.intersection(data_kb.var.index)
    obs_intersect = data.obs.index.intersection(data_kb.obs.index)
    
    # Subset the dataframes
    data_subset = data[obs_intersect, var_intersect]    
    data_kb_subset = data_kb[obs_intersect, var_intersect]
    
    # Set the spliced/unspliced layers
    data_subset.layers['spliced'] = data_kb_subset.layers['spliced']
    data_subset.layers['unspliced'] = data_kb_subset.layers['unspliced']
    
    if i < 2: # Do this as there are only 2 UW P49 datasets but 3 SW PWD4 datasets
        uwp49_datasets.append(data_subset)
    else:
        swpwd4_datasets.append(data_subset)
        

### SW PWD7: Phan et al. (2020)

In [None]:
# Phan et al. (2020)
paper_name = 'Phan2020/'
conditions = ['WoundedP21/P21_1_Wound.loom', 'WoundedP21/P21_2_Wound.loom', 'WoundedP21/P21_3_Wound.loom']
samples = ['SW PWD7', 'SW PWD7', 'SW PWD7']
sub_samples = ['SWPWD7_1', 'SWPWD7_2', 'SWPWD7_3']

swpwd7_datasets = []
# We now add the data to the list
for i in range(len(conditions)):
    
    directory = data_directory + paper_name + conditions[i]
    data = sc.read_loom(directory, obs_names = 'CellID', var_names = 'Gene') # Gene symbols will be used as variable names
    data.var_names_make_unique()
    data.obs.index = data.obs.index.str.split('-').str[0]
    data.obs.index.name = None
    data.obs['sample'] = samples[i]
    data.obs['sub_sample'] = sub_samples[i]
    data.var.rename(columns={'Accession':'gene_ids'}, inplace=True)
    data.obs.drop(['Clusters', '_X', '_Y'], axis=1, inplace=True) # Drop the unnecessary labels
    data.var.drop(['Chromosome', 'Start', 'End', 'Strand'], axis=1, inplace=True)
    del data.layers['matrix'], data.layers['ambiguous']
    # Also need to remove non-ribo genes in this set... they do weird things.
    non_ribo_genes = [name for name in data.var_names if not name.startswith('Rp')]
    data = data[:, non_ribo_genes]
    swpwd7_datasets.append(data)

### LW PWD12: Guerrero-Juarez et al. (2019)

In [None]:
# Guerrero-Juarez et al. (2019)
paper_name = 'GuerreroJuarez2019/'
conditions = ['Shortened/']
samples = ['LW PWD12']
sub_samples = ['LWPWD12']

lwpwd12_datasets = []

# We now add the data to the list
for i in range(len(conditions)):
    
    # Get the published counts
    directory = data_directory + paper_name + conditions[i]
    data = sc.read_10x_mtx(directory, # Directory with relevant .mtx file
        var_names = 'gene_symbols', # Gene symbols will be used as variable names
        cache=True)
    data.obs['sample'] = samples[i]
    data.obs['sub_sample'] = sub_samples[i]
    data.obs.index = data.obs.index.str.split('-').str[0]
    wound_datasets.append(data)

    # Get the RNA velocity data obtained with kallisto
    directory = data_directory + paper_name + conditions[i] + 'counts_unfiltered/adata.h5ad'
    data_kb = sc.read_h5ad(directory)
    data_kb.obs['sample'] = samples[i]
    data_kb.obs['sub_sample'] = sub_samples[i]
    data_kb.obs.index.name = None
    
    data_kb.obs.index = data_kb.obs.index.str.split('-').str[0]
    data_kb.var.index = data_kb.var.index.str.split('.').str[0]
    data_kb.var.rename({'gene_name':'gene_ids'}, axis=1, inplace=True)
    gene_ids = data_kb.var.index
    data_kb.var.index = pd.Index(data_kb.var['gene_ids'].tolist()) # The index names should be the gene ids
    data_kb.var['gene_ids'] = gene_ids
    data_kb.var.index.name = None
    data_kb.var_names_make_unique()
    
    # Get the intersections of the cells and gene sets
    var_intersect = data.var.index.intersection(data_kb.var.index)
    obs_intersect = data.obs.index.intersection(data_kb.obs.index)
    
    # Subset the dataframes
    data_subset = data[obs_intersect, var_intersect]    
    data_kb_subset = data_kb[obs_intersect, var_intersect]
    
    # Set the spliced/unspliced layers
    data_subset.layers['spliced'] = data_kb_subset.layers['spliced']
    data_subset.layers['unspliced'] = data_kb_subset.layers['unspliced']
    
    lwpwd12_datasets.append(data_subset)

### LW PWD14: Abbasi et al. (2020)

In [None]:
# First load Abbasi et al. (2020)
paper_name = 'Abbasi2020/'
conditions = ['LargeWoundCentre_PWD14/', 'LargeWoundPeriphery_PWD14/']
samples = ['LW PWD14', 'LW PWD14']
sub_samples = ['LWPWD14_C', 'LWPWD14_P']

lwpwd14_datasets = []
# We now add the data to the list
for i in range(len(conditions)):
    
    directory = data_directory + paper_name + conditions[i]
    data = sc.read_10x_mtx(directory, # Directory with relevant .mtx file
        var_names = 'gene_symbols', # Gene symbols will be used as variable names
        cache=True)
    data.obs['sample'] = samples[i]
    data.obs['sub_sample'] = sub_samples[i]
    data.obs.index = data.obs.index.str.split('-').str[0]
    wound_datasets.append(data)
    
    # Get the RNA velocity data obtained with kallisto
    directory = data_directory + paper_name + conditions[i] + 'counts_unfiltered/adata.h5ad'
    data_kb = sc.read_h5ad(directory)
    data_kb.obs['sample'] = samples[i]
    data_kb.obs['sub_sample'] = sub_samples[i]
    data_kb.obs.index.name = None
    
    data_kb.obs.index = data_kb.obs.index.str.split('-').str[0]
    data_kb.var.index = data_kb.var.index.str.split('.').str[0]
    data_kb.var.rename({'gene_name':'gene_ids'}, axis=1, inplace=True)
    gene_ids = data_kb.var.index
    data_kb.var.index = pd.Index(data_kb.var['gene_ids'].tolist()) # The index names should be the gene ids
    data_kb.var['gene_ids'] = gene_ids
    data_kb.var.index.name = None
    data_kb.var_names_make_unique()
    
    # Get the intersections of the cells and gene sets
    var_intersect = data.var.index.intersection(data_kb.var.index)
    obs_intersect = data.obs.index.intersection(data_kb.obs.index)
    
    # Subset the dataframes
    data_subset = data[obs_intersect, var_intersect]    
    data_kb_subset = data_kb[obs_intersect, var_intersect]
    
    # Set the spliced/unspliced layers
    data_subset.layers['spliced'] = data_kb_subset.layers['spliced']
    data_subset.layers['unspliced'] = data_kb_subset.layers['unspliced']
    
    lwpwd14_datasets.append(data_subset)

### LW FIB/REG PWD18: Gay et al. (2020)

In [None]:
# Gay et al. (2020)
paper_name = 'Gay2020/'
conditions = ['Fibrotic/', 'Regenerative/']
samples = ['FIB PWD18', 'REG PWD18']
sub_samples = ['FIBPWD18', 'REGPWD18']

lwpwd18_datasets = []

# We now add the data to the list
for i in range(len(conditions)):
    
    directory = data_directory + paper_name + conditions[i]
    data = sc.read_10x_mtx(directory, # Directory with relevant .mtx file
        var_names = 'gene_symbols', # Gene symbols will be used as variable names
        cache=True)
    data.obs['sample'] = samples[i]
    data.obs['sub_sample'] = sub_samples[i]
    data.obs.index = data.obs.index.str.split('-').str[0]
    wound_datasets.append(data)
    
    # Get the RNA velocity data obtained with kallisto
    directory = data_directory + paper_name + conditions[i] + 'counts_unfiltered/adata.h5ad'
    data_kb = sc.read_h5ad(directory)
    data_kb.obs['sample'] = samples[i]
    data_kb.obs['sub_sample'] = sub_samples[i]
    data_kb.obs.index.name = None

    data_kb.obs.index = data_kb.obs.index.str.split('-').str[0]
    data_kb.var.index = data_kb.var.index.str.split('.').str[0]
    data_kb.var.rename({'gene_name':'gene_ids'}, axis=1, inplace=True)
    gene_ids = data_kb.var.index
    data_kb.var.index = pd.Index(data_kb.var['gene_ids'].tolist()) # The index names should be the gene ids
    data_kb.var['gene_ids'] = gene_ids
    data_kb.var.index.name = None
    data_kb.var_names_make_unique()
    
    # Get the intersections of the cells and gene sets
    var_intersect = data.var.index.intersection(data_kb.var.index)
    obs_intersect = data.obs.index.intersection(data_kb.obs.index)
    
    # Subset the dataframes
    data_subset = data[obs_intersect, var_intersect]    
    data_kb_subset = data_kb[obs_intersect, var_intersect]
    
    # Set the spliced/unspliced layers
    data_subset.layers['spliced'] = data_kb_subset.layers['spliced']
    data_subset.layers['unspliced'] = data_kb_subset.layers['unspliced']
    
    lwpwd18_datasets.append(data_subset)

Remove doublets via Scrublet.

In [None]:
for i in range(len(uwp21_datasets)):
    data = uwp21_datasets[i]
    sc.external.pp.scrublet(data)
    data = data[data.obs['predicted_doublet'] == False]
    uwp21_datasets[i] = data
    
for i in range(len(uwp49_datasets)):
    data = uwp49_datasets[i]
    sc.external.pp.scrublet(data)
    data = data[data.obs['predicted_doublet'] == False]
    uwp49_datasets[i] = data
    
for i in range(len(swpwd4_datasets)):
    data = swpwd4_datasets[i]
    sc.external.pp.scrublet(data)
    data = data[data.obs['predicted_doublet'] == False]
    swpwd4_datasets[i] = data
    
for i in range(len(swpwd7_datasets)):
    data = swpwd7_datasets[i]
    sc.external.pp.scrublet(data)
    data = data[data.obs['predicted_doublet'] == False]
    swpwd7_datasets[i] = data
    
for i in range(len(lwpwd12_datasets)):
    data = lwpwd12_datasets[i]
    sc.external.pp.scrublet(data)
    data = data[data.obs['predicted_doublet'] == False]
    lwpwd12_datasets[i] = data
    
for i in range(len(lwpwd14_datasets)):
    data = lwpwd14_datasets[i]
    sc.external.pp.scrublet(data)
    data = data[data.obs['predicted_doublet'] == False]
    lwpwd14_datasets[i] = data
    
for i in range(len(lwpwd18_datasets)):
    data = lwpwd18_datasets[i]
    sc.external.pp.scrublet(data)
    data = data[data.obs['predicted_doublet'] == False]
    lwpwd18_datasets[i] = data

Preprocess each condition separately. First, for each condition, we merge the list of samples for QC filtering.

In [None]:
uwp21_merged = uwp21_datasets[0].concatenate(uwp21_datasets[1:])
uwp21_merged.obs_names_make_unique(join='_')

uwp49_merged = uwp49_datasets[0].concatenate(uwp49_datasets[1:])
uwp49_merged.obs_names_make_unique(join='_')

swpwd4_merged = swpwd4_datasets[0].concatenate(swpwd4_datasets[1:])
swpwd4_merged.obs_names_make_unique(join='_')

swpwd7_merged = swpwd7_datasets[0].concatenate(swpwd7_datasets[1:])
swpwd7_merged.obs_names_make_unique(join='_')

lwpwd12_merged = lwpwd12_datasets[0].concatenate(lwpwd12_datasets[1:])
lwpwd12_merged.obs_names_make_unique(join='_')

lwpwd14_merged = lwpwd14_datasets[0].concatenate(lwpwd14_datasets[1:])
lwpwd14_merged.obs_names_make_unique(join='_')

lwpwd18_merged = lwpwd18_datasets[0].concatenate(lwpwd18_datasets[1:])
lwpwd18_merged.obs_names_make_unique(join='_')

For each condition, we filter based on the gene counts, the UMIs (`n_genes`) and mitochondrial gene expression.

### UW P21

In [None]:
uwp21_merged.obs['n_counts'] = uwp21_merged.X.sum(1)
uwp21_merged.obs['log_counts'] = np.log(uwp21_merged.obs['n_counts'])
uwp21_merged.obs['n_genes'] = (uwp21_merged.X > 0).sum(1)

# Calculate the percentage of mitochondrial gene expression
mito_genes = uwp21_merged.var_names.str.match('mt-')
mt_sum = np.sum(uwp21_merged[:, mito_genes].X, axis=1)
total_sum = np.sum(uwp21_merged.X, axis=1)
# the `.A1` is only necessary as X is sparse (to transform to a dense array after summing)
if sparse.issparse(uwp21_merged.X):
    mt_sum = mt_sum.A1
    total_sum = total_sum.A1
uwp21_merged.obs['pct_counts_mt'] =  100.0 * mt_sum / total_sum


In [None]:
plt.rcParams['figure.figsize']=(12, 4) #rescale figures
sc.pl.violin(uwp21_merged, 'pct_counts_mt', groupby='sub_sample', multi_panel=False, cut=5)
plt.show()

sc.pl.violin(uwp21_merged, 'n_counts', groupby='sub_sample', multi_panel=False, log=True, cut=0)
plt.show()

plt.rcParams['figure.figsize']=(5, 5) #rescale figures

#Data quality summary plots
p1 = sc.pl.scatter(uwp21_merged, 'n_counts', 'n_genes', color='pct_counts_mt')
p2 = sc.pl.scatter(uwp21_merged[uwp21_merged.obs['n_counts']<5000], 'n_counts', 'n_genes', color='pct_counts_mt')

In [None]:
plt.rcParams['figure.figsize']=(5, 5) #rescale figures

#Thresholding decision: counts
p3 = sns.distplot(uwp21_merged.obs['n_counts'], kde=False)
plt.show()

p4 = sns.distplot(uwp21_merged.obs['n_counts'][uwp21_merged.obs['n_counts']<5000], kde=False, bins=60)
plt.show()

p5 = sns.distplot(uwp21_merged.obs['n_counts'][uwp21_merged.obs['n_counts']>10000], kde=False, bins=60)
plt.show()

#Thresholding decision: genes
p6 = sns.distplot(uwp21_merged.obs['n_genes'][uwp21_merged.obs['n_genes'] > 2000], kde=False, bins=60)
plt.show()

p7 = sns.distplot(uwp21_merged.obs['n_genes'][uwp21_merged.obs['n_genes']<1500], kde=False, bins=60)
plt.show()

In [None]:
# Filter cells according to identified QC thresholds:
print('Total number of cells: {:d}'.format(uwp21_merged.n_obs))

sc.pp.filter_cells(uwp21_merged, min_counts = 300)
print('Number of cells after min count filter: {:d}'.format(uwp21_merged.n_obs))

sc.pp.filter_cells(uwp21_merged, max_counts = 25000)
print('Number of cells after max count filter: {:d}'.format(uwp21_merged.n_obs))

uwp21_merged = uwp21_merged[uwp21_merged.obs['pct_counts_mt'] < 5.0] # Should be either 5 or 10
print('Number of cells after MT filter: {:d}'.format(uwp21_merged.n_obs))

sc.pp.filter_cells(uwp21_merged, max_genes = 5500)
print('Number of cells after gene filter: {:d}'.format(uwp21_merged.n_obs))

sc.pp.filter_cells(uwp21_merged, min_genes = 250)
print('Number of cells after gene filter: {:d}'.format(uwp21_merged.n_obs))

sc.pp.filter_genes(uwp21_merged, min_cells=3)
print('Number of genes after cell filter: {:d}'.format(uwp21_merged.n_vars))

### UW P49

In [None]:
uwp49_merged.obs['n_counts'] = uwp49_merged.X.sum(1)
uwp49_merged.obs['log_counts'] = np.log(uwp49_merged.obs['n_counts'])
uwp49_merged.obs['n_genes'] = (uwp49_merged.X > 0).sum(1)

# Calculate the percentage of mitochondrial gene expression
mito_genes = uwp49_merged.var_names.str.match('mt-')
mt_sum = np.sum(uwp49_merged[:, mito_genes].X, axis=1)
total_sum = np.sum(uwp49_merged.X, axis=1)
# the `.A1` is only necessary as X is sparse (to transform to a dense array after summing)
if sparse.issparse(uwp49_merged.X):
    mt_sum = mt_sum.A1
    total_sum = total_sum.A1
uwp49_merged.obs['pct_counts_mt'] =  100.0 * mt_sum / total_sum


In [None]:
plt.rcParams['figure.figsize']=(12, 4) #rescale figures
sc.pl.violin(uwp49_merged, 'pct_counts_mt', groupby='sub_sample', multi_panel=False, cut=5)
plt.show()

sc.pl.violin(uwp49_merged, 'n_counts', groupby='sub_sample', multi_panel=False, log=True, cut=0)
plt.show()

plt.rcParams['figure.figsize']=(5, 5) #rescale figures

#Data quality summary plots
p1 = sc.pl.scatter(uwp49_merged, 'n_counts', 'n_genes', color='pct_counts_mt')
p2 = sc.pl.scatter(uwp49_merged[uwp49_merged.obs['n_counts']<5000], 'n_counts', 'n_genes', color='pct_counts_mt')

In [None]:
plt.rcParams['figure.figsize']=(5, 5) #rescale figures

#Thresholding decision: counts
p3 = sns.distplot(uwp49_merged.obs['n_counts'], kde=False)
plt.show()

p4 = sns.distplot(uwp49_merged.obs['n_counts'][uwp49_merged.obs['n_counts']<5000], kde=False, bins=60)
plt.show()

p5 = sns.distplot(uwp49_merged.obs['n_counts'][uwp49_merged.obs['n_counts']>17500], kde=False, bins=60)
plt.show()

#Thresholding decision: genes
p6 = sns.distplot(uwp49_merged.obs['n_genes'][uwp49_merged.obs['n_genes'] > 3000], kde=False, bins=60)
plt.show()

p7 = sns.distplot(uwp49_merged.obs['n_genes'][uwp49_merged.obs['n_genes']<1000], kde=False, bins=60)
plt.show()

In [None]:
# Filter cells according to identified QC thresholds:
print('Total number of cells: {:d}'.format(uwp49_merged.n_obs))

sc.pp.filter_cells(uwp49_merged, min_counts = 300)
print('Number of cells after min count filter: {:d}'.format(uwp49_merged.n_obs))

sc.pp.filter_cells(uwp49_merged, max_counts = 25000)
print('Number of cells after max count filter: {:d}'.format(uwp49_merged.n_obs))

uwp49_merged = uwp49_merged[uwp49_merged.obs['pct_counts_mt'] < 10.0] # Should be either 5 or 10
print('Number of cells after MT filter: {:d}'.format(uwp49_merged.n_obs))

sc.pp.filter_cells(uwp49_merged, max_genes = 5000)
print('Number of cells after gene filter: {:d}'.format(uwp49_merged.n_obs))

sc.pp.filter_cells(uwp49_merged, min_genes = 500)
print('Number of cells after gene filter: {:d}'.format(uwp49_merged.n_obs))

sc.pp.filter_genes(uwp49_merged, min_cells=3)
print('Number of genes after cell filter: {:d}'.format(uwp49_merged.n_vars))

### SW PWD7

In [None]:
swpwd7_merged.obs['n_counts'] = swpwd7_merged.X.sum(1)
swpwd7_merged.obs['log_counts'] = np.log(swpwd7_merged.obs['n_counts'])
swpwd7_merged.obs['n_genes'] = (swpwd7_merged.X > 0).sum(1)

# Calculate the percentage of mitochondrial gene expression
mito_genes = swpwd7_merged.var_names.str.match('mt-')
mt_sum = np.sum(swpwd7_merged[:, mito_genes].X, axis=1)
total_sum = np.sum(swpwd7_merged.X, axis=1)
# the `.A1` is only necessary as X is sparse (to transform to a dense array after summing)
if sparse.issparse(swpwd7_merged.X):
    mt_sum = mt_sum.A1
    total_sum = total_sum.A1
swpwd7_merged.obs['pct_counts_mt'] =  100.0 * mt_sum / total_sum


In [None]:
plt.rcParams['figure.figsize']=(12, 4) #rescale figures
sc.pl.violin(swpwd7_merged, 'pct_counts_mt', groupby='sub_sample', multi_panel=False, cut=5)
`plt.show()

plt.rcParams['figure.figsize']=(12, 4) #rescale figures
sc.pl.violin(swpwd7_merged, 'n_counts', groupby='sub_sample', multi_panel=False, log=True, cut=0)
plt.show()

plt.rcParams['figure.figsize']=(5, 5) #rescale figures

#Data quality summary plots
p1 = sc.pl.scatter(swpwd7_merged, 'n_counts', 'n_genes', color='pct_counts_mt')
p2 = sc.pl.scatter(swpwd7_merged[swpwd7_merged.obs['n_counts']<5000], 'n_counts', 'n_genes', color='pct_counts_mt')

In [None]:
plt.rcParams['figure.figsize']=(5, 5) #rescale figures

#Thresholding decision: counts
p3 = sns.distplot(swpwd7_merged.obs['n_counts'], kde=False)
plt.show()

p4 = sns.distplot(swpwd7_merged.obs['n_counts'][swpwd7_merged.obs['n_counts']<5000], kde=False, bins=60)
plt.show()

p5 = sns.distplot(swpwd7_merged.obs['n_counts'][swpwd7_merged.obs['n_counts']>10000], kde=False, bins=60)
plt.show()

#Thresholding decision: genes
p6 = sns.distplot(swpwd7_merged.obs['n_genes'][swpwd7_merged.obs['n_genes'] > 2000], kde=False, bins=60)
plt.show()

p7 = sns.distplot(swpwd7_merged.obs['n_genes'][swpwd7_merged.obs['n_genes']<1500], kde=False, bins=60)
plt.show()

In [None]:
# Filter cells according to identified QC thresholds:
print('Total number of cells: {:d}'.format(swpwd7_merged.n_obs))

sc.pp.filter_cells(swpwd7_merged, min_counts = 300)
print('Number of cells after min count filter: {:d}'.format(swpwd7_merged.n_obs))

sc.pp.filter_cells(swpwd7_merged, max_counts = 30000)
print('Number of cells after max count filter: {:d}'.format(swpwd7_merged.n_obs))

uw_p20_merged = swpwd7_merged[swpwd7_merged.obs['pct_counts_mt'] < 5.0] # Should be either 5 or 10
print('Number of cells after MT filter: {:d}'.format(swpwd7_merged.n_obs))

sc.pp.filter_cells(swpwd7_merged, max_genes = 4500)
print('Number of cells after gene filter: {:d}'.format(swpwd7_merged.n_obs))

sc.pp.filter_cells(swpwd7_merged, min_genes = 400)
print('Number of cells after gene filter: {:d}'.format(swpwd7_merged.n_obs))

sc.pp.filter_genes(swpwd7_merged, min_cells=3)
print('Number of genes after cell filter: {:d}'.format(swpwd7_merged.n_vars))

### LW PWD12

In [None]:
lwpwd12_merged.obs['n_counts'] = lwpwd12_merged.X.sum(1)
lwpwd12_merged.obs['log_counts'] = np.log(lwpwd12_merged.obs['n_counts'])
lwpwd12_merged.obs['n_genes'] = (lwpwd12_merged.X > 0).sum(1)

# Calculate the percentage of mitochondrial gene expression
mito_genes = lwpwd12_merged.var_names.str.match('mt-')
mt_sum = np.sum(lwpwd12_merged[:, mito_genes].X, axis=1)
total_sum = np.sum(lwpwd12_merged.X, axis=1)
# the `.A1` is only necessary as X is sparse (to transform to a dense array after summing)
if sparse.issparse(lwpwd12_merged.X):
    mt_sum = mt_sum.A1
    total_sum = total_sum.A1
lwpwd12_merged.obs['pct_counts_mt'] =  100.0 * mt_sum / total_sum


In [None]:
plt.rcParams['figure.figsize']=(12, 4) #rescale figures
sc.pl.violin(lwpwd12_merged, 'pct_counts_mt', groupby='sub_sample', multi_panel=False, cut=5)
plt.show()

sc.pl.violin(lwpwd12_merged, 'n_counts', groupby='sub_sample', multi_panel=False, log=True, cut=0)
plt.show()
plt.rcParams['figure.figsize']=(5, 5) #rescale figures

#Data quality summary plots
p1 = sc.pl.scatter(lwpwd12_merged, 'n_counts', 'n_genes', color='pct_counts_mt')
p2 = sc.pl.scatter(lwpwd12_merged[lwpwd12_merged.obs['n_counts']<5000], 'n_counts', 'n_genes', color='pct_counts_mt')


In [None]:
plt.rcParams['figure.figsize']=(5, 5) #rescale figures

#Thresholding decision: counts
p3 = sns.distplot(lwpwd12_merged.obs['n_counts'], kde=False)
plt.show()

p4 = sns.distplot(lwpwd12_merged.obs['n_counts'][lwpwd12_merged.obs['n_counts']<5000], kde=False, bins=60)
plt.show()

p5 = sns.distplot(lwpwd12_merged.obs['n_counts'][lwpwd12_merged.obs['n_counts']>5000], kde=False, bins=60)
plt.show()

#Thresholding decision: genes
p6 = sns.distplot(lwpwd12_merged.obs['n_genes'][lwpwd12_merged.obs['n_genes'] > 2000], kde=False, bins=60)
plt.show()

p7 = sns.distplot(lwpwd12_merged.obs['n_genes'][lwpwd12_merged.obs['n_genes']<1000], kde=False, bins=60)
plt.show()

In [None]:
# Filter cells according to identified QC thresholds:
print('Total number of cells: {:d}'.format(lwpwd12_merged.n_obs))

sc.pp.filter_cells(lwpwd12_merged, min_counts = 300)
print('Number of cells after min count filter: {:d}'.format(lwpwd12_merged.n_obs))

sc.pp.filter_cells(lwpwd12_merged, max_counts = 10000)
print('Number of cells after max count filter: {:d}'.format(lwpwd12_merged.n_obs))

lwpwd12_merged = lwpwd12_merged[lwpwd12_merged.obs['pct_counts_mt'] < 10.0] # Should be either 5 or 10
print('Number of cells after MT filter: {:d}'.format(lwpwd12_merged.n_obs))

sc.pp.filter_cells(lwpwd12_merged, max_genes = 3000)
print('Number of cells after gene filter: {:d}'.format(lwpwd12_merged.n_obs))

sc.pp.filter_cells(lwpwd12_merged, min_genes = 450)
print('Number of cells after gene filter: {:d}'.format(lwpwd12_merged.n_obs))

sc.pp.filter_genes(lwpwd12_merged, min_cells=3)
print('Number of genes after cell filter: {:d}'.format(lwpwd12_merged.n_vars))

### LW PWD14

In [None]:
lwpwd14_merged.obs['n_counts'] = lwpwd14_merged.X.sum(1)
lwpwd14_merged.obs['log_counts'] = np.log(lwpwd14_merged.obs['n_counts'])
lwpwd14_merged.obs['n_genes'] = (lwpwd14_merged.X > 0).sum(1)

# Calculate the percentage of mitochondrial gene expression
mito_genes = lwpwd14_merged.var_names.str.match('mt-')
mt_sum = np.sum(lwpwd14_merged[:, mito_genes].X, axis=1)
total_sum = np.sum(lwpwd14_merged.X, axis=1)
# the `.A1` is only necessary as X is sparse (to transform to a dense array after summing)
if sparse.issparse(lwpwd14_merged.X):
    mt_sum = mt_sum.A1
    total_sum = total_sum.A1
lwpwd14_merged.obs['pct_counts_mt'] =  100.0 * mt_sum / total_sum


In [None]:
plt.rcParams['figure.figsize']=(12, 4) #rescale figures
sc.pl.violin(lwpwd14_merged, 'pct_counts_mt', groupby='sub_sample', multi_panel=False, cut=5)
plt.show()

plt.rcParams['figure.figsize']=(12, 4) #rescale figures
sc.pl.violin(lwpwd14_merged, 'n_counts', groupby='sub_sample', multi_panel=False, log=True, cut=0)
plt.show()

plt.rcParams['figure.figsize']=(5, 5) #rescale figures

#Data quality summary plots
p1 = sc.pl.scatter(lwpwd14_merged, 'n_counts', 'n_genes', color='pct_counts_mt')
p2 = sc.pl.scatter(lwpwd14_merged[lwpwd14_merged.obs['n_counts']<5000], 'n_counts', 'n_genes', color='pct_counts_mt')

In [None]:
plt.rcParams['figure.figsize']=(5, 5) #rescale figures

#Thresholding decision: counts
p3 = sns.distplot(lwpwd14_merged.obs['n_counts'], kde=False)
plt.show()

p4 = sns.distplot(lwpwd14_merged.obs['n_counts'][lwpwd14_merged.obs['n_counts']<5000], kde=False, bins=60)
plt.show()

p5 = sns.distplot(lwpwd14_merged.obs['n_counts'][lwpwd14_merged.obs['n_counts']>5000], kde=False, bins=60)
plt.show()

#Thresholding decision: genes
p6 = sns.distplot(lwpwd14_merged.obs['n_genes'][lwpwd14_merged.obs['n_genes'] > 2000], kde=False, bins=60)
plt.show()

p7 = sns.distplot(lwpwd14_merged.obs['n_genes'][lwpwd14_merged.obs['n_genes']<1500], kde=False, bins=60)
plt.show()

In [None]:
# Filter cells according to identified QC thresholds:
print('Total number of cells: {:d}'.format(lwpwd14_merged.n_obs))

sc.pp.filter_cells(lwpwd14_merged, min_counts = 300)
print('Number of cells after min count filter: {:d}'.format(lwpwd14_merged.n_obs))

sc.pp.filter_cells(lwpwd14_merged, max_counts = 35000)
print('Number of cells after max count filter: {:d}'.format(lwpwd12_gj19_merged.n_obs))

lwpwd14_merged = lwpwd14_merged[lwpwd14_merged.obs['pct_counts_mt'] < 5.0] # Should be either 5 or 10
print('Number of cells after MT filter: {:d}'.format(lwpwd14_merged.n_obs))

sc.pp.filter_cells(lwpwd14_merged, max_genes = 5000)
print('Number of cells after gene filter: {:d}'.format(lwpwd14_merged.n_obs))

sc.pp.filter_cells(lwpwd14_merged, min_genes = 800)
print('Number of cells after gene filter: {:d}'.format(lwpwd14_merged.n_obs))

sc.pp.filter_genes(lwpwd14_merged, min_cells=3)
print('Number of genes after cell filter: {:d}'.format(lwpwd14_merged.n_vars))

### LW PWD18

In [None]:
lwpwd18_merged.obs['n_counts'] = lwpwd18_merged.X.sum(1)
lwpwd18_merged.obs['log_counts'] = np.log(lwpwd18_merged.obs['n_counts'])
lwpwd18_merged.obs['n_genes'] = (lwpwd18_merged.X > 0).sum(1)

# Calculate the percentage of mitochondrial gene expression
mito_genes = lwpwd18_merged.var_names.str.match('mt-')
mt_sum = np.sum(lwpwd18_merged[:, mito_genes].X, axis=1)
total_sum = np.sum(lwpwd18_merged.X, axis=1)
# the `.A1` is only necessary as X is sparse (to transform to a dense array after summing)
if sparse.issparse(lwpwd18_merged.X):
    mt_sum = mt_sum.A1
    total_sum = total_sum.A1
lwpwd18_merged.obs['pct_counts_mt'] =  100.0 * mt_sum / total_sum


In [None]:
plt.rcParams['figure.figsize']=(12, 4) #rescale figures
sc.pl.violin(lwpwd18_merged, 'pct_counts_mt', groupby='sub_sample', multi_panel=False, cut=5)
plt.show()

sc.pl.violin(lwpwd18_merged, 'n_counts', groupby='sub_sample', multi_panel=False, log=True, cut=0)
plt.show()

plt.rcParams['figure.figsize']=(5, 5) #rescale figures

#Data quality summary plots
p1 = sc.pl.scatter(lwpwd18_merged, 'n_counts', 'n_genes', color='pct_counts_mt')
p2 = sc.pl.scatter(lwpwd18_merged[lwpwd18_merged.obs['n_counts']<5000], 'n_counts', 'n_genes', color='pct_counts_mt')

In [None]:
plt.rcParams['figure.figsize']=(5, 5) #rescale figures

#Thresholding decision: counts
p3 = sns.distplot(lwpwd18_merged.obs['n_counts'], kde=False)
plt.show()

p4 = sns.distplot(lwpwd18_merged.obs['n_counts'][lwpwd18_merged.obs['n_counts']<5000], kde=False, bins=60)
plt.show()

p5 = sns.distplot(lwpwd18_merged.obs['n_counts'][lwpwd18_merged.obs['n_counts']>10000], kde=False, bins=60)
plt.show()

#Thresholding decision: genes
p6 = sns.distplot(lwpwd18_merged.obs['n_genes'][lwpwd18_merged.obs['n_genes'] > 3000], kde=False, bins=60)
plt.show()

p7 = sns.distplot(lwpwd18_merged.obs['n_genes'][lwpwd18_merged.obs['n_genes']<1500], kde=False, bins=60)
plt.show()

In [None]:
# Filter cells according to identified QC thresholds:
print('Total number of cells: {:d}'.format(lwpwd18_merged.n_obs))

sc.pp.filter_cells(lwpwd18_merged, min_counts = 300)
print('Number of cells after min count filter: {:d}'.format(lwpwd18_merged.n_obs))

sc.pp.filter_cells(lwpwd18_merged, max_counts = 30000)
print('Number of cells after max count filter: {:d}'.format(lwpwd18_merged.n_obs))

lwpwd18_merged = lwpwd18_merged[lwpwd18_merged.obs['pct_counts_mt'] < 10.0] # Should be either 5 or 10
print('Number of cells after MT filter: {:d}'.format(lwpwd18_merged.n_obs))

sc.pp.filter_cells(lwpwd18_merged, max_genes = 5000)
print('Number of cells after gene filter: {:d}'.format(lwpwd18_merged.n_obs))

sc.pp.filter_cells(lwpwd18_merged, min_genes = 250)
print('Number of cells after gene filter: {:d}'.format(lwpwd18_merged.n_obs))

sc.pp.filter_genes(lwpwd18_merged, min_cells=3)
print('Number of genes after cell filter: {:d}'.format(lwpwd18_merged.n_vars))

# Merge datasets and normalise

In [None]:
# We now join up the datasets into a single annotated data frame
wound_full_merged = uwp21_merged.concatenate([uwp49_merged, swpwd4_merged, swpwd7_merged, lwpwd12_merged, lwpwd14_merged, lwpwd18_merged], join='outer')
wound_full_merged.obs_names_make_unique(join='_')
wound_full_merged.obs.index = wound_full_merged.obs.index.str.split('-').str[0]
wound_full_merged.obs_names_make_unique(join='_')


In [None]:
# Store the raw counts as a layer
wound_full_merged.layers['counts'] = wound_full_merged.X.copy()

# Normalise the data
sc.pp.normalize_total(wound_full_merged, target_sum=1e4)
sc.pp.log1p(wound_full_merged)

In [None]:
# Identify the highly-variable genes. We use the CellRanger routine provided in Scanpy.
target_genes = 4000
sc.pp.highly_variable_genes(wound_full_merged, flavor='cell_ranger', n_top_genes=target_genes, batch_key='sample')

In [None]:
# As we don't have enough target genes, we need to consider HVGs in all but one batches.
n_batches = len(wound_full_merged.obs['sample'].cat.categories)
# These are the genes that are variable across all batches
nbatch1_dispersions = wound_full_merged.var['dispersions_norm'][wound_full_merged.var.highly_variable_nbatches > n_batches - 1]
nbatch1_dispersions.sort_values(ascending=False, inplace=True)
print(len(nbatch1_dispersions))

# Fill up the genes now, using this method from the Theis lab
enough = False
hvg = nbatch1_dispersions.index[:]
not_n_batches = 1

# We'll go down one by one, until we're selecting HVGs from just a single gbatch
while not enough:
    
    target_genes_diff = target_genes - len(hvg) # Get the number of genes we still need to fill up
    
    tmp_dispersions = wound_full_merged.var['dispersions_norm'][wound_full_merged.var.highly_variable_nbatches == (n_batches - not_n_batches)]
    
    # If we haven't hit the target gene numbers, add this to the list and we repeat this iteration
    if len(tmp_dispersions) < target_genes_diff:
        
        hvg = hvg.append(tmp_dispersions.index)
        not_n_batches += 1
        
    else:
        
        tmp_dispersions.sort_values(ascending=False, inplace=True)
        hvg = hvg.append(tmp_dispersions.index[:target_genes_diff])
        enough = True

# Data integration via Scanorama

Subset the data on the HVG to speed things up

In [None]:
wound_full_merged_hvg = wound_full_merged[:, hvg] # Filter out genes that do not vary much across cells

In [None]:
# Split the data into batches (marked by 'sample')
wound_full_split = []

for sample in wound_full_merged_hvg.obs['sample'].unique():
    wound_full_split.append(wound_full_merged_hvg[wound_full_merged_hvg.obs['sample']==sample].copy())

Now we run Scanorama on the split data to obtain an integrated reduced dimension embedding.

In [None]:
%%time
scrama.integrate_scanpy(wound_full_split, ds_names = list(wound_full_merged_hvg.obs['sample'].unique()))

In [None]:
# Consider when we just take the embedding
embeddings = [adata.obsm['X_scanorama'] for adata in wound_full_split]
embeddings_joined = np.concatenate(embeddings, axis=0)
wound_full_merged.obsm['X_SC'] = embeddings_joined

In [None]:
# Generate the kNN graph and calculate the UMAp
sc.pp.neighbors(wound_full_merged, use_rep = "X_SC", n_neighbors=30)
sc.tl.umap(wound_full_merged)

In [None]:
plt.rcParams['figure.figsize']=(6, 6) #rescale figures
sc.pl.umap(wound_full_merged, color='sample')

# Clustering and cell type identification

In [None]:
# Cluster the data now
sc.tl.leiden(wound_full_merged, resolution = 0.3, key_added = 'leiden')

In [None]:
sc.tl.rank_genes_groups(wound_full_merged, 'leiden', key_added = 'leiden', method='wilcoxon')

In [None]:
# Rename the clusters
new_cluster_names = ['Epidermal 1', 'Fibroblast 1', 'Fibroblast 2', 'Epidermal 2', 'Immune 1', \
                     'Fibroblast 3', 'Pericyte', 'Immune 2', 'Epidermal 3', 'Endothelial',\
                     'Epidermal 4', 'Schwann', 'Fibroblast 4', 'Lymphatic endothelial', 'Epidermal 5', \
                     'Immune 3', 'Melanocyte', 'Langerhans cell', 'Skeletal muscle', 'Smooth muscle']
wound_full_merged.rename_categories('leiden', new_cluster_names)

In [None]:
new_leiden_colours = np.array(['#023fa5', '#8e063b', '#d33e6a', '#4b68af', '#11c638', '#bb7784',\
                              '#ef9708', '#8dd593', '#7d87b9', '#f0b98d', '#8595e1', '#bec1d4',\
                              '#e07b91', '#ead3c6', '#b5bbe3', '#c6dec7', '#0fcfc0', '#9cded6',\
                              '#a58c88', '#f6c4e1' , '#f3e1eb', '#d5eae7', '#d33e6a'])
wound_full_merged.uns['leiden_colors'] = new_leiden_colours

# Plot the integration metrics

In [None]:
integration_metrics = pd.read_csv(results_directory + 'integratedskin_integrationmetrics.csv', index_col=0)
integration_metrics.rename(columns={'0':'Value'}, inplace=True)
integration_metrics.dropna(subset=['Value'], inplace=True) # Some of these weren't calculated, so we drop them.

In [None]:
metrics_types = ['Bio conservation', 'Bio conservation', 'Bio conservation', 'Batch correction',\
                 'Batch correction', 'Bio conservation', 'Bio conservation', 'Bio conservation',\
                 'Batch correction', 'Batch correction']
integration_metrics['Type'] = pd.Series(index=integration_metrics.index, data=metrics_types, dtype='category')

In [None]:
metrics_order = ['PCR_batch', 'ASW_label/batch', 'kBET', 'graph_conn', 'NMI_cluster/label', 'ARI_cluster/label', 'ASW_label', 'isolated_label_F1', 'cell_cycle_conservation', 'isolated_label_silhouette']
integration_metrics = integration_metrics.reindex(metrics_order)

# Some of these are irrelevant for our needs, so we don't plot them.
integration_metrics_subset = integration_metrics.iloc[[0, 1, 2, 3, 4, 6,7, 9],:]


In [None]:
integration_metrics_subset.rename(index={'PCR_batch':'PCR batch', 'ASW_label/batch':'ASW batch', 'graph_conn':'Graph connectivity',\
                                        'NMI_cluster/label': 'NMI label', 'ASW_label': 'ASW label', 'isolated_label_F1':'Isolated label F1',\
                                        'isolated_label_silhouette': 'Isolated label silhouette'}, inplace=True)

In [None]:
metric_colours = {'Batch correction': 'blue', "Bio conservation": "orange"}

ax = integration_metrics_subset.plot.bar(y='Value', rot=90, width=0.8, figsize=(6, 12), linewidth=1.0, color=[metric_colours[i] for i in integration_metrics['Type']])
plt.legend(loc='center left', bbox_to_anchor=(1.0, 0.5))


# Plot the cell type composition across time 

In [None]:
wound_merged_df = wound_full_merged.obs

sample_order = ['UW P21', 'UW P49', 'SW PWD4', 'SW PWD7', 'LW PWD12', 'LW PWD14', 'LW FIB PWD18', 'LW REG PWD18']

tmp = pd.crosstab(wound_merged_df['sample'], wound_merged_df['leiden'], normalize='index')
tmp = tmp.reindex(sample_order)

In [None]:
axis = tmp.plot.bar(stacked=True, width=0.9, grid=False, figsize=(6,10), linewidth=1.0, color=new_leiden_colours)
# axis.invert_yaxis()
plt.legend(loc='center left', bbox_to_anchor=(1.0, 0.5))
fig = axis.get_figure()

# Plot discriminatory cell type marker genes

In [None]:
marker_genes_dict = {'Epidermal 1': ['Lgals7', 'Apoc1', 'Krt5'],
                     'Epidermal 2': ['Lgals7', 'Apoc1', 'Krt5'],
                     'Epidermal 3': ['Lgals7', 'Apoc1', 'Krt5'],
                     'Epidermal 4': ['Lgals7', 'Apoc1', 'Krt5'],
                     'Epidermal 5': ['Lgals7', 'Apoc1', 'Krt5'],
                    'Fibroblast 1': ['Dcn', 'Lum', 'Crabp1'],
                     'Fibroblast 2': ['Dcn', 'Lum', 'Crabp1'],
                     'Fibroblast 3': ['Dcn', 'Lum', 'Crabp1'],
                     'Fibroblast 4': ['Dcn', 'Lum', 'Crabp1'],
                     'Immune 1': ['Cd52', 'Srgn', 'Tyrobp'],
                      'Immune 2': ['Cd52', 'Srgn', 'Tyrobp'],
                      'Immune 3': ['Cd52', 'Srgn', 'Tyrobp'],
                      'Pericyte': ['Rgs5', 'Col4a1'],
                     'Endothelial': ['Pecam1', 'Cdh5'],
                     'Schwann': ['Plp1', 'Pmp22'],
                     'Lymphatic endothelial': ['Cldn5', 'Lyve1'],
                     'Melanocyte': ['Dct', 'Mlana'],
                     'Langerhans cell': ['Cd74', 'Cd207'],
                     'Skeletal muscle': ['Ckm', 'Acta1'],
                     'Smooth muscle': ['Mylk', 'Actg2']
                    }

In [None]:
categories = wound_full_merged.obs['leiden_sub'].value_counts().index

In [None]:
keys = ['leiden']

for category in marker_genes_dict:
    genes = marker_genes_dict[category]
    
    for gene in genes:
        if gene not in keys:
            keys.append(gene)

This code is a hacked version from Scanpy's `tracksplot` function and code I am grateful to have received from Suoqin Jin (Wuhan university) that was used in [Guerrero-Juarez et al. (2019)](https://www.nature.com/articles/s41467-018-08247-x). I should say that this plot was then heavily edited in Illustrator afterwards.

In [None]:
# Prepare the data for the tracksplot
obs_tidy = sc.get.obs_df(wound_full_merged, keys=keys, use_raw=True)
obs_tidy['leiden'] = pd.Categorical(obs_tidy['leiden'], ncategories=categories, ordered=True)
obs_tidy = obs_tidy.sort_values('leiden',ascending=True)
tracksplot_genes = list(obs_tidy.columns[1:])

# Create the trackplot
nbins = 10

# obtain the start and end of each category and make
# a list of ranges that will be used to plot a different
# color
cumsum = [0] + list(np.cumsum(obs_tidy['leiden'].value_counts(sort=False)))
x_values = [(x, y) for x, y in zip(cumsum[:-1], cumsum[1:])]

dendro_height = 0

groupby_height = 0.24
# +2 because of dendrogram on top and categories at bottom
num_rows = len(tracksplot_genes) + 2
width = 12
track_height = 0.25

height_ratios = [dendro_height] + [track_height] * len(tracksplot_genes) + [groupby_height]
height = 2*sum(height_ratios)

obs_tidy = obs_tidy.T

fig = plt.figure(figsize=(width, height))
axs = gridspec.GridSpec(
    ncols=2,
    nrows=num_rows,
    wspace=1.0 / width,
    hspace=0,
    height_ratios=height_ratios,
    width_ratios=[width, 0.14],
)
axs_list = []
first_ax = None
for idx, var in enumerate(tracksplot_genes):
    ax_idx = idx + 1  # this is because of the dendrogram
    if first_ax is None:
        ax = fig.add_subplot(axs[ax_idx, 0])
        first_ax = ax
    else:
        ax = fig.add_subplot(axs[ax_idx, 0], sharex=first_ax)
    axs_list.append(ax)
    for cat_idx, category in enumerate(categories):
        x_start, x_end = x_values[cat_idx]
        expression_values = np.sort(obs_tidy.iloc[idx + 1, x_start:x_end].to_numpy()) # Get the expression_values
        average_expressions = np.zeros(nbins)

        num = int(np.floor(np.size(expression_values)/nbins))

        for ave_idx in range(nbins):
            if ave_idx < nbins - 1:
                average_expressions[ave_idx] = np.mean(expression_values[num*(ave_idx):num*(1 +ave_idx)])
            else:
                average_expressions[ave_idx] = np.mean(expression_values[num*(ave_idx):])

        ax.fill_between(
            range(cat_idx*10, (cat_idx + 1)*10),
            0,
            average_expressions,
            lw=0.1,
            color=groupby_colors[cat_idx],
        )

    # remove the xticks labels except for the last processed plot.
    # Because the plots share the x axis it is redundant and less compact
    # to plot the axis for each plot
    if idx < len(tracksplot_genes) - 1:
        ax.tick_params(labelbottom=False, labeltop=False, bottom=False, top=False)
        ax.set_xlabel('')
    ax.spines['left'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.grid(False)
    ymin, ymax = ax.get_ylim()
    ymax = int(ymax)
    ax.set_yticks([ymax])
    ax.set_yticklabels([str(ymax)], ha='left', va='top')
    ax.spines['right'].set_position(('axes', 1.01))
    ax.tick_params(
        axis='y',
        labelsize='x-small',
        right=True,
        left=False,
        length=2,
        which='both',
        labelright=True,
        labelleft=False,
        direction='in',
    )
    ax.set_ylabel(var, rotation=0, fontsize='small', ha='right', va='bottom')
    ax.yaxis.set_label_coords(-0.005, 0.1)
ax.set_xlim(0, len(categories)*10)
ax.tick_params(axis='x', bottom=False, labelbottom=False)

# the ax to plot the groupby categories is split to add a small space
# between the rest of the plot and the categories
axs2 = gridspec.GridSpecFromSubplotSpec(
    2, 1, subplot_spec=axs[num_rows - 1, 0], height_ratios=[1, 1]
)

groupby_ax = fig.add_subplot(axs2[1])

# Save the results

In [None]:
# Save so we don't lose all the good work from Scanorama.
init_results_name = 'integratedskindata.h5ad'
wound_full_merged.write(results_directory + init_results_name, compression='gzip')