In [None]:
import os
import os
import re
import glob
import urllib.request
from typing import Dict, List, Tuple, Optional

import numpy as np
import pandas as pd
import anndata as ad
from scipy import io
from scipy import sparse

# Base URL for GSE272457 dataset
GSE272457_BASE_URL = "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE272nnn/GSE272457/suppl/"

# List of files to download
GSE272457_FILES = [
    "GSE272457_293T_LRB100_NTlib1_barcodes.tsv.gz",
    "GSE272457_293T_LRB100_NTlib1_features.tsv.gz",
    "GSE272457_293T_LRB100_NTlib1_matrix.mtx.gz",
    "GSE272457_293T_LRB100_NTlib1-NIH3T3_LRB100_NTlib2_0hr_mix_barcodes.tsv.gz",
    "GSE272457_293T_LRB100_NTlib1-NIH3T3_LRB100_NTlib2_0hr_mix_features.tsv.gz",
    "GSE272457_293T_LRB100_NTlib1-NIH3T3_LRB100_NTlib2_0hr_mix_matrix.mtx.gz",
    "GSE272457_293T_LRB100_NTlib1-NIH3T3_LRB100_NTlib2_72hr_mix_barcodes.tsv.gz",
    "GSE272457_293T_LRB100_NTlib1-NIH3T3_LRB100_NTlib2_72hr_mix_features.tsv.gz",
    "GSE272457_293T_LRB100_NTlib1-NIH3T3_LRB100_NTlib2_72hr_mix_matrix.mtx.gz",
    "GSE272457_293T_MCH2_NTlib1_barcodes.tsv.gz",
    "GSE272457_293T_MCH2_NTlib1_features.tsv.gz",
    "GSE272457_293T_MCH2_NTlib1_matrix.mtx.gz",
    "GSE272457_293T_MCH2_NTlib1-NIH3T3_MCH2_NTlib2_0hr_mix_barcodes.tsv.gz",
    "GSE272457_293T_MCH2_NTlib1-NIH3T3_MCH2_NTlib2_0hr_mix_features.tsv.gz",
    "GSE272457_293T_MCH2_NTlib1-NIH3T3_MCH2_NTlib2_0hr_mix_matrix.mtx.gz",
    "GSE272457_293T_MCH2_NTlib1-NIH3T3_MCH2_NTlib2_72hr_mix_barcodes.tsv.gz",
    "GSE272457_293T_MCH2_NTlib1-NIH3T3_MCH2_NTlib2_72hr_mix_features.tsv.gz",
    "GSE272457_293T_MCH2_NTlib1-NIH3T3_MCH2_NTlib2_72hr_mix_matrix.mtx.gz",
    "GSE272457_NIH3T3_LRB100_NTlib2_barcodes.tsv.gz",
    "GSE272457_NIH3T3_LRB100_NTlib2_features.tsv.gz",
    "GSE272457_NIH3T3_LRB100_NTlib2_matrix.mtx.gz",
    "GSE272457_NIH3T3_MCH2_NTlib2_barcodes.tsv.gz",
    "GSE272457_NIH3T3_MCH2_NTlib2_features.tsv.gz",
    "GSE272457_NIH3T3_MCH2_NTlib2_matrix.mtx.gz",
]

def download_file(url: str, output_path: str) -> bool:
    """
    Download a file from a URL to a local path.
    """
    try:
        if os.path.exists(output_path):
            print(f"File already exists: {output_path}")
            return True
        
        print(f"Downloading {url} to {output_path}")
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        urllib.request.urlretrieve(url, output_path)
        return True
    except Exception as e:
        print(f"Error downloading {url}: {str(e)}")
        return False

