In [None]:
import os
import gzip
import shutil
import tempfile
import urllib.request
from pathlib import Path
import pandas as pd
import numpy as np
from scipy import sparse
from scipy.io import mmread
import anndata

# Define constants
ACCESSION = "GSE255832"
BASE_URL = f"https://ftp.ncbi.nlm.nih.gov/geo/series/GSE255nnn/{ACCESSION}/suppl/"
SAMPLE_IDS = ["MMA198_1", "MMA198_2", "MMA200_1", "MMA202_1"]
FILE_TYPES = ["barcodes.tsv.gz", "features.tsv.gz", "matrix.mtx.gz", "protospacer_calls_per_cell.csv.gz"]

def ensure_directory(directory):
    """Create directory if it doesn't exist."""
    Path(directory).mkdir(parents=True, exist_ok=True)
    return directory

def download_file(url, destination):
    """
    Download a file from URL to destination if it doesn't exist.
    
    Args:
        url: URL to download from
        destination: Path to save the file to
        
    Returns:
        True if the file exists or was successfully downloaded, False otherwise.
    """
    if not os.path.exists(destination):
        print(f"Downloading {url} to {destination}")
        try:
            urllib.request.urlretrieve(url, destination)
        except Exception as e:
            print(f"Error downloading {url}: {e}")
            return False
    return True

def download_dataset(data_dir):
    """
    Download all dataset files if they don't exist.
    
    Args:
        data_dir: Directory to save the files to.
        
    Returns:
        True if all files exist or were successfully downloaded, False otherwise.
    """
    ensure_directory(data_dir)
    
    all_files_exist = True
    for sample_id in SAMPLE_IDS:
        for file_type in FILE_TYPES:
            filename = f"{ACCESSION}_{sample_id}_{file_type}"
            file_path = os.path.join(data_dir, filename)
            file_url = f"{BASE_URL}{filename}"
            
            if not download_file(file_url, file_path):
                all_files_exist = False
    
    return all_files_exist

def load_10x_data(matrix_file, features_file, barcodes_file):
    """
    Load 10X Genomics data into an AnnData object.
    
    Args:
        matrix_file: Path to the matrix.mtx.gz file.
        features_file: Path to the features.tsv.gz file.
        barcodes_file: Path to the barcodes.tsv.gz file.
        
    Returns:
        AnnData object with the loaded data.
    """
    # Create temporary directory for unzipped files
    with tempfile.TemporaryDirectory() as temp_dir:
        # Unzip files to temporary directory
        temp_matrix = os.path.join(temp_dir, 'matrix.mtx')
        temp_features = os.path.join(temp_dir, 'features.tsv')
        temp_barcodes = os.path.join(temp_dir, 'barcodes.tsv')
        
        # Unzip matrix file
        with gzip.open(matrix_file, 'rb') as f_in:
            with open(temp_matrix, 'wb') as f_out:
                shutil.copyfileobj(f_in, f_out)
        
        # Unzip features file
        with gzip.open(features_file, 'rb') as f_in:
            with open(temp_features, 'wb') as f_out:
                shutil.copyfileobj(f_in, f_out)
        
        # Unzip barcodes file
        with gzip.open(barcodes_file, 'rb') as f_in:
            with open(temp_barcodes, 'wb') as f_out:
                shutil.copyfileobj(f_in, f_out)
        
        # Read the mtx file
        X = sparse.csr_matrix(mmread(temp_matrix).T)
        
        # Read gene information
        gene_info = pd.read_csv(temp_features, sep='\t', header=None)
        
        # Read cell barcodes
        barcodes = pd.read_csv(temp_barcodes, sep='\t', header=None)[0].values
    
    # Create observation and variable dataframes
    obs = pd.DataFrame(index=barcodes)
    var = pd.DataFrame(index=gene_info[0].values)  # Gene IDs as index
    
    # Add gene information
    var['gene_symbols'] = gene_info[1].values      # Gene symbols
    if gene_info.shape[1] > 2:
        var['feature_types'] = gene_info[2].values  # Feature types
    
    # Create AnnData object
    adata = anndata.AnnData(X=X, obs=obs, var=var)
    
    # Use gene symbols as var_names
    adata.var_names = adata.var['gene_symbols'].values
    adata.var_names_make_unique()
    
    return adata

