In [None]:
import os
import re
import gzip
import shutil
import urllib.request
import pandas as pd
import numpy as np
import anndata as ad
from scipy import sparse
from pathlib import Path
from tqdm import tqdm

# Constants
ACCESSION = "GSE229505"
BASE_URL = f"https://ftp.ncbi.nlm.nih.gov/geo/series/GSE229nnn/{ACCESSION}/suppl/"
SAMPLE_PAIRS = [
    # Pilot experiments
    ("BUT290A1_BUT290A2", "Pilot experiment, 2 sgRNAs"),
    ("BUT290A19_BUT290A20", "Pilot experiment, 24 sgRNAs"),
    # Main screen - unstimulated
    ("BUT290A21_BUT290A22", "Dual perturb-seq screen, unstimulated, replicate 1, channel 1"),
    ("BUT290A23_BUT290A24", "Dual perturb-seq screen, unstimulated, replicate 1, channel 2"),
    ("BUT290A25_BUT290A26", "Dual perturb-seq screen, unstimulated, replicate 2, channel 1"),
    ("BUT290A27_BUT290A28", "Dual perturb-seq screen, unstimulated, replicate 2, channel 2"),
    ("BUT290A37_BUT290A38", "Dual perturb-seq screen, unstimulated, replicate 3, channel 1"),
    ("BUT290A39_BUT290A40", "Dual perturb-seq screen, unstimulated, replicate 3, channel 2"),
    ("BUT290A41_BUT290A42", "Dual perturb-seq screen, unstimulated, replicate 3, channel 3"),
    ("BUT290A43_BUT290A44", "Dual perturb-seq screen, unstimulated, replicate 3, channel 4"),
    # Main screen - IFNγ stimulated
    ("BUT290A29_BUT290A30", "Dual perturb-seq screen, interferon-gamma stimulated, replicate 1, channel 1"),
    ("BUT290A31_BUT290A32", "Dual perturb-seq screen, interferon-gamma stimulated, replicate 1, channel 2"),
    ("BUT290A33_BUT290A34", "Dual perturb-seq screen, interferon-gamma stimulated, replicate 1, channel 3"),
    ("BUT290A35_BUT290A36", "Dual perturb-seq screen, interferon-gamma stimulated, replicate 1, channel 4"),
    ("BUT290A45_BUT290A46", "Dual perturb-seq screen, interferon-gamma stimulated, replicate 2, channel 1"),
    ("BUT290A47_BUT290A48", "Dual perturb-seq screen, interferon-gamma stimulated, replicate 2, channel 2"),
    ("BUT290A49_BUT290A50", "Dual perturb-seq screen, interferon-gamma stimulated, replicate 2, channel 3"),
    ("BUT290A51_BUT290A52", "Dual perturb-seq screen, interferon-gamma stimulated, replicate 2, channel 4"),
]

def download_file(url, dest_path, force=False):
    """Download a file if it doesn't exist or if force is True."""
    if os.path.exists(dest_path) and not force:
        print(f"File already exists: {dest_path}")
        return
    
    try:
        with urllib.request.urlopen(url) as response, open(dest_path, 'wb') as out_file:
            shutil.copyfileobj(response, out_file)
    except Exception as e:
        print(f"Error downloading {url}: {e}")
        if os.path.exists(dest_path):
            os.remove(dest_path)
        raise

def download_dataset(data_dir):
    """Download all necessary files for the dataset."""
    os.makedirs(data_dir, exist_ok=True)
    
    for sample_id, _ in tqdm(SAMPLE_PAIRS, desc="Downloading samples"):
        for file_type in ["barcodes.tsv.gz", "features.tsv.gz", "matrix.mtx.gz"]:
            file_name = f"{ACCESSION}_{sample_id}_{file_type}"
            url = f"{BASE_URL}{file_name}"
            dest_path = os.path.join(data_dir, file_name)
            download_file(url, dest_path)

def read_10x_mtx(matrix_path):
    """Read 10x matrix file directly without using scanpy."""
    with gzip.open(matrix_path, 'rt') as f:
        # Skip header lines
        while True:
            line = f.readline()
            if not line.startswith('%'):
                break
        
        # Parse dimensions
        dims = line.strip().split()
        n_genes, n_cells = int(dims[0]), int(dims[1])
        
        # Read data
        data = []
        row_indices = []
        col_indices = []
        
        for line in f:
            parts = line.strip().split()
            if len(parts) >= 3:
                gene_idx, cell_idx, count = map(int, parts[:3])
                # Convert to 0-based indexing
                row_indices.append(gene_idx - 1)
                col_indices.append(cell_idx - 1)
                data.append(count)
    
    # Create sparse matrix (genes x cells)
    matrix = sparse.csr_matrix((data, (row_indices, col_indices)), shape=(n_genes, n_cells))
    return matrix

