In [4]:
import os
import sys
import gzip
import shutil
import requests
import numpy as np
import pandas as pd
import scanpy as sc
from scipy import sparse
from pathlib import Path
from anndata import AnnData

# Constants
DATASET_ID = "GSE225807"
DATASET_URL = "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE225nnn/GSE225807/suppl/GSE225807_RAW.tar"
FILE_NAMES = {
    "barcodes": "GSM7056649_barcodes.tsv.gz",
    "features": "GSM7056649_features.tsv.gz",
    "matrix": "GSM7056649_matrix.mtx.gz",
    "mu_matrix": "GSM7056649_mu_matrix.tsv.gz",
    "sgrna_mapping": "GSM7056650_bc_to_sgrna_mapping_3_06.csv.gz"
}

def download_dataset(output_dir):
    """Download the dataset if it doesn't exist."""
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    tar_path = output_dir / f"{DATASET_ID}_RAW.tar"
    
    # Check if all files already exist
    all_files_exist = all((output_dir / filename).exists() for filename in FILE_NAMES.values())
    
    if all_files_exist:
        print(f"All files already exist in {output_dir}. Skipping download.")
        return
    
    # Download the tar file if it doesn't exist
    if not tar_path.exists():
        print(f"Downloading {DATASET_ID} dataset...")
        response = requests.get(DATASET_URL, stream=True)
        response.raise_for_status()
        
        with open(tar_path, 'wb') as f:
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)
        print(f"Download complete: {tar_path}")
    
    # Extract the tar file
    print(f"Extracting files from {tar_path}...")
    import tarfile
    with tarfile.open(tar_path, 'r') as tar:
        tar.extractall(path=output_dir)
    print("Extraction complete.")

