In [None]:
import os
import sys
import gzip
import pandas as pd
import numpy as np
import scipy.io as sio
import scipy.sparse as sparse
import urllib.request
import tarfile
import re
from pathlib import Path
import warnings
import time
from datetime import datetime
import traceback
warnings.filterwarnings('ignore')

def download_data(data_dir):
    """
    Download the GSE272093 dataset if not already present.
    """
    os.makedirs(data_dir, exist_ok=True)
    
    files_to_download = {
        "GSE272093_RAW.tar": "https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE272093&format=file",
        "GSE272093_sgRNA_assignments.txt.gz": "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE272nnn/GSE272093/suppl/GSE272093_sgRNA_assignments.txt.gz"
    }
    
    for filename, url in files_to_download.items():
        filepath = os.path.join(data_dir, filename)
        if not os.path.exists(filepath):
            print(f"Downloading {filename}...")
            try:
                urllib.request.urlretrieve(url, filepath)
                print(f"Downloaded {filename}")
            except Exception as e:
                print(f"Error downloading {filename}: {e}")
                if os.path.exists(filepath):
                    os.remove(filepath)
        else:
            print(f"{filename} already exists, skipping download")
    
    tar_filepath = os.path.join(data_dir, "GSE272093_RAW.tar")
    if os.path.exists(tar_filepath):
        sample_file = os.path.join(data_dir, "GSM8392917_EL1_Day14_matrix.mtx.gz")
        if not os.path.exists(sample_file):
            print("Extracting GSE272093_RAW.tar...")
            try:
                with tarfile.open(tar_filepath) as tar:
                    tar.extractall(path=data_dir)
                print("Extracted GSE272093_RAW.tar")
            except Exception as e:
                print(f"Error extracting {tar_filepath}: {e}")
        else:
            print("Files already extracted, skipping extraction")

def load_10x_data(data_dir, sample_id):
    """
    Load 10x data for a given sample.
    """
    matrix_file = os.path.join(data_dir, f"{sample_id}_matrix.mtx.gz")
    features_file = os.path.join(data_dir, f"{sample_id}_features.tsv.gz")
    barcodes_file = os.path.join(data_dir, f"{sample_id}_barcodes.tsv.gz")
    
    for file_path in [matrix_file, features_file, barcodes_file]:
        if not os.path.exists(file_path):
            raise FileNotFoundError(f"File {file_path} not found")
    
    print(f"Loading matrix from {matrix_file}...")
    start_time = time.time()
    with gzip.open(matrix_file, 'rb') as f:
        X = sio.mmread(f).T.tocsr()  # Transpose so that cells are rows, genes are columns
    print(f"Matrix loaded in {time.time() - start_time:.2f} seconds")
    
    print(f"Loading features from {features_file}...")
    features = pd.read_csv(features_file, sep='\t', header=None)
    gene_ids = features[0].values
    gene_symbols = features[1].values
    
    print(f"Loading barcodes from {barcodes_file}...")
    barcodes = pd.read_csv(barcodes_file, sep='\t', header=None)[0].values
    
    return {'X': X, 'gene_ids': gene_ids, 'gene_symbols': gene_symbols, 'barcodes': barcodes}

def load_sgRNA_assignments(data_dir):
    """
    Load sgRNA assignments.
    """
    sgRNA_file = os.path.join(data_dir, "GSE272093_sgRNA_assignments.txt.gz")
    print(f"Loading sgRNA assignments from {sgRNA_file}...")
    sgRNA_assignments = pd.read_csv(sgRNA_file, sep='\t')
    return sgRNA_assignments

def load_mapped_data(data_dir, sample_id):
    """
    Load mapped data (CSV with sgRNA assignments) for a sample.
    """
    mapped_file = os.path.join(data_dir, f"{sample_id}.csv.gz")
    if not os.path.exists(mapped_file):
        raise FileNotFoundError(f"File {mapped_file} not found")
    
    print(f"Loading mapped data from {mapped_file}...")
    start_time = time.time()
    chunks = []
    chunk_size = 100000
    for chunk in pd.read_csv(mapped_file, chunksize=chunk_size):
        chunks.append(chunk)
        print(f"Loaded {len(chunks) * chunk_size} rows...")
    mapped_data = pd.concat(chunks, ignore_index=True)
    print(f"Mapped data loaded in {time.time() - start_time:.2f} seconds")
    return mapped_data

