In [None]:
import os
import glob
import gzip
import tarfile
import urllib.request
from pathlib import Path
import numpy as np
import pandas as pd
import anndata as ad
from scipy import sparse
from tqdm import tqdm

# Constants
GEO_ACCESSION = "GSE280767"
GEO_URL = f"https://www.ncbi.nlm.nih.gov/geo/download/?acc={GEO_ACCESSION}&format=file"
MAX_ENTRIES_PER_SAMPLE = 100000000  # Limit entries per sample for efficiency

def download_dataset(data_dir):
    """Download the dataset if it doesn't exist."""
    data_dir = Path(data_dir)
    data_dir.mkdir(exist_ok=True, parents=True)
    
    tar_file = data_dir / f"{GEO_ACCESSION}_RAW.tar"
    
    if not tar_file.exists():
        print(f"Downloading {GEO_ACCESSION} dataset...")
        urllib.request.urlretrieve(GEO_URL, tar_file)
    
    # Extract if not already extracted
    if not glob.glob(f"{data_dir}/*_barcodes.tsv.gz"):
        print("Extracting files...")
        with tarfile.open(tar_file, 'r') as tar:
            tar.extractall(path=data_dir)
    
    return data_dir

def improved_parse_sample_metadata(sample_id):
    """Improved parser for sample metadata from sample ID."""
    metadata = {
        'sample_id': sample_id,
        'organism': 'Homo sapiens',  # All samples are human
        'cell_type': 'T Cells',      # Default cell type
        'crispr_type': 'Unknown',
        'cancer_type': 'Non-Cancer',
        'condition': 'Unknown',
        'perturbation_name': 'Unknown'
    }
    
    # Sample ID mapping based on paper information
    isc_scrna_mapping = {
        # Day 3 samples
        'ISC_SCRNA001': {'condition': 'Day 3', 'crispr_type': 'None', 'perturbation_name': 'Non-targeting'},
        'ISC_SCRNA005': {'condition': 'Day 3', 'crispr_type': 'CRISPR Cas9', 'perturbation_name': 'B2M'},
        'ISC_SCRNA006': {'condition': 'Day 3', 'crispr_type': 'CRISPR Cas9', 'perturbation_name': 'PDCD1'},
        'ISC_SCRNA007': {'condition': 'Day 3', 'crispr_type': 'CRISPR Cas9', 'perturbation_name': 'Non-targeting'},
        'ISC_SCRNA008': {'condition': 'Day 3', 'crispr_type': 'ABE', 'perturbation_name': 'B2M'},
        'ISC_SCRNA009': {'condition': 'Day 3', 'crispr_type': 'ABE', 'perturbation_name': 'PDCD1'},
        'ISC_SCRNA010': {'condition': 'Day 3', 'crispr_type': 'ABE', 'perturbation_name': 'Non-targeting'},
        
        # Day 7 samples
        'ISC_SCRNA014': {'condition': 'Day 7', 'crispr_type': 'CRISPR Cas9', 'perturbation_name': 'B2M'},
        'ISC_SCRNA015': {'condition': 'Day 7', 'crispr_type': 'CRISPR Cas9', 'perturbation_name': 'PDCD1'},
        'ISC_SCRNA016': {'condition': 'Day 7', 'crispr_type': 'ABE', 'perturbation_name': 'B2M'},
        'ISC_SCRNA017': {'condition': 'Day 7', 'crispr_type': 'ABE', 'perturbation_name': 'PDCD1'},
        
        # Day 21 samples
        'ISC_SCRNA018': {'condition': 'Day 21', 'crispr_type': 'CRISPR Cas9', 'perturbation_name': 'B2M'},
        'ISC_SCRNA019': {'condition': 'Day 21', 'crispr_type': 'CRISPR Cas9', 'perturbation_name': 'PDCD1'},
    }
    
    # Check if this is an ISC_SCRNA sample and apply mapping
    for scrna_id, info in isc_scrna_mapping.items():
        if scrna_id in sample_id:
            for key, value in info.items():
                metadata[key] = value
            break
    
    # Process explicitly labeled samples (as in original function)
    # Extract CRISPR type
    if 'Cas9' in sample_id or 'Casq' in sample_id:
        metadata['crispr_type'] = 'CRISPR Cas9'
    elif 'ABE' in sample_id:
        if 'ABEe8' in sample_id or 'V106W' in sample_id:
            metadata['crispr_type'] = 'ABE-V106W'  # High-fidelity ABE
        else:
            metadata['crispr_type'] = 'ABE'  # Adenine Base Editor
    
    # Extract perturbation target
    if 'NTC' in sample_id:
        metadata['perturbation_name'] = 'Non-targeting'
    elif 'B2M' in sample_id:
        metadata['perturbation_name'] = 'B2M'
    elif 'PD1' in sample_id:
        metadata['perturbation_name'] = 'PDCD1'
    
    # Extract cell type
    if 'huh7' in sample_id.lower():
        metadata['cell_type'] = 'Huh7'  # Liver cancer cell line
        metadata['cancer_type'] = 'Hepatocellular carcinoma'
    
    # Extract time point for explicitly labeled samples
    if 'day3' in sample_id.lower():
        metadata['condition'] = "Day 3"
    elif 'day7' in sample_id.lower():
        metadata['condition'] = "Day 7"
    elif 'day21' in sample_id.lower():
        metadata['condition'] = "Day 21"
    
    # Handle blank control
    if 'blank' in sample_id.lower():
        metadata['perturbation_name'] = 'Non-targeting'
        metadata['crispr_type'] = 'None'
    
    return metadata

