In [None]:
import os
import sys
import glob
import gzip
import shutil
import urllib.request
from pathlib import Path
import numpy as np
import pandas as pd
import scanpy as sc
import anndata as ad
from scipy import sparse

# Set up logging
sc.settings.verbosity = 1
sc.logging.print_header()

class DatasetHarmonizer:
    """Base class for harmonizing scRNA-seq datasets into standardized h5ad format."""
    
    def __init__(self, accession, root_path):
        """
        Initialize the harmonizer with the accession number and root path.
        
        Parameters:
        -----------
        accession : str
            GEO accession number (e.g., GSE216673)
        root_path : str
            Path to the directory where data will be stored
        """
        self.accession = accession
        self.root_path = Path(root_path)
        self.dataset_path = self.root_path / self.accession
        
        # Ensure the dataset directory exists
        os.makedirs(self.dataset_path, exist_ok=True)
    
    def download_data(self):
        """Download the dataset if it doesn't exist locally."""
        raise NotImplementedError("Subclasses must implement download_data method")
    
    def harmonize_dataset(self):
        """
        Process and harmonize the dataset according to the specified standards.
        
        Returns:
        --------
        anndata.AnnData
            Harmonized AnnData object
        """
        raise NotImplementedError("Subclasses must implement harmonize_dataset method")
    
    def print_dataset_summary(self, adata):
        """
        Print a summary of the harmonized dataset.
        
        Parameters:
        -----------
        adata : anndata.AnnData
            The harmonized dataset
        """
        print("\nDataset Summary:")
        print(f"  Total cells: {adata.n_obs}")
        print(f"  Total genes: {adata.n_vars}")
        
        # Count cells by condition
        condition_counts = adata.obs['condition'].value_counts()
        print("\n  Cells by condition:")
        for condition, count in condition_counts.items():
            print(f"    {condition}: {count} cells")
        
        # Count cells by perturbation
        perturbation_counts = adata.obs['perturbation_name'].value_counts()
        print("\n  Cells by perturbation:")
        for perturbation, count in perturbation_counts.items():
            print(f"    {perturbation}: {count} cells")
        
        # Count cells by experiment if available
        if 'experiment' in adata.obs:
            experiment_counts = adata.obs['experiment'].value_counts()
            print("\n  Cells by experiment:")
            for experiment, count in experiment_counts.items():
                print(f"    {experiment}: {count} cells")
        
        # Count cells by sample
        sample_counts = adata.obs['sample'].value_counts()
        print("\n  Cells by sample:")
        for sample, count in sample_counts.items():
            print(f"    {sample}: {count} cells")
    
    def process(self):
        """
        Main processing function to download, harmonize, and save the dataset.
        
        Returns:
        --------
        str
            Path to the saved h5ad file
        """
        # Download and extract data if needed
        self.download_data()
        
        # Harmonize the dataset
        adata = self.harmonize_dataset()
        
        # Save the harmonized dataset
        output_file = self.dataset_path / f"{self.accession}_harmonized.h5ad"
        adata.write_h5ad(output_file)
        
        print(f"Harmonized dataset saved to: {output_file}")
        print(f"Dataset shape: {adata.shape}")
        print(f"Observations metadata: {list(adata.obs.columns)}")
        
        # Print dataset summary
        self.print_dataset_summary(adata)
        
        return str(output_file)


