In [None]:
import os
import gzip
import urllib.request
from pathlib import Path
import numpy as np
import pandas as pd
import scipy.sparse as sp
import anndata as ad

class CITEseqHarmonizer:
    """Class to harmonize CITE-seq data into h5ad format."""
    
    def __init__(self, data_dir, accession=None):
        """Initialize the harmonizer.
        
        Args:
            data_dir (str): Path to the data directory.
            accession (str, optional): GEO accession number. If None, will try to infer from data_dir.
        """
        self.data_dir = os.path.abspath(data_dir)
        
        # Try to infer accession from data_dir if not provided
        if accession is None:
            self.accession = os.path.basename(self.data_dir)
            # Check if the accession follows the GEO format (GSE + numbers)
            if not (self.accession.startswith('GSE') and self.accession[3:].isdigit()):
                raise ValueError("Could not infer GEO accession from data_dir. Please provide accession.")
        else:
            self.accession = accession
        
        # Create data directory if it doesn't exist
        os.makedirs(self.data_dir, exist_ok=True)
        
        # Set file paths
        self.barcodes_file = os.path.join(self.data_dir, f'{self.accession}_cite_barcodes.tsv.gz')
        self.features_file = os.path.join(self.data_dir, f'{self.accession}_cite_features.tsv.gz')
        self.matrix_file = os.path.join(self.data_dir, f'{self.accession}_cite_matrix.mtx.gz')
        
        # Set URLs for the dataset files
        self.urls = {
            'barcodes': f'https://www.ncbi.nlm.nih.gov/geo/download/?acc={self.accession}&format=file&file=GSE279632%5Fcite%5Fbarcodes%2Etsv%2Egz',
            'features': f'https://www.ncbi.nlm.nih.gov/geo/download/?acc={self.accession}&format=file&file=GSE279632%5Fcite%5Ffeatures%2Etsv%2Egz',
            'matrix': f'https://www.ncbi.nlm.nih.gov/geo/download/?acc={self.accession}&format=file&file=GSE279632%5Fcite%5Fmatrix%2Emtx%2Egz'
        }
        
        # Set output directory
        self.output_dir = os.path.join(os.path.dirname(self.data_dir), 'processed')
        os.makedirs(self.output_dir, exist_ok=True)
        
        # Set output file paths
        self.gene_output = os.path.join(self.output_dir, f'{self.accession}_gene_expression.h5ad')
        self.protein_output = os.path.join(self.output_dir, f'{self.accession}_protein_expression.h5ad')
    
    def download_files(self):
        """Download dataset files if they don't exist."""
        for name, url in self.urls.items():
            filepath = getattr(self, f'{name}_file')
            
            if not os.path.exists(filepath):
                print(f"Downloading {os.path.basename(filepath)}...")
                try:
                    urllib.request.urlretrieve(url, filepath)
                    print(f"Downloaded {os.path.basename(filepath)}")
                except Exception as e:
                    print(f"Error downloading {os.path.basename(filepath)}: {e}")
                    return False
            else:
                print(f"File {os.path.basename(filepath)} already exists")
        
        return True
    
    def read_mtx_files(self):
        """Read the MTX files and create AnnData objects."""
        # Read barcodes
        with gzip.open(self.barcodes_file, 'rt') as f:
            barcodes = [line.strip() for line in f]
        
        # Read features
        features = []
        with gzip.open(self.features_file, 'rt') as f:
            for line in f:
                features.append(line.strip().split('\t'))
        
        # Separate gene expression and protein features
        gene_features = [f for f in features if f[2] == 'Gene Expression']
        protein_features = [f for f in features if f[2] == 'Antibody Capture']
        
        # Create feature DataFrames
        gene_df = pd.DataFrame(gene_features, columns=['gene_id', 'gene_name', 'feature_type'])
        protein_df = pd.DataFrame(protein_features, columns=['protein_id', 'gene_name', 'feature_type'])
        
        # Transform protein names: convert to uppercase and remove extra parts (e.g., '_TotalA')
        protein_df['gene_name'] = protein_df['gene_name'].apply(lambda x: x.split('_')[0].upper())
        
        # Read the sparse matrix
        from scipy import io as spio
        mtx = spio.mmread(gzip.open(self.matrix_file, 'rb'))
        mtx = sp.csr_matrix(mtx)
        
        # Check matrix dimensions
        print(f"Matrix shape: {mtx.shape}")
        print(f"Number of barcodes: {len(barcodes)}")
        print(f"Number of gene features: {len(gene_features)}")
        print(f"Number of protein features: {len(protein_features)}")
        
        # Transpose the matrix if needed (cells should be rows, features should be columns)
        if mtx.shape[0] != len(barcodes):
            print("Transposing matrix to match barcodes...")
            mtx = mtx.T
        
        # Split the matrix into gene expression and protein expression
        n_genes = len(gene_features)
        gene_mtx = mtx[:, :n_genes]
        protein_mtx = mtx[:, n_genes:]
        
        # Create observation metadata
        obs = pd.DataFrame(index=barcodes)
        obs['barcode'] = barcodes
        
        # Add required metadata fields with default values
        obs['organism'] = 'Homo sapiens'  # Based on the dataset description
        obs['cell_type'] = 'Cardiomyocyte'  # Based on the dataset description (sinoatrial node pacemaker cardiomyocytes)
        obs['crispr_type'] = 'None'  # No CRISPR perturbation in this dataset
        obs['cancer_type'] = 'Non-Cancer'  # This is a normal cell study
        obs['condition'] = 'Control'  # Default condition
        obs['perturbation_name'] = 'None'  # No perturbation in this dataset
        obs['source'] = self.accession  # Source dataset
        obs['study_accession'] = self.accession  # GEO accession
        obs['data_type'] = 'CITE-seq'  # Data type
        
        # Create AnnData objects
        gene_adata = ad.AnnData(
            X=gene_mtx,
            obs=obs,
            var=gene_df.set_index('gene_name')
        )
        
        protein_adata = ad.AnnData(
            X=protein_mtx,
            obs=obs,
            var=protein_df.set_index('gene_name')
        )
        
        # Add protein expression to gene_adata.obsm
        gene_adata.obsm['protein_expression'] = protein_mtx.toarray()
        
        # Add protein names to gene_adata.uns
        gene_adata.uns['protein_names'] = protein_df['gene_name'].values
        
        return gene_adata, protein_adata
    
    def check_for_duplicates(self, adata):
        """Check for duplicate gene names and handle them."""
        if adata.var_names.duplicated().any():
            print(f"Found {adata.var_names.duplicated().sum()} duplicate gene names")
            
            # Get duplicate gene names
            duplicates = adata.var_names[adata.var_names.duplicated()].unique()
            print(f"Duplicate genes: {duplicates}")
            
            # Make gene names unique by appending a suffix
            adata.var_names_make_unique()
            
        return adata
    
    def save_anndata(self, adata, output_file):
        """Save AnnData object to h5ad file."""
        print(f"Saving {output_file}...")
        adata.write(output_file)
        print(f"Saved {output_file}")
    
    def analyze_cd34(self, gene_adata, protein_adata):
        """Analyze CD34 expression in the dataset."""
        # Check if CD34 is in the gene expression data
        if 'CD34' in gene_adata.var_names:
            cd34_idx = np.where(gene_adata.var_names == 'CD34')[0][0]
            print(f"\nCD34 gene found at index {cd34_idx}")
            
            # Get CD34 expression
            cd34_expr = gene_adata[:, 'CD34'].X.toarray().flatten()
            print(f"CD34 gene expression stats:")
            print(f"  Mean: {cd34_expr.mean():.4f}")
            print(f"  Median: {np.median(cd34_expr):.4f}")
            print(f"  Max: {cd34_expr.max():.4f}")
            print(f"  Number of cells with CD34 > 0: {np.sum(cd34_expr > 0)}")
        
        # Check if CD34 is in the protein expression data
        cd34_proteins = [p for p in protein_adata.var_names if 'CD34' in p.upper() or 'Cd34' in p]
        if cd34_proteins:
            for protein in cd34_proteins:
                print(f"\nCD34 protein found: {protein}")
                protein_expr = protein_adata[:, protein].X.toarray().flatten()
                print(f"CD34 protein expression stats:")
                print(f"  Mean: {protein_expr.mean():.4f}")
                print(f"  Median: {np.median(protein_expr):.4f}")
                print(f"  Max: {protein_expr.max():.4f}")
                print(f"  Number of cells with CD34 protein > 0: {np.sum(protein_expr > 0)}")
    
    def harmonize(self):
        """Harmonize the dataset."""
        # Download files if they don't exist
        if not self.download_files():
            print("Error downloading files. Exiting.")
            return False
        
        # Read MTX files and create AnnData objects
        print("Processing data...")
        gene_adata, protein_adata = self.read_mtx_files()
        
        # Check for duplicate gene names
        gene_adata = self.check_for_duplicates(gene_adata)
        protein_adata = self.check_for_duplicates(protein_adata)
        
        # Print basic information
        print("\nGene Expression AnnData:")
        print(f"Shape: {gene_adata.shape}")
        print(f"Observations (cells): {gene_adata.n_obs}")
        print(f"Variables (genes): {gene_adata.n_vars}")
        print(f"Genes per cell (mean): {gene_adata.X.sum(axis=1).mean():.2f}")
        
        print("\nProtein Expression AnnData:")
        print(f"Shape: {protein_adata.shape}")
        print(f"Observations (cells): {protein_adata.n_obs}")
        print(f"Variables (proteins): {protein_adata.n_vars}")
        
        # Analyze CD34 expression
        self.analyze_cd34(gene_adata, protein_adata)
        
        # Save AnnData objects
        self.save_anndata(gene_adata, self.gene_output)
        self.save_anndata(protein_adata, self.protein_output)
        
        print("\nProcessing complete!")
        print(f"Gene expression data saved to: {self.gene_output}")
        print(f"Protein expression data saved to: {self.protein_output}")
        
        return True

# ----- Run the Harmonization Process in Jupyter -----

# Define the data directory.
# Here we use a folder named 'GSE279632' in the current working directory.
data_dir = os.path.join(os.getcwd(), 'GSE279632')

# Create an instance of the harmonizer and run the harmonization.
harmonizer = CITEseqHarmonizer(data_dir)
harmonizer.harmonize()
