In [None]:
#%% [code]
import os
import sys
import urllib.request
import tarfile
import h5py
import numpy as np
import pandas as pd
import scipy.sparse as sp
import anndata
import logging
from pathlib import Path

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger('GSE289235_Harmonizer')

# Constants
GEO_ACCESSION = "GSE289235"
DATASET_URL = "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE289nnn/GSE289235/suppl/GSE289235_RAW.tar"
SAMPLE_ID = "GSM8787485"
FILE_NAME = "GSM8787485_EXP224.h5ad"


def download_dataset(data_path):
    """
    Download the dataset if it doesn't exist.
    
    Args:
        data_path: Path where the dataset will be downloaded
        
    Returns:
        Path to the downloaded tar file
    """
    os.makedirs(data_path, exist_ok=True)
    tar_path = os.path.join(data_path, f"{GEO_ACCESSION}_RAW.tar")
    
    if not os.path.exists(tar_path):
        logger.info(f"Downloading {GEO_ACCESSION} dataset...")
        urllib.request.urlretrieve(DATASET_URL, tar_path)
        logger.info(f"Download complete: {tar_path}")
    else:
        logger.info(f"Dataset already downloaded: {tar_path}")
    
    return tar_path


def extract_dataset(tar_path, data_path):
    """
    Extract the dataset from the tar file.
    
    Args:
        tar_path: Path to the tar file
        data_path: Path where the dataset will be extracted
        
    Returns:
        Path to the extracted h5ad file
    """
    h5ad_path = os.path.join(data_path, FILE_NAME)
    
    if not os.path.exists(h5ad_path):
        logger.info(f"Extracting {tar_path}...")
        with tarfile.open(tar_path) as tar:
            tar.extract(FILE_NAME, data_path)
        logger.info(f"Extraction complete: {h5ad_path}")
    else:
        logger.info(f"Dataset already extracted: {h5ad_path}")
    
    return h5ad_path


def load_h5ad_with_h5py(h5ad_path):
    """
    Load the h5ad file using h5py and convert to AnnData.
    
    Args:
        h5ad_path: Path to the h5ad file
        
    Returns:
        AnnData object
    """
    logger.info(f"Loading {h5ad_path} with h5py...")
    
    with h5py.File(h5ad_path, 'r') as f:
        # Get the shape of the data matrix
        x_shape = f['X'].attrs['shape']
        
        # Load the CSR matrix components
        data = f['X/data'][:]
        indices = f['X/indices'][:]
        indptr = f['X/indptr'][:]
        
        # Create a CSR matrix
        X = sp.csr_matrix((data, indices, indptr), shape=x_shape)
        
        # Load obs data
        obs_index = [s.decode('utf-8') for s in f['obs/_index'][:]]
        
        # Load guide information
        guide_categories = [s.decode('utf-8') for s in f['obs/guide/categories'][:]]
        guide_codes = f['obs/guide/codes'][:]
        guides = pd.Categorical.from_codes(guide_codes, categories=guide_categories)
        
        # Load sample_id information
        sample_categories = [s.decode('utf-8') for s in f['obs/sample_id/categories'][:]]
        sample_codes = f['obs/sample_id/codes'][:]
        sample_ids = pd.Categorical.from_codes(sample_codes, categories=sample_categories)
        
        # Create obs DataFrame
        obs = pd.DataFrame({
            'guide': guides,
            'sample_id': sample_ids
        }, index=obs_index)
        
        # Load var data (gene names)
        var_index = [s.decode('utf-8') for s in f['var/_index'][:]]
        var = pd.DataFrame(index=var_index)
        
    # Create AnnData object
    adata = anndata.AnnData(X=X, obs=obs, var=var)
    logger.info(f"Successfully loaded data with shape {adata.shape}")
    
    return adata


