In [None]:
# %% [code]
import os
import sys
import pandas as pd
import numpy as np
import scanpy as sc
import scipy.io
import scipy.sparse
import anndata
import gzip
import urllib.request
import warnings
import gc
from pathlib import Path

# Suppress warnings
warnings.filterwarnings('ignore')

# Constants
GEO_URL = "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE283nnn/GSE283614/suppl"
FILE_PREFIXES = ["GSE283614_ESC_perturb-seq", "GSE283614_DE_perturb-seq"]
FILE_SUFFIXES = ["_barcodes.tsv.gz", "_features.tsv.gz", "_matrix.mtx.gz"]

def download_files(data_dir):
    """
    Download the necessary files from GEO if they don't exist.
    """
    print(f"Checking for required files in {data_dir}...")
    for prefix in FILE_PREFIXES:
        for suffix in FILE_SUFFIXES:
            filename = f"{prefix}{suffix}"
            filepath = os.path.join(data_dir, filename)
            if not os.path.exists(filepath):
                print(f"Downloading {filename}...")
                url = f"{GEO_URL}/{filename}"
                try:
                    urllib.request.urlretrieve(url, filepath)
                    print(f"Downloaded {filename}")
                except Exception as e:
                    print(f"Error downloading {filename}: {e}")
            else:
                print(f"File {filename} already exists")

def read_mtx_files(data_dir, prefix):
    """
    Read 10X-style mtx files and create an AnnData object.
    """
    print(f"Reading {prefix} files...")
    matrix_file = os.path.join(data_dir, f"{prefix}_matrix.mtx.gz")
    features_file = os.path.join(data_dir, f"{prefix}_features.tsv.gz")
    barcodes_file = os.path.join(data_dir, f"{prefix}_barcodes.tsv.gz")
    
    print(f"Reading features from {features_file}...")
    with gzip.open(features_file, 'rt') as f:
        feature_df = pd.read_csv(f, sep='\t', header=None)
    
    print(f"Reading barcodes from {barcodes_file}...")
    with gzip.open(barcodes_file, 'rt') as f:
        barcodes = pd.read_csv(f, sep='\t', header=None)[0].values
    
    var = pd.DataFrame({
        'gene_id': feature_df[0].values,
        'gene_name': feature_df[1].values,
        'feature_type': feature_df[2].values if feature_df.shape[1] > 2 else 'Gene Expression'
    })
    var.index = var['gene_name'].values
    
    print(f"Reading matrix from {matrix_file}...")
    matrix = scipy.io.mmread(matrix_file).T.tocsr()
    
    print(f"Creating AnnData object with shape {matrix.shape}...")
    adata = anndata.AnnData(X=matrix, var=var)
    adata.obs_names = barcodes
    
    gc.collect()
    return adata

