In [None]:
import os
import sys
import re
import gzip
import h5py
import numpy as np
import pandas as pd
import scanpy as sc
import anndata as ad
from scipy import sparse
import urllib.request
import tarfile
import shutil
from pathlib import Path

def download_dataset(output_dir):
    """
    Download the GSE254179 dataset if not already present.
    
    Args:
        output_dir (str): Directory to save the downloaded data
    
    Returns:
        str: Path to the extracted data directory
    """
    output_dir = Path(output_dir)
    tar_file = output_dir / "GSE254179_RAW.tar"
    extract_dir = output_dir / "GSE254179"
    
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Download the dataset if it doesn't exist
    if not tar_file.exists():
        print(f"Downloading GSE254179 dataset to {tar_file}...")
        url = "https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE254179&format=file"
        urllib.request.urlretrieve(url, tar_file)
        print("Download complete.")
    else:
        print(f"Dataset already downloaded at {tar_file}")
    
    # Extract the dataset if it hasn't been extracted
    if not extract_dir.exists() or not any(extract_dir.iterdir()):
        print(f"Extracting dataset to {extract_dir}...")
        os.makedirs(extract_dir, exist_ok=True)
        with tarfile.open(tar_file, 'r') as tar:
            tar.extractall(path=extract_dir)
        print("Extraction complete.")
    else:
        print(f"Dataset already extracted at {extract_dir}")
    
    return extract_dir

def find_paired_data(data_dir):
    """
    Find paired gene expression and protein expression data files.
    
    Args:
        data_dir (str): Directory containing the dataset files
    
    Returns:
        list: List of tuples containing (gene_file, cite_file, feature_ref_file, sample_id)
    """
    data_dir = Path(data_dir)
    gene_expr_files = []
    cite_seq_files = []
    feature_ref_files = {}
    
    # Pattern to match sample identifiers
    pattern = re.compile(r'GSM\d+_(.+?)_(?:GEX|Cite|feature_reference)')
    
    for file in os.listdir(data_dir):
        if file.endswith('_filtered_feature_bc_matrix.h5'):
            if 'GEX' in file:
                gene_expr_files.append(file)
            elif 'Cite' in file:
                cite_seq_files.append(file)
        elif file.endswith('_feature_reference.csv.gz'):
            # Extract the sample identifier
            match = pattern.search(file)
            if match:
                sample_id = match.group(1)
                feature_ref_files[sample_id] = file
    
    # Find paired samples
    paired_samples = []
    for gene_file in gene_expr_files:
        gene_match = pattern.search(gene_file)
        if gene_match:
            gene_sample_id = gene_match.group(1)
            for cite_file in cite_seq_files:
                cite_match = pattern.search(cite_file)
                if cite_match and cite_match.group(1) == gene_sample_id:
                    feature_ref = feature_ref_files.get(gene_sample_id, None)
                    paired_samples.append((
                        str(data_dir / gene_file),
                        str(data_dir / cite_file),
                        str(data_dir / feature_ref) if feature_ref else None,
                        gene_sample_id
                    ))
    
    return paired_samples

