In [None]:
import os
import glob
import gzip
import tarfile
import re
import urllib.request
from typing import Dict, List, Tuple, Set

import numpy as np
import pandas as pd
import scanpy as sc
from scipy import sparse
from anndata import AnnData

# Constants
GEO_ACCESSION = "GSE253865"
DOWNLOAD_URL = f"https://www.ncbi.nlm.nih.gov/geo/download/?acc={GEO_ACCESSION}&format=file"
TAR_FILENAME = f"{GEO_ACCESSION}_RAW.tar"


def download_data(data_dir: str) -> str:
    """
    Download the dataset if not already present.
    
    Args:
        data_dir: Directory to save the data.
        
    Returns:
        Path to the downloaded tar file.
    """
    os.makedirs(data_dir, exist_ok=True)
    tar_path = os.path.join(data_dir, TAR_FILENAME)
    
    # Check if files are already extracted
    if glob.glob(os.path.join(data_dir, "GSM*")):
        print("Files already extracted, skipping download and extraction.")
        return tar_path
    
    if not os.path.exists(tar_path):
        print(f"Downloading {GEO_ACCESSION} dataset...")
        urllib.request.urlretrieve(DOWNLOAD_URL, tar_path)
        print(f"Download complete: {tar_path}")
    else:
        print(f"Dataset already downloaded: {tar_path}")
    
    # Extract if not already extracted
    if not glob.glob(os.path.join(data_dir, "GSM*")):
        print("Extracting tar file...")
        try:
            with tarfile.open(tar_path, 'r') as tar:
                tar.extractall(path=data_dir)
            print("Extraction complete.")
        except Exception as e:
            print(f"Error extracting tar file: {e}")
            print("Continuing with existing files...")
    
    return tar_path


def get_sample_files(data_dir: str) -> Dict[str, Dict[str, List[str]]]:
    """
    Get all sample files organized by sample type and file type.
    
    Args:
        data_dir: Directory containing the data files.
        
    Returns:
        Dictionary of sample files organized by sample type and file type.
    """
    sample_files = {
        "cell": {"barcodes": [], "features": [], "matrix": []},
        "nucleus": {"barcodes": [], "features": [], "matrix": []},
        "CITE": {"barcodes": [], "features": [], "matrix": []}
    }
    
    for file_path in glob.glob(os.path.join(data_dir, "GSM*")):
        file_name = os.path.basename(file_path)
        
        if "_cell_" in file_name:
            sample_type = "cell"
        elif "_nucleus_" in file_name:
            sample_type = "nucleus"
        elif "_CITE_" in file_name:
            sample_type = "CITE"
        else:
            continue
        
        if "barcodes" in file_name:
            sample_files[sample_type]["barcodes"].append(file_path)
        elif "features" in file_name:
            sample_files[sample_type]["features"].append(file_path)
        elif "matrix" in file_name:
            sample_files[sample_type]["matrix"].append(file_path)
    
    # Sort files to ensure a consistent order
    for s_type in sample_files:
        for f_type in sample_files[s_type]:
            sample_files[s_type][f_type].sort()
    
    return sample_files


def extract_sample_metadata(file_path: str) -> Dict[str, str]:
    """
    Extract sample metadata from a file path.
    
    Args:
        file_path: Path to the file.
        
    Returns:
        Dictionary of sample metadata.
    """
    file_name = os.path.basename(file_path)
    
    # Extract GSM accession
    gsm_match = re.search(r'(GSM\d+)', file_name)
    gsm_accession = gsm_match.group(1) if gsm_match else "Unknown"
    
    # Extract sample ID
    sample_id_match = re.search(r'_(cell|nucleus|CITE)_([A-Za-z0-9]+)', file_name)
    sample_id = sample_id_match.group(2) if sample_id_match else "Unknown"
    
    # Extract sample type
    sample_type_match = re.search(r'_(cell|nucleus|CITE)_', file_name)
    sample_type = sample_type_match.group(1) if sample_type_match else "Unknown"
    
    # Extract patient ID
    patient_id = "Unknown"
    if sample_id.startswith("KDP"):
        patient_id = f"NB00{sample_id[3:]}" if len(sample_id) > 3 else "Unknown"
    elif sample_id.startswith("CS"):
        patient_id = "NB003"
    elif sample_id.startswith("NBO"):
        patient_id = sample_id
    
    return {
        "gsm_accession": gsm_accession,
        "sample_id": sample_id,
        "sample_type": sample_type,
        "patient_id": patient_id
    }


