In [None]:
import os
import sys
import glob
import gzip
import urllib.request
import pandas as pd
import numpy as np
import h5py
import scanpy as sc
import anndata as ad
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Union

# Constants
ACCESSION = "GSE270795"
BASE_URL = f"https://ftp.ncbi.nlm.nih.gov/geo/series/GSE270nnn/{ACCESSION}/suppl/"
FILE_PATTERNS = [
    f"{ACCESSION}_BC*_Tet*_sample_feature_bc_matrix.h5",
    f"{ACCESSION}_BC*_Tet*_filtered_contig_annotations.csv.gz",
    f"{ACCESSION}_TS_TotalSeq_01_features.csv.gz"
]

# Metadata mapping
METADATA_MAPPING = {
    "organism": "Homo sapiens",  # Based on the dataset description
    "cell_type": "MAIT cells",   # Based on the dataset description
    "crispr_type": "None",       # Not applicable for this dataset
    "cancer_type": "Non-Cancer", # Based on the dataset description (healthy donors)
    "condition": {
        "Tet": "Control",
        "Tet_IL12": "IL-12 stimulated",
        "Tet_IL23": "IL-23 stimulated"
    },
    "perturbation_name": "None"  # No CRISPR perturbations in this dataset
}

def clean_protein_name(name: str) -> str:
    """
    Clean up the protein name by removing extra prefixes and trailing parts.
    """
    # If it's an isotype control (contains 'Ctrl'), only keep the part before the underscore.
    if "Ctrl" in name:
        name = name.split("_")[0]
    
    # Remove specific prefixes in order of priority
    for prefix in ["anti-mouse-", "antihuman", "anti-"]:
        if name.startswith(prefix):
            name = name[len(prefix):]
            break

    # Remove trailing "isotype" (case-insensitive) if present
    if name.lower().endswith("isotype"):
        name = name[:-len("isotype")]
    
    return name.strip()

def download_files(data_dir: str) -> None:
    """
    Download dataset files if they don't exist.
    
    Args:
        data_dir: Directory to save the downloaded files
    """
    os.makedirs(data_dir, exist_ok=True)
    
    # Get list of files to download
    file_list = []
    for pattern in FILE_PATTERNS:
        file_list.extend(glob.glob(os.path.join(data_dir, pattern)))
    
    # If no files found, download them
    if not file_list:
        print(f"No files found in {data_dir}. Downloading...")
        
        # Install required packages if needed
        try:
            import requests
            from bs4 import BeautifulSoup
        except ImportError:
            print("Installing required packages...")
            os.system("pip install requests beautifulsoup4")
            import requests
            from bs4 import BeautifulSoup
        
        # Define the files we need to download
        files_to_download = [
            f"{ACCESSION}_BC2_Tet_IL12_filtered_contig_annotations.csv.gz",
            f"{ACCESSION}_BC2_Tet_IL12_sample_feature_bc_matrix.h5",
            f"{ACCESSION}_BC2_Tet_IL23_filtered_contig_annotations.csv.gz",
            f"{ACCESSION}_BC2_Tet_IL23_sample_feature_bc_matrix.h5",
            f"{ACCESSION}_BC2_Tet_filtered_contig_annotations.csv.gz",
            f"{ACCESSION}_BC2_Tet_sample_feature_bc_matrix.h5",
            f"{ACCESSION}_BC3_Tet_IL12_filtered_contig_annotations.csv.gz",
            f"{ACCESSION}_BC3_Tet_IL12_sample_feature_bc_matrix.h5",
            f"{ACCESSION}_BC3_Tet_IL23_filtered_contig_annotations.csv.gz",
            f"{ACCESSION}_BC3_Tet_IL23_sample_feature_bc_matrix.h5",
            f"{ACCESSION}_BC3_Tet_filtered_contig_annotations.csv.gz",
            f"{ACCESSION}_BC3_Tet_sample_feature_bc_matrix.h5",
            f"{ACCESSION}_BC4_Tet_IL12_filtered_contig_annotations.csv.gz",
            f"{ACCESSION}_BC4_Tet_IL12_sample_feature_bc_matrix.h5",
            f"{ACCESSION}_BC4_Tet_IL23_filtered_contig_annotations.csv.gz",
            f"{ACCESSION}_BC4_Tet_IL23_sample_feature_bc_matrix.h5",
            f"{ACCESSION}_BC4_Tet_filtered_contig_annotations.csv.gz",
            f"{ACCESSION}_BC4_Tet_sample_feature_bc_matrix.h5",
            f"{ACCESSION}_TS_TotalSeq_01_features.csv.gz"
        ]
        
        # Download each file directly from the FTP server
        for file in files_to_download:
            file_url = f"{BASE_URL}{file}"
            file_path = os.path.join(data_dir, file)
            
            if not os.path.exists(file_path):
                print(f"Downloading {file}...")
                try:
                    # Try with requests first
                    response = requests.get(file_url, stream=True)
                    if response.status_code == 200:
                        with open(file_path, 'wb') as f:
                            for chunk in response.iter_content(chunk_size=8192):
                                f.write(chunk)
                        print(f"Downloaded {file}")
                    else:
                        # Fall back to urllib if requests fails
                        try:
                            urllib.request.urlretrieve(file_url, file_path)
                            print(f"Downloaded {file} using urllib")
                        except Exception as e:
                            print(f"Failed to download {file}: {e}")
                except Exception as e:
                    print(f"Error downloading {file}: {e}")
                    # Try with wget as a last resort
                    try:
                        os.system(f"wget -O {file_path} {file_url}")
                        if os.path.exists(file_path) and os.path.getsize(file_path) > 0:
                            print(f"Downloaded {file} using wget")
                        else:
                            print(f"Failed to download {file} using wget")
                    except Exception as e2:
                        print(f"Error using wget: {e2}")
    else:
        print(f"Found {len(file_list)} files in {data_dir}")

