# Run with scib-pipeline-R4.0 conda environment

In [None]:
# Import dependencies
%matplotlib inline
import os
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
import scib
import anndata

import matplotlib.pyplot as plt
from typing import List

# Print date and time:
import datetime
e = datetime.datetime.now()
print ("Current date and time = %s" % e)

# set a working directory
wdir = "/mnt/da8aa2c4-0136-465b-87a2-d12a59afec55/akurjan/analysis/notebooks"
os.chdir( wdir )

# folder structures
QC_FOLDERNAME = "foetal/results/QC/"
RESULTS_FOLDERNAME = "foetal/results/Normalisation/"
FIGURES_FOLDERNAME = "foetal/figures/Normalisation/"

if not os.path.exists(RESULTS_FOLDERNAME):
    os.makedirs(RESULTS_FOLDERNAME)
if not os.path.exists(FIGURES_FOLDERNAME):
    os.makedirs(FIGURES_FOLDERNAME)

# Set folder for saving figures into
sc.settings.figdir = FIGURES_FOLDERNAME

# Set other settings
sc.settings.verbosity = 3 # verbosity: errors (0), warnings (1), info (2), hints (3)
sc.logging.print_versions()
sc.set_figure_params(dpi=150, fontsize=10, dpi_save=600)

In [None]:
def savesvg(fname: str, fig, folder: str=FIGURES_FOLDERNAME) -> None:
    """
    Save figure as vector-based SVG image format.
    """
    fig.savefig(os.path.join(folder, fname), format='svg')

In [None]:
adata = sc.read_h5ad(os.path.join(QC_FOLDERNAME, 'adata_concat_filtered.h5ad'))
adata

In [None]:
del adata.layers['ambiguous'], adata.layers['matrix']

In [None]:
# Splitting adata into separate objects that can be integrated separately.
adult_adata = adata[adata.obs['agegroup'] == 'Adult', :]
dev_adata = adata[adata.obs['agegroup'] != 'Adult', :]

adata_dict = {'adult_adata': adult_adata, 'dev_adata': dev_adata}
adata_dict

# NORMALISATION

## Shifted logarithm normalisation

In [None]:
for i, adata in adata_dict.items():
    scaled_counts = sc.pp.normalize_total(adata, target_sum=None, inplace=False)
    # log1p transform
    adata.layers["log1p_norm"] = sc.pp.log1p(scaled_counts["X"], copy=True)
    print(adata.layers["log1p_norm"][1:10, 1:10])
    adata_dict[i] = adata

In [None]:
for adata in adata_dict.values():
    print(adata.layers['log1p_norm'][0:5,0:5])

In [None]:
for adata in adata_dict.values():
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    p1 = sns.histplot(adata.obs["total_counts"], bins=100, kde=False, ax=axes[0])
    axes[0].set_title("Total counts")
    p2 = sns.histplot(adata.layers["log1p_norm"].sum(1), bins=100, kde=False, ax=axes[1])
    axes[1].set_title("Shifted logarithm")
    plt.show()

# FEATURE SELECTION AND SCALING