def read_10x_h5(file_path):
    """
    Read a 10x Genomics h5 file and return an AnnData object.
    
    Args:
        file_path (str): Path to the h5 file
    
    Returns:
        anndata.AnnData: AnnData object containing the data
    """
    try:
        # Try using scanpy's read_10x_h5 function first
        adata = sc.read_10x_h5(file_path)
        print(f"  Successfully read {file_path} using scanpy.read_10x_h5")
        
        # Add feature metadata if not already present
        if 'feature_types' not in adata.var:
            with h5py.File(file_path, 'r') as f:
                if 'matrix' in f and 'features' in f['matrix'] and 'feature_type' in f['matrix']['features']:
                    feature_types = f['matrix']['features']['feature_type'][:]
                    feature_types = [ft.decode('utf-8') if isinstance(ft, bytes) else ft for ft in feature_types]
                    adata.var['feature_type'] = feature_types
        
        return adata
        
    except Exception as e:
        print(f"  Error using scanpy.read_10x_h5: {e}")
        print("  Falling back to manual h5 reading...")
        
        with h5py.File(file_path, 'r') as f:
            # Read barcodes
            barcodes = f['matrix']['barcodes'][:]
            barcodes = [bc.decode('utf-8') if isinstance(bc, bytes) else bc for bc in barcodes]
            
            # Read features
            feature_ids = f['matrix']['features']['id'][:]
            feature_ids = [id.decode('utf-8') if isinstance(id, bytes) else id for id in feature_ids]
            
            feature_names = f['matrix']['features']['name'][:]
            feature_names = [name.decode('utf-8') if isinstance(name, bytes) else name for name in feature_names]
            
            feature_types = f['matrix']['features']['feature_type'][:]
            feature_types = [ft.decode('utf-8') if isinstance(ft, bytes) else ft for ft in feature_types]
            
            # Read genome if available
            if 'genome' in f['matrix']['features']:
                genome = f['matrix']['features']['genome'][:]
                genome = [g.decode('utf-8') if isinstance(g, bytes) else g for g in genome]
            else:
                genome = ['Unknown'] * len(feature_ids)
            
            # Read sparse matrix data
            data = f['matrix']['data'][:]
            indices = f['matrix']['indices'][:]
            indptr = f['matrix']['indptr'][:]
            shape = f['matrix']['shape'][:]
            
            # Create sparse matrix - transpose to match scanpy's convention
            matrix = sparse.csr_matrix((data, indices, indptr), shape=(shape[1], shape[0])).T
            
            # Create feature metadata
            var = pd.DataFrame({
                'feature_id': feature_ids,
                'feature_name': feature_names,
                'feature_type': feature_types,
                'genome': genome
            })
            var.index = var['feature_name']
            
            # Create AnnData object
            adata = ad.AnnData(X=matrix, var=var)
            adata.obs_names = barcodes
            
            return adata

def read_feature_reference(file_path):
    """
    Read a feature reference file and return a DataFrame.
    
    Args:
        file_path (str): Path to the feature reference file
    
    Returns:
        pandas.DataFrame: DataFrame containing the feature reference data
    """
    if file_path is None:
        return None
    
    with gzip.open(file_path, 'rt') as f:
        # Check if the file uses semicolons as separators
        first_line = f.readline()
        f.seek(0)
        
        if ';' in first_line:
            sep = ';'
        else:
            sep = ','
        
        feature_ref = pd.read_csv(f, sep=sep)
        
        # Clean up column names if they contain separators
        if any(';' in col for col in feature_ref.columns):
            feature_ref.columns = [col.split(';')[0] for col in feature_ref.columns]
        
        return feature_ref

def process_paired_data(gene_file, cite_file, feature_ref_file, sample_id):
    """
    Process paired gene expression and protein expression data.
    
    Args:
        gene_file (str): Path to the gene expression h5 file
        cite_file (str): Path to the CITE-seq h5 file
        feature_ref_file (str): Path to the feature reference file
        sample_id (str): Sample identifier
    
    Returns:
        tuple: (gene_adata, protein_adata) - AnnData objects for gene and protein expression
    """
    print(f"Processing sample: {sample_id}")
    
    # Read gene expression data
    gene_adata = read_10x_h5(gene_file)
    print(f"  Gene expression data: {gene_adata.shape[0]} cells, {gene_adata.shape[1]} genes")
    
    # Read protein expression data
    protein_adata = read_10x_h5(cite_file)
    
    # Read feature reference data
    feature_ref = read_feature_reference(feature_ref_file)
    if feature_ref is not None:
        print(f"  Feature reference data: {feature_ref.shape[0]} features")
        
        # If protein_adata has no features, create it from the feature reference
        if protein_adata.shape[1] == 0 and feature_ref is not None:
            print("  Creating protein expression data from feature reference")
            
            # Create a new AnnData object with the correct features
            with h5py.File(cite_file, 'r') as f:
                # Read barcodes
                barcodes = f['matrix']['barcodes'][:]
                barcodes = [bc.decode('utf-8') if isinstance(bc, bytes) else bc for bc in barcodes]
                
                # Read sparse matrix data
                data = f['matrix']['data'][:]
                indices = f['matrix']['indices'][:]
                indptr = f['matrix']['indptr'][:]
                shape = f['matrix']['shape'][:]
                
                # Create sparse matrix with the correct shape
                n_cells = len(barcodes)
                n_features = len(feature_ref)
                matrix = sparse.csr_matrix((data, indices, indptr), shape=(n_cells, n_features))
                
                # Create feature metadata from feature reference
                var = pd.DataFrame({
                    'feature_id': feature_ref['id'].values,
                    'feature_name': feature_ref['name'].values,
                    'feature_type': feature_ref['feature_type'].values if 'feature_type' in feature_ref.columns else ['Antibody Capture'] * len(feature_ref)
                })
                var.index = var['feature_name']
                
                # Create AnnData object
                protein_adata = ad.AnnData(X=matrix, var=var)
                protein_adata.obs_names = barcodes
    
    print(f"  Protein expression data: {protein_adata.shape[0]} cells, {protein_adata.shape[1]} proteins")
    
    # Find common barcodes
    common_barcodes = set(gene_adata.obs_names).intersection(set(protein_adata.obs_names))
    print(f"  Common barcodes: {len(common_barcodes)}")
    
    # Filter to keep only common barcodes
    gene_adata = gene_adata[list(common_barcodes)].copy()
    protein_adata = protein_adata[list(common_barcodes)].copy()
    
    # Add sample_id to obs
    gene_adata.obs['sample_id'] = sample_id
    protein_adata.obs['sample_id'] = sample_id
    
    # Add metadata based on sample_id
    add_metadata(gene_adata, sample_id)
    add_metadata(protein_adata, sample_id)
    
    return gene_adata, protein_adata

