In [None]:
import os
import h5py
import numpy as np
import pandas as pd
from scipy import sparse
import anndata as ad
import urllib.request
from pathlib import Path
import gzip
import re

def download_file(url, output_path):
    """Download a file from a URL to a specified path."""
    if not os.path.exists(output_path):
        print(f"Downloading {url} to {output_path}")
        urllib.request.urlretrieve(url, output_path)
        print(f"Downloaded {output_path}")
    else:
        print(f"File {output_path} already exists, skipping download")

def download_gse266225_files(data_dir):
    """
    Download GSE266225 files from NCBI's FTP, 
    using the correct file names from the GEO page.
    """
    base_url = "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE266nnn/GSE266225/suppl/"
    # NOTE: Updated 2wPC and 8wPC Citeseq file names 
    #       to remove "_filtered_" where GEO does not have it.
    files = [
        # Multiome data
        "GSE266225_liver_naive_JVG28_feature_bc_matrix.h5",
        "GSE266225_liver_8wPC_JVG29_feature_bc_matrix.h5",
        # CITE-seq data
        "GSE266225_MF_naive_MMU1_filtered_feature_bc_matrix.h5",
        "GSE266225_MF_infected_MMU2_feature_bc_matrix.h5",
        "GSE266225_MF_2wPC_MMU3_feature_bc_matrix.h5",  # removed '_filtered_'
        "GSE266225_MF_8wPC_MMU4_feature_bc_matrix.h5",  # removed '_filtered_'
        # Metadata (may or may not exist)
        "GSE266225_Metadata_Citeseq_Macrophages.csv.gz",
        "GSE266225_adt_feature_reference.csv.gz"
    ]
    
    # Attempt download
    for file in files:
        url = base_url + file
        output_path = os.path.join(data_dir, file)
        download_file(url, output_path)
        
        # Decompress gzipped files if needed
        if output_path.endswith('.gz'):
            uncompressed_path = output_path[:-3]
            if not os.path.exists(uncompressed_path):
                print(f"Decompressing {output_path}")
                with gzip.open(output_path, 'rb') as f_in:
                    with open(uncompressed_path, 'wb') as f_out:
                        f_out.write(f_in.read())
                print(f"Decompressed {output_path} to {uncompressed_path}")

def read_h5_to_anndata(file_path):
    """
    Read a 10x h5 file and convert it to an AnnData object (gene expression or protein, etc.)
    """
    print(f"Reading {file_path}")
    with h5py.File(file_path, 'r') as f:
        # Barcodes
        barcodes = f['matrix/barcodes'][:]
        barcodes = [b.decode('utf-8') for b in barcodes]
        
        # Features
        feature_ids = f['matrix/features/id'][:]
        feature_ids = [f0.decode('utf-8') for f0 in feature_ids]
        
        feature_names = f['matrix/features/name'][:]
        feature_names = [f0.decode('utf-8') for f0 in feature_names]
        
        # Feature types
        if 'feature_type' in f['matrix/features']:
            feature_types = f['matrix/features/feature_type'][:]
            feature_types = [f0.decode('utf-8') for f0 in feature_types]
        else:
            feature_types = ['Gene Expression'] * len(feature_ids)
        
        # Matrix shape
        shape = f['matrix/shape'][:]
        
        # Sparse matrix data
        data = f['matrix/data'][:]
        indices = f['matrix/indices'][:]
        indptr = f['matrix/indptr'][:]
        
        # Create csc_matrix, then transpose
        mat = sparse.csc_matrix((data, indices, indptr), shape=shape).transpose()
        
        # var DataFrame
        var = pd.DataFrame(index=feature_names)
        var["gene_ids"] = feature_ids
        var["feature_types"] = feature_types
        
        # obs DataFrame
        obs = pd.DataFrame(index=barcodes)
        
        # Create AnnData
        adata = ad.AnnData(X=mat, obs=obs, var=var)
        
        return adata