In [None]:
def convert_uint_to_int_single(adata):
    """
    Convert uint32 and uint64 dtypes in anndata object to int32 and int64 dtypes,
    respectively. Prints a message for each conversion.
    """
    # Check var and obs dtypes
    for dtype in ['var', 'obs']:
        dtype_data = getattr(adata, dtype)
        if isinstance(dtype_data, np.ndarray):
            # If dtype_data is a structured array, convert each field individually
            for field in dtype_data.dtype.names:
                if dtype_data[field].dtype == 'uint32':
                    dtype_data[field] = dtype_data[field].astype('float32')
                    print(f"Converted {dtype}.{field} from uint32 to float32.")
                elif dtype_data[field].dtype == 'uint64':
                    dtype_data[field] = dtype_data[field].astype('float64')
                    print(f"Converted {dtype}.{field} from uint64 to float64.")
                elif dtype_data[field].dtype == 'uint16':
                    dtype_data[field] = dtype_data[field].astype('int16')
                    print(f"Converted {dtype}.{field} from uint16 to int16.")
        elif isinstance(dtype_data, pd.DataFrame):
            # If dtype_data is a DataFrame, convert each column individually
            for col in dtype_data.columns:
                if dtype_data[col].dtype == 'uint32':
                    dtype_data[col] = dtype_data[col].astype('float32')
                    print(f"Converted {dtype}.{col} from uint32 to float32.")
                elif dtype_data[col].dtype == 'uint64':
                    dtype_data[col] = dtype_data[col].astype('float64')
                    print(f"Converted {dtype}.{col} from uint64 to float64.")
                elif dtype_data[col].dtype == 'uint16':
                    dtype_data[col] = dtype_data[col].astype('int16')
                    print(f"Converted {dtype}.{col} from uint16 to int16.")
                
    # Update X and layers dtypes
    if adata.X.dtype == 'uint32':
        adata.X = adata.X.astype('int32')
        print("Converted X from uint32 to int32.")
    elif adata.X.dtype == 'uint64':
        adata.X = adata.X.astype('int64')
        print("Converted X from uint64 to int64.")
    elif adata.X.dtype == 'uint16':
        adata.X = adata.X.astype('int16')
        print("Converted X from uint16 to int16.")
    for layer_key, layer_val in adata.layers.items():
        if layer_val.dtype == 'uint32':
            adata.layers[layer_key] = layer_val.astype('int32')
            print(f"Converted layer {layer_key} from uint32 to int32.")
        elif layer_val.dtype == 'uint64':
            adata.layers[layer_key] = layer_val.astype('int64')
            print(f"Converted layer {layer_key} from uint64 to int64.")
        elif layer_val.dtype == 'uint16':
            adata.layers[layer_key] = layer_val.astype('int16')
            print(f"Converted layer {layer_key} from uint16 to int16.")

In [None]:
for adata in adata_dict.values():
    convert_uint_to_int_single(adata)            

# Selection based on deviance (needs more work, not used)

In [None]:
# import anndata2ri
# import logging
# import rpy2.rinterface_lib.callbacks as rcb
# import rpy2.robjects as robjects

# rcb.logger.setLevel(logging.ERROR)
# robjects.pandas2ri.activate()
# anndata2ri.activate()

# #Loading the rpy2 extension enables cell/line magic to be used
# %load_ext rpy2.ipython

In [None]:
# %%R
# BiocManager::install("scry")

In [None]:
# for adata in adata_dict.values():
#     print(adata.X[0:5,0:5])

In [None]:
# def split_and_select_deviant(anndata_obj: anndata.AnnData, obs_var: str, target_genes: int) -> anndata.AnnData:
#     """
    
#     """
    
#     # First, make a copy of the input anndata object
#     anndata_copy = anndata_obj.copy()
    
#     n_batches = len(anndata_copy.obs[obs_var].cat.categories)
#     print(f'splitting into {n_batches} batches for deviance calculation')
    
#     # Split the data by the provided observation variable for batch-aware deviance selection
#     groups = anndata_copy.obs[obs_var].unique() 
#     split_data = [anndata_copy[anndata_copy.obs[obs_var] == group] for group in groups]
    
#     # Calculate deviance for each group and store highly deviant genes separately
#     highly_deviant_genes = []
#     df = []
#     gene_list = anndata_copy.var_names
#     highly_deviant_genes_per_batch = {}

#     for i, data in enumerate(split_data):
#         name=data.obs['sampletype'][0]
#         print(f'Counting deviance for {name}')
#         temp_counts = data.X.T
#         robjects.globalenv['temp_counts'] = temp_counts
#         robjects.r('library(SingleCellExperiment)')
#         robjects.r('library(scry)')
#         robjects.r('sce <- SingleCellExperiment(list(counts=temp_counts))')
#         robjects.r('sce <- devianceFeatureSelection(sce, assay="counts")')
#         binomial_deviance = robjects.r("rowData(sce)$binomial_deviance").T
#         idx = binomial_deviance.argsort()[-target_genes:]
#         mask = np.zeros(data.var_names.shape, dtype=bool)
#         mask[idx] = True

#         data.var["highly_deviant"] = mask
#         data.var["binomial_deviance"] = binomial_deviance
#         split_data[i] = data
        
#         highly_deviant_genes.append(data.var_names[mask])
#         highly_deviant_genes_per_batch[name] = data.var_names[mask]
    
        
#     # Merge the split data back together using `anndata.concat`
#     print('Merging data')
#     merged_data = anndata.concat(split_data, join='outer', index_unique=None)
    