class GSE216673Harmonizer(DatasetHarmonizer):
    """Class to harmonize GSE216673 dataset into standardized h5ad format."""
    
    def __init__(self, root_path):
        """
        Initialize the harmonizer with the root path for data storage.
        
        Parameters:
        -----------
        root_path : str
            Path to the directory where data will be stored
        """
        super().__init__("GSE216673", root_path)
        self.raw_data_url = "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE216nnn/GSE216673/suppl/GSE216673_RAW.tar"
        self.raw_tar_path = self.dataset_path / "GSE216673_RAW.tar"
    
    def download_data(self):
        """Download the dataset if it doesn't exist locally."""
        if not self.raw_tar_path.exists():
            print(f"Downloading {self.accession} dataset...")
            urllib.request.urlretrieve(self.raw_data_url, self.raw_tar_path)
            print(f"Download complete: {self.raw_tar_path}")
        else:
            print(f"Dataset already downloaded: {self.raw_tar_path}")
        
        # Extract the tar file if needed
        sample_files = list(self.dataset_path.glob("GSM*.gz"))
        if not sample_files:
            print("Extracting tar file...")
            shutil.unpack_archive(self.raw_tar_path, self.dataset_path)
            print("Extraction complete")
    
    def load_10x_data(self, sample_prefix):
        """
        Load a 10X dataset from the specified sample prefix.
        
        Parameters:
        -----------
        sample_prefix : str
            Prefix of the sample files (e.g., "GSM6685596_WT_TREX1_microglia-Exp1")
            
        Returns:
        --------
        anndata.AnnData
            AnnData object containing the loaded data
        """
        print(f"Loading {sample_prefix}...")
        
        # File paths
        matrix_file = self.dataset_path / f"{sample_prefix}_matrix.mtx.gz"
        features_file = self.dataset_path / f"{sample_prefix}_features.tsv.gz"
        barcodes_file = self.dataset_path / f"{sample_prefix}_barcodes.tsv.gz"
        
        # Load the count matrix
        counts = sparse.csr_matrix(sc.read_mtx(str(matrix_file)).X.T)
        
        # Load features (genes)
        with gzip.open(features_file, 'rt') as f:
            gene_ids = []
            gene_symbols = []
            for line in f:
                fields = line.strip().split('\t')
                gene_ids.append(fields[0])
                gene_symbols.append(fields[1])
        
        # Load cell barcodes
        with gzip.open(barcodes_file, 'rt') as f:
            barcodes = [line.strip() for line in f]
        
        # Create a var DataFrame with gene symbols and IDs
        var_df = pd.DataFrame({
            'gene_ids': gene_ids,
            'gene_symbols': gene_symbols
        })
        
        # Make gene symbols unique for index
        var_df.index = pd.Index([f"{s}_{i}" if gene_symbols.count(s) > 1 else s 
                                for i, s in enumerate(gene_symbols)])
        
        # Create AnnData object
        adata = ad.AnnData(
            X=counts,
            obs=pd.DataFrame(index=barcodes),
            var=var_df
        )
        
        # Extract metadata from sample name.
        # Expecting format: [GSM_ID, condition (WT/KO), perturbation, ...]
        sample_info = sample_prefix.split('_')
        condition_indicator = sample_info[1]  # "WT" or "KO"
        experiment = sample_info[-1].split('-')[1]  # Exp1, Exp2, etc.
        
        # Assign metadata using the new rules:
        # - For WT samples, condition="Control" and perturbation_name="Non-targeting"
        # - For KO samples, condition="Test" and perturbation_name="TREX1"
        adata.obs['sample'] = sample_prefix
        adata.obs['condition'] = 'Control' if condition_indicator == 'WT' else 'Test'
        adata.obs['experiment'] = experiment
        adata.obs['perturbation_name'] = 'Non-targeting' if condition_indicator == 'WT' else 'TREX1'
        
        return adata
    
    def harmonize_dataset(self):
        """
        Process and harmonize the dataset according to the specified standards.
        
        Returns:
        --------
        anndata.AnnData
            Harmonized AnnData object
        """
        # Find all sample prefixes
        sample_files = list(self.dataset_path.glob("GSM*_barcodes.tsv.gz"))
        sample_prefixes = [f.name.replace("_barcodes.tsv.gz", "") for f in sample_files]
        
        # Load and combine all samples
        adatas = []
        for prefix in sample_prefixes:
            adata = self.load_10x_data(prefix)
            adatas.append(adata)
        
        # Concatenate all samples
        if len(adatas) > 1:
            combined = ad.concat(
                adatas, 
                join='outer', 
                merge='same',
                label='sample',
                keys=sample_prefixes,
                index_unique='-'
            )
        else:
            combined = adatas[0]
        
        # Standardize additional metadata columns
        combined.obs['organism'] = 'Homo sapiens'
        combined.obs['cell_type'] = 'Microglia'  # Based on dataset description
        combined.obs['crispr_type'] = 'CRISPR KO'  # Based on dataset description (TREX1 KO)
        combined.obs['cancer_type'] = 'Non-Cancer'  # Based on dataset description
        
        # Ensure all metadata columns are strings
        for col in combined.obs.columns:
            combined.obs[col] = combined.obs[col].astype(str)
        
        # Update metadata based on sample name.
        # If the sample name contains "_KO_", update as KO sample; otherwise, WT.
        combined.obs['condition'] = combined.obs['sample'].apply(
            lambda x: 'Test' if '_KO_' in x else 'Control'
        )
        combined.obs['perturbation_name'] = combined.obs['sample'].apply(
            lambda x: 'TREX1' if '_KO_' in x else 'Non-targeting'
        )
        
        # Store original gene symbols for reference
        combined.var['original_gene_symbols'] = combined.var['gene_symbols']
        
        # Ensure we have unique gene symbols as var_names
        print("Ensuring gene symbols are unique in the final dataset...")
        if not combined.var_names.is_unique:
            print("Warning: Gene symbols are not unique. They were made unique during loading.")
        
        return combined


def get_harmonizer(accession, root_path):
    """
    Factory function to get the appropriate harmonizer for a given accession.
    
    Parameters:
    -----------
    accession : str
        GEO accession number
    root_path : str
        Path to store the data
        
    Returns:
    --------
    DatasetHarmonizer
        Appropriate harmonizer for the given accession
    """
    if accession == "GSE216673":
        return GSE216673Harmonizer(root_path)
    else:
        raise ValueError(f"Unsupported accession: {accession}. Currently only GSE216673 is supported.")


def main(accession="GSE216673", root_path=os.getcwd()):
    """Main function to run the harmonization process.
    
    Parameters:
    -----------
    accession : str, optional
        GEO accession number (default is "GSE216673")
    root_path : str, optional
        Directory to store data (default is the current working directory)
    """
    try:
        harmonizer = get_harmonizer(accession, root_path)
        harmonized_file = harmonizer.process()
        print(f"Harmonization complete. File saved at: {harmonized_file}")
    except ValueError as e:
        print(f"Error: {e}")
        print("Usage: call main(accession, root_path)")

# In a Jupyter Notebook, simply call main() to run the harmonization process.
main()
