In [4]:
"""
Jupyter Notebook version of the GSE236304 harmonization script with orientation fix and updated perturbation_name formatting.

This script processes the GSE236304 dataset (MethNet: a robust approach to identify regulatory hubs 
and their distal targets in cancer [Perturb-seq]) and harmonizes it into h5ad format with standardized metadata.

The dataset contains CRISPR-i perturbation data from A549 cells with dCas9-KRAB-MECP2.
"""

import os
import logging
import numpy as np
import pandas as pd
import h5py
from urllib.request import urlretrieve
import gzip
import shutil
from scipy import sparse
import anndata

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger('GSE236304_harmonizer')

# URLs for the dataset files
DATASET_FILES = {
    'barcodes': 'https://ftp.ncbi.nlm.nih.gov/geo/series/GSE236nnn/GSE236304/suppl/GSE236304_barcodes.tsv.gz',
    'features': 'https://ftp.ncbi.nlm.nih.gov/geo/series/GSE236nnn/GSE236304/suppl/GSE236304_features.tsv.gz',
    'matrix': 'https://ftp.ncbi.nlm.nih.gov/geo/series/GSE236nnn/GSE236304/suppl/GSE236304_matrix.mtx.gz',
    'feature_reference': 'https://ftp.ncbi.nlm.nih.gov/geo/series/GSE236nnn/GSE236304/suppl/GSE236304_feature_reference.csv.gz',
    'protospacer_umi_thresholds': 'https://ftp.ncbi.nlm.nih.gov/geo/series/GSE236nnn/GSE236304/suppl/GSE236304_protospacer_umi_thresholds.csv.gz',
    'h5_matrix': 'https://ftp.ncbi.nlm.nih.gov/geo/series/GSE236nnn/GSE236304/suppl/GSE236304_filtered_feature_bc_matrix.h5'
}

def download_file(url, destination):
    """Download a file from a URL to a destination."""
    if os.path.exists(destination):
        logger.info(f"File already exists: {destination}")
        return
    
    logger.info(f"Downloading {url} to {destination}")
    urlretrieve(url, destination)
    logger.info(f"Downloaded {destination}")

def ensure_dataset_files(data_dir):
    """Ensure all dataset files are available, downloading if necessary."""
    os.makedirs(data_dir, exist_ok=True)
    
    for file_key, url in DATASET_FILES.items():
        file_name = os.path.basename(url)
        file_path = os.path.join(data_dir, file_name)
        download_file(url, file_path)
    
    return {file_key: os.path.join(data_dir, os.path.basename(url)) for file_key, url in DATASET_FILES.items()}

def load_10x_h5(filename):
    """Load data from a 10X h5 file and detect the correct orientation."""
    logger.info(f"Loading data from {filename}")
    
    with h5py.File(filename, 'r') as f:
        # Get the underlying CSR components
        data = f['matrix/data'][:]
        indices = f['matrix/indices'][:]
        indptr = f['matrix/indptr'][:]
        shape_raw = f['matrix/shape'][:]
        
        # Determine correct shape:
        # If indptr length equals shape_raw[0]+1 then matrix is stored as (features x cells)
        # Otherwise, if it equals shape_raw[1]+1 then matrix is stored as (cells x features)
        if len(indptr) == shape_raw[0] + 1:
            matrix_shape = tuple(shape_raw)
        elif len(indptr) == shape_raw[1] + 1:
            matrix_shape = (shape_raw[1], shape_raw[0])
        else:
            raise ValueError("indptr length does not match any dimension of shape.")
        
        # Get barcodes
        barcodes = f['matrix/barcodes'][:]
        barcodes = [bc.decode('utf-8') for bc in barcodes]
        
        # Get feature info
        feature_ids = f['matrix/features/id'][:]
        feature_ids = [fid.decode('utf-8') for fid in feature_ids]
        
        feature_names = f['matrix/features/name'][:]
        feature_names = [fn.decode('utf-8') for fn in feature_names]
        
        feature_types = f['matrix/features/feature_type'][:]
        feature_types = [ft.decode('utf-8') for ft in feature_types]
        
        target_gene_ids = f['matrix/features/target_gene_id'][:]
        target_gene_ids = [tg.decode('utf-8') for tg in target_gene_ids]
        
        target_gene_names = f['matrix/features/target_gene_name'][:]
        target_gene_names = [tg.decode('utf-8') for tg in target_gene_names]
    
    # Create feature DataFrame
    features = pd.DataFrame({
        'id': feature_ids,
        'name': feature_names,
        'feature_type': feature_types,
        'target_gene_id': target_gene_ids,
        'target_gene_name': target_gene_names
    })
    
    return data, indices, indptr, matrix_shape, barcodes, features

