In [None]:
import os
import tarfile
import urllib.request
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
from scipy import sparse
import anndata

# Constants
GEO_ACCESSION = "GSE273164"
GEO_URL = f"https://www.ncbi.nlm.nih.gov/geo/download/?acc={GEO_ACCESSION}&format=file"
README_URL = f"https://ftp.ncbi.nlm.nih.gov/geo/series/{GEO_ACCESSION[:5]}nnn/{GEO_ACCESSION}/suppl/{GEO_ACCESSION}_feature_README.txt"


def download_dataset(data_dir: Path) -> Path:
    """
    Download the dataset if it doesn't exist.

    Args:
        data_dir: Directory to store the dataset

    Returns:
        Path to the downloaded tar file
    """
    tar_path = data_dir / f"{GEO_ACCESSION}_RAW.tar"

    if tar_path.exists():
        print(f"Dataset already exists at {tar_path}")
        return tar_path

    print(f"Downloading dataset from {GEO_URL}...")
    urllib.request.urlretrieve(GEO_URL, tar_path)
    print(f"Downloaded dataset to {tar_path}")

    return tar_path


def extract_files(tar_path: Path, data_dir: Path) -> Path:
    """
    Extract files from the tar archive.

    Args:
        tar_path: Path to the tar file
        data_dir: Directory to extract the files to

    Returns:
        Path to the extracted directory
    """
    extract_dir = data_dir / "extracted"

    if extract_dir.exists():
        print(f"Files already extracted to {extract_dir}")
        return extract_dir

    print(f"Extracting files from {tar_path}...")
    extract_dir.mkdir(exist_ok=True)

    with tarfile.open(tar_path) as tar:
        tar.extractall(path=extract_dir)

    print(f"Extracted files to {extract_dir}")
    return extract_dir


def get_file_paths(extract_dir: Path) -> Dict[str, Dict[str, Path]]:
    """
    Get file paths for each pool.

    Args:
        extract_dir: Directory containing the extracted files

    Returns:
        Dictionary mapping pool IDs to file paths
    """
    file_paths_by_pool = {}

    # Find all CSV files
    csv_files = list(extract_dir.glob("*.csv.gz"))

    # Group files by pool
    for file_path in csv_files:
        file_name = file_path.name

        # Extract pool ID from file name
        if "_RNA_" in file_name:
            pool_id = file_name.split("_RNA_")[0]
            file_type = "rna"
        elif "_ADT_" in file_name:
            pool_id = file_name.split("_ADT_")[0]
            file_type = "adt"
        elif "_HTO_" in file_name:
            pool_id = file_name.split("_HTO_")[0]
            file_type = "hto"
        else:
            continue

        if pool_id not in file_paths_by_pool:
            file_paths_by_pool[pool_id] = {}

        file_paths_by_pool[pool_id][file_type] = file_path

    print(f"Found {len(file_paths_by_pool)} pools:")
    for pool_id, file_paths in file_paths_by_pool.items():
        print(f"  {pool_id}: {', '.join(file_paths.keys())}")

    return file_paths_by_pool


def load_adt_feature_metadata(data_dir: Path) -> Dict[str, str]:
    """
    Load ADT feature metadata from the README file.

    Args:
        data_dir: Directory containing the dataset

    Returns:
        Dictionary mapping ADT feature IDs to protein names
    """
    readme_path = data_dir / f"{GEO_ACCESSION}_feature_README.txt"

    # Download README file if it doesn't exist
    if not readme_path.exists():
        print(f"Downloading README file from {README_URL}...")
        try:
            urllib.request.urlretrieve(README_URL, readme_path)
            print(f"Downloaded README file to {readme_path}")
        except Exception as e:
            print(f"WARNING: {readme_path} does not exist. Skipping ADT feature metadata loading.")
            return {}

    # Parse README file
    try:
        feature_to_protein = {}

        with open(readme_path, 'r') as f:
            # Skip header
            f.readline()

            for line in f:
                parts = line.strip().split(',')
                if len(parts) >= 5 and 'Antibody Capture' in line:
                    adt_id = parts[0]
                    protein_name = parts[1]
                    sequence = parts[4]
                    feature_id = f"{adt_id}-{sequence}"

                    feature_to_protein[adt_id] = protein_name
                    feature_to_protein[sequence] = protein_name
                    feature_to_protein[feature_id] = protein_name

        print(f"Loaded metadata for {len(feature_to_protein)} ADT features")
        return feature_to_protein

    except Exception as e:
        print(f"Error loading ADT feature metadata: {e}")
        return {}


