In [None]:
import torch
import scvi
import anndata as ad
import scanpy as sc
import pandas as pd
import numpy as np
import scipy

In [None]:
def malignant_cell_collection(adata, malignant_cell_incices, label_key):
    mdata = adata[adata.obs[label_key].isin(adata.obs[label_key].unique()[[i for i in malignant_cell_incices]])]
    return mdata

In [None]:
def model_preprocessing(adata, batch_key):
    
    if "counts" not in adata.layers.keys():
        adata.layers["counts"] = adata.X.copy()

    if scipy.sparse.issparse(adata.layers['counts']):
        if np.any([(k%1) for k in adata.layers['counts'].todense().ravel()]):
            adata.layers['counts'] = np.round(adata.layers['counts'].todense())
    else:
        if np.any([(k%1) for k in adata.layers['counts'].ravel()]):
            adata.layers['counts'] = np.round(adata.layers['counts'])
            
    sc.pp.normalize_total(adata, target_sum=10e4)
    sc.pp.log1p(adata)
    adata.raw = adata
    sc.pp.highly_variable_genes(
    adata,
    n_top_genes=2000,
    subset=True,
    layer="counts",
    batch_key=batch_key
)

In [None]:
def model_train(adata, layer_key, batch_key, n_layers, n_hidden, n_latent, lr):
    adata = adata.copy()
    scvi.model.SCVI.setup_anndata(adata, layer=layer_key, batch_key = batch_key)
    model = scvi.model.SCVI(adata, n_layers = n_layers, n_hidden=n_hidden, n_latent=n_latent)
    model.train(max_epochs=100, validation_size=0.1, check_val_every_n_epoch=5, early_stopping=True, 
    early_stopping_monitor='elbo_validation', early_stopping_patience = 20, plan_kwargs={'lr':lr})
    return model

In [None]:
def plot_reconstruction_loss_and_elbo(model):
    train_recon_loss = model.history['reconstruction_loss_train']
    elbo_train = model.history['elbo_train']
    elbo_val = model.history['elbo_validation']
    val_recon_loss = model.history['reconstruction_loss_validation']
    ax = train_recon_loss.plot()
    elbo_train.plot(ax = ax)
    elbo_val.plot(ax = ax)
    val_recon_loss.plot(ax = ax)

In [None]:
def get_latent_UMAP(model, adata, batch_key, label_key, added_latent_key, print_UMAP):
    latent = model.get_latent_representation()
    adata.obsm[added_latent_key] = latent
    sc.pp.neighbors(adata, use_rep=added_latent_key, n_neighbors=20)
    sc.tl.umap(adata, min_dist=0.3)
    if print_UMAP:
        sc.pl.umap(adata, color = [label_key, batch_key])
    return added_latent_key

In [None]:
from typing import Optional, List

import numpy as np
import pandas as pd
import scanpy as sc
from anndata import AnnData
from scipy import sparse


def save_latent(adata: AnnData, latent_key: str, dataset_name: str) -> None:
    if latent_key in adata.obsm_keys():
        latent = pd.DataFrame(adata.obsm[latent_key], index=adata.obs_names)
        latent.to_csv(f"{dataset_name}_latent.csv")

def plot_integration(
        adata: AnnData, dataset_name: str, batch_key: str, group_key: str
) -> None:
    sc.settings.figdir = "."
    sc.tl.umap(adata)
    sc.pl.umap(
        adata, color=[batch_key, group_key], save=f"_{dataset_name}.png", show=False
    )


def get_partition(gpu: bool) -> str:
    if gpu:
        return "gpu"
    return "compute"


def get_gres(gpu: bool) -> Optional[str]:
    if gpu:
        return "gpu:rtx2080ti:1"
    return None


def diffusion_nn(adata, k, max_iterations=26):
    """
    Diffusion neighbourhood score
    This function generates a nearest neighbour list from a connectivities matrix
    as supplied by BBKNN or Conos. This allows us to select a consistent number
    of nearest neighbours across all methods.
    Return:
       `k_indices` a numpy.ndarray of the indices of the k-nearest neighbors.
    """
    if "neighbors" not in adata.uns:
        raise ValueError(
            "`neighbors` not in adata object. " "Please compute a neighbourhood graph!"
        )

    if "connectivities" not in adata.obsp:
        raise ValueError(
            "`connectivities` not in `adata.obsp`. "
            "Please pass an object with connectivities computed!"
        )

    T = adata.obsp["connectivities"]

    # Row-normalize T
    T = sparse.diags(1 / T.sum(1).A.ravel()) * T

    T_agg = T ** 3
    M = T + T ** 2 + T_agg
    i = 4

    while ((M > 0).sum(1).min() < (k + 1)) and (i < max_iterations):
        # note: k+1 is used as diag is non-zero (self-loops)
        print(f"Adding diffusion to step {i}")
        T_agg *= T
        M += T_agg
        i += 1

    if (M > 0).sum(1).min() < (k + 1):
        raise ValueError(
            f"could not find {k} nearest neighbors in {max_iterations}"
            "diffusion steps.\n Please increase max_iterations or reduce"
            " k.\n"
        )

    M.setdiag(0)
    k_indices = np.argpartition(M.A, -k, axis=1)[:, -k:]

    return k_indices


def diffusion_conn(adata, min_k=50, copy=True, max_iterations=26):
    """
    Diffusion for connectivites matrix extension
    This function performs graph diffusion on the connectivities matrix until a
    minimum number `min_k` of entries per row are non-zero.
    Note:
    Due to self-loops min_k-1 non-zero connectivies entries is actually the stopping
    criterion. This is equivalent to `sc.pp.neighbors`.
    Returns:
       The diffusion-enhanced connectivities matrix of a copy of the AnnData object
       with the diffusion-enhanced connectivities matrix is in
       `adata.uns["neighbors"]["conectivities"]`
    """
    if "neighbors" not in adata.uns:
        raise ValueError(
            "`neighbors` not in adata object. " "Please compute a neighbourhood graph!"
        )

    if "connectivities" not in adata.obsp:
        raise ValueError(
            "`connectivities` not in `adata.obsp`. "
            "Please pass an object with connectivities computed!"
        )

    T = adata.obsp["connectivities"]

    # Normalize T with max row sum
    # Note: This keeps the matrix symmetric and ensures |M| doesn't keep growing
    T = sparse.diags(1 / np.array([T.sum(1).max()] * T.shape[0])) * T

    M = T

    # Check for disconnected component
    n_comp, labs = sparse.csgraph.connected_components(
        adata.obsp["connectivities"], connection="strong"
    )

    if n_comp > 1:
        tab = pd.value_counts(labs)
        small_comps = tab.index[tab < min_k]
        large_comp_mask = np.array(~pd.Series(labs).isin(small_comps))
    else:
        large_comp_mask = np.array([True] * M.shape[0])

    T_agg = T
    i = 2
    while ((M[large_comp_mask, :][:, large_comp_mask] > 0).sum(1).min() < min_k) and (
            i < max_iterations
    ):
        print(f"Adding diffusion to step {i}")
        T_agg *= T
        M += T_agg
        i += 1

    if (M[large_comp_mask, :][:, large_comp_mask] > 0).sum(1).min() < min_k:
        raise ValueError(
            "could not create diffusion connectivities matrix"
            f"with at least {min_k} non-zero entries in"
            f"{max_iterations} iterations.\n Please increase the"
            "value of max_iterations or reduce k_min.\n"
        )

    M.setdiag(0)

    if copy:
        adata_tmp = adata.copy()
        adata_tmp.uns["neighbors"].update({"diffusion_connectivities": M})
        return adata_tmp

    else:
        return M


