In [None]:
import os
import sys
import gzip
import requests
from pathlib import Path
from typing import Dict, List, Optional, Union
import numpy as np
import pandas as pd
import scanpy as sc
import anndata as ad
from tqdm import tqdm

# Constants
GEO_ACCESSION = "GSE190604"
BASE_URL = f"https://ftp.ncbi.nlm.nih.gov/geo/series/GSE190nnn/{GEO_ACCESSION}/suppl"
FILES_TO_DOWNLOAD = [
    f"{GEO_ACCESSION}_barcodes.tsv.gz",
    f"{GEO_ACCESSION}_features.tsv.gz",
    f"{GEO_ACCESSION}_matrix.mtx.gz",
    f"{GEO_ACCESSION}_cellranger-guidecalls-aggregated-unfiltered.txt.gz",
]
SERIES_MATRIX_URL = f"https://ftp.ncbi.nlm.nih.gov/geo/series/GSE190nnn/{GEO_ACCESSION}/matrix/{GEO_ACCESSION}_series_matrix.txt.gz"


def download_file(url: str, output_path: Path, force: bool = False) -> None:
    """
    Download a file from a URL to the specified output path.
    """
    if output_path.exists() and not force:
        print(f"File already exists: {output_path}")
        return
    
    print(f"Downloading {url} to {output_path}")
    response = requests.get(url, stream=True)
    response.raise_for_status()
    
    total_size = int(response.headers.get('content-length', 0))
    block_size = 1024  # 1 Kibibyte
    
    with open(output_path, 'wb') as f:
        for data in tqdm(
            response.iter_content(block_size),
            total=total_size // block_size,
            unit='KiB',
            unit_scale=True
        ):
            f.write(data)


def download_dataset(data_dir: Path) -> None:
    """
    Download all necessary files for the dataset.
    """
    data_dir.mkdir(parents=True, exist_ok=True)
    
    # Download main data files
    for filename in FILES_TO_DOWNLOAD:
        url = f"{BASE_URL}/{filename}"
        output_path = data_dir / filename
        download_file(url, output_path)
    
    # Download series matrix file for metadata
    series_matrix_path = data_dir / f"{GEO_ACCESSION}_series_matrix.txt.gz"
    download_file(SERIES_MATRIX_URL, series_matrix_path)


def extract_sample_metadata(series_matrix_path: Path) -> pd.DataFrame:
    """
    Extract sample metadata from the series matrix file.
    """
    sample_info = {}
    sample_ids = []
    sample_titles = []
    
    with gzip.open(series_matrix_path, 'rt') as f:
        for line in f:
            if line.startswith('!Sample_geo_accession'):
                sample_ids = line.strip().split('\t')[1:]
                sample_ids = [s.strip('"') for s in sample_ids]
                for sample_id in sample_ids:
                    if sample_id not in sample_info:
                        sample_info[sample_id] = {}
            elif line.startswith('!Sample_title'):
                titles = line.strip().split('\t')[1:]
                titles = [t.strip('"') for t in titles]
                sample_titles = titles
                for i, sample_id in enumerate(sample_ids):
                    if i < len(titles):
                        sample_info[sample_id]['sample_title'] = titles[i]
            elif line.startswith('!Sample_characteristics_ch1'):
                chars = line.strip().split('\t')[1:]
                chars = [c.strip('"') for c in chars]
                for i, sample_id in enumerate(sample_ids):
                    if i < len(chars):
                        char_parts = chars[i].split(': ')
                        if len(char_parts) == 2:
                            key, value = char_parts
                            sample_info[sample_id][key] = value
    
    df = pd.DataFrame.from_dict(sample_info, orient='index')
    
    if df.empty and sample_titles:
        df = pd.DataFrame({'sample_title': sample_titles}, index=sample_ids)
    
    return df


