In [None]:
# --------------------------------------------------------------------------------
# 1. Imports and setup
# --------------------------------------------------------------------------------
import os
import re
import glob
import h5py
import numpy as np
import pandas as pd
from scipy import sparse
import anndata as ad
import logging
from pathlib import Path
import requests
import gzip
import tarfile
import time
import subprocess
import sys

# Check for GEOparse or install Bioconductor packages if needed
try:
    # Try to import GEOparse for Python
    import GEOparse
    have_geoparse = True
except ImportError:
    have_geoparse = False
    print("GEOparse package not available. Will use alternative download methods.")

# Logging setup
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# --------------------------------------------------------------------------------
# 2. Constants and metadata
# --------------------------------------------------------------------------------
ACCESSION_NUMBER = "GSE288020"
ORGANISM = "Homo sapiens"

# Sample metadata mapping
SAMPLE_METADATA = {
    'GSM8757538': {'sample_type': 'MGUS', 'immune_age': 'Old', 'sample_name': 'R001'},
    'GSM8757539': {'sample_type': 'MGUS', 'immune_age': 'Young', 'sample_name': 'R003'},
    'GSM8757540': {'sample_type': 'MGUS', 'immune_age': 'Young', 'sample_name': 'R005'},
    'GSM8757541': {'sample_type': 'MGUS', 'immune_age': 'Old', 'sample_name': 'R006'},
    'GSM8757542': {'sample_type': 'MGUS', 'immune_age': 'Young', 'sample_name': 'R008'},
    'GSM8757543': {'sample_type': 'MGUS', 'immune_age': 'Young', 'sample_name': 'R009'},
    'GSM8757544': {'sample_type': 'MGUS', 'immune_age': 'Young', 'sample_name': 'R010'},
    'GSM8757545': {'sample_type': 'MGUS', 'immune_age': 'Old', 'sample_name': 'R013'},
    'GSM8757546': {'sample_type': 'MGUS', 'immune_age': 'Old', 'sample_name': 'R014'},
    'GSM8757547': {'sample_type': 'MGUS', 'immune_age': 'Young', 'sample_name': 'R015'},
    'GSM8757548': {'sample_type': 'MGUS', 'immune_age': 'Old', 'sample_name': 'R016'},
    'GSM8757549': {'sample_type': 'MGUS', 'immune_age': 'Old', 'sample_name': 'R020'},
    'GSM8757550': {'sample_type': 'MGUS', 'immune_age': 'Young', 'sample_name': 'R023'},
    'GSM8757551': {'sample_type': 'MGUS', 'immune_age': 'Old', 'sample_name': 'R024'},
    'GSM8757552': {'sample_type': 'MM', 'immune_age': 'Young', 'sample_name': 'E2228'},
    'GSM8757553': {'sample_type': 'MM', 'immune_age': 'Old', 'sample_name': 'E2238'},
    'GSM8757554': {'sample_type': 'MM', 'immune_age': 'Young', 'sample_name': 'E2242'},
    'GSM8757555': {'sample_type': 'MM', 'immune_age': 'Old', 'sample_name': 'E2243'},
    'GSM8757556': {'sample_type': 'MM', 'immune_age': 'Young', 'sample_name': 'E2263'},
    'GSM8757557': {'sample_type': 'MM', 'immune_age': 'Old', 'sample_name': 'E2324'},
    'GSM8757558': {'sample_type': 'MM', 'immune_age': 'Young', 'sample_name': 'E2326'},
    'GSM8757559': {'sample_type': 'MM', 'immune_age': 'Old', 'sample_name': 'E2328'},
    'GSM8757560': {'sample_type': 'MM', 'immune_age': 'Old', 'sample_name': 'E2329'}
}

# --------------------------------------------------------------------------------
# 3. Download Functions - Multiple approaches
# --------------------------------------------------------------------------------

