In [None]:
import os
import sys
import urllib.request
import tarfile
import gzip
import numpy as np
import pandas as pd
import anndata as ad
from scipy import sparse
from scipy.io import mmread

def clean_protein_names(protein_names):
    """
    Clean protein names to follow conventional naming by removing:
    1. Prefixes like 'Hu.', 'HuMs.', etc.
    2. Clone information after '_' (e.g., '_M5E2')
    3. Special cases for isotypes and CMO markers
    
    Parameters:
    -----------
    protein_names : list
        List of protein names to clean
        
    Returns:
    --------
    list
        Cleaned protein names
    """
    cleaned_names = []
    
    for name in protein_names:
        # Handle CMO markers (Cell Multiplexing Oligos) - keep as is
        if name.startswith('CMO'):
            cleaned_names.append(name)
            continue
            
        # Handle isotypes - extract just the isotype name
        if name.startswith('Isotype_'):
            isotype = name.split('_', 1)[1].split('.')[0]
            cleaned_names.append(f"Isotype_{isotype}")
            continue
            
        # Remove species prefixes (Hu., HuMs., HuMsRt., etc.)
        if '.' in name:
            parts = name.split('.')
            name = parts[-1]
            
        # Remove clone information after '_'
        if '_' in name:
            name = name.split('_')[0]
            
        # Special case for TCR variants
        if name.startswith('TCR'):
            name = name.replace('.', '')
            
        # Special case for HLA
        if name.startswith('HLA'):
            name = name.replace('.', '-')
            
        # Special case for integrin
        if name == 'integrinb7':
            name = 'ITGB7'  # conventional name for integrin beta-7
            
        cleaned_names.append(name)
    
    return cleaned_names

