In [None]:
import os
import pandas as pd
import numpy as np
import scanpy as sc
import anndata as ad
import scipy.sparse as sp
import re
import gzip
import urllib.request
import time
from pathlib import Path

def download_file(url, output_path, chunk_size=8192):
    """
    Download a file from a URL with progress reporting.
    
    Parameters:
    -----------
    url : str
        URL to download from.
    output_path : str or Path
        Path to save the downloaded file.
    chunk_size : int, optional
        Size of chunks to download at a time.
    """
    # Create parent directory if it doesn't exist
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    # Start the download
    start_time = time.time()
    try:
        with urllib.request.urlopen(url) as response:
            total_size = int(response.info().get('Content-Length', 0))
            downloaded_size = 0
            
            with open(output_path, 'wb') as f:
                while True:
                    chunk = response.read(chunk_size)
                    if not chunk:
                        break
                    
                    f.write(chunk)
                    downloaded_size += len(chunk)
                    
                    # Calculate progress
                    progress = downloaded_size / total_size * 100 if total_size > 0 else 0
                    elapsed_time = time.time() - start_time
                    speed = downloaded_size / (1024 * 1024 * elapsed_time) if elapsed_time > 0 else 0
                    
                    # Print progress
                    if total_size > 0:
                        print(f"\rDownloading: {progress:.1f}% ({downloaded_size/(1024*1024):.1f} MB / {total_size/(1024*1024):.1f} MB) at {speed:.2f} MB/s", end='')
                    else:
                        print(f"\rDownloaded: {downloaded_size/(1024*1024):.1f} MB at {speed:.2f} MB/s", end='')
        
        print(f"\nDownload completed: {output_path}")
    except Exception as e:
        print(f"\nError downloading {url}: {e}")
        if output_path.exists():
            output_path.unlink()  # Remove partial download
        raise

