In [None]:
# Import dependencies
%matplotlib inline
import os
import numpy as np
import scanpy as sc
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
import anndata as ad
import warnings

# Ignore all warnings
warnings.simplefilter("ignore")

matplotlib.rcParams['font.family'] = 'sans-serif'

# Initialize random seed
import random
random.seed(111)

# Print date and time:
import datetime
e = datetime.datetime.now()
print ("Current date and time = %s" % e)

# wdir = "/ceph/project/tendonhca/akurjan/analysis/"
wdir = "/mnt/da8aa2c4-0136-465b-87a2-d12a59afec55/akurjan/analysis/notebooks/"
os.chdir( wdir )

# folder structures
INPUT_FOLDERNAME = "developmental/scVI/results/"
RESULTS_FOLDERNAME = "adult/integration/results/"
FIGURES_FOLDERNAME = "adult/integration/figures/"

if not os.path.exists(RESULTS_FOLDERNAME):
    os.makedirs(RESULTS_FOLDERNAME)
if not os.path.exists(FIGURES_FOLDERNAME):
    os.makedirs(FIGURES_FOLDERNAME)

# Set folder for saving figures into
sc.settings.figdir = FIGURES_FOLDERNAME

def savesvg(fname: str, fig, folder: str=FIGURES_FOLDERNAME) -> None:
    """
    Save figure as vector-based SVG image format.
    """
    fig.savefig(os.path.join(folder, fname), format='svg')

def plot_umaps(anndata, parameters: list, filename: str):
    n_plots = len(parameters)
    fig, axs = plt.subplots(n_plots, 1, figsize=(10, 4*n_plots))
    for i, param in enumerate(parameters):
        sc.pl.umap(anndata, color=param, ax=axs[i], show=False, frameon=False, s=2)
        axs[i].set_title(param)
    plt.tight_layout()
    savesvg(filename, fig)
    plt.show()    
    
# Set other settings
sc.settings.verbosity = 3 # verbosity: errors (0), warnings (1), info (2), hints (3)
sc.logging.print_versions()
sc.set_figure_params(dpi=150, fontsize=10, dpi_save=600)

# Loading and preparing developmental data

In [None]:
adata = sc.read_h5ad(os.path.join(INPUT_FOLDERNAME, 'concat.h5ad'))
adata.var_names = adata.var_names.astype(str)
adata.var_names_make_unique()
adata

In [None]:
adata.obs.drop(columns=['latent_RT_efficiency', 'latent_cell_probability', 'latent_scale',
                       'initial_size_unspliced', 'initial_size_spliced', 'initial_size',
       'log1p_n_genes_by_counts', 
       'log1p_total_counts', 'pct_counts_in_top_20_genes', 'total_counts_mt',
       'log1p_total_counts_mt', 'pct_counts_mt', 'total_counts_ribo',
       'log1p_total_counts_ribo', 'pct_counts_ribo', 'total_counts_hb',
       'log1p_total_counts_hb', 'pct_counts_hb', 'outlier', 'mt_outlier', 'samplename',
       'n_counts', 'n_genes', 'C_scANVI_orig', 'tendon_spmarker_score', 'sample_stage', 'norm_sample_stage', 
       'hospital_id', 'tissue', 'agegroup'
                       ], inplace=True)

In [None]:
adata.obs.libbatch = adata.obs.libbatch.str.replace("Illumina-C HiSeq 4000 Paired end sequencing_3' v2", "Embryonic")
adata.obs.libbatch = adata.obs.libbatch.astype('str').astype('category')
adata.obs.libbatch.value_counts()

In [None]:
adata.obs['group'] = adata.obs['libbatch'].astype('str') + '_' + adata.obs['type'].astype('str')
adata.obs['group'] = adata.obs['group'].astype('category')
adata.obs['type'] = adata.obs['type'].str.replace('Quad/Pat', 'Quad')

In [None]:
adata.obs['age'] = adata.obs['age'].astype("category")
list(adata.obs['age'].cat.categories)

In [None]:
adata.obs['age'] = pd.Categorical(adata.obs['age'], categories=['6.5w', '7.2w', '8.4w', '9.0w', '9.3w', '12w', '17w', '20w'], ordered=True)
adata = adata[adata.obs['age'].argsort()]
adata.obs 

In [None]:
adata.raw = adata

# Loading and preparing adult tendon data

In [None]:
adataquad = sc.read_h5ad('/mnt/da8aa2c4-0136-465b-87a2-d12a59afec55/akurjan/analysis/files/AdultData/20231130_quads_int_labelled.h5ad')
adataquad.var_names_make_unique()
adataquad

In [None]:
print(adataquad.X[0:20,0:20])

In [None]:
adataquad.X.max()

In [None]:
adataquad.obs

In [None]:
sc.pl.umap(adataquad, color=['age', 'cluster_id'], frameon=False)

In [None]:
sc.pl.umap(adataquad, color=['cluster_id'], legend_loc='on data', 
           legend_fontsize=5, frameon=False)

In [None]:
adataquad.obs['sequencing_date'].value_counts()

In [None]:
adataquad.obs['sex'].value_counts()

In [None]:
adataquad.obs['age'].value_counts()

In [None]:
adataquad.obs['sample'].value_counts()

