# What this code does: Top genes selection--> Impute --> Clustering
#  Imputation methods included: scVI, gimVI, tangram (You can add more imputation methods in the 'Imputation Evaluator Class easily....)
1. Gene Selection Criteria: All genes, top 2000 genes, top 5000 genes
2. Leiden Clustering is used when inferencing.
3. gimVI and tangram are dedicated for ST data but requires the corresponding scRNA-seq data. In ours case, the respective ST datasets do not have the respective scRNA-seq data. Hence, we use the ST datasets as an alternative of the scRNA-seq datasets.

The code can be run in three mode:
1) Full Batch Mode: It will look into a directory for '.h5ad' file and run them one by one for all genes, top 2000 genes, top 5000 genes and then save in a csv.
2) Semi Batch Mode: It will look into a directory for '.h5ad' file and an external parameter can be set which will decide how many datasets will be processed at once for all genes, top 2000 genes and top 5000 genes.
3) Multi-batch mode: It will look into a directory for '.h5ad' file and two external parameters can be set where first parameter will decide how many datasets will be processed at once and second parameter will decide the number of gene selection value from all genes, top 2000 genes and top 5000 genes to be processed for each datasets.

Such modes are created for easier usability as different computer configuration can take longer time if all of the run at once!

In [None]:
# @title mount drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# @title installing packages
!pip install scanpy
!pip install magic-impute
!pip install igraph
!pip3 install leidenalg
!pip install fancyimpute
!pip install -U scvi-tools
!pip install tangram-sc

Collecting scanpy
  Downloading scanpy-1.11.1-py3-none-any.whl.metadata (9.9 kB)
Collecting anndata>=0.8 (from scanpy)
  Downloading anndata-0.11.4-py3-none-any.whl.metadata (9.3 kB)
Collecting legacy-api-wrap>=1.4 (from scanpy)
  Downloading legacy_api_wrap-1.4.1-py3-none-any.whl.metadata (2.1 kB)
Collecting scikit-learn<1.6.0,>=1.1 (from scanpy)
  Downloading scikit_learn-1.5.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Collecting session-info2 (from scanpy)
  Downloading session_info2-0.1.2-py3-none-any.whl.metadata (2.5 kB)
Collecting array-api-compat!=1.5,>1.4 (from anndata>=0.8->scanpy)
  Downloading array_api_compat-1.11.2-py3-none-any.whl.metadata (1.9 kB)
