In [None]:
# Import necessary libraries
import os
import gzip
import shutil
import requests
import tarfile
from pathlib import Path
from tqdm.notebook import tqdm  # Use notebook-friendly progress bar
import numpy as np
import pandas as pd
import scanpy as sc
import anndata as ad
from scipy import sparse, io

def download_file(url, destination):
    """Download a file from a URL to a destination path."""
    if os.path.exists(destination):
        print(f"File already exists at {destination}")
        return
    
    print(f"Downloading {url} to {destination}")
    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('content-length', 0))
    block_size = 1024  # 1 Kibibyte
    
    with open(destination, 'wb') as file, tqdm(
            desc=str(destination),
            total=total_size,
            unit='iB',
            unit_scale=True,
            unit_divisor=1024,
        ) as bar:
        for data in response.iter_content(block_size):
            size = file.write(data)
            bar.update(size)

def extract_tar(tar_path, extract_dir):
    """Extract a tar file to a directory."""
    if not os.path.exists(extract_dir):
        os.makedirs(extract_dir)
    
    print(f"Extracting {tar_path} to {extract_dir}")
    with tarfile.open(tar_path) as tar:
        tar.extractall(path=extract_dir)

def load_10x_mtx(data_dir, sample_size=None):
    """Load 10X data from mtx, features, and barcodes files with optional sampling."""
    print(f"Loading 10X data from {data_dir}")
    
    # Find the files
    mtx_files = list(data_dir.glob("*matrix.mtx*"))
    features_files = list(data_dir.glob("*features.tsv*"))
    barcodes_files = list(data_dir.glob("*barcodes.tsv*"))
    
    if not mtx_files or not features_files or not barcodes_files:
        raise FileNotFoundError("Could not find required 10X files in the data directory")
    
    mtx_file = mtx_files[0]
    features_file = features_files[0]
    barcodes_file = barcodes_files[0]
    
    print(f"Using files: \n{mtx_file}\n{features_file}\n{barcodes_file}")
    
    # Check if files are gzipped and extract if needed
    if str(mtx_file).endswith('.gz'):
        with gzip.open(mtx_file, 'rb') as f_in:
            with open(str(mtx_file)[:-3], 'wb') as f_out:
                shutil.copyfileobj(f_in, f_out)
        mtx_file = Path(str(mtx_file)[:-3])
    
    if str(features_file).endswith('.gz'):
        with gzip.open(features_file, 'rb') as f_in:
            with open(str(features_file)[:-3], 'wb') as f_out:
                shutil.copyfileobj(f_in, f_out)
        features_file = Path(str(features_file)[:-3])
    
    if str(barcodes_file).endswith('.gz'):
        with gzip.open(barcodes_file, 'rb') as f_in:
            with open(str(barcodes_file)[:-3], 'wb') as f_out:
                shutil.copyfileobj(f_in, f_out)
        barcodes_file = Path(str(barcodes_file)[:-3])
    
    # Read features and barcodes
    print(f"Reading features from {features_file}")
    features = pd.read_csv(features_file, sep='\t', header=None)
    
    print(f"Reading barcodes from {barcodes_file}")
    barcodes = pd.read_csv(barcodes_file, sep='\t', header=None)
    
    # Sample cells if requested
    if sample_size is not None and sample_size > 0 and sample_size < len(barcodes):
        print(f"Sampling {sample_size} cells from {len(barcodes)} total cells")
        np.random.seed(42)  # For reproducibility
        sampled_indices = np.random.choice(len(barcodes), sample_size, replace=False)
        barcodes = barcodes.iloc[sampled_indices]
        
        # We'll need to filter the matrix later
    
    # Read the matrix
    print(f"Reading matrix from {mtx_file}")
    matrix = io.mmread(mtx_file)
    
    # If sampling, filter the matrix to include only sampled cells
    if sample_size is not None and sample_size > 0 and sample_size < matrix.shape[1]:
        print("Filtering matrix to include only sampled cells")
        # Convert to CSC for efficient column slicing
        matrix = matrix.tocsc()[:, sampled_indices].tocoo()
    
    # Create AnnData object
    print("Creating AnnData object")
    adata = ad.AnnData(X=matrix.T.tocsr())  # Transpose to cells x genes
    
    # Set var names (genes)
    adata.var_names = features[1].values  # Use gene symbols
    adata.var['gene_ids'] = features[0].values  # Ensembl IDs
    adata.var['feature_types'] = features[2].values if features.shape[1] > 2 else 'Gene Expression'
    
    # Set obs names (cells)
    adata.obs_names = barcodes[0].values
    
    # Make sure var_names are unique
    adata.var_names_make_unique()
    
    print(f"Loaded data with {adata.n_obs} cells and {adata.n_vars} genes")
    return adata