def create_anndata_object(data, indices, indptr, shape, barcodes, features):
    """Create an AnnData object from the loaded data.
    
    This function ensures that the returned sparse matrix has rows corresponding to features.
    If the number of barcodes equals the first dimension of the matrix, then the matrix is stored as (cells, features)
    and will be transposed.
    """
    logger.info("Creating AnnData object")
    
    # Create the sparse matrix from the loaded data
    matrix = sparse.csr_matrix((data, indices, indptr), shape=shape)
    
    # If the number of barcodes equals the first dimension, then the matrix is (cells x features).
    # Transpose it so that rows correspond to features.
    if len(barcodes) == shape[0]:
        matrix = matrix.T  # now matrix is (features x cells)
    
    # Now, extract gene expression features (rows where feature_type is "Gene Expression")
    gene_expr_features = features[features['feature_type'] == 'Gene Expression']
    gene_expr_indices = gene_expr_features.index.tolist()
    
    # Extract the gene expression matrix (rows: gene expression features, columns: cells)
    gene_expr_matrix = matrix[gene_expr_indices, :]
    
    # Create AnnData object: transpose gene_expr_matrix so that cells are rows
    adata = anndata.AnnData(
        X=gene_expr_matrix.T,
        obs=pd.DataFrame(index=barcodes),
        var=gene_expr_features.set_index('id')
    )
    
    # Add gene names to the var DataFrame
    adata.var['gene_name'] = gene_expr_features['name'].values
    
    return adata, matrix

def assign_perturbations(adata, features, matrix, protospacer_df):
    """Assign perturbations to cells based on CRISPR guide counts."""
    logger.info("Assigning perturbations to cells")
    
    # Get CRISPR guide features
    crispr_features = features[features['feature_type'] == 'CRISPR Guide Capture']
    crispr_indices = crispr_features.index.tolist()
    
    # Extract CRISPR guide counts for each cell from the full matrix
    # (matrix is assumed to be (features x cells))
    crispr_matrix = matrix[crispr_indices, :].toarray()
    
    # Create a DataFrame with CRISPR guide counts
    crispr_df = pd.DataFrame(
        crispr_matrix,
        index=crispr_features['name'],
        columns=adata.obs.index
    ).T  # now rows: cells, columns: guides
    
    # Merge protospacer UMI thresholds
    protospacer_thresholds = dict(zip(protospacer_df['Protospacer'], protospacer_df['UMI threshold']))
    
    # Determine which guides are detected in each cell based on UMI thresholds
    detected_guides = pd.DataFrame(index=adata.obs.index, columns=['detected_guides', 'target_genes'])
    
    for cell in adata.obs.index:
        cell_guides = []
        cell_targets = []
        
        for guide in crispr_df.columns:
            threshold = protospacer_thresholds.get(guide, 0)
            if crispr_df.loc[cell, guide] >= threshold:
                cell_guides.append(guide)
                
                # Find the target gene for this guide
                target = crispr_features[crispr_features['name'] == guide]['target_gene_name'].values
                if len(target) > 0 and target[0]:
                    cell_targets.append(target[0])
        
        # Join target genes with semicolon for internal use
        detected_guides.loc[cell, 'detected_guides'] = ';'.join(cell_guides) if cell_guides else 'None'
        detected_guides.loc[cell, 'target_genes'] = ';'.join(set(cell_targets)) if cell_targets else 'None'
    
    # Add detected guides to adata.obs
    adata.obs['detected_guides'] = detected_guides['detected_guides']
    adata.obs['target_genes'] = detected_guides['target_genes']
    
    # Determine perturbation status
    # If target_genes is 'None' or 'Non-Targeting', assign 'Non-targeting'
    # Otherwise, replace semicolons with plus signs
    adata.obs['perturbation_name'] = adata.obs['target_genes'].apply(
        lambda x: 'Non-targeting' if x.strip().lower() in ['non-targeting', 'none'] else x.replace(';', '+')
    )
    
    # Determine condition (control or test)
    adata.obs['condition'] = adata.obs['perturbation_name'].apply(
        lambda x: 'control' if x == 'Non-targeting' else 'test'
    )
    
    return adata

