In [None]:
import os
import sys
import gzip
import shutil
import pandas as pd
import numpy as np
import scanpy as sc
import anndata as ad
from scipy import sparse
from scipy.io import mmread
import urllib.request
from pathlib import Path
import re

def download_file(url, output_path):
    """Download a file from a URL to the specified output path."""
    print(f"Downloading {url} to {output_path}...")
    urllib.request.urlretrieve(url, output_path)
    print(f"Downloaded {output_path}")

def ensure_files_exist(data_dir):
    """Ensure all required files exist, downloading them if necessary."""
    data_dir = Path(data_dir)
    data_dir.mkdir(exist_ok=True, parents=True)
    
    base_url = "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE236nnn/GSE236057/suppl/"
    files = [
        "GSE236057_FinalCells_Barcodes.tsv.gz",
        "GSE236057_FinalCells_Counts.mtx.gz",
        "GSE236057_FinalCells_GeneNames.tsv.gz",
        "GSE236057_FinalCells_Metadata.csv.gz"
    ]
    
    for file in files:
        file_path = data_dir / file
        if not file_path.exists():
            url = f"{base_url}{file}"
            download_file(url, file_path)
    
    return data_dir

def load_data(data_dir):
    """Load the dataset from the specified directory."""
    data_dir = Path(data_dir)
    
    print("Loading gene names...")
    genes_df = pd.read_csv(
        data_dir / "GSE236057_FinalCells_GeneNames.tsv.gz",
        sep='\t', header=None, names=['gene_symbol', 'ensembl_id', 'feature_type']
    )
    
    print("Loading cell barcodes...")
    barcodes = pd.read_csv(
        data_dir / "GSE236057_FinalCells_Barcodes.tsv.gz",
        sep='\t', header=None, names=['barcode']
    )
    
    print("Loading count matrix...")
    count_matrix = mmread(str(data_dir / "GSE236057_FinalCells_Counts.mtx.gz")).T.tocsr()
    
    print("Loading metadata...")
    metadata = pd.read_csv(data_dir / "GSE236057_FinalCells_Metadata.csv.gz", index_col=0)
    
    print("Creating AnnData object...")
    adata = ad.AnnData(
        X=count_matrix,
        obs=metadata,
        var=genes_df.set_index('gene_symbol')
    )
    adata.obs.index = barcodes['barcode']
    adata.obs_names_make_unique()
    
    return adata


def process_guide_information(adata):
    """
    Identify all columns in adata.obs that contain a guide (look for '_g').
    - If it's an enhancer-style column (e.g., 'Enh12_g3_chr...'), rename it to 'Enhancer_12'.
    - If it starts with 'Pos_', remove 'Pos_' and any trailing '_g\\d+.*' (e.g. 'Pos_RPL29_g1' -> 'RPL29').
    - Otherwise, remove the trailing '_g\\d+.*' from the column name (e.g. 'RPL29_g1' -> 'RPL29').
    - Combine multiple guides in the same cell using '+'.
    - Finally, set `condition='Control'` for Non-targeting cells, else 'Test'.
    """
    print("Processing guide information (removing _g1, _g2, etc.)...")
    
    # Start all cells as "Non-targeting"
    adata.obs['perturbation_name'] = 'Non-targeting'
    
    # Identify any column that contains '_g' (typical CRISPR guide naming)
    guide_cols = [col for col in adata.obs.columns if '_g' in col]
    
    # Regex to detect enhancers of the form "Enh(\d+)_g(\d+)_chr..."
    enhancer_pat = re.compile(r'^Enh(\d+)_g\d+_chr')
    
    for col in guide_cols:
        # 1) Check if it matches the enhancer pattern (Enh12_g3_chrXYZ)
        m_enh = enhancer_pat.match(col)
        if m_enh:
            # Example: Enh12_g3_chr -> 'Enhancer_12'
            name = f"Enhancer_{m_enh.group(1)}"
        
        # 2) Else if it starts with 'Pos_', treat as a "positive-control" style column
        elif col.startswith("Pos_"):
            # Remove the 'Pos_' prefix, then remove trailing _g\d+...
            # e.g. Pos_RPL29_g1 -> remove 'Pos_' -> 'RPL29_g1' -> remove '_g1' -> 'RPL29'
            subcol = col[len("Pos_"):]  # e.g. 'RPL29_g1'
            name = re.sub(r'_g\d+.*', '', subcol)
        
        # 3) Otherwise, it's a normal gene column, remove trailing _g\d+...
        else:
            name = re.sub(r'_g\d+.*', '', col)
        
        # Now set or append this name for cells that are True in that column
        cells_with_guide = adata.obs.index[adata.obs[col] == True]
        for cell in cells_with_guide:
            current = adata.obs.at[cell, 'perturbation_name']
            if current == 'Non-targeting':
                adata.obs.at[cell, 'perturbation_name'] = name
            else:
                # If there's already another guide, append with '+'
                current_parts = current.split('+')
                if name not in current_parts:
                    new_val = '+'.join(sorted(current_parts + [name]))
                    adata.obs.at[cell, 'perturbation_name'] = new_val
    
    # Now set condition: 'Control' for Non-targeting, 'Test' otherwise
    adata.obs['condition'] = np.where(
        adata.obs['perturbation_name'] == 'Non-targeting',
        'Control', 
        'Test'
    )
    
    # Print stats
    num_perturbed = (adata.obs['perturbation_name'] != 'Non-targeting').sum()
    print(f"Processed {num_perturbed} cells with a guide (perturbation).")
    
    return adata