def load_10x_data(barcodes_file: str, features_file: str, matrix_file: str) -> Tuple[ad.AnnData, Optional[ad.AnnData]]:
    """
    Load 10x data from barcodes, features, and matrix files.
    """
    # Check if files exist
    for file_path in [barcodes_file, features_file, matrix_file]:
        if not os.path.exists(file_path):
            raise FileNotFoundError(f"File not found: {file_path}")
    
    # Load data
    barcodes = pd.read_csv(barcodes_file, header=None, sep='\t')[0].values
    features = pd.read_csv(features_file, header=None, sep='\t')
    matrix = io.mmread(matrix_file).T.tocsr()
    
    # Process features
    feature_ids = features[0].values
    feature_names = features[1].values
    
    # Handle feature types - they might not be present
    if features.shape[1] > 2:
        feature_types = features[2].values
    else:
        feature_types = np.array(['Gene Expression'] * len(feature_ids))
    
    # Check for CRISPR guide features using various possible type names
    crispr_type_patterns = [
        'CRISPR Guide Capture', 'CRISPR Guide', 'Guide Capture', 
        'gRNA', 'sgRNA', 'guide', 'CRISPR'
    ]
    
    gene_mask = np.ones(len(feature_types), dtype=bool)
    crispr_mask = np.zeros(len(feature_types), dtype=bool)
    
    for pattern in crispr_type_patterns:
        pattern_mask = feature_types == pattern
        crispr_mask = crispr_mask | pattern_mask
        gene_mask = gene_mask & ~pattern_mask
        
        # Check for partial match (case insensitive)
        for i, ftype in enumerate(feature_types):
            if isinstance(ftype, str) and pattern.lower() in ftype.lower():
                crispr_mask[i] = True
                gene_mask[i] = False
    
    # If no CRISPR guides were found, try to infer from feature names
    if not np.any(crispr_mask):
        for i, name in enumerate(feature_names):
            if isinstance(name, str) and any(p.lower() in name.lower() for p in ['guide', 'grna', 'sgrna', 'crispr']):
                crispr_mask[i] = True
                gene_mask[i] = False
    
    gene_indices = np.where(gene_mask)[0]
    crispr_indices = np.where(crispr_mask)[0]
    
    if len(gene_indices) == 0 and len(crispr_indices) == 0:
        gene_indices = np.arange(len(feature_types))
        crispr_indices = np.array([], dtype=int)
    
    cleaned_feature_names = []
    for name in feature_names[gene_indices]:
        if isinstance(name, str):
            for prefix in ["GRCh38_", "mm10_", "hg38_", "hg19_", "mm9_"]:
                if name.startswith(prefix):
                    name = name[len(prefix):]
                    break
            if "(" in name and ")" in name:
                symbol_match = re.search(r'\((.*?)\)', name)
                if symbol_match:
                    name = symbol_match.group(1)
            if "." in name and not name.startswith("LOC"):
                name = name.split(".")[0]
        cleaned_feature_names.append(name)
    
    # Create gene expression AnnData
    if len(gene_indices) > 0:
        adata_gene = ad.AnnData(X=matrix[:, gene_indices])
        adata_gene.obs_names = barcodes
        adata_gene.var['original_names'] = feature_names[gene_indices]
        adata_gene.var_names = cleaned_feature_names
        adata_gene.var['gene_ids'] = feature_ids[gene_indices]
        adata_gene.var['feature_types'] = feature_types[gene_indices]
    else:
        adata_gene = None
    
    # Create CRISPR guide AnnData
    if len(crispr_indices) > 0:
        adata_crispr = ad.AnnData(X=matrix[:, crispr_indices])
        adata_crispr.obs_names = barcodes
        adata_crispr.var_names = feature_names[crispr_indices]
        adata_crispr.var['guide_ids'] = feature_ids[crispr_indices]
        adata_crispr.var['feature_types'] = feature_types[crispr_indices]
    else:
        adata_crispr = None
    
    return adata_gene, adata_crispr