def load_h5_data(file_path: str) -> Tuple[ad.AnnData, ad.AnnData]:
    """
    Load data from a 10x h5 file and split into gene expression and protein data.
    
    Args:
        file_path: Path to the h5 file
        
    Returns:
        Tuple of (gene_expression_adata, protein_adata)
    """
    print(f"Processing {os.path.basename(file_path)}...")
    
    # Extract sample info from filename
    filename = os.path.basename(file_path)
    parts = filename.split('_')
    donor = parts[1]  # BC2, BC3, BC4
    
    # Handle different condition formats
    if len(parts) > 3 and parts[3] in ["IL12", "IL23"]:
        condition = f"{parts[2]}_{parts[3]}"  # Tet_IL12, Tet_IL23
    else:
        condition = parts[2]  # Tet
    
    # Load the h5 file manually to ensure correct feature handling
    with h5py.File(file_path, 'r') as f:
        # Get matrix dimensions
        shape = f['matrix']['shape'][:]
        n_features, n_barcodes = shape
        
        # Get feature metadata
        feature_ids = [x.decode('utf-8') for x in f['matrix']['features']['id'][:]]
        feature_names = [x.decode('utf-8') for x in f['matrix']['features']['name'][:]]
        feature_types = [x.decode('utf-8') for x in f['matrix']['features']['feature_type'][:]]
        
        # Get cell barcodes
        barcodes = [x.decode('utf-8') for x in f['matrix']['barcodes'][:]]
        
        # Get sparse matrix data
        data = f['matrix']['data'][:]
        indices = f['matrix']['indices'][:]
        indptr = f['matrix']['indptr'][:]
        
        # Create feature metadata DataFrame
        feature_metadata = pd.DataFrame({
            'id': feature_ids,
            'name': feature_names,
            'feature_type': feature_types
        })
        
        # Split features by type
        gene_metadata = feature_metadata[feature_metadata['feature_type'] == 'Gene Expression']
        protein_metadata = feature_metadata[feature_metadata['feature_type'] == 'Antibody Capture']
        
        # Create cell metadata
        cell_metadata = pd.DataFrame(index=barcodes)
        cell_metadata['donor'] = donor
        cell_metadata['condition_raw'] = condition
        cell_metadata['sample'] = f"{donor}_{condition}"
        
        # Add harmonized metadata
        cell_metadata['organism'] = METADATA_MAPPING['organism']
        cell_metadata['cell_type'] = METADATA_MAPPING['cell_type']
        cell_metadata['crispr_type'] = METADATA_MAPPING['crispr_type']
        cell_metadata['cancer_type'] = METADATA_MAPPING['cancer_type']
        cell_metadata['condition'] = cell_metadata['condition_raw'].map(METADATA_MAPPING['condition'])
        cell_metadata['perturbation_name'] = METADATA_MAPPING['perturbation_name']
        
        # Create gene expression AnnData
        if not gene_metadata.empty:
            from scipy import sparse
            
            # Get indices of gene expression features
            gene_indices = gene_metadata.index.tolist()
            
            # Create a mapping from original indices to new positions
            gene_idx_map = {idx: i for i, idx in enumerate(gene_indices)}
            
            # Filter the sparse matrix for gene expression
            gene_data = []
            gene_indices_new = []
            gene_indptr = [0]
            
            for i in range(len(indptr) - 1):
                start, end = indptr[i], indptr[i+1]
                count = 0
                for j in range(start, end):
                    if indices[j] in gene_idx_map:
                        gene_data.append(data[j])
                        gene_indices_new.append(gene_idx_map[indices[j]])
                        count += 1
                gene_indptr.append(gene_indptr[-1] + count)
            
            # Create sparse matrix
            gene_matrix = sparse.csr_matrix(
                (gene_data, gene_indices_new, gene_indptr),
                shape=(len(barcodes), len(gene_indices))
            )
            
            # Create AnnData object
            gene_adata = ad.AnnData(
                X=gene_matrix,
                obs=cell_metadata,
                var=gene_metadata.reset_index(drop=True)
            )
            
            # Set var_names to gene symbols and ensure uniqueness
            gene_adata.var_names = gene_adata.var['name'].values
            gene_adata.var.index.name = None
            if gene_adata.var_names.duplicated().any():
                gene_adata.var_names_make_unique()
                
            print(f"Created gene expression AnnData: {gene_adata.shape[0]} cells, {gene_adata.shape[1]} genes")
        else:
            gene_adata = None
            print("No gene expression features found")
        
        # Create protein AnnData
        if not protein_metadata.empty:
            from scipy import sparse
            
            # Get indices of protein features
            protein_indices = protein_metadata.index.tolist()
            
            # Create a mapping from original indices to new positions
            protein_idx_map = {idx: i for i, idx in enumerate(protein_indices)}
            
            # Filter the sparse matrix for protein expression
            protein_data = []
            protein_indices_new = []
            protein_indptr = [0]
            
            for i in range(len(indptr) - 1):
                start, end = indptr[i], indptr[i+1]
                count = 0
                for j in range(start, end):
                    if indices[j] in protein_idx_map:
                        protein_data.append(data[j])
                        protein_indices_new.append(protein_idx_map[indices[j]])
                        count += 1
                protein_indptr.append(protein_indptr[-1] + count)
            
            # Create sparse matrix
            protein_matrix = sparse.csr_matrix(
                (protein_data, protein_indices_new, protein_indptr),
                shape=(len(barcodes), len(protein_indices))
            )
            
            # Create AnnData object
            protein_adata = ad.AnnData(
                X=protein_matrix,
                obs=cell_metadata,
                var=protein_metadata.reset_index(drop=True)
            )
            
            # Set var_names to protein names
            protein_adata.var_names = protein_adata.var['name'].values
            protein_adata.var.index.name = None
            
            # Clean the protein names using the helper function
            protein_adata.var_names = [clean_protein_name(name) for name in protein_adata.var_names]
            
            # Ensure protein var_names are unique
            if protein_adata.var_names.duplicated().any():
                protein_adata.var_names_make_unique()
                
            print(f"Created protein expression AnnData: {protein_adata.shape[0]} cells, {protein_adata.shape[1]} proteins")
        else:
            protein_adata = None
            print("No protein expression features found")
    
    return gene_adata, protein_adata

