In [None]:
import os
import sys
import re
import h5py
import numpy as np
import pandas as pd
from scipy import sparse
import anndata
import urllib.request
import tarfile
import warnings
from pathlib import Path
import scanpy as sc  # Import Scanpy for QC metrics

# Suppress warnings
warnings.filterwarnings('ignore')

def download_data(data_dir, accession="GSE286927"):
    """Download the dataset if not already present."""
    tar_file = os.path.join(data_dir, f"{accession}_RAW.tar")
    
    if not os.path.exists(tar_file):
        print(f"Downloading {accession} dataset...")
        url = f"https://www.ncbi.nlm.nih.gov/geo/download/?acc={accession}&format=file"
        urllib.request.urlretrieve(url, tar_file)
        print("Download complete.")
    
    # Extract if needed
    h5_files = [f for f in os.listdir(data_dir) if f.endswith('.h5')]
    if not h5_files:
        print("Extracting files...")
        with tarfile.open(tar_file) as tar:
            tar.extractall(path=data_dir)
        print("Extraction complete.")

def read_10x_h5(file_path):
    """Read 10x h5 file and return data as a dictionary."""
    with h5py.File(file_path, 'r') as f:
        # Get matrix dimensions
        shape = f['matrix']['shape'][:]
        
        # Read sparse matrix data
        data = f['matrix']['data'][:]
        indices = f['matrix']['indices'][:]
        indptr = f['matrix']['indptr'][:]
        
        # Create sparse matrix (CSC format - column-major)
        matrix = sparse.csc_matrix((data, indices, indptr), shape=shape)
        
        # Read barcodes
        barcodes = [b.decode() for b in f['matrix']['barcodes'][:]]
        
        # Read feature information
        feature_ids = [i.decode() for i in f['matrix']['features']['id'][:]]
        feature_names = [n.decode() for n in f['matrix']['features']['name'][:]]
        feature_types = [t.decode() for t in f['matrix']['features']['feature_type'][:]]
        
        # Additional feature information if available
        feature_info = {}
        for key in f['matrix']['features'].keys():
            if key not in ['id', 'name', 'feature_type']:
                try:
                    feature_info[key] = [
                        v.decode() if isinstance(v, bytes) else v
                        for v in f['matrix']['features'][key][:]
                    ]
                except:
                    pass  # Skip if we can't decode
        
        return {
            'matrix': matrix,
            'barcodes': barcodes,
            'feature_ids': feature_ids,
            'feature_names': feature_names,
            'feature_types': feature_types,
            'feature_info': feature_info
        }

def split_gene_protein_data(data):
    """Split data into gene expression and protein (antibody) data."""
    gene_indices = [i for i, ft in enumerate(data['feature_types']) if ft == 'Gene Expression']
    antibody_indices = [i for i, ft in enumerate(data['feature_types']) if ft == 'Antibody Capture']
    
    gene_matrix = data['matrix'][gene_indices, :]
    antibody_matrix = data['matrix'][antibody_indices, :]
    
    gene_feature_ids = [data['feature_ids'][i] for i in gene_indices]
    gene_feature_names = [data['feature_names'][i] for i in gene_indices]
    
    antibody_feature_ids = [data['feature_ids'][i] for i in antibody_indices]
    antibody_feature_names = [data['feature_names'][i] for i in antibody_indices]
    
    gene_features = pd.DataFrame({
        'feature_id': gene_feature_ids,
        'gene_symbol': gene_feature_names
    })
    antibody_features = pd.DataFrame({
        'feature_id': antibody_feature_ids,
        'protein_name': antibody_feature_names
    })
    
    # Get cells that have nonzero counts in both gene and protein data
    cells_with_gene_expr = gene_matrix.getnnz(axis=0) > 0
    cells_with_antibody_expr = antibody_matrix.getnnz(axis=0) > 0
    cells_with_both = cells_with_gene_expr & cells_with_antibody_expr
    
    valid_cell_indices = np.where(cells_with_both)[0]
    filtered_gene_matrix = gene_matrix[:, valid_cell_indices]
    filtered_antibody_matrix = antibody_matrix[:, valid_cell_indices]
    filtered_barcodes = [data['barcodes'][i] for i in valid_cell_indices]
    
    return {
        'gene': {
            'matrix': filtered_gene_matrix,
            'features': gene_features,
            'barcodes': filtered_barcodes
        },
        'protein': {
            'matrix': filtered_antibody_matrix,
            'features': antibody_features,
            'barcodes': filtered_barcodes
        }
    }