In [None]:
adataquad.obs['tendon_disease'].value_counts()

In [None]:
adataquad.var.index = adataquad.var['rownames(so)'].copy()
adataquad.var.index.name = 'Gene'
adataquad.var

In [None]:
adataquad.obs['barcode'] = adataquad.obs.index.str.split('_').str[0]
print(adataquad.obs['barcode'].head())

In [None]:
adataach = sc.read_h5ad('/mnt/da8aa2c4-0136-465b-87a2-d12a59afec55/akurjan/analysis/files/AdultData/Achilles_integrated_annotated.h5ad')
adataach.var_names_make_unique()
adataach

In [None]:
print(adataach.X[0:20,0:20])

In [None]:
adataach.X.max()

In [None]:
adataach.obs

In [None]:
sc.pl.umap(adataach, color=['age', 'cell_annotation_update'], frameon=False)

In [None]:
sc.pl.umap(adataach, color=['cell_annotation_update'], legend_loc='on data', 
           legend_fontsize=5, frameon=False)

In [None]:
adataach.obs['sequencing_date'].value_counts()

In [None]:
adataach.obs['sex'].value_counts()

In [None]:
adataach.obs['age'].value_counts()

In [None]:
adataach.var.index = adataach.var['rownames(so)'].copy()
adataach.var.index.name = 'Gene'
adataach.var

In [None]:
adataach.obs['barcode'] = adataach.obs.index.str.split('_').str[-1]
print(adata.obs['barcode'].head())

In [None]:
adataquad.obs['sex'] = adataquad.obs['sex'].replace({'Male': 'male'})

date_to_batch = {
    '20230822': 'Aug2023',
    '20220808': 'Aug2022',
    '11102021': 'Oct2021',
    '20211213': 'Dec2021'
}

In [None]:
adataach.obs['libbatch'] = adataach.obs['sequencing_date'].map(date_to_batch)
adataquad.obs['libbatch'] = adataquad.obs['sequencing_date'].map(date_to_batch)

print(adataach.obs[['sequencing_date', 'libbatch']].head())
print(adataquad.obs[['sequencing_date', 'libbatch']].head())

In [None]:
adataach.obs['ageint'] = adataach.obs['age'].copy()

age_categories = ['45yr', '50yr', '51yr', '58yr', '74yr', '76yr']

age_with_yr = {age: f"{age}yr" for age in sorted(adataach.obs['age'].unique())}

adataach.obs['age'] = adataach.obs['age'].map(age_with_yr)
adataach.obs['age'] = pd.Categorical(adataach.obs['age'], categories=age_categories, ordered=True)

print(adataach.obs[['ageint', 'age']].head())

In [None]:
adataquad.obs['ageint'] = adataquad.obs['age'].copy()

age_categories = ['25yr', '29yr', '44yr', '67yr', '69yr', '75yr']
age_with_yr = {age: f"{age}yr" for age in sorted(adataquad.obs['age'].unique())}
adataquad.obs['age'] = adataquad.obs['age'].map(age_with_yr)
adataquad.obs['age'] = pd.Categorical(adataquad.obs['age'], categories=age_categories, ordered=True)

print(adataquad.obs[['ageint', 'age']].head())

In [None]:
adata.obs

In [None]:
adataach.obs['type'] = 'Ach'
adataquad.obs['type'] = 'Quad'

adataach.obs['sampletype'] = adataach.obs['sample'].astype('str') + '_' + adataach.obs['type'].astype('str')
adataquad.obs['sampletype'] = adataquad.obs['sample'].astype('str') + '_' + adataquad.obs['type'].astype('str')

adataquad.obs['tendon_status'] = adataquad.obs['tendon_disease'].copy()
adataach.obs['tendon_status'] = 'Healthy'
adata.obs['tendon_status'] = 'Healthy'

adataach.obs['group'] = 'Adult'
adataquad.obs['group'] = 'Adult'
adata.obs['group'] = np.where(adata.obs['libbatch'] == 'Embryonic', 'Embryonic', 'Foetal')

adataach.obs['grouptype'] = adataach.obs['group'].astype('str') + '_' + adataach.obs['type'].astype('str') 
adataquad.obs['grouptype'] = adataquad.obs['group'].astype('str') + '_' + adataquad.obs['type'].astype('str') 
adata.obs['grouptype'] = adata.obs['group'].astype('str') + '_' + adata.obs['type'].astype('str') 

adataquad.obs['microanat'] = 'MB'

In [None]:
adataach.obs['annotations_orig'] = adataach.obs['cell_annotation_update'].copy()
adataach.obs['annotations_orig_full'] = adataach.obs['grouptype'].astype('str') + '_' + adataach.obs['annotations_orig'].astype('str')

adataquad.obs['annotations_orig'] = adataquad.obs['cluster_id'].copy()
adataquad.obs['annotations_orig_full'] = adataquad.obs['grouptype'].astype('str') + '_' + adataquad.obs['annotations_orig'].astype('str')

adata.obs['annotations_orig'] = adata.obs['C_scANVI'].copy()
adata.obs['annotations_orig_full'] = adata.obs['C_scANVI'].copy()