def load_vdj_data(file_path: str) -> pd.DataFrame:
    """
    Load VDJ data from a filtered contig annotations file.
    
    Args:
        file_path: Path to the filtered contig annotations file
        
    Returns:
        DataFrame with VDJ data
    """
    # Load the VDJ data
    vdj_data = pd.read_csv(file_path, compression='gzip')
    
    # Extract sample info from filename
    filename = os.path.basename(file_path)
    parts = filename.split('_')
    donor = parts[1]  # BC2, BC3, BC4
    
    # Handle different condition formats
    if len(parts) > 3 and parts[3] in ["IL12", "IL23"]:
        condition = f"{parts[2]}_{parts[3]}"  # Tet_IL12, Tet_IL23
    else:
        condition = parts[2]  # Tet
    
    # Add sample information
    vdj_data['donor'] = donor
    vdj_data['condition'] = condition
    vdj_data['sample'] = f"{donor}_{condition}"
    
    return vdj_data

def process_dataset(data_dir: str) -> Tuple[ad.AnnData, ad.AnnData]:
    """
    Process the entire dataset and create harmonized AnnData objects.
    
    Args:
        data_dir: Directory containing the dataset files
        
    Returns:
        Tuple of (gene_expression_adata, protein_adata)
    """
    # Find all h5 files
    h5_files = glob.glob(os.path.join(data_dir, f"{ACCESSION}_BC*_Tet*_sample_feature_bc_matrix.h5"))
    
    # Find all VDJ files
    vdj_files = glob.glob(os.path.join(data_dir, f"{ACCESSION}_BC*_Tet*_filtered_contig_annotations.csv.gz"))
    
    # Load protein features
    protein_features_file = os.path.join(data_dir, f"{ACCESSION}_TS_TotalSeq_01_features.csv.gz")
    if os.path.exists(protein_features_file):
        protein_features = pd.read_csv(protein_features_file, compression='gzip')
        print(f"Loaded {len(protein_features)} protein features")
    else:
        protein_features = None
        print("Protein features file not found")
    
    # Process each h5 file
    gene_adatas = []
    protein_adatas = []
    
    for h5_file in h5_files:
        gene_adata, protein_adata = load_h5_data(h5_file)
        
        if gene_adata is not None:
            gene_adatas.append(gene_adata)
        
        if protein_adata is not None:
            protein_adatas.append(protein_adata)
    
    # Process VDJ data
    vdj_data_list = []
    for vdj_file in vdj_files:
        vdj_data = load_vdj_data(vdj_file)
        vdj_data_list.append(vdj_data)
    
    if vdj_data_list:
        vdj_data_combined = pd.concat(vdj_data_list, ignore_index=True)
        print(f"Loaded VDJ data for {len(vdj_data_combined['barcode'].unique())} cells")
    else:
        vdj_data_combined = None
    
    # Combine gene expression data
    if gene_adatas:
        # Make observation names unique before concatenation
        for i, adata in enumerate(gene_adatas):
            adata.obs_names = [f"{obs_name}_{i}" for obs_name in adata.obs_names]
        
        combined_gene_adata = ad.concat(gene_adatas, join='outer', merge='same')
        print(f"Combined gene expression data: {combined_gene_adata.shape[0]} cells, {combined_gene_adata.shape[1]} genes")
        
        if combined_gene_adata.var_names.duplicated().any():
            print(f"Found {combined_gene_adata.var_names.duplicated().sum()} duplicate gene names")
            combined_gene_adata.var_names = combined_gene_adata.var_names + '_' + combined_gene_adata.var['id']
            if combined_gene_adata.var_names.duplicated().any():
                combined_gene_adata.var_names_make_unique()
    else:
        combined_gene_adata = None
    
    # Combine protein data
    if protein_adatas:
        for i, adata in enumerate(protein_adatas):
            adata.obs_names = [f"{obs_name}_{i}" for obs_name in adata.obs_names]
        
        combined_protein_adata = ad.concat(protein_adatas, join='outer', merge='same')
        print(f"Combined protein data: {combined_protein_adata.shape[0]} cells, {combined_protein_adata.shape[1]} proteins")
        
        if combined_protein_adata.var_names.duplicated().any():
            print(f"Found {combined_protein_adata.var_names.duplicated().sum()} duplicate protein names")
            combined_protein_adata.var_names_make_unique()
    else:
        combined_protein_adata = None
    
    # Add VDJ data to gene expression and protein data
    if vdj_data_combined is not None:
        for col in ['chain', 'v_gene', 'j_gene', 'c_gene', 'cdr3', 'raw_clonotype_id']:
            vdj_data_combined[col] = vdj_data_combined[col].fillna('').astype(str)
        
        vdj_summary = vdj_data_combined.groupby('barcode').agg({
            'chain': lambda x: ','.join(sorted(set(x))) if all(x != '') else np.nan,
            'v_gene': lambda x: ','.join(sorted(set(x))) if all(x != '') else np.nan,
            'j_gene': lambda x: ','.join(sorted(set(x))) if all(x != '') else np.nan,
            'c_gene': lambda x: ','.join(sorted(set(x))) if all(x != '') else np.nan,
            'cdr3': lambda x: ','.join(sorted(set(x))) if all(x != '') else np.nan,
            'productive': lambda x: all(x) if not pd.isna(x).any() else np.nan,
            'full_length': lambda x: all(x) if not pd.isna(x).any() else np.nan,
            'raw_clonotype_id': lambda x: ','.join(sorted(set(x))) if all(x != '') else np.nan
        }).reset_index()
        
        vdj_dict = vdj_summary.set_index('barcode').to_dict(orient='index')
        
        if combined_gene_adata is not None:
            vdj_cols = ['chain', 'v_gene', 'j_gene', 'c_gene', 'cdr3', 'productive', 'full_length', 'raw_clonotype_id']
            for col in vdj_cols:
                combined_gene_adata.obs[f'vdj_{col}'] = [vdj_dict.get(bc, {}).get(col, np.nan) for bc in combined_gene_adata.obs.index]
        
        if combined_protein_adata is not None:
            vdj_cols = ['chain', 'v_gene', 'j_gene', 'c_gene', 'cdr3', 'productive', 'full_length', 'raw_clonotype_id']
            for col in vdj_cols:
                combined_protein_adata.obs[f'vdj_{col}'] = [vdj_dict.get(bc, {}).get(col, np.nan) for bc in combined_protein_adata.obs.index]
    
    # Filter to keep only paired data (cells present in both gene and protein data)
    if combined_gene_adata is not None and combined_protein_adata is not None:
        gene_barcodes = combined_gene_adata.obs['sample'].astype(str) + '_' + combined_gene_adata.obs.index.str.split('_').str[0]
        protein_barcodes = combined_protein_adata.obs['sample'].astype(str) + '_' + combined_protein_adata.obs.index.str.split('_').str[0]
        
        gene_barcode_map = {bc: idx for idx, bc in zip(combined_gene_adata.obs.index, gene_barcodes)}
        protein_barcode_map = {bc: idx for idx, bc in zip(combined_protein_adata.obs.index, protein_barcodes)}
        
        common_original_barcodes = set(gene_barcode_map.keys()) & set(protein_barcode_map.keys())
        print(f"Found {len(common_original_barcodes)} cells with both gene expression and protein data")
        
        if common_original_barcodes:
            gene_common_barcodes = [gene_barcode_map[bc] for bc in common_original_barcodes]
            protein_common_barcodes = [protein_barcode_map[bc] for bc in common_original_barcodes]
            
            combined_gene_adata = combined_gene_adata[gene_common_barcodes].copy()
            combined_protein_adata = combined_protein_adata[protein_common_barcodes].copy()
            
            combined_gene_adata.obs['original_barcode'] = [bc.split('_')[0] for bc in combined_gene_adata.obs.index]
            combined_protein_adata.obs['original_barcode'] = [bc.split('_')[0] for bc in combined_protein_adata.obs.index]
    
    return combined_gene_adata, combined_protein_adata

def main(data_dir: Optional[str] = None):
    """
    Main function to download, process, and save the dataset.
    """
    if data_dir is None:
        data_dir = os.getcwd()
    
    # Download files if needed
    download_files(data_dir)
    
    # Process the dataset
    gene_adata, protein_adata = process_dataset(data_dir)
    
    # Save the processed data
    output_dir = os.path.join(data_dir, "processed")
    os.makedirs(output_dir, exist_ok=True)
    
    if gene_adata is not None:
        gene_output_file = os.path.join(output_dir, f"{ACCESSION}_gene_expression_harmonized.h5ad")
        gene_adata.write_h5ad(gene_output_file, compression="gzip")
        print(f"Saved gene expression data to {gene_output_file}")
    
    if protein_adata is not None:
        protein_output_file = os.path.join(output_dir, f"{ACCESSION}_protein_expression_harmonized.h5ad")
        protein_adata.write_h5ad(protein_output_file, compression="gzip")
        print(f"Saved protein expression data to {protein_output_file}")

# Run the main function directly (for use in a Jupyter Notebook cell)
main()