def split_batches(adata: AnnData, batch_key: str) -> List[AnnData]:
    splits = []
    for batch in adata.obs[batch_key].cat.categories:
        splits.append(adata[adata.obs[batch_key] == batch].copy())
    return splits

In [None]:
from typing import Literal  # pytype: disable=not-supported-yet
from typing import Optional

import anndata as an  # pytype: disable=import-error
import numpy as np
import pydantic  # pytype: disable=import-error
import scanpy as sc  # pytype: disable=import-error
from anndata import AnnData

_SupportedMetric = Literal[
    "cityblock",
    "cosine",
    "euclidean",
    "l1",
    "l2",
    "manhattan",
    "braycurtis",
    "canberra",
    "chebyshev",
    "correlation",
    "dice",
    "hamming",
    "jaccard",
    "kulsinski",
    "mahalanobis",
    "minkowski",
    "rogerstanimoto",
    "russellrao",
    "seuclidean",
    "sokalmichener",
    "sokalsneath",
    "sqeuclidean",
    "yule",
]


class NeighborsGraphConfig(pydantic.BaseModel):
    """Settings for neighborhood graph computation.
    For description, see
    https://scanpy.readthedocs.io/en/stable/generated/scanpy.pp.neighbors.html
    """

    n_neighbors: int = pydantic.Field(default=15)
    n_pcs: Optional[int] = pydantic.Field(default=None)
    knn: bool = pydantic.Field(default=True)
    # TODO(Pawel): Check whether we can support other methods as well.
    method: Literal["umap"] = pydantic.Field(default="umap")
    metric: _SupportedMetric = pydantic.Field(default="euclidean")


class _LeidenBaseConfig(pydantic.BaseModel):
    nngraph: NeighborsGraphConfig = pydantic.Field(default_factory=NeighborsGraphConfig)
    random_state: int = pydantic.Field(default=0)
    directed: bool = pydantic.Field(default=True)
    use_weights: bool = pydantic.Field(default=True)
    n_iterations: int = pydantic.Field(default=-1)


class BinSearchSettings(pydantic.BaseModel):
    start: pydantic.PositiveFloat = pydantic.Field(
        default=1e-3, description="The minimal resolution."
    )
    end: pydantic.PositiveFloat = pydantic.Field(
        default=10.0, description="The maximal resolution."
    )
    epsilon: pydantic.PositiveFloat = pydantic.Field(
        default=1e-3,
        description="Controls the maximal number of iterations before throwing lookup "
        "error.",
    )

    @pydantic.validator("end")
    def validate_end_greater_than_start(cls, v, values, **kwargs) -> float:
        if v <= values["start"]:
            raise ValueError("In binary search end must be greater than start.")
        return v


class LeidenNClusterConfig(_LeidenBaseConfig):
    clusters: int = pydantic.Field(
        default=5, description="The number of clusters to be returned."
    )
    binsearch: BinSearchSettings = pydantic.Field(default_factory=BinSearchSettings)


class LeidenNCluster:
    def __init__(self, settings: LeidenNClusterConfig) -> None:
        self._settings = settings

    def fit_predict(self, adata: AnnData, key_added: str) -> np.ndarray:
        for offset in [0, 20_000, 30_000, 40_000]:
            points = _binary_search_leiden_resolution(
                adata,
                k=self._settings.clusters,
                key_added=key_added,
                random_state=self._settings.random_state + offset,
                directed=self._settings.directed,
                use_weights=self._settings.use_weights,
                start=self._settings.binsearch.start,
                end=self._settings.binsearch.end,
                _epsilon=self._settings.binsearch.epsilon,
            )
            if points is not None:
                break
        # In case that for multiple random seeds we didn't find a resolution that
        # matches the number of clusters, we raise a ValueError.
        else:
            raise ValueError(
                f"No resolution for the number of clusters {self._settings.clusters}"
                f" found."
            )

        return points.obs[key_added].astype(int).values


def _binary_search_leiden_resolution(
    adata: an.AnnData,
    k: int,
    start: float,
    end: float,
    key_added: str,
    random_state: int,
    directed: bool,
    use_weights: bool,
    _epsilon: float,
) -> Optional[an.AnnData]:
    """Binary search to get the resolution corresponding
    to the right k."""
    # We try the resolution which is in the middle of the interval
    res = 0.5 * (start + end)

    # Run Leiden clustering
    sc.tl.leiden(
        adata,
        resolution=res,
        key_added=key_added,
        random_state=random_state,
        directed=directed,
        use_weights=use_weights,
    )

    # Get the number of clusters found
    selected_k = adata.obs[key_added].nunique()
    if selected_k == k:
        return adata

    # If the start and the end are too close (and there is no point in doing another
    # iteration), we raise an error that one can't find the required number of clusters
    if abs(end - start) < _epsilon * res:
        return None

    if selected_k > k:
        return _binary_search_leiden_resolution(
            adata,
            k=k,
            start=start,
            end=res,
            key_added=key_added,
            random_state=random_state,
            directed=directed,
            _epsilon=_epsilon,
            use_weights=use_weights,
        )
    else:
        return _binary_search_leiden_resolution(
            adata,
            k=k,
            start=res,
            end=end,
            key_added=key_added,
            random_state=random_state,
            directed=directed,
            _epsilon=_epsilon,
            use_weights=use_weights,
        )

In [None]:
import warnings
from dataclasses import dataclass, field
from pathlib import Path
from timeit import default_timer as timer
from typing import Tuple, List, Optional, Union