def extract_perturbation_info(guide_name):
    """
    Extract perturbation information from a guide name.
    """
    if pd.isna(guide_name) or guide_name == "None":
        return {"perturbation_name": "Non-targeting", "crispr_type": "CRISPRi", "targeting": "Non-targeting"}
    
    match = re.match(r'guide_([^_]+)', guide_name)
    if match:
        gene_name = match.group(1)
        return {"perturbation_name": gene_name, "crispr_type": "CRISPRi", "targeting": "Targeting"}
    else:
        return {"perturbation_name": "Unknown", "crispr_type": "CRISPRi", "targeting": "Unknown"}

def parse_multi_guide_name(guide_str):
    """
    Split a multi-guide string on commas and extract only the gene name portion.
    Example:
        "guide_CDK8_-_26828453.23-P1P2,guide_UQCRQ_-_132202364.23-P1P2"
        -> "CDK8 + UQCRQ"
    """
    if pd.isna(guide_str):
        return None
    
    # Split on commas to get each guide entry
    guide_entries = guide_str.split(',')
    
    # For each guide, extract the gene name (the piece after 'guide_' and before the next underscore)
    genes = []
    for entry in guide_entries:
        parts = entry.split('_')
        if len(parts) > 1:
            gene_name = parts[1]  # e.g. "CDK8", "NonTargeting"
        else:
            gene_name = "Unknown"
        genes.append(gene_name)
    
    # Remove duplicates while preserving order
    # Example: if "CDK8" appears multiple times, we keep only one
    genes = list(dict.fromkeys(genes))
    
    # Join with " + "
    return " + ".join(genes)

def process_sample(data_dir, sample_info):
    """
    Example of how to integrate the parse_multi_guide_name function
    to create a multi-gene 'perturbation_name'.
    """
    print(f"Processing sample {sample_info['sample_id']}...")
    
    # -- Load expression data (unchanged) --
    data = load_10x_data(data_dir, sample_info['expression_id'])
    
    # Create obs and var
    obs = pd.DataFrame(index=data['barcodes'])
    var = pd.DataFrame(index=data['gene_symbols'])
    var['gene_ids'] = data['gene_ids']
    
    # If mapped data exists
    if sample_info['mapped_id']:
        try:
            mapped_data = load_mapped_data(data_dir, sample_info['mapped_id'])
            print("Processing sgRNA assignments...")
            cell_to_guide = {}
            
            # Build a dict of {cell_barcode: "guide1,guide2,..."}
            for _, row in mapped_data.iterrows():
                cell_barcode = row['cell']
                gRNA = row['gRNA']
                if cell_barcode in cell_to_guide:
                    # Append this guide if not already present
                    existing = cell_to_guide[cell_barcode].split(',')
                    if gRNA not in existing:
                        existing.append(gRNA)
                        cell_to_guide[cell_barcode] = ','.join(existing)
                else:
                    cell_to_guide[cell_barcode] = gRNA
            
            # Assign new multi-guide string to obs['guide']
            parsed_guides = []
            parsed_perturbations = []
            
            for cell in obs.index:
                raw_guides = cell_to_guide.get(cell, None)
                
                # The 'guide' column can store the raw comma-delimited info
                parsed_guides.append(raw_guides)
                
                # Convert the raw string into a "gene1 + gene2" format
                p_name = parse_multi_guide_name(raw_guides)
                parsed_perturbations.append(p_name)
            
            obs['guide'] = parsed_guides
            obs['perturbation_name'] = parsed_perturbations
            
            # For example, define 'crispr_type' always as 'CRISPRi',
            # and 'targeting' is 'Non-targeting' if *all* genes are NonTargeting,
            # or 'Targeting' otherwise.
            targeting_list = []
            for p_name in obs['perturbation_name']:
                if p_name is None:
                    targeting_list.append("Unknown")
                elif p_name.startswith("NonTargeting"):
                    targeting_list.append("Non-targeting")
                else:
                    targeting_list.append("Targeting")
            
            obs['crispr_type'] = "CRISPRi"
            obs['targeting'] = targeting_list
            
        except Exception as e:
            print(f"Error processing mapped data: {e}")
            # Fallback defaults
            obs['guide'] = None
            obs['perturbation_name'] = "Unknown"
            obs['crispr_type'] = "CRISPRi"
            obs['targeting'] = "Unknown"
    else:
        # If no mapped data
        obs['guide'] = None
        obs['perturbation_name'] = "Unknown"
        obs['crispr_type'] = "CRISPRi"
        obs['targeting'] = "Unknown"
    
    # Standard metadata
    obs['organism'] = "Homo sapiens"
    obs['cell_type'] = sample_info['cell_type']
    obs['condition'] = sample_info['condition']
    obs['cancer_type'] = "Non-Cancer"
    obs['sample_id'] = sample_info['sample_id']
    obs['culture_type'] = sample_info['culture_type']
    
    harmonized_data = {
        'X': data['X'],
        'obs': obs,
        'var': var,
        'uns': {
            'sample_info': sample_info,
            'dataset_id': 'GSE272093',
            'harmonization_date': datetime.now().strftime('%Y-%m-%d'),
            'description': 'CRISPRi-based screens in iAssembloids to elucidate neuron-glia interactions'
        }
    }
    return harmonized_data

