In [None]:
import os
import sys
import gzip
import shutil
import tarfile
import urllib.request
from pathlib import Path
import numpy as np
import pandas as pd
import scanpy as sc
import anndata as ad
from scipy import sparse
import h5py
from tqdm import tqdm

# Constants
GEO_ACCESSION = "GSE197452"
DOWNLOAD_URL = "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE197nnn/GSE197452/suppl/GSE197452_RAW.tar"
DATASETS = [
    {
        "name": "3prime_PBMC_Illumina",
        "files": {
            "expression": "GSM6297378_expression_counts_Three_Ill.txt.gz",
            "genes": "GSM6297378_genes_counts_Three_Ill.txt.gz",
            "cells": "GSM6297378_cells_counts_Three_Ill.txt.gz"
        },
        "metadata": {
            "organism": "Homo sapiens",
            "cell_type": "PBMC",
            "crispr_type": "None",
            "cancer_type": "Non-Cancer",
            "condition": "control",
            "perturbation_name": "None"
        }
    },
    {
        "name": "5prime_PBMC_Illumina",
        "files": {
            "expression": "GSM6297380_expression_counts_Five_Ill.txt.gz",
            "genes": "GSM6297380_genes_counts_Five_Ill.txt.gz",
            "cells": "GSM6297380_cells_counts_Five_Ill.txt.gz"
        },
        "metadata": {
            "organism": "Homo sapiens",
            "cell_type": "PBMC",
            "crispr_type": "None",
            "cancer_type": "Non-Cancer",
            "condition": "control",
            "perturbation_name": "None"
        }
    },
    {
        "name": "5prime_PBMC_mixture_Illumina",
        "files": {
            "expression": "GSM6297382_expression_counts_FiveMix_Ill.txt.gz",
            "genes": "GSM6297382_genes_counts_FiveMix_Ill.txt.gz",
            "cells": "GSM6297382_cells_counts_FiveMix_Ill.txt.gz"
        },
        "metadata": {
            "organism": "Homo sapiens",
            "cell_type": "PBMC",
            "crispr_type": "None",
            "cancer_type": "Non-Cancer",
            "condition": "control",
            "perturbation_name": "None"
        }
    },
    {
        "name": "Perturb-seq_Illumina",
        "files": {
            "h5": "GSM6297388_filtered_feature_bc_matrix.pert.ill.h5"
        },
        "metadata": {
            "organism": "Homo sapiens",
            "cell_type": "K562",
            "cancer_type": "Leukemia"
        }
    }
]

def download_data(data_dir):
    """Download the dataset if not already present."""
    download_dir = os.path.join(data_dir, "download")
    os.makedirs(download_dir, exist_ok=True)
    
    tar_file = os.path.join(download_dir, f"{GEO_ACCESSION}_RAW.tar")
    
    if not os.path.exists(tar_file):
        print(f"Downloading {GEO_ACCESSION} dataset...")
        urllib.request.urlretrieve(DOWNLOAD_URL, tar_file)
        print("Download complete!")
    else:
        print(f"Dataset already downloaded at {tar_file}")
    
    return tar_file

def extract_data(tar_file, data_dir):
    """Extract the dataset files."""
    extract_dir = os.path.join(data_dir, "extracted")
    
    if os.path.exists(extract_dir) and len(os.listdir(extract_dir)) > 0:
        print(f"Data already extracted in {extract_dir}")
        return extract_dir
    
    os.makedirs(extract_dir, exist_ok=True)
    
    print(f"Extracting files to {extract_dir}...")
    with tarfile.open(tar_file, 'r') as tar:
        tar.extractall(path=extract_dir)
    
    print("Extraction complete!")
    return extract_dir