def read_10x_data(data_dir, sample_id):
    """Read 10x Genomics format data for a sample."""
    barcodes_path = os.path.join(data_dir, f"{ACCESSION}_{sample_id}_barcodes.tsv.gz")
    features_path = os.path.join(data_dir, f"{ACCESSION}_{sample_id}_features.tsv.gz")
    matrix_path = os.path.join(data_dir, f"{ACCESSION}_{sample_id}_matrix.mtx.gz")
    
    # Read barcodes
    with gzip.open(barcodes_path, 'rt') as f:
        barcodes = [line.strip() for line in f]
    
    # Read features
    with gzip.open(features_path, 'rt') as f:
        features = []
        for line in f:
            parts = line.strip().split('\t')
            if len(parts) >= 3:
                features.append({
                    'id': parts[0],
                    'name': parts[1],
                    'type': parts[2] if len(parts) > 2 else "Gene Expression"
                })
    
    features_df = pd.DataFrame(features)
    
    # Separate gene expression and CRISPR guide features
    gene_expr_idx = features_df['type'] == "Gene Expression"
    crispr_idx = features_df['type'] == "CRISPR Guide Capture"
    
    gene_features = features_df[gene_expr_idx]
    crispr_features = features_df[crispr_idx]
    
    # Convert boolean indices to integer indices
    gene_expr_indices = np.where(gene_expr_idx)[0]
    crispr_indices = np.where(crispr_idx)[0]
    
    # Read matrix
    mtx = read_10x_mtx(matrix_path)
    
    # Create separate matrices for gene expression and CRISPR guides
    gene_mtx = mtx[gene_expr_indices, :]
    crispr_mtx = mtx[crispr_indices, :]
    
    return {
        'barcodes': barcodes,
        'gene_features': gene_features,
        'crispr_features': crispr_features,
        'gene_mtx': gene_mtx,
        'crispr_mtx': crispr_mtx
    }

def extract_perturbation_name(feature_name):
    """Extract perturbation name from CRISPR guide feature name."""
    # Check for non-targeting controls
    if "non-targeting" in feature_name.lower() or "nt" in feature_name.lower():
        return "Non-targeting"
    
    # Try to extract gene name using regex patterns
    patterns = [
        r'sg([A-Za-z0-9]+)',  # sgGeneA
        r'sg_([A-Za-z0-9]+)',  # sg_GeneA
        r'sg-([A-Za-z0-9]+)',  # sg-GeneA
        r'([A-Za-z0-9]+)_sg',  # GeneA_sg
        r'([A-Za-z0-9]+)-sg',  # GeneA-sg
    ]
    
    for pattern in patterns:
        match = re.search(pattern, feature_name)
        if match:
            return match.group(1)
    
    # If no pattern matches, try splitting by underscore
    if '_' in feature_name:
        parts = feature_name.split('_')
        if len(parts) > 1:
            return parts[0]
    
    # If all else fails, return the original name
    return feature_name

