In [1]:

import scanpy as sc
import anndata
import numpy as np
import pandas as pd
from sklearn.metrics import silhouette_score, adjusted_rand_score
from sklearn.cluster import KMeans
from scipy.stats import zscore
import matplotlib.pyplot as plt
from scipy.sparse import csr_matrix
from io import BytesIO
import base64
import warnings
from typing import List, Literal, Optional
from anndata import AnnData
from insitupy.utils.preprocessing import sctransform_anndata
from pathlib import Path
from insitupy.datasets.download import download_url
import shutil
import os
from insitupy import read_xenium
import scanpy as sc
# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')

In [2]:
out_dir = Path("demo_dataset") # output directory
data_dir = out_dir / "output-XETG00000__0001879__Replicate 1" # directory of xenium data
image_dir = out_dir / "unregistered_images" # directory of images

In [3]:
out_dir = Path("demo_dataset") # output directory
data_dir = out_dir / "output-XETG00000__0001879__Replicate 1" # directory of xenium data
image_dir = out_dir / "unregistered_images" # directory of images

In [4]:
xd = read_xenium(data_dir)

In [13]:
def compare_transformations_anndata(
    adata: AnnData,
    transformation_methods: List[Literal["log1p", "sqrt_1", "sqrt_2", "pearson_residuals", "sctransform"]],
    verbose: bool = True,
    output_path: str = "normalization_results.html",
    true_labels: Optional[pd.Series] = None
) -> pd.DataFrame:
    """
    Normalize and transform the data in an AnnData object based on specified methods,
    and then compare the transformed results using various metrics.

    Args:
        adata (AnnData): The AnnData object to be normalized and transformed.
        transformation_methods (List[str]): List of transformation methods to apply.
            Options are ["log1p", "sqrt_1", "sqrt_2", "pearson_residuals", "sctransform"].
        verbose (bool, optional): If True, prints progress messages. Default is True.
        output_path (str, optional): The path where the HTML report will be saved.
            Default is 'normalization_results.html'.
        true_labels (Optional[pd.Series], optional): True labels for the cells,
            if available, for computing ARI.

    Returns:
        pd.DataFrame: A DataFrame with comparison metrics for each transformation method.
    """

    # Import necessary libraries
    import scanpy as sc
    import numpy as np
    import pandas as pd
    from sklearn.metrics import silhouette_score, adjusted_rand_score
    from sklearn.cluster import KMeans
    from scipy.stats import zscore
    import matplotlib.pyplot as plt
    from scipy.sparse import csr_matrix
    from io import BytesIO
    import base64
    import warnings

    # Suppress warnings for cleaner output
    warnings.filterwarnings('ignore')

    # Step 1: Normalize and transform the data using the specified methods
    if verbose:
        print("Storing raw counts in adata.layers['counts']...")

    # Store raw counts for comparison
    adata.layers['counts'] = adata.X.copy()

    # Normalize total counts
    sc.pp.normalize_total(adata, target_sum=1e4)
    adata.layers['norm_counts'] = adata.X.copy()

    # Dictionary to store different transformations
    transformed_data = {}

    for method in transformation_methods:
        if verbose:
            print(f"Applying transformation: {method}")

        # Copy the normalized AnnData object for each transformation
        adata_copy = adata.copy()

        # Apply the selected transformation method
        if method == "log1p":
            sc.pp.log1p(adata_copy)

        elif method == "sqrt_1":
            X = adata_copy.X.toarray()
            adata_copy.X = csr_matrix(np.sqrt(X + 1))

        elif method == "sqrt_2":
            X = adata_copy.X.toarray()
            adata_copy.X = csr_matrix(np.sqrt(X))

        elif method == "pearson_residuals":
            # Applying the Pearson residuals transformation
            sc.experimental.pp.normalize_pearson_residuals(adata_copy, layer="counts", inplace=True)

        elif method == "sctransform":
            # Applying SCTransform using your custom function
            adata_copy = sctransform_anndata(adata_copy)

        else:
            raise ValueError(f'`transformation_method` {method} is not valid.')

        # Store the transformed AnnData object in the results dictionary
        transformed_data[method] = adata_copy

    # Step 2: Compare the transformations and generate the plots
    results = {}
    plots = []


    for method, transformed_adata in transformed_data.items():
        if verbose:
            print(f"Processing {method}...")

        # Compute PCA for dimensionality reduction
        sc.pp.pca(transformed_adata, n_comps=10)
        X_pca = transformed_adata.obsm['X_pca']

        # Clustering with KMeans
        kmeans = KMeans(n_clusters=5, random_state=0).fit(X_pca)
        labels = kmeans.labels_

        # Compute Silhouette Score
        sil_score = silhouette_score(X_pca, labels)

        # Mean-Variance Relationship
        X_array = transformed_adata.X.toarray()
        mean_expression = np.mean(X_array, axis=0)
        variance_expression = np.var(X_array, axis=0)
        variance_stabilization = np.corrcoef(mean_expression, variance_expression)[0, 1]

        # Z-Score Mean
        z_scores = zscore(X_array, axis=0, ddof=1)
        z_score_mean = np.nanmean(z_scores)

        # Coefficient of Variation (CV)
        cv = np.std(X_array) / np.mean(X_array)

        # Adjusted Rand Index (ARI)
        if true_labels is not None:
            ari = adjusted_rand_score(true_labels, labels)
        else:
            ari = np.nan

        # Store results for comparison
        results[method] = {
            "Silhouette Score": sil_score,
            "Variance Stabilization": variance_stabilization,
            "Z-Score Mean": z_score_mean,
            "Coefficient of Variation": cv,
            "Adjusted Rand Index": ari
        }

        # Generate Mean-Variance Plot
        plt.figure(figsize=(6, 4))
        plt.scatter(mean_expression, variance_expression, alpha=0.5, s=10)
        plt.xlabel('Mean Expression')
        plt.ylabel('Variance')
        plt.title(f'Mean-Variance Plot - {method}')
        plt.tight_layout()
        buf = BytesIO()
        plt.savefig(buf, format='png')
        plt.close()
        buf.seek(0)
        image_base64 = base64.b64encode(buf.getvalue()).decode('utf-8')
        plots.append(f'<h3>{method} - Mean-Variance Plot</h3><img src="data:image/png;base64,{image_base64}" alt="{method} Mean-Variance Plot"/>')

    # Convert results dictionary to a DataFrame for easier comparison
    results_df = pd.DataFrame(results).T  # Transpose to have methods as rows

    # Highlight the best methods and create HTML table
    def highlight_best_method(results_df):
        # Copy the DataFrame to avoid modifying the original
        highlighted_df = results_df.copy()

        # For metrics where higher is better
        metrics_to_maximize = ['Silhouette Score', 'Adjusted Rand Index']

        # For metrics where lower is better
        metrics_to_minimize = ['Variance Stabilization', 'Z-Score Mean', 'Coefficient of Variation']

        # Highlight the best values
        for metric in metrics_to_maximize:
            if metric in highlighted_df.columns:
                best_value_index = highlighted_df[metric].idxmax()
                highlighted_df.loc[best_value_index, metric] = (
                    f'<div style="background-color:lightgreen">{results_df.loc[best_value_index, metric]}</div>'
                )

        for metric in metrics_to_minimize:
            if metric in highlighted_df.columns:
                best_value_index = highlighted_df[metric].abs().idxmin()
                highlighted_df.loc[best_value_index, metric] = (
                    f'<div style="background-color:lightgreen">{results_df.loc[best_value_index, metric]}</div>'
                )

        # Convert the DataFrame to HTML with escape=False to allow HTML tags
        return highlighted_df.to_html(escape=False)

    results_html = highlight_best_method(results_df)

    # Generate the final HTML report
    full_html = f"""
    <html>
    <head>
        <title>Transformation Results</title>
    </head>
    <body>
        <h1>Transformation Comparison Results</h1>
        <h2>Summary Table</h2>
        {results_html}
        <h2>Transformation Method Plots</h2>
        {"<br>".join(plots)}
    </body>
    </html>
    """

    # Save the HTML file to the specified output path
    with open(output_path, "w") as file:
        file.write(full_html)

    if verbose:
        print(f"HTML report created and saved as '{output_path}'")
    return results_df

