In [None]:
import os
import sys
import gc
import gzip
import urllib.request
import tarfile
import pandas as pd
import numpy as np
import scanpy as sc
from scipy import sparse
from scipy.io import mmread
from pathlib import Path

def download_and_extract_data(data_dir):
    """
    Download and extract the GSE247599 dataset if not already present.
    """
    data_path = Path(data_dir) / "GSE247599"
    data_path.mkdir(parents=True, exist_ok=True)
    
    # Files we expect after extraction
    expected_files = [
        "GSM7897841_JKLAT-LRA-Positive_barcodes.tsv.gz",
        "GSM7897841_JKLAT-LRA-Positive_features.tsv.gz",
        "GSM7897841_JKLAT-LRA-Positive_matrix.mtx.gz",
        "GSM7897842_JKLAT-LRA-Negative_barcodes.tsv.gz",
        "GSM7897842_JKLAT-LRA-Negative_features.tsv.gz",
        "GSM7897842_JKLAT-LRA-Negative_matrix.mtx.gz",
        "GSM7897843_JKLAT-NoDrug_barcodes.tsv.gz",
        "GSM7897843_JKLAT-NoDrug_features.tsv.gz",
        "GSM7897843_JKLAT-NoDrug_matrix.mtx.gz"
    ]
    
    all_files_exist = all((data_path / file).exists() for file in expected_files)
    
    if not all_files_exist:
        print("Downloading GSE247599 dataset...")
        tar_path = data_path / "GSE247599_RAW.tar"
        
        if not tar_path.exists():
            url = "https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE247599&format=file"
            urllib.request.urlretrieve(url, tar_path)
        
        print("Extracting files...")
        with tarfile.open(tar_path, "r") as tar:
            tar.extractall(path=data_path)
    
    return data_path

def get_guide_to_gene_mapping(guide_features):
    """
    Create a mapping from guide names to target genes.
    For example, if 'HSPD0000005747_CCNT1' is the feature name, we map that to 'CCNT1'.
    """
    guide_to_gene = {}
    for _, row in guide_features.iterrows():
        guide_name = row['name']
        if '_' in guide_name:
            parts = guide_name.split('_')
            if len(parts) > 1:
                target_gene = parts[-1]
                guide_to_gene[guide_name] = target_gene
        else:
            guide_to_gene[guide_name] = guide_name
    return guide_to_gene

def process_sample(data_path, sample_id, condition_info, output_dir):
    """
    Process a single sample and save it as an intermediate h5ad file.
    """
    print(f"Processing {sample_id}...")
    output_file = output_dir / f"{sample_id}.h5ad"
    
    # Skip if already processed
    if output_file.exists():
        print(f"  {sample_id} already processed, skipping...")
        return output_file
    
    # File paths
    matrix_file = data_path / f"{sample_id}_matrix.mtx.gz"
    features_file = data_path / f"{sample_id}_features.tsv.gz"
    barcodes_file = data_path / f"{sample_id}_barcodes.tsv.gz"
    
    # Read features to separate gene expression and CRISPR guides
    with gzip.open(features_file, 'rt') as f:
        features_df = pd.read_csv(f, sep='\t', header=None, names=['id', 'name', 'feature_type'])
    
    # Identify gene expression and CRISPR guide features
    gene_indices = features_df['feature_type'] == 'Gene Expression'
    guide_indices = features_df['feature_type'] == 'CRISPR Guide Capture'
    
    gene_features = features_df[gene_indices].reset_index(drop=True)
    guide_features = features_df[guide_indices].reset_index(drop=True) if guide_indices.any() else None
    
    # Read barcodes and verify uniqueness
    with gzip.open(barcodes_file, 'rt') as f:
        barcodes = pd.read_csv(f, sep='\t', header=None)[0].values
    
    if len(set(barcodes)) != len(barcodes):
        print("Warning: Duplicate barcodes found in the barcodes file!")
    
    # Read the matrix
    print(f"  Reading matrix for {sample_id}...")
    with gzip.open(matrix_file, 'rb') as f:
        X = mmread(f).tocsr()
    
    # Transpose so that rows=cell-barcodes, columns=features
    X = X.transpose()
    
    # Extract the gene-expression columns
    gene_indices_list = np.where(gene_indices)[0]
    gene_X = X[:, gene_indices_list]
    
    # Build AnnData for gene expression
    print(f"  Creating AnnData object for {sample_id}...")
    adata = sc.AnnData(X=gene_X)
    adata.obs_names = pd.Index(barcodes)
    adata.var_names = pd.Index(gene_features['name'].values)
    adata.var['gene_ids'] = gene_features['id'].values
    
    # Make var names unique if needed
    adata.var_names_make_unique()
    
    # Add standardized metadata
    print(f"  Adding metadata for {sample_id}...")
    adata.obs['organism'] = "Homo sapiens"         # from the experiment description
    adata.obs['cell_type'] = "Jurkat T Cells"       # known from the publication
    adata.obs['crispr_type'] = "CRISPR KO"          # from the text, they used Cas9-based perturbation
    adata.obs['cancer_type'] = "T-cell leukemia"    # Jurkat is a T-leukemia line
    adata.obs['sample_id'] = sample_id
    
    # For user-defined conditions
    adata.obs['stimulation'] = condition_info.get('stimulation', 'NA')
    adata.obs['gfp_status'] = condition_info.get('gfp_status', 'NA')
    adata.obs['sample_condition'] = condition_info.get('condition', 'NA')
    
    # Parse CRISPR guides
    if guide_features is not None and not guide_features.empty:
        print(f"  Processing CRISPR guide data for {sample_id}...")
        guide_indices_list = np.where(guide_indices)[0]
        
        # guide -> gene mapping
        guide_to_gene = get_guide_to_gene_mapping(guide_features)
        
        # We'll iterate over cells in small batches to avoid memory spikes
        batch_size = 20000
        num_cells = X.shape[0]
        perturbation_names = ["None"] * num_cells
        
        for i in range(0, num_cells, batch_size):
            end_idx = min(i + batch_size, num_cells)
            if i % 100000 == 0:
                print(f"  Processing cells {i}-{end_idx} of {num_cells}...")
            
            # Extract only the guide columns for the batch
            batch_X = X[i:end_idx, guide_indices_list].toarray()
            
            # For each cell in the batch
            for j in range(batch_X.shape[0]):
                nonzero_indices = np.where(batch_X[j, :] > 0)[0]
                if len(nonzero_indices) > 0:
                    # The guide names we see
                    detected_guides = [guide_features.iloc[idx]['name'] for idx in nonzero_indices]
                    
                    # Map to target genes
                    target_genes = [guide_to_gene.get(g, g) for g in detected_guides]
                    target_genes = sorted(set(target_genes))
                    
                    # Summarize in a single string
                    perturbation_names[i + j] = " + ".join(target_genes)
            
            del batch_X
            gc.collect()
        
        adata.obs['perturbation_name'] = perturbation_names
        adata.uns['guide_targets'] = guide_to_gene
    else:
        adata.obs['perturbation_name'] = "None"
    
    # Determine whether a cell is "Control" or "Test" based on the 'perturbation_name'
    def classify_condition(pert_name):
        if pert_name in ['None', 'NegativeControl1', 'NegativeControl2', 
                         'NegativeControl3', 'HScontrol_AAVS1']:
            return 'Control'
        else:
            return 'Test'
    
    adata.obs['condition'] = adata.obs['perturbation_name'].apply(classify_condition)
    
    # Save the processed sample
    print(f"  Saving {sample_id} to {output_file}")
    adata.write_h5ad(output_file)
    
    # Cleanup
    del X, gene_X, adata
    gc.collect()
    
    return output_file