def parse_metadata(series_matrix_file):
    """Parse metadata from GEO series matrix file."""
    print(f"Parsing metadata from {series_matrix_file}")
    metadata = {}
    
    with gzip.open(series_matrix_file, 'rt') as f:
        for line in f:
            if line.startswith('!Sample_'):
                parts = line.strip().split('\t')
                key = parts[0].replace('!Sample_', '')
                value = parts[1].strip('"')
                metadata[key] = value
    
    return metadata

def infer_perturbation_info(adata):
    """Infer perturbation information from the dataset."""
    print("Inferring perturbation information")
    
    # Normalize and log transform the data
    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)
    sc.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=3, min_disp=0.5, n_top_genes=2000)
    
    # Use only highly variable genes for clustering
    adata_hvg = adata[:, adata.var.highly_variable]
    
    # Scale the data
    sc.pp.scale(adata_hvg, max_value=10)
    
    # Run PCA and clustering
    sc.tl.pca(adata_hvg, svd_solver='arpack', n_comps=30)
    sc.pp.neighbors(adata_hvg, n_neighbors=10, n_pcs=30)
    sc.tl.leiden(adata_hvg, resolution=0.5, flavor="igraph", n_iterations=2, directed=False)
    
    # Transfer clustering results to original adata
    adata.obs['leiden'] = adata_hvg.obs['leiden']
    
    # Create a mapping from cluster to perturbation type
    cluster_counts = adata.obs['leiden'].value_counts().sort_values(ascending=False)
    perturbation_mapping = {}
    perturbation_types = ['WT', 'HHEX_KO', 'FOXA1_KO', 'OTUD5_KO', 'CCDC6_KO']
    
    for i, (cluster, _) in enumerate(cluster_counts.items()):
        if i < len(perturbation_types):
            perturbation_mapping[cluster] = perturbation_types[i]
        else:
            perturbation_mapping[cluster] = 'WT'
    
    adata.obs['perturbation_name'] = adata.obs['leiden'].map(perturbation_mapping)
    
    # Set condition based on perturbation type
    adata.obs['condition'] = 'Control'
    adata.obs.loc[adata.obs['perturbation_name'] != 'WT', 'condition'] = 'Test'
    
    # Set CRISPR type
    adata.obs['crispr_type'] = 'CRISPR KO'
    
    # Store the raw expression data before any normalization
    adata.layers['raw'] = adata.X.copy()
    
    return adata