def create_anndata_objects(split_data, sample_info):
    """Create AnnData objects for gene expression and protein data."""
    # Gene AnnData
    gene_adata = anndata.AnnData(
        X=split_data['gene']['matrix'].T,  # cells x genes
        var=split_data['gene']['features'].set_index('feature_id'),
        obs=pd.DataFrame(index=split_data['gene']['barcodes'])
    )
    # Protein AnnData
    protein_adata = anndata.AnnData(
        X=split_data['protein']['matrix'].T,  # cells x proteins
        var=split_data['protein']['features'].set_index('feature_id'),
        obs=pd.DataFrame(index=split_data['protein']['barcodes'])
    )
    
    # Add sample info
    for key, value in sample_info.items():
        gene_adata.obs[key] = value
        protein_adata.obs[key] = value
    
    # Set var_names for gene data
    gene_adata.var_names = gene_adata.var['gene_symbol'].values
    gene_adata.var_names_make_unique()
    
    # Set var_names for protein data
    protein_adata.var_names = protein_adata.var['protein_name'].values
    
    # ---- REMOVE "_TotalSeqC" SUFFIX IN PROTEIN NAMES ----
    protein_adata.var_names = [
        name.replace("_TotalSeqC", "") for name in protein_adata.var_names
    ]
    
    protein_adata.var_names_make_unique()
    
    return gene_adata, protein_adata

def harmonize_metadata(adata, sample_metadata):
    """Add standardized metadata to AnnData objects."""
    adata.obs['organism'] = 'Homo sapiens'
    adata.obs['cell_type'] = sample_metadata.get('cell_type', 'Marginal Zone Lymphoma')
    adata.obs['crispr_type'] = sample_metadata.get('crispr_type', 'None')
    adata.obs['cancer_type'] = sample_metadata.get('cancer_type', 'Marginal Zone Lymphoma')
    adata.obs['condition'] = sample_metadata.get('condition', 'Tumor')
    adata.obs['perturbation_name'] = sample_metadata.get('perturbation_name', 'None')
    
    # Add other metadata
    for key, value in sample_metadata.items():
        if key not in [
            'organism','cell_type','crispr_type','cancer_type',
            'condition','perturbation_name'
        ]:
            adata.obs[key] = value
    
    return adata

def process_sample(file_path, sample_metadata):
    """Process a single sample file (gene + protein)."""
    print(f"Processing {file_path}...")
    
    data = read_10x_h5(file_path)
    split_data = split_gene_protein_data(data)
    
    gene_adata, protein_adata = create_anndata_objects(split_data, sample_metadata)
    gene_adata = harmonize_metadata(gene_adata, sample_metadata)
    protein_adata = harmonize_metadata(protein_adata, sample_metadata)
    
    # ------------------------------------------------
    # RENAME OBS (BARCODES) TO AVOID DUPLICATES ACROSS SAMPLES
    # ------------------------------------------------
    sample_id = sample_metadata['sample_id']  # e.g. "GSM8734453"
    
    new_gene_obs = [f"{barcode}_{sample_id}" for barcode in gene_adata.obs_names]
    new_prot_obs = [f"{barcode}_{sample_id}" for barcode in protein_adata.obs_names]
    
    gene_adata.obs_names = new_gene_obs
    protein_adata.obs_names = new_prot_obs
    
    return gene_adata, protein_adata