def assign_cell_hashing(adata, features, matrix):
    """Assign cell hashing information based on CMO counts."""
    logger.info("Assigning cell hashing information")
    
    # Get CMO features (Antibody Capture)
    cmo_features = features[features['feature_type'] == 'Antibody Capture']
    cmo_indices = cmo_features.index.tolist()
    
    # Extract CMO counts for each cell (matrix is (features x cells))
    cmo_matrix = matrix[cmo_indices, :].toarray()
    
    # Create a DataFrame with CMO counts (rows: cells, columns: CMOs)
    cmo_df = pd.DataFrame(
        cmo_matrix,
        index=cmo_features['name'],
        columns=adata.obs.index
    ).T
    
    # For each cell, determine which CMO has the highest count
    adata.obs['cmo_assignment'] = cmo_df.idxmax(axis=1)
    adata.obs['cmo_count'] = cmo_df.max(axis=1)
    
    return adata

def harmonize_metadata(adata):
    """Harmonize metadata according to the required standards."""
    logger.info("Harmonizing metadata")
    
    # Set organism
    adata.obs['organism'] = 'Homo sapiens'
    
    # Set cell type (A549 cells)
    adata.obs['cell_type'] = 'A549'
    
    # Set CRISPR type (CRISPRi)
    adata.obs['crispr_type'] = 'CRISPRi'
    
    # Set cancer type (Lung Cancer - A549 is a lung adenocarcinoma cell line)
    adata.obs['cancer_type'] = 'Lung Cancer'
    
    return adata

def process_dataset(data_dir, output_file):
    """Process the GSE236304 dataset and save as h5ad."""
    # Ensure dataset files are available
    file_paths = ensure_dataset_files(data_dir)
    
    # Load data from h5 file with orientation detection
    data, indices, indptr, matrix_shape, barcodes, features = load_10x_h5(file_paths['h5_matrix'])
    
    # Load protospacer UMI thresholds
    protospacer_df = pd.read_csv(file_paths['protospacer_umi_thresholds'])
    
    # Create AnnData object and get the full matrix (now guaranteed to be in orientation (features x cells))
    adata, matrix = create_anndata_object(data, indices, indptr, matrix_shape, barcodes, features)
    
    # Assign perturbations
    adata = assign_perturbations(adata, features, matrix, protospacer_df)
    
    # Assign cell hashing information
    adata = assign_cell_hashing(adata, features, matrix)
    
    # Harmonize metadata
    adata = harmonize_metadata(adata)
    
    # Save the harmonized dataset
    logger.info(f"Saving harmonized dataset to {output_file}")
    adata.write_h5ad(output_file)
    logger.info("Done!")
    
    return adata

# -------------------------------
# Run the harmonization process
# -------------------------------

# Set the directory where dataset files will be stored
data_dir = 'data_'  # Change this path as needed

# Define the output file for the harmonized AnnData object
output_file = os.path.join(data_dir, 'GSE236304_harmonized.h5ad')

# Ensure the data directory exists
os.makedirs(data_dir, exist_ok=True)

# Process the dataset and save the result
adata = process_dataset(data_dir, output_file)
print(f"Harmonized dataset saved to {output_file}")