def save_h5ad(harmonized_data, output_file):
    """
    Save harmonized data to h5ad format.
    """
    try:
        import anndata as ad
        print(f"Creating AnnData object...")
        start_time = time.time()
        adata = ad.AnnData(
            X=harmonized_data['X'],
            obs=harmonized_data['obs'],
            var=harmonized_data['var'],
            uns=harmonized_data['uns']
        )
        print(f"AnnData object created in {time.time() - start_time:.2f} seconds")
        print(f"Saving to {output_file}...")
        start_time = time.time()
        adata.write(output_file)
        print(f"Saved harmonized data to {output_file} in {time.time() - start_time:.2f} seconds")
        return adata
    except ImportError:
        print("Warning: anndata package not available, saving in numpy format instead")
        np.savez(
            output_file.replace('.h5ad', '.npz'),
            X=harmonized_data['X'].toarray(),
            obs=harmonized_data['obs'].to_dict('list'),
            var=harmonized_data['var'].to_dict('list'),
            uns=harmonized_data['uns']
        )
        print(f"Saved harmonized data to {output_file.replace('.h5ad', '.npz')}")
        return None

def harmonize_dataset(data_dir=None):
    """
    Harmonize the GSE272093 dataset across multiple samples.
    """
    if data_dir is None:
        data_dir = os.getcwd()
    
    download_data(data_dir)
    
    samples = [
        {
            'sample_id': 'GSM8392917',
            'expression_id': 'GSM8392917_EL1_Day14',
            'mapped_id': 'GSM8392920_EL1_Monoculture_1_mapped',
            'description': '2D monoculture neuron CROP-seq rep 1',
            'cell_type': 'Neurons',
            'condition': 'Control',
            'culture_type': '2D monoculture'
        },
        {
            'sample_id': 'GSM8392918',
            'expression_id': 'GSM8392918_EL2_Day14_3D_Culture1',
            'mapped_id': 'GSM8392921_EL2_iAssembloid_1_mapped',
            'description': 'iAssembloids neuron CROP-seq rep 1',
            'cell_type': 'Mixed (Neurons, Astrocytes, Microglia)',
            'condition': 'Control',
            'culture_type': '3D iAssembloid'
        },
        {
            'sample_id': 'GSM8392919',
            'expression_id': 'GSM8392919_EL3_Day14_3D_Culture2',
            'mapped_id': 'GSM8392922_EL3_iAssembloid_2_mapped',
            'description': 'iAssembloids neuron CROP-seq rep 2',
            'cell_type': 'Mixed (Neurons, Astrocytes, Microglia)',
            'condition': 'Control',
            'culture_type': '3D iAssembloid'
        },
        {
            'sample_id': 'GSM8392923',
            'expression_id': 'GSM8392923_EL7_Lane1_2D_CROP',
            'mapped_id': 'GSM8392924_EL7_Monoculture_2_mapped',
            'description': '2D monoculture neuron CROP-seq rep 2',
            'cell_type': 'Neurons',
            'condition': 'Control',
            'culture_type': '2D monoculture'
        },
        {
            'sample_id': 'GSM8392925',
            'expression_id': 'GSM8392925_EL8_Lane2_2D_CROP',
            'mapped_id': 'GSM8392926_EL8_Monoculture_3_mapped',
            'description': '2D monoculture neuron CROP-seq rep 3',
            'cell_type': 'Neurons',
            'condition': 'Control',
            'culture_type': '2D monoculture'
        }
    ]
    
    output_files = []
    adatas = []
    for sample_info in samples:
        output_file = os.path.join(data_dir, f"{sample_info['sample_id']}_harmonized.h5ad")
        if os.path.exists(output_file):
            print(f"File {output_file} already exists, loading...")
            try:
                import scanpy as sc
                adata = sc.read_h5ad(output_file)
                adatas.append(adata)
                output_files.append(output_file)
            except Exception as e:
                print(f"Error reading {output_file}: {e}")
        else:
            try:
                harmonized_data = process_sample(data_dir, sample_info)
                adata = save_h5ad(harmonized_data, output_file)
                if adata is not None:
                    adatas.append(adata)
                output_files.append(output_file)
            except Exception as e:
                print(f"Error processing sample {sample_info['sample_id']}: {e}")
                traceback.print_exc()
    return output_files, adatas

