In [None]:
import os
import pandas as pd
import numpy as np
import scanpy as sc
import anndata as ad
import gzip
import urllib.request
from scipy import sparse
from scipy.io import mmread
import re

def download_files(accession, output_dir):
    """Download all files for the given accession number if they don't exist."""
    os.makedirs(output_dir, exist_ok=True)
    
    # List of files to download
    files = [
        "GSE220974_D1_protospacer_calls_per_cell.csv.gz",
        "GSE220974_D2_protospacer_calls_per_cell.csv.gz",
        "GSE220974_D3_protospacer_calls_per_cell.csv.gz",
        "GSE220974_K562_cell_metadata.csv.gz",
        "GSE220974_RNA_barcodes.tsv.gz",
        "GSE220974_RNA_features.tsv.gz",
        "GSE220974_RNA_matrix.mtx.gz",
        "GSE220974_S1Sa_protospacer_calls_per_cell.csv.gz",
        "GSE220974_S2Sp_protospacer_calls_per_cell.csv.gz",
        "GSE220974_S3SaSp_protospacer_calls_per_cell.csv.gz",
        "GSE220974_features.csv.gz",
        "GSE220974_gRNA_barcodes.tsv.gz",
        "GSE220974_gRNA_features.tsv.gz",
        "GSE220974_gRNA_matrix.mtx.gz"
    ]
    
    base_url = f"https://ftp.ncbi.nlm.nih.gov/geo/series/GSE220nnn/{accession}/suppl/"
    
    for file in files:
        file_path = os.path.join(output_dir, file)
        if not os.path.exists(file_path):
            print(f"Downloading {file}...")
            url = base_url + file
            urllib.request.urlretrieve(url, file_path)
        else:
            print(f"File {file} already exists, skipping download.")
    
    print("All files downloaded or already exist.")

def load_expression_data(data_dir):
    """Load gene expression data from 10X format files."""
    print("Loading gene expression data...")
    
    # Load the matrix
    matrix_file = os.path.join(data_dir, "GSE220974_RNA_matrix.mtx.gz")
    features_file = os.path.join(data_dir, "GSE220974_RNA_features.tsv.gz")
    barcodes_file = os.path.join(data_dir, "GSE220974_RNA_barcodes.tsv.gz")
    
    # Read the matrix in sparse format
    matrix = mmread(gzip.open(matrix_file, 'rb'))
    
    # Read features and barcodes
    features = pd.read_csv(features_file, sep='\t', header=None)
    barcodes = pd.read_csv(barcodes_file, sep='\t', header=None)
    
    # Create AnnData object
    var_df = pd.DataFrame(index=features[0])
    var_df.index.name = None  # Ensure index name is None
    
    adata = ad.AnnData(X=matrix.T.tocsr(), 
                      obs=pd.DataFrame(index=barcodes[0]), 
                      var=var_df)
    
    # Add gene names as var_names
    adata.var['gene_symbol'] = features[1].values
    
    print(f"Expression data loaded: {adata.shape[0]} cells, {adata.shape[1]} genes")
    return adata

def load_metadata(data_dir):
    """Load and process cell metadata."""
    print("Loading metadata...")
    
    metadata_file = os.path.join(data_dir, "GSE220974_K562_cell_metadata.csv.gz")
    metadata = pd.read_csv(metadata_file, index_col=0)
    
    print(f"Metadata loaded for {metadata.shape[0]} cells")
    return metadata

def load_guide_info(data_dir):
    """Load guide RNA information."""
    print("Loading guide RNA information...")
    
    features_file = os.path.join(data_dir, "GSE220974_features.csv.gz")
    features = pd.read_csv(features_file)
    
    print(f"Guide RNA information loaded for {features.shape[0]} guides")
    return features

def load_protospacer_calls(data_dir):
    """Load protospacer calls from all datasets."""
    print("Loading protospacer calls...")
    
    protospacer_files = [
        "GSE220974_D1_protospacer_calls_per_cell.csv.gz",
        "GSE220974_D2_protospacer_calls_per_cell.csv.gz",
        "GSE220974_D3_protospacer_calls_per_cell.csv.gz",
        "GSE220974_S1Sa_protospacer_calls_per_cell.csv.gz",
        "GSE220974_S2Sp_protospacer_calls_per_cell.csv.gz",
        "GSE220974_S3SaSp_protospacer_calls_per_cell.csv.gz"
    ]
    
    all_calls = []
    for file in protospacer_files:
        file_path = os.path.join(data_dir, file)
        calls = pd.read_csv(file_path)
        all_calls.append(calls)
    
    # Combine all calls
    combined_calls = pd.concat(all_calls, ignore_index=True)
    print(f"Protospacer calls loaded for {combined_calls['cell_barcode'].nunique()} cells")
    return combined_calls