In [None]:
adataach.obs['microanat'] = adataach.obs['sample'].str.split('-').str[-1]
rename_map = {
    'Enth': 'ENTH',
    'MB2': 'MB',
    'muscle': 'MUSCLE'
}
adataach.obs['microanat'] = adataach.obs['microanat'].replace(rename_map)
print(adataach.obs['microanat'].value_counts())

In [None]:
adata.obs['microanat'] = adata.obs['sample'].str.split('-').str[-1]
rename_map = {
    'DEV16127': 'FULL',           
    'DEV16135DEV16171': 'FULL',
    'DEV15985': 'MB',
    'DEV16569': 'FULL',
    'DEV15984': 'MB',
    'DEV16134': 'FULL',
    'DEV16136': 'FULL',
    'DEV15983': 'MB',
    'BRC2172': 'A-FULL',
    'BRC2181': 'A-FULL',
    'BRC2173': 'A-FULL',
    'BRC2083': 'A-FULL',
    'BRC2092': 'A-FULL',
    'BRC2114': 'A-FULL',
}
adata.obs['microanat'] = adata.obs['microanat'].replace(rename_map)
print(adata.obs['microanat'].value_counts())

In [None]:
adataach.obs['megagrouptype'] = adataach.obs['grouptype'].astype('str') + '_' + adataach.obs['microanat'].astype('str') 
adataquad.obs['megagrouptype'] = adataquad.obs['grouptype'].astype('str') + '_' + adataquad.obs['microanat'].astype('str') 
adata.obs['megagrouptype'] = adata.obs['grouptype'].astype('str') + '_' + adata.obs['microanat'].astype('str') 
print(adata.obs['megagrouptype'].value_counts())

In [None]:
sc.pl.umap(adataach, color='annotations_orig_full', frameon=False)

In [None]:
sc.pl.umap(adataquad, color='annotations_orig_full', frameon=False)

In [None]:
adataquad.obs.annotations_orig_full

# Concatenation and Preparation

In [None]:
concat = ad.concat([adata, adataach, adataquad], join='outer') 
concat

In [None]:
concat.obs.drop(columns=[
       'n_genes_by_counts', 'total_counts', 
       '_scvi_batch', '_scvi_labels', 'cell_type', 'S_score', 'G2M_score', 'phase', 
       'C_scANVI', 'seq_protocol', 'kit', 'modality',
       'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'sum',
       'detected', 'subsets_mito_sum', 'subsets_mito_detected',
       'subsets_mito_percent', 'total', 'log10GenesPerUMI', 'patient',
       'ethnicity', 'surgical_procedure', 'disease_status', 'tendon_disease',
       'anatomical_site', 'time_to_freezing', 'sequencing_date',
       'seurat_clusters', 'decontX_contamination', 'decontX_clusters',
       'nCount_decontXcounts', 'nFeature_decontXcounts', 'RNA_snn_res.0.3',
       'nCount_SoupXcounts', 'nFeature_SoupXcounts', 'scDblFinder.class',
       'scDblFinder.score', 'SoupXcounts_snn_res.0.3', 'RNA_snn_res.0.1',
       'SoupXcounts_snn_res.0', 'SoupXcounts_snn_res.0.1',
       'SoupXcounts_snn_res.0.2', 'SoupXcounts_snn_res.0.4',
       'SoupXcounts_snn_res.0.5', 'SoupXcounts_snn_res.0.6',
       'SoupXcounts_snn_res.0.7', 'SoupXcounts_snn_res.0.8', 'cluster_id',
       'cluster_idbackup', 'ident_diseasestatus', 'affected_side',
       'microanatomical_site', 'sizeFactor', 'scDblFinder.cluster',
       'scDblFinder.weighted', 'scDblFinder.difficulty',
       'scDblFinder.cxds_score', 'scDblFinder.mostLikelyOrigin',
       'scDblFinder.originAmbiguous', 'nCount_soupX', 'nFeature_soupX',
       'soupX_snn_res.0.1', 'soupX_snn_res.0.2', 'soupX_snn_res.0.3',
       'soupX_snn_res.0.4', 'soupX_snn_res.0.5', 'soupX_snn_res.0.6',
       'soupX_snn_res.0.7', 'soupX_snn_res.0.8', 'soupX_snn_res.0.9',
       'soupX_snn_res.1', 'cell_annotation_0.1', 'cell_annotation_0.2',
       'cell_annotation_0.3', 'cell_annotation_update',
       'cell_annotation_spatial'], inplace=True)

del concat.obsm['pca'], concat.obsm['umap']
del concat.layers['decontX'], concat.layers['logcounts'], concat.layers['soupX']

In [None]:
del adataach, adataquad, adata

In [None]:
sc.pp.filter_genes(concat, min_counts=30, inplace=True)
sc.pp.filter_cells(concat, min_genes=200)

In [None]:
conversion_mapping = {
    'QUAD': 'Quad',
    'ACH': 'Ach'
}
type_mapping = concat.obs.set_index('sample')['type'].to_dict()

# Apply conversion to all sample types
concat.obs['sampletype'] = concat.obs['sample'].apply(lambda x: x.split('-')[0] if x.startswith('MSK') else x + '-' + type_mapping.get(x, ''))
concat.obs['sampletype'] = concat.obs['sampletype'].apply(lambda x: '-'.join([conversion_mapping.get(part, part) for part in x.split('-')]))