import bbknn
import numpy as np
import pandas as pd
import scanpy as sc
import scanpy.external as sce
import scvi
from anndata import AnnData
from cansig.integration.model import CanSig
from omegaconf import MISSING

#from utils import split_batches


@dataclass
class ModelConfig:
    name: str = MISSING
    gpu: bool = False
    malignant_only: bool = True
    batch_key: str = "sample_id"
    latent_key: str = "latent"
    n_top_genes: int = 2000



def run_model(adata: AnnData, cfg) -> Tuple[AnnData, float]:
    start = timer()
    if cfg.name == "bbknn":
        adata = run_bbknn(adata, config=cfg)
    elif cfg.name == "scvi":
        adata = run_scvi(adata, config=cfg)
    elif cfg.name == "scanorama":
        adata = run_scanorama(adata, config=cfg)
    elif cfg.name == "harmony":
        adata = run_harmony(adata, config=cfg)
    elif cfg.name == "cansig":
        adata = run_cansig(adata, config=cfg)
    elif cfg.name == "nmm":
        adata = run_mnn(adata, config=cfg)
    elif cfg.name == "combat":
        adata = run_combat(adata, config=cfg)
    elif cfg.name == "desc":
        adata = run_desc(adata, config=cfg)
    elif cfg.name == "dhaka":
        adata = run_dhaka(adata, config=cfg)
    elif cfg.name == "scanvi":
        adata = run_scanvi(adata, config=cfg)
    elif cfg.name == "trvaep":
        adata = run_trvaep(adata, config=cfg)
    elif cfg.name == "scgen":
        adata = run_scgen(adata, config=cfg)
    else:
        raise NotImplementedError(f"{cfg.name} is not implemented.")
    run_time = timer() - start
    return adata, run_time


@dataclass
class DhakaConfig(ModelConfig):
    name: str = "dhaka"
    gpu: bool = True

    n_latent: int = 3
    # Data preprocessing
    n_genes: int = 5000
    total_expression: float = 1e6
    pseudocounts: int = 1
    # Training
    epochs: int = 5
    batch_size: int = 50
    learning_rate: float = 1e-4
    clip_norm: float = 2.0
    # Magic flag
    scale_reconstruction_loss: bool = True


def run_dhaka(adata: AnnData, config: DhakaConfig) -> AnnData:
    import dhaka.api as dh

    new_config = dh.DhakaConfig(
        n_latent=config.n_latent,
        n_genes=config.n_genes,
        total_expression=config.total_expression,
        pseudocounts=config.pseudocounts,
        epochs=config.epochs,
        batch_size=config.batch_size,
        learning_rate=config.learning_rate,
        clip_norm=config.clip_norm,
        scale_reconstruction_loss=config.scale_reconstruction_loss
    )

    return dh.run_dhaka(adata, config=new_config, key_added=config.latent_key)


@dataclass
class ScanVIConfig(ModelConfig):
    name: str = "scanvi"
    malignant_only: bool = False


def run_scanvi(adata: AnnData, config: ScanVIConfig) -> AnnData:
    raise NotImplementedError("This method requires several celltypes to run.")


@dataclass
class TrVAEpConfig(ModelConfig):
    name: str = "trvaep"
    n_top_genes: int = 3000
    n_latent: int = 10
    alpha: float = 1e-4
    layer1: int = 64
    layer2: int = 32
    seed: int = 42  # Random seed
    # Training params
    epochs: int = 300
    batch_size: int = 1024
    early_patience: int = 50
    learning_rate: float = 1e-3


def _trvaep_normalize(adata: AnnData, n_top_genes: int) -> AnnData:
    sc.pp.normalize_per_cell(adata)
    sc.pp.log1p(adata)
    sc.pp.highly_variable_genes(adata, n_top_genes=n_top_genes)
    adata = adata[:, adata.var['highly_variable']]
    return adata


def run_trvaep(adata: AnnData, config: TrVAEpConfig) -> AnnData:
    """trVAE (PyTorch version) wrapper function. It's a slightly modified scIB code."""
    import trvaep
    from scipy.sparse import issparse

    n_batches = adata.obs[config.batch_key].nunique()

    adata = _trvaep_normalize(adata, n_top_genes=config.n_top_genes)

    # Densify the data matrix
    if issparse(adata.X):
        adata.X = adata.X.A

    model = trvaep.CVAE(
        adata.n_vars,
        num_classes=n_batches,
        encoder_layer_sizes=[config.layer1, config.layer2],  # Originally [64, 32]
        decoder_layer_sizes=[config.layer2, config.layer1],  # Originally [32, 64]
        latent_dim=config.n_latent,
        alpha=config.alpha,
        use_mmd=True,
        beta=1,
        output_activation="ReLU",
    )

    # Note: set seed for reproducibility of results
    trainer = trvaep.Trainer(
        model,
        adata,
        condition_key=config.batch_key,
        seed=config.seed,
        learning_rate=config.learning_rate
    )

    trainer.train_trvae(
        n_epochs=config.epochs,
        batch_size=config.batch_size,
        early_patience=config.early_patience
    )

    # Get the dominant batch covariate
    main_batch = adata.obs[config.batch_key].value_counts().idxmax()

    # Get latent representation
    latent_y = model.get_y(
        adata.X,
        c=model.label_encoder.transform(np.tile(np.array([main_batch]), len(adata))),
    )
    adata.obsm[config.latent_key] = latent_y

    return adata


class ScGENConfig(ModelConfig):
    name: str = "scgen"
    malignant_only: bool = False  # Probably -- hard to be 100% sure


def run_scgen(adata: AnnData, config: ScGENConfig) -> AnnData:
    raise NotImplementedError("scGEN model in scIB doesn't add low-dimensional representations, "
                              "so that the implementation is tricky. Moreover, it requires other cell types.")


@dataclass
class BBKNNConfig(ModelConfig):
    name: str = "bbknn"
    neighbors_within_batch: int = 3


def run_bbknn(adata: AnnData, config: BBKNNConfig) -> AnnData:
    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)
    sc.pp.highly_variable_genes(adata, n_top_genes=config.n_top_genes, subset=True)
    sc.pp.scale(adata)
    sc.tl.pca(adata)
    bbknn.bbknn(
        adata,
        batch_key=config.batch_key,
        neighbors_within_batch=config.neighbors_within_batch,
    )

    return adata


@dataclass
class SCVIConfig(ModelConfig):
    name: str = "scvi"
    gpu: bool = True
    covariates: Optional[List] = field(
        default_factory=lambda: ["S_score", "G2M_score"]
    )
    n_latent: int = 4
    n_hidden: int = 128
    n_layers: int = 1
    max_epochs: int = 400