def download_family_soft_file(geo_id, output_dir):
    """
    Download the family SOFT file for a GEO Series
    Args:
        geo_id (str): GEO accession ID (e.g., GSE288020)
        output_dir (str): Directory to save the downloaded file
    Returns:
        str: Path to the downloaded file
    """
    os.makedirs(output_dir, exist_ok=True)
    output_file = os.path.join(output_dir, f"{geo_id}_family.soft.gz")
    
    if os.path.exists(output_file):
        logger.info(f"Family SOFT file already exists at {output_file}")
        return output_file
    
    # Construct URL for family SOFT file
    url = f"https://ftp.ncbi.nlm.nih.gov/geo/series/{geo_id[:-3]}nnn/{geo_id}/soft/{geo_id}_family.soft.gz"
    
    logger.info(f"Downloading family SOFT file from {url}")
    try:
        response = requests.get(url, stream=True)
        response.raise_for_status()
        
        with open(output_file, 'wb') as f:
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)
        
        logger.info(f"Downloaded family SOFT file to {output_file}")
        return output_file
    
    except Exception as e:
        logger.error(f"Failed to download family SOFT file: {e}")
        return None

def extract_supplementary_file_urls(soft_file):
    """
    Extract supplementary file URLs from a GEO SOFT file
    Args:
        soft_file (str): Path to the GEO SOFT file
    Returns:
        list: List of supplementary file URLs
    """
    urls = []
    
    if soft_file.endswith('.gz'):
        open_func = gzip.open
        mode = 'rt'
    else:
        open_func = open
        mode = 'r'
    
    try:
        with open_func(soft_file, mode) as f:
            for line in f:
                if line.startswith('!Series_supplementary_file') or line.startswith('!Sample_supplementary_file'):
                    # Extract URL from line
                    url = line.strip().split(' = ')[1]
                    urls.append(url)
    except Exception as e:
        logger.error(f"Error parsing SOFT file: {e}")
    
    logger.info(f"Found {len(urls)} supplementary file URLs")
    return urls

def download_file(url, output_path, max_retries=3, retry_delay=5):
    """
    Download a file from a URL with retries
    Args:
        url (str): URL to download
        output_path (str): Path to save the file
        max_retries (int): Maximum number of retry attempts
        retry_delay (int): Delay between retries in seconds
    Returns:
        bool: True if download successful, False otherwise
    """
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    for attempt in range(max_retries):
        try:
            logger.info(f"Downloading {url} to {output_path} (Attempt {attempt + 1})")
            response = requests.get(url, stream=True, timeout=60)
            response.raise_for_status()
            
            with open(output_path, 'wb') as f:
                for chunk in response.iter_content(chunk_size=8192):
                    f.write(chunk)
            
            logger.info(f"Successfully downloaded {url}")
            return True
            
        except requests.exceptions.RequestException as e:
            logger.warning(f"Attempt {attempt + 1} failed: {e}")
            if attempt < max_retries - 1:
                logger.info(f"Retrying in {retry_delay} seconds...")
                time.sleep(retry_delay)
            else:
                logger.error(f"Failed to download {url} after {max_retries} attempts")
                return False

def try_alternative_url_format(url, output_path):
    """
    Try alternative URL format if original fails
    Args:
        url (str): Original URL
        output_path (str): Path to save the file
    Returns:
        bool: True if download successful, False otherwise
    """
    # Try direct FTP format
    if url.startswith('ftp://'):
        # Try HTTPS alternative
        https_url = url.replace('ftp://', 'https://')
        if download_file(https_url, output_path):
            return True
    
    # Try GEO direct URL format
    if 'geo/samples/' in url or 'geo/series/' in url:
        # Parse the accession from the URL
        match = re.search(r'(GSM\d+|GSE\d+)', url)
        if match:
            accession = match.group(1)
            filename = os.path.basename(url)
            
            if accession.startswith('GSM'):
                alternative_url = f"https://www.ncbi.nlm.nih.gov/geo/download/?acc={accession}&format=file&file={filename}"
            else:  # GSE
                alternative_url = f"https://www.ncbi.nlm.nih.gov/geo/download/?acc={accession}&format=file"
            
            logger.info(f"Trying alternative URL: {alternative_url}")
            if download_file(alternative_url, output_path):
                return True
    
    return False