# For samples starting with 'MSK', extract the first two parts and apply conversion
mask = concat.obs['sample'].str.startswith('MSK')
concat.obs.loc[mask, 'sampletype'] = concat.obs.loc[mask, 'sample'].apply(lambda x: '-'.join([conversion_mapping.get(part, part) for part in x.split('-')[:2]]))

print(concat.obs['sampletype'].value_counts())

In [None]:
concat

In [None]:
concat.write(os.path.join(RESULTS_FOLDERNAME, 'adultdev_combined.h5ad'))

In [None]:
sc.pp.highly_variable_genes(concat, n_top_genes=7000, flavor="seurat_v3", batch_key='sampletype', subset=False, span=1)
sc.pp.normalize_total(concat)
sc.pp.log1p(concat)
sc.pp.scale(concat)
sc.pp.pca(concat)
sc.pp.neighbors(concat)
sc.tl.umap(concat)

In [None]:
sc.pl.umap(concat, color='annotations_orig_full', frameon=False)

In [None]:
concat.obs

In [None]:
#concat.raw = concat

In [None]:
sc.pp.highly_variable_genes(concat, n_top_genes=7000, flavor="seurat_v3", batch_key='sampletype', subset=False, span=1)
sc.pl.highly_variable_genes(concat)

In [None]:
g2m_genes = [
    'HMGB2', 'CDK1', 'NUSAP1', 'UBE2C', 'BIRC5', 'TPX2', 'TOP2A', 'NDC80', 'CKS2',
    'NUF2', 'CKS1B', 'MKI67', 'TMPO', 'CENPF', 'TACC3', 'FAM64A', 'SMC4', 'CCNB2',
    'CKAP2L', 'CKAP2', 'AURKB', 'BUB1', 'KIF11', 'ANP32E', 'TUBB4B', 'GTSE1', 'KIF20B',
    'HJURP', 'CDCA3', 'HN1', 'CDC20', 'TTK', 'CDC25C', 'KIF2C', 'RANGAP1', 'NCAPD2',
    'DLGAP5', 'CDCA2', 'CDCA8', 'ECT2', 'KIF23', 'HMMR', 'AURKA', 'PSRC1', 'ANLN', 'LBR',
    'CKAP5', 'CENPE', 'CTCF', 'NEK2', 'G2E3', 'GAS2L3', 'CBX5', 'CENPA'
]

s_genes = [
    'MCM5', 'PCNA', 'TYMS', 'FEN1', 'MCM2', 'MCM4', 'RRM1', 'UNG', 'GINS2', 'MCM6',
    'CDCA7', 'DTL', 'PRIM1', 'UHRF1', 'MLF1IP', 'HELLS', 'RFC2', 'RPA2', 'NASP', 'RAD51AP1',
    'GMNN', 'WDR76', 'SLBP', 'CCNE2', 'UBR7', 'POLD3', 'MSH2', 'ATAD2', 'RAD51', 'RRM2',
    'CDC45', 'CDC6', 'EXO1', 'TIPIN', 'DSCC1', 'BLM', 'CASP8AP2', 'USP1', 'CLSPN', 'POLA1',
    'CHAF1B', 'BRIP1', 'E2F8'
]

sc.tl.score_genes_cell_cycle(concat, s_genes, g2m_genes)

In [None]:
sc.pl.violin(concat, ['S_score', 'G2M_score'],
            jitter=0.4, groupby = 'sampletype', rotation=90, 
            )

In [None]:
print(concat.X[0:10, 0:10])

In [None]:
sc.pp.normalize_total(concat, target_sum=None, inplace=True)
sc.pp.log1p(concat)
print(concat.X[0:10, 0:10])

In [None]:
concat.layers["log1p_norm"] = concat.X.copy()

In [None]:
sc.pp.scale(concat)
print(concat.X[0:5,0:5])

In [None]:
concat.layers['scaled'] = concat.X.copy()

In [None]:
sc.pp.pca(concat, n_comps=40, svd_solver="arpack")

In [None]:
#explained_var = concat.uns['pca']['variance']
#cumulative_var = np.cumsum(explained_var) / np.sum(explained_var)
#num_pcs_90_var = np.argmax(cumulative_var >= 0.9) + 1
#num_pcs_90_var

In [None]:
sc.pl.pca_loadings(concat, components='1,2,3,4,5,6,7,8')

In [None]:
for var in ['grouptype', 'megagrouptype', 'age', 'libbatch', "phase", "sex", 'tendon_status', 'microanat', 'sampletype']:
    sc.pl.pca(concat, components=['1,2', '3,4', '5,6', '7,8'], ncols=4, color=var)

In [None]:
#adata.X = adata.layers['log1p_norm'].copy()
#print(adata.X[0:10,0:10])
#sc.pp.regress_out(adata, ['S_score', 'G2M_score'], n_jobs=20)
#print(adata.X[0:5,0:5])
#adata.layers['regressed_cc'] = adata.X.copy()

In [None]:
sc.pp.neighbors(concat, metric='correlation')
sc.tl.umap(concat)

In [None]:
sc.pl.umap(concat, color='annotations_orig_full', frameon=False)