def clean_protein_names_final(adata):
    """
    Simplify protein names by keeping only one instance of each base protein.
    This approach keeps only the first occurrence of each protein and removes duplicates.
    
    Parameters:
    adata (AnnData): AnnData object with protein features
    
    Returns:
    AnnData: New AnnData object with unique protein names
    """
    import re
    import pandas as pd
    
    # Extract base names (remove _# or -# suffixes)
    var_names = adata.var_names.tolist()
    base_names = [re.sub(r'[_-][0-9]+$', '', name) for name in var_names]
    
    # Create mapping to find first occurrence of each base name
    unique_indices = {}
    for i, base in enumerate(base_names):
        if base not in unique_indices:
            unique_indices[base] = i
    
    # Get list of indices to keep (first occurrence of each base name)
    indices_to_keep = sorted(list(unique_indices.values()))
    
    # Create new filtered AnnData with only one instance of each protein
    filtered_adata = adata[:, indices_to_keep].copy()
    
    # Set var_names to the clean base names
    cleaned_names = [base_names[i] for i in indices_to_keep]
    filtered_adata.var_names = pd.Index(cleaned_names)
    
    # Make sure there are no duplicates in final names
    if len(set(filtered_adata.var_names)) < len(filtered_adata.var_names):
        filtered_adata.var_names_make_unique()
    
    print(f"Reduced from {adata.shape[1]} to {filtered_adata.shape[1]} protein features")
    return filtered_adata

# Alternate simpler approach - just pick the first instance of each protein
def simplify_protein_names(adata):
    """
    Simplify protein names by keeping only one instance of each base protein.
    This approach keeps only the first occurrence and drops duplicate proteins.
    WARNING: This will remove data if there are true duplicates.
    
    Parameters:
    adata (AnnData): AnnData object with protein features
    
    Returns:
    AnnData: A new AnnData object with simplified var_names
    """
    import re
    
    # Extract base names
    base_names = [re.sub(r'[_-][0-9]+$', '', name) for name in adata.var_names]
    
    # Find unique base names and their first occurrence
    unique_bases = {}
    for i, base_name in enumerate(base_names):
        if base_name not in unique_bases:
            unique_bases[base_name] = i
    
    # Filter adata to keep only the first occurrence of each protein
    indices_to_keep = list(unique_bases.values())
    
    # Create new filtered AnnData
    filtered_adata = adata[:, indices_to_keep].copy()
    
    # Update var_names to base names
    base_names_list = [base_names[i] for i in indices_to_keep]
    filtered_adata.var_names = pd.Index(base_names_list)
    
    return filtered_adata

def process_multiome_data(data_dir):
    """
    Process multiome data from GSE266225 (two .h5 files: naive + 8wPC).
    Returns concatenated AnnData with only gene expression.
    """
    file_paths = {
        "control": os.path.join(data_dir, "GSE266225_liver_naive_JVG28_feature_bc_matrix.h5"),
        "8we treated": os.path.join(data_dir, "GSE266225_liver_8wPC_JVG29_feature_bc_matrix.h5"),
    }
    
    gene_adatas = []
    for condition, file_path in file_paths.items():
        print(f"Processing multiome condition: {condition}")
        adata = read_h5_to_anndata(file_path)
        
        # Add condition
        adata.obs["condition"] = {
            "control": "Control",
            "8we treated": "8 weeks post curing"
        }.get(condition, condition)
        
        # Add consistent metadata
        adata.obs["organism"] = "Mus musculus"
        adata.obs["cell_type"] = "Liver cells"
        adata.obs["crispr_type"] = "None"
        adata.obs["cancer_type"] = "Non-Cancer"
        adata.obs["perturbation_name"] = "None"
        
        # Keep only gene expression features
        gene_mask = adata.var["feature_types"] == "Gene Expression"
        gene_adata = adata[:, gene_mask].copy()
        gene_adata.var_names_make_unique()
        
        # Mark batch
        gene_adata.obs["batch"] = condition
        
        gene_adatas.append(gene_adata)
    
    # Fix repeated var names across multiple sets
    all_gene_names = []
    for adata in gene_adatas:
        all_gene_names.extend(adata.var_names)
    dup_series = pd.Series(all_gene_names).value_counts()
    duplicates = dup_series[dup_series > 1].index.tolist()
    
    for i, adata in enumerate(gene_adatas):
        for dup in duplicates:
            if dup in adata.var_names:
                idx = adata.var_names.get_loc(dup)
                adata.var_names.values[idx] = f"{dup}_{i}"
    
    # Concatenate
    print("Concatenating multiome gene data...")
    multi_adata = ad.concat(gene_adatas, join='outer')
    multi_adata.obs_names_make_unique()
    multi_adata.var_names_make_unique()
    
    return multi_adata

