# Preprocess Ramachandran dataset

This notebook preprocesses a Raw.TAR file that consists of healthy and diseased liver samples and converts them into H5AD files. The goal is to use the processed H5AD files to generate embeddings using scGPT.

The data used can be found [here](https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE136103).

# Install and Imports

In [None]:
!pip install -q scanpy anndata

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.1 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.4/2.1 MB[0m [31m10.8 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m2.1/2.1 MB[0m [31m31.6 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m22.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m169.9/169.9 kB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.2/58.2 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m276.4/276.4 kB[0m [31m10.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.8/8.8 MB[0m [31m78.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
import scanpy as sc
import anndata as ad
import os
import scipy
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

## Load Data

This part loads in the TAR file and separates the healthy and diseased samples and groups them into their respective subfolders

In [None]:
# --- Separate samples into subfolders ---

# Path to the main folder containing all sample folders
main_path = '/content/drive/MyDrive/projects/scGPT-MAFLD/reference/GSE136103/'

# List all sample folders
all_samples = [f for f in os.listdir(main_path) if os.path.isdir(os.path.join(main_path, f))]

# Create dictionaries
healthy_samples = {}
diseased_samples = {}

for sample in all_samples:
    sample_path = os.path.join(main_path, sample)
    sample_lower = sample.lower()

    # Skip blood and mouse samples
    if "blood" in sample_lower or "mouse" in sample_lower:
        continue

    if "healthy" in sample_lower:
        healthy_samples[sample] = sample_path
    elif "cirrhotic" in sample_lower or "fibrotic" in sample_lower:
        diseased_samples[sample] = sample_path

print("Healthy samples:", healthy_samples)
print("Diseased samples:", diseased_samples)


Healthy samples: {'GSM4041150_healthy1_cd45+': '/content/drive/MyDrive/projects/scGPT-MAFLD/reference/GSE136103/GSM4041150_healthy1_cd45+', 'GSM4041151_healthy1_cd45-A': '/content/drive/MyDrive/projects/scGPT-MAFLD/reference/GSE136103/GSM4041151_healthy1_cd45-A', 'GSM4041152_healthy1_cd45-B': '/content/drive/MyDrive/projects/scGPT-MAFLD/reference/GSE136103/GSM4041152_healthy1_cd45-B', 'GSM4041153_healthy2_cd45+': '/content/drive/MyDrive/projects/scGPT-MAFLD/reference/GSE136103/GSM4041153_healthy2_cd45+', 'GSM4041154_healthy2_cd45-': '/content/drive/MyDrive/projects/scGPT-MAFLD/reference/GSE136103/GSM4041154_healthy2_cd45-', 'GSM4041155_healthy3_cd45+': '/content/drive/MyDrive/projects/scGPT-MAFLD/reference/GSE136103/GSM4041155_healthy3_cd45+', 'GSM4041156_healthy3_cd45-A': '/content/drive/MyDrive/projects/scGPT-MAFLD/reference/GSE136103/GSM4041156_healthy3_cd45-A', 'GSM4041157_healthy3_cd45-B': '/content/drive/MyDrive/projects/scGPT-MAFLD/reference/GSE136103/GSM4041157_healthy3_cd45-B'

In [None]:
print(len(healthy_samples))
print(len(diseased_samples))

11
9


## Create Functions

The functions created do the following:
1. Load and merge samples into healthy and cirrhotic
2. Calculate QC metrics
3. Analyse QC and set thresholds
4. Apply filtering based on the thresholds
5. Process data, which runs all the above functions for all samples and creates a H5AD file

In [None]:
def load_and_merge_samples(sample_dirs, group_name):
    """
    Load and merge multiple 10X samples into one AnnData object with Ensembl IDs as index
    and gene symbols as metadata. Removes redundant 'gene_ids' column.
    Handles both v2 (2-column) and v3 (3-column) features.tsv.gz formats.
    """
    import os
    import scipy.io
    import scanpy as sc
    import pandas as pd

    adatas = []
    for sample_id, sample_path in sample_dirs.items():
        print(f"Loading sample: {sample_id}")

        # Load raw 10X files
        matrix = scipy.io.mmread(os.path.join(sample_path, "matrix.mtx.gz")).T.tocsr()
        barcodes = pd.read_csv(os.path.join(sample_path, "barcodes.tsv.gz"),
                               header=None, sep="\t")[0].values
        genes = pd.read_csv(os.path.join(sample_path, "genes.tsv.gz"),
                            header=None, sep="\t")

        # Handle v2 (2 cols) vs v3 (3 cols)
        if genes.shape[1] == 2:
            gene_ids = genes[0].values
            gene_symbols = genes[1].values
        elif genes.shape[1] >= 3:
            gene_ids = genes[0].values
            gene_symbols = genes[1].values
        else:
            raise ValueError(f"Unexpected genes.tsv.gz format in {sample_id}")

        # Build AnnData
        adata = sc.AnnData(matrix)
        adata.obs_names = barcodes
        adata.var_names = gene_ids
        adata.var["gene_symbols"] = gene_symbols  # keep only symbols as metadata

        # Metadata
        adata.obs["sample_id"] = sample_id
        adata.obs["batch"] = sample_id

        adatas.append(adata)

    # Merge samples, keep all var columns
    merged_adata = sc.concat(adatas, join="outer", merge="same")
    merged_adata.obs["group"] = group_name

    print(f"Merged {group_name} dataset: {merged_adata.n_obs} cells, {merged_adata.n_vars} genes")
    print(f"Var columns: {merged_adata.var.columns.tolist()}")
    return merged_adata

In [None]:
def calculate_qc_metrics(adata):
    """
    Calculate quality control metrics for single-cell data

    Parameters:
    -----------
    adata : AnnData
        AnnData object with gene_symbols in adata.var

    Returns:
    --------
    adata : AnnData
        AnnData object with QC metrics added
    """
    print("Calculating QC metrics...")

    # Identify gene types using gene symbols
    adata.var["mt"] = adata.var["gene_symbols"].str.startswith("MT-")
    adata.var["ribo"] = adata.var["gene_symbols"].str.startswith(("RPS", "RPL"))
    adata.var["hb"] = adata.var["gene_symbols"].str.contains("^HB[^(P)]")

    # Calculate QC metrics
    sc.pp.calculate_qc_metrics(
        adata, qc_vars=["mt", "ribo", "hb"], inplace=True, log1p=True
    )

    # Print summary
    print(f"MT genes detected: {adata.var['mt'].sum()}")
    print(f"Ribosomal genes detected: {adata.var['ribo'].sum()}")
    print(f"Hemoglobin genes detected: {adata.var['hb'].sum()}")
    print(f"Mean MT%: {adata.obs['pct_counts_mt'].mean():.2f}%")
    print(f"Mean ribosomal%: {adata.obs['pct_counts_ribo'].mean():.2f}%")

    return adata

In [None]:
def analyze_qc_and_get_thresholds(adata, plot=True):
    """
    Analyze QC metrics and suggest filtering thresholds

    Parameters:
    -----------
    adata : AnnData
        AnnData object with calculated QC metrics
    plot : bool
        Whether to create QC plots

    Returns:
    --------
    thresholds : dict
        Dictionary containing recommended filtering thresholds
    """
    print(" ")
    print("=== QC METRICS ANALYSIS ===")
    print(f"Total cells: {adata.n_obs}")
    print(f"Total genes: {adata.n_vars}")

    # Basic statistics
    print(" ")
    print("=== PER-CELL METRICS ===")
    metrics = ['n_genes_by_counts', 'total_counts', 'pct_counts_mt', 'pct_counts_ribo']
    for metric in metrics:
        values = adata.obs[metric]
        print(f"{metric}:")
        print(f"  Mean: {values.mean():.1f}")
        print(f"  Median: {values.median():.1f}")
        print(f"  5th-95th percentile: {values.quantile(0.05):.1f} - {values.quantile(0.95):.1f}")

    # Calculate suggested thresholds
    genes_5th = adata.obs['n_genes_by_counts'].quantile(0.05)
    genes_95th = adata.obs['n_genes_by_counts'].quantile(0.95)
    counts_5th = adata.obs['total_counts'].quantile(0.05)
    counts_95th = adata.obs['total_counts'].quantile(0.95)

    # MT threshold calculation
    mt_95th = adata.obs['pct_counts_mt'].quantile(0.95)
    mt_mad = np.median(adata.obs['pct_counts_mt']) + 3 * np.median(
        np.abs(adata.obs['pct_counts_mt'] - np.median(adata.obs['pct_counts_mt']))
    )
    mt_threshold = min(mt_95th, mt_mad, 20)  # Cap at 20%

    # Gene filtering threshold
    cells_threshold = max(3, int(0.001 * adata.n_obs))

    thresholds = {
        'min_genes': int(genes_5th),
        'max_genes': int(genes_95th),
        'min_counts': int(counts_5th),
        'max_counts': int(counts_95th),
        'max_mt_pct': round(mt_threshold, 1),
        'min_cells_per_gene': cells_threshold
    }

    print(" ")
    print("=== SUGGESTED FILTERING THRESHOLDS ===")
    print(f"Genes per cell: {thresholds['min_genes']} - {thresholds['max_genes']}")
    print(f"UMI counts: {thresholds['min_counts']} - {thresholds['max_counts']}")
    print(f"MT percentage: < {thresholds['max_mt_pct']}%")
    print(f"Genes expressed in >= {thresholds['min_cells_per_gene']} cells")

    # Estimate retention
    cell_filter = (
        (adata.obs['n_genes_by_counts'] >= thresholds['min_genes']) &
        (adata.obs['n_genes_by_counts'] <= thresholds['max_genes']) &
        (adata.obs['total_counts'] >= thresholds['min_counts']) &
        (adata.obs['total_counts'] <= thresholds['max_counts']) &
        (adata.obs['pct_counts_mt'] <= thresholds['max_mt_pct'])
    )
    gene_filter = adata.var['n_cells_by_counts'] >= thresholds['min_cells_per_gene']

    print(f"Estimated retention:")
    print(f"Cells: {cell_filter.sum()} / {adata.n_obs} ({cell_filter.sum()/adata.n_obs*100:.1f}%)")
    print(f"Genes: {gene_filter.sum()} / {adata.n_vars} ({gene_filter.sum()/adata.n_vars*100:.1f}%)")

    if plot:
        # Create QC plots
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))

        # Histograms
        axes[0,0].hist(adata.obs['n_genes_by_counts'], bins=50, alpha=0.7)
        axes[0,0].axvline(thresholds['min_genes'], color='red', linestyle='--', label=f'Min: {thresholds["min_genes"]}')
        axes[0,0].axvline(thresholds['max_genes'], color='red', linestyle='--', label=f'Max: {thresholds["max_genes"]}')
        axes[0,0].set_xlabel('Number of genes')
        axes[0,0].set_ylabel('Number of cells')
        axes[0,0].set_title('Genes per cell')
        axes[0,0].legend()

        axes[0,1].hist(adata.obs['total_counts'], bins=50, alpha=0.7)
        axes[0,1].axvline(thresholds['min_counts'], color='red', linestyle='--', label=f'Min: {thresholds["min_counts"]}')
        axes[0,1].axvline(thresholds['max_counts'], color='red', linestyle='--', label=f'Max: {thresholds["max_counts"]}')
        axes[0,1].set_xlabel('Total UMI counts')
        axes[0,1].set_ylabel('Number of cells')
        axes[0,1].set_title('UMI counts per cell')
        axes[0,1].legend()

        axes[0,2].hist(adata.obs['pct_counts_mt'], bins=50, alpha=0.7)
        axes[0,2].axvline(thresholds['max_mt_pct'], color='red', linestyle='--', label=f'Max: {thresholds["max_mt_pct"]}%')
        axes[0,2].set_xlabel('Mitochondrial gene %')
        axes[0,2].set_ylabel('Number of cells')
        axes[0,2].set_title('MT% per cell')
        axes[0,2].legend()

        axes[1,0].hist(adata.var['n_cells_by_counts'], bins=50, alpha=0.7)
        axes[1,0].axvline(thresholds['min_cells_per_gene'], color='red', linestyle='--', label=f'Min: {thresholds["min_cells_per_gene"]}')
        axes[1,0].set_xlabel('Number of cells')
        axes[1,0].set_ylabel('Number of genes')
        axes[1,0].set_title('Cells per gene')
        axes[1,0].set_yscale('log')
        axes[1,0].legend()

        # Scatter plots
        axes[1,1].scatter(adata.obs['total_counts'], adata.obs['n_genes_by_counts'], alpha=0.5, s=1)
        axes[1,1].set_xlabel('Total UMI counts')
        axes[1,1].set_ylabel('Number of genes')
        axes[1,1].set_title('Total counts vs Genes detected')

        axes[1,2].scatter(adata.obs['total_counts'], adata.obs['pct_counts_mt'], alpha=0.5, s=1)
        axes[1,2].set_xlabel('Total UMI counts')
        axes[1,2].set_ylabel('Mitochondrial gene %')
        axes[1,2].set_title('Total counts vs MT%')

        plt.tight_layout()
        plt.show()

    return thresholds

In [None]:
def apply_filtering(adata, thresholds):
    """
    Apply filtering based on provided thresholds

    Parameters:
    -----------
    adata : AnnData
        AnnData object to filter
    thresholds : dict
        Dictionary with filtering thresholds from analyze_qc_and_get_thresholds()

    Returns:
    --------
    adata : AnnData
        Filtered AnnData object
    """
    print(" ")
    print("=== APPLYING FILTERING ===")
    print(f"Before filtering: {adata.n_obs} cells, {adata.n_vars} genes")

    # Filter genes first (expressed in at least min_cells_per_gene cells)
    sc.pp.filter_genes(adata, min_cells=thresholds['min_cells_per_gene'])
    print(f"After gene filtering (>= {thresholds['min_cells_per_gene']} cells): {adata.n_vars} genes")

    # Filter cells based on gene counts
    adata = adata[adata.obs.n_genes_by_counts >= thresholds['min_genes'], :]
    print(f"After min genes filtering (>= {thresholds['min_genes']}): {adata.n_obs} cells")

    adata = adata[adata.obs.n_genes_by_counts <= thresholds['max_genes'], :]
    print(f"After max genes filtering (<= {thresholds['max_genes']}): {adata.n_obs} cells")

    # Filter cells based on UMI counts
    adata = adata[adata.obs.total_counts >= thresholds['min_counts'], :]
    print(f"After min UMI filtering (>= {thresholds['min_counts']}): {adata.n_obs} cells")

    adata = adata[adata.obs.total_counts <= thresholds['max_counts'], :]
    print(f"After max UMI filtering (<= {thresholds['max_counts']}): {adata.n_obs} cells")

    # Filter cells with high mitochondrial percentage
    adata = adata[adata.obs.pct_counts_mt <= thresholds['max_mt_pct'], :]
    print(f"After MT% filtering (<= {thresholds['max_mt_pct']}%): {adata.n_obs} cells")

    print(f"Final dataset: {adata.n_obs} cells, {adata.n_vars} genes")
    return adata

In [None]:
def process_adata(adata, thresholds=None, analyze_qc=False):
    """
    Process an existing AnnData object: calculate QC metrics, analyze QC, and apply filtering.

    Parameters:
    -----------
    adata : AnnData
        Pre-loaded AnnData object
    thresholds : dict, optional
        Pre-defined filtering thresholds. If None, thresholds will be calculated from the data
    analyze_qc : bool
        Whether to perform QC analysis and plotting

    Returns:
    --------
    adata : AnnData
        Processed and filtered AnnData object
    thresholds : dict
        Filtering thresholds used (useful for applying same thresholds to other samples)
    """
    print(f"\n{'='*50}")
    print(f"Processing AnnData object with {adata.n_obs} cells and {adata.n_vars} genes")
    print(f"{'='*50}")

    # Calculate QC metrics
    adata = calculate_qc_metrics(adata)

    # Analyze QC and get thresholds (or use provided ones)
    if thresholds is None:
        thresholds = analyze_qc_and_get_thresholds(adata, plot=analyze_qc)
    else:
        print("Using provided thresholds:")
        for key, value in thresholds.items():
            print(f"  {key}: {value}")

    # Apply filtering
    adata = apply_filtering(adata, thresholds)

    print(f"Successfully processed AnnData object")
    return adata

In [None]:
healthy_adata = load_and_merge_samples(healthy_samples, group_name="healthy")

Loading sample: GSM4041150_healthy1_cd45+
Loading sample: GSM4041151_healthy1_cd45-A
Loading sample: GSM4041152_healthy1_cd45-B
Loading sample: GSM4041153_healthy2_cd45+
Loading sample: GSM4041154_healthy2_cd45-
Loading sample: GSM4041155_healthy3_cd45+
Loading sample: GSM4041156_healthy3_cd45-A
Loading sample: GSM4041157_healthy3_cd45-B
Loading sample: GSM4041158_healthy4_cd45+
Loading sample: GSM4041159_healthy4_cd45-
Loading sample: GSM4041160_healthy5_cd45+
Merged healthy dataset: 35685 cells, 33694 genes
Var columns: ['gene_symbols']


  utils.warn_names_duplicates("obs")


In [None]:
cirrhotic_adata = load_and_merge_samples(diseased_samples, group_name="diseased")

Loading sample: GSM4041161_cirrhotic1_cd45+
Loading sample: GSM4041162_cirrhotic1_cd45-A
Loading sample: GSM4041163_cirrhotic1_cd45-B
Loading sample: GSM4041164_cirrhotic2_cd45+
Loading sample: GSM4041165_cirrhotic2_cd45-
Loading sample: GSM4041166_cirrhotic3_cd45+
Loading sample: GSM4041167_cirrhotic3_cd45-
Loading sample: GSM4041168_cirrhotic4_cd45+
Loading sample: GSM4041169_cirrhotic5_cd45+
Merged diseased dataset: 26525 cells, 33694 genes
Var columns: ['gene_symbols']


  utils.warn_names_duplicates("obs")


In [None]:
healthy_adata

AnnData object with n_obs × n_vars = 35685 × 33694
    obs: 'sample_id', 'batch', 'group'
    var: 'gene_symbols'

In [None]:
healthy_adata.var

Unnamed: 0,gene_symbols
ENSG00000243485,RP11-34P13.3
ENSG00000237613,FAM138A
ENSG00000186092,OR4F5
ENSG00000238009,RP11-34P13.7
ENSG00000239945,RP11-34P13.8
...,...
ENSG00000277856,AC233755.2
ENSG00000275063,AC233755.1
ENSG00000271254,AC240274.1
ENSG00000277475,AC213203.1


In [None]:
healthy_adata_processed = process_adata(healthy_adata)


Processing AnnData object with 35685 cells and 17934 genes
Calculating QC metrics...
MT genes detected: 13
Ribosomal genes detected: 100
Hemoglobin genes detected: 7
Mean MT%: 4.89%
Mean ribosomal%: 27.21%
 
=== QC METRICS ANALYSIS ===
Total cells: 35685
Total genes: 17934
 
=== PER-CELL METRICS ===
n_genes_by_counts:
  Mean: 1237.9
  Median: 1106.0
  5th-95th percentile: 458.0 - 2673.0
total_counts:
  Mean: 3936.6
  Median: 2992.0
  5th-95th percentile: 866.0 - 10773.0
pct_counts_mt:
  Mean: 4.9
  Median: 3.1
  5th-95th percentile: 1.1 - 8.4
pct_counts_ribo:
  Mean: 27.2
  Median: 27.0
  5th-95th percentile: 11.6 - 43.8
 
=== SUGGESTED FILTERING THRESHOLDS ===
Genes per cell: 458 - 2673
UMI counts: 866 - 10772
MT percentage: < 6.2%
Genes expressed in >= 35 cells
Estimated retention:
Cells: 28306 / 35685 (79.3%)
Genes: 17934 / 17934 (100.0%)
 
=== APPLYING FILTERING ===
Before filtering: 35685 cells, 17934 genes
After gene filtering (>= 35 cells): 17934 genes
After min genes filtering

  utils.warn_names_duplicates("obs")


In [None]:
# Save as .h5ad
healthy_adata_processed.write("/content/drive/MyDrive/projects/scGPT-MAFLD/data/healthy_samples.h5ad")

  df[key] = c
  utils.warn_names_duplicates("obs")
  df[key] = c
  utils.warn_names_duplicates("obs")
  df[key] = c
  utils.warn_names_duplicates("obs")
  df[key] = c
  utils.warn_names_duplicates("obs")


In [None]:
cirrhotic_adata_processed = process_adata(cirrhotic_adata)


Processing AnnData object with 26525 cells and 33694 genes
Calculating QC metrics...
MT genes detected: 13
Ribosomal genes detected: 103
Hemoglobin genes detected: 12
Mean MT%: 6.15%
Mean ribosomal%: 27.68%
 
=== QC METRICS ANALYSIS ===
Total cells: 26525
Total genes: 33694
 
=== PER-CELL METRICS ===
n_genes_by_counts:
  Mean: 1611.5
  Median: 1384.0
  5th-95th percentile: 805.2 - 3129.0
total_counts:
  Mean: 5699.9
  Median: 4151.0
  5th-95th percentile: 2216.0 - 14030.8
pct_counts_mt:
  Mean: 6.2
  Median: 3.6
  5th-95th percentile: 1.6 - 15.5
pct_counts_ribo:
  Mean: 27.7
  Median: 27.2
  5th-95th percentile: 11.1 - 45.0
 
=== SUGGESTED FILTERING THRESHOLDS ===
Genes per cell: 805 - 3129
UMI counts: 2216 - 14030
MT percentage: < 7.0%
Genes expressed in >= 26 cells
Estimated retention:
Cells: 19605 / 26525 (73.9%)
Genes: 18916 / 33694 (56.1%)
 
=== APPLYING FILTERING ===
Before filtering: 26525 cells, 33694 genes
After gene filtering (>= 26 cells): 18916 genes
After min genes filter

  utils.warn_names_duplicates("obs")


In [None]:
cirrhotic_adata_processed.write("/content/drive/MyDrive/projects/scGPT-MAFLD/data/cirrhotic_samples.h5ad")

  df[key] = c
  utils.warn_names_duplicates("obs")
  df[key] = c
  utils.warn_names_duplicates("obs")
  df[key] = c
  utils.warn_names_duplicates("obs")
  df[key] = c
  utils.warn_names_duplicates("obs")