def extract_metadata_from_filename(filename: str) -> Dict[str, str]:
    """
    Extract metadata from the filename for GSE272457 dataset.
    """
    metadata = {
        'dataset': os.path.splitext(os.path.basename(filename))[0],
        'organism': 'Unknown',
        'cell_type': 'Unknown',
        'cancer_type': 'Non-Cancer',
        'crispr_type': 'CRISPRi',
        'condition': 'Control',
        'perturbation_name': 'Unknown',
        'time_point': 'Unknown',
        'mix': False
    }
    
    if '293T' in filename and 'NIH3T3' in filename:
        metadata['organism'] = 'Mixed'
        metadata['cell_type'] = 'Mixed'
        metadata['mix'] = True
    elif '293T' in filename:
        metadata['organism'] = 'Homo sapiens'
        metadata['cell_type'] = 'HEK293T'
    elif 'NIH3T3' in filename:
        metadata['organism'] = 'Mus musculus'
        metadata['cell_type'] = 'NIH3T3'
    
    if '0hr' in filename:
        metadata['time_point'] = '0hr'
    elif '72hr' in filename:
        metadata['time_point'] = '72hr'
    
    if 'NTlib1' in filename and 'NTlib2' not in filename:
        metadata['perturbation_name'] = 'Non-targeting'
    elif 'NTlib2' in filename and 'NTlib1' not in filename:
        metadata['perturbation_name'] = 'Non-targeting'
    elif 'NTlib1' in filename and 'NTlib2' in filename:
        metadata['perturbation_name'] = 'Non-targeting'
    
    return metadata

def find_dataset_files(data_dir: str) -> List[Tuple[str, str, str, str]]:
    """
    Find all dataset files in the data directory.
    """
    dataset_files = []
    barcodes_files = glob.glob(os.path.join(data_dir, "*_barcodes.tsv.gz"))
    
    for barcodes_file in barcodes_files:
        prefix = barcodes_file.replace("_barcodes.tsv.gz", "")
        features_file = f"{prefix}_features.tsv.gz"
        matrix_file = f"{prefix}_matrix.mtx.gz"
        
        if os.path.exists(features_file) and os.path.exists(matrix_file):
            dataset_files.append((prefix, barcodes_file, features_file, matrix_file))
    
    return dataset_files

def process_datasets(data_dir: str) -> List[Tuple[ad.AnnData, Optional[ad.AnnData], Dict[str, str]]]:
    """
    Process all datasets in the data directory.
    """
    dataset_files = find_dataset_files(data_dir)
    processed_datasets = []
    
    for prefix, barcodes_file, features_file, matrix_file in dataset_files:
        try:
            dataset_name = os.path.basename(prefix)
            print(f"Processing dataset: {dataset_name}")
            adata_gene, adata_crispr = load_10x_data(barcodes_file, features_file, matrix_file)
            metadata = extract_metadata_from_filename(dataset_name)
            processed_datasets.append((adata_gene, adata_crispr, metadata))
        except Exception as e:
            print(f"Error processing dataset {prefix}: {str(e)}")
    
    return processed_datasets