def run_scvi(adata: AnnData, config: SCVIConfig) -> AnnData:
    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)
    sc.pp.highly_variable_genes(adata, n_top_genes=config.n_top_genes)
    bdata = adata[:, adata.var["highly_variable"]].copy()

    scvi.model.SCVI.setup_anndata(bdata, layer="counts", batch_key=config.batch_key,
                                  continuous_covariate_keys=config.covariates)
    model = scvi.model.SCVI(
        bdata,
        n_latent=config.n_latent,
        n_hidden=config.n_hidden,
        n_layers=config.n_layers,
    )
    model.train(
        max_epochs=config.max_epochs,
        # TODO: add this to cansig!
        train_size=1.0,
        plan_kwargs={"n_epochs_kl_warmup": config.max_epochs},
    )
    adata.obsm[config.latent_key] = model.get_latent_representation()
    return adata


@dataclass
class ScanoramaConfig(ModelConfig):
    name: str = "scanorama"
    knn: int = 20
    sigma: float = 15.0
    approx: bool = True
    alpha: float = 0.1


def run_scanorama(adata: AnnData, config: ScanoramaConfig) -> AnnData:
    # scanorama requires that cells from the same batch must
    # be contiguously stored in adata
    idx = np.argsort(adata.obs[config.batch_key])
    adata = adata[idx, :].copy()
    sc.pp.recipe_zheng17(adata, n_top_genes=config.n_top_genes)
    sc.tl.pca(adata)
    sce.pp.scanorama_integrate(
        adata,
        config.batch_key,
        adjusted_basis=config.latent_key,
        knn=config.knn,
        sigma=config.sigma,
        approx=config.approx,
        alpha=config.alpha,
    )
    return adata


@dataclass
class HarmonyConfig(ModelConfig):
    name: str = "harmony"
    max_iter_harmony: int = 100
    max_iter_kmeans: int = 100
    theta: float = 2.0
    lamb: float = 1.0
    epsilon_cluster: float = 1e-5
    epsilon_harmony: float = 1e-4
    random_state: int = 0


def run_harmony(adata: AnnData, config: HarmonyConfig) -> AnnData:
    sc.pp.recipe_zheng17(adata, n_top_genes=config.n_top_genes)
    sc.tl.pca(adata)
    sce.pp.harmony_integrate(
        adata,
        config.batch_key,
        theta=config.theta,
        lamb=config.lamb,
        adjusted_basis=config.latent_key,
        max_iter_harmony=config.max_iter_harmony,
        max_iter_kmeans=config.max_iter_kmeans,
        epsilon_cluster=config.epsilon_cluster,
        epsilon_harmony=config.epsilon_harmony,
        random_state=config.random_state,
    )

    return adata


@dataclass
class CanSigConfig(ModelConfig):
    name: str = "cansig"
    gpu: bool = True
    malignant_only: bool = False
    n_latent: int = 4
    n_layers: int = 1
    n_hidden: int = 128
    n_latent_batch_effect: int = 5
    n_latent_cnv: int = 10
    max_epochs: int = 400
    cnv_max_epochs: int = 400
    batch_effect_max_epochs: int = 400
    beta: float = 1.0
    batch_effect_beta: float = 1.0
    covariates: Optional[List] = field(
        default_factory=lambda: ["S_score", "G2M_score"]
    )
    annealing: str = "linear"
    malignant_key: str = "malignant_key"
    malignant_cat: str = "malignant"
    non_malignant_cat: str = "non-malignant"
    subclonal_key: str = "subclonal"
    celltype_key: str = "program"


def run_cansig(adata: AnnData, config: CanSigConfig) -> AnnData:
    bdata = CanSig.preprocessing(
        adata.copy(),
        n_highly_variable_genes=config.n_top_genes,
        malignant_key=config.malignant_key,
        malignant_cat=config.malignant_cat,
    )
    CanSig.setup_anndata(
        bdata,
        celltype_key=config.celltype_key,
        malignant_key=config.malignant_key,
        malignant_cat=config.malignant_cat,
        non_malignant_cat=config.non_malignant_cat,
        continuous_covariate_keys=config.covariates,
        layer="counts",
    )
    model = CanSig(
        bdata,
        n_latent=config.n_latent,
        n_layers=config.n_layers,
        n_hidden=config.n_hidden,
        n_latent_cnv=config.n_latent_cnv,
        n_latent_batch_effect=config.n_latent_batch_effect,
        sample_id_key=config.batch_key,
        subclonal_key=config.subclonal_key,
    )

    model.train(
        max_epochs=config.max_epochs,
        cnv_max_epochs=config.cnv_max_epochs,
        batch_effect_max_epochs=config.batch_effect_max_epochs,
        train_size=1.0,
        plan_kwargs={
            "n_epochs_kl_warmup": config.max_epochs,
            "beta": config.beta,
            "annealing": config.annealing,
        },
        batch_effect_plan_kwargs={"beta": config.batch_effect_beta},
    )

    save_model_history(model)

    save_latent_spaces(model, adata)

    idx = model.get_index(malignant_cells=True)
    adata = adata[idx, :].copy()
    adata.obsm[config.latent_key] = model.get_latent_representation()

    return adata


@dataclass
class MNNConfig(ModelConfig):
    name: str = "nmm"
    k: int = 20
    sigma: float = 1.


def run_mnn(adata: AnnData, config: MNNConfig) -> AnnData:
    split = split_batches(adata, config.batch_key)

    bdata = adata.copy()
    sc.pp.normalize_total(bdata, target_sum=1e4)
    sc.pp.log1p(bdata)
    sc.pp.highly_variable_genes(bdata, n_top_genes=config.n_top_genes)
    hvg = bdata.var.index[bdata.var["highly_variable"]].tolist()

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        corrected, _, _ = sce.pp.mnn_correct(*split, var_subset=hvg)
    corrected = corrected[0].concatenate(corrected[1:])

    corrected.obsm[config.latent_key] = corrected.X

    return corrected


@dataclass
class CombatConfig(ModelConfig):
    name: str = "combat"
    cell_cycle: bool = False
    log_counts: bool = False


def run_combat(adata: AnnData, config: CombatConfig) -> AnnData:
    covariates = []
    if config.cell_cycle:
        covariates += ["G2M_score", "S_score"]

    if config.log_counts:
        covariates += ["log_counts"]

    covariates = covariates or None

    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)
    sc.pp.highly_variable_genes(adata, n_top_genes=config.n_top_genes)
    adata = adata[:, adata.var["highly_variable"]].copy()

    X = sc.pp.combat(adata, config.batch_key, covariates=covariates,
                     inplace=False)
    adata.obsm[config.latent_key] = X
    return adata