def process_GSE278572(data_dir, skip_matrix=False):
    """
    Process GSE278572 dataset and convert to h5ad format.
    
    Parameters:
    -----------
    data_dir : str
        Path to directory containing the dataset files.
    skip_matrix : bool, optional
        If True, skip downloading the large matrix file and create a placeholder matrix.
    
    Returns:
    --------
    adata : AnnData
        Processed dataset in AnnData format.
    """
    print(f"Processing GSE278572 dataset from {data_dir}")
    
    # Convert to Path object
    data_dir = Path(data_dir)
    
    # Define required files and their URLs
    required_files = {
        "GSE278572_barcodes.tsv.gz": "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE278nnn/GSE278572/suppl/GSE278572_barcodes.tsv.gz",
        "GSE278572_features.tsv.gz": "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE278nnn/GSE278572/suppl/GSE278572_features.tsv.gz",
        "GSE278572_matrix.mtx.gz": "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE278nnn/GSE278572/suppl/GSE278572_matrix.mtx.gz",
        "GSE278572_protospacer_calls_per_cell.csv.gz": "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE278nnn/GSE278572/suppl/GSE278572_protospacer_calls_per_cell.csv.gz"
    }
    
    # Check if files exist and download if needed
    for file, url in required_files.items():
        file_path = data_dir / file
        
        # Skip matrix file if requested
        if skip_matrix and file == "GSE278572_matrix.mtx.gz":
            print(f"Skipping download of large matrix file {file} as requested.")
            continue
            
        if not file_path.exists():
            print(f"File {file} not found. Downloading from {url}...")
            download_file(url, file_path)
        else:
            print(f"File {file} already exists.")
    
    # Read the 10X format data
    print("Reading 10X format data...")
    try:
        adata = sc.read_10x_mtx(
            data_dir,
            var_names='gene_symbols',
            cache=True,
            prefix="GSE278572_"
        )
    except Exception as e:
        print(f"Error reading 10X data: {e}")
        print("Creating a placeholder AnnData object with perturbation data only...")
        
        # Read barcodes and features to create a placeholder AnnData object
        barcodes = []
        with gzip.open(data_dir / "GSE278572_barcodes.tsv.gz", 'rt') as f:
            for line in f:
                barcodes.append(line.strip())
        
        features = []
        with gzip.open(data_dir / "GSE278572_features.tsv.gz", 'rt') as f:
            for line in f:
                parts = line.strip().split('\t')
                if len(parts) >= 2:
                    features.append(parts[1])  # Use gene symbol
        
        # Create a placeholder matrix
        placeholder_matrix = sp.csr_matrix((len(barcodes), len(features)), dtype=np.float32)
        
        # Create AnnData object
        adata = ad.AnnData(
            X=placeholder_matrix,
            obs=pd.DataFrame(index=barcodes),
            var=pd.DataFrame(index=features)
        )
        
        print(f"Created placeholder AnnData with {len(barcodes)} cells and {len(features)} genes")
        print("WARNING: Gene expression data is not available, only perturbation information will be processed")
    
    # Read perturbation data
    print("Reading perturbation data...")
    perturbation_df = pd.read_csv(data_dir / "GSE278572_protospacer_calls_per_cell.csv.gz")
    
    # Process perturbation data
    print("Processing perturbation data...")
    
    # Create a dictionary to map cell barcodes to perturbations
    cell_to_perturbation = dict(zip(perturbation_df['cell_barcode'], perturbation_df['feature_call']))
    
    # Add perturbation information to adata.obs
    adata.obs['perturbation'] = adata.obs.index.map(lambda x: cell_to_perturbation.get(x, 'Unknown'))
    
    # Add number of guides per cell
    adata.obs['num_guides'] = adata.obs['perturbation'].apply(
        lambda x: 0 if x == 'Unknown' else len(x.split('|'))
    )
    
    # Function to extract perturbation information
    def extract_perturbation_info(perturbation_str):
        if pd.isna(perturbation_str) or perturbation_str == 'Unknown':
            return pd.Series({
                'perturbation_type': 'Unknown',
                'target_genes': 'Unknown',
                'is_control': False,
                'guide_ids': 'Unknown'
            })
        
        # Split multiple perturbations
        perturbations = perturbation_str.split('|')
        
        # Extract information from each perturbation
        target_genes = []
        guide_ids = []
        perturbation_types = set()
        is_control = True
        has_targeting = False
        
        for pert in perturbations:
            parts = pert.split('_')
            if len(parts) >= 3:
                gene = parts[0]
                guide = parts[1]
                pert_type = '_'.join(parts[2:])  # In case there are multiple parts after the guide
                
                guide_ids.append(f"{gene}_{guide}")
                
                if 'Non-Targeting' in gene:
                    # This is a control perturbation
                    pass
                else:
                    # This is a targeting perturbation
                    has_targeting = True
                    target_genes.append(gene)
                
                perturbation_types.add(pert_type)
        
        # If there are any targeting guides, this is not a control
        is_control = not has_targeting
        
        # Join target genes with + for multiple perturbations
        target_genes_str = '+'.join(sorted(set(target_genes))) if target_genes else 'Non-Targeting'
        
        # Join guide IDs
        guide_ids_str = '|'.join(sorted(guide_ids))
        
        # Join perturbation types
        perturbation_type_str = '+'.join(sorted(perturbation_types))
        
        return pd.Series({
            'perturbation_type': perturbation_type_str,
            'target_genes': target_genes_str,
            'is_control': is_control,
            'guide_ids': guide_ids_str
        })
    
    # Apply the function to extract perturbation information
    perturbation_info = adata.obs['perturbation'].apply(extract_perturbation_info)
    adata.obs = pd.concat([adata.obs, perturbation_info], axis=1)
    
    # Set condition based on perturbation
    adata.obs['condition'] = adata.obs['is_control'].map({True: 'Control', False: 'Test'})
    
    # Based on the dataset description, this is a human T cell dataset
    adata.obs['organism'] = 'Homo sapiens'
    adata.obs['cell_type'] = 'CD4+ T Cells'
    
    # From the dataset description, we know this includes resting and stimulated CD4+ Tregs and Teffs
    adata.obs['stimulation_status'] = 'Unknown'
    adata.obs['t_cell_subtype'] = 'Unknown'
    
    # Extract CRISPR type from perturbation data
    adata.obs['crispr_type'] = 'CRISPRi'
    
    # Based on the dataset description, this is non-cancer
    adata.obs['cancer_type'] = 'Non-Cancer'
    
    # Harmonize column names to match the requested format
    adata.obs.rename(columns={
        'target_genes': 'perturbation_name'
    }, inplace=True)
    
    # Add statistics about the dataset
    adata.uns['perturbation_stats'] = {
        'total_cells': adata.n_obs,
        'control_cells': adata.obs['is_control'].sum(),
        'test_cells': (~adata.obs['is_control']).sum(),
        'unique_target_genes': len(adata.obs['perturbation_name'].unique()),
        'crispr_types': 'CRISPRi'
    }
    
    # Preserve the raw counts
    adata.raw = adata.copy()
    
    # Add additional metadata based on the dataset description and our analysis
    adata.uns['dataset_metadata'] = {
        'accession': 'GSE278572',
        'title': 'Centralized control of dynamic gene regulatory circuits governs human T cell rest and activation [Perturb-CITE-seq]',
        'description': 'Perturb-CITE-seq performed with 2 donors including resting and stimulated CD4+ Tregs and Teffs',
        'organism': 'Homo sapiens',
        'experiment_type': 'Perturb-CITE-seq',
        'contributors': 'Arce MM, Umhoefer JM, Arang N, Kasinathan S, Freimer JW, Steinhart Z, Shen H, Pham M, Ota M, Wadhera A, Dorovskyi D, Zhou Y, Rama D, Chen Y, Liu Q, Shy BR, Satpathy AT, Carnevale J, Krogan NJ, Pritchard JK, Marson A',
        'perturbation_summary': 'The dataset contains 28 unique target genes with CRISPRi perturbations. About 73% of cells have a single guide, 20% have two guides, and the rest have multiple guides. Approximately 12% of cells are controls (non-targeting guides only).'
    }
    
    print(f"Processed dataset with {adata.n_obs} cells and {adata.n_vars} genes")
    
    return adata