def infer_perturbation_groups(adata, target_genes=['POU5F1', 'NANOG']):
    """
    Infer perturbation groups based on expression patterns of target genes.
    """
    print("Inferring perturbation groups based on expression patterns...")
    
    # Save raw counts
    if scipy.sparse.issparse(adata.X):
        adata.layers['counts'] = adata.X.copy()
    else:
        adata.layers['counts'] = scipy.sparse.csr_matrix(adata.X.copy())
    
    # Normalize and log-transform
    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)
    
    available_targets = [gene for gene in target_genes if gene in adata.var_names]
    if not available_targets:
        print(f"Warning: None of the target genes {target_genes} found in dataset")
        adata.obs['inferred_perturbation'] = 'Unknown'
        adata.obs['condition'] = 'Unknown'
        return adata
    
    print(f"Found target genes: {available_targets}")
    perturbation_status = pd.DataFrame(index=adata.obs_names)
    
    for gene in available_targets:
        if gene in adata.var_names:
            print(f"Processing gene {gene}...")
            gene_expr = adata[:, gene].X.toarray().flatten()
            low_threshold = np.quantile(gene_expr, 0.33)
            high_threshold = np.quantile(gene_expr, 0.67)
            perturbation_status[f"{gene}_status"] = pd.cut(
                gene_expr,
                bins=[float('-inf'), low_threshold, high_threshold, float('inf')],
                labels=['low', 'mid', 'high']
            )
            print(f"Gene {gene} expression quantiles: low < {low_threshold:.2f}, high > {high_threshold:.2f}")
            del gene_expr
            gc.collect()
    
    print("Creating perturbation labels...")
    def combine_perturbations(row):
        return '+'.join([f"{gene}_{row[f'{gene}_status']}" for gene in available_targets])
    
    adata.obs['inferred_perturbation'] = perturbation_status.apply(combine_perturbations, axis=1)
    
    print("Assigning condition labels...")
    def assign_condition(pert):
        if any('_low' in p for p in pert.split('+')):
            return 'Perturbed'
        else:
            return 'Control'
    
    adata.obs['condition'] = adata.obs['inferred_perturbation'].apply(assign_condition)
    
    pert_counts = adata.obs['inferred_perturbation'].value_counts()
    print("\nInferred perturbation groups:")
    for pert, count in pert_counts.items():
        print(f"  {pert}: {count} cells")
    
    adata.uns['perturbation_note'] = (
        "Perturbation assignments are inferred from gene expression patterns "
        "and should be considered preliminary. For accurate perturbation information, "
        "the guide RNA readout data should be integrated when available."
    )
    
    del perturbation_status
    gc.collect()
    return adata

def extract_perturbation_info(adata, cell_type):
    """
    Add metadata and infer perturbation information for the dataset.
    """
    print(f"Adding metadata for {cell_type} dataset...")
    adata.obs['organism'] = 'Homo sapiens'
    adata.obs['cell_type'] = 'Embryonic Stem Cells' if cell_type == 'ESC' else 'Definitive Endoderm'
    adata.obs['crispr_type'] = 'CRISPR KO'
    adata.obs['cancer_type'] = 'Non-Cancer'
    
    adata = infer_perturbation_groups(adata, target_genes=['POU5F1', 'NANOG'])
    
    print("Creating perturbation_name field...")
    adata.obs['perturbation_name'] = adata.obs['inferred_perturbation'].apply(
        lambda x: x.replace('_low', ' KO').replace('_mid', '').replace('_high', ' OE').replace('+', ' + ')
    )
    
    # ---- Ad hoc modifications start here ----
    # Convert 'condition' to string type
    adata.obs["condition"] = adata.obs["condition"].astype(str)
    # Update values where 'condition' is 'Perturbed' to 'Test'
    adata.obs.loc[adata.obs["condition"] == "Perturbed", "condition"] = "Test"
    # Convert back to categorical (if needed)
    adata.obs["condition"] = adata.obs["condition"].astype("category")
    
    # Remove "KO" and "OE" as whole words from the 'perturbation_name' column
    adata.obs['perturbation_name'] = adata.obs['perturbation_name'].str.replace(r'\b(KO|OE)\b', '', regex=True)
    # Clean up extra spaces that may result from the removal
    adata.obs['perturbation_name'] = adata.obs['perturbation_name'].str.replace(r'\s+', ' ', regex=True).str.strip()
    
    # Set perturbation_name to "Non-Targeting" where condition is "Control"
    adata.obs.loc[adata.obs['condition'] == 'Control', 'perturbation_name'] = 'Non-Targeting'
    # ---- Ad hoc modifications end here ----
    
    adata.uns['dataset_id'] = 'GSE283614'
    adata.uns['dataset_name'] = 'Enhancer perturbations of OCT4 and NANOG in human ESCs'
    adata.uns['dataset_description'] = (
        "Single cell perturb-seq of OCT4 and NANOG enhancer perturbation library on "
        "embryonic stem cell and definitive endoderm"
    )
    return adata