def add_metadata(adata, sample_id):
    """
    Add standardized metadata to the AnnData object based on the sample ID.
    
    Args:
        adata (anndata.AnnData): AnnData object to add metadata to
        sample_id (str): Sample identifier
    """
    # Set organism
    adata.obs['organism'] = 'Homo sapiens'
    
    # Set cell_type based on sample_id
    if 'Covid_T' in sample_id:
        adata.obs['cell_type'] = 'T Cells'
        adata.obs['condition'] = 'SARS-CoV-2 specific'
        adata.obs['perturbation_name'] = 'SARS-CoV-2'
    elif 'CMV' in sample_id:
        adata.obs['cell_type'] = 'T Cells'
        adata.obs['condition'] = 'CMV specific'
        adata.obs['perturbation_name'] = 'CMV'
    elif 'EBV' in sample_id:
        adata.obs['cell_type'] = 'T Cells'
        adata.obs['condition'] = 'EBV specific'
        adata.obs['perturbation_name'] = 'EBV'
    elif 'plasmablasts' in sample_id:
        adata.obs['cell_type'] = 'Plasmablasts'
        adata.obs['condition'] = 'MIS-C'
        adata.obs['perturbation_name'] = 'SARS-CoV-2'
    elif 'Pool' in sample_id:
        if 'PosMP' in sample_id:
            adata.obs['cell_type'] = 'Mixed Immune Cells'
            adata.obs['condition'] = 'Post-Methylprednisolone'
            adata.obs['perturbation_name'] = 'SARS-CoV-2 + Methylprednisolone'
        elif 'PreMP' in sample_id:
            adata.obs['cell_type'] = 'Mixed Immune Cells'
            adata.obs['condition'] = 'Pre-Methylprednisolone'
            adata.obs['perturbation_name'] = 'SARS-CoV-2'
        elif 'Pool1' in sample_id or 'Pool2' in sample_id or 'Pool3' in sample_id:
            adata.obs['cell_type'] = 'Mixed Immune Cells'
            adata.obs['condition'] = 'MIS-C'
            adata.obs['perturbation_name'] = 'SARS-CoV-2'
        elif 'Pool4' in sample_id:
            adata.obs['cell_type'] = 'Mixed Immune Cells'
            adata.obs['condition'] = 'Post-SARS-CoV-2'
            adata.obs['perturbation_name'] = 'SARS-CoV-2'
    else:
        adata.obs['cell_type'] = 'Unknown'
        adata.obs['condition'] = 'Unknown'
        adata.obs['perturbation_name'] = 'Unknown'
    
    # Set cancer_type (all non-cancer in this dataset)
    adata.obs['cancer_type'] = 'Non-Cancer'
    
    # Set CRISPR type (not applicable for this dataset)
    adata.obs['crispr_type'] = 'None'