Downloading scanpy-1.11.1-py3-none-any.whl (2.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m13.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading anndata-0.11.4-py3-none-any.whl (144 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m144.

In [None]:
# @title Importing packages

# System & utility
import os
import time
import psutil
import tracemalloc
from tqdm import tqdm
import torch
import gc

# Core scientific packages
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Machine learning and preprocessing
from sklearn import metrics
from sklearn.impute import KNNImputer, SimpleImputer
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import (
    normalized_mutual_info_score,
    adjusted_rand_score,
    adjusted_mutual_info_score,
    homogeneity_score
)

# Matrix and sparse operations
from scipy import sparse
from scipy.sparse import issparse, csr_matrix
from scipy.sparse.linalg import svds

# Imputation methods
import magic
from fancyimpute import SoftImpute

# Single-cell packages
import scanpy as sc
from scvi.external import GIMVI
import tangram as tg
import scvi

# Uncomment these if needed later
# import scarches as sca
# import dca


In [3]:
# @title Imputation Evaluator

class ImputationEvaluator:
    def __init__(self, dataset_path, n_top_genes=2000):
        self.dataset_path = dataset_path
        self.dataset_name = os.path.basename(dataset_path)

        if n_top_genes == 'all':
            self.n_top_genes = None
        else:
            self.n_top_genes = int(n_top_genes)

        self.load_data()

    def load_data(self):
        """Load dataset and select top highly variable genes."""
        print(f"\nLoading dataset: {self.dataset_path}")
        self.adata = sc.read_h5ad(self.dataset_path)

        # Convert sparse matrix to dense if necessary
        # self.adata = self.adata.copy()
        X = self.adata.X
        if issparse(X):
            X = X.toarray()
            self.adata.X = X


        print(f"In the Imputation Evaluator the n_top_genes: {self.n_top_genes}\n\n")
        print(f"Original adata shape: {self.adata.shape}")

        # Determine ground truth label key
        if 'annotation' in self.adata.obs.columns:
            self.annotation_key = 'annotation'
        elif 'CellType' in self.adata.obs.columns:
            self.annotation_key = 'CellType'
        else:
            raise ValueError(f"No 'annotation' or 'CellType' found in obs columns for {self.dataset_name}")

        self.size_after_preprocessing = self.adata.shape

        if self.n_top_genes is not None:
          print(f"top genes the class got: {self.n_top_genes}\n")
          print(f"type of this var: {type(self.n_top_genes)}\n")
          # Keep the top highly variable genes
          sc.pp.highly_variable_genes(self.adata, flavor="seurat", n_top_genes=self.n_top_genes)
          self.adata = self.adata[:, self.adata.var['highly_variable']]
          self.size_after_top_genes = self.adata.shape
          print(f"Adata shape after selecting top {self.n_top_genes} genes: {self.adata.shape}")
        else:
          self.size_after_top_genes = self.adata.shape

    @staticmethod
    def calculate_sparsity(X):
        """Calculate sparsity of a matrix."""
        if issparse(X):
            X = X.toarray()
        zero_elements = np.sum(X == 0)
        total_elements = X.size
        return 100.0 * float(zero_elements) / float(total_elements)


    @staticmethod
    def scVI_impute(adata_imputed, max_epochs=350, latent_dim=32, use_gpu=True):
        """
        Train scVI and return imputed normalized expression (adata.X).

        Parameters:
            adata_imputed : AnnData
                AnnData object containing raw counts or approximated counts.
            max_epochs : int
                Maximum number of training epochs.
            latent_dim : int or None
                Dimension of latent space for the VAE.
            use_gpu : bool
                Whether to use GPU (if available).

        Returns:
            imputed_expression : np.ndarray
                Dense matrix of normalized expression (imputed values).
        """
        n_cores = os.cpu_count()
        torch.set_num_threads(n_cores - 1)  # Use 11 threads if 12 cores detected
        adata_scVI = adata_imputed.copy()
        SCVI = scvi.model.SCVI
        # Setup for scVI
        print("Setting up scVI...")
        SCVI.setup_anndata(adata_scVI)

        # Create the scVI model
        print("Creating scVI model...")
        if latent_dim is not None:
            vae = SCVI(adata_scVI, n_latent=latent_dim)
        else:
            vae = SCVI(adata_scVI)

        # Train model on GPU if available
        device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
        print(f"Training on {device}...")
        vae.train(max_epochs=max_epochs, accelerator=device, devices=1)

        # Get imputed/normalized values and assign to X
        print("Fetching normalized expression...")
        imputed_values = vae.get_normalized_expression()
        print("Finished fetching.")

        # Assign to .X and return
        adata_scVI.X = imputed_values
        return adata_scVI.X

    @staticmethod

    def gimVI_impute(adata_imputed, max_epochs=400, n_latent=40, use_gpu=True):

        """
        Train gimVI imputation and return imputed normalized expresssion (adata.X) 
        """
        adata_seq = adata_imputed.copy()
        spatial_adata = adata_imputed.copy()

        # Setup for gimVI
        print("Setting up gimVI...")
        GIMVI.setup_anndata(adata_seq)
        GIMVI.setup_anndata(spatial_adata)

        # Create the gimVI model
        print("Creating gimVI model...")
        gimVi_spatial = spatial_adata
        gimVi_expression = adata_seq

        model = GIMVI(gimVi_expression, gimVi_spatial, n_latent=n_latent)

        # Train model on GPU if available
        accelerator = "gpu" if use_gpu and torch.cuda.is_available() else "cpu"
        print(f"Training on {accelerator}.....")
        model.train(max_epochs=max_epochs, accelerator=accelerator, devices=1)
        
        # Get imputed/normalised values
        print("Fetching normalized expression....")
        _, imputed_values = model.get_imputed_values()
        imputed_values = csr_matrix(imputed_values)
        print("Finished fetching!")
        
        return imputed_values

    @staticmethod
    def tangram_impute(adata_imputed, max_epochs=200, use_gpu=True):

        """
        Train tangram imputation and return imputed normalized expresssion (adata.X) 
        """
        
        adata_seq = adata_imputed.copy()
        spatial_adata = adata_imputed.copy()

        Xdense = adata_seq.X
        if issparse(Xdense):
            adata_seq.X = Xdense.toarray()

        Xdense = spatial_adata.X
        if issparse(Xdense):
            spatial_adata.X = Xdense.toarray()

        markers = list(set.intersection(set(adata_seq.var_names), set(spatial_adata.var_names))) # get common genes/they are all common

        # Setup for tangram
        print("Setting up tangram...")
        tg.pp_adatas(adata_seq, spatial_adata, genes=markers)
    
       
        # Train model on GPU if available
        device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
        print(f"Training on {device}.....")
        ad_map = tg.map_cells_to_space(
            adata_seq,
            spatial_adata,
            mode="cells",
            density_prior="rna_count_based",
            num_epochs=max_epochs,
            device=device,  # or: cpu
        )

        
        # Get imputed/normalised values
        print("Projecting gene expression...")
        ad_ge = tg.project_genes(adata_map=ad_map, adata_sc=adata_seq)
        print("Done.")

        imputed_values = csr_matrix(ad_ge.X)

        return imputed_values



    """............More imputation
    methods to add...................."""
    # def dummy_imputation_X(X):
    #     """Dummy imputation method X (for placeholder)."""
    #     return X + np.random.normal(0, 0.01, size=X.shape)


    # def dummy_imputation_Y(X):
    #     """Dummy imputation method Y (for placeholder)."""
    #     return np.clip(X * 1.01, 0, None)

    def perform_clustering(self, adata, cluster_key):
        """Perform PCA, neighbors, UMAP and Leiden clustering."""
        sc.pp.pca(adata)
        sc.pp.neighbors(adata)
        sc.tl.umap(adata)
        sc.tl.leiden(adata, key_added=cluster_key, directed=False, n_iterations=2)

    def perform_clustering_with_plot(self, adata, cluster_key, dataset_name=None, n_top_genes=2000):
        """Perform PCA, neighbors, UMAP, Leiden clustering, and plot UMAP."""

        # Step 1: Dimensionality reduction and clustering
        sc.pp.pca(adata)
        sc.pp.neighbors(adata)
        sc.tl.umap(adata)
        sc.tl.leiden(adata, key_added=cluster_key, directed=False, n_iterations=2)

        # Step 2: Title for UMAP plot
        if n_top_genes is not None:
            title = f"UMAP - {dataset_name} | Top {n_top_genes} genes" if dataset_name and n_top_genes else "UMAP"
        else:
            title = f"UMAP - {dataset_name} | Full Data" if dataset_name else "UMAP"

        adata.obs[cluster_key] = adata.obs[cluster_key].astype("category")
        # Step 3: Show UMAP plot
        sc.pl.umap(
            adata,
            color=cluster_key,
            title=title,
            legend_loc='on data',
            frameon=False,
            show=True
        )

        # Commented out for visium...
        # # Step 4: Title for predicted label on the sptial plot
        # if n_top_genes is not None:
        #     title = f"Predicted Spatial Plot - {dataset_name} | Top {n_top_genes} genes" if dataset_name and n_top_genes else "Predicted Spatial Plot"
        # else:
        #     title = f"Predicted Spatial Plot - {dataset_name} | Full Data" if dataset_name else "Predicted Spatial Plot"
        # # Step 5: Show predicted spatial plot

        # sc.pl.spatial(
        #     adata,
        #     color=cluster_key,
        #     title=title,
        #     spot_size=2
        # )

    def plot_spatial_with_predicted_labels(self, adata, annotation_key='annotation', dataset_name=None, s=4, n_top_genes=2000):
        """
        Plots 2D spatial coordinates of cells colored by true labels.

        Parameters:
        - adata: AnnData object with adata.obsm['spatial'] and adata.obs[annotation_key]
        - annotation_key: Column in adata.obs for true labels (e.g., 'annotation', 'cell_type')
        - dataset_name: Optional dataset name for the plot title
        - s: Marker size
        """

        # Step 1: Get spatial coordinates
        spatial_coords = adata.obsm['spatial']
        x = spatial_coords[:, 0]
        y = spatial_coords[:, 1]

        # Step 2: Prepare labels
        labels = adata.obs[self.annotation_key]
        if labels.dtype.name == 'category' or labels.dtype == object:
            le = LabelEncoder()
            color_labels = le.fit_transform(labels)
            label_names = le.classes_
        else:
            color_labels = labels
            label_names = np.unique(labels)

        # Step 3: Plot
        plt.figure(figsize=(8, 6))
        scatter = plt.scatter(x, y, c=color_labels, cmap='tab20', s=s, alpha=0.8)
        plt.gca().invert_yaxis()  # Optional: if spatial coordinates are top-down
        plt.title(f"Spatial Plot (Predicted Labels) - {dataset_name}|{n_top_genes} top genes" if dataset_name else "Spatial Plot (Predicted Labels)")
        plt.xlabel("Spatial 1")
        plt.ylabel("Spatial 2")

        # Create legend
        handles = [plt.Line2D([0], [0], marker='o', color='w', label=label,
                            markerfacecolor=scatter.cmap(scatter.norm(i)), markersize=8)
                for i, label in enumerate(label_names)]
        plt.legend(handles=handles, bbox_to_anchor=(1.05, 1), loc='upper left')

        plt.tight_layout()
        plt.show()

    def plot_spatial_with_true_labels(self, adata, annotation_key='annotation', dataset_name=None, spot_size=2):
        """
        Plots spatial coordinates of cells using scanpy's built-in spatial plot,
        colored by true labels (e.g., cell types).

        Parameters:
        - adata: AnnData object with adata.obsm['spatial'] and adata.obs[annotation_key]
        - annotation_key: Column in adata.obs for true labels (e.g., 'annotation', 'cell_type')
        - dataset_name: Optional dataset name for title
        - spot_size: Marker size for plotting
        """
        if annotation_key not in adata.obs:
            raise ValueError(f"'{annotation_key}' not found in adata.obs")

        
        sc.pl.spatial(
            adata,
            color=annotation_key,
            title=f"Spatial Plot - {dataset_name}" if dataset_name else "Spatial Plot (True Labels)",
            spot_size=spot_size
        )


    def evaluate_clustering(self, adata, cluster_key):
        """Compute ARI, NMI, AMI, and Homogeneity scores."""
        true_labels = adata.obs[self.annotation_key]
        predicted_labels = adata.obs[cluster_key]

        ari = adjusted_rand_score(true_labels, predicted_labels)
        nmi = normalized_mutual_info_score(true_labels, predicted_labels)
        ami = adjusted_mutual_info_score(true_labels, predicted_labels)
        homo = homogeneity_score(true_labels, predicted_labels)

        return ari, nmi, ami, homo

    def run_full_evaluation(self):
        """
        Run baseline clustering and multiple imputations.
        Returns: dict with all results
        """
        results = {}

        # Define available imputations
        imputation_methods = {
            'scVI': self.scVI_impute,
            'gimVI': self.gimVI_impute,
            'tangram': self.tangram_impute
        }
        # imputation_methods = {
        #     'magic': self.magic_impute,
        #     'alra': self.alra_impute,
        #     'dummy_X': self.dummy_imputation_X,
        #     'dummy_Y': self.dummy_imputation_Y,
        # }

        # Calculate initial sparsity
    
        results['Raw zero Exp val (%)'] = self.calculate_sparsity(self.adata.X)

        # True Labels plot
        print(f"True Labels plot for the dataset {self.dataset_name}..\n")
        self.plot_spatial_with_true_labels(self.adata, dataset_name=self.dataset_name)
        # Baseline clustering
        print("performing baseline clustering..\n")
        ari_scores, nmi_scores, ami_scores, homo_scores = [], [], [], []

        for i in range(5):
            print(f"Base Clustering Run {i+1}...")
            # self.perform_clustering(self.adata, cluster_key="clusters_original")
            self.perform_clustering_with_plot(self.adata, cluster_key="clusters_original", dataset_name=self.dataset_name, n_top_genes=self.n_top_genes)
            ari, nmi, ami, homo = self.evaluate_clustering(self.adata, "clusters_original")

            ari_scores.append(ari)
            nmi_scores.append(nmi)
            ami_scores.append(ami)
            homo_scores.append(homo)

        # Compute mean values across the 5 runs
        base_ari = np.mean(ari_scores)
        base_nmi = np.mean(nmi_scores)
        base_ami = np.mean(ami_scores)
        base_homo = np.mean(homo_scores)

        results.update({
            'Base ARI': base_ari,
            'Base NMI': base_nmi,
            'Base AMI': base_ami,
            'Base HOMO': base_homo,
        })

        # Perform imputations
        for method_name, imputation_function in tqdm(imputation_methods.items(), desc="Running Imputations", leave=False):
            adata_imputed = self.adata.copy()

            X = adata_imputed.X
            if issparse(X):
                X = X.toarray().astype(np.float16)
                adata_imputed.X = X
            else:
                X = X.astype(np.float16)
                adata_imputed.X = X

            tracemalloc.start()
            start_time = time.time()

            # Apply imputation
            adata_imputed.X = imputation_function(adata_imputed)

            end_time = time.time()
            current, peak = tracemalloc.get_traced_memory()
            runtime = end_time - start_time
            memory = peak / (1024 ** 2)  # MB
            tracemalloc.stop()

            # Clustering after imputation
            cluster_key = f"clusters_{method_name}"
            print(f"performing clustering on the {method_name} imputation..\n")

            ari_scores, nmi_scores, ami_scores, homo_scores = [], [], [], []

            for i in range(5):
              print(f"Clustering Run {i+1}...")
              # self.perform_clustering(adata_imputed, cluster_key=cluster_key)
              self.perform_clustering_with_plot(adata_imputed, cluster_key=cluster_key, dataset_name=self.dataset_name, n_top_genes=self.n_top_genes)
              self.plot_spatial_with_predicted_labels(adata_imputed, annotation_key=cluster_key, dataset_name=self.dataset_name, n_top_genes=self.n_top_genes)
              ari, nmi, ami, homo = self.evaluate_clustering(adata_imputed, cluster_key)

              ari_scores.append(ari)
              nmi_scores.append(nmi)
              ami_scores.append(ami)
              homo_scores.append(homo)

            # Compute mean values across the 5 runs
            ari = np.mean(ari_scores)
            nmi = np.mean(nmi_scores)
            ami = np.mean(ami_scores)
            homo = np.mean(homo_scores)

            # Save results
            results.update({
                f'ARI_{method_name}': ari,
                f'NMI_{method_name}': nmi,
                f'AMI_{method_name}': ami,
                f'HOMO_{method_name}': homo,
                f'{method_name} zero Exp val (%)': self.calculate_sparsity(adata_imputed.X),
                f'{method_name} Runtime (s)': runtime,
                f'{method_name} Memory (MB)': memory,
            })

        return results


In [14]:
# @title run_on_multiple_datasets

def run_on_multiple_datasets(folder_path, n_top_genes=2000):
    """Run evaluation on all datasets and save results to CSV."""
    datasets = [f for f in os.listdir(folder_path) if f.endswith('.h5ad')]
    all_results = []

    for dataset_name in tqdm(datasets, desc="Datasets"):
        dataset_path = os.path.join(folder_path, dataset_name)
        evaluator = ImputationEvaluator(dataset_path, n_top_genes)

        results = evaluator.run_full_evaluation()
        result_row = {
            'Dataset Name': evaluator.dataset_name,
            'Size After pre-processing': f"{evaluator.size_after_preprocessing[0]}x{evaluator.size_after_preprocessing[1]}",
            'Size After selecting top genes': f"{evaluator.size_after_top_genes[0]}x{evaluator.size_after_top_genes[1]}",
            'Base ARI': results['Base ARI'],
            'Base NMI': results['Base NMI'],
            'Base AMI': results['Base AMI'],
            'Base HOMO': results['Base HOMO'],
            'Raw zero Exp val (%)': results['Raw zero Exp val (%)'],
            'Cluster Algo': 'Leiden',
            'top_genes': n_top_genes
        }

        # Add all imputation-specific results dynamically
        for key, value in results.items():
            if key not in result_row:
                result_row[key] = value

        all_results.append(result_row)

    # convert the results to a dataframe
    df_results = pd.DataFrame(all_results)

    return df_results

In [4]:
# @title run_on_single_dataset

def run_on_single_dataset(dataset_path, n_top_genes):
    """Process a single dataset, evaluate, and append results to CSV."""

    evaluator = ImputationEvaluator(dataset_path, n_top_genes)
    results = evaluator.run_full_evaluation()

    result_row = {
        'Dataset Name': evaluator.dataset_name,
        'Size After pre-processing': f"{evaluator.size_after_preprocessing[0]}x{evaluator.size_after_preprocessing[1]}",
        'Size After selecting top genes': f"{evaluator.size_after_top_genes[0]}x{evaluator.size_after_top_genes[1]}",
        'Base ARI': results['Base ARI'],
        'Base NMI': results['Base NMI'],
        'Base AMI': results['Base AMI'],
        'Base HOMO': results['Base HOMO'],
        'Raw zero Exp val (%)': results['Raw zero Exp val (%)'],
        'Cluster Algo': 'Leiden',
        'top_genes': n_top_genes
    }

    # Add dynamic imputation-specific results
    for key, value in results.items():
        if key not in result_row:
            result_row[key] = value

    # Convert to DataFrame
    df_row = pd.DataFrame([result_row])

    return df_row



In [5]:
# @title post processing after evaluating imputations

def reorder_columns(df):
  # Define the base columns
  base_keys = ['Dataset Name', 'Size After pre-processing', 'Size After selecting top genes', 'Cluster Algo', 'top_genes']

  # Dynamically collect other metric-type columns
  ari_keys = [col for col in df.columns if 'ARI' in col]
  nmi_keys = [col for col in df.columns if 'NMI' in col]
  ami_keys = [col for col in df.columns if 'AMI' in col]
  homo_keys = [col for col in df.columns if 'HOMO' in col]
  zero_keys = [col for col in df.columns if 'zero' in col]
  runtime_keys = [col for col in df.columns if 'Runtime' in col]
  memory_keys = [col for col in df.columns if 'Memory' in col]

  # Catch any columns not included above
  all_collected = set(base_keys + ari_keys + nmi_keys + ami_keys + homo_keys + zero_keys + runtime_keys + memory_keys)
  remaining_keys = [col for col in df.columns if col not in all_collected]

  # Reorder the DataFrame
  ordered_cols = base_keys + ari_keys + nmi_keys + ami_keys + homo_keys + zero_keys + runtime_keys + memory_keys + remaining_keys
  df = df[ordered_cols]

  return df


In [None]:
# @title main function to run in Full Batch Mode

if __name__ == "__main__":
  dataset_path = 'D:/VM Data/thesis/Analysis/visium/'
  n_top = ['all', 2000, 5000]
  # List to collect DataFrames
  all_results = []
  output_csv = f"visium_magic_knn_soft_simple_imputation_results.csv"

  for n in n_top:
    # Run the function
    df_results = run_on_multiple_datasets(dataset_path, n_top_genes = n)
    # Collect in list
    all_results.append(df_results)

  # Merge all DataFrames
  final_df = pd.concat(all_results, ignore_index=True)

  # Sort by Dataset name, then by Top_Genes value
  final_df = final_df.sort_values(by=["Dataset Name", "top_genes"]).reset_index(drop=True)

  # reorder the results
  df_results_reordered = reorder_columns(final_df)

  # save to csv
  output_path = os.path.join(dataset_path, output_csv)
  df_results_reordered.to_csv(output_path, index=False)
  print(f"\nAll results saved to {output_path}")



In [None]:
# @title main function to run in Semi Batch Mode


if __name__ == "__main__":

    dataset_folder = 'D:/VM Data/thesis/Analysis/visium/'
    output_csv = os.path.join(dataset_folder, "visium_magic_knn_soft_simple_imputation_results.csv")
    n_top = ['all', 2000, 5000]
    # n_top = ['all']

    batch_size = 1  # Change this to control how many datasets to process per run
    # n_top_batch_size = 1

    # Get all .h5ad dataset filenames
    all_datasets = sorted([f for f in os.listdir(dataset_folder) if f.endswith('.h5ad')])

    # Get names of already processed datasets from the CSV
    if os.path.exists(output_csv):
        df_done = pd.read_csv(output_csv)
        processed_datasets = set(df_done['Dataset Name'].unique())
        print(processed_datasets)
    else:
        processed_datasets = set()

    # Find remaining datasets to process
    remaining_datasets = [f for f in all_datasets if f not in processed_datasets]
    print(remaining_datasets)

    if not remaining_datasets:
        print("✅ All datasets are already processed and saved in the CSV.\n")
    else:
        print(f"🟡 Found {len(remaining_datasets)} remaining datasets.\n")
        to_process = remaining_datasets[:batch_size]

        for dataset_filename in tqdm(to_process, desc="Processing Datasets in Bath-mode"):
            dataset_path = os.path.join(dataset_folder, dataset_filename)
            all_results = []

            for n in n_top:
              print(f"Analyzing dataset {dataset_filename} for top genes {n}\n")
              df_row = run_on_single_dataset(dataset_path, n_top_genes=n)
              all_results.append(df_row)
            print(f"✔ Analyzing finished for {dataset_filename}\n")
            print(f"save the results for the dataset {dataset_filename} in the csv\n")

            # Merge all DataFrames
            final_df = pd.concat(all_results, ignore_index=True)

            # Sort by Dataset name, then by Top_Genes value
            final_df = final_df.sort_values(by=["Dataset Name", "top_genes"]).reset_index(drop=True)

            # reorder the results
            df_results_reordered = reorder_columns(final_df)

            # Append to CSV
            if os.path.exists(output_csv):
                df_results_reordered.to_csv(output_csv, mode='a', header=False, index=False)
            else:
                df_results_reordered.to_csv(output_csv, index=False)

            print(f"Results saved for the dataset {dataset_filename} in the csv!!\n")



In [None]:
# @title main function to run in Multi-batch Mode

if __name__ == "__main__":

    dataset_folder = 'D:/VM Data/thesis/Analysis/visium/'
    output_csv = os.path.join(dataset_folder, "visium_scVI_GIMVI_tangram_imputation_results.csv")

    all_n_top = ['all', 2000, 5000]
    batch_size = 1            # Number of datasets per batch
    n_top_batch_size = 1     # Number of top_genes values per dataset per batch

    all_datasets = sorted([f for f in os.listdir(dataset_folder) if f.endswith('.h5ad')])

    # Load already processed entries
    processed_map = {}
    if os.path.exists(output_csv):
        df_done = pd.read_csv(output_csv)
        for dataset in df_done['Dataset Name'].unique():
            processed_tops = set(df_done[df_done['Dataset Name'] == dataset]['top_genes'].astype(str))
            processed_map[dataset] = processed_tops
    else:
        df_done = pd.DataFrame()
        processed_map = {}

    # Build dictionary of work to be done
    work_dict = {}
    for fname in all_datasets:
        dataset_name = fname
        done_tops = processed_map.get(dataset_name, set())
        remaining_tops = [str(t) for t in all_n_top if str(t) not in done_tops]
        if remaining_tops:
            work_dict[dataset_name] = remaining_tops

    # Print all pending work
    if not work_dict:
        print("✅ All datasets and top_genes combinations are already processed.\n")
    else:
        print(f"🟡 {len(work_dict)} datasets still need processing:")
        for k, v in work_dict.items():
            print(f"  🔸 {k}: Missing top_genes → {v}")
        print()

    # Select datasets to process
    selected_datasets = list(work_dict.keys())[:batch_size]

    for dataset_filename in tqdm(selected_datasets, desc="Processing Datasets in Batch-mode"):
        dataset_path = os.path.join(dataset_folder, dataset_filename)
        remaining_tops = work_dict[dataset_filename]

        # Handle smaller-than-batch case safely
        top_genes_list = remaining_tops[:n_top_batch_size] if len(remaining_tops) >= n_top_batch_size else remaining_tops

        all_results = []
        for n in top_genes_list:
            print(f"🔍 Analyzing dataset {dataset_filename} for top genes {n}")
            df_row = run_on_single_dataset(dataset_path, n_top_genes=n)
            all_results.append(df_row)

        if all_results:
            final_df = pd.concat(all_results, ignore_index=True)
            final_df = final_df.sort_values(by=["Dataset Name", "top_genes"]).reset_index(drop=True)
            df_results_reordered = reorder_columns(final_df)

            if os.path.exists(output_csv):
                df_results_reordered.to_csv(output_csv, mode='a', header=False, index=False)

                # Reload entire CSV and sort
                full_df = pd.read_csv(output_csv)
                full_df = full_df.sort_values(by=["Dataset Name", "top_genes"]).reset_index(drop=True)
                full_df = reorder_columns(full_df)
                full_df.to_csv(output_csv, index=False)

            else:
                df_results_reordered.to_csv(output_csv, index=False)

            print(f"✅ Results saved for {dataset_filename}\n")

    # 🔁 Print remaining work again after processing
    print("🔄 Remaining datasets and top_genes still to process:")
    updated_processed = set()
    if os.path.exists(output_csv):
        updated_df = pd.read_csv(output_csv)
        updated_processed_map = {
            dataset: set(updated_df[updated_df['Dataset Name'] == dataset]['top_genes'].astype(str))
            for dataset in updated_df['Dataset Name'].unique()
        }
        for fname in all_datasets:
            done_tops = updated_processed_map.get(fname, set())
            still_pending = [str(t) for t in all_n_top if str(t) not in done_tops]
            if still_pending:
                print(f"  🔸 {fname}: Missing top_genes → {still_pending}")
    else:
        print("⚠ CSV not found after run, no updates made.")