def validate_harmonized_data(output_files):
    """
    Validate the harmonized data by printing summary information.
    """
    import scanpy as sc
    print("\nValidating harmonized data...")
    for output_file in output_files:
        if not os.path.exists(output_file):
            print(f"File {output_file} does not exist, skipping validation")
            continue
        print(f"Validating {output_file}...")
        try:
            adata = sc.read_h5ad(output_file)
            print(f"  AnnData object with {adata.n_obs} cells and {adata.n_vars} genes")
            is_gene_symbols = all(adata.var_names.str.match(r'^[A-Za-z0-9\-\.]+$'))
            print(f"  Gene symbols used as var_names: {is_gene_symbols}")
            for col in ['organism', 'cell_type', 'crispr_type', 'cancer_type', 'condition', 'perturbation_name']:
                if col in adata.obs.columns:
                    unique_values = adata.obs[col].unique()
                    print(f"  {col}: {unique_values[:5]} {'...' if len(unique_values) > 5 else ''}")
                else:
                    print(f"  {col}: Not available")
            print(f"  Validation complete for {output_file}")
        except Exception as e:
            print(f"  Error validating {output_file}: {e}")
            traceback.print_exc()

def combine_and_qc(data_dir=None, combined_output='combined_filtered.h5ad'):
    """
    Combine individual harmonized h5ad files, remove cells with NaN in 'guide',
    apply QC filtering, update the condition column, and save the combined dataset.
    """
    import scanpy as sc
    import anndata as ad
    if data_dir is None:
        data_dir = os.getcwd()
    
    # 1. Get the list of existing harmonized files and associated AnnData objects.
    output_files, adatas = harmonize_dataset(data_dir)
    if not adatas:
        adatas = []
        for f in output_files:
            try:
                adata = sc.read_h5ad(f)
                adatas.append(adata)
            except Exception as e:
                print(f"Error loading {f}: {e}")
    
    # 2. For each AnnData:
    #    - Ensure obs (cell) names are unique
    #    - Prefix each obs_name with the sample ID
    #    - Ensure var (gene) names are also unique
    for adata in adatas:
        # Make cell IDs unique in case of duplicates
        adata.obs_names_make_unique()
        
        # Make gene names unique for each dataset
        adata.var_names_make_unique()
        
        # Prefix obs_names with sample_id
        sample_id = adata.obs['sample_id'].iloc[0]
        adata.obs_names = [f"{sample_id}_{cell_id}" for cell_id in adata.obs_names]
    
    # 3. Concatenate AnnData objects
    combined = ad.concat(adatas, join='outer', label='sample')
    
    # 4. Remove cells with NaN in 'guide'
    combined = combined[combined.obs['guide'].notna(), :].copy()
    print(f"Combined AnnData after removing cells with NaN in guide: {combined.n_obs} cells")
    
    # 5. Ensure final var names are unique across the combined dataset
    combined.var_names_make_unique()
    
    # 6. Identify mitochondrial genes (assuming gene names start with 'MT-')
    combined.var['mt'] = combined.var_names.str.startswith('MT-')
    
    # 7. Calculate QC metrics
    sc.pp.calculate_qc_metrics(combined, qc_vars=['mt'], inplace=True)
    
    # 8. Apply QC filtering: remove cells with >20% mitochondrial counts and <200 genes
    combined_filtered = combined[
        (combined.obs['pct_counts_mt'] < 20) & (combined.obs['n_genes_by_counts'] > 200),
        :
    ].copy()
    print(f"After QC filtering: {combined_filtered.n_obs} cells")
    
    # 9. Update the condition column
    #    If perturbation_name is "Non-targeting", set condition to "Control", else "Test"
    combined_filtered.obs['condition'] = combined_filtered.obs.apply(
        lambda row: 'Control' if row['perturbation_name'] == 'Non-targeting' else 'Test',
        axis=1
    )
    
    # 10. Save the combined, filtered AnnData object
    combined_filtered.write(os.path.join(data_dir, combined_output))
    print(f"Combined filtered AnnData saved to {os.path.join(data_dir, combined_output)}")
    
    return combined_filtered



def run_all(data_dir=None):
    """
    Run the entire harmonization, combination, and QC process.
    """
    if data_dir is None:
        data_dir = os.getcwd()
    output_files, _ = harmonize_dataset(data_dir)
    validate_harmonized_data(output_files)
    combined_filtered = combine_and_qc(data_dir)
    return combined_filtered

# Run the entire process in Jupyter.
combined_filtered = run_all()