@dataclass
class DescConfig(ModelConfig):
    name: str = "desc"
    gpu: bool = False  # TODO: add GPU acceleration
    res: float = 0.8
    n_top_genes: int = 2000
    n_neighbors: int = 10
    batch_size: int = 256
    tol: float = 0.005
    learning_rate: float = 500
    save_dir: Union[str, Path] = "."


def run_desc(adata: AnnData, config: DescConfig) -> AnnData:
    import desc
    # Preprocessing and parameters taken from https://github.com/eleozzr/desc/issues/28.
    sc.pp.normalize_per_cell(adata, counts_per_cell_after=1e4)
    sc.pp.log1p(adata)
    sc.pp.highly_variable_genes(adata, n_top_genes=config.n_top_genes, inplace=True)
    sc.pp.scale(adata, zero_center=True, max_value=6)
    adata = desc.scale_bygroup(adata, groupby=config.batch_key, max_value=6)
    adata_out = desc.train(adata,
                           dims=[adata.shape[1], 128, 32],  # or set 256
                           tol=config.tol,
                           # suggest 0.005 when the dataset less than 5000
                           n_neighbors=config.n_neighbors,
                           batch_size=config.batch_size,
                           louvain_resolution=config.res,
                           save_dir=config.save_dir,
                           do_tsne=False,
                           use_GPU=config.gpu,
                           num_Cores=8,
                           save_encoder_weights=False,
                           save_encoder_step=2,
                           use_ae_weights=False,
                           do_umap=False,
                           num_Cores_tsne=4,
                           learning_rate=config.learning_rate)

    adata_out.obsm[config.latent_key] = adata_out.obsm["X_Embeded_z" + str(config.res)]

    return adata_out


def save_model_history(model: CanSig, name: str = ""):
    modules = {
        "combined": model.module,
        "batch_effect": model.module_batch_effect,
        "cnv": model.module_cnv,
    }

    for key, module in modules.items():
        df = pd.concat([df for df in module.history.values()], axis=1)
        df.to_csv(f"{key}_{name}.csv")


def save_latent_spaces(model: CanSig, adata: AnnData, name: str = ""):
    latent = model.get_batch_effect_latent_representation()
    idx = model.get_index(malignant_cells=False)
    df = pd.DataFrame(latent, index=adata.obs_names[idx])
    df.to_csv(f"{name}_batch_effect_latent.csv")

    latent = model.get_cnv_latent_representation()
    idx = model.get_index(malignant_cells=True)
    df = pd.DataFrame(latent, index=adata.obs_names[idx])
    df.to_csv(f"{name}_cnv_latent.csv")

In [None]:
from dataclasses import dataclass
from typing import Dict, Optional, Tuple

import numpy as np
import pandas as pd
import scanpy as sc
import scipy
from anndata import AnnData
from scETM.eval_utils import (
    calculate_kbet,
    _get_knn_indices,
)
from sklearn.metrics import (
    adjusted_rand_score,
    normalized_mutual_info_score,
    silhouette_score,
    calinski_harabasz_score,
    davies_bouldin_score,
)

#from _cluster import LeidenNClusterConfig, LeidenNCluster
#from models import ModelConfig
#from utils import diffusion_nn, diffusion_conn


@dataclass
class MetricsConfig:
    n_neighbors: int = 50
    group_key: str = "program"
    cluster_key: str = "leiden"
    n_random_seeds: int = 10
    clustering_range: Tuple[int] = tuple(range(2, 6))


def run_metrics(adata: AnnData, config: ModelConfig, metric_config: MetricsConfig):
    metrics = {}

    compute_neighbors(
        adata,
        latent_key=config.latent_key,
        n_neighbors=metric_config.n_neighbors,
    )

    # Biological conservation metrics
    metrics.update(compute_asw(adata, metric_config.group_key, config.latent_key))
    metrics.update(
        compute_davies_bouldin(adata, metric_config.group_key, config.latent_key)
    )
    metrics.update(
        compute_calinski_harabasz(adata, metric_config.group_key, config.latent_key)
    )
    metrics.update(compute_ari_nmi(adata, metric_config))

    # Batch effect metrics
    metrics.update(
        kbet(
            adata,
            latent_key=config.latent_key,
            label_key=metric_config.group_key,
            batch_key=config.batch_key,
        )
    )

    return metrics


def kbet(
        adata: AnnData, latent_key: str, label_key: str, batch_key: str
) -> Dict[str, float]:
    """This implementation of kBet is taken from scib and combined with the
    kbet_single implementation from scETM."""
    adata.strings_to_categoricals()
    if latent_key in adata.obsm_keys():
        adata_tmp = sc.pp.neighbors(adata, n_neighbors=50, use_rep=latent_key,
                                    copy=True)
    else:
        adata_tmp = adata.copy()
    # check if pre-computed neighbours are stored in input file
    connectivities = diffusion_conn(adata_tmp, min_k=50, copy=False)
    adata_tmp.obsp["connectivities"] = connectivities

    # set upper bound for k0
    size_max = 2 ** 31 - 1

    # prepare call of kBET per cluster
    kBET_scores = {"cluster": [], "kBET": []}
    for clus in adata_tmp.obs[label_key].unique():

        # subset by label
        adata_sub = adata_tmp[adata_tmp.obs[label_key] == clus, :].copy()

        # check if neighborhood size too small or only one batch in subset
        if np.logical_or(
                adata_sub.n_obs < 10, len(adata_sub.obs[batch_key].cat.categories) == 1
        ):
            print(f"{clus} consists of a single batch or is too small. Skip.")
            score = np.nan
        else:
            quarter_mean = np.floor(
                np.mean(adata_sub.obs[batch_key].value_counts()) / 4
            ).astype("int")
            k0 = np.min([70, np.max([10, quarter_mean])])
            # check k0 for reasonability
            if k0 * adata_sub.n_obs >= size_max:
                k0 = np.floor(size_max / adata_sub.n_obs).astype("int")

            n_comp, labs = scipy.sparse.csgraph.connected_components(
                adata_sub.obsp["connectivities"], connection="strong"
            )

            if n_comp == 1:  # a single component to compute kBET on
                adata_sub.obsm["knn_indices"] = diffusion_nn(adata_sub, k=k0)
                adata_sub.uns["neighbors"]["params"]["n_neighbors"] = k0

                score = calculate_kbet(
                    adata_sub,
                    use_rep="",
                    batch_col=batch_key,
                    calc_knn=False,
                    n_neighbors=adata_sub.uns["neighbors"]["params"]["n_neighbors"],
                )[2]

            else:
                # check the number of components where kBET can be computed upon
                comp_size = pd.value_counts(labs)
                # check which components are small
                comp_size_thresh = 3 * k0
                idx_nonan = np.flatnonzero(
                    np.in1d(labs, comp_size[comp_size >= comp_size_thresh].index)
                )

                # check if 75% of all cells can be used for kBET run
                if len(idx_nonan) / len(labs) >= 0.75:
                    # create another subset of components, assume they are not visited
                    # in a diffusion process
                    adata_sub_sub = adata_sub[idx_nonan, :].copy()
                    adata_sub_sub.obsm["knn_indices"] = diffusion_nn(
                        adata_sub_sub, k=k0
                    )
                    adata_sub_sub.uns["neighbors"]["params"]["n_neighbors"] = k0

                    score = calculate_kbet(
                        adata_sub_sub,
                        use_rep="",
                        batch_col=batch_key,
                        calc_knn=False,
                        n_neighbors=adata_sub_sub.uns["neighbors"]["params"][
                            "n_neighbors"
                        ],
                    )[2]

                else:  # if there are too many too small connected components,
                    score = 0  # i.e. 100% rejection

        kBET_scores["cluster"].append(clus)
        kBET_scores["kBET"].append(score)

    kBET_scores = pd.DataFrame.from_dict(kBET_scores)
    kBET_scores = kBET_scores.reset_index(drop=True)

    final_score = np.nanmean(kBET_scores["kBET"]).item()

    return {"k_bet_acceptance_rate": final_score}