In [None]:
def plot_umaps(anndata, parameters: list, filename: str):
    n_plots = len(parameters)
    fig, axs = plt.subplots(n_plots, 1, figsize=(8, 4*n_plots))
    for i, param in enumerate(parameters):
        sc.pl.umap(anndata, color=param, ax=axs[i], show=False, frameon=False, s=2)
        axs[i].set_title(param)
    plt.tight_layout()
    savesvg(filename, fig)
    plt.show()   

In [None]:
plot_umaps(concat, ["group", 
                   'grouptype', 
                   'megagrouptype', 
                   'age', 'libbatch', "phase", 
                   "sex", 'tendon_status', 'microanat',
                   "sampletype"],
           filename='unintegrated_fulldevadult.svg'
          )

In [None]:
sc.pl.umap(concat, color='annotations_orig_full', frameon=False, save='unintegrated_orig_full.svg')

In [None]:
concat

In [None]:
concat.write(os.path.join(RESULTS_FOLDERNAME, 'adultdev_combined.h5ad'))

In [None]:
all_annotations = {}
if 'annotations_orig_full_colors' in concat.uns:
    cell_types = concat.obs['annotations_orig_full'].cat.categories
    colors = concat.uns['annotations_orig_full_colors']
    print("Cell Types and Their Colors:")
    for cell_type, color in zip(cell_types, colors):
        all_annotations[cell_type] = color
        #print(f"'{cell_type}': '{color}',")
else:
    print("Color palette for 'annotations_orig' not found. Run a plot first.")

all_annotations

In [None]:
group_annotations = concat.obs[concat.obs['group'] == 'Embryonic']['annotations_orig_full']
unique_annotations = list(pd.unique(group_annotations))
#unique_annotations = ['Skeletal Myocytes']
highlighted_clusters = {annotation: all_annotations[annotation] for annotation in unique_annotations 
                        if annotation in all_annotations}
unique_clusters = concat.obs['annotations_orig_full'].cat.categories
color_palette = [highlighted_clusters.get(cluster, 'lightgray') for cluster in unique_clusters]
sc.pl.umap(concat, color=['annotations_orig_full'], palette=color_palette, s=10,
                frameon=False, save='_merged_notintegrated_embryonicct.svg')

In [None]:
group_annotations = concat.obs[concat.obs['group'] == 'Foetal']['annotations_orig_full']
unique_annotations = list(pd.unique(group_annotations))
#unique_annotations = ['Skeletal Myocytes']
highlighted_clusters = {annotation: all_annotations[annotation] for annotation in unique_annotations 
                        if annotation in all_annotations}
unique_clusters = concat.obs['annotations_orig_full'].cat.categories
color_palette = [highlighted_clusters.get(cluster, 'lightgray') for cluster in unique_clusters]
sc.pl.umap(concat, color=['annotations_orig_full'], palette=color_palette, s=10,
                frameon=False, save='_merged_notintegrated_foetalct.svg')

In [None]:
group_annotations = concat.obs[concat.obs['group'] == 'Adult']['annotations_orig_full']
unique_annotations = list(pd.unique(group_annotations))
#unique_annotations = ['Skeletal Myocytes']
highlighted_clusters = {annotation: all_annotations[annotation] for annotation in unique_annotations 
                        if annotation in all_annotations}
unique_clusters = concat.obs['annotations_orig_full'].cat.categories
color_palette = [highlighted_clusters.get(cluster, 'lightgray') for cluster in unique_clusters]
sc.pl.umap(concat, color=['annotations_orig_full'], palette=color_palette, s=10,
                frameon=False, save='_merged_notintegrated_adultct.svg')

# scANVI Integration

In [None]:
adata = sc.read_h5ad(os.path.join(RESULTS_FOLDERNAME, 'adultdev_combined.h5ad'))
adata.var_names_make_unique()
adata

In [None]:
adata.X = adata.layers['counts'].copy()
print(adata.X[0:10,0:10])

In [None]:
adata.raw = adata

In [None]:
adata = adata[:, adata.var.highly_variable].copy()
adata

In [None]:
import ray
import hyperopt
import scvi
from ray import tune
from scvi import autotune

model_cls = scvi.model.SCVI
model_cls.setup_anndata(adata, layer="counts", labels_key='annotations_orig_full',
                        batch_key='libbatch',
                        categorical_covariate_keys=["sampletype"],
                        continuous_covariate_keys=["G2M_score", "S_score"])

scvi_tuner = autotune.ModelTuner(model_cls)
scvi_tuner.info()

In [None]:
search_space = {
    "n_latent": tune.choice([10, 30, 50]),
    "n_hidden": tune.choice([60, 128, 256]),
    "n_layers": tune.choice([1, 2, 3]),
    "lr": tune.loguniform(1e-4, 1e-2),
    "gene_likelihood": tune.choice(["nb", "zinb"])
}

In [None]:
ray.init(log_to_driver=False)

In [None]:
results = scvi_tuner.fit(
    adata,
    metric="validation_loss",
    search_space=search_space,
    searcher='hyperopt',
    num_samples=100,
    max_epochs=30,
    resources={"gpu": 1},
)

In [None]:
print(results.model_kwargs)
print(results.train_kwargs)