def remove_enhancer_only_cells(adata):
    """
    Remove cells that are enhancer-only (i.e. no gene perturbation, only one or more 'Enhancer_xxx').
    
    After 'process_guide_information', some cells might have:
      - 'Non-targeting'
      - 'GeneName' (or 'GeneName+GeneName2')
      - 'Enhancer_123' (or multiple enhancers combined)
      - 'Gene+Enhancer_123'
    This function drops cells that have only enhancers in 'perturbation_name'.
    """
    def is_enhancer_only(p):
        if p == "Non-targeting":
            return False
        # If *every* part is 'Enhancer_xxx', then it's "Enhancer only"
        parts = p.split('+')
        return all(part.startswith("Enhancer_") for part in parts)

    mask = adata.obs['perturbation_name'].apply(is_enhancer_only)
    n_remove = mask.sum()
    print(f"Removing {n_remove} cells that are enhancer-only.")
    
    # Drop those cells from the dataset
    adata = adata[~mask].copy()
    # Optionally, if you want them labeled as Non-targeting instead of removing them,
    # you could do something like:
    # adata.obs.loc[mask, 'perturbation_name'] = 'Non-targeting'
    # adata.obs.loc[mask, 'condition'] = 'Control'
    return adata

def harmonize_metadata(adata):
    """Harmonize metadata according to the required standards."""
    print("Harmonizing metadata...")
    
    # Add or update relevant metadata fields
    adata.obs['organism'] = 'Homo sapiens'
    adata.obs['cell_type'] = 'Astrocyte'
    adata.obs['crispr_type'] = 'CRISPRi'
    adata.obs['cancer_type'] = 'Non-Cancer'
    # We do NOT reset 'condition' here because it's already set
    # based on whether the cell is Non-targeting or has a perturbation.
    
    # Rename some columns to more standard names if they exist
    rename_dict = {
        'UMI_Total': 'umi_count',
        'Gene_Total': 'gene_count',
        'Mito_Pct': 'mito_percent',
        'Library': 'library',
        'Cycle_Seurat_Phase': 'cell_cycle_phase',
        'Cluster': 'cluster',
        'MOI': 'multiplicity_of_infection',
        'TransductionPool': 'transduction_pool'
    }
    
    for old_name, new_name in rename_dict.items():
        if old_name in adata.obs.columns:
            adata.obs[new_name] = adata.obs[old_name]
    
    return adata


def filter_metadata_columns(adata, keep_guide_columns=False):
    """Filter out unnecessary metadata columns to reduce file size."""
    print("Filtering metadata columns...")
    
    # Required columns
    required_cols = [
        'organism', 'cell_type', 'crispr_type', 'cancer_type', 
        'condition', 'perturbation_name'
    ]
    
    # Additional useful columns
    useful_cols = [
        'umi_count', 'gene_count', 'mito_percent', 'library',
        'cell_cycle_phase', 'cluster', 'multiplicity_of_infection',
        'transduction_pool'
    ]
    
    # Original columns to keep (if they exist)
    original_cols = [
        'Library', 'UMI_Total', 'Gene_Total', 'Mito_Pct', 
        'Cycle_Seurat_Phase', 'Cluster', 'MOI', 'TransductionPool',
        'AnyGuide'
    ]
    
    keep_cols = required_cols + useful_cols + original_cols
    all_cols = adata.obs.columns.tolist()
    
    if keep_guide_columns:
        # Keep columns that match 'Enh' or 'Pos' or anything with '_g'
        drop_cols = [
            col for col in all_cols 
            if col not in keep_cols 
            # only drop if it doesn't look like a guide column
            and not re.search(r'(Enh|Pos)_|\_g\d+', col)
        ]
    else:
        # Drop columns with any guide pattern
        drop_cols = [col for col in all_cols if col not in keep_cols]
    
    adata.obs = adata.obs.drop(columns=drop_cols)
    
    print(f"Kept {len(adata.obs.columns)} columns, dropped {len(drop_cols)} columns")
    return adata


def main(data_dir="GSE236057", keep_guide_columns=False):
    """
    Main function to process and harmonize the dataset.
    
    Parameters
    ----------
    data_dir : str
        Directory where data will be downloaded and processed.
    keep_guide_columns : bool
        If True, keep the original guide columns in adata.obs.
    """
    # Ensure files exist
    data_dir = ensure_files_exist(data_dir)
    
    # Load data
    adata = load_data(data_dir)
    
    # Remove guide suffixes (_g1, etc.) and unify perturbation naming
    adata = process_guide_information(adata)
    
    # >>> Remove enhancer-only cells <<<
    adata = remove_enhancer_only_cells(adata)
    
    # Harmonize metadata
    adata = harmonize_metadata(adata)
    
    # Filter metadata columns
    adata = filter_metadata_columns(adata, keep_guide_columns=keep_guide_columns)
    
    # Save the processed data
    output_file = os.path.join(data_dir, "GSE236057_harmonized_no_enhancer_only.h5ad")
    print(f"Saving harmonized data (no enhancer-only cells) to {output_file}...")
    adata.write(output_file)
    
    print(f"Harmonization complete. Output file: {output_file}")
    print(f"Dataset shape: {adata.shape}")
    print(f"Number of unique perturbations: {adata.obs['perturbation_name'].nunique()}")
    
    # Print top 10 perturbations
    perturbation_counts = adata.obs['perturbation_name'].value_counts()
    print("\nTop 10 perturbations by frequency:")
    print(perturbation_counts.head(10))
    
    return adata

# Usage in Jupyter:
# 1. Run this entire cell to define the functions.
# 2. Then run:
adata = main(data_dir="GSE236057", keep_guide_columns=False)
#    to process the dataset. This version will remove any cells that only had enhancers
#    (e.g. 'Enhancer_123' or 'Enhancer_123+Enhancer_456'), so your final dataset
#    contains only Non-targeting cells, gene-targeting cells, or gene+enhancer cells.