#     # Calculate overall deviance across all groups and rank genes accordingly
#     print('Calculating overall deviance and ranking genes')
#     binomial_deviance = merged_data.var["binomial_deviance"].values
#     gene_names = merged_data.var_names.values
#     overall_deviance_df = pd.DataFrame({"gene_name": gene_names, "binomial_deviance": binomial_deviance})
#     overall_deviance_df.sort_values(by="binomial_deviance", ascending=False, inplace=True)

#     # Select the top `target_genes` highly deviant genes
#     top_highly_deviant_genes = overall_deviance_df["gene_name"].iloc[:target_genes].values

#     # Update the "highly_deviant" column in the merged Anndata object
#     merged_data.var["highly_deviant"] = np.isin(gene_names, top_highly_deviant_genes)
    
#     # Assign batch-specific highly deviant genes to the merged data
#     merged_data.var["highly_deviant"] = np.isin(merged_data.var_names, np.concatenate(highly_deviant_genes))
    
#     del anndata_copy
#     del split_data
#     del groups
    
#     return merged_data
    
    # Select deviances of genes that are highly variable in all batches except one
    #all_deviances = np.concatenate(highly_deviant_genes)
    #highly_variable_deviances = all_deviances[
    #    np.sum(all_deviances.mask, axis=0) == n_batches - 1
    #]
    #highly_variable_deviances.sort_values(ascending=False, inplace=True)
    
    #return highly_variable_deviances

In [None]:
# def split_and_select_deviant(anndata_obj: anndata.AnnData, obs_var: str, target_genes: int) -> anndata.AnnData:
#     """
    
#     """
#     anndata_copy = anndata_obj.copy()
    
#     n_batches = len(anndata_copy.obs[obs_var].cat.categories)
#     print(f'splitting into {n_batches} batches for deviance calculation')
    
#     # Split the data by the provided observation variable for batch-aware deviance selection
#     groups = anndata_copy.obs[obs_var].unique() 
#     split_data = [anndata_copy[anndata_copy.obs[obs_var] == group] for group in groups]
    
#     # Calculate deviance for each group and store highly deviant genes separately
#     df = []
#     gene_list = anndata_copy.var_names
#     highly_deviant_genes_per_batch = []
#     highly_variable_nbatches = np.zeros(len(gene_list), dtype=int)

#     for i, data in enumerate(split_data):
#         name=data.obs['sampletype'][0]
#         print(f'Counting deviance for {name}')
#         temp_counts = data.X.T
#         robjects.globalenv['temp_counts'] = temp_counts
#         robjects.r('library(SingleCellExperiment)')
#         robjects.r('library(scry)')
#         robjects.r('sce <- SingleCellExperiment(list(counts=temp_counts))')
#         robjects.r('sce <- devianceFeatureSelection(sce, assay="counts")')
#         binomial_deviance = robjects.r("rowData(sce)$binomial_deviance").T
#         idx = binomial_deviance.argsort()[-target_genes:]
#         mask = np.zeros(data.var_names.shape, dtype=bool)
#         mask[idx] = True

#         data.var["highly_deviant"] = mask
#         data.var["binomial_deviance"] = binomial_deviance
#         split_data[i] = data
        
#         highly_deviant_genes_per_batch.append(data.var_names[mask])
#         highly_variable_nbatches[mask] += 1
    
        
#     # Merge the split data back together using `anndata.concat`
#     print('Merging data')
#     merged_data = anndata.concat(split_data, join='outer', index_unique=None)
    
#     # Create a new variable 'highly_variable_nbatches' in the merged Anndata object
#     merged_data.var["highly_variable_nbatches"] = highly_variable_nbatches
        
#     del anndata_copy
#     del split_data
#     del groups
    
#     # Create 'nbatch1_deviances' variable to retain highly deviant genes present in all batches
#     nbatch1_deviances = merged_data.var["binomial_deviance"][
#         merged_data.var["highly_variable_nbatches"] >= len(merged_data.obs[obs_var].cat.categories) - 1
#     ]

#     # Sort the deviances in descending order
#     nbatch1_deviances.sort_values(ascending=False, inplace=True)

#     if len(nbatch1_deviances) > target_genes:
#         hvg = nbatch1_deviances.index[:target_genes]

#     else:
#         enough = False
#         print(f"Using {len(nbatch1_deviances)} HVGs from full intersect set")
#         hvg = nbatch1_deviances.index[:]
#         not_n_batches = 1