def attempt_geoparse_download(geo_id, output_dir):
    """
    Attempt to download files using GEOparse
    Args:
        geo_id (str): GEO accession ID
        output_dir (str): Directory to save the downloaded files
    Returns:
        list: List of paths to downloaded files
    """
    if not have_geoparse:
        logger.warning("GEOparse not available, skipping this download method")
        return []
    
    try:
        logger.info(f"Attempting to download {geo_id} using GEOparse")
        gse = GEOparse.get_GEO(geo=geo_id, destdir=output_dir)
        
        # Try to download supplementary files
        gse.download_supplementary_files(directory=output_dir)
        
        # Get list of downloaded files
        downloaded_files = glob.glob(os.path.join(output_dir, f"{geo_id}_*"))
        logger.info(f"Downloaded {len(downloaded_files)} files using GEOparse")
        return downloaded_files
    
    except Exception as e:
        logger.error(f"GEOparse download failed: {e}")
        return []

def attempt_r_geoquery_download(geo_id, output_dir):
    """
    Try to download GEO files using R's GEOquery package
    Args:
        geo_id (str): GEO accession ID
        output_dir (str): Directory to save the downloaded files
    Returns:
        bool: True if successful, False otherwise
    """
    # Check if R is installed
    try:
        r_version = subprocess.check_output(['R', '--version'], text=True)
        logger.info(f"R is available: {r_version.split()[0]} {r_version.split()[1]}")
    except:
        logger.warning("R is not available, skipping GEOquery download")
        return False
    
    # Create R script for download
    r_script = os.path.join(output_dir, "download_geo.R")
    
    # Convert Windows path to R-compatible path (forward slashes)
    r_output_dir = output_dir.replace('\\', '/')
    
    with open(r_script, "w") as f:
        f.write(f'''
# Install BiocManager if not already installed
if (!requireNamespace("BiocManager", quietly = TRUE))
    install.packages("BiocManager", repos = "https://cran.r-project.org")

# Install GEOquery if not already installed
if (!requireNamespace("GEOquery", quietly = TRUE))
    BiocManager::install("GEOquery")

library(GEOquery)

# Set timeout to a larger value
options(timeout = 600)

# Set download methods
options(download.file.method.GEOquery = "auto")

# Download supplementary files
print("Downloading {geo_id} supplementary files...")
try({{
    files <- getGEOSuppFiles("{geo_id}", baseDir = "{r_output_dir}")
    print("Downloaded files:")
    print(files)
}}, silent = FALSE)

# Also try to get the Series Matrix file
print("Downloading {geo_id} Series Matrix files...")
try({{
    gse <- getGEO("{geo_id}", GSEMatrix = TRUE, destdir = "{r_output_dir}")
    print("Downloaded Series Matrix files")
}}, silent = FALSE)
''')
    
    # Run the R script
    logger.info(f"Running R script to download {geo_id} using GEOquery")
    try:
        result = subprocess.run(['Rscript', r_script], capture_output=True, text=True, timeout=1200)
        logger.info(f"R script output: {result.stdout}")
        if result.stderr:
            logger.warning(f"R script errors: {result.stderr}")
        
        # Check if files were downloaded
        files = glob.glob(os.path.join(output_dir, f"{geo_id}*"))
        if files:
            logger.info(f"GEOquery download successful, found {len(files)} files")
            return True
        else:
            logger.warning("GEOquery ran but no files were found")
            return False
            
    except subprocess.TimeoutExpired:
        logger.error("R script timed out after 20 minutes")
        return False
    except Exception as e:
        logger.error(f"Error running R script: {e}")
        return False

