In [None]:
import os
import gzip
import urllib.request
import numpy as np
import pandas as pd
import scipy.io as sio
import anndata as ad

# Base URL for downloading files
BASE_URL = "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE284nnn/GSE284526/suppl/"

# Files to download
FILES = [
    "GSE284526_LBA001_matrix.mtx.gz",
    "GSE284526_LBA001_features.tsv.gz",
    "GSE284526_LBA001_barcodes.tsv.gz",
    "GSE284526_LBA001_feature_reference.csv.gz",
    "GSE284526_LBA002_matrix.mtx.gz",
    "GSE284526_LBA002_features.tsv.gz",
    "GSE284526_LBA002_barcodes.tsv.gz",
    "GSE284526_LBA002_feature_reference.csv.gz",
    "GSE284526_LBA003_matrix.mtx.gz",
    "GSE284526_LBA003_features.tsv.gz",
    "GSE284526_LBA003_barcodes.tsv.gz",
    "GSE284526_LBA003_feature_reference.csv.gz",
    "GSE284526_LBA004_matrix.mtx.gz",
    "GSE284526_LBA004_features.tsv.gz",
    "GSE284526_LBA004_barcodes.tsv.gz",
    "GSE284526_LBA004_feature_reference.csv.gz"
]


def download_files(data_dir):
    """
    Download required files if they don't exist.
    Raises an error if download fails.
    
    Args:
        data_dir: Path to the data directory
    """
    os.makedirs(data_dir, exist_ok=True)
    
    for file in FILES:
        file_path = os.path.join(data_dir, file)
        if not os.path.exists(file_path):
            try:
                print(f"Downloading {file}...")
                url = f"{BASE_URL}{file}"
                urllib.request.urlretrieve(url, file_path)
                print(f"Downloaded {file}")
            except Exception as e:
                print(f"Error downloading {file}: {e}")
                print("Please download the file manually and place it in the data directory.")
                raise
        else:
            print(f"File {file} already exists, skipping download")


def make_names_unique(names):
    """
    Make a list of names unique by appending a suffix to duplicates.
    """
    name_count = {}
    unique_names = []
    for name in names:
        if name in name_count:
            name_count[name] += 1
            unique_names.append(f"{name}_{name_count[name]}")
        else:
            name_count[name] = 0
            unique_names.append(name)
    return unique_names


def read_mtx_data(data_dir, prefix):
    """
    Read data from 10X mtx format.
    
    Args:
        data_dir: Path to the data directory
        prefix: Prefix of the files to read (e.g. GSE284526_LBA001)
        
    Returns:
        ad.AnnData object with the data
    """
    mtx_file = os.path.join(data_dir, f"{prefix}_matrix.mtx.gz")
    features_file = os.path.join(data_dir, f"{prefix}_features.tsv.gz")
    barcodes_file = os.path.join(data_dir, f"{prefix}_barcodes.tsv.gz")
    
    # Read matrix
    matrix = sio.mmread(mtx_file).T.tocsr()
    
    # Read features
    with gzip.open(features_file, 'rt') as f:
        feature_df = pd.DataFrame([line.strip().split('\t') for line in f])
    
    # Read barcodes
    with gzip.open(barcodes_file, 'rt') as f:
        # Add prefix to barcode to keep them unique across multiple samples
        barcodes = [f"{line.strip()}_{prefix}" for line in f]
    
    # Create AnnData object
    adata = ad.AnnData(X=matrix, obs=pd.DataFrame(index=barcodes))
    
    # Set var names and attributes
    # Typically:
    #   feature_df col0 = gene ID
    #   feature_df col1 = gene name
    #   feature_df col2 = feature type
    if feature_df.shape[1] >= 3:
        var_names = pd.Series(feature_df[1].values).astype(str)
        adata.var['original_name'] = var_names.values
        
        # Make var_names unique
        var_names_unique = pd.Series(make_names_unique(var_names))
        adata.var_names = var_names_unique
        
        adata.var['feature_id'] = feature_df[0].values
        adata.var['feature_type'] = feature_df[2].values
    else:
        # Fallback if the file has fewer columns
        var_names = pd.Series(feature_df[0].values).astype(str)
        adata.var['original_name'] = var_names.values
        adata.var_names = pd.Series(make_names_unique(var_names))
    
    # Add sample information
    adata.obs['sample_id'] = prefix
    
    return adata