def parse_cell_barcode(barcode: str) -> Tuple[str, str, str]:
    """
    Parse a cell barcode to extract sample ID, perturbation, and cell barcode.

    Args:
        barcode: Cell barcode string

    Returns:
        Tuple of (sample_id, perturbation, cell_barcode)
    """
    # Example: PZ-2975_PDX17c-pool1_AAACCCAAGATACAGT-1
    parts = barcode.split('_')

    if len(parts) >= 3:
        sample_id = parts[1].split('-')[0]  # PDX17c
        cell_barcode = parts[2]  # AAACCCAAGATACAGT-1

        # Determine perturbation
        if 'sgGAP45.3' in barcode or 'sgGAP45-3' in barcode:
            perturbation = f"{sample_id}-sgGAP45.3"
        elif 'sgGAP45.4' in barcode or 'sgGAP45-4' in barcode:
            perturbation = f"{sample_id}-sgGAP45.4"
        else:
            perturbation = f"{sample_id}-sgNeg"

        return sample_id, perturbation, cell_barcode

    return "", "", barcode


def read_csv_data(file_path: Path) -> pd.DataFrame:
    """
    Read a CSV file containing expression data.

    Args:
        file_path: Path to the CSV file

    Returns:
        DataFrame containing the expression data
    """
    try:
        df = pd.read_csv(file_path, compression='gzip', index_col=0)
        return df
    except Exception as e:
        print(f"Error reading {file_path}: {e}")
        return pd.DataFrame()


def process_expression_data(
    pool_id: str,
    rna_path: Optional[Path] = None,
    adt_path: Optional[Path] = None,
    adt_feature_metadata: Optional[Dict[str, str]] = None
) -> Tuple[Optional[anndata.AnnData], Optional[anndata.AnnData]]:
    """
    Process expression data for a single pool.

    Args:
        pool_id: Pool ID
        rna_path: Path to RNA expression data
        adt_path: Path to ADT expression data
        adt_feature_metadata: Dictionary mapping ADT feature IDs to protein names

    Returns:
        Tuple of (RNA AnnData, ADT AnnData)
    """
    # Read RNA data
    rna_adata = None
    if rna_path:
        print(f"Reading RNA data from {rna_path}")
        try:
            # Read the data (genes as rows, cells as columns)
            rna_df = read_csv_data(rna_path)

            if rna_df.empty:
                print("RNA data is empty")
                return None, None

            # Extract cell barcodes and gene names
            cell_barcodes = rna_df.columns.tolist()
            gene_names = rna_df.index.tolist()

            # Create expression matrix (genes x cells)
            X = rna_df.values

            # Create AnnData object with transposed data (cells x genes)
            rna_adata = anndata.AnnData(
                X=sparse.csr_matrix(X.T),
                obs=pd.DataFrame(index=cell_barcodes),
                var=pd.DataFrame(index=gene_names)
            )

            # Add original barcodes
            rna_adata.obs['Barcode'] = cell_barcodes

            # Extract metadata from barcodes
            sample_ids = []
            perturbations = []
            cell_ids = []

            for barcode in cell_barcodes:
                sample_id, perturbation, cell_id = parse_cell_barcode(barcode)
                sample_ids.append(sample_id)
                perturbations.append(perturbation)
                cell_ids.append(cell_id)

            rna_adata.obs['sample_id'] = sample_ids
            rna_adata.obs['perturbation_name'] = perturbations
            rna_adata.obs['cell_id'] = cell_ids

            # Add sequencing lane information
            rna_adata.obs['Sequencing.Lane'] = pool_id
            rna_adata.obs['pool_id'] = pool_id

            print(f"RNA data shape: {rna_adata.shape}")

        except Exception as e:
            print(f"Error processing RNA data: {e}")
            return None, None

    # Read ADT data
    adt_adata = None
    if adt_path:
        print(f"Reading ADT data from {adt_path}")
        try:
            # Read the data (proteins as rows, cells as columns)
            adt_df = read_csv_data(adt_path)

            if adt_df.empty:
                print("ADT data is empty")
                return rna_adata, None

            # Extract cell barcodes and protein names
            cell_barcodes = adt_df.columns.tolist()
            protein_ids = adt_df.index.tolist()

            # Create expression matrix (proteins x cells)
            X = adt_df.values

            # Create AnnData object with transposed data (cells x proteins)
            adt_adata = anndata.AnnData(
                X=sparse.csr_matrix(X.T),
                obs=pd.DataFrame(index=cell_barcodes),
                var=pd.DataFrame(index=protein_ids)
            )

            # Add original barcodes
            adt_adata.obs['Barcode'] = cell_barcodes

            # Extract metadata from barcodes
            sample_ids = []
            perturbations = []
            cell_ids = []

            for barcode in cell_barcodes:
                sample_id, perturbation, cell_id = parse_cell_barcode(barcode)
                sample_ids.append(sample_id)
                perturbations.append(perturbation)
                cell_ids.append(cell_id)

            adt_adata.obs['sample_id'] = sample_ids
            adt_adata.obs['perturbation_name'] = perturbations
            adt_adata.obs['cell_id'] = cell_ids

            # Add sequencing lane information
            adt_adata.obs['Sequencing.Lane'] = pool_id
            adt_adata.obs['pool_id'] = pool_id

            # Add protein names if available
            if adt_feature_metadata:
                protein_names = []
                for protein_id in protein_ids:
                    if protein_id in adt_feature_metadata:
                        protein_names.append(adt_feature_metadata[protein_id])
                    else:
                        protein_names.append(protein_id)

                adt_adata.var['protein_name'] = protein_names

            print(f"ADT data shape: {adt_adata.shape}")

        except Exception as e:
            print(f"Error processing ADT data: {e}")
            return rna_adata, None

    return rna_adata, adt_adata