In [None]:
best_vl = 10000
best_i = 0
for i, res in enumerate(results.results):
    vl = res.metrics['validation_loss']

    if vl < best_vl:
        best_vl = vl
        best_i = i
        
results.results[best_i]

In [None]:
ray.shutdown()

In [None]:
import scvi

scvi.model.SCVI.setup_anndata(adata,
                              layer="counts", labels_key='annotations_orig_full',
                              batch_key='libbatch',
                              categorical_covariate_keys=["sampletype"],
                              continuous_covariate_keys=["G2M_score", "S_score"]
                             )

In [None]:
vae = scvi.model.SCVI(adata, n_hidden = 256, n_latent=50, n_layers=1, 
                      dropout_rate=0.1, dispersion='gene-batch',
                      gene_likelihood='zinb')
vae

In [None]:
vae.view_anndata_setup(adata)

In [None]:
max_epochs_scvi = np.min([round((20000 / adata.n_obs) * 400), 400])
max_epochs_scvi

In [None]:
%%time

vae.train(max_epochs = 80, train_size = 0.9, validation_size = 0.1, 
          use_gpu=True, accelerator='gpu', 
          check_val_every_n_epoch=4,
          early_stopping=True,
          early_stopping_patience=5,
          early_stopping_monitor="elbo_validation",
          plan_kwargs = {'lr': 0.0025}
         )

In [None]:
train_test_results = vae.history["elbo_train"]
train_test_results["elbo_validation"] = vae.history["elbo_validation"]
plt.show()

In [None]:
y = vae.history['reconstruction_loss_validation']['reconstruction_loss_validation'].min()
plt.plot(vae.history['reconstruction_loss_train']['reconstruction_loss_train'], label='train')
plt.plot(vae.history['reconstruction_loss_validation']['reconstruction_loss_validation'], label='validation')
plt.axhline(y, c = 'k')
plt.legend()
plt.show()

In [None]:
vae.save(os.path.join(RESULTS_FOLDERNAME, "DevAdult_LibbatchSampletype_G2MSScores_256_50_1_01_Zinb_GeneBatch/"), overwrite=True)

In [None]:
adata.obsm["X_scVI"] = vae.get_latent_representation()
sc.pp.neighbors(adata, use_rep="X_scVI", metric='correlation')
sc.tl.umap(adata)

In [None]:
plot_umaps(adata, ["group", 
                   'grouptype', 
                   'megagrouptype', 
                   'age', 'libbatch', "phase", 
                   "sex", 'tendon_status', 'microanat',
                   "sample", "sampletype"],
           filename='scVIintegrated_fulldevadult_libbatchsampletype.svg'
          )

In [None]:
sc.pl.umap(adata, color=['annotations_orig_full'], frameon=False)

In [None]:
adata.write(os.path.join(RESULTS_FOLDERNAME, 'adultdev_combined_scVI.h5ad'))

In [None]:
all_annotations = {}
if 'annotations_orig_full_colors' in adata.uns:
    cell_types = adata.obs['annotations_orig_full'].cat.categories
    colors = adata.uns['annotations_orig_full_colors']
    print("Cell Types and Their Colors:")
    for cell_type, color in zip(cell_types, colors):
        all_annotations[cell_type] = color
        #print(f"'{cell_type}': '{color}',")
else:
    print("Color palette for 'annotations_orig' not found. Run a plot first.")

all_annotations

In [None]:
group_annotations = adata.obs[adata.obs['group'] == 'Embryonic']['annotations_orig_full']
unique_annotations = list(pd.unique(group_annotations))
#unique_annotations = ['Skeletal Myocytes']
highlighted_clusters = {annotation: all_annotations[annotation] for annotation in unique_annotations 
                        if annotation in all_annotations}
unique_clusters = adata.obs['annotations_orig_full'].cat.categories
color_palette = [highlighted_clusters.get(cluster, 'lightgray') for cluster in unique_clusters]
sc.pl.umap(adata, color=['annotations_orig_full'], palette=color_palette, s=10, 
                frameon=False, save='_scVI_embryonicct.svg'
          )

In [None]:
group_annotations = adata.obs[adata.obs['group'] == 'Foetal']['annotations_orig_full']
unique_annotations = list(pd.unique(group_annotations))
#unique_annotations = ['Skeletal Myocytes']
highlighted_clusters = {annotation: all_annotations[annotation] for annotation in unique_annotations 
                        if annotation in all_annotations}
unique_clusters = adata.obs['annotations_orig_full'].cat.categories
color_palette = [highlighted_clusters.get(cluster, 'lightgray') for cluster in unique_clusters]
sc.pl.umap(adata, color=['annotations_orig_full'], palette=color_palette, s=10,
           frameon=False, save='_scVI_foetalct.svg'
          )

In [None]:
group_annotations = adata.obs[adata.obs['group'] == 'Adult']['annotations_orig_full']
unique_annotations = list(pd.unique(group_annotations))
#unique_annotations = ['Skeletal Myocytes']
highlighted_clusters = {annotation: all_annotations[annotation] for annotation in unique_annotations 
                        if annotation in all_annotations}