def read_mtx_file(expression_file, genes_file, cells_file):
    """Read expression data from MTX format files."""
    print(f"Reading expression data from {expression_file}")
    
    with gzip.open(expression_file, 'rt') as f:
        # Skip header lines until we get dimensions
        for line in f:
            if not line.startswith('%'):
                dimensions = line.strip().split()
                n_genes, n_cells, n_entries = map(int, dimensions)
                break
        
        data = []
        row_indices = []
        col_indices = []
        
        with tqdm(total=n_entries, desc="Reading expression data") as pbar:
            for line in f:
                gene_idx, cell_idx, value = line.strip().split()
                row_indices.append(int(gene_idx) - 1)  # 0-based indexing
                col_indices.append(int(cell_idx) - 1)    # 0-based indexing
                data.append(float(value))
                pbar.update(1)
    
    matrix = sparse.csr_matrix((data, (row_indices, col_indices)), shape=(n_genes, n_cells))
    
    print(f"Reading genes metadata from {genes_file}")
    genes_df = pd.read_csv(genes_file, sep='\t', header=None, names=['gene_id', 'gene_name', 'feature_type'])
    
    print(f"Reading cells metadata from {cells_file}")
    cell_barcodes = pd.read_csv(cells_file, sep='\t', header=None)[0].values
    
    if genes_df.shape[0] != n_genes:
        print(f"Warning: Number of genes in metadata ({genes_df.shape[0]}) doesn't match matrix dimensions ({n_genes})")
        if genes_df.shape[0] < n_genes:
            print(f"Adding {n_genes - genes_df.shape[0]} missing genes to metadata")
            missing_genes = pd.DataFrame({
                'gene_id': [f'Unknown_{i}' for i in range(genes_df.shape[0], n_genes)],
                'gene_name': [f'Unknown_{i}' for i in range(genes_df.shape[0], n_genes)],
                'feature_type': ['Gene Expression'] * (n_genes - genes_df.shape[0])
            })
            genes_df = pd.concat([genes_df, missing_genes], ignore_index=True)
        else:
            print("Truncating gene metadata to match matrix dimensions")
            genes_df = genes_df.iloc[:n_genes]
    
    var_df = genes_df.copy()
    var_df.index = var_df['gene_name']  # Use gene_name as index
    
    adata = ad.AnnData(X=matrix.T,  # cells x genes
                       obs=pd.DataFrame(index=cell_barcodes),
                       var=var_df)
    
    return adata

def read_h5_file(h5_file):
    """Read expression data from 10x Genomics H5 file."""
    print(f"Reading data from {h5_file}")
    
    with h5py.File(h5_file, 'r') as f:
        shape = f['matrix/shape'][:]
        n_features, n_barcodes = shape
        
        data = f['matrix/data'][:]
        indices = f['matrix/indices'][:]
        indptr = f['matrix/indptr'][:]
        
        matrix = sparse.csc_matrix((data, indices, indptr), shape=shape).tocsr()
        
        feature_ids = f['matrix/features/id'][:]
        feature_names = f['matrix/features/name'][:]
        feature_types = f['matrix/features/feature_type'][:]
        
        # Convert byte strings to regular strings
        feature_ids = [x.decode('utf-8') for x in feature_ids]
        feature_names = [x.decode('utf-8') for x in feature_names]
        feature_types = [x.decode('utf-8') for x in feature_types]
        
        barcodes = f['matrix/barcodes'][:]
        barcodes = [x.decode('utf-8') for x in barcodes]
    
    gene_indices = np.where(np.array(feature_types) == 'Gene Expression')[0]
    gene_matrix = matrix[gene_indices, :]
    
    gene_var = pd.DataFrame({
        'feature_type': np.array(feature_types)[gene_indices],
        'gene_id': np.array(feature_ids)[gene_indices],
    }, index=np.array(feature_names)[gene_indices])
    
    if not gene_var.index.is_unique:
        print(f"Warning: Found {gene_var.index.duplicated().sum()} duplicate gene symbols")
        new_index = []
        seen = set()
        for idx, name in enumerate(gene_var.index):
            if name in seen:
                new_index.append(f"{name}_{gene_var.iloc[idx]['gene_id']}")
            else:
                new_index.append(name)
                seen.add(name)
        gene_var.index = new_index
        print(f"After making unique: {gene_var.index.duplicated().sum()} duplicates")
    
    adata = ad.AnnData(X=gene_matrix.T,  # cells x genes
                       obs=pd.DataFrame(index=barcodes),
                       var=gene_var)
    
    # Extract CRISPR guide information
    guide_indices = np.where(np.array(feature_types) == 'CRISPR Guide Capture')[0]
    
    if len(guide_indices) > 0:
        guide_matrix = matrix[guide_indices, :]
        guide_names = np.array(feature_names)[guide_indices]
        
        # For each cell, find the guide with the highest count
        guide_assignments = {}
        for i in range(guide_matrix.shape[1]):
            cell_barcode = barcodes[i]
            cell_guides = guide_matrix[:, i].toarray().flatten()
            if np.sum(cell_guides) > 0:
                max_guide_idx = np.argmax(cell_guides)
                guide_name = guide_names[max_guide_idx]
                guide_assignments[cell_barcode] = guide_name
        
        # Add guide information to obs
        adata.obs['guide'] = pd.Series(guide_assignments)
        
        # Process guide names to extract perturbation targets
        perturbation_map = {}
        for barcode, guide in guide_assignments.items():
            # If the guide indicates a non-targeting control, label as "Non-targeting"
            if ('NON-GENE_SITE' in guide) or guide.startswith('NO_SITE') or (guide == 'Background'):
                perturbation_map[barcode] = 'Non-targeting'
            else:
                parts = guide.split('_')
                if len(parts) > 1 and parts[0]:
                    perturbation_map[barcode] = parts[0]
                else:
                    perturbation_map[barcode] = guide
        
        # Add perturbation information to obs
        adata.obs['perturbation_name'] = pd.Series(perturbation_map)
    
    return adata