In [14]:
xd.load_cells()
adata = xd.cells.matrix

Loading cells...


In [15]:
compare_transformations_anndata(adata, transformation_methods=["log1p", "sqrt_1", "sqrt_2", "pearson_residuals", "sctransform"])

Storing raw counts in adata.layers['counts']...
Applying transformation: log1p
Applying transformation: sqrt_1
Applying transformation: sqrt_2
Applying transformation: pearson_residuals
Applying transformation: sctransform
Starting SCTransform...
AnnData object saved temporarily at: C:\Users\Aitana\AppData\Local\Temp\tmp2cpcoxh9.h5ad


R[write to console]: Running SCTransform on assay: RNA

R[write to console]: vst.flavor='v2' set. Using model with fixed slope and excluding poisson genes.

R[write to console]: Calculating cell attributes from input UMI matrix: log_umi

R[write to console]: Variance stabilizing transformation of count matrix of size 313 by 166363

R[write to console]: Model formula is y ~ log_umi

R[write to console]: Get Negative Binomial regression parameters per gene

R[write to console]: Using 312 genes, 5000 cells

R[write to console]: Second step: Get residuals using fitted parameters for 313 genes

R[write to console]: Computing corrected count matrix for 313 genes

R[write to console]: Calculating gene attributes

R[write to console]: Wall clock passed: Time difference of 18.82202 secs

R[write to console]: Determine variable features

R[write to console]: Centering data matrix

  |                                                                            
  |                                 

SCTransform applied to Seurat object.
Converted Seurat object to SingleCellExperiment.
SCTransform transformation completed and returned as AnnData.
Processing log1p...
Processing sqrt_1...
Processing sqrt_2...
Processing pearson_residuals...
Processing sctransform...


KeyError: nan