#         while not enough:
#             target_genes_diff = target_genes - len(hvg)

#             tmp_dispersions = merged_data.var["binomial_deviance"][
#                 merged_data.var.highly_variable_nbatches == (n_batches - not_n_batches)
#             ]

#             if len(tmp_dispersions) < target_genes_diff:
#                 print(
#                     f"Using {len(tmp_dispersions)} HVGs from n_batch-{not_n_batches} set"
#                 )
#                 hvg = hvg.append(tmp_dispersions.index)
#                 not_n_batches += 1

#             else:
#                 print(
#                     f"Using {target_genes_diff} HVGs from n_batch-{not_n_batches} set"
#                 )
#                 tmp_dispersions.sort_values(ascending=False, inplace=True)
#                 hvg = hvg.append(tmp_dispersions.index[:target_genes_diff])
#                 enough = True

#     print(f"Using {len(hvg)} HVGs")

#     return merged_data

In [None]:
# deviant_adata = {}
# for i, adata in adata_dict.items():
#     deviant_adata[i] = select_deviant(adata, 4000)
# deviant_adata

In [None]:
#assert adata_dict.keys() == deviant_adata.keys()

In [None]:
#for i in adata_dict.keys():
#    adata_dict[i].var['highly_deviant'] = deviant_adata[i].var['highly_deviant']

### scIB:
Batch-aware highly variable gene selection

Method to select HVGs based on mean dispersions of genes that are highly variable genes in all batches. Using a the top target_genes per batch by average normalize dispersion. If target genes still hasn’t been reached, then HVGs in all but one batches are used to fill up. This is continued until HVGs in a single batch are considered.

Parameters:
- adata – anndata object
- batch – adata.obs column
- target_genes – maximum number of genes (intersection reduces the number of genes)
- flavor – parameter for scanpy.pp.highly_variable_genes
- n_bins – parameter for scanpy.pp.highly_variable_genes
- adataOut – whether to return an anndata object or a list of highly variable genes

In [None]:
for i, adata in adata_dict.items():
    adata_dict[i].X = adata_dict[i].layers['log1p_norm'].copy()
    print(adata_dict[i].X[0:5, 0:5])

In [None]:
import scib

for i, adata in adata_dict.items():
    scib.preprocessing.hvg_batch(adata, 
                                 batch_key="sampletype",
                                 target_genes=3500, 
                                 flavor='cell_ranger', 
                                 n_bins=20, 
                                 adataOut=True
                                )

    sc.pl.highly_variable_genes(adata)

In [None]:
n_batches={}
for i, (name, adata) in enumerate(adata_dict.items()):
    n_batches[i] = adata.var["highly_variable_nbatches"].value_counts()
    #ax[i] = n_batches[i].plot(kind="bar")
    print(n_batches[i])

In [None]:
# for i, adata in adata_dict.items():
#     ax = sns.scatterplot(
#         data=adata.var, x="means", y="dispersions", hue="highly_deviant", s=5
#     )
#     ax.set_xlim(None, 1.5)
#     ax.set_ylim(None, 3)
#     plt.show()

Most genes are not highly variable. By selecting the top 3500 genes, we capture HVGs that are variable in at least 2+ batches.

In [None]:
#del deviant_adata

In [None]:
for i, adata in adata_dict.items():
    # check that variables are unique:
    assert len(adata.var_names) == len(set(adata.var_names))

In [None]:
for i, adata in adata_dict.items():
    annot = sc.queries.biomart_annotations(
        "hsapiens",
        ["ensembl_gene_id", "external_gene_name", "start_position", "end_position", "chromosome_name"],
    ).set_index("ensembl_gene_id")

    adata.var[annot.columns] = annot

    adata.var.rename(columns={"external_gene_name": "Gene"}, inplace=True)
    adata.var['ensembl_gene_id'] = adata.var.index
    adata.var['Gene'] = adata.var['Gene'].fillna(adata.var['ensembl_gene_id'])
    adata.obs.index.name = 'CellID'
    adata.var.index = adata.var["Gene"]
    adata.var_names_make_unique()
    adata_dict[i] = adata

In [None]:
adata_dict['dev_adata'].var

In [None]:
adata_dict['adult_adata'].var