def parse_guide_calls(guide_calls_path: Path) -> pd.DataFrame:
    """
    Parse the guide calls file to extract perturbation information.
    """
    guide_calls = pd.read_csv(guide_calls_path, sep='\t', compression='gzip')
    
    # Function to process each feature_call
    def process_feature_call(feature_call):
        # If the feature call contains "NO-TARGET" (case insensitive), assign as "Non-targeting"
        if "NO-TARGET" in feature_call.upper():
            return "Non-targeting"
        else:
            # Otherwise, extract gene names (format: GENE-guide1|GENE-guide2)
            return '|'.join([g.split('-')[0] for g in feature_call.split('|')])
    
    guide_calls['perturbation_name'] = guide_calls['feature_call'].apply(process_feature_call)
    
    # Identify non-targeting controls by also checking for "Non-targeting" in the feature_call
    guide_calls['is_non_targeting'] = guide_calls['feature_call'].str.contains('Non-targeting|NO-TARGET', case=False, regex=True)
    guide_calls.loc[guide_calls['is_non_targeting'], 'perturbation_name'] = 'Non-targeting'
    
    # Set cell barcode as index for easier merging
    guide_calls.set_index('cell_barcode', inplace=True)
    
    return guide_calls


def load_10x_data(data_dir: Path) -> ad.AnnData:
    """
    Load 10x data from the specified directory.
    """
    matrix_file = data_dir / f"{GEO_ACCESSION}_matrix.mtx.gz"
    features_file = data_dir / f"{GEO_ACCESSION}_features.tsv.gz"
    barcodes_file = data_dir / f"{GEO_ACCESSION}_barcodes.tsv.gz"
    
    # Load features (genes)
    var_names = pd.read_csv(features_file, header=None, sep='\t', compression='gzip')
    gene_ids = var_names[0].values
    gene_symbols = var_names[1].values
    
    # Load barcodes
    obs_names = pd.read_csv(barcodes_file, header=None, compression='gzip')
    cell_barcodes = obs_names[0].values
    
    # Load the count matrix using scanpy's read_mtx function (matrix is genes x cells)
    adata = sc.read_mtx(str(matrix_file))
    
    # Transpose if necessary to get cells x genes
    if adata.shape[0] == len(gene_ids) and adata.shape[1] == len(cell_barcodes):
        adata = adata.T
    
    # Set observation and variable names
    adata.obs_names = pd.Index(cell_barcodes)
    adata.var_names = pd.Index(gene_symbols)
    adata.var['gene_ids'] = gene_ids
    adata.var_names_make_unique()
    
    return adata


def determine_stimulation_condition(adata: ad.AnnData, sample_metadata: pd.DataFrame) -> None:
    """
    Determine stimulation condition for each cell based on sample metadata and cell barcodes.
    """
    has_stim_info = False
    if 'sample_title' in sample_metadata.columns:
        stim_samples = [title for title in sample_metadata['sample_title'] if 'stim' in title.lower()]
        nostim_samples = [title for title in sample_metadata['sample_title'] if 'nostim' in title.lower()]
        
        if stim_samples or nostim_samples:
            has_stim_info = True
            stim_mask = adata.obs.index.str.contains('-1$')
            adata.obs.loc[stim_mask, 'condition'] = 'Stimulated'
            adata.obs.loc[~stim_mask, 'condition'] = 'Non-stimulated'
    
    if not has_stim_info:
        stim_cells = [bc for bc in adata.obs.index if 'stim' in bc.lower()]
        nostim_cells = [bc for bc in adata.obs.index if 'nostim' in bc.lower()]
        
        if stim_cells or nostim_cells:
            stim_mask = adata.obs.index.str.contains('stim', case=False)
            nostim_mask = adata.obs.index.str.contains('nostim', case=False)
            adata.obs.loc[stim_mask & ~nostim_mask, 'condition'] = 'Stimulated'
            adata.obs.loc[nostim_mask, 'condition'] = 'Non-stimulated'
        else:
            adata.obs['condition'] = 'unknown'