def compute_ari(adata: AnnData, group_key: str, cluster_key: str) -> float:
    return adjusted_rand_score(adata.obs[group_key], adata.obs[cluster_key])


def compute_nmi(adata: AnnData, group_key: str, cluster_key: str) -> float:
    return normalized_mutual_info_score(adata.obs[group_key], adata.obs[cluster_key])

def compute_asw(
        adata: AnnData, group_key: str, latent_key: str
) -> Dict[str, Optional[float]]:
    if latent_key not in adata.obsm_keys():
        return {"average_silhouette_width": np.nan}
    asw = silhouette_score(X=adata.obsm[latent_key], labels=adata.obs[group_key])
    asw = (asw + 1) / 2

    return {"average_silhouette_width": asw}


def compute_calinski_harabasz(
        adata: AnnData, group_key: str, latent_key: str
) -> Dict[str, Optional[float]]:
    if latent_key not in adata.obsm_keys():
        return {"calinski_harabasz_score": np.nan}
    score = calinski_harabasz_score(adata.obsm[latent_key], adata.obs[group_key])
    return {"calinski_harabasz_score": score}


def compute_davies_bouldin(
        adata: AnnData, group_key: str, latent_key: str
) -> Dict[str, Optional[float]]:
    if latent_key not in adata.obsm_keys():
        return {"davies_bouldin": np.nan}
    score = davies_bouldin_score(adata.obsm[latent_key], adata.obs[group_key])
    return {"davies_bouldin": score}


def compute_ari_nmi(
        adata: AnnData, metric_config: MetricsConfig
) -> Dict[str, Optional[float]]:
    metrics = {}
    for k in metric_config.clustering_range:
        for random_seed in range(metric_config.n_random_seeds):
            try:
                leiden_config = LeidenNClusterConfig(
                    random_state=random_seed, clusters=k
                )
                cluster_algo = LeidenNCluster(leiden_config)
                cluster_algo.fit_predict(adata, key_added=metric_config.cluster_key)
            except ValueError as e:
                print(e)
                ari = np.nan
                nmi = np.nan
            else:
                ari = compute_ari(adata, metric_config.group_key,
                                  metric_config.cluster_key)
                nmi = compute_nmi(adata, metric_config.group_key,
                                  metric_config.cluster_key)

            metrics[f"ari_{k}_{random_seed}"] = ari
            metrics[f"nmi_{k}_{random_seed}"] = nmi

    return metrics


def compute_neighbors(adata: AnnData, latent_key: str, n_neighbors: int):
    if latent_key in adata.obsm.keys():
        knn_indices = _get_knn_indices(
            adata,
            use_rep=latent_key,
            n_neighbors=n_neighbors,
            calc_knn=True,
        )
        adata.obsm["knn_indices"] = knn_indices

In [None]:
from typing import Literal  # pytype: disable=not-supported-yet
from typing import Optional

import anndata as an  # pytype: disable=import-error
import numpy as np
import pydantic  # pytype: disable=import-error
import scanpy as sc  # pytype: disable=import-error
from anndata import AnnData

_SupportedMetric = Literal[
    "cityblock",
    "cosine",
    "euclidean",
    "l1",
    "l2",
    "manhattan",
    "braycurtis",
    "canberra",
    "chebyshev",
    "correlation",
    "dice",
    "hamming",
    "jaccard",
    "kulsinski",
    "mahalanobis",
    "minkowski",
    "rogerstanimoto",
    "russellrao",
    "seuclidean",
    "sokalmichener",
    "sokalsneath",
    "sqeuclidean",
    "yule",
]


class NeighborsGraphConfig(pydantic.BaseModel):
    """Settings for neighborhood graph computation.
    For description, see
    https://scanpy.readthedocs.io/en/stable/generated/scanpy.pp.neighbors.html
    """

    n_neighbors: int = pydantic.Field(default=15)
    n_pcs: Optional[int] = pydantic.Field(default=None)
    knn: bool = pydantic.Field(default=True)
    # TODO(Pawel): Check whether we can support other methods as well.
    method: Literal["umap"] = pydantic.Field(default="umap")
    metric: _SupportedMetric = pydantic.Field(default="euclidean")


class _LeidenBaseConfig(pydantic.BaseModel):
    nngraph: NeighborsGraphConfig = pydantic.Field(default_factory=NeighborsGraphConfig)
    random_state: int = pydantic.Field(default=0)
    directed: bool = pydantic.Field(default=True)
    use_weights: bool = pydantic.Field(default=True)
    n_iterations: int = pydantic.Field(default=-1)


class BinSearchSettings(pydantic.BaseModel):
    start: pydantic.PositiveFloat = pydantic.Field(
        default=1e-3, description="The minimal resolution."
    )
    end: pydantic.PositiveFloat = pydantic.Field(
        default=10.0, description="The maximal resolution."
    )
    epsilon: pydantic.PositiveFloat = pydantic.Field(
        default=1e-3,
        description="Controls the maximal number of iterations before throwing lookup "
        "error.",
    )

    @pydantic.validator("end")
    def validate_end_greater_than_start(cls, v, values, **kwargs) -> float:
        if v <= values["start"]:
            raise ValueError("In binary search end must be greater than start.")
        return v