def harmonize_gene_names(adata):
    """
    Ensure gene names are based on gene symbols.
    
    Args:
        adata (anndata.AnnData): AnnData object to harmonize gene names
    
    Returns:
        anndata.AnnData: AnnData object with harmonized gene names
    """
    # Check if we have ENSEMBL IDs as feature_ids and gene symbols as feature_names
    if 'feature_id' in adata.var and 'feature_name' in adata.var:
        if adata.var['feature_id'].iloc[0].startswith('ENSG'):
            adata.var_names = adata.var['feature_name']
            print("  Using gene symbols as var_names")
    elif 'gene_ids' in adata.var and 'gene_symbols' in adata.var:
        if adata.var['gene_ids'].iloc[0].startswith('ENSG'):
            adata.var_names = adata.var['gene_symbols']
            print("  Using gene symbols as var_names")
    
    if adata.var_names.duplicated().any():
        print(f"  Found {adata.var_names.duplicated().sum()} duplicate gene names")
        if 'feature_id' in adata.var:
            adata.var_names = adata.var_names + '_' + adata.var['feature_id']
            adata.var_names = adata.var_names.map(lambda x: x.split('_ENSG')[0] + '_' + x.split('ENSG')[1] if 'ENSG' in x else x)
        elif 'gene_ids' in adata.var:
            adata.var_names = adata.var_names + '_' + adata.var['gene_ids']
            adata.var_names = adata.var_names.map(lambda x: x.split('_ENSG')[0] + '_' + x.split('ENSG')[1] if 'ENSG' in x else x)
        else:
            adata.var_names = make_index_unique(adata.var_names)
        
        if adata.var_names.duplicated().any():
            print(f"  Still have {adata.var_names.duplicated().sum()} duplicate gene names after first attempt")
            adata.var_names = make_index_unique(adata.var_names)
    
    return adata

def make_index_unique(index):
    """
    Make a pandas Index unique by appending a suffix to duplicates.
    
    Args:
        index (pandas.Index): Index to make unique
    
    Returns:
        pandas.Index: Unique index
    """
    counts = {}
    new_index = []
    
    for idx in index:
        if idx in counts:
            counts[idx] += 1
            new_index.append(f"{idx}_{counts[idx]}")
        else:
            counts[idx] = 0
            new_index.append(idx)
    
    return pd.Index(new_index)

def main(data_dir=None):
    """
    Main function to process the GSE254179 dataset.
    
    Args:
        data_dir (str, optional): Directory containing the dataset files.
            If None, the dataset will be downloaded to the current working directory.
    """
    if data_dir is None:
        data_dir = os.getcwd()
    
    # Download the dataset if needed
    data_dir = download_dataset(data_dir)
    
    # Find paired data
    paired_samples = find_paired_data(data_dir)
    print(f"Found {len(paired_samples)} paired samples")
    
    if not paired_samples:
        print("No paired samples found. Exiting.")
        return
    
    # Process each paired sample
    for gene_file, cite_file, feature_ref_file, sample_id in paired_samples:
        gene_adata, protein_adata = process_paired_data(gene_file, cite_file, feature_ref_file, sample_id)
        
        # Harmonize gene names
        gene_adata = harmonize_gene_names(gene_adata)
        
        # Save the harmonized data
        output_dir = os.path.join(os.path.dirname(str(data_dir)), "harmonized")
        os.makedirs(output_dir, exist_ok=True)
        
        gene_output_file = os.path.join(output_dir, f"{sample_id}_gene_expression.h5ad")
        protein_output_file = os.path.join(output_dir, f"{sample_id}_protein_expression.h5ad")
        
        print(f"Saving gene expression data to {gene_output_file}")
        gene_adata.write_h5ad(gene_output_file)
        
        print(f"Saving protein expression data to {protein_output_file}")
        protein_adata.write_h5ad(protein_output_file)
        
        print(f"Processed {sample_id} successfully")

# For Jupyter, simply call main() in a separate cell or at the end of this cell.
# If you want to specify a data directory, pass it as an argument (e.g., main('/path/to/data'))
main()