def download_data(accession, output_dir):
    """
    Download GEO dataset if not already present.
    
    Parameters:
    -----------
    accession : str
        GEO accession number
    output_dir : str
        Directory to save downloaded data
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    tar_file = os.path.join(output_dir, f"{accession}_RAW.tar")
    
    if not os.path.exists(tar_file):
        print(f"Downloading {accession} dataset...")
        url = f"https://www.ncbi.nlm.nih.gov/geo/download/?acc={accession}&format=file"
        urllib.request.urlretrieve(url, tar_file)
        print(f"Download complete: {tar_file}")
    else:
        print(f"Using existing download: {tar_file}")
    
    # Extract if not already extracted
    expected_file = os.path.join(output_dir, "GSM7866650_4h_barcodes.tsv.gz")
    if not os.path.exists(expected_file):
        print(f"Extracting {tar_file}...")
        with tarfile.open(tar_file) as tar:
            tar.extractall(path=output_dir)
        print("Extraction complete")
    else:
        print("Files already extracted")
    
    return output_dir

def read_mtx_data(base_path, prefix):
    """
    Read 10x data in MTX format.
    
    Parameters:
    -----------
    base_path : str
        Base directory containing the files
    prefix : str
        Prefix for the files (e.g., "GSM7866650_4h")
    
    Returns:
    --------
    Tuple of (matrix, features, barcodes, timepoint)
    """
    print(f"Reading {prefix} data...")
    
    # File paths
    matrix_file = os.path.join(base_path, f"{prefix}_matrix.mtx.gz")
    features_file = os.path.join(base_path, f"{prefix}_features.tsv.gz")
    barcodes_file = os.path.join(base_path, f"{prefix}_barcodes.tsv.gz")
    
    # Read the matrix
    with gzip.open(matrix_file, 'rb') as f:
        matrix = sparse.csr_matrix(mmread(f).T)
    
    # Read features (genes/proteins)
    features = pd.read_csv(features_file, sep='\t', header=None)
    features.columns = ['id', 'name', 'feature_type']
    
    # Read cell barcodes
    barcodes = pd.read_csv(barcodes_file, sep='\t', header=None)
    barcodes.columns = ['barcode']
    
    # Extract timepoint from prefix
    timepoint = prefix.split('_')[1]  # Extract '4h' or '24h'
    
    # Make barcodes unique by adding timepoint suffix
    barcodes['unique_barcode'] = barcodes['barcode'] + "_" + timepoint
    
    return matrix, features, barcodes, timepoint

def create_anndata(matrix, features, barcodes, timepoint, feature_type):
    """
    Create AnnData object for either gene or protein data.
    
    Parameters:
    -----------
    matrix : scipy.sparse.csr_matrix
        Expression matrix
    features : pandas.DataFrame
        Feature metadata
    barcodes : pandas.DataFrame
        Cell barcodes
    timepoint : str
        Timepoint ('4h' or '24h')
    feature_type : str
        'Gene Expression' or 'Antibody Capture'
    
    Returns:
    --------
    AnnData object
    """
    # Filter features by type
    type_indices = features[features['feature_type'] == feature_type].index
    type_features = features.iloc[type_indices].reset_index(drop=True)
    
    # Subset matrix
    type_matrix = matrix[:, type_indices]
    
    # Handle feature names
    feature_names = type_features['name'].values
    if len(feature_names) != len(set(feature_names)):
        print(f"Found duplicate {feature_type} names, making them unique...")
        if feature_type == 'Gene Expression':
            feature_names = [f"{name}_{id_}" for name, id_ in zip(type_features['name'], type_features['id'])]
        else:  # Antibody Capture - will use clean_protein_names later
            # Just append indices for now to ensure uniqueness
            feature_names = [f"{name}_{i}" for i, name in enumerate(feature_names)]
    
    # Create AnnData object with unique indices
    adata = ad.AnnData(
        X=type_matrix,
        obs=pd.DataFrame(index=barcodes['unique_barcode']),
        var=pd.DataFrame(index=feature_names)
    )
    
    # Add metadata
    adata.obs['original_barcode'] = barcodes['barcode'].values
    adata.obs['timepoint'] = timepoint
    adata.var['feature_id'] = type_features['id'].values
    adata.var['feature_type'] = feature_type
    
    return adata

def clean_protein_adata_names(protein_adata):
    """
    Update protein names in an AnnData object to use conventional naming.
    
    Parameters:
    -----------
    protein_adata : AnnData
        AnnData object with protein expression data
        
    Returns:
    --------
    AnnData
        AnnData object with updated protein names
    """
    # Get current protein names (strip any previously added indices)
    protein_names = [name.split('_')[0] if '_' in name and name.split('_')[-1].isdigit() 
                    else name for name in protein_adata.var_names.tolist()]
    
    # Clean the names
    cleaned_names = clean_protein_names(protein_names)
    
    # Create a mapping dictionary to track which original names map to which cleaned names
    name_mapping = {}
    
    # Handle duplicates to ensure uniqueness
    unique_cleaned_names = []
    name_counts = {}
    
    for i, name in enumerate(cleaned_names):
        if name in name_counts:
            name_counts[name] += 1
            unique_name = f"{name}_{name_counts[name]}"
        else:
            name_counts[name] = 0
            unique_name = name
        
        unique_cleaned_names.append(unique_name)
        name_mapping[protein_adata.var_names[i]] = unique_name
    
    # Create a new var DataFrame with updated index
    new_var = protein_adata.var.copy()
    new_var.index = [name_mapping[name] for name in protein_adata.var_names]
    
    # Create a new AnnData object with updated var names
    cleaned_adata = ad.AnnData(
        X=protein_adata.X.copy(),
        obs=protein_adata.obs.copy(),
        var=new_var,
        uns=protein_adata.uns.copy(),
        obsm=protein_adata.obsm.copy() if protein_adata.obsm is not None else None,
        varm=protein_adata.varm.copy() if protein_adata.varm is not None else None
    )
    
    return cleaned_adata

def harmonize_data(adata, data_type):
    """
    Harmonize data according to required standards.
    
    Parameters:
    -----------
    adata : AnnData
        AnnData object to harmonize
    data_type : str
        Type of data ('gene' or 'protein')
    
    Returns:
    --------
    Harmonized AnnData object
    """
    print(f"Harmonizing {data_type} data...")
    
    # Add required metadata fields
    adata.obs['organism'] = 'Homo sapiens'
    adata.obs['cell_type'] = 'PBMC'  # All cells are PBMCs
    adata.obs['crispr_type'] = 'None'  # No CRISPR perturbation in this dataset
    adata.obs['cancer_type'] = 'Non-Cancer'  # These are healthy donor PBMCs
    
    # Extract condition from timepoint
    timepoint_map = {
        '4h': 'ICI_treatment_4h',
        '24h': 'ICI_treatment_24h'
    }
    adata.obs['condition'] = adata.obs['timepoint'].map(timepoint_map)
    
    # Set perturbation_name to immune checkpoint inhibitors
    adata.obs['perturbation_name'] = 'Immune Checkpoint Inhibitors'
    
    # Convert to categorical
    for col in ['organism', 'cell_type', 'crispr_type', 'cancer_type', 'condition', 'perturbation_name']:
        adata.obs[col] = adata.obs[col].astype('category')
    
    return adata

def process_dataset(data_dir, output_dir):
    """
    Process the GSE246317 dataset.
    
    Parameters:
    -----------
    data_dir : str
        Directory containing the data files
    output_dir : str
        Directory to save the processed data
        
    Returns:
    --------
    Tuple of (gene_adata, protein_adata)
    """
    # Create output directory if it doesn't exist
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    # Read 4h data
    matrix_4h, features_4h, barcodes_4h, timepoint_4h = read_mtx_data(data_dir, "GSM7866650_4h")
    
    # Read 24h data
    matrix_24h, features_24h, barcodes_24h, timepoint_24h = read_mtx_data(data_dir, "GSM7866653_24h")
    
    # Create AnnData objects for 4h gene expression and protein
    gene_adata_4h = create_anndata(matrix_4h, features_4h, barcodes_4h, timepoint_4h, 'Gene Expression')
    protein_adata_4h = create_anndata(matrix_4h, features_4h, barcodes_4h, timepoint_4h, 'Antibody Capture')
    
    # Create AnnData objects for 24h gene expression and protein
    gene_adata_24h = create_anndata(matrix_24h, features_24h, barcodes_24h, timepoint_24h, 'Gene Expression')
    protein_adata_24h = create_anndata(matrix_24h, features_24h, barcodes_24h, timepoint_24h, 'Antibody Capture')
    
    print(f"Gene expression data (4h): {gene_adata_4h.shape[0]} cells, {gene_adata_4h.shape[1]} genes")
    print(f"Protein data (4h): {protein_adata_4h.shape[0]} cells, {protein_adata_4h.shape[1]} proteins")
    print(f"Gene expression data (24h): {gene_adata_24h.shape[0]} cells, {gene_adata_24h.shape[1]} genes")
    print(f"Protein data (24h): {protein_adata_24h.shape[0]} cells, {protein_adata_24h.shape[1]} proteins")
    
    # Verify indices are unique
    assert len(gene_adata_4h.obs.index) == len(set(gene_adata_4h.obs.index)), "4h gene indices not unique"
    assert len(protein_adata_4h.obs.index) == len(set(protein_adata_4h.obs.index)), "4h protein indices not unique"
    assert len(gene_adata_24h.obs.index) == len(set(gene_adata_24h.obs.index)), "24h gene indices not unique"
    assert len(protein_adata_24h.obs.index) == len(set(protein_adata_24h.obs.index)), "24h protein indices not unique"
    
    # Concatenate gene data
    gene_adata = ad.concat(
        [gene_adata_4h, gene_adata_24h],
        join='outer',  # Use outer join to include all genes
        merge='first'  # Keep first occurrence of overlapping variables
    )
    
    # Concatenate protein data
    protein_adata = ad.concat(
        [protein_adata_4h, protein_adata_24h],
        join='outer',  # Use outer join to include all proteins
        merge='first'  # Keep first occurrence of overlapping variables
    )
    
    # Verify indices are unique after concatenation
    assert len(gene_adata.obs.index) == len(set(gene_adata.obs.index)), "Concatenated gene indices not unique"
    assert len(protein_adata.obs.index) == len(set(protein_adata.obs.index)), "Concatenated protein indices not unique"
    
    # Clean protein names to conventional format
    protein_adata = clean_protein_adata_names(protein_adata)
    
    # Verify that all cells have both gene and protein data
    common_cells = set(gene_adata.obs.index).intersection(set(protein_adata.obs.index))
    print(f"Common cells between gene and protein data: {len(common_cells)}")
    
    # Subset both datasets to common cells
    gene_adata = gene_adata[list(common_cells)].copy()
    protein_adata = protein_adata[list(common_cells)].copy()
    
    # Sort both datasets by the same cell order
    cell_order = sorted(common_cells)
    gene_adata = gene_adata[cell_order].copy()
    protein_adata = protein_adata[cell_order].copy()
    
    # Verify alignment
    assert (gene_adata.obs.index == protein_adata.obs.index).all(), "Indices not aligned after sorting"
    
    # Harmonize data
    gene_adata = harmonize_data(gene_adata, 'gene')
    protein_adata = harmonize_data(protein_adata, 'protein')
    
    # Add protein data as a layer in gene_adata
    gene_adata.obsm['protein_expression'] = protein_adata.X.copy()
    
    # Store protein names
    gene_adata.uns['protein_names'] = protein_adata.var_names.tolist()
    
    # Save
    gene_output_file = os.path.join(output_dir, "GSE246317_gene_expression.h5ad")
    protein_output_file = os.path.join(output_dir, "GSE246317_protein_expression.h5ad")
    
    print(f"Saving gene expression data to {gene_output_file}")
    gene_adata.write(gene_output_file)
    
    print(f"Saving protein expression data to {protein_output_file}")
    protein_adata.write(protein_output_file)
    
    print("Processing complete!")
    
    return gene_adata, protein_adata

def main():
    """Main function for Jupyter Notebook environment"""
    # Use current working directory as base
    script_dir = os.getcwd()
    data_dir = os.path.join(script_dir, "GSE246317")
    output_dir = os.path.join(script_dir, "processed")
    
    # Download and extract data if needed
    download_data("GSE246317", data_dir)
    
    # Process the dataset
    gene_adata, protein_adata = process_dataset(data_dir, output_dir)
    
    # Print summary
    print(f"\nSummary:")
    print(f"Gene expression data: {gene_adata.shape[0]} cells, {gene_adata.shape[1]} genes")
    print(f"Protein expression data: {protein_adata.shape[0]} cells, {protein_adata.shape[1]} proteins")
    
    # Print summary of harmonized metadata
    print("\nHarmonized metadata summary:")
    for col in ['organism', 'cell_type', 'crispr_type', 'cancer_type', 'condition', 'perturbation_name']:
        print(f"{col}: {gene_adata.obs[col].unique()}")
    
    print(f"\nOutput files:")
    print(f"- {os.path.join(output_dir, 'GSE246317_gene_expression.h5ad')}")
    print(f"- {os.path.join(output_dir, 'GSE246317_protein_expression.h5ad')}")

# Call main directly to run the pipeline in Jupyter Notebook
if __name__ == "__main__":
    main()