def split_gene_protein_data(adata):
    """
    Split AnnData object into gene expression and protein data.
    
    Args:
        adata: AnnData object with gene expression and protein data
        
    Returns:
        (gene_adata, protein_adata)
    """
    # Get gene expression data
    gene_mask = adata.var['feature_type'] == 'Gene Expression'
    gene_adata = adata[:, gene_mask].copy()
    
    # Get protein data
    protein_mask = adata.var['feature_type'] == 'Antibody Capture'
    protein_adata = adata[:, protein_mask].copy()
    
    # Filter out "Hashtag" features from protein data, if present
    if 'original_name' in protein_adata.var:
        non_hashtag_mask = ~protein_adata.var['original_name'].str.contains('Hashtag', na=False)
        protein_adata = protein_adata[:, non_hashtag_mask].copy()
    
    return gene_adata, protein_adata


def add_metadata(adata, sample_id):
    """
    Add metadata to AnnData object based on sample ID and hashtags.
    
    Args:
        adata: AnnData object
        sample_id: Sample ID (e.g. GSE284526_LBA001, GSE284526_LBA002, etc.)
        
    Returns:
        adata with added metadata
    """
    # Identify any hashtag-based sub-samples if present
    if sample_id in ['GSE284526_LBA001', 'GSE284526_LBA002']:
        if 'original_name' in adata.var:
            hashtag_features = adata.var[adata.var['original_name'].str.contains('Hashtag', na=False)]
        else:
            hashtag_features = pd.DataFrame()
        
        if not hashtag_features.empty:
            # Summed hashtag expression
            hashtag_indices = hashtag_features.index
            hashtag_names = hashtag_features['original_name'].values
            idx_to_name = dict(zip(hashtag_indices, hashtag_names))
            
            hashtag_counts = pd.DataFrame(
                adata[:, hashtag_indices].X.toarray(),
                index=adata.obs_names,
                columns=[idx_to_name[idx] for idx in hashtag_indices]
            )
            
            # Each cell assigned to the hashtag with the highest count
            adata.obs['hashtag'] = hashtag_counts.idxmax(axis=1)
            
            if sample_id == 'GSE284526_LBA001':
                adata.obs['sample'] = adata.obs['hashtag'].map({
                    'Hashtag1': 'ATRIP',
                    'Hashtag2': 'HC1'
                }).fillna('Unknown')
            else:  # LBA002
                adata.obs['sample'] = adata.obs['hashtag'].map({
                    'Hashtag3': 'HC2',
                    'Hashtag4': 'HC3'
                }).fillna('Unknown')
        else:
            # No hashtags found
            adata.obs['sample'] = sample_id
    else:
        # LBA003/LBA004 are B-cell purified, so treat them as single samples
        adata.obs['sample'] = sample_id
    
    # Standard metadata
    adata.obs['organism'] = 'Homo sapiens'
    adata.obs['cell_type'] = 'Unknown'
    adata.obs['crispr_type'] = 'None'
    adata.obs['cancer_type'] = 'Non-Cancer'
    
    # Add condition mapping
    adata.obs['condition'] = adata.obs['sample'].map({
        'ATRIP': 'ATRIP-deficient',
        'HC1': 'Healthy Control',
        'HC2': 'Healthy Control',
        'HC3': 'Healthy Control',
        'GSE284526_LBA003': 'B cells',
        'GSE284526_LBA004': 'B cells'
    }).fillna('Unknown')
    
    adata.obs['perturbation_name'] = 'None'
    
    # Refine cell_type
    if 'LBA003' in sample_id or 'LBA004' in sample_id:
        adata.obs['cell_type'] = 'B cells'
    elif 'LBA001' in sample_id or 'LBA002' in sample_id:
        adata.obs['cell_type'] = 'PBMCs'
    
    return adata