def process_pool(
    pool_id: str,
    file_paths: Dict[str, Path],
    adt_feature_metadata: Optional[Dict[str, str]] = None,
    output_dir: Optional[Path] = None
) -> Tuple[Optional[anndata.AnnData], Optional[anndata.AnnData]]:
    """
    Process data for a single pool.

    Args:
        pool_id: Pool ID
        file_paths: Dictionary mapping file types to file paths
        adt_feature_metadata: Dictionary mapping ADT feature IDs to protein names
        output_dir: Directory to save intermediate results

    Returns:
        Tuple of (RNA AnnData, ADT AnnData)
    """
    print(f"\nProcessing pool: {pool_id}")

    try:
        # Process expression data
        rna_adata, adt_adata = process_expression_data(
            pool_id=pool_id,
            rna_path=file_paths.get('rna'),
            adt_path=file_paths.get('adt'),
            adt_feature_metadata=adt_feature_metadata
        )

        # Check if data was processed successfully
        if rna_adata is None and adt_adata is None:
            print(f"Failed to process data for pool {pool_id}")
            return None, None

        # Find common barcodes
        if rna_adata is not None and adt_adata is not None:
            common_barcodes = set(rna_adata.obs_names).intersection(set(adt_adata.obs_names))
            print(f"Found {len(common_barcodes)} common barcodes across RNA and ADT data")

            # Filter data to include only common barcodes
            if len(common_barcodes) > 0:
                rna_adata = rna_adata[list(common_barcodes)]
                adt_adata = adt_adata[list(common_barcodes)]

                print(f"Filtered RNA data shape: {rna_adata.shape}")
                print(f"Filtered ADT data shape: {adt_adata.shape}")

        # Save intermediate results if output directory is provided
        if output_dir is not None:
            pool_dir = output_dir / "pools"
            pool_dir.mkdir(exist_ok=True)

            if rna_adata is not None:
                rna_output_path = pool_dir / f"{pool_id}_rna.h5ad"
                print(f"Saving RNA data to {rna_output_path}")
                rna_adata.write(rna_output_path)

            if adt_adata is not None:
                adt_output_path = pool_dir / f"{pool_id}_adt.h5ad"
                print(f"Saving ADT data to {adt_output_path}")
                adt_adata.write(adt_output_path)

        return rna_adata, adt_adata

    except Exception as e:
        print(f"Error processing pool {pool_id}: {e}")
        return None, None