def download_geo_supplementary_files(geo_id, output_dir, recursive_samples=True):
    """
    Download all supplementary files for a GEO Series
    Args:
        geo_id (str): GEO accession ID
        output_dir (str): Directory to save the downloaded files
        recursive_samples (bool): Whether to also download supplementary files for samples
    Returns:
        list: List of paths to downloaded files
    """
    os.makedirs(output_dir, exist_ok=True)
    downloaded_files = []
    
    # First try GEOparse if available
    geoparse_files = attempt_geoparse_download(geo_id, output_dir)
    if geoparse_files:
        downloaded_files.extend(geoparse_files)
        return downloaded_files
    
    # Fallback to manual download
    # 1. Download the family SOFT file
    soft_file = download_family_soft_file(geo_id, output_dir)
    if not soft_file:
        logger.error("Failed to download family SOFT file, cannot proceed with manual download")
        return downloaded_files
    
    # 2. Extract supplementary file URLs
    urls = extract_supplementary_file_urls(soft_file)
    
    # 3. Download each file
    for url in urls:
        filename = os.path.basename(url)
        output_path = os.path.join(output_dir, filename)
        
        if os.path.exists(output_path):
            logger.info(f"File {filename} already exists, skipping download")
            downloaded_files.append(output_path)
            continue
        
        success = download_file(url, output_path)
        if not success:
            success = try_alternative_url_format(url, output_path)
        
        if success:
            downloaded_files.append(output_path)
    
    # 4. If recursive, also check for supplementary files for each sample
    if recursive_samples and geo_id.startswith('GSE'):
        # Get all sample IDs from the SOFT file
        sample_ids = []
        with gzip.open(soft_file, 'rt') as f:
            for line in f:
                if line.startswith('^SAMPLE = '):
                    sample_id = line.strip().split(' = ')[1]
                    sample_ids.append(sample_id)
        
        logger.info(f"Found {len(sample_ids)} samples, checking for supplementary files")
        
        for sample_id in sample_ids:
            sample_dir = os.path.join(output_dir, sample_id)
            os.makedirs(sample_dir, exist_ok=True)
            
            sample_files = download_geo_supplementary_files(sample_id, sample_dir, recursive_samples=False)
            downloaded_files.extend(sample_files)
    
    return downloaded_files

def get_h5_files(data_dir, sample_ids=None, create_dummy=False):
    """
    Get list of h5 files in the data directory or create dummy files for testing
    Args:
        data_dir (str): Directory containing h5 files
        sample_ids (list): List of sample IDs to look for
        create_dummy (bool): Whether to create dummy h5 files for testing
    Returns:
        list: List of paths to h5 files
    """
    if sample_ids is None:
        sample_ids = list(SAMPLE_METADATA.keys())
    
    h5_files = []
    missing_files = []
    
    # Look for h5 files matching expected patterns
    for sample_id in sample_ids:
        sample_name = SAMPLE_METADATA[sample_id]['sample_name']
        expected_patterns = [
            f"{sample_id}_{sample_name}*filtered_feature_bc_matrix.h5",
            f"{sample_id}*filtered_feature_bc_matrix.h5",
            f"{sample_name}*filtered_feature_bc_matrix.h5",
            f"{sample_id}*.h5",
            f"{sample_id}_*matrix*.h5"
        ]
        
        found = False
        for pattern in expected_patterns:
            matching_files = glob.glob(os.path.join(data_dir, pattern))
            if matching_files:
                h5_files.extend(matching_files)
                found = True
                break
                
        if not found:
            missing_files.append(sample_id)
    
    # If create_dummy is True and there are missing files, create dummy h5 files
    if create_dummy and missing_files:
        logger.warning(f"Creating dummy h5 files for {len(missing_files)} samples")
        for sample_id in missing_files:
            sample_name = SAMPLE_METADATA[sample_id]['sample_name']
            dummy_file = os.path.join(data_dir, f"{sample_id}_{sample_name}_dummy_matrix.h5")
            
            # Create a minimal h5 file for testing
            with h5py.File(dummy_file, 'w') as f:
                # Create a matrix group
                matrix = f.create_group("matrix")
                
                # Add small test data
                matrix.create_dataset("data", data=np.array([1, 2, 3, 4]))
                matrix.create_dataset("indices", data=np.array([0, 1, 0, 1]))
                matrix.create_dataset("indptr", data=np.array([0, 2, 4]))
                
                # Add barcodes and features
                barcodes = [f"CELL_{i}" for i in range(2)]
                features_id = [f"GENE_{i}" for i in range(2)]
                features_name = [f"Gene_{i}" for i in range(2)]
                features_type = ["Gene Expression", "Gene Expression"]
                
                # Convert to bytes
                barcodes_bytes = [s.encode('utf-8') for s in barcodes]
                features_id_bytes = [s.encode('utf-8') for s in features_id]
                features_name_bytes = [s.encode('utf-8') for s in features_name]
                features_type_bytes = [s.encode('utf-8') for s in features_type]
                
                # Add to file
                matrix.create_dataset("barcodes", data=barcodes_bytes)
                
                # Create features group
                features = matrix.create_group("features")
                features.create_dataset("id", data=features_id_bytes)
                features.create_dataset("name", data=features_name_bytes)
                features.create_dataset("feature_type", data=features_type_bytes)
            
            logger.info(f"Created dummy h5 file at {dummy_file}")
            h5_files.append(dummy_file)
    
    return h5_files

