In [None]:
import os
import gzip
import urllib.request
import pandas as pd
import numpy as np
import scanpy as sc
import anndata as ad
from scipy import sparse
from pathlib import Path
from typing import Dict, List, Tuple
import time
from concurrent.futures import ThreadPoolExecutor
import re

# Constants
GEO_ACCESSION = "GSE261278"
BASE_URL = f"https://ftp.ncbi.nlm.nih.gov/geo/series/GSE261nnn/{GEO_ACCESSION}/suppl/"
SAMPLE_IDS = ["D534", "D559", "D561", "D564", "D568", "D574", "D616", "FLD032FLD040FLD015FLD041FLD029FLD030"]

def simplify_protein_name(name: str) -> str:
    """
    Simplify a protein name by removing parentheses and extra tokens.
    For example, "anti-human_CD101_(BB27)" becomes "CD101".
    """
    # Remove any parentheses and content inside them
    name = re.sub(r'\s*\(.*?\)', '', name)
    
    # If the name starts with "anti-", drop everything before (and including) the first underscore.
    if name.startswith("anti-"):
        parts = name.split('_', 1)
        if len(parts) > 1:
            name = parts[1]
    
    # Split the remaining string by underscores.
    tokens = name.split('_')
    # Return the first token that starts with "CD" or "Ig"
    for token in tokens:
        token_clean = token.rstrip(',').strip()
        if token_clean.startswith("CD") or token_clean.startswith("Ig"):
            return token_clean
    # If no such token, return the cleaned name.
    return name.strip()

def download_file(url: str, destination: str) -> None:
    """
    Download a file from a URL to a destination path using a custom user-agent.
    """
    if not os.path.exists(destination):
        print(f"Downloading {url} to {destination}")
        try:
            req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'})
            with urllib.request.urlopen(req) as response, open(destination, 'wb') as out_file:
                out_file.write(response.read())
            print(f"Downloaded {destination}")
        except Exception as e:
            print(f"Error downloading {url}: {e}")
    else:
        print(f"File {destination} already exists, skipping download")

def download_dataset(data_dir: str) -> None:
    """
    Download all dataset files if they don't exist.
    """
    os.makedirs(data_dir, exist_ok=True)
    
    # Define all files to download
    files_to_download = []
    
    # Add GEX and ADT matrices for each sample
    for sample_id in SAMPLE_IDS:
        files_to_download.extend([
            (f"{BASE_URL}{GEO_ACCESSION}_{sample_id}.GEX.matrix.txt.gz", 
             os.path.join(data_dir, f"{GEO_ACCESSION}_{sample_id}.GEX.matrix.txt.gz")),
            (f"{BASE_URL}{GEO_ACCESSION}_{sample_id}.ADT.matrix.txt.gz", 
             os.path.join(data_dir, f"{GEO_ACCESSION}_{sample_id}.ADT.matrix.txt.gz"))
        ])
    
    # Download files in parallel
    with ThreadPoolExecutor(max_workers=4) as executor:
        for url, destination in files_to_download:
            executor.submit(download_file, url, destination)

def read_matrix_file(file_path: str) -> Tuple[sparse.csr_matrix, pd.DataFrame, List[str]]:
    """
    Read a matrix file and return a sparse matrix, gene/protein info, and cell barcodes.
    """
    print(f"Reading {file_path}")
    start_time = time.time()
    
    # First pass: get dimensions and headers
    with gzip.open(file_path, 'rt') as f:
        header = f.readline().strip().split('\t')
        cell_barcodes = header[2:]  # Skip 'accession' and 'gene' columns
        gene_count = sum(1 for _ in f)
    
    print(f"  Found {gene_count} genes/proteins and {len(cell_barcodes)} cells")
    
    # Second pass: read data efficiently into sparse matrix
    data = []
    row_indices = []
    col_indices = []
    gene_ids = []
    gene_names = []
    
    with gzip.open(file_path, 'rt') as f:
        # Skip header
        f.readline()
        
        for row_idx, line in enumerate(f):
            parts = line.strip().split('\t')
            gene_id = parts[0]
            gene_name = parts[1]
            values = parts[2:]
            
            gene_ids.append(gene_id)
            gene_names.append(gene_name)
            
            # Add non-zero values to sparse matrix
            for col_idx, val in enumerate(values):
                try:
                    val_float = float(val)
                    if val_float != 0:
                        data.append(val_float)
                        row_indices.append(row_idx)
                        col_indices.append(col_idx)
                except ValueError:
                    continue
    
    matrix = sparse.csr_matrix(
        (data, (row_indices, col_indices)), 
        shape=(len(gene_names), len(cell_barcodes)),
        dtype=np.float32
    )
    
    var_df = pd.DataFrame({
        'gene_id': gene_ids,
        'gene_name': gene_names
    }, index=gene_names)
    
    elapsed_time = time.time() - start_time
    print(f"  Processed {file_path} in {elapsed_time:.2f} seconds")
    
    return matrix, var_df, cell_barcodes