def load_perturbation_data(perturbation_file, adata):
    """
    Load perturbation information and add it to the AnnData object.
    
    Cells with an "Unknown" perturbation_name are removed.
    Also, perturbations labeled as "NonTargetingControlGuideForMouse" are converted to "Non-targeting"
    and assigned a perturbation type of "Control".
    
    Args:
        perturbation_file: Path to the protospacer_calls_per_cell.csv.gz file.
        adata: AnnData object to add the perturbation information to.
        
    Returns:
        Updated AnnData object.
    """
    # Create temporary file for unzipped perturbation data
    with tempfile.NamedTemporaryFile(suffix='.csv', delete=False) as temp_file:
        temp_pert_file = temp_file.name
    
    try:
        # Unzip perturbation file
        with gzip.open(perturbation_file, 'rb') as f_in:
            with open(temp_pert_file, 'wb') as f_out:
                shutil.copyfileobj(f_in, f_out)
        
        # Read perturbation data
        perturbations = pd.read_csv(temp_pert_file)
        
        # Convert to dictionary for faster lookup
        pert_dict = {}
        for _, row in perturbations.iterrows():
            pert_dict[row['cell_barcode']] = row['feature_call']
        
        # Map perturbation info to each cell in adata
        adata.obs['perturbation'] = adata.obs.index.map(lambda x: pert_dict.get(x, 'Unknown'))
        
        # Extract target gene from perturbation string
        def extract_target_gene(pert_string):
            """
            Extract target gene(s) from perturbation string.
            
            Args:
                pert_string: Perturbation string (e.g., 'Elovl1_gRNA1').
                
            Returns:
                Target gene(s) (e.g., 'Elovl1' or 'Elovl1+Tgds').
            """
            if pd.isna(pert_string) or pert_string == 'Unknown':
                return 'Unknown'
            
            # Handle multiple perturbations (separated by |)
            if '|' in pert_string:
                genes = []
                for p in pert_string.split('|'):
                    # Convert NonTargetingControlGuideForMouse to Non-targeting
                    if 'NonTargetingControlGuideForMouse' in p:
                        genes.append('Non-targeting')
                    else:
                        gene = p.split('_')[0] if '_' in p else p
                        genes.append(gene)
                return '+'.join(genes)
            else:
                # Single perturbation
                if 'NonTargetingControlGuideForMouse' in pert_string:
                    return 'Non-targeting'
                return pert_string.split('_')[0] if '_' in pert_string else pert_string
        
        # Add target gene information
        adata.obs['perturbation_name'] = adata.obs['perturbation'].apply(extract_target_gene)
        
        # Set perturbation type based on the extracted name:
        # If the perturbation is "Non-targeting", mark as "Control", else "targeting".
        adata.obs['perturbation_type'] = adata.obs['perturbation_name'].apply(
            lambda x: 'Control' if x == 'Non-targeting' else 'targeting'
        )
        
        # Exclude cells with an "Unknown" perturbation name.
        adata = adata[adata.obs['perturbation_name'] != 'Unknown'].copy()
    
    finally:
        # Clean up temporary file
        if os.path.exists(temp_pert_file):
            os.remove(temp_pert_file)
    
    return adata