# --------------------------------------------------------------------------------
# 4. Reading and processing h5 files
# --------------------------------------------------------------------------------

def read_10x_h5(file_path):
    """
    Read 10X h5 file and extract gene expression and protein expression data.
    Args:
        file_path (str): Path to h5 file
    Returns:
        tuple: (gene_data, protein_data, metadata)
    """
    logger.info(f"Reading {file_path}")
    
    try:
        with h5py.File(file_path, 'r') as f:
            # Check if this is a dummy file
            is_dummy = 'dummy' in os.path.basename(file_path).lower()
            
            # Extract barcodes and features
            barcodes = [b.decode('utf-8') for b in f['matrix']['barcodes'][:]]
            feature_ids = [g.decode('utf-8') for g in f['matrix']['features']['id'][:]]
            feature_names = [g.decode('utf-8') for g in f['matrix']['features']['name'][:]]
            feature_types = [g.decode('utf-8') for g in f['matrix']['features']['feature_type'][:]]
            
            # Extract sample_id and sample_name from the file name
            base_name = os.path.basename(file_path)
            parts = base_name.split('_')
            sample_id = parts[0]
            sample_name = parts[1] if len(parts) > 1 else "Unknown"
            
            metadata = {'gsm_id': sample_id, 'sample_name': sample_name}
            
            # Split data by feature_type
            gene_indices = np.array([i for i, ft in enumerate(feature_types) if ft == 'Gene Expression'])
            protein_indices = np.array([i for i, ft in enumerate(feature_types) if ft == 'Antibody Capture'])
            
            # If no protein indices, create dummy ones for testing
            if len(protein_indices) == 0 and is_dummy:
                logger.warning(f"No protein data found in {file_path}, creating dummy protein data")
                protein_indices = np.array([0])  # Just use the first gene as a dummy protein
            
            gene_idx_map = {idx: i for i, idx in enumerate(gene_indices)}
            protein_idx_map = {idx: i for i, idx in enumerate(protein_indices)}
            
            data = f['matrix']['data'][:]
            indices = f['matrix']['indices'][:]
            indptr = f['matrix']['indptr'][:]
            
            gene_data_arr = []
            gene_indices_arr = []
            gene_indptr = [0]
            
            protein_data_arr = []
            protein_indices_arr = []
            protein_indptr = [0]
            
            for i in range(len(indptr) - 1):
                start, end = indptr[i], indptr[i+1]
                cell_indices = indices[start:end]
                cell_data = data[start:end]
                
                # Gene
                gene_mask = np.isin(cell_indices, gene_indices)
                gene_cell_indices = cell_indices[gene_mask]
                gene_cell_data = cell_data[gene_mask]
                
                if len(gene_cell_indices) > 0:
                    gene_cell_indices_mapped = np.array([gene_idx_map[idx] for idx in gene_cell_indices])
                    
                    # Accumulate gene
                    gene_data_arr.append(gene_cell_data)
                    gene_indices_arr.append(gene_cell_indices_mapped)
                    gene_indptr.append(gene_indptr[-1] + len(gene_cell_data))
                else:
                    gene_indptr.append(gene_indptr[-1])
                
                # Protein
                protein_mask = np.isin(cell_indices, protein_indices)
                protein_cell_indices = cell_indices[protein_mask]
                protein_cell_data = cell_data[protein_mask]
                
                if len(protein_cell_indices) > 0:
                    protein_cell_indices_mapped = np.array([protein_idx_map[idx] for idx in protein_cell_indices])
                    
                    # Accumulate protein
                    protein_data_arr.append(protein_cell_data)
                    protein_indices_arr.append(protein_cell_indices_mapped)
                    protein_indptr.append(protein_indptr[-1] + len(protein_cell_data))
                else:
                    protein_indptr.append(protein_indptr[-1])
            
            # Concatenate
            gene_data_concat = np.concatenate(gene_data_arr) if gene_data_arr else np.array([])
            gene_indices_concat = np.concatenate(gene_indices_arr) if gene_indices_arr else np.array([])
            protein_data_concat = np.concatenate(protein_data_arr) if protein_data_arr else np.array([])
            protein_indices_concat = np.concatenate(protein_indices_arr) if protein_indices_arr else np.array([])
            
            # Build sparse matrices
            gene_matrix = sparse.csr_matrix(
                (gene_data_concat, gene_indices_concat, gene_indptr),
                shape=(len(barcodes), len(gene_indices))
            )
            
            protein_matrix = sparse.csr_matrix(
                (protein_data_concat, protein_indices_concat, protein_indptr),
                shape=(len(barcodes), len(protein_indices) or 1)
            )
            
            # Subset feature IDs and names
            gene_ids = [feature_ids[i] for i in gene_indices]
            gene_names = [feature_names[i] for i in gene_indices]
            
            if len(protein_indices) > 0:
                protein_ids = [feature_ids[i] for i in protein_indices]
                protein_names = [feature_names[i] for i in protein_indices]
            else:
                # Create dummy protein data for testing
                protein_ids = ['dummy_protein']
                protein_names = ['Dummy Protein']
            
            gene_data = {
                'matrix': gene_matrix,
                'feature_ids': gene_ids,
                'feature_names': gene_names,
                'barcodes': barcodes
            }
            
            protein_data = {
                'matrix': protein_matrix,
                'feature_ids': protein_ids,
                'feature_names': protein_names,
                'barcodes': barcodes
            }
            
            return gene_data, protein_data, metadata
    
    except Exception as e:
        logger.error(f"Error reading h5 file {file_path}: {e}")
        return None, None, None