def load_data(data_dir):
    """Load the dataset from the specified directory."""
    data_dir = Path(data_dir)
    
    # Check if all required files exist
    for file_type, file_name in FILE_NAMES.items():
        file_path = data_dir / file_name
        if not file_path.exists():
            raise FileNotFoundError(f"Required file not found: {file_path}")
    
    print("Loading gene expression matrix...")
    # Load the count matrix in sparse format
    mtx_path = data_dir / FILE_NAMES["matrix"]
    counts = sc.read_mtx(str(mtx_path)).X.T  # Transpose to get cells as rows
    
    # Load barcodes (cell IDs)
    barcodes_path = data_dir / FILE_NAMES["barcodes"]
    with gzip.open(barcodes_path, 'rt') as f:
        barcodes = [line.strip() for line in f]
    
    # Load gene features
    features_path = data_dir / FILE_NAMES["features"]
    features_df = pd.read_csv(features_path, sep='\t', header=None, names=['gene_id', 'gene_name', 'feature_type'], compression='gzip')
    
    # Load sgRNA mapping
    sgrna_path = data_dir / FILE_NAMES["sgrna_mapping"]
    
    # Read the file manually since it has a non-standard format
    with gzip.open(sgrna_path, 'rt') as f:
        content = f.readlines()
    
    # Print the first few lines to debug
    print("First few lines of sgRNA mapping file:")
    for i in range(min(5, len(content))):
        print(f"Line {i}: {content[i].strip()}")
    
    # Try to parse the file
    try:
        # Try reading with pandas directly
        sgrna_df = pd.read_csv(sgrna_path, compression='gzip', sep='\t')
        print(f"Successfully read sgRNA mapping file with pandas. Shape: {sgrna_df.shape}")
    except Exception as e:
        print(f"Error reading with pandas: {e}")
        
        # Try manual parsing
        try:
            # Parse header
            header_line = content[0].strip()
            header = header_line.split('\t')
            header = [h.strip() for h in header]
            
            # Parse data
            data = []
            for line in content[1:]:
                row = line.strip().split('\t')
                # Ensure each row has the same number of columns as the header
                if len(row) > len(header):
                    # If there are more columns in the data than in the header,
                    # we need to add more column names
                    header.extend([f'col{i}' for i in range(len(header), len(row))])
                elif len(row) < len(header):
                    # If there are fewer columns in the data than in the header,
                    # we need to pad the row with empty strings
                    row.extend([''] * (len(header) - len(row)))
                data.append(row)
            
            # Create DataFrame
            sgrna_df = pd.DataFrame(data, columns=header)
            print(f"Successfully parsed sgRNA mapping file manually. Shape: {sgrna_df.shape}")
        except Exception as e:
            print(f"Error with manual parsing: {e}")
            raise
    
    print(f"sgRNA mapping DataFrame shape: {sgrna_df.shape}")
    print(f"sgRNA mapping columns: {sgrna_df.columns.tolist()}")
    
    # Create AnnData object
    adata = AnnData(X=counts, 
                   obs=pd.DataFrame(index=barcodes),
                   var=pd.DataFrame(index=features_df['gene_name']))
    
    # Add gene information
    adata.var['gene_id'] = features_df['gene_id'].values
    adata.var['feature_type'] = features_df['feature_type'].values
    
    # Add sgRNA mapping information
    # Identify the barcode, sgRNA, and RBP columns
    barcode_col = None
    sgrna_col = None
    rbp_col = None
    
    # Look for columns with expected names
    for col in sgrna_df.columns:
        if 'barcode' in col.lower():
            barcode_col = col
        elif 'sgrna' in col.lower() or 'guide' in col.lower():
            sgrna_col = col
        elif 'rbp' in col.lower() or 'protein' in col.lower() or 'target' in col.lower():
            rbp_col = col
    
    # If we couldn't find the columns by name, try to infer them by position
    if barcode_col is None and len(sgrna_df.columns) >= 3:
        # Assume the second column is the barcode (after the index)
        barcode_col = sgrna_df.columns[1]
    
    if sgrna_col is None and len(sgrna_df.columns) >= 3:
        # Assume the third column is the sgRNA
        sgrna_col = sgrna_df.columns[2]
    
    if rbp_col is None and len(sgrna_df.columns) >= 4:
        # Assume the fourth column is the RBP
        rbp_col = sgrna_df.columns[3]
    
    # Final check
    if barcode_col is None or sgrna_col is None or rbp_col is None:
        # If we still can't find the columns, print the first few rows to help debug
        print("First few rows of sgRNA mapping file:")
        print(sgrna_df.head())
        raise ValueError("Could not identify barcode, sgRNA, or RBP columns in sgRNA mapping file.")
    
    print(f"Using columns: barcode={barcode_col}, sgRNA={sgrna_col}, RBP={rbp_col}")
    
    # Convert to string
    sgrna_df[barcode_col] = sgrna_df[barcode_col].astype(str)
    
    # Create a mapping from cell barcode to sgRNA and RBP
    barcode_to_sgrna = dict(zip(sgrna_df[barcode_col], sgrna_df[sgrna_col]))
    barcode_to_rbp = dict(zip(sgrna_df[barcode_col], sgrna_df[rbp_col]))
    
    # Add sgRNA and RBP information to obs
    # Extract the barcode part from the cell index (remove the suffix if present)
    cell_barcodes = [bc.split('-')[0] if '-' in bc else bc for bc in adata.obs.index]
    
    # Map cell barcodes to sgRNAs and RBPs
    adata.obs['sgRNA'] = [barcode_to_sgrna.get(bc, 'unknown') for bc in cell_barcodes]
    adata.obs['RBP'] = [barcode_to_rbp.get(bc, 'unknown') for bc in cell_barcodes]
    
    # Print some statistics
    print(f"Cells with known sgRNA: {sum(adata.obs['sgRNA'] != 'unknown')} out of {adata.n_obs}")
    print(f"Cells with known RBP: {sum(adata.obs['RBP'] != 'unknown')} out of {adata.n_obs}")
    
    # Load mu_matrix if available (contains latent variables)
    mu_path = data_dir / FILE_NAMES["mu_matrix"]
    if mu_path.exists():
        print("Loading latent variables from mu_matrix...")
        mu_df = pd.read_csv(mu_path, sep='\t', index_col=0, compression='gzip')
        
        # Only include cells that are in the AnnData object
        common_cells = list(set(mu_df.index) & set(adata.obs.index))
        if common_cells:
            mu_df = mu_df.loc[common_cells]
            for col in mu_df.columns:
                adata.obs[col] = [mu_df.loc[bc, col] if bc in mu_df.index else np.nan for bc in adata.obs.index]
    
    print(f"Loaded data with {adata.n_obs} cells and {adata.n_vars} genes.")
    return adata