def load_10x_data(matrix_file: str, features_file: str, barcodes_file: str) -> Tuple[sparse.csr_matrix, pd.DataFrame, pd.DataFrame]:
    """
    Load 10x data from matrix, features, and barcodes files.
    
    Args:
        matrix_file: Path to matrix.mtx.gz file.
        features_file: Path to features.tsv.gz file.
        barcodes_file: Path to barcodes.tsv.gz file.
        
    Returns:
        Tuple of (matrix, features, barcodes).
    """
    # Load matrix
    matrix = sc.read_mtx(matrix_file).X.T
    
    # Load features
    with gzip.open(features_file, 'rt') as f:
        features_df = pd.read_csv(f, sep='\t', header=None)
        if features_df.shape[1] == 3:
            features_df.columns = ['id', 'name', 'feature_type']
        else:
            features_df.columns = ['id', 'name']
            features_df['feature_type'] = 'Gene Expression'
    
    # Load barcodes
    with gzip.open(barcodes_file, 'rt') as f:
        barcodes_df = pd.read_csv(f, sep='\t', header=None)
        barcodes_df.columns = ['barcode']
    
    # Extract sample metadata
    sample_metadata = extract_sample_metadata(matrix_file)
    
    # Add sample metadata to barcodes
    for key, value in sample_metadata.items():
        barcodes_df[key] = value
    
    # Create full barcode with sample ID
    barcodes_df['full_barcode'] = barcodes_df['barcode'] + '_' + barcodes_df['sample_id']
    
    return matrix, features_df, barcodes_df


def process_cite_seq_data(data_dir: str, sample_files: Dict[str, Dict[str, List[str]]]) -> Tuple[AnnData, AnnData]:
    """
    Process CITE-seq data to create gene expression and protein expression AnnData objects.
    
    Args:
        data_dir: Directory containing the data files.
        sample_files: Dictionary of sample files organized by sample type and file type.
        
    Returns:
        Tuple of (gene_expression_adata, protein_expression_adata).
    """
    # Get CITE-seq files
    barcodes_files = sample_files["CITE"]["barcodes"]
    features_files = sample_files["CITE"]["features"]
    matrix_files = sample_files["CITE"]["matrix"]
    
    # Initialize lists to store data
    gene_matrices = []
    protein_matrices = []
    gene_features_list = []
    protein_features_list = []
    barcodes_list = []
    
    # Process each CITE-seq sample
    for i, (matrix_file, features_file, barcodes_file) in enumerate(zip(matrix_files, features_files, barcodes_files)):
        print(f"Processing CITE-seq sample {i+1}/{len(matrix_files)}: {os.path.basename(matrix_file)}")
        
        # Load data
        matrix, features_df, barcodes_df = load_10x_data(matrix_file, features_file, barcodes_file)
        
        # Split gene expression and protein expression
        gene_mask = (features_df['feature_type'] == 'Gene Expression').values
        protein_mask = (features_df['feature_type'] == 'Antibody Capture').values
        
        gene_matrix = matrix[:, gene_mask]
        protein_matrix = matrix[:, protein_mask]
        
        gene_features = features_df[gene_mask].copy()
        protein_features = features_df[protein_mask].copy()
        
        # Store
        gene_matrices.append(gene_matrix)
        protein_matrices.append(protein_matrix)
        gene_features_list.append(gene_features)
        protein_features_list.append(protein_features)
        barcodes_list.append(barcodes_df)
    
    # Find common genes and proteins across all samples
    gene_feature_sets = [set(df['name']) for df in gene_features_list]
    common_genes = set.intersection(*gene_feature_sets)
    print(f"Number of common genes across CITE-seq samples: {len(common_genes)}")
    
    protein_feature_sets = [set(df['name']) for df in protein_features_list]
    common_proteins = set.intersection(*protein_feature_sets)
    print(f"Number of common proteins across CITE-seq samples: {len(common_proteins)}")
    
    # Create a unified gene feature DataFrame with common genes
    unified_gene_features = gene_features_list[0][gene_features_list[0]['name'].isin(common_genes)].copy()
    unified_gene_features = unified_gene_features.drop_duplicates(subset=['name']).reset_index(drop=True)
    
    # Create a unified protein feature DataFrame with common proteins
    unified_protein_features = protein_features_list[0][protein_features_list[0]['name'].isin(common_proteins)].copy()
    unified_protein_features = unified_protein_features.drop_duplicates(subset=['name']).reset_index(drop=True)
    
    # Concatenate all barcodes and remove duplicates
    all_barcodes = pd.concat(barcodes_list, axis=0)
    all_barcodes = all_barcodes.drop_duplicates(subset=['full_barcode']).reset_index(drop=True)
    
    # Create AnnData objects for gene and protein expression
    gene_adata = create_anndata_from_samples(
        matrices=gene_matrices,
        features_list=gene_features_list,
        barcodes_list=barcodes_list,
        common_features=common_genes,
        unified_features=unified_gene_features,
        all_barcodes=all_barcodes,
        feature_type='gene'
    )
    
    protein_adata = create_anndata_from_samples(
        matrices=protein_matrices,
        features_list=protein_features_list,
        barcodes_list=barcodes_list,
        common_features=common_proteins,
        unified_features=unified_protein_features,
        all_barcodes=all_barcodes,
        feature_type='protein'
    )
    
    # Ensure only paired data (common barcodes across gene and protein)
    common_barcodes = list(set(gene_adata.obs_names).intersection(set(protein_adata.obs_names)))
    gene_adata = gene_adata[common_barcodes, :]
    protein_adata = protein_adata[common_barcodes, :]
    
    print(f"Final gene expression data shape: {gene_adata.shape}")
    print(f"Final protein expression data shape: {protein_adata.shape}")
    
    return gene_adata, protein_adata