unique_clusters = adata.obs['annotations_orig_full'].cat.categories
color_palette = [highlighted_clusters.get(cluster, 'lightgray') for cluster in unique_clusters]
sc.pl.umap(adata, color=['annotations_orig_full'], palette=color_palette, s=10,
                frameon=False, save='_scVI_adultct.svg')

In [None]:
lvae = scvi.model.SCANVI.from_scvi_model(
    vae,
    adata=adata,
    labels_key="annotations_orig_full",
    unlabeled_category="Unknown",
)

In [None]:
lvae.train(max_epochs=10, train_size = 0.9, validation_size = 0.1, 
          use_gpu=True, accelerator='gpu', 
          check_val_every_n_epoch=1,
          early_stopping=True,
          early_stopping_patience=2,
          early_stopping_monitor="elbo_validation")

In [None]:
y = lvae.history['reconstruction_loss_validation']['reconstruction_loss_validation'].min()
plt.plot(lvae.history['reconstruction_loss_train']['reconstruction_loss_train'], label='train')
plt.plot(lvae.history['reconstruction_loss_validation']['reconstruction_loss_validation'], label='validation')
plt.axhline(y, c = 'k')
plt.legend()
plt.show()

In [None]:
adata.obs["C_scANVI"] = lvae.predict(adata)
adata.obsm["X_scANVI"] = lvae.get_latent_representation(adata)

In [None]:
sc.pp.neighbors(adata, use_rep="X_scANVI", metric='correlation')
sc.tl.umap(adata)
plot_umaps(adata, ["group", 
                   'grouptype', 
                   'megagrouptype', 
                   'age', 'libbatch', "phase", 
                   "sex", 'tendon_status', 'microanat',
                   "sampletype"],
           filename='scANVIintegrated_fulldevadult_libbatchsampletype.svg'
          )

In [None]:
group_annotations = adata.obs[adata.obs['group'] == 'Embryonic']['annotations_orig_full']
unique_annotations = list(pd.unique(group_annotations))
#unique_annotations = ['Skeletal Myocytes']
highlighted_clusters = {annotation: all_annotations[annotation] for annotation in unique_annotations 
                        if annotation in all_annotations}
unique_clusters = adata.obs['annotations_orig_full'].cat.categories
color_palette = [highlighted_clusters.get(cluster, 'lightgray') for cluster in unique_clusters]
sc.pl.umap(adata, color=['annotations_orig_full'], palette=color_palette, s=10, 
                frameon=False, save='_scANVI_embryonicct.svg'
          )

In [None]:
group_annotations = adata.obs[adata.obs['group'] == 'Foetal']['annotations_orig_full']
unique_annotations = list(pd.unique(group_annotations))
#unique_annotations = ['Skeletal Myocytes']
highlighted_clusters = {annotation: all_annotations[annotation] for annotation in unique_annotations 
                        if annotation in all_annotations}
unique_clusters = adata.obs['annotations_orig_full'].cat.categories
color_palette = [highlighted_clusters.get(cluster, 'lightgray') for cluster in unique_clusters]
sc.pl.umap(adata, color=['annotations_orig_full'], palette=color_palette, s=10, 
                frameon=False, save='_scANVI_foetalct.svg'
          )

In [None]:
sc.pl.umap(adata, color=['annotations_orig_full'], palette=color_palette, s=10, 
           legend_loc='on data', legend_fontsize=3,
           frameon=False, save='_scANVI_foetalct_annotated.svg'
          )

In [None]:
sc.pl.umap(adata, color=['tendon_status'], palette=color_palette, s=10, 
           legend_loc='on data', legend_fontsize=3,
           frameon=False, 
          )

In [None]:
group_annotations = adata.obs[adata.obs['group'] == 'Adult']['annotations_orig_full']
unique_annotations = list(pd.unique(group_annotations))
#unique_annotations = ['Skeletal Myocytes']
highlighted_clusters = {annotation: all_annotations[annotation] for annotation in unique_annotations 
                        if annotation in all_annotations}
unique_clusters = adata.obs['annotations_orig_full'].cat.categories
color_palette = [highlighted_clusters.get(cluster, 'lightgray') for cluster in unique_clusters]
sc.pl.umap(adata, color=['annotations_orig_full'], palette=color_palette, s=10, 
                frameon=False, save='_scANVI_adultct.svg'
          )

In [None]:
df = adata.obs.groupby(["annotations_orig_full", "C_scANVI"]).size().unstack(fill_value=0)
conf_mat = df / df.sum(axis=1).values[:, np.newaxis]

plt.figure(figsize=(8, 8))
plt.pcolormesh(conf_mat, edgecolors='k', linewidths=0.5, cmap='viridis')
plt.xticks(np.arange(0.5, len(df.columns), 1), df.columns, rotation=90)
plt.yticks(np.arange(0.5, len(df.index), 1), df.index)
plt.grid(False)
plt.xlabel("Predicted")
plt.ylabel("Observed")
# Add colorbar for better interpretation of the plot
plt.colorbar(label='Proportion')
savesvg('scANVI_prediction_matrix.svg', plt)
plt.show()

In [None]:
sc.pl.umap(adata, color=['C_scANVI'], frameon=False, save='_scANVI_predicted.svg')

In [None]:
sc.pl.umap(adata, color=['ageint'], frameon=False, save='_scANVI_ageint.svg')