def process_citeseq_data(data_dir):
    """
    Process CITE-seq data from GSE266225 (four .h5 files: naive, infected, 2wPC, 8wPC).
    Returns (gene_adata_combined, protein_adata_combined).
    """
    file_paths = {
        "naive": os.path.join(data_dir, "GSE266225_MF_naive_MMU1_filtered_feature_bc_matrix.h5"),
        "infected": os.path.join(data_dir, "GSE266225_MF_infected_MMU2_feature_bc_matrix.h5"),
        "2wPC": os.path.join(data_dir, "GSE266225_MF_2wPC_MMU3_feature_bc_matrix.h5"),
        "8wPC": os.path.join(data_dir, "GSE266225_MF_8wPC_MMU4_feature_bc_matrix.h5"),
    }
    
    # Attempt to load macrophage metadata if it exists
    meta_csv = os.path.join(data_dir, "GSE266225_Metadata_Citeseq_Macrophages.csv")
    metadata = None
    if os.path.exists(meta_csv):
        print(f"Loading macrophage metadata from {meta_csv}")
        metadata = pd.read_csv(meta_csv, index_col=0)
    
    gene_list = []
    prot_list = []
    
    for condition, file_path in file_paths.items():
        print(f"Processing Citeseq condition: {condition}")
        adata = read_h5_to_anndata(file_path)
        
        # Condition labeling
        adata.obs["condition"] = {
            "naive": "Control",
            "infected": "Infected",
            "2wPC": "2 weeks post curing",
            "8wPC": "8 weeks post curing",
        }.get(condition, condition)
        
        # More metadata
        adata.obs["organism"] = "Mus musculus"
        adata.obs["cell_type"] = "Macrophages"
        adata.obs["crispr_type"] = "None"
        adata.obs["cancer_type"] = "Non-Cancer"
        adata.obs["perturbation_name"] = "None"
        adata.obs["batch"] = condition
        
        # Split gene vs protein
        gene_mask = adata.var["feature_types"] == "Gene Expression"
        prot_mask = adata.var["feature_types"] == "Antibody Capture"
        
        gene_adata = adata[:, gene_mask].copy()
        gene_adata.var_names_make_unique()
        
        prot_adata = adata[:, prot_mask].copy() if prot_mask.sum() > 0 else None
        if prot_adata is not None:
            prot_adata.var_names_make_unique()
        
        # Add extra metadata if available
        if metadata is not None:
            common_barcodes = set(gene_adata.obs_names).intersection(metadata.index)
            if common_barcodes:
                print(f"  Found {len(common_barcodes)} cells in metadata for {condition}")
                for col in metadata.columns:
                    # Gene:
                    gene_adata.obs[col] = pd.Series(index=gene_adata.obs_names, dtype=metadata[col].dtype)
                    for bc in common_barcodes:
                        gene_adata.obs.loc[bc, col] = metadata.loc[bc, col]
                # Protein:
                if prot_adata is not None:
                    for col in metadata.columns:
                        prot_adata.obs[col] = pd.Series(index=prot_adata.obs_names, dtype=metadata[col].dtype)
                        for bc in common_barcodes:
                            if bc in prot_adata.obs_names:
                                prot_adata.obs.loc[bc, col] = metadata.loc[bc, col]
        
        gene_list.append(gene_adata)
        if prot_adata is not None:
            prot_list.append(prot_adata)
    
    # Resolve gene name collisions
    all_gene_names = []
    for g in gene_list:
        all_gene_names.extend(g.var_names)
    dup_series = pd.Series(all_gene_names).value_counts()
    duplicates = dup_series[dup_series > 1].index.tolist()
    
    for i, g in enumerate(gene_list):
        for dup in duplicates:
            if dup in g.var_names:
                idx = g.var_names.get_loc(dup)
                g.var_names.values[idx] = f"{dup}_{i}"
    
    print("Concatenating CITE-seq gene data...")
    gene_adata_combined = ad.concat(gene_list, join='outer')
    gene_adata_combined.obs_names_make_unique()
    gene_adata_combined.var_names_make_unique()
    
    # Resolve protein name collisions if any
    if len(prot_list) > 0:
        all_prot_names = []
        for p in prot_list:
            all_prot_names.extend(p.var_names)
        dup_series_p = pd.Series(all_prot_names).value_counts()
        duplicates_p = dup_series_p[dup_series_p > 1].index.tolist()
        
        for i, p in enumerate(prot_list):
            for dup in duplicates_p:
                if dup in p.var_names:
                    idx = p.var_names.get_loc(dup)
                    p.var_names.values[idx] = f"{dup}_{i}"
        
        print("Concatenating CITE-seq protein data...")
        protein_adata_combined = ad.concat(prot_list, join='outer')
        protein_adata_combined.obs_names_make_unique()
        protein_adata_combined.var_names_make_unique()
        
        # Clean protein names after concatenation
        print("Cleaning protein names (removing suffixes)...")
        protein_adata_combined = clean_protein_names(protein_adata_combined)
    else:
        protein_adata_combined = None
    
    return gene_adata_combined, protein_adata_combined