class LeidenNClusterConfig(_LeidenBaseConfig):
    clusters: int = pydantic.Field(
        default=5, description="The number of clusters to be returned."
    )
    binsearch: BinSearchSettings = pydantic.Field(default_factory=BinSearchSettings)


class LeidenNCluster:
    def __init__(self, settings: LeidenNClusterConfig) -> None:
        self._settings = settings

    def fit_predict(self, adata: AnnData, key_added: str) -> np.ndarray:
        for offset in [0, 20_000, 30_000, 40_000]:
            points = _binary_search_leiden_resolution(
                adata,
                k=self._settings.clusters,
                key_added=key_added,
                random_state=self._settings.random_state + offset,
                directed=self._settings.directed,
                use_weights=self._settings.use_weights,
                start=self._settings.binsearch.start,
                end=self._settings.binsearch.end,
                _epsilon=self._settings.binsearch.epsilon,
            )
            if points is not None:
                break
        # In case that for multiple random seeds we didn't find a resolution that
        # matches the number of clusters, we raise a ValueError.
        else:
            raise ValueError(
                f"No resolution for the number of clusters {self._settings.clusters}"
                f" found."
            )

        return points.obs[key_added].astype(int).values


def _binary_search_leiden_resolution(
    adata: an.AnnData,
    k: int,
    start: float,
    end: float,
    key_added: str,
    random_state: int,
    directed: bool,
    use_weights: bool,
    _epsilon: float,
) -> Optional[an.AnnData]:
    """Binary search to get the resolution corresponding
    to the right k."""
    # We try the resolution which is in the middle of the interval
    res = 0.5 * (start + end)

    # Run Leiden clustering
    sc.tl.leiden(
        adata,
        resolution=res,
        key_added=key_added,
        random_state=random_state,
        directed=directed,
        use_weights=use_weights,
    )

    

    # Get the number of clusters found
    selected_k = adata.obs[key_added].nunique()
    if selected_k == k:
        return adata

    # If the start and the end are too close (and there is no point in doing another
    # iteration), we raise an error that one can't find the required number of clusters
    if abs(end - start) < _epsilon * res:
        return None

    if selected_k > k:
        return _binary_search_leiden_resolution(
            adata,
            k=k,
            start=start,
            end=res,
            key_added=key_added,
            random_state=random_state,
            directed=directed,
            _epsilon=_epsilon,
            use_weights=use_weights,
        )
    else:
        return _binary_search_leiden_resolution(
            adata,
            k=k,
            start=res,
            end=end,
            key_added=key_added,
            random_state=random_state,
            directed=directed,
            _epsilon=_epsilon,
            use_weights=use_weights,
        )

In [None]:
from sklearn.metrics.cluster import silhouette_samples, silhouette_score

def silhouette_batch(
    adata,
    batch_key,
    group_key,
    latent_key,
    metric="euclidean",
    return_all=False,
    scale=True,
    verbose=True,
):
    """Batch ASW
    Modified average silhouette width (ASW) of batch
    This metric measures the silhouette of a given batch.
    It assumes that a silhouette width close to 0 represents perfect overlap of the batches, thus the absolute value of
    the silhouette width is used to measure how well batches are mixed.
    For all cells :math:`i` of a cell type :math:`C_j`, the batch ASW of that cell type is:
    .. math::
        batch \\, ASW_j = \\frac{1}{|C_j|} \\sum_{i \\in C_j} |silhouette(i)|
    The final score is the average of the absolute silhouette widths computed per cell type :math:`M`.
    .. math::
        batch \\, ASW = \\frac{1}{|M|} \\sum_{i \\in M} batch \\, ASW_j
    For a scaled metric (which is the default), the absolute ASW per group is subtracted from 1 before averaging, so that
    0 indicates suboptimal label representation and 1 indicates optimal label representation.
    .. math::
        batch \\, ASW_j = \\frac{1}{|C_j|} \\sum_{i \\in C_j} 1 - |silhouette(i)|
    :param batch_key: batch labels to be compared against
    :param group_key: group labels to be subset by e.g. cell type
    :param embed: name of column in adata.obsm
    :param metric: see sklearn silhouette score
    :param scale: if True, scale between 0 and 1
    :param return_all: if True, return all silhouette scores and label means
        default False: return average width silhouette (ASW)
    :param verbose: print silhouette score per group
    :return:
        Batch ASW  (always)
        Mean silhouette per group in pd.DataFrame (additionally, if return_all=True)
        Absolute silhouette scores per group label (additionally, if return_all=True)
    """
    if latent_key not in adata.obsm.keys():
        print(adata.obsm.keys())
        raise KeyError(f"{latent_key} not in obsm")

    sil_per_label = []
    for group in adata.obs[group_key].unique():
        adata_group = adata[adata.obs[group_key] == group]
        n_batches = adata_group.obs[batch_key].nunique()

        if (n_batches == 1) or (n_batches == adata_group.shape[0]):
            continue

        sil = silhouette_samples(
            adata_group.obsm[latent_key], adata_group.obs[batch_key], metric=metric
        )

        # take only absolute value
        sil = [abs(i) for i in sil]

        if scale:
            # scale s.t. highest number is optimal
            sil = [1 - i for i in sil]

        sil_per_label.extend([(group, score) for score in sil])

    sil_df = pd.DataFrame.from_records(
        sil_per_label, columns=["group", "silhouette_score"]
    )

    if len(sil_per_label) == 0:
        sil_means = np.nan
        asw = np.nan
    else:
        sil_means = sil_df.groupby("group").mean()
        asw = sil_means["silhouette_score"].mean()

    if verbose:
        print(f"mean silhouette per group: {sil_means}")

    if return_all:
        return asw, sil_means, sil_df

    return {"asw_batch_score":asw}

