In [None]:
import os
import sys
import glob
import tarfile
import urllib.request
import numpy as np
import pandas as pd
import scipy.io
import scanpy as sc
import anndata as ad

# Dataset constants
GEO_ACCESSION = "GSE243244"
GEO_URL = "https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE243244&format=file"
ORGANISM = "Mus musculus"
CRISPR_TYPE = "CRISPR KO"

def download_dataset(data_dir):
    """Download the dataset if not already present."""
    os.makedirs(data_dir, exist_ok=True)
    tar_path = os.path.join(data_dir, f"{GEO_ACCESSION}_RAW.tar")
    
    if not os.path.exists(tar_path):
        print(f"Downloading {GEO_ACCESSION} dataset...")
        urllib.request.urlretrieve(GEO_URL, tar_path)
        print(f"Downloaded to {tar_path}")
    else:
        print(f"Dataset already downloaded at {tar_path}")
    
    # Extract GSM files if needed
    gsm_files = glob.glob(os.path.join(data_dir, "GSM*.tar.gz"))
    if not gsm_files:
        with tarfile.open(tar_path) as tar:
            for member in tar.getmembers():
                if member.name.startswith("GSM") and member.name.endswith(".tar.gz"):
                    tar.extract(member, path=data_dir)
                    print(f"Extracted {member.name}")
    
    return gsm_files or glob.glob(os.path.join(data_dir, "GSM*.tar.gz"))

def extract_tar_gz(tar_gz_path, extract_dir):
    """Extract a tar.gz file to a directory."""
    os.makedirs(extract_dir, exist_ok=True)
    with tarfile.open(tar_gz_path) as tar:
        tar.extractall(path=extract_dir)
    return extract_dir

def read_10x_mtx(path):
    """Read 10X Genomics formatted matrix files."""
    mtx_file = os.path.join(path, "matrix.mtx")
    features_file = os.path.join(path, "features.tsv")
    barcodes_file = os.path.join(path, "barcodes.tsv")
    
    # Read the matrix
    matrix = scipy.io.mmread(mtx_file).T.tocsr()
    
    # Read features (genes)
    features = pd.read_csv(features_file, sep='\t', header=None)
    var_names = features.iloc[:, 1].values  # Use gene symbols
    
    # Read barcodes
    barcodes = pd.read_csv(barcodes_file, sep='\t', header=None).iloc[:, 0].values
    
    # Create AnnData object
    adata = ad.AnnData(X=matrix, obs=pd.DataFrame(index=barcodes), var=pd.DataFrame(index=var_names))
    
    return adata

def process_perturbation_data(gdo_path, adata):
    """Process perturbation data and add to AnnData object."""
    # Read perturbation data
    gdo_adata = read_10x_mtx(gdo_path)
    
    # Find common barcodes
    common_barcodes = np.intersect1d(adata.obs.index, gdo_adata.obs.index)
    
    # Subset both AnnData objects to common barcodes
    adata = adata[common_barcodes].copy()
    gdo_adata = gdo_adata[common_barcodes].copy()
    
    # Convert perturbation matrix to dense for easier processing
    gdo_matrix = gdo_adata.X.toarray()
    
    # Determine perturbation for each cell
    perturbation_names = []
    for i, barcode in enumerate(common_barcodes):
        # Get indices of non-zero elements (perturbations)
        pert_indices = np.where(gdo_matrix[i] > 0)[0]
        
        if len(pert_indices) == 0:
            # No perturbation detected
            perturbation_names.append("Non-targeting")
        else:
            # Get perturbation names, removing guide numbers
            pert_names = []
            for idx in pert_indices:
                guide_name = gdo_adata.var_names[idx]
                # Handle different naming formats
                if '-' in guide_name:
                    gene_name = guide_name.split('-')[0]
                else:
                    gene_name = guide_name
                
                # Clean up gene names
                if gene_name.startswith('Random'):
                    gene_name = 'Random'
                
                pert_names.append(gene_name)
            
            # Remove duplicates and sort
            pert_names = sorted(list(set(pert_names)))
            
            # Join multiple perturbations with +
            perturbation_names.append(" + ".join(pert_names))
    
    # Add perturbation data to original AnnData
    adata.obs['perturbation_name'] = perturbation_names
    adata.obs['condition'] = np.where(
        adata.obs['perturbation_name'] == "Non-targeting", 
        "Control", 
        "Test"
    )
    
    # Store raw perturbation matrix in .obsm
    adata.obsm['X_perturbation'] = gdo_matrix
    
    return adata