def combine_samples(sample_files, output_file):
    """
    Combine multiple h5ad files into a single h5ad file.
    """
    print("Combining samples...")
    adata_list = []
    for i, sf in enumerate(sample_files):
        print(f"  Reading {sf}")
        ad = sc.read_h5ad(sf)
        adata_list.append(ad)
    
    print("  Concatenating samples...")
    adata_combined = sc.concat(adata_list, join='outer', merge='same')
    
    print(f"  Saving combined dataset to {output_file}")
    adata_combined.write_h5ad(output_file)
    
    print(f"Dataset saved to {output_file}")
    print(f"Final dataset shape: {adata_combined.shape}")
    print(f"Number of cells: {adata_combined.n_obs}")
    print(f"Number of genes: {adata_combined.n_vars}")
    
    return output_file

def filter_combined_dataset(combined_file, min_genes=200, min_counts=500):
    """
    Filter the combined AnnData object based on QC metrics:
      - Retains cells with at least `min_genes` expressed genes
      - Retains cells with at least `min_counts` total counts
    """
    print("Filtering combined dataset for QC metrics...")
    adata = sc.read_h5ad(combined_file)
    
    # Compute QC metrics
    adata.obs["n_counts"] = np.ravel(adata.X.sum(axis=1))
    adata.obs["n_genes"] = np.ravel((adata.X > 0).sum(axis=1))
    
    # Apply the filters
    filtered_adata = adata[
        (adata.obs["n_genes"] >= min_genes) & 
        (adata.obs["n_counts"] >= min_counts)
    ].copy()
    
    print("Original shape:", adata.shape)
    print("Filtered shape:", filtered_adata.shape)
    
    filtered_file = combined_file.parent / "GSE247599_harmonized_filtered.h5ad"
    filtered_adata.write_h5ad(filtered_file)
    print("Filtered dataset saved to", filtered_file)
    return filtered_file

def harmonize_dataset(data_dir):
    """
    Main entry point:
      - Download/unzip data if needed
      - Process each sample
      - Combine into a single h5ad
      - Filter the combined dataset based on QC metrics
    """
    data_path = download_and_extract_data(data_dir)
    
    # Create a directory for intermediate output
    output_dir = data_path / "processed"
    output_dir.mkdir(exist_ok=True)
    
    # Sample information
    samples = {
        "GSM7897841_JKLAT-LRA-Positive": {
            "condition": "Stimulated",
            "gfp_status": "GFP+",
            "stimulation": "PMA/I"
        },
        "GSM7897842_JKLAT-LRA-Negative": {
            "condition": "Stimulated",
            "gfp_status": "GFP-",
            "stimulation": "PMA/I"
        },
        "GSM7897843_JKLAT-NoDrug": {
            "condition": "Unstimulated",
            "gfp_status": "NA",
            "stimulation": "None"
        }
    }
    
    # Process each sample
    sample_files = []
    for sample_id, condition_info in samples.items():
        sf = process_sample(data_path, sample_id, condition_info, output_dir)
        sample_files.append(sf)
    
    # Combine the samples
    combined_file = data_path / "GSE247599_harmonized.h5ad"
    combine_samples(sample_files, combined_file)
    
    # Apply filtering to the combined dataset
    filtered_file = filter_combined_dataset(combined_file, min_genes=200, min_counts=500)
    
    return filtered_file

if __name__ == "__main__":
    # Example usage: python script.py /path/to/data
    # If no argument is passed, default to current directory
    if len(sys.argv) > 1:
        data_dir = sys.argv[1]
    else:
        data_dir = "/content"

    final_file = harmonize_dataset(data_dir)
    print("Harmonization and filtering complete! File saved at:", final_file)