def load_10x_data(base_path, prefix, max_entries=MAX_ENTRIES_PER_SAMPLE):
    """Load 10X Genomics data files with a limit on the number of entries."""
    try:
        # Find the files with the given prefix
        barcode_file = glob.glob(f"{base_path}/{prefix}_barcodes.tsv.gz")[0]
        features_file = glob.glob(f"{base_path}/{prefix}_features.tsv.gz")[0]
        matrix_file = glob.glob(f"{base_path}/{prefix}_matrix.mtx.gz")[0]
        
        # Read barcodes
        with gzip.open(barcode_file, 'rt') as f:
            barcodes = [line.strip() for line in f]
        
        # Read features
        with gzip.open(features_file, 'rt') as f:
            features_data = [line.strip().split('\t') for line in f]
        
        # Extract gene IDs and gene symbols
        gene_ids = [row[0] for row in features_data]
        gene_symbols = [row[1] for row in features_data]
        
        # Read the sparse matrix header
        with gzip.open(matrix_file, 'rt') as f:
            # Skip comments
            header = next(f)
            while header.startswith('%'):
                header = next(f)
            
            # Parse dimensions
            n_genes, n_cells, n_entries = map(int, header.strip().split())
            
            # Read a limited number of entries
            data = []
            row_indices = []
            col_indices = []
            
            for i, line in enumerate(f):
                if i >= max_entries:
                    break
                
                if line.strip():
                    gene_idx, cell_idx, value = line.strip().split()
                    row_indices.append(int(gene_idx) - 1)  # Convert to 0-based indexing
                    col_indices.append(int(cell_idx) - 1)  # Convert to 0-based indexing
                    data.append(float(value))
        
        # Create sparse matrix
        matrix = sparse.csr_matrix(
            (data, (row_indices, col_indices)), 
            shape=(n_genes, n_cells)
        )
        
        # Transpose to get cells as rows, genes as columns
        matrix = matrix.T
        
        # Create DataFrame with gene symbols as index
        var_df = pd.DataFrame(index=gene_symbols, data={'gene_ids': gene_ids})
        
        # Make gene symbols unique
        var_df.index = pd.Index([f"{s}_{i}" if gene_symbols.count(s) > 1 else s 
                                for i, s in enumerate(gene_symbols)])
        
        # Create AnnData object
        adata = ad.AnnData(
            X=matrix,
            obs=pd.DataFrame(index=barcodes),
            var=var_df,
            dtype=np.float32  # Explicitly set dtype to avoid warning
        )
        
        # Make observation names unique
        adata.obs_names_make_unique()
        
        return adata
    
    except Exception as e:
        print(f"Error loading data for {prefix}: {e}")
        return None