def create_anndata_from_samples(
    matrices: List[sparse.csr_matrix],
    features_list: List[pd.DataFrame],
    barcodes_list: List[pd.DataFrame],
    common_features: Set[str],
    unified_features: pd.DataFrame,
    all_barcodes: pd.DataFrame,
    feature_type: str
) -> AnnData:
    """
    Create an AnnData object from multiple samples using an efficient approach.
    
    Args:
        matrices: List of matrices for each sample.
        features_list: List of feature DataFrames for each sample.
        barcodes_list: List of barcode DataFrames for each sample.
        common_features: Set of common features across all samples.
        unified_features: Unified feature DataFrame with just the common features.
        all_barcodes: Concatenated barcode DataFrame for all samples.
        feature_type: 'gene' or 'protein'.
        
    Returns:
        AnnData object with all samples concatenated.
    """
    print(f"  Creating {feature_type} AnnData object from {len(matrices)} samples.")
    
    # Create a mapping from barcode to row index
    barcode_to_idx = {bc: i for i, bc in enumerate(all_barcodes['full_barcode'])}
    
    # Create a mapping from feature name to column index
    feature_to_idx = {feature: i for i, feature in enumerate(unified_features['name'])}
    
    # Final shape
    n_cells = len(all_barcodes)
    n_features = len(unified_features)
    
    # Prepare arrays for constructing the sparse matrix
    data = []
    row_indices = []
    col_indices = []
    
    # Process each sample
    for sample_idx, (matrix, features_df, barcodes_df) in enumerate(zip(matrices, features_list, barcodes_list)):
        print(f"  Processing sample {sample_idx+1}/{len(matrices)}...")
        
        # Create mappings for this sample
        feature_name_to_idx = {name: i for i, name in enumerate(features_df['name'])}
        barcode_to_global_idx = {
            bc: barcode_to_idx[full_bc]
            for bc, full_bc in zip(barcodes_df['barcode'], barcodes_df['full_barcode'])
            if full_bc in barcode_to_idx
        }
        
        if not barcode_to_global_idx:
            continue
        
        # Find common features for this sample
        feature_indices = np.array([feature_name_to_idx.get(name, -1) for name in common_features])
        valid_feature_mask = feature_indices >= 0
        feature_indices = feature_indices[valid_feature_mask]
        
        # Corresponding unified indices
        unified_indices = np.array([feature_to_idx.get(name, -1) for name in common_features])
        unified_indices = unified_indices[valid_feature_mask]
        
        # Get the valid barcodes
        valid_barcodes = [bc for bc in barcodes_df['barcode'] if bc in barcode_to_global_idx]
        barcode_to_local_idx = {bc: i for i, bc in enumerate(barcodes_df['barcode'])}
        cell_indices = np.array([barcode_to_local_idx.get(bc, -1) for bc in valid_barcodes])
        
        if len(cell_indices) == 0 or len(feature_indices) == 0:
            continue
        
        # Extract submatrix with only these cells and features
        submatrix = matrix[cell_indices, :][:, feature_indices]
        
        # Convert to COO for easy iteration
        coo = submatrix.tocoo()
        
        # Map local row/col to global
        for i, j, v in zip(coo.row, coo.col, coo.data):
            global_row = barcode_to_global_idx[valid_barcodes[i]]
            global_col = unified_indices[j]
            
            data.append(v)
            row_indices.append(global_row)
            col_indices.append(global_col)
    
    # Construct the final sparse matrix
    X = sparse.csr_matrix((data, (row_indices, col_indices)), shape=(n_cells, n_features))
    
    # Create AnnData object
    adata = AnnData(X=X)
    adata.obs_names = all_barcodes['full_barcode'].values
    adata.var_names = unified_features['name'].values
    
    # Add feature metadata
    if feature_type == 'gene':
        adata.var['gene_id'] = unified_features['id'].values
        adata.var['feature_type'] = unified_features['feature_type'].values
    else:
        adata.var['protein_id'] = unified_features['id'].values
        adata.var['feature_type'] = unified_features['feature_type'].values
    
    # Add cell metadata
    for col in all_barcodes.columns:
        if col != 'full_barcode':
            adata.obs[col] = all_barcodes[col].values
    
    return adata