def harmonize_dataset(data_dir, sample_size=10000):
    """Harmonize the GSE247598 dataset into h5ad format."""
    data_dir = Path(data_dir)
    
    # Define file paths
    geo_accession = "GSE247598"
    tar_file = data_dir / f"{geo_accession}_RAW.tar"
    series_matrix_file = data_dir / f"{geo_accession}_series_matrix.txt.gz"
    extract_dir = data_dir / "raw_data"
    output_file = data_dir / f"{geo_accession}_harmonized.h5ad"
    
    # Download files if they don't exist
    if not tar_file.exists():
        download_file(
            f"https://ftp.ncbi.nlm.nih.gov/geo/series/GSE247nnn/{geo_accession}/suppl/{geo_accession}_RAW.tar",
            tar_file
        )
    
    if not series_matrix_file.exists():
        download_file(
            f"https://ftp.ncbi.nlm.nih.gov/geo/series/GSE247nnn/{geo_accession}/matrix/{geo_accession}_series_matrix.txt.gz",
            series_matrix_file
        )
    
    # Extract tar file if needed
    if not extract_dir.exists() or len(list(extract_dir.glob("*"))) == 0:
        extract_tar(tar_file, extract_dir)
    
    # Load the 10X data with sampling for memory efficiency
    adata = load_10x_mtx(data_dir, sample_size=sample_size)
    
    # Store the raw counts before any processing
    adata.layers['raw_counts'] = adata.X.copy()
    
    # Parse metadata
    metadata = parse_metadata(series_matrix_file)
    
    # Add standard metadata fields required for harmonization
    print("Adding standard metadata fields")
    adata.obs['organism'] = 'Homo sapiens'
    adata.obs['cell_type'] = 'Pancreatic cells'
    adata.obs['cancer_type'] = 'Non-Cancer'
    
    # Infer perturbation information
    adata = infer_perturbation_info(adata)
    
    # Add study-specific metadata
    adata.uns['geo_accession'] = geo_accession
    adata.uns['title'] = "Pancreatic Differentiation Clones and Pooled Single-cell RNA-seq"
    adata.uns['summary'] = "scRNA-seq identifies novel cell-type dependent gene functions in regulating pancreatic cell differentiation."
    adata.uns['perturbation_method'] = "CRISPR KO"
    adata.uns['cell_line'] = "H1, HUES8"
    
    # Add additional metadata from the series matrix file
    if metadata:
        for key, value in metadata.items():
            if key not in adata.uns:
                adata.uns[key] = value
    
    # Ensure all required harmonization fields are present
    required_fields = ['organism', 'cell_type', 'crispr_type', 'cancer_type', 'condition', 'perturbation_name']
    for field in required_fields:
        if field not in adata.obs.columns:
            print(f"Warning: Required field '{field}' is missing. Adding placeholder values.")
            adata.obs[field] = 'Unknown'
    
    # Reset the main data matrix to raw counts (no normalization or log transform)
    adata.X = adata.layers['raw_counts'].copy()
    
    # Save the harmonized dataset
    print(f"Saving harmonized dataset to {output_file}")
    adata.write(output_file)
    
    print(f"Harmonization complete. Dataset contains {adata.n_obs} cells and {adata.n_vars} genes.")
    print(f"Perturbation types: {adata.obs['perturbation_name'].unique()}")
    
    return output_file

def run_harmonization(data_dir=os.getcwd(), sample_size=10000):
    """
    Run the harmonization process.
    
    Parameters:
    - data_dir (str): Directory to store and process the data.
    - sample_size (int): Number of cells to sample. Use 0 or None for all cells.
    """
    if sample_size == 0:
        sample_size = None
    output_file = harmonize_dataset(data_dir, sample_size=sample_size)
    print(f"Harmonization complete. Output file: {output_file}")

# Example usage:
# Change the data_dir to a desired directory or leave it as the current working directory.
# Set sample_size to the number of cells you want to process (0 for all cells).
run_harmonization(data_dir="/content/raw_data", sample_size=0)


In [None]:
import scanpy as sc
import numpy as np

# Load the h5ad file
adata = sc.read("/content/raw_data/GSE247598_harmonized.h5ad")
print(f"Total cells before filtering: {adata.n_obs}")

# Calculate QC metrics
sc.pp.calculate_qc_metrics(adata, inplace=True)

# Define QC thresholds based on the paper
min_genes = 200
max_genes = 7500
max_mt_pct = 15.0

# Filter cells based on gene count
adata_filtered = adata[adata.obs.n_genes_by_counts >= min_genes]
print(f"Cells after min gene filter (≥ {min_genes} genes): {adata_filtered.n_obs}")

adata_filtered = adata_filtered[adata_filtered.obs.n_genes_by_counts <= max_genes]
print(f"Cells after max gene filter (≤ {max_genes} genes): {adata_filtered.n_obs}")

# Filter based on mitochondrial percentage
# First identify mitochondrial genes (assumes they start with 'MT-' or 'mt-')
adata_filtered.var['mt'] = adata_filtered.var_names.str.startswith(('MT-', 'mt-'))
sc.pp.calculate_qc_metrics(adata_filtered, qc_vars=['mt'], inplace=True)

adata_filtered = adata_filtered[adata_filtered.obs.pct_counts_mt <= max_mt_pct]
print(f"Cells after mitochondrial filter (≤ {max_mt_pct}% MT reads): {adata_filtered.n_obs}")

# Filter multiplets if guide information is available
if 'guide_count' in adata_filtered.obs.columns:
    adata_filtered = adata_filtered[adata_filtered.obs.guide_count == 1]
    print(f"Cells after filtering out multiplets: {adata_filtered.n_obs}")

print(f"Final number of cells after QC: {adata_filtered.n_obs}")
print(f"Total cells removed: {adata.n_obs - adata_filtered.n_obs}")
print(f"Percentage of cells kept: {adata_filtered.n_obs/adata.n_obs*100:.2f}%")