In [4]:
import os
import sys
import gzip
import glob
import shutil
import urllib.request
from pathlib import Path
import numpy as np
import pandas as pd
import h5py
from scipy import sparse
import anndata as ad

# URLs for downloading the dataset
DATASET_URLS = {
    'raw_data': 'https://ftp.ncbi.nlm.nih.gov/geo/series/GSE250nnn/GSE250558/suppl/GSE250558_RAW.tar',
    'feature_ref': 'https://ftp.ncbi.nlm.nih.gov/geo/series/GSE250nnn/GSE250558/suppl/GSE250558_feature_ref.csv.gz',
    'results': 'https://ftp.ncbi.nlm.nih.gov/geo/series/GSE250nnn/GSE250558/suppl/GSE250558_all_results_combined.tsv.gz'
}

def download_file(url, output_path):
    """Download a file from a URL to the specified output path."""
    print(f"Downloading {url} to {output_path}...")
    urllib.request.urlretrieve(url, output_path)
    print(f"Downloaded {output_path}")

def ensure_data_available(data_dir):
    """Ensure all required data files are available, downloading if necessary."""
    data_dir = Path(data_dir)
    data_dir.mkdir(exist_ok=True, parents=True)
    
    # Check for raw data tar file
    raw_tar_path = data_dir / 'GSE250558_RAW.tar'
    if not raw_tar_path.exists():
        download_file(DATASET_URLS['raw_data'], raw_tar_path)
        # Extract the tar file
        print(f"Extracting {raw_tar_path}...")
        os.system(f"tar -xf {raw_tar_path} -C {data_dir}")
    
    # Check for feature reference file
    feature_ref_path = data_dir / 'GSE250558_feature_ref.csv.gz'
    if not feature_ref_path.exists():
        download_file(DATASET_URLS['feature_ref'], feature_ref_path)
    
    # Check for results file
    results_path = data_dir / 'GSE250558_all_results_combined.tsv.gz'
    if not results_path.exists():
        download_file(DATASET_URLS['results'], results_path)
    
    # Check if count matrices are extracted
    count_matrices = list(data_dir.glob('GSM*.count_matrix.h5'))
    if not count_matrices:
        print("Count matrices not found. Checking if they need to be extracted from tar file...")
        os.system(f"tar -xf {raw_tar_path} -C {data_dir}")
        count_matrices = list(data_dir.glob('GSM*.count_matrix.h5'))
        if not count_matrices:
            raise FileNotFoundError("Count matrices not found even after extraction. Please check the tar file.")
    
    return {
        'feature_ref_path': feature_ref_path,
        'results_path': results_path,
        'count_matrices': count_matrices
    }

def read_10x_h5(filename):
    """Read a 10X h5 file and return the matrix, barcodes, and features."""
    with h5py.File(filename, 'r') as f:
        # Get the matrix dimensions
        shape = f['matrix']['shape'][:]
        
        # Get the sparse matrix data
        data = f['matrix']['data'][:]
        indices = f['matrix']['indices'][:]
        indptr = f['matrix']['indptr'][:]
        
        # Create a CSC matrix (genes are columns in 10X format)
        matrix = sparse.csc_matrix((data, indices, indptr), shape=shape)
        
        # Get barcodes and features
        barcodes = [b.decode('utf-8') for b in f['matrix']['barcodes'][:]]
        
        # Get feature information
        feature_dict = {}
        for key in f['matrix']['features'].keys():
            if key != '_all_tag_keys':  # Skip this key
                feature_dict[key] = [item.decode('utf-8') if isinstance(item, bytes) else item 
                                    for item in f['matrix']['features'][key][:]]
        
        feature_df = pd.DataFrame(feature_dict)
        
        return matrix, barcodes, feature_df

def read_protospacer_calls(filename):
    """Read a protospacer calls file and return as a DataFrame."""
    with gzip.open(filename, 'rt') as f:
        return pd.read_csv(f)

def parse_sample_info(filename):
    """Parse sample information from filename."""
    basename = os.path.basename(filename)
    parts = basename.split('_')
    
    # Extract time point (6h, 12h, 18h)
    time_point = next((part for part in parts if part.endswith('h')), None)
    
    # Check if it's a replicate
    is_replicate = any(part.startswith('R') and part[1:].isdigit() for part in parts)
    replicate = next((part for part in parts if part.startswith('R') and part[1:].isdigit()), 'R1')
    
    # Check if it's untargeted
    is_untargeted = 'untargeted' in basename
    
    # Create a sample key that can be used to match with protospacer files
    # Format: Hek6h_R2_untargeted or Hek6h
    sample_key = time_point if time_point else ''
    if replicate != 'R1':
        sample_key += f"_{replicate}"
    if is_untargeted:
        sample_key += "_untargeted"
    
    return {
        'time_point': time_point,
        'replicate': replicate,
        'is_untargeted': is_untargeted,
        'sample_id': basename.split('.')[0],
        'sample_key': sample_key
    }