def harmonize_dataset(data_dir: Path) -> ad.AnnData:
    """
    Harmonize the dataset into a standardized format.
    """
    print("Loading count matrix...")
    adata = load_10x_data(data_dir)
    print(f"Loaded data with {adata.n_obs} cells and {adata.n_vars} genes")
    
    print("Loading guide calls...")
    guide_calls_path = data_dir / f"{GEO_ACCESSION}_cellranger-guidecalls-aggregated-unfiltered.txt.gz"
    guide_calls = parse_guide_calls(guide_calls_path)
    
    print("Loading sample metadata...")
    series_matrix_path = data_dir / f"{GEO_ACCESSION}_series_matrix.txt.gz"
    sample_metadata = extract_sample_metadata(series_matrix_path)
    
    print("Adding perturbation information...")
    adata.obs = adata.obs.join(guide_calls, how='left')
    
    print("Harmonizing metadata...")
    adata.obs['organism'] = 'Homo sapiens'
    adata.obs['cell_type'] = 'T Cells'
    adata.obs['crispr_type'] = 'CRISPRa'
    adata.obs['cancer_type'] = 'Non-Cancer'
    determine_stimulation_condition(adata, sample_metadata)
    
    if 'perturbation_name' not in adata.obs.columns:
        adata.obs['perturbation_name'] = 'unknown'
    else:
        adata.obs['perturbation_name'] = adata.obs['perturbation_name'].fillna('Non-targeting')
    
    for col in guide_calls.columns:
        if col not in adata.obs.columns and col not in ['cell_barcode', 'perturbation_name']:
            adata.obs[col] = adata.obs.index.map(guide_calls[col]).fillna('unknown')
    
    required_fields = ['organism', 'cell_type', 'crispr_type', 'cancer_type', 'condition', 'perturbation_name']
    for field in required_fields:
        if field not in adata.obs.columns:
            adata.obs[field] = 'unknown'
    
    # Override condition: if perturbation_name is "Non-targeting" label as "Control", else "Test"
    adata.obs['condition'] = adata.obs['perturbation_name'].apply(lambda x: "Control" if x == "Non-targeting" else "Test")
    
    # Convert required metadata columns to categorical type
    for col in required_fields:
        adata.obs[col] = adata.obs[col].astype('category')
    
    adata.uns['dataset'] = {
        'geo_accession': GEO_ACCESSION,
        'title': 'CRISPR activation and interference screens decode stimulation responses in primary human T cells',
        'organism': 'Homo sapiens',
        'publication': ('Schmidt R, Steinhart Z, Layeghi M, Freimer JW et al. '
                        'CRISPR activation and interference screens decode stimulation responses in primary human T cells. '
                        'Science 2022 Feb 4;375(6580):eabj4008. PMID: 35113687')
    }
    
    return adata


def run_in_jupyter(data_dir: Path = Path(f"./{GEO_ACCESSION}")):
    """
    Run the download and harmonization pipeline in a Jupyter Notebook.
    """
    print(f"Using data directory: {data_dir}")
    download_dataset(data_dir)
    adata = harmonize_dataset(data_dir)
    
    # Before saving, convert any categorical or object-type columns in obs to strings
    for col in adata.obs.columns:
        if pd.api.types.is_categorical_dtype(adata.obs[col]) or adata.obs[col].dtype == object:
            adata.obs[col] = adata.obs[col].astype(str)
    
    output_file = data_dir / f"{GEO_ACCESSION}_harmonized.h5ad"
    print(f"Saving harmonized dataset to {output_file}")
    adata.write(output_file)
    
    print("Done!")
    print(f"Harmonized dataset saved to {output_file}")
    print(f"Dataset shape: {adata.shape}")
    print(f"Metadata fields: {list(adata.obs.columns)}")


# Run the pipeline in Jupyter Notebook
run_in_jupyter()