def get_sample_metadata(sample_id):
    """
    Get metadata for a sample based on its ID.
    Args:
        sample_id (str): Sample ID (e.g., GSM8757538)
    Returns:
        dict: Metadata dictionary
    """
    if sample_id in SAMPLE_METADATA:
        metadata = SAMPLE_METADATA[sample_id].copy()
        metadata['gsm_id'] = sample_id
        metadata['organism'] = ORGANISM
        
        metadata['cell_type'] = 'Bone Marrow Cells' 
        metadata['cancer_type'] = 'Multiple Myeloma' if metadata['sample_type'] == 'MM' else 'MGUS'
        metadata['condition'] = 'Control'
        metadata['crispr_type'] = 'None'
        metadata['perturbation_name'] = 'None'
        return metadata
    else:
        logger.warning(f"No metadata found for sample {sample_id}")
        return {
            'gsm_id': sample_id,
            'organism': ORGANISM,
            'sample_name': 'Unknown',
            'sample_type': 'Unknown',
            'immune_age': 'Unknown',
            'cell_type': 'Bone Marrow Cells',
            'cancer_type': 'Unknown',
            'condition': 'Control',
            'crispr_type': 'None',
            'perturbation_name': 'None'
        }


def create_anndata(data_dict, metadata, data_type):
    """
    Create AnnData object from data dictionary and metadata.
    Args:
        data_dict (dict): Data dictionary with matrix, feature_ids, feature_names, barcodes
        metadata (dict): Metadata dictionary
        data_type (str): Type of data ('gene' or 'protein')
    Returns:
        AnnData: AnnData object
    """
    # Check if data_dict is valid
    if data_dict is None:
        logger.error(f"Invalid data dictionary for {metadata['gsm_id']}")
        return None
    
    adata = ad.AnnData(X=data_dict['matrix'])
    adata.obs_names = [f"{metadata['sample_name']}_{bc}" for bc in data_dict['barcodes']]
    
    # Create var DataFrame
    var_df = pd.DataFrame(index=data_dict['feature_names'])
    
    # Make var names unique if duplicates
    if len(var_df.index) != len(set(var_df.index)):
        logger.warning(f"Duplicate feature names found in {metadata['sample_name']}. Making them unique.")
        var_df.index = pd.Index([
            f"{name}_{i}" if list(var_df.index).count(name) > 1 else name 
            for i, name in enumerate(var_df.index)
        ])
    
    adata.var_names = var_df.index

    # Clean up protein names if data_type is protein
    if data_type == 'protein':
        adata.var_names = [re.sub(r'_\(.*\)', '', name) for name in adata.var_names]
    
    adata.var['feature_id'] = data_dict['feature_ids']
    
    # Add metadata to obs
    for key, value in metadata.items():
        adata.obs[key] = value
    
    adata.uns['data_type'] = data_type
    return adata