def update_adata(adata):
    # Ensure 'perturbation_name' and 'condition' are not categorical
    if pd.api.types.is_categorical_dtype(adata.obs['perturbation_name']):
        adata.obs['perturbation_name'] = adata.obs['perturbation_name'].astype(str)
    if pd.api.types.is_categorical_dtype(adata.obs['condition']):
        adata.obs['condition'] = adata.obs['condition'].astype(str)
    
    # Change 'Unknown' to 'Non-targeting' in 'perturbation_name' column
    # and set 'condition' to 'Control' for the same rows
    mask = adata.obs['perturbation_name'] == 'Unknown'
    adata.obs.loc[mask, 'perturbation_name'] = 'Non-targeting'
    adata.obs.loc[mask, 'condition'] = 'Control'
    
    return adata

# ----- Parameters for Jupyter Notebook Execution -----
data_dir = "/content"  # Set your data directory. This directory will be created if it doesn't exist.
skip_matrix = False    # Set to True to skip downloading the large matrix file and create a placeholder matrix.
output_path = os.path.join(data_dir, "GSE278572.h5ad")  # Define output file path

# Create the data directory if it doesn't exist
Path(data_dir).mkdir(parents=True, exist_ok=True)

# Process the dataset
adata = process_GSE278572(data_dir, skip_matrix=skip_matrix)

# Incorporate adhoc change by updating the AnnData object
adata = update_adata(adata)
print(adata.obs)

# Save the processed AnnData object to h5ad file
print(f"Saving processed dataset to {output_path}")
adata.write(output_path)
print("Done!")