def harmonize_data(data_dir):
    """
    Harmonize data from the GSE284526 dataset.
    
    Args:
        data_dir: Path to the data directory
        
    Returns:
        (gene_adata, protein_adata)
    """
    # 1. Download files if needed
    download_files(data_dir)
    
    # 2. Read and process each sample
    samples = ['GSE284526_LBA001', 'GSE284526_LBA002', 'GSE284526_LBA003', 'GSE284526_LBA004']
    adatas = []
    for prefix in samples:
        print(f"\nProcessing {prefix}...")
        adata = read_mtx_data(data_dir, prefix)
        adata = add_metadata(adata, prefix)
        adatas.append(adata)
    
    # 3. Concatenate all samples
    print("\nMerging samples...")
    merged_adata = ad.concat(adatas, join='outer', merge='same')
    
    # 4. Split into gene-expression and protein-expression data
    print("Splitting gene expression and protein data...")
    gene_adata, protein_adata = split_gene_protein_data(merged_adata)
    
    # 5. Filter for overlapping cells (i.e., those that have both gene + protein data)
    print("Filtering for paired data...")
    common_cells = np.intersect1d(gene_adata.obs_names, protein_adata.obs_names)
    print(f"Found {len(common_cells)} cells with both gene and protein data.")
    
    if len(common_cells) > 0:
        gene_adata = gene_adata[common_cells].copy()
        protein_adata = protein_adata[common_cells].copy()
        
        # Synchronize metadata columns
        for col in gene_adata.obs.columns:
            if col in protein_adata.obs.columns:
                protein_adata.obs[col] = gene_adata.obs[col]
    else:
        print("Warning: No overlapping cells found. Keeping all cells in separate objects.")
    
    # 6. Clean up gene names
    print("Cleaning gene names...")
    if 'feature_id' in gene_adata.var:
        gene_symbols = []
        for idx, row in gene_adata.var.iterrows():
            if 'original_name' in row and pd.notna(row['original_name']):
                gene_symbols.append(row['original_name'])
            else:
                # Fallback to feature_id
                feature_id = row['feature_id']
                gene_symbols.append(feature_id if isinstance(feature_id, str) else idx)
        
        # Store original IDs
        gene_adata.var['ensembl_id'] = gene_adata.var['feature_id']
        
        # Make gene symbols unique
        unique_gene_symbols = make_names_unique(gene_symbols)
        gene_adata.var_names = unique_gene_symbols
    
    # 7. Clean up protein names: remove text after first underscore, make unique if needed
    print("Cleaning protein names...")
    if 'original_name' in protein_adata.var:
        
        def clean_protein_name(name):
            # Remove anything after the first underscore
            if "_" in name:
                return name.split("_", 1)[0]
            else:
                return name
        
        # Transform each original_name
        protein_raw_names = protein_adata.var['original_name'].tolist()
        cleaned_names = [clean_protein_name(n) for n in protein_raw_names]
        
        # Make them unique
        cleaned_unique_names = make_names_unique(cleaned_names)
        protein_adata.var_names = cleaned_unique_names
    
    return gene_adata, protein_adata


def main(data_dir):
    """
    Main function: Harmonize data and write outputs to disk.
    """
    # Harmonize
    gene_adata, protein_adata = harmonize_data(data_dir)
    
    # Save .h5ad
    gene_output_file = os.path.join(data_dir, "GSE284526_gene_expression.h5ad")
    protein_output_file = os.path.join(data_dir, "GSE284526_protein_expression.h5ad")
    
    print(f"\nSaving gene expression data to {gene_output_file}...")
    gene_adata.write_h5ad(gene_output_file)
    
    print(f"Saving protein expression data to {protein_output_file}...")
    protein_adata.write_h5ad(protein_output_file)
    
    # Final summary
    print("Done!\n")
    print(f"Gene expression data: {gene_adata.shape[0]} cells, {gene_adata.shape[1]} genes")
    print(f"Protein expression data: {protein_adata.shape[0]} cells, {protein_adata.shape[1]} proteins")


# In a Jupyter Notebook, you can call:
main("/content/GSE284526")

# If you want to see the in-memory AnnData objects without saving:
# gene_adata, protein_adata = harmonize_data("/path/to/data_dir")