def process_sample(data_dir, sample_id, sample_name):
    """Process a single sample and return an AnnData object."""
    print(f"Processing sample: {sample_name}")
    
    try:
        # Read 10x data
        data = read_10x_data(data_dir, sample_id)
        
        # Create AnnData object for gene expression
        adata = ad.AnnData(
            X=data['gene_mtx'].T,  # Transpose to cells x genes
            obs=pd.DataFrame(index=data['barcodes']),
            var=pd.DataFrame(index=data['gene_features']['name'].values)  # Use gene symbols
        )
        
        # Add gene IDs as additional column
        adata.var['gene_id'] = data['gene_features']['id'].values
        
        # Add CRISPR guide information to obs
        if data['crispr_mtx'].shape[0] > 0:
            crispr_df = pd.DataFrame(
                data['crispr_mtx'].T.toarray(),  # Transpose to cells x guides
                index=data['barcodes'],
                columns=data['crispr_features']['name'].values
            )
            
            # Identify which guide has the highest count for each cell
            max_guide_idx = crispr_df.idxmax(axis=1)
            max_guide_count = crispr_df.max(axis=1)
            
            # Only assign a guide if the count is > 0
            perturbations = []
            for guide, count in zip(max_guide_idx, max_guide_count):
                if count > 0:
                    perturbations.append(extract_perturbation_name(guide))
                else:
                    perturbations.append("Non-targeting")
            
            adata.obs['perturbation_name'] = perturbations
            
            # Add raw guide counts to the object
            adata.obsm['crispr_counts'] = crispr_df
        else:
            adata.obs['perturbation_name'] = "Unknown"
        
        # Add sample metadata
        adata.obs['sample_id'] = sample_id
        adata.obs['sample_name'] = sample_name
        
        # Parse condition from sample name
        if "interferon-gamma stimulated" in sample_name:
            adata.obs['condition'] = "IFNγ stimulated"
        else:
            adata.obs['condition'] = "Unstimulated"
        
        # Add replicate and channel information
        if "replicate" in sample_name:
            replicate = sample_name.split("replicate")[1].split(",")[0].strip()
            adata.obs['replicate'] = replicate
            
            channel = sample_name.split("channel")[1].split(",")[0].strip()
            adata.obs['channel'] = channel
        else:
            adata.obs['replicate'] = "pilot"
            adata.obs['channel'] = "pilot"
        
        # Separate human and Toxoplasma genes
        human_genes = [name.startswith('GRCh38_') for name in adata.var_names]
        toxo_genes = [name.startswith('TGA4_') for name in adata.var_names]
        
        # Convert to numpy arrays for boolean indexing
        human_genes = np.array(human_genes)
        toxo_genes = np.array(toxo_genes)
        
        # Add species information
        adata.var['species'] = "Unknown"
        adata.var.loc[human_genes, 'species'] = "Homo sapiens"
        adata.var.loc[toxo_genes, 'species'] = "Toxoplasma gondii"
        
        # Clean up gene names (remove prefixes)
        new_var_names = []
        for name in adata.var_names:
            if name.startswith('GRCh38_'):
                new_var_names.append(name.replace('GRCh38_', ''))
            elif name.startswith('TGA4_'):
                new_var_names.append(name.replace('TGA4___', ''))
            else:
                new_var_names.append(name)
        
        adata.var_names = pd.Index(new_var_names)
        
        # Make variable names unique
        adata.var_names_make_unique()
        
        # Add standardized metadata
        adata.obs['organism'] = "Homo sapiens"  # Host cells are human
        adata.obs['cell_type'] = "Human foreskin fibroblast"  # HFFs as mentioned in the study
        adata.obs['crispr_type'] = "CRISPR KO"  # CRISPR knockout as mentioned in the study
        adata.obs['cancer_type'] = "Non-Cancer"  # HFFs are not cancer cells
        
        return adata
    
    except Exception as e:
        print(f"Error processing sample {sample_id}: {e}")
        return None

def harmonize_dataset(data_dir):
    """Process and harmonize the entire dataset."""
    all_samples = []
    
    for sample_id, sample_name in SAMPLE_PAIRS:
        adata = process_sample(data_dir, sample_id, sample_name)
        if adata is not None:
            all_samples.append(adata)
    
    # Concatenate all samples
    if all_samples:
        print(f"Concatenating {len(all_samples)} samples...")
        combined = ad.concat(
            all_samples, 
            join='outer',  # Keep all genes from all datasets
            merge='same',  # Only merge same fields
            label='sample_id',
            index_unique='-'
        )
        
        # Convert categorical columns
        for col in ['organism', 'cell_type', 'crispr_type', 'cancer_type', 'condition']:
            combined.obs[col] = combined.obs[col].astype('category')
        
        return combined
    else:
        raise ValueError("No samples were successfully processed")

def main(data_dir):
    """Main function to process and harmonize the dataset."""
    data_dir = os.path.abspath(data_dir)
    os.makedirs(data_dir, exist_ok=True)
    
    print(f"Processing GSE229505 dataset in {data_dir}")
    
    # Download dataset if files don't exist
    download_dataset(data_dir)
    
    # Process and harmonize dataset
    adata = harmonize_dataset(data_dir)
    
    # Save harmonized dataset
    output_file = os.path.join(data_dir, f"{ACCESSION}_harmonized.h5ad")
    print(f"Saving harmonized dataset to {output_file}")
    adata.write_h5ad(output_file)
    
    print("Processing complete!")
    print(f"Dataset shape: {adata.shape[0]} cells × {adata.shape[1]} genes")
    print(f"Organism: {adata.obs['organism'].cat.categories}")
    print(f"Cell types: {adata.obs['cell_type'].cat.categories}")
    print(f"CRISPR types: {adata.obs['crispr_type'].cat.categories}")
    print(f"Cancer types: {adata.obs['cancer_type'].cat.categories}")
    print(f"Conditions: {adata.obs['condition'].cat.categories}")
    print(f"Number of unique perturbations: {adata.obs['perturbation_name'].nunique()}")

# For Jupyter, simply set the data directory and run main()
data_dir = "./data"  # Change this path as needed
main(data_dir)
