In [None]:
import os
import gzip
import pandas as pd
import numpy as np
import h5py
import urllib.request
from pathlib import Path
import re

# Check if we can use anndata, otherwise use h5py directly
try:
    import anndata as ad
    HAS_ANNDATA = True
except ImportError:
    HAS_ANNDATA = False

# Constants
GEO_ACCESSION = "GSE164996"
BASE_URL = "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE164nnn/GSE164996/suppl/"
SAMPLES = ["S1", "S2", "S3", "S4"]
FILE_TYPES = [
    "filtered_barcodes.tsv.gz",
    "filtered_features.tsv.gz",
    "filtered_matrix.mtx.gz",
    "protospacer_calls_per_cell.csv.gz"
]

def download_files(data_dir):
    """Download dataset files if they don't exist."""
    os.makedirs(data_dir, exist_ok=True)
    
    for sample in SAMPLES:
        for file_type in FILE_TYPES:
            filename = f"{GEO_ACCESSION}_{sample}_{file_type}"
            filepath = os.path.join(data_dir, filename)
            
            if not os.path.exists(filepath):
                url = f"{BASE_URL}{filename}"
                print(f"Downloading {url} to {filepath}")
                urllib.request.urlretrieve(url, filepath)
            else:
                print(f"File {filepath} already exists, skipping download")

def read_mtx_file(file_path):
    """Read a Matrix Market file and return a CSR sparse matrix."""
    with gzip.open(file_path, 'rt') as f:
        # Skip header lines starting with %
        while True:
            line = f.readline().strip()
            if not line.startswith('%'):
                break
        
        # Parse dimensions
        dims = line.split()
        n_genes, n_cells, n_entries = map(int, dims)
        
        # Initialize lists for COO matrix entries
        data = []
        row_indices = []
        col_indices = []
        
        for _ in range(n_entries):
            line = f.readline().strip()
            if line:
                row, col, value = line.split()
                # MTX format is 1-indexed
                row_indices.append(int(row) - 1)
                col_indices.append(int(col) - 1)
                data.append(int(value))
        
        from scipy.sparse import coo_matrix
        matrix = coo_matrix((data, (row_indices, col_indices)), shape=(n_genes, n_cells))
        return matrix.tocsr()

def process_sample(sample_id, data_dir):
    """Process a single sample and return its data with unique cell IDs."""
    print(f"Processing sample {sample_id}")
    
    # Define file paths
    matrix_file = os.path.join(data_dir, f"{GEO_ACCESSION}_{sample_id}_filtered_matrix.mtx.gz")
    features_file = os.path.join(data_dir, f"{GEO_ACCESSION}_{sample_id}_filtered_features.tsv.gz")
    barcodes_file = os.path.join(data_dir, f"{GEO_ACCESSION}_{sample_id}_filtered_barcodes.tsv.gz")
    protospacer_file = os.path.join(data_dir, f"{GEO_ACCESSION}_{sample_id}_protospacer_calls_per_cell.csv.gz")
    
    # Load matrix
    try:
        matrix = read_mtx_file(matrix_file)
    except Exception as e:
        print(f"Error reading matrix file: {e}")
        return None
    
    # Load features (genes)
    with gzip.open(features_file, 'rt') as f:
        features = pd.read_csv(f, sep='\t', header=None)
        features.columns = ['gene_id', 'gene_symbol', 'feature_type']
    
    # Load barcodes
    with gzip.open(barcodes_file, 'rt') as f:
        barcodes = pd.read_csv(f, sep='\t', header=None)
        barcodes.columns = ['barcode']
    
    # Load perturbation (protospacer) info
    with gzip.open(protospacer_file, 'rt') as f:
        perturbations = pd.read_csv(f)
    
    # Create dictionary mapping original barcode to perturbation call
    perturbation_dict = dict(zip(perturbations['cell_barcode'], perturbations['feature_call']))
    
    # Create unique cell IDs by prefixing with sample_id
    unique_barcodes = [f"{sample_id}_{bc}" for bc in barcodes['barcode']]
    cell_metadata = pd.DataFrame(index=unique_barcodes)
    cell_metadata['perturbation'] = [perturbation_dict.get(bc, 'Unknown') for bc in barcodes['barcode']]
    cell_metadata['sample_id'] = sample_id
    
    return {
        'matrix': matrix,
        'features': features,
        'barcodes': unique_barcodes,
        'cell_metadata': cell_metadata
    }