In [None]:
def kbet_rni_asw(adata, latent_key, batch_key, label_key, group_key, max_clusters):
    bdata = adata.copy()
    ari_score_collection = []
    k = np.linspace(2, max_clusters, max_clusters-1)
    for i in k:
        cdata = _binary_search_leiden_resolution(bdata, k = int(i), start = 0.1, end = 0.9, key_added ='final_annotation', random_state = 0, directed = False, 
        use_weights = False, _epsilon = 1e-3)
        if cdata is None:
            ari_score_collection.append(0)
            continue
        adata.obs['cluster_{}'.format(int(i))] = cdata.obs['final_annotation']
        ari_score_collection.append(compute_ari(adata, group_key = group_key, cluster_key = 'cluster_{}'.format(int(i))))


    # Note, all keys should come from the columns in adata.obs
    ari_score = {f"maximum ARI_score with {int(k[np.argmax(ari_score_collection)])} clusters": np.max(ari_score_collection)}
    sc.pl.umap(adata, color = ['cluster_{}'.format(int(k[np.argmax(ari_score_collection)]))])
    kbet_score = kbet(adata, latent_key=latent_key, batch_key=batch_key, label_key=label_key)
    asw_score = compute_asw(adata, group_key = group_key, latent_key = latent_key)
    asw_batch_score = silhouette_batch(adata, batch_key = batch_key, group_key= group_key, latent_key= latent_key)

    return [kbet_score, ari_score, asw_score, asw_batch_score]

In [None]:
#If returned an error called "Equality comparisons are not supported for AnnData objects, instead compare the desired attributes.", this means laiden clustering 
# returned bdata as a Nonetype object, just need to decrease the max_clusters
#kbet_rni_asw(adata = mdata, latent_key = 'X_VAE', batch_key = 'batch', label_key = 'final_annotation', group_key = 'final_annotation', max_clusters = 8)

In [None]:
def max_min_scale(dataset):
    if np.max(dataset) - np.min(dataset) != 0:
        return (dataset - np.min(dataset)) / (np.max(dataset) - np.min(dataset))
    if np.max(dataset) - np.min(dataset) == 0:
        return dataset

In [None]:
def grid_search_list_generator(layers, num_latent, lr):
    grid_search_list = []
    for i in range(len(layers)):
        for j in range(len(num_latent)):
            for k in range(len(lr)):
                grid_search_list.append([layers[i], num_latent[j], lr[k]])

    return grid_search_list

In [None]:
grid_search_list = grid_search_list_generator(layers = [1,3,5], 
num_latent = [5,10,15], 
lr =[1e-3,1e-4,1e-5])

In [None]:
col_name = []
for i in grid_search_list:
    col_name.append("VAE layers of %s, num latents of %s, lr of %s" % (i[0],i[1],i[2]))

In [None]:
def hyperparameter_tuning_general_scvi(adata, layer_key, batch_key, label_key, group_key, max_clusters, grid_search_list):
    # After malignant_cell_collection and model_preprocessing


    ari_collection = []
    asw_batch_collection = []
    kbet_collection = [] 
    asw_collection = []

    for i in grid_search_list:
        model = model_train(adata, layer_key = layer_key, batch_key = batch_key, n_layers = i[0], n_hidden = 512, n_latent = i[1], lr = i[2])
        print("VAE layers of %s, num latents of %s, lr of %s" % (i[0],i[1],i[2]))
        latent_key = get_latent_UMAP(model, adata, batch_key, label_key, added_latent_key = 'X_VAE_{}'.format(i), print_UMAP = True)
        score_collection = kbet_rni_asw(adata, latent_key = latent_key, batch_key = batch_key, label_key = label_key, group_key = group_key, max_clusters = max_clusters)
        for i in score_collection[1].values():
            ari_collection.append(i)
        for i in score_collection[3].values():
            asw_batch_collection.append(i)
        for i in score_collection[0].values():
            kbet_collection.append(i)
        for i in score_collection[2].values():
            asw_collection.append(i)

    ari_collection_mn = max_min_scale(ari_collection)
    asw_batch_collection_mn = max_min_scale(asw_batch_collection)
    kbet_collection_mn = max_min_scale(kbet_collection)
    asw_collection_mn = max_min_scale(asw_collection)


    #batch removal score includes kbet and asw_batch, bio-conservation score (cell_type_keeping) includes ari and asw_cell
    bio_score_collection = [] 
    batch_score_collection = [] 
    overall_score_collection = []
    for i in range(len(grid_search_list)):
        bio_score = np.mean((ari_collection_mn[i], asw_collection_mn[i]))
        bio_score_collection.append(bio_score)
        batch_score = np.mean((kbet_collection_mn[i], asw_batch_collection_mn[i]))
        batch_score_collection.append(batch_score)
        overall_score_collection.append(0.6 * bio_score + 0.4 * batch_score)
    return [ari_collection, asw_collection, kbet_collection, asw_batch_collection,bio_score_collection, batch_score_collection, overall_score_collection]

In [None]:
idata = sc.read_h5ad("Immune_ALL_human.h5ad")

In [None]:
mdata = malignant_cell_collection(idata, malignant_cell_incices = [1,5], label_key = 'final_annotation')

In [None]:
model_preprocessing(mdata, 'batch')

In [None]:
score = hyperparameter_tuning_general_scvi(mdata, layer_key = 'counts', batch_key= 'batch', label_key= 'final_annotation', group_key= 'final_annotation', max_clusters = 8, grid_search_list = grid_search_list)

In [None]:
def convert_scorelist_into_df(scorelist, variable_name, store, csv_file_name):
    score_pd = pd.DataFrame(scorelist, index = ["ari", "asw_cell", "kbet", "asw_batch","bio_score", "batch_score", "overall_score"], columns = variable_name)
    if store:
        score_pd.to_csv(csv_file_name)
    return score_pd

In [None]:
score_pd = convert_scorelist_into_df(score, col_name, True,'grid_general_scvi_all.csv')

In [None]:
ldata = sc.read_h5ad("Lung_atlas_public.h5ad")
mldata = malignant_cell_collection(ldata, malignant_cell_incices = [0,2], label_key = 'cell_type')
model_preprocessing(mldata, 'batch')

In [None]:
score_lung = hyperparameter_tuning_general_scvi(mldata, layer_key = 'counts', batch_key= 'batch', label_key= 'cell_type', group_key= 'cell_type', max_clusters = 8, grid_search_list = grid_search_list)

In [None]:
score_lung_pd = convert_scorelist_into_df(score_lung, col_name, True,'grid_general_scvi_lung.csv')

In [None]:
pdata = sc.read_h5ad("human_pancreas_norm_complexBatch.h5ad")

In [None]:
mpdata = malignant_cell_collection(pdata, malignant_cell_incices = [0,2], label_key = 'celltype')
model_preprocessing(mpdata, 'tech')
score_pancreas = hyperparameter_tuning_general_scvi(mpdata, layer_key = 'counts', batch_key= 'tech', label_key= 'celltype', group_key= 'celltype', max_clusters = 8, grid_search_list = grid_search_list)
score_pancreas_pd = convert_scorelist_into_df(score_pancreas, col_name, True,'grid_general_scvi_pancreas.csv')