def determine_perturbation_info(metadata, guide_info, protospacer_calls):
    """Determine perturbation information for each cell."""
    print("Processing perturbation information...")
    
    # Create a mapping from guide ID to perturbation type (a or i)
    guide_type_map = {}
    for _, row in guide_info.iterrows():
        guide_id = row['id']
        if '-a' in guide_id:
            guide_type_map[guide_id] = 'CRISPRa'
        elif '-i' in guide_id:
            guide_type_map[guide_id] = 'CRISPRi'
        else:
            guide_type_map[guide_id] = 'unknown'
    
    # Create a mapping from guide ID to target gene
    guide_gene_map = dict(zip(guide_info['id'], guide_info['target_gene_name']))
    
    # Process guide_group information to extract perturbation names
    perturbation_info = {}
    
    for cell_id, row in metadata.iterrows():
        if pd.isna(row['guide_group']):
            perturbation_info[cell_id] = {
                'perturbation_name': 'Non-Targeting',
                'crispr_type': 'unknown',
                'condition': 'control'
            }
            continue
            
        guides = row['guide_group'].split('|')
        
        # Filter out non-targeting guides
        targeting_guides = [g for g in guides if not g.startswith('NTC')]
        
        if not targeting_guides:
            perturbation_info[cell_id] = {
                'perturbation_name': 'Non-Targeting',
                'crispr_type': 'unknown',
                'condition': 'control'
            }
            continue
        
        # Extract gene names and CRISPR types
        genes = []
        crispr_types = []
        
        for guide in targeting_guides:
            # Extract gene name (remove the -a1, -i2, etc.)
            gene = re.match(r'([A-Za-z0-9]+)-[ai][0-9]', guide)
            if gene:
                genes.append(gene.group(1))
            
            # Determine CRISPR type
            if '-a' in guide:
                crispr_types.append('CRISPRa')
            elif '-i' in guide:
                crispr_types.append('CRISPRi')
        
        # Create perturbation name
        if genes:
            perturbation_name = ' + '.join(sorted(set(genes)))
        else:
            perturbation_name = 'Non-Targeting'
        
        # Determine overall CRISPR type
        if 'CRISPRa' in crispr_types and 'CRISPRi' in crispr_types:
            crispr_type = 'CRISPRai'
        elif 'CRISPRa' in crispr_types:
            crispr_type = 'CRISPRa'
        elif 'CRISPRi' in crispr_types:
            crispr_type = 'CRISPRi'
        else:
            crispr_type = 'unknown'
        
        # Determine condition
        condition = 'test' if perturbation_name != 'Non-Targeting' else 'control'
        
        perturbation_info[cell_id] = {
            'perturbation_name': perturbation_name,
            'crispr_type': crispr_type,
            'condition': condition
        }
    
    # Convert to DataFrame
    perturbation_df = pd.DataFrame.from_dict(perturbation_info, orient='index')
    print(f"Perturbation information processed for {perturbation_df.shape[0]} cells")
    return perturbation_df

def harmonize_dataset(data_dir):
    """Harmonize the dataset into h5ad format with standardized metadata."""
    print(f"Harmonizing dataset from {data_dir}...")
    
    # Load data
    adata = load_expression_data(data_dir)
    metadata = load_metadata(data_dir)
    guide_info = load_guide_info(data_dir)
    protospacer_calls = load_protospacer_calls(data_dir)
    
    # Process perturbation information
    perturbation_df = determine_perturbation_info(metadata, guide_info, protospacer_calls)
    
    # Merge metadata with perturbation information
    merged_metadata = metadata.join(perturbation_df, how='left')
    
    # Fill missing values
    merged_metadata['perturbation_name'] = merged_metadata['perturbation_name'].fillna('Non-Targeting')
    merged_metadata['crispr_type'] = merged_metadata['crispr_type'].fillna('unknown')
    merged_metadata['condition'] = merged_metadata['condition'].fillna('control')
    
    # Map cell barcodes from adata to metadata
    common_cells = set(adata.obs.index) & set(merged_metadata.index)
    print(f"Found {len(common_cells)} common cells between expression data and metadata")
    
    # Filter adata to keep only cells with metadata
    adata = adata[adata.obs.index.isin(common_cells)].copy()
    
    # Create standardized metadata
    standardized_metadata = pd.DataFrame(index=adata.obs.index)
    standardized_metadata.index.name = None  # Ensure index name is None to avoid errors
    
    # Add standardized metadata fields
    standardized_metadata['organism'] = 'Homo sapiens'
    standardized_metadata['cell_type'] = 'K562'
    standardized_metadata['cancer_type'] = 'Leukemia'
    
    # Add perturbation information
    for cell in adata.obs.index:
        if cell in merged_metadata.index:
            standardized_metadata.loc[cell, 'perturbation_name'] = merged_metadata.loc[cell, 'perturbation_name']
            standardized_metadata.loc[cell, 'crispr_type'] = merged_metadata.loc[cell, 'crispr_type']
            standardized_metadata.loc[cell, 'condition'] = merged_metadata.loc[cell, 'condition']
        else:
            standardized_metadata.loc[cell, 'perturbation_name'] = 'unknown'
            standardized_metadata.loc[cell, 'crispr_type'] = 'unknown'
            standardized_metadata.loc[cell, 'condition'] = 'unknown'
    
    # Add original metadata as additional columns
    for col in merged_metadata.columns:
        if col not in standardized_metadata.columns:
            for cell in adata.obs.index:
                if cell in merged_metadata.index:
                    standardized_metadata.loc[cell, col] = merged_metadata.loc[cell, col]
    
    # Update adata.obs with standardized metadata
    adata.obs = standardized_metadata
    
    # Save the harmonized dataset
    output_file = os.path.join(data_dir, "GSE220974_harmonized.h5ad")
    print(f"Saving harmonized dataset to {output_file}")
    adata.write(output_file)
    
    print("Harmonization complete!")
    return adata

def main(data_dir=None):
    """Main function to process the dataset in a Jupyter Notebook."""
    # Use current working directory if no directory is provided
    if data_dir is None:
        data_dir = os.path.join(os.getcwd(), "GSE220974")
    
    # Create directory if it doesn't exist
    os.makedirs(data_dir, exist_ok=True)
    
    # Download files if they don't exist
    download_files("GSE220974", data_dir)
    
    # Harmonize the dataset
    adata = harmonize_dataset(data_dir)
    
    print(f"Final dataset shape: {adata.shape[0]} cells, {adata.shape[1]} genes")
    print(f"Metadata fields: {list(adata.obs.columns)}")
    return adata

# Run the main function in the notebook
adata = main()