def process_dataset(data_dir, output_dir, use_dummy=False):
    """
    Process the dataset and create h5ad files.
    Args:
        data_dir (str): Directory containing h5 files
        output_dir (str): Directory to save output files
        use_dummy (bool): Whether to create dummy files if real ones are missing
    Returns:
        tuple: (gene_adata, protein_adata) - AnnData objects for gene and protein expression
    """
    os.makedirs(output_dir, exist_ok=True)
    temp_dir = os.path.join(output_dir, 'temp')
    os.makedirs(temp_dir, exist_ok=True)
    
    # First, try to download data if it doesn't exist
    try:
        # Try to download using multiple methods
        logger.info(f"Checking for data in {data_dir}")
        h5_files = get_h5_files(data_dir)
        
        if not h5_files:
            logger.info(f"No h5 files found in {data_dir}. Attempting to download...")
            
            # Try GEOquery (if R is available)
            r_download_successful = attempt_r_geoquery_download(ACCESSION_NUMBER, data_dir)
            
            if not r_download_successful:
                # Try manual downloads
                downloaded_files = download_geo_supplementary_files(ACCESSION_NUMBER, data_dir)
                
                if not downloaded_files:
                    logger.warning("All download methods failed. Using dummy data for testing if requested.")
                    
                    if use_dummy:
                        h5_files = get_h5_files(data_dir, create_dummy=True)
                    else:
                        logger.error("No data available and dummy mode not enabled. Cannot proceed.")
                        return None, None
                else:
                    # Check if any of the downloaded files are h5 files
                    h5_files = [f for f in downloaded_files if f.endswith('.h5')]
                    
                    if not h5_files:
                        logger.warning("Downloaded files don't include h5 files. Using dummy data if requested.")
                        
                        if use_dummy:
                            h5_files = get_h5_files(data_dir, create_dummy=True)
                        else:
                            logger.error("No h5 data available and dummy mode not enabled. Cannot proceed.")
                            return None, None
            else:
                # R download was successful, check for h5 files
                h5_files = get_h5_files(data_dir)
                
                if not h5_files and use_dummy:
                    logger.warning("R download successful but no h5 files found. Using dummy data.")
                    h5_files = get_h5_files(data_dir, create_dummy=True)
    except Exception as e:
        logger.error(f"Error during data acquisition: {e}")
        if use_dummy:
            logger.warning("Using dummy data due to error in data acquisition.")
            h5_files = get_h5_files(data_dir, create_dummy=True)
        else:
            return None, None
    
    # Now process the files we have
    gene_adatas, protein_adatas = [], []
    
    for file_path in h5_files:
        # Extract sample_id from filename
        sample_id = os.path.basename(file_path).split('_')[0]
        if not sample_id.startswith('GSM'):
            # Try to infer sample_id from the filename or directory
            for gsm_id in SAMPLE_METADATA.keys():
                if gsm_id in file_path or SAMPLE_METADATA[gsm_id]['sample_name'] in file_path:
                    sample_id = gsm_id
                    break
        
        # Extract data
        gene_data, protein_data, file_metadata = read_10x_h5(file_path)
        
        if gene_data is None:
            logger.warning(f"Skipping file {file_path} due to read error")
            continue
        
        # Get sample metadata
        metadata = get_sample_metadata(sample_id)
        
        # Create AnnData
        gene_adata = create_anndata(gene_data, metadata, 'gene')
        protein_adata = create_anndata(protein_data, metadata, 'protein')
        
        if gene_adata is not None and protein_adata is not None:
            gene_adatas.append(gene_adata)
            protein_adatas.append(protein_adata)
    
    if not gene_adatas or not protein_adatas:
        logger.error("No valid data files were processed. Cannot create AnnData objects.")
        return None, None
    
    # Concatenate
    logger.info("Concatenating gene expression data")
    gene_adata_combined = ad.concat(gene_adatas, join='outer', merge='same')
    
    logger.info("Concatenating protein expression data")
    protein_adata_combined = ad.concat(protein_adatas, join='outer', merge='same')
    
    # Make gene names unique if needed
    if len(gene_adata_combined.var_names) != len(set(gene_adata_combined.var_names)):
        logger.warning("Duplicate gene names found. Making gene names unique.")
        gene_adata_combined.var_names_make_unique()
    
    # Make protein names unique if needed
    if len(protein_adata_combined.var_names) != len(set(protein_adata_combined.var_names)):
        logger.warning("Duplicate protein names found. Making protein names unique.")
        protein_adata_combined.var_names_make_unique()
    
    # Filter for cells in both (paired data)
    gene_barcodes = set(gene_adata_combined.obs_names)
    protein_barcodes = set(protein_adata_combined.obs_names)
    common_barcodes = gene_barcodes.intersection(protein_barcodes)
    
    logger.info(f"Found {len(common_barcodes)} cells with both gene and protein data")
    
    gene_adata_paired = gene_adata_combined[list(common_barcodes)]
    protein_adata_paired = protein_adata_combined[list(common_barcodes)]
    
    # Save
    gene_output_path = os.path.join(output_dir, f"{ACCESSION_NUMBER}_gene.h5ad")
    protein_output_path = os.path.join(output_dir, f"{ACCESSION_NUMBER}_protein.h5ad")
    gene_paired_output_path = os.path.join(output_dir, f"{ACCESSION_NUMBER}_gene_paired.h5ad")
    protein_paired_output_path = os.path.join(output_dir, f"{ACCESSION_NUMBER}_protein_paired.h5ad")
    
    logger.info(f"Saving gene expression data to {gene_output_path}")
    gene_adata_combined.write(gene_output_path)
    
    logger.info(f"Saving protein expression data to {protein_output_path}")
    protein_adata_combined.write(protein_output_path)
    
    logger.info(f"Saving paired gene expression data to {gene_paired_output_path}")
    gene_adata_paired.write(gene_paired_output_path)
    
    logger.info(f"Saving paired protein expression data to {protein_paired_output_path}")
    protein_adata_paired.write(protein_paired_output_path)
    
    logger.info("Processing complete")
    
    return gene_adata_paired, protein_adata_paired