def process_adt_data(adt_path, adata):
    """Process ADT (protein) data and add to AnnData object."""
    # Read ADT data
    adt_adata = read_10x_mtx(adt_path)
    
    # Find common barcodes
    common_barcodes = np.intersect1d(adata.obs.index, adt_adata.obs.index)
    
    # Subset both AnnData objects to common barcodes
    adata = adata[common_barcodes].copy()
    adt_adata = adt_adata[common_barcodes].copy()
    
    # Store ADT data in .obsm
    adata.obsm['X_adt'] = adt_adata.X
    
    return adata

def process_hto_data(hto_path, adata):
    """Process HTO (cell hashing) data and add to AnnData object."""
    # Read HTO data
    hto_adata = read_10x_mtx(hto_path)
    
    # Find common barcodes
    common_barcodes = np.intersect1d(adata.obs.index, hto_adata.obs.index)
    
    # Subset both AnnData objects to common barcodes
    adata = adata[common_barcodes].copy()
    hto_adata = hto_adata[common_barcodes].copy()
    
    # Store HTO data in .obsm
    adata.obsm['X_hto'] = hto_adata.X
    
    return adata

def process_experiment(data_dir, experiment_name):
    """Process a single experiment with all its modalities."""
    print(f"Processing experiment: {experiment_name}")
    
    # Find all modality files for this experiment
    cdna_file = glob.glob(os.path.join(data_dir, f"*{experiment_name}*cDNA*.tar.gz"))
    gdo_file = glob.glob(os.path.join(data_dir, f"*{experiment_name}*GDO*.tar.gz"))
    adt_file = glob.glob(os.path.join(data_dir, f"*{experiment_name}*ADT*.tar.gz"))
    hto_file = glob.glob(os.path.join(data_dir, f"*{experiment_name}*HTO*.tar.gz"))
    
    if not cdna_file:
        print(f"No cDNA file found for experiment {experiment_name}. Skipping.")
        return None
    
    # Extract and process cDNA data
    cdna_extract_dir = os.path.join(data_dir, f"extracted_{experiment_name}_cDNA")
    extract_tar_gz(cdna_file[0], cdna_extract_dir)
    
    # Find the matrix directory
    matrix_dirs = glob.glob(os.path.join(cdna_extract_dir, "**/matrix.mtx"), recursive=True)
    if not matrix_dirs:
        print(f"No matrix.mtx file found in cDNA extraction for {experiment_name}.")
        return None
    cdna_matrix_dir = os.path.dirname(matrix_dirs[0])
    
    # Read gene expression data
    adata = read_10x_mtx(cdna_matrix_dir)
    print(f"Read gene expression data: {adata.shape[0]} cells, {adata.shape[1]} genes")
    
    # Process GDO data if available
    if gdo_file:
        gdo_extract_dir = os.path.join(data_dir, f"extracted_{experiment_name}_GDO")
        extract_tar_gz(gdo_file[0], gdo_extract_dir)
        gdo_matrix_dirs = glob.glob(os.path.join(gdo_extract_dir, "**/matrix.mtx"), recursive=True)
        if gdo_matrix_dirs:
            gdo_matrix_dir = os.path.dirname(gdo_matrix_dirs[0])
            adata = process_perturbation_data(gdo_matrix_dir, adata)
    else:
        # Set default perturbation values
        adata.obs['perturbation_name'] = "Non-targeting"
        adata.obs['condition'] = "Control"
    
    # Process ADT data if available
    if adt_file:
        adt_extract_dir = os.path.join(data_dir, f"extracted_{experiment_name}_ADT")
        extract_tar_gz(adt_file[0], adt_extract_dir)
        adt_matrix_dirs = glob.glob(os.path.join(adt_extract_dir, "**/matrix.mtx"), recursive=True)
        if adt_matrix_dirs:
            adt_matrix_dir = os.path.dirname(adt_matrix_dirs[0])
            adata = process_adt_data(adt_matrix_dir, adata)
    
    # Process HTO data if available
    if hto_file:
        hto_extract_dir = os.path.join(data_dir, f"extracted_{experiment_name}_HTO")
        extract_tar_gz(hto_file[0], hto_extract_dir)
        hto_matrix_dirs = glob.glob(os.path.join(hto_extract_dir, "**/matrix.mtx"), recursive=True)
        if hto_matrix_dirs:
            hto_matrix_dir = os.path.dirname(hto_matrix_dirs[0])
            adata = process_hto_data(hto_matrix_dir, adata)
    
    # Add standardized metadata
    adata.obs['organism'] = ORGANISM
    adata.obs['crispr_type'] = CRISPR_TYPE
    adata.obs['cancer_type'] = "Non-Cancer"
    
    # Set experiment-specific metadata
    if experiment_name == "TCell":
        adata.obs['cell_type'] = "T Cells"
    elif experiment_name == "Splenocytes":
        adata.obs['cell_type'] = "Splenocytes"
    elif experiment_name in ["Invivo", "Invivo_B16"]:
        adata.obs['cell_type'] = "Bone marrow dendritic cells"
        adata.obs['cancer_type'] = "Melanoma"
    else:
        adata.obs['cell_type'] = "Bone marrow dendritic cells"
    
    # Add experiment type
    adata.obs['experiment_type'] = experiment_name
    
    # Add a unique cell ID
    adata.obs['cell_id'] = [f"{experiment_name}_{i}" for i in range(adata.n_obs)]
    
    return adata