def process_dataset(data_dir, max_samples=None, max_entries=MAX_ENTRIES_PER_SAMPLE):
    """Process the entire dataset and create harmonized h5ad file."""
    data_dir = Path(data_dir)
    
    # Download and extract dataset if needed
    data_dir = download_dataset(data_dir)
    
    # Find all sample files
    sample_files = glob.glob(f"{data_dir}/*_barcodes.tsv.gz")
    sample_prefixes = [os.path.basename(f).replace('_barcodes.tsv.gz', '') for f in sample_files]
    
    # Limit the number of samples if specified
    if max_samples is not None:
        sample_prefixes = sample_prefixes[:max_samples]
    
    print(f"Processing {len(sample_prefixes)} samples")
    
    # Process each sample
    all_samples = []
    for prefix in tqdm(sample_prefixes):
        print(f"Processing {prefix}...")
        
        # Load data
        adata = load_10x_data(data_dir, prefix, max_entries)
        if adata is None:
            continue
        
        # Add metadata using improved parser
        metadata = improved_parse_sample_metadata(prefix)
        adata.obs['sample_id'] = metadata['sample_id']
        
        # Add harmonized metadata fields
        for key, value in metadata.items():
            if key != 'sample_id':  # Already added
                adata.obs[key] = value
        
        all_samples.append(adata)
    
    # Combine all samples
    if all_samples:
        print("Combining all samples...")
        
        # Concatenate all samples directly
        print("Concatenating all samples...")
        combined = ad.concat(
            all_samples,
            join='inner',  # Only keep genes that are in all datasets
            merge='same',  # Only merge when attributes are the same
            label='batch',  # Add batch annotation
            keys=[adata.obs['sample_id'].iloc[0] for adata in all_samples]
        )
        
        print(f"Number of genes after concatenation: {combined.n_vars}")
        
        # Ensure required metadata fields are present
        required_fields = ['organism', 'cell_type', 'crispr_type', 'cancer_type', 'condition', 'perturbation_name']
        for field in required_fields:
            if field not in combined.obs.columns:
                combined.obs[field] = 'Unknown'
        
        # ---- Filtering out cells with perturbation_name "Unknown" (or typo "Unknonw") ----
        invalid_names = ['Unknown', 'Unknonw']
        initial_cell_count = combined.n_obs
        combined = combined[~combined.obs['perturbation_name'].isin(invalid_names)]
        filtered_cell_count = combined.n_obs
        print(f"Filtered out {initial_cell_count - filtered_cell_count} cells with perturbation_name in {invalid_names}")
        # --------------------------------------------------------------------------
        
        # Convert metadata fields to categorical
        for field in required_fields:
            combined.obs[field] = combined.obs[field].astype('category')
        
        # Make observation names unique
        combined.obs_names_make_unique()
        
        # Save the combined dataset
        output_file = data_dir / f"{GEO_ACCESSION}_harmonized_improved.h5ad"
        print(f"Saving harmonized dataset to {output_file}")
        combined.write(output_file)
        
        # Print summary
        print("\nDataset Summary:")
        print(f"Total cells: {combined.n_obs}")
        print(f"Total genes: {combined.n_vars}")
        print("\nMetadata fields:")
        for field in required_fields:
            categories = combined.obs[field].cat.categories
            print(f"- {field}: {list(categories)}")
            print(f"  Categories ({len(categories)}, object): {list(categories)}")
        
        return combined
    else:
        print("No samples were successfully processed.")
        return None