def harmonize_datasets(processed_datasets: List[Tuple[ad.AnnData, Optional[ad.AnnData], Dict[str, str]]]) -> ad.AnnData:
    """
    Harmonize multiple datasets into a single AnnData object.
    """
    if not processed_datasets:
        raise ValueError("No datasets to harmonize")
    
    gene_datasets = []
    for adata_gene, adata_crispr, metadata in processed_datasets:
        if adata_gene is not None:
            if not adata_gene.var_names.is_unique:
                print(f"Warning: Dataset {metadata['dataset']} has non-unique var_names. Making them unique...")
                adata_gene.var_names_make_unique()
            
            if not adata_gene.obs_names.is_unique:
                print(f"Warning: Dataset {metadata['dataset']} has non-unique obs_names. Making them unique...")
                adata_gene.obs_names_make_unique()
            
            for key, value in metadata.items():
                adata_gene.obs[key] = value
            
            adata_gene.obs_names = [f"{metadata['dataset']}_{bc}" for bc in adata_gene.obs_names]
            
            if adata_crispr is not None:
                if not adata_crispr.var_names.is_unique:
                    print(f"Warning: CRISPR data for {metadata['dataset']} has non-unique var_names. Making them unique...")
                    adata_crispr.var_names_make_unique()
                if not adata_crispr.obs_names.is_unique:
                    print(f"Warning: CRISPR data for {metadata['dataset']} has non-unique obs_names. Making them unique...")
                    adata_crispr.obs_names_make_unique()
                
                adata_crispr.obs_names = [f"{metadata['dataset']}_{bc}" for bc in adata_crispr.obs_names]
                guide_counts = adata_crispr.X.toarray()
                max_guide_idx = np.argmax(guide_counts, axis=1)
                guide_names = adata_crispr.var_names[max_guide_idx]
                adata_gene.obs['guide'] = guide_names
                is_targeting = np.array(['non-targeting' not in g.lower() for g in guide_names])
                adata_gene.obs['is_targeting'] = is_targeting
                if 'perturbation_name' not in adata_gene.obs:
                    adata_gene.obs['perturbation_name'] = np.where(is_targeting, guide_names, 'Non-targeting')
            else:
                adata_gene.obs['guide'] = 'Unknown'
                adata_gene.obs['is_targeting'] = False
                if 'perturbation_name' not in adata_gene.obs:
                    adata_gene.obs['perturbation_name'] = 'Unknown'
            
            gene_datasets.append(adata_gene)
    
    if len(gene_datasets) > 1:
        common_genes = set(gene_datasets[0].var_names)
        for adata in gene_datasets[1:]:
            common_genes &= set(adata.var_names)
        
        if common_genes:
            print(f"Found {len(common_genes)} common genes across all datasets")
            adata_harmonized = ad.concat(
                gene_datasets,
                join='inner',
                merge='same'
            )
        else:
            print("Warning: No common genes found across all datasets. Using outer join.")
            adata_harmonized = ad.concat(
                gene_datasets,
                join='outer',
                merge='same',
                fill_value=0
            )
    else:
        adata_harmonized = gene_datasets[0]
    
    required_fields = ['organism', 'cell_type', 'cancer_type', 'crispr_type', 'condition', 'perturbation_name']
    for field in required_fields:
        if field not in adata_harmonized.obs:
            adata_harmonized.obs[field] = 'Unknown'
    
    return adata_harmonized

def download_gse272457_dataset(output_dir: str) -> bool:
    """
    Download GSE272457 dataset files.
    """
    os.makedirs(output_dir, exist_ok=True)
    all_successful = True
    for file_name in GSE272457_FILES:
        url = f"{GSE272457_BASE_URL}{file_name}"
        output_path = os.path.join(output_dir, file_name)
        success = download_file(url, output_path)
        all_successful = all_successful and success
    return all_successful

def run_pipeline(output_dir: str):
    """
    Run the entire pipeline to download, process, harmonize, and save the dataset.
    
    Args:
        output_dir: Directory to save the harmonized dataset and downloads.
    """
    os.makedirs(output_dir, exist_ok=True)
    dataset_dir = os.path.join(output_dir, 'GSE272457')
    os.makedirs(dataset_dir, exist_ok=True)
    
    # Download dataset files if they don't exist
    if not all(os.path.exists(os.path.join(dataset_dir, file_name)) for file_name in GSE272457_FILES):
        print("Downloading GSE272457 dataset files...")
        download_gse272457_dataset(dataset_dir)
    
    print("Processing datasets...")
    processed_datasets = process_datasets(dataset_dir)
    
    print("Harmonizing datasets...")
    adata_harmonized = harmonize_datasets(processed_datasets)
    
    output_path = os.path.join(dataset_dir, 'GSE272457_harmonized.h5ad')
    adata_harmonized.write_h5ad(output_path)
    
    print("\nHarmonization complete!")
    print(f"Number of cells: {adata_harmonized.n_obs}")
    print(f"Number of genes: {adata_harmonized.n_vars}")
    print(f"Organisms: {', '.join(adata_harmonized.obs['organism'].unique())}")
    print(f"Cell types: {', '.join(adata_harmonized.obs['cell_type'].unique())}")
    print(f"CRISPR types: {', '.join(adata_harmonized.obs['crispr_type'].unique())}")
    print(f"Perturbation types: {', '.join(adata_harmonized.obs['perturbation_name'].unique())}")

# Example call from a Jupyter cell:
run_pipeline('/content/GSE272457')