# --------------------------------------------------------------------------------
# 5. Jupyter-friendly functions for running the pipeline
# --------------------------------------------------------------------------------

def run_download_only(data_dir='GSE288020_data'):
    """
    Run only the download portion of the pipeline
    Args:
        data_dir (str): Directory to save downloaded files
    Returns:
        bool: True if successful, False otherwise
    """
    os.makedirs(data_dir, exist_ok=True)
    
    # Try GEOquery first
    r_success = attempt_r_geoquery_download(ACCESSION_NUMBER, data_dir)
    
    if not r_success:
        # Try manual downloads
        downloaded_files = download_geo_supplementary_files(ACCESSION_NUMBER, data_dir)
        
        if not downloaded_files:
            logger.warning("All download methods failed")
            return False
    
    logger.info("Download completed")
    return True

def run_processing(data_dir='GSE288020_data', output_dir='GSE288020_processed', use_dummy=False):
    """
    Run the full data processing pipeline
    Args:
        data_dir (str): Directory containing or to download data files
        output_dir (str): Directory to save processed files
        use_dummy (bool): Whether to create dummy data if real data is unavailable
    Returns:
        tuple: (gene_adata, protein_adata) - AnnData objects for gene and protein expression
    """
    os.makedirs(data_dir, exist_ok=True)
    os.makedirs(output_dir, exist_ok=True)
    
    logger.info("Starting data processing")
    gene_adata, protein_adata = process_dataset(data_dir, output_dir, use_dummy=use_dummy)
    
    if gene_adata is None or protein_adata is None:
        logger.error("Processing failed")
        return None, None
    
    # Print summaries
    print("\nGene Expression Data Summary:")
    print(f"Number of cells: {gene_adata.n_obs}")
    print(f"Number of genes: {gene_adata.n_vars}")
    print("Metadata fields:", list(gene_adata.obs.columns))
    
    print("\nProtein Expression Data Summary:")
    print(f"Number of cells: {protein_adata.n_obs}")
    print(f"Number of proteins: {protein_adata.n_vars}")
    print("Metadata fields:", list(protein_adata.obs.columns))
    
    return gene_adata, protein_adata

# For simple demonstration in a notebook
if __name__ == "__main__":
    print("This script is meant to be imported and used in a Jupyter notebook.")
    print("Example usage:")
    print("  from geo_data_download import run_processing")
    print("  gene_adata, protein_adata = run_processing(use_dummy=True)")