def process_sample(count_matrix_path, feature_ref, proto_file_dict, output_dir):
    """Process a single sample and save it to a file."""
    sample_info = parse_sample_info(count_matrix_path)
    sample_id = sample_info['sample_id']
    sample_key = sample_info['sample_key']
    print(f"Processing {sample_id} (key: {sample_key})...")
    
    # Create a mapping from guide ID to target gene
    guide_to_target = dict(zip(feature_ref['id'], feature_ref['target_gene_name']))
    
    # Find corresponding protospacer calls file
    proto_file = proto_file_dict.get(sample_key)
    
    # Read count matrix
    matrix, barcodes, feature_df = read_10x_h5(count_matrix_path)
    
    # Create AnnData object
    adata = ad.AnnData(
        X=matrix.T,  # Transpose to get cells as rows, genes as columns
        obs=pd.DataFrame(index=barcodes),
        var=feature_df.set_index('id')
    )
    
    # Add sample information to obs
    adata.obs['time_point'] = str(sample_info['time_point'])
    adata.obs['replicate'] = str(sample_info['replicate'])
    adata.obs['is_untargeted'] = str(sample_info['is_untargeted'])
    adata.obs['sample_id'] = str(sample_id)
    
    if proto_file is None:
        print(f"Warning: No protospacer calls file found for {sample_id}. Creating AnnData without perturbation info.")
        # Add default perturbation information
        adata.obs['perturbation'] = 'Unknown'
        adata.obs['target_gene'] = 'Unknown'
        adata.obs['condition'] = 'Unknown'
    else:
        # Read protospacer calls
        proto_calls = read_protospacer_calls(proto_file)
        
        # Add perturbation information from protospacer calls
        # Create a mapping from cell barcode to perturbation
        cell_to_perturbation = {}
        cell_to_target_gene = {}
        
        for _, row in proto_calls.iterrows():
            cell_barcode = row['cell_barcode']
            feature_call = row['feature_call']
            cell_to_perturbation[cell_barcode] = feature_call
            cell_to_target_gene[cell_barcode] = guide_to_target.get(feature_call, 'Unknown')
        
        # Add perturbation information to obs
        adata.obs['perturbation'] = adata.obs.index.map(lambda x: cell_to_perturbation.get(x, 'Unknown'))
        adata.obs['target_gene'] = adata.obs.index.map(lambda x: cell_to_target_gene.get(x, 'Unknown'))
        
        # Add condition (control or test)
        adata.obs['condition'] = adata.obs['target_gene'].apply(
            lambda x: 'Control' if x in ['Non-Targeting', 'Safe_Cutter'] else 'Test'
        )
    
    # Add harmonized metadata
    adata.obs['organism'] = 'Homo sapiens'
    adata.obs['cell_type'] = 'HEK293'
    adata.obs['crispr_type'] = 'CRISPR KO'
    adata.obs['cancer_type'] = 'Non-Cancer'
    adata.obs['perturbation_name'] = adata.obs['target_gene']
    
    # Save the sample
    output_file = os.path.join(output_dir, f"{sample_id}_processed.h5ad")
    adata.write_h5ad(output_file)
    print(f"Saved {sample_id} to {output_file}")
    
    return output_file