def process_guide_information(adata):
    """
    Process guide information to extract perturbation targets.
    Excludes cells with NaN perturbation names, labels non-targeting controls,
    and assigns a condition based on the guide.
    
    Args:
        adata: AnnData object
        
    Returns:
        Updated AnnData object with processed perturbation information and conditions.
    """
    logger.info("Processing guide information...")
    
    # Convert guide to string type
    adata.obs['guide'] = adata.obs['guide'].astype(str)
    
    # Extract gene names from guide names (format: GENE_i1, GENE_i2, etc.)
    adata.obs['perturbation_name'] = adata.obs['guide'].str.split('_i', expand=True)[0]
    
    # Identify non-targeting controls using patterns 'NTC' or 'non-targeting'
    is_non_targeting = adata.obs['guide'].str.contains('NTC', case=False, na=False) | \
                       adata.obs['guide'].str.contains('non-targeting', case=False, na=False)
    
    # Set perturbation_name for controls to 'Non-targeting'
    adata.obs.loc[is_non_targeting, 'perturbation_name'] = 'Non-targeting'
    
    # Exclude cells with 'nan' as the perturbation name (case insensitive)
    before_drop = adata.n_obs
    adata = adata[adata.obs['perturbation_name'].str.lower() != 'nan'].copy()
    after_drop = adata.n_obs
    logger.info(f"Dropped {before_drop - after_drop} cells with NaN perturbation names.")
    
    # Set condition: 'Control' for non-targeting and 'Test' for all others
    adata.obs['condition'] = np.where(adata.obs['perturbation_name'] == 'Non-targeting', 'Control', 'Test')
    
    logger.info(f"Identified {adata.obs['perturbation_name'].nunique()} unique perturbation targets")
    
    return adata


def harmonize_metadata(adata):
    """
    Harmonize metadata according to the specified requirements.
    
    Args:
        adata: AnnData object
        
    Returns:
        Harmonized AnnData object
    """
    logger.info("Harmonizing metadata...")
    
    # Add required metadata fields (do not overwrite condition as it's set already)
    adata.obs['organism'] = 'Homo sapiens'
    adata.obs['cell_type'] = 'iPSC-derived neurons'
    adata.obs['crispr_type'] = 'CRISPRi'
    adata.obs['cancer_type'] = 'Non-Cancer'
    
    # Add study information to uns
    adata.uns['dataset_id'] = GEO_ACCESSION
    adata.uns['sample_id'] = SAMPLE_ID
    adata.uns['study_description'] = (
        "A Massively Parallel CRISPR-Based Screening Platform for Modifiers of Neuronal Activity. "
        "The study used CRISPR interference (CRISPRi) and the fluorescent calcium integrator CaMPARI2 "
        "to evaluate 1343 genes for their effect on excitability in human iPSC-derived neurons."
    )
    
    return adata


def save_harmonized_data(adata, output_path):
    """
    Save the harmonized data to h5ad format.
    
    Args:
        adata: AnnData object
        output_path: Path where the harmonized data will be saved
    """
    logger.info(f"Saving harmonized data to {output_path}...")
    adata.write_h5ad(output_path)
    logger.info("Harmonization complete!")


def main(data_root_path='.'):
    """
    Main function to download, process, and harmonize the dataset.
    
    Args:
        data_root_path: Root path where data will be stored
        
    Returns:
        Path to the harmonized h5ad file
    """
    # Create paths
    data_root_path = Path(data_root_path)
    data_path = data_root_path / GEO_ACCESSION
    output_path = data_path / f"{GEO_ACCESSION}_harmonized.h5ad"
    
    # Download and extract dataset
    tar_path = download_dataset(data_path)
    h5ad_path = extract_dataset(tar_path, data_path)
    
    # Load and process data
    adata = load_h5ad_with_h5py(h5ad_path)
    adata = process_guide_information(adata)
    adata = harmonize_metadata(adata)
    
    # Save harmonized data
    save_harmonized_data(adata, output_path)
    
    return output_path

# Run main and display the output path
harmonized_file = main('.')
print(f"Harmonized data saved to: {harmonized_file}")