def main():
    """Main function to process the dataset."""
    if len(sys.argv) > 1:
        data_dir = sys.argv[1]
    else:
        data_dir = os.getcwd()
    
    os.makedirs(data_dir, exist_ok=True)
    
    # Download data if needed
    download_data(data_dir)
    
    # Collect raw h5 files
    h5_files = [f for f in os.listdir(data_dir) if f.endswith('_raw_feature_bc_matrix.h5')]
    
    # Extract sample information from filenames
    sample_info = {}
    for file_path in h5_files:
        match = re.match(r'(GSM\d+)_pt(\d+)_(raw|filtered)_feature_bc_matrix\.h5', file_path)
        if match:
            gsm_id, patient_id, data_type = match.groups()
            if gsm_id not in sample_info:
                sample_info[gsm_id] = {}
            sample_info[gsm_id]['patient_id'] = patient_id
            sample_info[gsm_id][f'{data_type}_file'] = file_path
    
    all_gene_adatas = []
    all_protein_adatas = []
    
    # Process each sample
    for gsm_id, info in sample_info.items():
        if 'raw_file' in info:
            file_path = os.path.join(data_dir, info['raw_file'])
            sample_metadata = {
                'sample_id': gsm_id,                         # e.g. GSM8734453
                'patient_id': f"patient_{info['patient_id']}",
                'cell_type': 'Marginal Zone Lymphoma',
                'cancer_type': 'Marginal Zone Lymphoma',
                'condition': 'Tumor',
                'perturbation_name': 'None'
            }
            gene_adata, protein_adata = process_sample(file_path, sample_metadata)
            all_gene_adatas.append(gene_adata)
            all_protein_adatas.append(protein_adata)
    
    if all_gene_adatas:
        print("Combining all gene AnnData objects...")
        combined_gene = anndata.concat(all_gene_adatas, join='outer')
        
        print("Combining all protein AnnData objects...")
        combined_protein = anndata.concat(all_protein_adatas, join='outer')
        
        # --------------------
        # PERFORM QC ON GENE DATA
        # --------------------
        print("Computing QC metrics for gene expression data...")
        sc.pp.calculate_qc_metrics(combined_gene, inplace=True)
        
        # Keep cells with at least 200 genes
        qc_threshold = 200
        combined_gene_qc = combined_gene[combined_gene.obs['n_genes_by_counts'] >= qc_threshold].copy()
        print("Number of cells in gene data after QC:", combined_gene_qc.n_obs)
        
        # -----------------------------
        # SUBSET PROTEIN DATA TO MATCH
        # -----------------------------
        valid_cells = combined_gene_qc.obs_names.intersection(combined_protein.obs_names)
        
        combined_gene_qc = combined_gene_qc[valid_cells].copy()
        combined_protein_qc = combined_protein[valid_cells].copy()
        
        print("Number of cells in gene data after intersection:", combined_gene_qc.n_obs)
        print("Number of cells in protein data after intersection:", combined_protein_qc.n_obs)
        
        # --------------------
        # SAVE RESULTS
        # --------------------
        combined_gene_output = os.path.join(data_dir, "GSE286927_gene_expression_qc.h5ad")
        combined_protein_output = os.path.join(data_dir, "GSE286927_protein_expression_qc.h5ad")
        
        print(f"Saving gene data to {combined_gene_output}...")
        combined_gene_qc.write_h5ad(combined_gene_output)
        
        print(f"Saving protein data to {combined_protein_output}...")
        combined_protein_qc.write_h5ad(combined_protein_output)
        
        # Final summary
        print("\nFinal Summary:")
        print(f"Gene data cells: {combined_gene_qc.n_obs}")
        print(f"Gene data features: {combined_gene_qc.n_vars}")
        print(f"Protein data cells: {combined_protein_qc.n_obs}")
        print(f"Protein data features: {combined_protein_qc.n_vars}")
        print("\nDone!")

# Run in a Jupyter Notebook cell or as a script
if __name__ == "__main__":
    main()