def run_harmonization(data_dir, output_file=None):
    """
    Run the harmonization pipeline in a Jupyter Notebook.
    
    Parameters:
      data_dir (str): Directory where the data will be stored and processed.
      output_file (str, optional): Final combined output file path.
    """
    print(f"Processing GSE250558 dataset from {data_dir}")
    if output_file is None:
        output_file = os.path.join(data_dir, 'GSE250558_harmonized.h5ad')
    print(f"Combined output will be saved to {output_file}")
    
    # Create a directory for intermediate files
    data_dir_path = Path(data_dir)
    output_dir = data_dir_path / 'processed_samples'
    output_dir.mkdir(exist_ok=True)
    
    # Ensure data is available
    file_paths = ensure_data_available(data_dir)
    
    # Read feature reference
    with gzip.open(file_paths['feature_ref_path'], 'rt') as f:
        feature_ref = pd.read_csv(f)
    
    # Get all protospacer call files
    proto_files = list(data_dir_path.glob('GSM*.protospacer_calls_per_cell.csv.gz'))
    proto_file_dict = {}
    
    # Create a mapping from sample key to protospacer file
    for proto_file in proto_files:
        file_name = proto_file.name
        # Extract the sample key from the protospacer file name
        # Format: GSM7981577_Hek6h.protospacer_calls_per_cell.csv.gz
        # or: GSM7981577_Hek6h_untargeted.protospacer_calls_per_cell.csv.gz
        parts = file_name.split('.')
        if len(parts) >= 2:
            sample_parts = parts[0].split('_')
            # Remove the GSM number
            sample_parts = sample_parts[1:]
            # Join the remaining parts to get the sample key
            sample_key = '_'.join(sample_parts)
            proto_file_dict[sample_key] = proto_file
            print(f"Mapped protospacer file {file_name} to sample key {sample_key}")
    
    # Process each sample individually
    processed_files = []
    for count_matrix_path in sorted(data_dir_path.glob('GSM*.count_matrix.h5')):
        processed_file = process_sample(count_matrix_path, feature_ref, proto_file_dict, output_dir)
        processed_files.append(processed_file)
    
    # Create a summary of the processed samples
    print(f"\nProcessed {len(processed_files)} samples.")
    print("Creating a summary of the processed samples...")
    
    # Create a summary file with metadata
    summary_file = os.path.join(data_dir, 'GSE250558_summary.csv')
    metadata_list = []
    
    for file_path in processed_files:
        try:
            adata = ad.read_h5ad(file_path, backed='r')
            sample_obs = adata.obs.head(1).copy()
            sample_obs['n_cells'] = adata.n_obs
            sample_obs['n_genes'] = adata.n_vars
            sample_obs['file_path'] = file_path
            metadata_list.append(sample_obs)
            del adata
        except Exception as e:
            print(f"Error processing {file_path}: {e}")
    
    if metadata_list:
        metadata_df = pd.concat(metadata_list)
        metadata_df.to_csv(summary_file)
        print(f"Summary saved to {summary_file}")
    
    # Combine all processed samples into one AnnData object
    print("\nCombining all processed samples into one AnnData object...")
    sample_files = list(data_dir_path.glob('processed_samples/*.h5ad'))
    if sample_files:
        combined = ad.concat([ad.read_h5ad(f) for f in sample_files], join='outer', merge='same')
        combined.write_h5ad(output_file)
        print(f"Combined AnnData saved to {output_file}")
    else:
        print("No processed sample files found to combine.")
    
    print("\nIndividual sample files have been saved in the 'processed_samples' directory.")
    print("To work with the data, you can load individual samples or use the combined dataset as needed.")
    
    # Print summary from the metadata
    if metadata_list:
        print("\nSample summary:")
        print(f"Total number of samples: {len(metadata_df)}")
        print(f"Total number of cells: {metadata_df['n_cells'].sum()}")
        print("\nTime points:")
        print(metadata_df['time_point'].value_counts())
        print("\nReplicates:")
        print(metadata_df['replicate'].value_counts())
        if 'perturbation_name' in metadata_df.columns:
            non_unknown = metadata_df[metadata_df['perturbation_name'] != 'Unknown']
            if len(non_unknown) > 0:
                print("\nSamples with perturbation information:")
                print(non_unknown[['sample_id', 'n_cells']])
    
    return

# --- Run the harmonization pipeline ---
# Replace 'your_data_directory_path' with the path where you want to store and process the data.
data_directory = "/content/GSE250558"  # e.g., "./data"
run_harmonization(data_directory)


In [None]:
import anndata as ad

# Load the combined AnnData object
adata = ad.read_h5ad("/content/GSE250558/GSE250558_harmonized.h5ad")  # adjust file path if needed

# Exclude cells with "Unknown" perturbation
adata = adata[adata.obs['perturbation_name'] != "Unknown"].copy()

# If 'perturbation_name' is categorical, add "Non-targeting" as a category if needed.
if hasattr(adata.obs['perturbation_name'], "cat"):
    if "Non-targeting" not in adata.obs['perturbation_name'].cat.categories:
        adata.obs['perturbation_name'] = adata.obs['perturbation_name'].cat.add_categories("Non-targeting")

# Standardize the labels: convert "Non-Targeting" and "Safe_Cutter" to "Non-targeting"
adata.obs.loc[adata.obs['perturbation_name'].isin(["Non-Targeting", "Safe_Cutter"]), "perturbation_name"] = "Non-targeting"
# Set the condition for these cells to "Control"
adata.obs.loc[adata.obs['perturbation_name'] == "Non-targeting", "condition"] = "Control"

# Print the number of cells after filtering and the number of control cells
print("Total cells after filtering unknowns:", adata.n_obs)
control_count = (adata.obs['perturbation_name'] == "Non-targeting").sum()
print("Number of control cells (Non-targeting):", control_count)