def harmonize_dataset(dataset, extract_dir, output_dir, accession):
    """Harmonize a single dataset."""
    print(f"\nProcessing {dataset['name']}...")
    
    if 'h5' in dataset['files']:
        h5_file = os.path.join(extract_dir, dataset['files']['h5'])
        adata = read_h5_file(h5_file)
    else:
        expression_file = os.path.join(extract_dir, dataset['files']['expression'])
        genes_file = os.path.join(extract_dir, dataset['files']['genes'])
        cells_file = os.path.join(extract_dir, dataset['files']['cells'])
        adata = read_mtx_file(expression_file, genes_file, cells_file)
    
    # Add standard metadata to each cell
    for key, value in dataset['metadata'].items():
        adata.obs[key] = value
    
    # Process perturbation information for Perturb-seq data
    if 'guide' in adata.obs.columns:
        adata.obs['crispr_type'] = 'CRISPR KO'
        adata.obs['condition'] = 'unknown'
        for cell in adata.obs_names:
            if cell in adata.obs.index:
                if pd.isna(adata.obs.loc[cell, 'guide']):
                    adata.obs.loc[cell, 'perturbation_name'] = 'None'
                    adata.obs.loc[cell, 'crispr_type'] = 'None'
                    adata.obs.loc[cell, 'condition'] = 'unknown'
                elif adata.obs.loc[cell, 'perturbation_name'] == 'Non-targeting':
                    adata.obs.loc[cell, 'condition'] = 'control'
                else:
                    adata.obs.loc[cell, 'condition'] = 'test'
    else:
        adata.obs['perturbation_name'] = 'None'
        adata.obs['crispr_type'] = 'None'
        adata.obs['condition'] = 'control'
    
    print(f"  Checking var_names: {adata.var_names[:5]}")
    
    if 'gene_id' in adata.var.columns and not adata.var.index.equals(adata.var['gene_id']):
        print("  var_names are already gene symbols")
    else:
        print("  Warning: var_names might not be gene symbols, please check")
    
    # Exclude cells with a perturbation name of "None" (case-insensitive)
    initial_n = adata.n_obs
    adata = adata[~adata.obs['perturbation_name'].str.lower().eq("none")].copy()
    print(f"Excluded {initial_n - adata.n_obs} cells with perturbation name 'None'")
    
    # Print summary only if there are remaining cells
    print("Dataset summary:")
    print(f"  Cells: {adata.n_obs}")
    print(f"  Genes: {adata.n_vars}")
    if adata.n_obs > 0:
        print(f"  Organism: {adata.obs['organism'].iloc[0]}")
        print(f"  Cell type: {adata.obs['cell_type'].iloc[0]}")
        if 'perturbation_name' in adata.obs.columns:
            perturbed_cells = adata.obs[adata.obs['perturbation_name'] != 'None'].shape[0]
            print(f"  Perturbed cells: {perturbed_cells}")
            if perturbed_cells > 0:
                top_perturbations = adata.obs['perturbation_name'].value_counts().head(5).to_dict()
                print(f"  Top perturbations: {top_perturbations}")
    else:
        print("  No cells remain after filtering.")
    
    output_file = os.path.join(output_dir, f"{accession}_{dataset['name']}.h5ad")
    print(f"Saving harmonized data to {output_file}")
    adata.write(output_file)
    
    return adata

def main(data_dir):
    """Main function to process the dataset."""
    output_dir = os.path.join(data_dir, "processed")
    os.makedirs(output_dir, exist_ok=True)
    
    tar_file = download_data(data_dir)
    extract_dir = extract_data(tar_file, data_dir)
    
    for dataset in DATASETS:
        harmonize_dataset(dataset, extract_dir, output_dir, GEO_ACCESSION)
    
    print("\nHarmonization complete!")

# Set your data directory here (e.g., a local path where you want to store the data)
data_dir = "./data_directory"  # Change this to your desired directory path

# Run the main function
main(data_dir)