def harmonize_metadata(cell_metadata):
    """Harmonize metadata: clean perturbation names, set conditions, and add standard fields."""
    metadata = cell_metadata.copy()
    
    def parse_perturbation(pert_string):
        # Return 'Unknown' if missing
        if pd.isna(pert_string) or pert_string == 'Unknown':
            return 'Unknown'
        
        # Split multiple targets (if any)
        targets = pert_string.split('|')
        # Filter out control calls
        non_control_targets = [t for t in targets if not (t.startswith('CTRL') or t.startswith('hRosa26'))]
        if len(non_control_targets) == 0:
            return 'Control'
        # Remove numeric suffix (e.g. _5, _3)
        cleaned_targets = [re.sub(r'_\d+$', '', t) for t in non_control_targets]
        if len(cleaned_targets) == 1:
            return cleaned_targets[0]
        else:
            return '+'.join(sorted(cleaned_targets))
    
    metadata['perturbation_name'] = metadata['perturbation'].apply(parse_perturbation)
    
    # Set condition: if non-control, mark as "Test"; if control, keep as "Control"
    metadata['condition'] = metadata['perturbation'].apply(
        lambda x: 'Control' if pd.isna(x) or x == 'Unknown' or all(t.startswith('CTRL') or t.startswith('hRosa26') for t in x.split('|'))
        else 'Test'
    )
    
    # Add additional metadata fields
    metadata['organism'] = 'Homo sapiens'
    metadata['cell_type'] = 'MCF10A'       # MCF10A-Cas9-Venus-vector cells
    metadata['crispr_type'] = 'CRISPR KO'
    metadata['cancer_type'] = 'Non-Cancer'
    
    # Exclude cells with unknown perturbation
    metadata = metadata[metadata['perturbation'] != 'Unknown']
    
    return metadata

def create_h5ad(data_list, output_file):
    """Create an h5ad file from processed data."""
    if not data_list:
        print("No data to process")
        return
    
    all_matrices = []
    all_metadata = []
    
    # Assume features are identical across samples (use first sample's features)
    features = data_list[0]['features']
    
    for data in data_list:
        all_matrices.append(data['matrix'])
        all_metadata.append(data['cell_metadata'])
    
    from scipy.sparse import hstack
    # Horizontally stack matrices (cells are columns)
    combined_matrix = hstack(all_matrices).tocsr()
    
    # Concatenate metadata (order corresponds to matrix column order)
    combined_metadata = pd.concat(all_metadata)
    
    # First, filter out cells with 'Unknown' perturbation
    mask = combined_metadata['perturbation'] != 'Unknown'
    filtered_metadata = combined_metadata[mask]
    
    # Harmonize metadata on the filtered set
    harmonized_metadata = harmonize_metadata(filtered_metadata)
    
    # Determine indices (in combined_metadata) of cells to keep
    indices_to_keep = [i for i, keep in enumerate(mask) if keep]
    filtered_matrix = combined_matrix[:, indices_to_keep]
    
    if HAS_ANNDATA:
        # Create AnnData with cells as rows
        adata = ad.AnnData(
            X=filtered_matrix.T,
            obs=harmonized_metadata,
            var=features.set_index('gene_id')
        )
        # Set gene symbols as variable names
        adata.var_names = features['gene_symbol'].values
        adata.write(output_file)
        print(f"Saved harmonized data to {output_file}")
        return adata
    else:
        with h5py.File(output_file, 'w') as f:
            matrix_group = f.create_group('X')
            matrix_csr = filtered_matrix.T.tocsr()
            matrix_group.create_dataset('data', data=matrix_csr.data)
            matrix_group.create_dataset('indices', data=matrix_csr.indices)
            matrix_group.create_dataset('indptr', data=matrix_csr.indptr)
            matrix_group.attrs['shape'] = matrix_csr.shape
            
            obs_group = f.create_group('obs')
            for col in harmonized_metadata.columns:
                obs_group.create_dataset(col, data=harmonized_metadata[col].values.astype('S'))
            
            var_group = f.create_group('var')
            for col in features.columns:
                var_group.create_dataset(col, data=features[col].values.astype('S'))
            
            # Save unique cell IDs and gene symbols
            f.create_dataset('obs_names', data=np.array(harmonized_metadata.index, dtype='S'))
            f.create_dataset('var_names', data=features['gene_symbol'].values.astype('S'))
            
            print(f"Saved harmonized data to {output_file} using h5py")
        return None

def main(data_dir):
    """Main function to process and harmonize the dataset in Jupyter Notebook."""
    output_file = os.path.join(data_dir, f"{GEO_ACCESSION}_harmonized.h5ad")
    
    # Download files if necessary
    download_files(data_dir)
    
    data_list = []
    for sample in SAMPLES:
        data = process_sample(sample, data_dir)
        if data:
            data_list.append(data)
    
    create_h5ad(data_list, output_file)
    print("Processing complete!")

# Set your data directory and run the main function
data_dir = "data"  # Modify if needed
main(data_dir)