def harmonize_metadata(adata: AnnData, data_type: str) -> AnnData:
    """
    Harmonize metadata according to a specified format, and clean protein var_names.
    
    Args:
        adata: AnnData object to harmonize.
        data_type: 'gene' or 'protein'.
        
    Returns:
        Harmonized AnnData object.
    """
    # Add some relevant metadata fields
    adata.obs['organism'] = 'Homo sapiens'
    adata.obs['cell_type'] = 'Neuroblastoma'  
    adata.obs['crispr_type'] = 'None'  
    adata.obs['cancer_type'] = 'Neuroblastoma'
    adata.obs['condition'] = 'Control'  
    adata.obs['perturbation_name'] = 'None'  
    
    # Additional metadata
    adata.obs['data_type'] = data_type
    adata.obs['study_accession'] = GEO_ACCESSION
    
    # If it's gene data, ensure var_names are gene symbols if possible
    if data_type == 'gene':
        # Heuristic check if the var names look like Ensembl IDs
        if all(name.startswith('ENSG') for name in adata.var_names):
            # If so, we assume `adata.var['name']` has the gene symbols:
            adata.var_names = adata.var['name'].values
        
        # Remove duplicated gene names if any
        if adata.var_names.duplicated().any():
            print(f"Found {adata.var_names.duplicated().sum()} duplicate gene names, removing duplicates.")
            adata = adata[:, ~adata.var_names.duplicated()].copy()
    
    elif data_type == 'protein':
        # Remove suffixes (underscore + everything after)
        new_names = [re.sub(r"_.*$", "", n) for n in adata.var_names]
        adata.var_names = new_names
    
    return adata


def main(data_dir: str = None):
    """
    Main function to process and harmonize the dataset.
    
    Args:
        data_dir: Directory to save and process the data.
    """
    if data_dir is None:
        data_dir = os.path.join(os.getcwd(), GEO_ACCESSION)
    
    # 1. Download data (if not present)
    download_data(data_dir)
    
    # 2. Identify and group sample files
    sample_files = get_sample_files(data_dir)
    
    # 3. Process CITE-seq data (if available)
    if sample_files["CITE"]["barcodes"]:
        print("Processing CITE-seq data...")
        gene_adata, protein_adata = process_cite_seq_data(data_dir, sample_files)
        
        # 4. Harmonize metadata
        gene_adata = harmonize_metadata(gene_adata, 'gene')
        protein_adata = harmonize_metadata(protein_adata, 'protein')
        
        # 5. Save outputs
        gene_output_path = os.path.join(data_dir, f"{GEO_ACCESSION}_gene_expression.h5ad")
        protein_output_path = os.path.join(data_dir, f"{GEO_ACCESSION}_protein_expression.h5ad")
        
        print(f"Saving gene expression data to: {gene_output_path}")
        gene_adata.write(gene_output_path, compression='gzip')
        
        print(f"Saving protein expression data to: {protein_output_path}")
        protein_adata.write(protein_output_path, compression='gzip')
        
        print("Processing complete!")
        print(f"Gene expression data shape: {gene_adata.shape}")
        print(f"Protein expression data shape: {protein_adata.shape}")
    else:
        print("No CITE-seq data found.")


# In a Jupyter notebook, run all cells to define these functions,
# then simply call main() to use the default directory, or specify a path:
main("/content/GSE253865")