In [None]:
for i, adata in adata_dict.items():
    print(f"Before filtering: {adata.n_vars} genes")
    # check if MALAT1 gene is in the gene list
    if 'MALAT1' in adata.var_names:
        # create a list of genes to keep
        gene_list = adata.var_names != 'MALAT1'
        # slice the anndata object to select all genes except for MALAT1
        adata = adata[:, gene_list]
        adata_dict[i] = adata

    print(f"After MALAT1 filtering: {adata.n_vars} genes")

In [None]:
for i, adata in adata_dict.items():
    # mitochondrial genes
    adata.var["mt"] = adata.var_names.str.startswith("MT-")
    # ribosomal genes
    adata.var["ribo"] = adata.var_names.str.startswith(("RPS", "RPL"))
    # hemoglobin genes.
    # adata.var["hb"] = adata.var_names.str.contains(("^HB[^(P)]"))

    # sc.pp.calculate_qc_metrics(adata, qc_vars=["mt", "ribo", "hb"], inplace=True, percent_top=[20], log1p=False)
    # Filter out mitochondrial and ribosomal genes
    #print(f"Before filtering: {adata.n_vars} genes")
    #mt_genes = adata.var_names[adata.var['mt']]  # list of mitochondrial genes
    #ribo_genes = adata.var_names[adata.var['ribo']]  # list of ribosomal genes
    #genes_to_remove = np.concatenate([mt_genes, ribo_genes])
    #adata = adata[:, ~adata.var_names.isin(genes_to_remove)]
    #print(f"After filtering: {adata.n_vars} genes")

    # Calculate n_counts and n_genes
    adata.obs['n_counts'] = adata.X.sum(axis=1)
    adata.obs['n_genes'] = (adata.X > 0).sum(axis=1)
    
    adata_dict[i] = adata

# SAMPLE SEX DETERMINATION

In [None]:
for i, adata in adata_dict.items():
    # check if there is XIST expression
    if any(adata.var_names.str.match('XIST')) == True:
        chrY_genes = adata.var_names.intersection(annot.index[annot.chromosome_name == "Y"])

        adata.obs['percent_chrY'] = np.sum(
            adata[:, chrY_genes].X, axis=1).A1 / np.sum(adata.X, axis=1).A1 * 100

        # color inputs must be from either .obs or .var, so add in XIST expression to obs.
        adata.obs["XIST-counts"] = adata.X[:,adata.var_names.str.match('XIST')].toarray()

        sc.pl.violin(adata, ["XIST-counts", "percent_chrY"], jitter=0.4, groupby = 'sample', rotation= 90, save=f'{i}_XIST.svg')
        adata_dict[i] = adata
    else:
        pass
    

In [None]:
for i, adata in adata_dict.items():
    # Calculate median XIST-counts and percent_chrY values for each sample
    sample_medians = adata.obs.groupby('sample')['XIST-counts', 'percent_chrY'].median()

    # Define female and male criteria based on median XIST-counts and percent_chrY
    female_criteria = (sample_medians['XIST-counts'] > 0.5)
    male_criteria = (sample_medians['XIST-counts'] < 0.5)

    # Create a new categorical variable 'sex'
    adata.obs['sex'] = 'unknown'

    # Update 'sex' based on the female and male criteria
    for sample in sample_medians.index:
        if female_criteria[sample]:
            adata.obs.loc[adata.obs['sample'] == sample, 'sex'] = 'female'
        elif male_criteria[sample]:
            adata.obs.loc[adata.obs['sample'] == sample, 'sex'] = 'male'

    # Print the names of female and male samples
    female_samples = adata.obs.loc[adata.obs['sex'] == 'female', 'sample'].unique()
    male_samples = adata.obs.loc[adata.obs['sex'] == 'male', 'sample'].unique()
    
    adata_dict[i] = adata

    print(f"Female samples: {', '.join(female_samples)}")
    print(f"Male samples: {', '.join(male_samples)}")

# CELL CYCLE PHASE DETERMINATION

In [None]:
for i, adata in adata_dict.items():
    # Count number of genes before removing zero count genes
    print(adata.shape[1])
    # Remove genes with zero counts
    sc.pp.filter_genes(adata, min_counts=5, inplace=True)
    sc.pp.filter_cells(adata, min_genes=200)
    # Count number of genes after removing zero count genes
    print(adata.shape[1])
    adata_dict[i] = adata

