# Hypertune: grid hypertune

The number of highly variable gene (HVG) and latent number is very important for scANVI algrithm. Therefore, we using [grid search](https://www.dremio.com/wiki/grid-search/) For hyper-parameter tuning.

In [1]:
import scanpy as sc
import torch
import scarches as sca
from scarches.dataset.trvae.data_handling import remove_sparsity
import matplotlib.pyplot as plt
import numpy as np
import argparse
import anndata
import pandas as pd

2024-10-18 16:14:16.478401: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
INFO:pytorch_lightning.utilities.seed:Global seed set to 0
  new_rank_zero_deprecation(
  return new_rank_zero_deprecation(*args, **kwargs)
 captum (see https://github.com/pytorch/captum).


We use a subset of the whole datasets for speeding up.

In [2]:
adata_whole=sc.read("../../process/pre-intergration/big_data/20241008_core_pp_log1p_half_gene_small_whole.h5ad")

In [3]:
adata_whole

AnnData object with n_obs × n_vars = 50766 × 19640
    obs: 'Age', 'Core_datasets', 'Cre', 'Data Source', 'Data location', 'Development stage', 'Disease', 'Dissociation_enzyme', 'FACs', 'Gene Type', 'Histology', 'Journal', 'Knownout_gene', 'Machine', 'Mandibular_Maxillary', 'Molar_Incisor', 'Project', 'Related assay', 'Sample', 'Sex', 'Species', 'Stage', 'Strain', 'Tooth position', 'Treatment', 'coarse_anno_1', 'compl', 'log10_total_counts', 'log1p_n_genes_by_counts', 'log1p_total_counts', 'mito_frac', 'nCount_RNA', 'nFeature_RNA', 'n_genes_by_counts', 'n_genes_detected', 'orig.ident', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', 'pct_counts_in_top_50_genes', 'ribo_frac', 'size_factors', 'total_counts'
    var: 'gene_symbol', 'gene_symbols', 'n_cells', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'highly_variable_nbatches', 'highly_variable_intersection'
    uns: 'X_name', 'hvg', 'log1p'
    layers: 'counts'

In [4]:
condition_key = "Sample"
labels_key = "coarse_anno_1"
unlabeled_category = "Unlabeled"

In [5]:
adata_whole.obs[labels_key]

CGATCGGTACATCGGTTCGACATGGCT_3_8    Epithelium
TGGTACATCTGATGGT-1_2_1                 Immune
CTCTCAGTCAACTTTC-1_3_3             Mesenchyme
ATAGACCAGGAGAATG.2_16              Mesenchyme
GTCTACCCATGTAACC-1_2_1                 Immune
                                      ...    
ACCACAATCGCAGTCG.3_16              Epithelium
GGTCACGTCGCGCCAA-1_3_1             Mesenchyme
GTTGTCCCACCATAAC-1_3_1             Mesenchyme
AAAGGATTCTGAGAGG-1_3_3             Mesenchyme
GACTTCCGTGCAAGAC-1_2_1             Mesenchyme
Name: coarse_anno_1, Length: 50766, dtype: category
Categories (8, object): ['Endothelium', 'Epithelium', 'Immune', 'Mesenchyme', 'Muscle', 'Neuron', 'Perivascular', 'RBC']

We saved the integration results in one folder, and then used [scib](https://github.com/theislab/scib)

In [12]:
def hypertune(hvg_n,nlatent):
    sc.pp.highly_variable_genes(adata_whole, batch_key=condition_key,min_mean=0.035, flavor="cell_ranger",n_top_genes=hvg_n)
    adata = adata_whole[:,adata_whole.var.highly_variable].copy()
    adata.X = adata.layers['counts']
    adata.raw = adata
    raw = adata.raw.to_adata()
    raw.X = adata.layers['counts']
    adata.raw = raw
    sca.models.SCVI.setup_anndata(adata, batch_key=condition_key, labels_key=labels_key)
    vae = sca.models.SCVI(
        adata,
        n_latent=nlatent,
        encode_covariates=True,
        deeply_inject_covariates=False,
        use_layer_norm="both",
        use_batch_norm="none",
        )
    vae.train(max_epochs=40)
    scanvae = sca.models.SCANVI.from_scvi_model(vae, unlabeled_category =unlabeled_category)
    scanvae.train(max_epochs=20)
    reference_latent = sc.AnnData(scanvae.get_latent_representation())
    reference_latent.obs[labels_key] = adata.obs[labels_key].tolist()
    reference_latent.obs[condition_key] = adata.obs[condition_key].tolist()
    reference_latent.obs["Project"] = adata.obs["Project"].tolist()
    sc.pp.neighbors(reference_latent, n_neighbors=8)
    sc.tl.leiden(reference_latent)
    sc.tl.umap(reference_latent)
    sc.pl.umap(reference_latent,
        color=[labels_key, condition_key],
        frameon=False,
        wspace=0.6,save="gene_{}_latent_{}_umap".format(hvg_n, nlatent)
        )
    reference_latent.write("../../process/pre-intergration/hypertune/gene_{}_latent_{}_umap.h5ad".format(hvg_n, nlatent))

In [13]:
hyperPara=pd.read_table("../../data/hypertune/hyperpara",header=None)

In [None]:
for i in range(19):
    hypertune(hvg_n=hyperPara[0][i],nlatent=hyperPara[1][i])

In [15]:
adata_whole

AnnData object with n_obs × n_vars = 50766 × 19640
    obs: 'Age', 'Core_datasets', 'Cre', 'Data Source', 'Data location', 'Development stage', 'Disease', 'Dissociation_enzyme', 'FACs', 'Gene Type', 'Histology', 'Journal', 'Knownout_gene', 'Machine', 'Mandibular_Maxillary', 'Molar_Incisor', 'Project', 'Related assay', 'Sample', 'Sex', 'Species', 'Stage', 'Strain', 'Tooth position', 'Treatment', 'coarse_anno_1', 'compl', 'log10_total_counts', 'log1p_n_genes_by_counts', 'log1p_total_counts', 'mito_frac', 'nCount_RNA', 'nFeature_RNA', 'n_genes_by_counts', 'n_genes_detected', 'orig.ident', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', 'pct_counts_in_top_50_genes', 'ribo_frac', 'size_factors', 'total_counts'
    var: 'gene_symbol', 'gene_symbols', 'n_cells', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'highly_variable_nbatches', 'highly_variable_intersection'
    uns: 'X_name', 'hvg', 'log1p'
    layers: 'counts'