def parse_cell_metadata(cell_barcode: str) -> Dict[str, str]:
    """
    Parse metadata from a cell barcode.
    Format: BARCODE-14.HTOX.SAMPLEID.TISSUE.CELLTYPE.CONDITION
    """
    parts = cell_barcode.split('.')
    if len(parts) >= 5:
        barcode = parts[0]
        hto = parts[1] if len(parts) > 1 else ""
        sample_id = parts[2] if len(parts) > 2 else ""
        tissue = parts[3] if len(parts) > 3 else ""
        cell_type = parts[4] if len(parts) > 4 else ""
        condition = parts[5] if len(parts) > 5 else "unstimulated"
    else:
        return {
            "barcode": cell_barcode,
            "hto": "",
            "sample_id": "",
            "tissue": "",
            "cell_type": "",
            "condition": "unstimulated"
        }
    
    return {
        "barcode": barcode,
        "hto": hto,
        "sample_id": sample_id,
        "tissue": tissue,
        "cell_type": cell_type,
        "condition": condition
    }

def create_harmonized_metadata(cell_barcodes: List[str]) -> pd.DataFrame:
    """
    Create harmonized metadata from cell barcodes.
    """
    metadata_list = []
    
    for barcode in cell_barcodes:
        metadata = parse_cell_metadata(barcode)
        organism = "Homo sapiens"  # All samples in this dataset are human
        cell_type = metadata["cell_type"]
        condition = metadata["condition"]
        tissue = metadata["tissue"]
        crispr_type = "None"         # No CRISPR data in this dataset
        cancer_type = "Non-Cancer"     # No cancer data in this dataset
        perturbation_name = "None"     # No perturbation data in this dataset
        
        if condition in ["sars", "flu"]:
            perturbation_name = "SARS-CoV-2" if condition == "sars" else "Influenza"
        
        harmonized_metadata = {
            "barcode": metadata["barcode"],
            "original_barcode": barcode,
            "hto": metadata["hto"],
            "sample_id": metadata["sample_id"],
            "tissue": tissue,
            "cell_type": cell_type,
            "condition": condition,
            "organism": organism,
            "crispr_type": crispr_type,
            "cancer_type": cancer_type,
            "perturbation_name": perturbation_name
        }
        metadata_list.append(harmonized_metadata)
    
    return pd.DataFrame(metadata_list)

def process_sample(sample_id: str, data_dir: str) -> Tuple[ad.AnnData, ad.AnnData]:
    """
    Process a single sample and return AnnData objects for GEX and ADT data.
    """
    gex_path = os.path.join(data_dir, f"{GEO_ACCESSION}_{sample_id}.GEX.matrix.txt.gz")
    adt_path = os.path.join(data_dir, f"{GEO_ACCESSION}_{sample_id}.ADT.matrix.txt.gz")
    
    # Read GEX data
    gex_matrix, gex_var, gex_barcodes = read_matrix_file(gex_path)
    
    # Read ADT data
    adt_matrix, adt_var, adt_barcodes = read_matrix_file(adt_path)
    
    # Find common barcodes between GEX and ADT
    common_barcodes = list(set(gex_barcodes).intersection(adt_barcodes))
    print(f"Sample {sample_id}: {len(common_barcodes)} common barcodes between GEX and ADT")
    
    gex_indices = [gex_barcodes.index(bc) for bc in common_barcodes]
    adt_indices = [adt_barcodes.index(bc) for bc in common_barcodes]
    
    gex_matrix_filtered = gex_matrix[:, gex_indices]
    adt_matrix_filtered = adt_matrix[:, adt_indices]
    
    metadata = create_harmonized_metadata(common_barcodes)
    
    gex_adata = ad.AnnData(
        X=gex_matrix_filtered.T,  # Transpose to have cells as rows
        obs=pd.DataFrame(index=common_barcodes),
        var=pd.DataFrame(index=gex_var.index)
    )
    
    adt_adata = ad.AnnData(
        X=adt_matrix_filtered.T,  # Transpose to have cells as rows
        obs=pd.DataFrame(index=common_barcodes),
        var=pd.DataFrame(index=adt_var.index)
    )
    
    # Add harmonized metadata to obs
    metadata_indexed = metadata.set_index('original_barcode')
    for col in metadata_indexed.columns:
        gex_adata.obs[col] = metadata_indexed[col]
        adt_adata.obs[col] = metadata_indexed[col]
    
    # Add gene/protein information to var
    for col in gex_var.columns:
        gex_adata.var[col] = gex_var[col].values
    for col in adt_var.columns:
        adt_adata.var[col] = adt_var[col].values

    # Simplify protein names in ADT var_names
    adt_adata.var.index = [simplify_protein_name(name) for name in adt_adata.var.index]
    
    # Additional metadata in uns
    gex_adata.uns["sample_id"] = sample_id
    gex_adata.uns["data_type"] = "gene_expression"
    gex_adata.uns["geo_accession"] = GEO_ACCESSION
    
    adt_adata.uns["sample_id"] = sample_id
    adt_adata.uns["data_type"] = "protein_expression"
    adt_adata.uns["geo_accession"] = GEO_ACCESSION
    
    # Check for duplicate gene/protein names and make unique if needed
    if gex_adata.var_names.duplicated().any():
        print(f"Warning: Found {gex_adata.var_names.duplicated().sum()} duplicate gene names in GEX data")
        gex_adata.var_names_make_unique()
    
    if adt_adata.var_names.duplicated().any():
        print(f"Warning: Found {adt_adata.var_names.duplicated().sum()} duplicate protein names in ADT data")
        adt_adata.var_names_make_unique()
    
    return gex_adata, adt_adata