def main(data_dir=None):
    """Main function to process and harmonize the dataset."""
    if data_dir is None:
        # By default, use the current working directory
        data_dir = os.getcwd()
    
    # Download and extract data
    gsm_files = download_dataset(data_dir)
    
    # Identify unique experiments
    experiments = set()
    for gsm_file in gsm_files:
        filename = os.path.basename(gsm_file)
        parts = filename.split('_')
        if len(parts) >= 2:
            experiment_name = parts[1].split('.')[0]
            if experiment_name in ['cDNA', 'GDO', 'ADT', 'HTO']:
                # This is a modality file, extract the experiment name
                experiment_name = parts[0].replace("GSM", "")
            experiments.add(experiment_name)
    
    print(f"Found {len(experiments)} experiments: {experiments}")
    
    # Process each experiment
    all_adatas = []
    for experiment in experiments:
        adata = process_experiment(data_dir, experiment)
        if adata is not None:
            all_adatas.append(adata)
    
    # Combine all experiments
    if len(all_adatas) > 1:
        # Use cell_id as index to avoid duplicate index issues
        for adata in all_adatas:
            adata.obs_names = adata.obs['cell_id']
        
        combined_adata = ad.concat(all_adatas, join='outer', merge='same')
    elif len(all_adatas) == 1:
        combined_adata = all_adatas[0]
        combined_adata.obs_names = combined_adata.obs['cell_id']
    else:
        print("No experiments processed. Exiting.")
        return None
    
    print(f"Combined dataset: {combined_adata.shape[0]} cells, {combined_adata.shape[1]} genes")
    
    # Convert categorical columns
    for col in ['perturbation_name', 'condition', 'organism', 'cell_type', 'crispr_type', 'cancer_type', 'experiment_type']:
        if col in combined_adata.obs.columns:
            combined_adata.obs[col] = combined_adata.obs[col].astype('category')
    
    # Save harmonized data
    output_path = os.path.join(data_dir, f"{GEO_ACCESSION}_harmonized.h5ad")
    combined_adata.write(output_path)
    print(f"Harmonized data saved to {output_path}")
    
    return output_path

# For Jupyter Notebook, simply call main() with the desired data directory (or leave empty to use the current directory)
data_dir = None  # Change to a specific path if needed, e.g. data_dir = "/path/to/data"
output_file = main(data_dir)