In [None]:
for i, adata in adata_dict.items():
    adata.raw = adata.copy()
    adata_dict[i] = adata

In [None]:
for i, adata in adata_dict.items():
    adata.obs['libbatch'] = adata.obs['libbatch'].astype('category')
    adata.obs['sample'] = adata.obs['sample'].astype('category')

    scib.preprocessing.score_cell_cycle(adata, organism='human')
    sc.pl.violin(adata, ['S_score', 'G2M_score'],
                 jitter=0.4, groupby = 'sample', rotation=90, 
                 save=f'{i}_cell_cycle.svg'
                )
    adata_dict[i] = adata


# DIMENSIONALITY REDUCTION

In [None]:
def split_and_scale(anndata_obj: anndata.AnnData, obs_var: str) -> anndata.AnnData:
    """
    Splits anndata object into separate objects based on the given observation variable, scales each object using
    `sc.pp.scale` and merges them back together using `anndata.concat`.

    Parameters:
    -----------
    anndata_obj: anndata.AnnData
        Annotated data matrix with normalized, log-transformed counts.
    obs_var: str
        Observation variable to split the data on.
    """
    
    # Check if anndata.X is log-transformed and normalised
    if np.min(anndata_obj.X) >= 1:
        raise ValueError("Anndata object X is not log-transformed.")
    if np.max(anndata_obj.X) > 10:
        raise ValueError("Anndata object X is not normalised.")
    
    # First, make a copy of the input anndata object
    anndata_copy = anndata_obj.copy()
    
    # Split the data by the provided observation variable
    groups = anndata_copy.obs[obs_var].unique()
    split_data = [anndata_copy[anndata_copy.obs[obs_var] == group] for group in groups]
    
    # Scale each split data object using `sc.pp.scale`
    for i, data in enumerate(split_data):
        sc.pp.scale(data)
        split_data[i] = data
        
    # Merge the split data back together using `anndata.concat`
    merged_data = anndata.concat(split_data, join='outer', index_unique=None)
    
    del anndata_copy
    del split_data
    del groups
    
    return merged_data

In [None]:
scaled_adata = {}
for i, adata in adata_dict.items():
    scaled_adata[i] = split_and_scale(adata, 'sample')
    print(scaled_adata[i].X[1:10,1:10])

In [None]:
assert adata_dict.keys() == scaled_adata.keys()
for i, adata in adata_dict.items():
    adata_dict[i].layers['scaled'] = scaled_adata[i].X.copy()

In [None]:
for i, adata in adata_dict.items():
    print(adata_dict[i].layers['scaled'][0:5,0:5])

In [None]:
del scaled_adata 

In [None]:
for i, adata in adata_dict.items():
    adata.obsm["X_pca"] = sc.pp.pca(adata[:,adata.var.highly_variable].layers["scaled"], n_comps=50, svd_solver="arpack")
    adata_dict[i] = adata

In [None]:
def plot_pca(anndata, parameters: list, components: list, filename: str):
    n_plots = len(parameters)
    fig, axs = plt.subplots(n_plots, 1, figsize=(4, 3*n_plots))
    for i, param in enumerate(parameters):
        sc.pl.pca(anndata, color=param, ax=axs[i], show=False, components = components, frameon=False)
        axs[i].set_title(param)
    plt.tight_layout()
    savesvg(filename, fig)
    plt.show()

In [None]:
plot_pca(adata_dict['adult_adata'], ['sampletype', 'age', 'libbatch', 'sample', 'type', 'phase', 'sex'], 
     components = ['1,2'], filename = 'adult_PC1vs2_plots.svg')

In [None]:
def plot_pca(anndata, parameters: list, components: list, filename: str):
    n_plots = len(parameters)
    fig, axs = plt.subplots(n_plots, 1, figsize=(7, 4*n_plots))
    for i, param in enumerate(parameters):
        sc.pl.pca(anndata, color=param, ax=axs[i], show=False, components = components, frameon=False)
        axs[i].set_title(param)
    plt.tight_layout()
    savesvg(filename, fig)
    plt.show()

In [None]:
plot_pca(adata_dict['dev_adata'], ['sampletype', 'age', 'libbatch', 'sample', 'type', 'phase', 'sex'], 
     components = ['1,2'], filename = 'dev_PC1vs2_plots.svg')