def harmonize_data(adata, sample_id):
    """
    Harmonize the data according to specified standards.
    
    Args:
        adata: AnnData object to harmonize.
        sample_id: Sample identifier.
        
    Returns:
        Harmonized AnnData object.
    """
    # Add sample ID
    adata.obs['sample_id'] = sample_id
    
    # Add organism information (from the study description)
    adata.obs['organism'] = 'Mus musculus'
    
    # Add cell type information (from the study description)
    adata.obs['cell_type'] = 'CD8+ T cells'
    
    # Add CRISPR type information (from the study description)
    adata.obs['crispr_type'] = 'CRISPR KO'
    
    # Add cancer type information (from the study description)
    adata.obs['cancer_type'] = 'Pancreatic cancer'
    
    # Add condition information based on sample ID.
    if sample_id in ['MMA198_1', 'MMA200_1']:
        adata.obs['condition'] = 'IgG control'
    elif sample_id in ['MMA198_2', 'MMA202_1']:
        adata.obs['condition'] = 'anti-PD-1'
    else:
        adata.obs['condition'] = 'Unknown'
    
    # Overwrite condition for non-targeting cells to "Control" only.
    adata.obs.loc[adata.obs['perturbation_type'] == 'Control', 'condition'] = 'Control'
    
    return adata

def process_sample(data_dir, sample_id):
    """
    Process a single sample from the dataset.
    
    Args:
        data_dir: Directory containing the data files.
        sample_id: Sample identifier.
        
    Returns:
        Processed AnnData object.
    """
    print(f"Processing sample {sample_id}...")
    
    # Define file paths
    matrix_file = os.path.join(data_dir, f"{ACCESSION}_{sample_id}_matrix.mtx.gz")
    features_file = os.path.join(data_dir, f"{ACCESSION}_{sample_id}_features.tsv.gz")
    barcodes_file = os.path.join(data_dir, f"{ACCESSION}_{sample_id}_barcodes.tsv.gz")
    perturbation_file = os.path.join(data_dir, f"{ACCESSION}_{sample_id}_protospacer_calls_per_cell.csv.gz")
    
    # Check if all files exist
    for file_path in [matrix_file, features_file, barcodes_file, perturbation_file]:
        if not os.path.exists(file_path):
            raise FileNotFoundError(f"File not found: {file_path}")
    
    # Load data
    adata = load_10x_data(matrix_file, features_file, barcodes_file)
    
    # Add perturbation information (which filters out Unknown cells)
    adata = load_perturbation_data(perturbation_file, adata)
    
    # Harmonize data (this will now update condition for non-targeting cells)
    adata = harmonize_data(adata, sample_id)
    
    return adata

def main_jupyter(data_dir=None):
    """Main function for running the pipeline in a Jupyter Notebook."""
    if data_dir is None:
        data_dir = os.path.join('.', ACCESSION)
    
    # Download dataset if needed
    if not download_dataset(data_dir):
        print("Error downloading dataset files. Please check your internet connection.")
        return
    
    # Process each sample
    adatas = []
    for i, sample_id in enumerate(SAMPLE_IDS):
        try:
            adata = process_sample(data_dir, sample_id)
            # Overwrite sample_id with a numeric identifier for easier reference
            adata.obs['sample_id'] = str(i)
            adatas.append(adata)
        except Exception as e:
            print(f"Error processing sample {sample_id}: {e}")
            continue
    
    if not adatas:
        print("No samples were successfully processed.")
        return
    
    # Combine all samples into a single AnnData object
    print("Combining samples...")
    combined_adata = anndata.concat(adatas, join='outer', label='sample_id')
    
    # Ensure var_names are gene symbols
    if 'gene_symbols' in combined_adata.var:
        combined_adata.var_names = combined_adata.var['gene_symbols']
        combined_adata.var_names_make_unique()
    
    # Make observation names unique
    combined_adata.obs_names_make_unique()
    
    # Save the combined data
    output_file = os.path.join(data_dir, f"{ACCESSION}_harmonized.h5ad")
    print(f"Saving harmonized data to {output_file}")
    combined_adata.write_h5ad(output_file)
    
    print("Processing complete!")
    print(f"Final dataset shape: {combined_adata.shape}")
    print(f"Number of genes: {combined_adata.n_vars}")
    print(f"Number of cells: {combined_adata.n_obs}")
    print(f"Harmonized data saved to: {output_file}")

# Run the pipeline in the Jupyter Notebook
main_jupyter()