def harmonize_data(adata):
    """Harmonize the data according to the specified standards."""
    print("Harmonizing data...")
    
    # Add organism information
    adata.obs['organism'] = 'Homo sapiens'
    
    # Add cell_type information (based on the paper, these are K562 cells)
    adata.obs['cell_type'] = 'K562'
    
    # Add CRISPR type information
    adata.obs['crispr_type'] = 'CRISPRi'  # Based on the paper, this is CRISPRi
    
    # Add cancer_type information
    adata.obs['cancer_type'] = 'Leukemia'  # K562 is a leukemia cell line
    
    # Add condition information
    # Cells with 'negative' as RBP are controls, others are test
    adata.obs['condition'] = 'test'
    adata.obs.loc[adata.obs['RBP'] == 'negative', 'condition'] = 'control'
    # Also mark cells with unknown RBP as 'unknown'
    adata.obs.loc[adata.obs['RBP'] == 'unknown', 'condition'] = 'unknown'
    
    # Add perturbation_name information
    # For control cells, set to 'Non-targeting'
    # For test cells, use the RBP name
    adata.obs['perturbation_name'] = adata.obs['RBP']
    adata.obs.loc[adata.obs['RBP'] == 'negative', 'perturbation_name'] = 'Non-targeting'
    
    # Print statistics
    print(f"Control cells: {sum(adata.obs['condition'] == 'control')}")
    print(f"Test cells: {sum(adata.obs['condition'] == 'test')}")
    print(f"Unknown cells: {sum(adata.obs['condition'] == 'unknown')}")
    
    # Print top perturbation targets
    perturbation_counts = adata.obs['perturbation_name'].value_counts()
    print("\nTop 10 perturbation targets:")
    for target, count in zip(perturbation_counts.index[:10], perturbation_counts.values[:10]):
        print(f"  {target}: {count}")
    
    # Add additional metadata
    adata.uns['dataset_id'] = DATASET_ID
    adata.uns['dataset_description'] = "Perturb-seq data from RNA-binding proteins knockdown experiment"
    adata.uns['paper_title'] = "A Unified Framework for Systematic Identification of Post-Transcriptional Regulatory Modules"
    
    return adata

def main(data_dir, output_file):
    """Main function to process and harmonize the data."""
    data_dir = Path(data_dir)
    
    # Download the dataset if necessary
    download_dataset(data_dir)
    
    # Load the data
    adata = load_data(data_dir)
    
    # Harmonize the data
    adata = harmonize_data(adata)
    
    # Save the harmonized data
    output_path = Path(output_file)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    print(f"Saving harmonized data to {output_path}...")
    adata.write(output_path)
    print("Done!")

# Instead of using command-line arguments, set the parameters directly here:
data_dir = "dataa"  # directory to store or load the dataset
output_file = "GSE225807_harmonized.h5ad"  # output file path

# Run the main function
main(data_dir, output_file)







In [None]:

adata = sc.read_h5ad("GSE225807_harmonized.h5ad")



# Update condition column
adata.obs["condition"] = "Test"
adata.obs.loc[adata.obs["perturbation_name"] == "unknown", "condition"] = "Control"

# Update perturbation_name column
adata.obs.loc[adata.obs["perturbation_name"] == "unknown", "perturbation_name"] = "Non-targeting"

# # Save back
adata.write("GSE225807_updated.h5ad")
# print("[INFO] Updated file saved.")