In [None]:
plot_pca(adata, ['sampletype', 'age', 'libbatch', 'sample', 'type', 'phase', 'sex'], 
         components = ['3,4'], filename=f'PC3vs4_plots.svg')

In [None]:
for i, adata in adata_dict.items():
    sc.pp.neighbors(adata, n_neighbors=30, n_pcs=15)
    sc.tl.umap(adata)

In [None]:
def plot_umaps(anndata, parameters: list, filename: str):
    n_plots = len(parameters)
    fig, axs = plt.subplots(n_plots, 1, figsize=(4, 3*n_plots))
    for i, param in enumerate(parameters):
        sc.pl.umap(anndata, color=param, ax=axs[i], show=False, frameon=False)
        axs[i].set_title(param)
    plt.tight_layout()
    savesvg(filename, fig)
    plt.show()

In [None]:
plot_umaps(adata_dict['adult_adata'], ['sampletype', 'age', 'agegroup', 'libbatch', 'sample', 'type', 'phase', 'sex'], 
           filename = 'adult_UMAP_plots.svg')

In [None]:
def plot_umaps(anndata, parameters: list, filename: str):
    n_plots = len(parameters)
    fig, axs = plt.subplots(n_plots, 1, figsize=(7, 4*n_plots))
    for i, param in enumerate(parameters):
        sc.pl.umap(anndata, color=param, ax=axs[i], show=False, frameon=False)
        axs[i].set_title(param)
    plt.tight_layout()
    savesvg(filename, fig)
    plt.show()

In [None]:
plot_umaps(adata_dict['dev_adata'], ['sampletype', 'age', 'agegroup', 'libbatch', 'sample', 'type', 'phase', 'sex'], 
           filename = 'dev_UMAP_plots.svg')

In [None]:
plot_umaps(adata, ['sampletype', 'age', 'agegroup', 'libbatch', 'sample', 'type', 'phase', 'sex'], 
           filename = 'UMAP_plots.svg')

In [None]:
def plot_umaps2(anndata, parameters: list, filename: str):
    n_plots = len(parameters)
    fig, axs = plt.subplots(n_plots, 1, figsize=(4, 10))
    for i, param in enumerate(parameters):
        sc.pl.umap(anndata, color=param, ax=axs[i], show=False, frameon=False)
        axs[i].set_title(param)
    plt.tight_layout()
    savesvg(filename, fig)
    plt.show()

In [None]:
plot_umaps2(adata_dict['adult_adata'], ["n_counts", "n_genes", "pct_counts_mt"], 
            filename = 'adult_UMAPparameter_plots.svg')

In [None]:
plot_umaps2(adata_dict['dev_adata'], ["n_counts", "n_genes", "pct_counts_mt"], 
            filename = 'dev_UMAPparameter_plots.svg')

In [None]:
plot_umaps2(adata, ["n_counts", "n_genes", "pct_counts_mt"], 
            filename = 'UMAPparameter_plots.svg')

In [None]:
for adata in adata_dict.values():
    sc.tl.leiden(adata, resolution=0.5, key_added='leiden_05')

In [None]:
for adata in adata_dict.values():
    sc.pl.umap(adata, color='leiden_05')

In [None]:
for i, adata in adata_dict.items():
    print(i, adata)

In [None]:
for i, adata in adata_dict.items():
    adata.write(os.path.join(RESULTS_FOLDERNAME, f'{i}_normalized.h5ad'))

# EXTRA (do not run)

In [None]:
dev_adata.var_names_make_unique()
adult_adata.var_names_make_unique()

In [None]:
dev_adata.var = dev_adata.var.drop(['Gene'], axis=1)

In [None]:
dev_adata = sc.read_h5ad(os.path.join(RESULTS_FOLDERNAME, 'dev_adata_normalized.h5ad'))
adult_adata = sc.read_h5ad(os.path.join(RESULTS_FOLDERNAME, 'adult_adata_normalized.h5ad'))

In [None]:
dev_adata.var

In [None]:
adult_adata.var.index.name='Genes'
adult_adata.var = adult_adata.var.drop(['Gene'], axis=1)
adult_adata.var

In [None]:
dev_adata.write(os.path.join(RESULTS_FOLDERNAME, f'dev_adata_normalized.h5ad'))
adult_adata.write(os.path.join(RESULTS_FOLDERNAME, f'adult_adata_normalized.h5ad'))

this was run for scIB integration benchmarking