def get_paired_data(gene_adata, protein_adata):
    """
    Given gene and protein AnnData from the same cells, find overlap 
    and return matched gene/protein objects.
    """
    if protein_adata is None:
        return gene_adata, None
    
    # Overlap in barcodes
    common_barcodes = list(set(gene_adata.obs_names).intersection(protein_adata.obs_names))
    print(f"Found {len(common_barcodes)} overlapping barcodes in gene & protein data.")
    if len(common_barcodes) == 0:
        return gene_adata, protein_adata
    
    # Subset both
    gene_paired = gene_adata[common_barcodes].copy()
    prot_paired = protein_adata[common_barcodes].copy()
    
    # Sync metadata columns
    for col in gene_paired.obs.columns:
        if col in prot_paired.obs.columns:
            prot_paired.obs[col] = gene_paired.obs[col]
    
    return gene_paired, prot_paired

def run_pipeline(data_dir: str):
    """
    Main pipeline in one function.
    """
    os.makedirs(data_dir, exist_ok=True)
    # Download files (make sure the file names match GEO)
    download_gse266225_files(data_dir)
    
    # Process multiome
    print("\n--- Processing Multiome data ---")
    multiome_gene = process_multiome_data(data_dir)
    
    # Process Citeseq
    print("\n--- Processing CITE-seq data ---")
    c_gene, c_protein = process_citeseq_data(data_dir)
    
    # Attempt to pair gene & protein
    print("\n--- Getting paired gene-protein from CITE-seq ---")
    c_gene_paired, c_protein_paired = get_paired_data(c_gene, c_protein)
    
    # Save out
    out_dir = os.path.join(data_dir, "harmonized")
    os.makedirs(out_dir, exist_ok=True)
    
    # Multiome gene
    multi_fn = os.path.join(out_dir, "GSE266225_multiome_gene_expression.h5ad")
    print(f"\nSaving multiome gene to {multi_fn}")
    multiome_gene.write_h5ad(multi_fn, compression="gzip")
    
    # Citeseq gene
    citeseq_gene_fn = os.path.join(out_dir, "GSE266225_citeseq_gene_expression.h5ad")
    print(f"Saving Citeseq gene to {citeseq_gene_fn}")
    c_gene_paired.write_h5ad(citeseq_gene_fn, compression="gzip")
    
    # Citeseq protein - Use the improved function to clean protein names
    if c_protein_paired is not None:
        print("\n--- Cleaning protein names (removing duplicates) ---")
        # Use the new function that keeps only one instance of each protein
        c_protein_unique = clean_protein_names_final(c_protein_paired)
        
        # Print the final list of protein names
        print("Final protein features:")
        print(", ".join([f"'{name}'" for name in c_protein_unique.var_names[:20]]) + 
              (", ..." if len(c_protein_unique.var_names) > 20 else ""))
        
        citeseq_prot_fn = os.path.join(out_dir, "GSE266225_citeseq_protein_expression.h5ad")
        print(f"Saving Citeseq protein to {citeseq_prot_fn}")
        c_protein_unique.write_h5ad(citeseq_prot_fn, compression="gzip")
    
    print("\n--- Pipeline done! ---")
    print("Multiome gene shape:", multiome_gene.shape)
    print("Citeseq gene shape:", c_gene_paired.shape)
    if c_protein_paired is not None:
        print("Original protein shape:", c_protein_paired.shape)
        print("Cleaned protein shape:", c_protein_unique.shape)

# --------------------------
# In a Jupyter notebook, just do:
data_directory = "./GSE266225"  # or another path
run_pipeline(data_directory)