In [None]:
group_annotations = adata.obs['annotations_orig_full']
unique_annotations = list(pd.unique(group_annotations))
#unique_annotations = ['Skeletal Myocytes']
highlighted_clusters = {annotation: all_annotations[annotation] for annotation in unique_annotations 
                        if annotation in all_annotations}
unique_clusters = adata.obs['annotations_orig_full'].cat.categories
color_palette = [highlighted_clusters.get(cluster, 'lightgray') for cluster in unique_clusters]

In [None]:
sc.pl.umap(adata, color=['annotations_orig_full'], palette=color_palette, frameon=False)

In [None]:
adata.write(os.path.join(RESULTS_FOLDERNAME, 'adultdev_combined_scANVI.h5ad'))

In [None]:
adata = sc.read_h5ad(os.path.join(RESULTS_FOLDERNAME, 'adultdev_combined_scANVI.h5ad'))
adata

In [None]:
print(adata.X[0:10,0:10])

In [None]:
plt.figure(figsize=(30, 35))
sc.tl.dendrogram(adata, 'annotations_orig_full', use_rep='X_scANVI')
ax_list = sc.pl.correlation_matrix(adata, 'annotations_orig_full', cmap='PuOr_r', show=False)
for ax in ax_list:
    ax.grid(False)
plt.savefig(os.path.join(FIGURES_FOLDERNAME,'annotation_correlation.svg'), bbox_inches='tight')
plt.show()

In [None]:
plt.figure(figsize=(10, 15))
sc.tl.dendrogram(adata, 'megagrouptype', use_rep='X_scANVI')
ax_list = sc.pl.correlation_matrix(adata, 'megagrouptype', cmap='PuOr_r', show=False)
for ax in ax_list:
    ax.grid(False)
plt.savefig(os.path.join(FIGURES_FOLDERNAME,'megagrouptype_correlation.svg'), bbox_inches='tight')
plt.show()

In [None]:
adata.obs[['age', 'tendon_status']].value_counts()

In [None]:
plt.figure(figsize=(15, 20))
sc.tl.dendrogram(adata, 'age', use_rep='X_scANVI')
ax_list = sc.pl.correlation_matrix(adata, 'age', cmap='PuOr_r', show=False)
for ax in ax_list:
    ax.grid(False)
plt.savefig(os.path.join(FIGURES_FOLDERNAME,'age_correlation.svg'), bbox_inches='tight')
plt.show()

In [None]:
sc.pl.violin(adata, 'TPPP3', groupby='age', 
             use_raw=False, layer='log1p_norm',
             rotation=90)

# Extra

In [None]:
import scgen

adata.X = adata.layers['log1p_norm'].copy()
print(adata.X[0:5,0:5])

In [None]:
adata.X.max()

In [None]:
scgen.SCGEN.setup_anndata(adata, 
                          batch_key="sampletype",
                          labels_key="annotations_orig_full")

model = scgen.SCGEN(adata)
model.view_anndata_setup()

In [None]:
model.train(
    max_epochs=100,
    use_gpu=True,
    batch_size=100,
    early_stopping=True,
    early_stopping_patience=50,
)

In [None]:
y = model.history['reconstruction_loss_validation']['reconstruction_loss_validation'].min()
plt.plot(model.history['reconstruction_loss_train']['reconstruction_loss_train'], label='train')
plt.plot(model.history['reconstruction_loss_validation']['reconstruction_loss_validation'], label='validation')
plt.axhline(y, c = 'k')
plt.legend()
plt.show()

In [None]:
corrected_adata = model.batch_removal()
corrected_adata

In [None]:
print(corrected_adata.X[0:5, 0:5])

In [None]:
corrected_adata.layers['scgen_corrected'] = corrected_adata.X.copy()

In [None]:
sc.pp.neighbors(corrected_adata, n_neighbors=20, use_rep='corrected_latent')
sc.tl.umap(corrected_adata)
sc.pl.umap(corrected_adata,
           color=['libbatch', 'group', 'phase', 'ageint', 'annotations_orig'], 
           ncols=2, wspace=0.4, frameon=False,
          save='scgen_correctedcounts_integrated_sampletype.svg'
          )

In [None]:
sc.pl.umap(corrected_adata,
           color=['sampletype'], frameon=False)

In [None]:
corrected_adata.write(os.path.join(RESULTS_FOLDERNAME, 'devadult_scGen_sampletype.h5ad'))

In [None]:
import scranPY

adata.X = adata.X.toarray()
adata.X = adata.X.astype(np.float64)

scranPY.compute_sum_factors(adata, clusters=None, parallelize=True, algorithm='CVXPY', sizes=np.arange(21, 102, 5), 
   max_size=3000, min_mean=None, plotting=True, lower_bound=0.1, normalize_counts=False, log1p=False, layer='scranPY', 
   save_plots_dir=FIGURES_FOLDERNAME, stopwatch=True)

scran = adata.X / adata.obs["size_factors"].values[:, None]
adata.layers["scranPY"] = csr_matrix(sc.pp.log1p(scran))
adata.X = adata.layers['counts'].copy()

print(adata.X[0:10, 0:10])
print(adata.layers["scranPY"][0:10, 0:10])