def add_standardized_metadata(adata: anndata.AnnData, data_type: str) -> anndata.AnnData:
    """
    Add standardized metadata to AnnData object.

    Args:
        adata: AnnData object
        data_type: Type of data ('rna' or 'adt')

    Returns:
        AnnData object with standardized metadata
    """
    # Add organism information
    adata.obs['organism'] = 'Homo sapiens'

    # Add cell type information
    adata.obs['cell_type'] = 'AML'

    # Add CRISPR type information
    adata.obs['crispr_type'] = 'CRISPR KO'

    # Add cancer type information
    adata.obs['cancer_type'] = 'Acute Myeloid Leukemia'

    # Add condition information based on perturbation
    conditions = []
    for perturbation in adata.obs['perturbation_name']:
        if 'sgGAP45' in perturbation:
            conditions.append('ARHGAP45 KO')
        else:
            conditions.append('Control')

    adata.obs['condition'] = conditions

    # Convert categorical columns
    for col in ['organism', 'cell_type', 'crispr_type', 'cancer_type', 'condition', 'perturbation_name']:
        adata.obs[col] = adata.obs[col].astype('category')

    # Add label column for visualization
    adata.obs['Label'] = adata.obs['perturbation_name']

    return adata


def combine_pools(
    rna_adatas: List[anndata.AnnData],
    adt_adatas: List[anndata.AnnData],
    output_dir: Path
) -> Tuple[anndata.AnnData, anndata.AnnData]:
    """
    Combine data from multiple pools.

    Args:
        rna_adatas: List of RNA AnnData objects
        adt_adatas: List of ADT AnnData objects
        output_dir: Directory to save intermediate results

    Returns:
        Tuple of (combined RNA AnnData, combined ADT AnnData)
    """
    print("\nCombining data from all pools...")

    # Combine RNA data
    combined_rna = anndata.concat(
        rna_adatas,
        join='outer',
        fill_value=0,
        index_unique='-'
    )

    # Combine ADT data
    combined_adt = anndata.concat(
        adt_adatas,
        join='outer',
        fill_value=0,
        index_unique='-'
    )

    print(f"Combined RNA data shape: {combined_rna.shape}")
    print(f"Combined ADT data shape: {combined_adt.shape}")

    # Save intermediate results
    rna_output_path = output_dir / f"{GEO_ACCESSION}_rna.h5ad"
    adt_output_path = output_dir / f"{GEO_ACCESSION}_adt.h5ad"

    print(f"Saving combined RNA data to {rna_output_path}")
    combined_rna.write(rna_output_path)

    print(f"Saving combined ADT data to {adt_output_path}")
    combined_adt.write(adt_output_path)

    return combined_rna, combined_adt


def harmonize_data(
    rna_adata: anndata.AnnData,
    adt_adata: anndata.AnnData,
    output_dir: Path
) -> Tuple[anndata.AnnData, anndata.AnnData]:
    """
    Harmonize data by adding standardized metadata.

    Args:
        rna_adata: RNA AnnData object
        adt_adata: ADT AnnData object
        output_dir: Directory to save intermediate results

    Returns:
        Tuple of (harmonized RNA AnnData, harmonized ADT AnnData)
    """
    print("\nHarmonizing data...")

    # Add standardized metadata
    rna_adata = add_standardized_metadata(rna_adata, 'rna')
    adt_adata = add_standardized_metadata(adt_adata, 'adt')

    # Save intermediate results
    rna_output_path = output_dir / f"{GEO_ACCESSION}_rna_updated.h5ad"
    adt_output_path = output_dir / f"{GEO_ACCESSION}_adt_updated.h5ad"

    print(f"Saving harmonized RNA data to {rna_output_path}")
    rna_adata.write(rna_output_path)

    print(f"Saving harmonized ADT data to {adt_output_path}")
    adt_adata.write(adt_output_path)

    return rna_adata, adt_adata


def filter_paired_data(
    rna_adata: anndata.AnnData,
    adt_adata: anndata.AnnData,
    output_dir: Path
) -> Tuple[anndata.AnnData, anndata.AnnData]:
    """
    Filter data to include only cells with both RNA and ADT data.

    Args:
        rna_adata: RNA AnnData object
        adt_adata: ADT AnnData object
        output_dir: Directory to save final results

    Returns:
        Tuple of (filtered RNA AnnData, filtered ADT AnnData)
    """
    print("\nFiltering paired data...")

    # Find common barcodes
    common_barcodes = set(rna_adata.obs_names).intersection(set(adt_adata.obs_names))
    print(f"Found {len(common_barcodes)} common barcodes across RNA and ADT data")

    # Filter data to include only common barcodes
    if len(common_barcodes) > 0:
        rna_adata = rna_adata[list(common_barcodes)]
        adt_adata = adt_adata[list(common_barcodes)]

        print(f"Filtered RNA data shape: {rna_adata.shape}")
        print(f"Filtered ADT data shape: {adt_adata.shape}")

    # Save final results
    rna_output_path = output_dir / f"{GEO_ACCESSION}_rna_final.h5ad"
    adt_output_path = output_dir / f"{GEO_ACCESSION}_adt_final.h5ad"

    print(f"Saving final RNA data to {rna_output_path}")
    rna_adata.write(rna_output_path)

    print(f"Saving final ADT data to {adt_output_path}")
    adt_adata.write(adt_output_path)

    return rna_adata, adt_adata