def process_all_samples(data_dir: str) -> Tuple[ad.AnnData, ad.AnnData]:
    """
    Process all samples and combine them into a single AnnData object.
    """
    gex_adatas = []
    adt_adatas = []
    
    for sample_id in SAMPLE_IDS:
        try:
            start_time = time.time()
            gex_adata, adt_adata = process_sample(sample_id, data_dir)
            gex_adatas.append(gex_adata)
            adt_adatas.append(adt_adata)
            elapsed_time = time.time() - start_time
            print(f"Processed sample {sample_id} in {elapsed_time:.2f} seconds")
        except Exception as e:
            print(f"Error processing sample {sample_id}: {e}")
    
    print("Combining all samples...")
    combined_gex = ad.concat(gex_adatas, join="outer", merge="same")
    combined_adt = ad.concat(adt_adatas, join="outer", merge="same")
    
    if combined_gex.var_names.duplicated().any():
        print(f"Warning: Found {combined_gex.var_names.duplicated().sum()} duplicate gene names in combined GEX data")
        combined_gex.var_names_make_unique()
    
    if combined_adt.var_names.duplicated().any():
        print(f"Warning: Found {combined_adt.var_names.duplicated().sum()} duplicate protein names in combined ADT data")
        combined_adt.var_names_make_unique()
    
    return combined_gex, combined_adt

def main(data_dir: str) -> None:
    """
    Main function to process the GSE261278 dataset.
    """
    os.makedirs(data_dir, exist_ok=True)
    
    # Download dataset files if needed
    download_dataset(data_dir)
    
    # Process all samples
    start_time = time.time()
    gex_adata, adt_adata = process_all_samples(data_dir)
    elapsed_time = time.time() - start_time
    print(f"Processed all samples in {elapsed_time:.2f} seconds")
    
    # Save processed data
    output_dir = os.path.join(data_dir, "processed")
    os.makedirs(output_dir, exist_ok=True)
    
    gex_output_path = os.path.join(output_dir, f"{GEO_ACCESSION}_gene_expression.h5ad")
    adt_output_path = os.path.join(output_dir, f"{GEO_ACCESSION}_protein_expression.h5ad")
    
    print(f"Saving gene expression data to {gex_output_path}")
    gex_adata.write(gex_output_path, compression="gzip")
    
    print(f"Saving protein expression data to {adt_output_path}")
    adt_adata.write(adt_output_path, compression="gzip")
    
    print("Processing complete!")
    print(f"Gene expression data shape: {gex_adata.shape}")
    print(f"Protein expression data shape: {adt_adata.shape}")
    
    # Print metadata summary
    print("\nMetadata summary:")
    for field in ["organism", "cell_type", "condition", "perturbation_name", "tissue"]:
        unique_values = gex_adata.obs[field].unique()
        preview = unique_values[:5]
        print(f"{field}: {preview}{'...' if len(unique_values) > 5 else ''}")

# For Jupyter Notebook, set the data directory explicitly
data_dir = os.path.join(os.getcwd(), GEO_ACCESSION)
main(data_dir)