def process_dataset(data_dir, dataset_type):
    """
    Process a single dataset (ESC or DE) and save it as an h5ad file.
    """
    print(f"Processing {dataset_type} data...")
    output_dir = os.path.join(data_dir, "harmonized")
    os.makedirs(output_dir, exist_ok=True)
    output_file = os.path.join(output_dir, f"GSE283614_{dataset_type}_harmonized.h5ad")
    
    if os.path.exists(output_file):
        print(f"Harmonized {dataset_type} dataset already exists at {output_file}")
        return
    
    try:
        adata = read_mtx_files(data_dir, f"GSE283614_{dataset_type}_perturb-seq")
        adata = extract_perturbation_info(adata, dataset_type)
        adata.obs['dataset_origin'] = dataset_type
        
        print(f"Saving {dataset_type} dataset to {output_file}...")
        adata.write(output_file)
        print(f"{dataset_type} dataset: {adata.shape[0]} cells, {adata.shape[1]} genes")
        del adata
        gc.collect()
        
    except Exception as e:
        print(f"Error processing {dataset_type} dataset: {e}")
        print(f"Skipping {dataset_type} dataset")

def create_combined_dataset(data_dir):
    """
    Combine the processed ESC and DE datasets and save the result.
    """
    print("Creating combined dataset...")
    output_dir = os.path.join(data_dir, "harmonized")
    os.makedirs(output_dir, exist_ok=True)
    output_file = os.path.join(output_dir, "GSE283614_combined_harmonized.h5ad")
    
    if os.path.exists(output_file):
        print(f"Combined dataset already exists at {output_file}")
        return
    
    esc_file = os.path.join(output_dir, "GSE283614_ESC_harmonized.h5ad")
    de_file = os.path.join(output_dir, "GSE283614_DE_harmonized.h5ad")
    
    has_esc = os.path.exists(esc_file)
    has_de = os.path.exists(de_file)
    
    if not has_esc and not has_de:
        print("No processed datasets found. Cannot create combined dataset.")
        return
    
    try:
        if has_esc and has_de:
            print(f"Loading ESC dataset from {esc_file}...")
            esc_adata = anndata.read_h5ad(esc_file)
            print(f"Loading DE dataset from {de_file}...")
            de_adata = anndata.read_h5ad(de_file)
            print("Combining ESC and DE datasets...")
            combined_adata = anndata.AnnData.concatenate(
                esc_adata, de_adata,
                join='outer',
                batch_key='batch',
                batch_categories=['ESC', 'DE']
            )
            del esc_adata, de_adata
            gc.collect()
        elif has_esc:
            print(f"Loading ESC dataset from {esc_file}...")
            combined_adata = anndata.read_h5ad(esc_file)
            combined_adata.uns['de_dataset_note'] = (
                "The DE dataset could not be processed. "
                "This combined dataset contains only the ESC data."
            )
        elif has_de:
            print(f"Loading DE dataset from {de_file}...")
            combined_adata = anndata.read_h5ad(de_file)
            combined_adata.uns['esc_dataset_note'] = (
                "The ESC dataset could not be processed. "
                "This combined dataset contains only the DE data."
            )
        
        print(f"Saving combined dataset to {output_file}...")
        combined_adata.write(output_file)
        print(f"Combined dataset: {combined_adata.shape[0]} cells, {combined_adata.shape[1]} genes")
        
    except Exception as e:
        print(f"Error creating combined dataset: {e}")

def harmonize_dataset(data_dir):
    """
    Harmonize the GSE283614 dataset into h5ad format.
    """
    print(f"Processing data in {data_dir}")
    download_files(data_dir)
    process_dataset(data_dir, "ESC")
    process_dataset(data_dir, "DE")
    create_combined_dataset(data_dir)
    print("\nHarmonization complete!")
    print(f"Harmonized data saved to {os.path.join(data_dir, 'harmonized')}")

# Set the data directory (change this to your desired path)
data_dir = "/content"  # e.g., "./data" or an absolute path

# Run the harmonization
harmonize_dataset(data_dir)