def update_existing_h5ad(h5ad_path):
    """Update metadata in an existing h5ad file using the improved parser."""
    import scanpy as sc
    
    print(f"Loading existing h5ad file: {h5ad_path}")
    adata = sc.read(h5ad_path)
    
    print("Updating sample metadata...")
    # Get unique sample IDs
    unique_samples = adata.obs['sample_id'].unique()
    
    # Create a mapping dictionary for each unique sample
    sample_metadata = {}
    for sample_id in unique_samples:
        sample_metadata[sample_id] = improved_parse_sample_metadata(sample_id)
    
    # Update metadata for each cell based on its sample_id
    for field in ['crispr_type', 'condition', 'perturbation_name', 'cell_type', 'cancer_type']:
        adata.obs[field] = adata.obs['sample_id'].map(lambda x: sample_metadata[x][field])
    
    # ---- Filtering out cells with perturbation_name "Unknown" (or typo "Unknonw") ----
    invalid_names = ['Unknown', 'Unknonw']
    initial_cell_count = adata.n_obs
    adata = adata[~adata.obs['perturbation_name'].isin(invalid_names)]
    filtered_cell_count = adata.n_obs
    print(f"Filtered out {initial_cell_count - filtered_cell_count} cells with perturbation_name in {invalid_names}")
    # ------------------------------------------------------------------------------
    
    # Convert categorical columns
    for field in ['organism', 'cell_type', 'crispr_type', 'cancer_type', 'condition', 'perturbation_name']:
        adata.obs[field] = adata.obs[field].astype('category')
    
    # Print summary statistics of updated metadata
    print("\nUpdated Dataset Summary:")
    print(f"Total cells: {adata.n_obs}")
    print(f"Total genes: {adata.n_vars}")
    
    print("\nMetadata field distributions:")
    for field in ['perturbation_name', 'condition', 'crispr_type', 'cell_type']:
        print(f"\n{field} distribution:")
        print(adata.obs[field].value_counts())
    
    # Save the updated dataset
    output_path = h5ad_path.replace('.h5ad', '_improved.h5ad')
    print(f"\nSaving updated dataset to {output_path}")
    adata.write(output_path)
    
    return adata

# Main execution
if __name__ == "__main__":
    data_dir = "./GSE280767"  # Change if you prefer a different directory
    
    # Check if we already have the h5ad file
    existing_h5ad = Path(f"{data_dir}/{GEO_ACCESSION}_harmonized.h5ad")
    
    if existing_h5ad.exists():
        print(f"Found existing h5ad file. Updating metadata...")
        updated_adata = update_existing_h5ad(existing_h5ad)
    else:
        # Process the entire dataset from raw files
        max_samples = None  # Set a specific number to limit samples if needed
        combined_dataset = process_dataset(data_dir, max_samples, MAX_ENTRIES_PER_SAMPLE)


In [None]:
# Identify cells where perturbation_name is "Non-targeting control"
nt_control_mask = adata.obs["perturbation_name"] == "Non-targeting control"
num_nt_control = nt_control_mask.sum()

# Update the perturbation_name and condition columns for these cells
adata.obs.loc[nt_control_mask, "perturbation_name"] = "Non-targeting"
adata.obs.loc[nt_control_mask, "condition"] = "Control"
print(f"Updated {num_nt_control} cells from 'Non-targeting control' to 'Non-targeting' and set their condition to 'Control'.")

# Additionally, ensure that any cell with perturbation_name "Non-targeting" has condition set to "Control"
nt_mask = adata.obs["perturbation_name"] == "Non-targeting"
adata.obs.loc[nt_mask, "condition"] = "Control"

# Optionally convert columns to categorical type if not already
if not pd.api.types.is_categorical_dtype(adata.obs["perturbation_name"]):
    adata.obs["perturbation_name"] = adata.obs["perturbation_name"].astype("category")
if not pd.api.types.is_categorical_dtype(adata.obs["condition"]):
    adata.obs["condition"] = adata.obs["condition"].astype("category")

adata.write_h5ad("/content/GSE280767.h5ad", compression="gzip")