In [None]:
import os
import sys
import gzip
import pandas as pd
import numpy as np
from pathlib import Path
import subprocess
import urllib.request
import tarfile
import time
import re

# Check if required packages are installed
try:
    import anndata as ad
    import scanpy as sc
except ImportError:
    print("Installing required packages...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "anndata", "scanpy"])
    import anndata as ad
    import scanpy as sc

def download_dataset(output_dir):
    """
    Download the GSE208240 dataset if not already present.
    
    Args:
        output_dir: Directory to save the downloaded files
    
    Returns:
        Path to the extracted data directory
    """
    output_dir = Path(output_dir)
    output_dir.mkdir(exist_ok=True, parents=True)
    
    # Define paths
    main_dataset_path = output_dir / "GSE208240_CRISPRi_perturbseq_sarscov2_filtered.tar.gz"
    viral_ref_path = output_dir / "GSE208240_viral_references.tar.gz"
    extracted_dir = output_dir / "GSE208240_extracted"
    data_path = extracted_dir / "data" / "sunshine" / "perturb_seq" / "sars_cov_2_geo_upload" / "CRISPRi_perturbseq_sarscov2_filtered"
    
    # Check if we already have the extracted data
    if data_path.exists():
        print(f"Using already extracted data at {data_path}")
        return data_path
    
    # URLs for the dataset files
    geo_url = "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE208nnn/GSE208240/suppl/GSE208240_CRISPRi_perturbseq_sarscov2_filtered.tar.gz"
    viral_ref_url = "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE208nnn/GSE208240/suppl/GSE208240_viral_references.tar.gz"
    
    # Download and extract the main dataset
    if not main_dataset_path.exists():
        print(f"Downloading main dataset from {geo_url}...")
        try:
            urllib.request.urlretrieve(geo_url, main_dataset_path)
        except Exception as e:
            raise RuntimeError(f"Failed to download dataset: {e}")
    
    # Extract the main dataset if not already extracted
    if not extracted_dir.exists():
        print("Extracting main dataset...")
        try:
            with tarfile.open(main_dataset_path, "r:gz") as tar:
                tar.extractall(path=extracted_dir)
        except Exception as e:
            raise RuntimeError(f"Failed to extract dataset: {e}")
    
    # Download and extract viral references
    if not viral_ref_path.exists():
        print(f"Downloading viral references from {viral_ref_url}...")
        try:
            urllib.request.urlretrieve(viral_ref_url, viral_ref_path)
        except Exception as e:
            print(f"Warning: Failed to download viral references: {e}")
    
    # Extract viral references if not already extracted
    if viral_ref_path.exists() and not (output_dir / "viral_references").exists():
        print("Extracting viral references...")
        try:
            with tarfile.open(viral_ref_path, "r:gz") as tar:
                tar.extractall(path=output_dir)
        except Exception as e:
            print(f"Warning: Failed to extract viral references: {e}")
    
    # Verify the data path exists
    if not data_path.exists():
        raise FileNotFoundError(f"Data path not found after extraction: {data_path}")
    
    return data_path

def load_metadata(data_dir):
    """
    Load metadata files (features, barcodes, cell identities).
    
    Args:
        data_dir: Directory containing the 10x Genomics data files
    
    Returns:
        Tuple of (features, barcodes, cell_identities)
    """
    data_dir = Path(data_dir)
    
    # Check if required files exist
    required_files = ['features.tsv.gz', 'barcodes.tsv.gz', 'cell_identities.csv']
    for file in required_files:
        if not (data_dir / file).exists():
            raise FileNotFoundError(f"Required file not found: {data_dir / file}")
    
    print("Loading features...")
    with gzip.open(data_dir / 'features.tsv.gz', 'rt') as f:
        features = pd.read_csv(f, sep='\t', header=None, names=['gene_id', 'gene_symbol', 'feature_type'])
    
    print("Loading barcodes...")
    with gzip.open(data_dir / 'barcodes.tsv.gz', 'rt') as f:
        barcodes = pd.read_csv(f, sep='\t', header=None, names=['barcode'])
    
    print("Loading cell identities...")
    cell_identities = pd.read_csv(data_dir / 'cell_identities.csv')
    
    return features, barcodes, cell_identities

def process_cell_identities(cell_identities):
    """
    Process cell identities to extract perturbation information.
    
    Args:
        cell_identities: DataFrame containing cell identity information
    
    Returns:
        DataFrame with processed perturbation information
    """
    print("Processing cell identities...")
    
    # Check required columns
    required_columns = ['cell_barcode', 'guide_identity', 'good_coverage', 'number_of_guides']
    for col in required_columns:
        if col not in cell_identities.columns:
            raise ValueError(f"Required column not found in cell_identities: {col}")
    
    # Filter for cells with good coverage
    good_cells = cell_identities[cell_identities['good_coverage'] == True].copy()
    print(f"Found {len(good_cells)} cells with good coverage out of {len(cell_identities)} total cells")
    
    # Extract perturbation information
    def extract_perturbation(guide_identity):
        """Extract gene names from guide identities"""
        if pd.isna(guide_identity):
            return "Unknown"
        
        # Split by semicolon to handle multiple guides
        guides = guide_identity.split(';')
        
        # Extract gene names from guides
        genes = []
        for guide in guides:
            # Handle non-targeting guides
            if guide.lower() == 'non-targeting' or guide.lower() == 'nt' or 'non-targeting' in guide.lower():
                genes.append('non-targeting')
                continue
                
            # The gene name is typically at the beginning before an underscore
            if '_' in guide:
                gene = guide.split('_')[0]
                genes.append(gene)
        
        if not genes:
            return "Unknown"
        
        # Remove duplicates and sort
        unique_genes = sorted(set(genes))
        
        # Join multiple genes with a plus sign
        return " + ".join(unique_genes)
    
    # Apply the extraction function
    good_cells['perturbation_name'] = good_cells['guide_identity'].apply(extract_perturbation)
    
    # Determine if the perturbation is targeting or non-targeting
    def is_targeting(guide_identity):
        """Determine if a guide is targeting or non-targeting"""
        if pd.isna(guide_identity):
            return "Unknown"
        
        # Non-targeting guides often contain terms like "NT" or "non-targeting"
        if "NT" in guide_identity or "non-targeting" in guide_identity.lower() or "scrambled" in guide_identity.lower():
            return "Non-targeting"
        return "Targeting"
    
    good_cells['targeting'] = good_cells['guide_identity'].apply(is_targeting)
    
    # Set condition based on targeting status
    good_cells['condition'] = good_cells['targeting'].map({
        "Targeting": "test", 
        "Non-targeting": "control",
        "Unknown": "unknown"
    })
    
    # Print some statistics
    print(f"Perturbation types: {good_cells['perturbation_name'].nunique()} unique perturbations")
    print(f"Targeting distribution: {good_cells['targeting'].value_counts().to_dict()}")
    
    return good_cells

def create_anndata_from_mtx(data_dir, features, barcodes, cell_identities_processed):
    """
    Create an AnnData object from the matrix.mtx.gz file.
    
    Args:
        data_dir: Directory containing the 10x Genomics data files
        features: DataFrame containing gene information
        barcodes: DataFrame containing cell barcodes
        cell_identities_processed: DataFrame containing processed cell identity information
    
    Returns:
        AnnData object
    """
    print("Creating AnnData object from matrix.mtx.gz...")
    
    # Check if matrix file exists
    mtx_file = Path(data_dir) / 'matrix.mtx.gz'
    if not mtx_file.exists():
        raise FileNotFoundError(f"Matrix file not found: {mtx_file}")
    
    # Import scipy here to avoid potential import issues
    import scipy.io
    import scipy.sparse
    
    # Load the matrix
    print("Loading matrix.mtx.gz (this may take a while)...")
    start_time = time.time()
    try:
        mat = scipy.io.mmread(str(mtx_file))
        print(f"Matrix loaded in {time.time() - start_time:.2f} seconds")
    except Exception as e:
        print(f"Error loading matrix with mmread: {e}")
        print("Trying alternative loading method...")
        # Alternative loading method
        try:
            with gzip.open(mtx_file, 'rb') as f:
                # Skip header lines
                for line in f:
                    if not line.startswith(b'%'):
                        break
                # Read dimensions
                dims = line.decode().strip().split()
                n_genes, n_cells, n_entries = map(int, dims)
                
                # Initialize sparse matrix in COO format
                rows = np.zeros(n_entries, dtype=np.int32)
                cols = np.zeros(n_entries, dtype=np.int32)
                data = np.zeros(n_entries, dtype=np.float32)
                
                # Read entries
                for i in range(n_entries):
                    if i % 1000000 == 0:
                        print(f"  Loaded {i}/{n_entries} entries...")
                    line = f.readline().decode().strip().split()
                    rows[i] = int(line[0]) - 1  # 1-based to 0-based indexing
                    cols[i] = int(line[1]) - 1
                    data[i] = float(line[2])
                
                mat = scipy.sparse.coo_matrix((data, (rows, cols)), shape=(n_genes, n_cells))
                print(f"Matrix loaded in {time.time() - start_time:.2f} seconds")
        except Exception as e2:
            print(f"Error with alternative loading method: {e2}")
            raise RuntimeError(f"Failed to load matrix file: {e} / {e2}")
    
    # Convert to CSR format for efficient operations
    print("Converting to CSR format...")
    mat = scipy.sparse.csr_matrix(mat)
    
    # Verify dimensions
    if mat.shape[0] != len(features) or mat.shape[1] != len(barcodes):
        print(f"Warning: Matrix dimensions ({mat.shape}) don't match features ({len(features)}) and barcodes ({len(barcodes)})")
    
    # Create AnnData object
    print("Creating AnnData object...")
    adata = ad.AnnData(
        X=mat.T,  # Transpose because 10x matrix is genes x cells
        obs=pd.DataFrame(index=barcodes['barcode']),
        var=pd.DataFrame(index=features['gene_symbol'])
    )
    
    # Add gene_ids as a column in var
    adata.var['gene_id'] = features['gene_id'].values
    
    # Ensure var_names are unique
    adata.var_names_make_unique()
    
    # Filter the AnnData object to include only cells with good coverage
    good_barcodes = cell_identities_processed['cell_barcode'].values
    print(f"Filtering to {len(good_barcodes)} cells with good coverage...")
    adata = adata[adata.obs.index.isin(good_barcodes)].copy()
    
    # Create a mapping from barcode to row index in cell_identities_processed
    barcode_to_idx = {bc: i for i, bc in enumerate(cell_identities_processed['cell_barcode'])}
    
    # Create a list of indices in cell_identities_processed for each barcode in adata
    indices = [barcode_to_idx.get(bc) for bc in adata.obs.index]
    
    # Filter out None values (barcodes not in cell_identities_processed)
    valid_cells = [i for i, idx in enumerate(indices) if idx is not None]
    valid_indices = [idx for idx in indices if idx is not None]
    
    # Filter adata to include only valid cells
    adata = adata[valid_cells].copy()
    
    # Add metadata to adata.obs
    print("Adding metadata to AnnData object...")
    for col in ['perturbation_name', 'condition', 'targeting', 'guide_identity', 'number_of_guides']:
        if col in cell_identities_processed.columns:
            adata.obs[col] = cell_identities_processed.iloc[valid_indices][col].values
    
    # Add standardized metadata fields
    adata.obs['organism'] = 'Homo sapiens'
    adata.obs['cell_type'] = 'Lung epithelial cells'  # Calu-3 cells are lung epithelial cells
    adata.obs['crispr_type'] = 'CRISPRi'
    adata.obs['cancer_type'] = 'Lung Cancer'  # Calu-3 is a lung cancer cell line
    
    # Add dataset information to uns
    adata.uns['dataset'] = {
        'id': 'GSE208240',
        'title': 'Systematic functional interrogation of SARS-CoV-2 host factors using Perturb-seq',
        'description': 'CRISPRi perturbation of host factors in Calu-3 cells infected with SARS-CoV-2',
        'organism': 'Homo sapiens',
        'publication': 'Sunshine S, et al. Systematic functional interrogation of SARS-CoV-2 host factors using Perturb-seq. Nat Commun 2023.'
    }
    
    return adata

def run_pipeline(data_dir="."):
    """Run the full data processing pipeline in a Jupyter-friendly way."""
    try:
        # Download or locate the dataset
        data_path = download_dataset(data_dir)
        
        # Load metadata
        features, barcodes, cell_identities = load_metadata(data_path)
        
        # Process cell identities
        cell_identities_processed = process_cell_identities(cell_identities)
        
        # Create AnnData object
        adata = create_anndata_from_mtx(data_path, features, barcodes, cell_identities_processed)
        
        # Save the harmonized data to an h5ad file
        output_file = os.path.join(data_dir, "GSE208240_harmonized.h5ad")
        print(f"Saving harmonized data to {output_file}...")
        adata.write(output_file)
        
        print(f"Harmonized data saved to {output_file}")
        print(f"Final data shape: {adata.shape}")
        print(f"Metadata fields: {list(adata.obs.columns)}")
        
        # Print summary statistics
        print("\nSummary Statistics:")
        print(f"Number of cells: {adata.n_obs}")
        print(f"Number of genes: {adata.n_vars}")
        print(f"Number of unique perturbations: {adata.obs['perturbation_name'].nunique()}")
        print(f"Condition distribution: {adata.obs['condition'].value_counts().to_dict()}")
        
        return adata
        
    except Exception as e:
        print(f"Error processing dataset: {e}")
        import traceback
        traceback.print_exc()
        return None

# Run the pipeline. You can change data_dir as needed.
adata = run_pipeline(data_dir=".")