def check_gene_symbols(rna_adata: anndata.AnnData) -> anndata.AnnData:
    """
    Check if var_names are gene symbols and fix if needed.

    Args:
        rna_adata: RNA AnnData object

    Returns:
        RNA AnnData object with gene symbols as var_names
    """
    print("\nChecking gene symbols...")

    var_names = rna_adata.var_names.tolist()

    # Check if var_names contain ENSEMBL IDs
    if any(name.startswith('ENSG') for name in var_names[:100]):
        print("var_names contain ENSEMBL IDs. Converting to gene symbols...")

        # Try to extract gene symbols from var_names
        gene_symbols = []
        for name in var_names:
            if '|' in name:
                # Format might be ENSEMBL|SYMBOL
                parts = name.split('|')
                if len(parts) > 1:
                    gene_symbols.append(parts[1])
                else:
                    gene_symbols.append(name)
            else:
                gene_symbols.append(name)

        # Check for duplicate gene symbols
        if len(set(gene_symbols)) < len(gene_symbols):
            print("Warning: Duplicate gene symbols found. Adding suffix to make them unique.")

            # Add suffix to duplicate gene symbols
            symbol_count = {}
            unique_symbols = []

            for symbol in gene_symbols:
                if symbol in symbol_count:
                    symbol_count[symbol] += 1
                    unique_symbols.append(f"{symbol}_{symbol_count[symbol]}")
                else:
                    symbol_count[symbol] = 0
                    unique_symbols.append(symbol)

            gene_symbols = unique_symbols

        # Update var_names
        rna_adata.var_names = pd.Index(gene_symbols)
        print(f"Updated var_names to gene symbols. Example: {rna_adata.var_names[:5].tolist()}")
    else:
        print(f"var_names already appear to be gene symbols. Example: {var_names[:5]}")

    return rna_adata


def main(data_dir: Union[str, Path]) -> None:
    """
    Main function to process the dataset.

    Args:
        data_dir: Directory containing the dataset
    """
    # Convert to Path object
    data_dir = Path(data_dir)
    data_dir.mkdir(exist_ok=True)

    # Create output directory
    output_dir = data_dir / "processed"
    output_dir.mkdir(exist_ok=True)

    # Download and extract dataset
    tar_path = download_dataset(data_dir)
    extract_dir = extract_files(tar_path, data_dir)

    # Get file paths for each pool
    file_paths_by_pool = get_file_paths(extract_dir)

    # Load ADT feature metadata
    adt_feature_metadata = load_adt_feature_metadata(data_dir)

    # Process each pool
    rna_adatas = []
    adt_adatas = []

    for pool_id, file_paths in file_paths_by_pool.items():
        rna_adata, adt_adata = process_pool(
            pool_id,
            file_paths,
            adt_feature_metadata=adt_feature_metadata,
            output_dir=output_dir
        )

        if rna_adata is not None:
            rna_adatas.append(rna_adata)

        if adt_adata is not None:
            adt_adatas.append(adt_adata)

    # Combine data from all pools
    if rna_adatas and adt_adatas:
        combined_rna, combined_adt = combine_pools(rna_adatas, adt_adatas, output_dir)

        # Check gene symbols
        combined_rna = check_gene_symbols(combined_rna)

        # Harmonize data
        harmonized_rna, harmonized_adt = harmonize_data(combined_rna, combined_adt, output_dir)

        # Filter paired data
        final_rna, final_adt = filter_paired_data(harmonized_rna, harmonized_adt, output_dir)

        print("\nProcessing complete!")
        print(f"Final RNA data shape: {final_rna.shape}")
        print(f"Final ADT data shape: {final_adt.shape}")
    else:
        print("No data was processed successfully.")


# ----------------------------------------------------------------------------
# You can now simply call main() in your Jupyter notebook with the desired
# data directory. By default, you can just pass in something like the current
# directory or an absolute path:
# ----------------------------------------------------------------------------

# Example usage in a Jupyter Notebook:
data_dir = "./GSE273164_data"  # Change to your preferred folder
main(data_dir)
