In [None]:
"""
harmonize_GSE261025_debug.py

Fully instrumented code to debug the "all cells are Non-Targeting" issue.
Includes debug prints to confirm:
  - Which _crispr_analysis folders are found
  - Which guides and genes are read
  - The final dictionary contents (cell->gene)
  - The actual assignment in process_sample()
"""

import os
import glob
import gzip
import tarfile
import urllib.request
import numpy as np
import pandas as pd
import scanpy as sc
import anndata as ad
from scipy import sparse
from scipy.io import mmread
from pathlib import Path

# --------------------------------------------------------------------------
# Define sample metadata
# --------------------------------------------------------------------------
SAMPLE_METADATA = {
    'GSM8132774': {
        'title': 'Proliferative WT HepaRG cells @ Day 5',
        'cell_line': 'HepaRG',
        'cell_type': 'Hepatocyte',
        'genotype': 'WT',
        'treatment': 'Control',
        'day': 5
    },
    'GSM8132775': {
        'title': 'Differentiated @ D28 WT HepaRG cells according to standard protocols using DSMO',
        'cell_line': 'HepaRG',
        'cell_type': 'Hepatocyte',
        'genotype': 'WT',
        'treatment': 'Differentiated',
        'day': 28
    },
    'GSM8132776': {
        'title': 'Proliferative dCas9 HepaRG cells @ Day 5',
        'cell_line': 'HepaRG',
        'cell_type': 'Hepatocyte',
        'genotype': 'dCas9-KRAB',
        'treatment': 'Control',
        'day': 5
    },
    'GSM8132777': {
        'title': 'Differentiated @ D28 dCas9 HepaRG cells according to standard protocols using DSMO',
        'cell_line': 'HepaRG',
        'cell_type': 'Hepatocyte',
        'genotype': 'dCas9-KRAB',
        'treatment': 'Differentiated',
        'day': 28
    },
    'GSM8132778': {
        'title': 'Differentiated @ Day42 WT HepaRG cells according to standard protocols using DSMO',
        'cell_line': 'HepaRG',
        'cell_type': 'Hepatocyte',
        'genotype': 'WT',
        'treatment': 'Differentiated',
        'day': 42
    },
    # Perturbation samples
    'GSM7660623': {
        'title': 'Perturb-Seq replicate 1 [DIFF_HRG_PS]',
        'cell_line': 'HepaRG',
        'cell_type': 'Hepatocyte',
        'genotype': 'dCas9-KRAB',
        'treatment': 'Differentiated',
        'day': 42,
        'replicate': 1
    },
    'GSM7660624': {
        'title': 'Perturb-Seq replicate 2 [P_Seq2]',
        'cell_line': 'HepaRG',
        'cell_type': 'Hepatocyte',
        'genotype': 'dCas9-KRAB',
        'treatment': 'Differentiated',
        'day': 42,
        'replicate': 2
    },
    'GSM7660625': {
        'title': 'Perturb-Seq replicate 3 [PSG_1]',
        'cell_line': 'HepaRG',
        'cell_type': 'Hepatocyte',
        'genotype': 'dCas9-KRAB',
        'treatment': 'Differentiated',
        'day': 42,
        'replicate': 3
    },
    'GSM7660626': {
        'title': 'Perturb-Seq replicate 4 [PSG_2]',
        'cell_line': 'HepaRG',
        'cell_type': 'Hepatocyte',
        'genotype': 'dCas9-KRAB',
        'treatment': 'Differentiated',
        'day': 42,
        'replicate': 4
    },
    'GSM7660627': {
        'title': 'Perturb-Seq replicate 5 [PSG_3]',
        'cell_line': 'HepaRG',
        'cell_type': 'Hepatocyte',
        'genotype': 'dCas9-KRAB',
        'treatment': 'Differentiated',
        'day': 42,
        'replicate': 5
    }
}

DATA_URLS = {
    'GSE261025': 'https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE261025&format=file',
    'GSE238219': 'https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE238219&format=file'
}


def download_and_extract_data(root_path):
    data_path = os.path.join(root_path, "GSE261025")
    os.makedirs(data_path, exist_ok=True)

    tar_path = os.path.join(data_path, "GSE261025_RAW.tar")
    if not os.path.exists(tar_path):
        print(f"Downloading GSE261025 dataset to {tar_path}...")
        urllib.request.urlretrieve(DATA_URLS['GSE261025'], tar_path)

    if not glob.glob(os.path.join(data_path, "GSM*")):
        print(f"Extracting data from {tar_path}...")
        with tarfile.open(tar_path) as tar:
            tar.extractall(path=data_path)

    perturb_path = os.path.join(root_path, "GSE238219")
    os.makedirs(perturb_path, exist_ok=True)
    perturb_tar_path = os.path.join(perturb_path, "GSE238219_RAW.tar")
    if not os.path.exists(perturb_tar_path):
        print(f"Downloading GSE238219 perturbation dataset to {perturb_tar_path}...")
        urllib.request.urlretrieve(DATA_URLS['GSE238219'], perturb_tar_path)

    if not glob.glob(os.path.join(perturb_path, "GSM*")):
        print(f"Extracting perturbation data from {perturb_tar_path}...")
        with tarfile.open(perturb_tar_path) as tar:
            tar.extractall(path=perturb_path)

    return data_path, perturb_path


def read_10x_mtx(path_prefix):
    barcodes_path = f"{path_prefix}_barcodes.tsv.gz"
    features_path = f"{path_prefix}_features.tsv.gz"
    matrix_path   = f"{path_prefix}_matrix.mtx.gz"

    with gzip.open(barcodes_path, 'rt') as f:
        barcodes = [line.strip() for line in f]

    gene_ids = []
    gene_symbols = []
    feature_types = []
    with gzip.open(features_path, 'rt') as f:
        for line in f:
            parts = line.strip().split('\t')
            if len(parts) >= 3:
                gene_ids.append(parts[0])
                gene_symbols.append(parts[1])
                feature_types.append(parts[2])
            else:
                parts = line.strip().split()
                gene_ids.append(parts[0])
                gene_symbols.append(parts[1])
                feature_types.append(parts[2])

    mat = mmread(matrix_path).T.tocsr()
    adata = ad.AnnData(X=mat)
    adata.obs_names = barcodes
    adata.var_names = gene_symbols
    adata.var["gene_ids"] = gene_ids
    adata.var["feature_types"] = feature_types
    adata.var_names_make_unique()
    return adata


def extract_perturbation_data(perturb_path):
    """
    Gathers guide->gene from every replicate's feature_reference.csv
    and merges them, then parse protospacer_calls_per_cell.csv from each.
    Includes debug prints to confirm what's found.
    """
    perturbation_info = {
        'target_genes': {},
        'guide_to_gene': {},
        'non_targeting': [],
        'cell_to_guide': {},
        'cell_to_target': {}
    }

    crispr_dirs = []
    for root, dirs, files in os.walk(perturb_path):
        for d in dirs:
            if d.endswith('_crispr_analysis'):
                crispr_dirs.append(os.path.join(root, d))

    print("\n[DEBUG] Found CRISPR directories:", crispr_dirs)

    # 1) Parse all feature_reference.csv
    for cdir in crispr_dirs:
        feature_ref_path = os.path.join(cdir, 'feature_reference.csv')
        if os.path.exists(feature_ref_path):
            print(f"[DEBUG] Reading feature_reference.csv from {cdir}")
            df = pd.read_csv(feature_ref_path)
            print(f"[DEBUG]   # rows in feature_reference: {len(df)}")
            for _, row in df.iterrows():
                guide_id = row['id']
                target_gene = row['target_gene_name']

                # Merge in
                if guide_id not in perturbation_info['guide_to_gene']:
                    perturbation_info['guide_to_gene'][guide_id] = target_gene

                if target_gene == 'Non-Targeting':
                    if guide_id not in perturbation_info['non_targeting']:
                        perturbation_info['non_targeting'].append(guide_id)
                else:
                    if target_gene not in perturbation_info['target_genes']:
                        perturbation_info['target_genes'][target_gene] = []
                    if guide_id not in perturbation_info['target_genes'][target_gene]:
                        perturbation_info['target_genes'][target_gene].append(guide_id)
        else:
            print(f"[DEBUG] No feature_reference.csv in {cdir}")

    # Print summary of all guide->gene we found
    print(f"[DEBUG] Done reading all feature_reference.csv files.")
    print(f"[DEBUG] # of distinct guide_ids in dictionary: {len(perturbation_info['guide_to_gene'])}")
    all_genes = list(perturbation_info['guide_to_gene'].values())
    print(f"[DEBUG] Unique gene names: {set(all_genes)}")

    # 2) Parse protospacer calls from each replicate
    for cdir in crispr_dirs:
        protospacer_path = os.path.join(cdir, 'protospacer_calls_per_cell.csv')
        if os.path.exists(protospacer_path):
            print(f"[DEBUG] Reading protospacer_calls_per_cell.csv from {cdir}")
            df = pd.read_csv(protospacer_path)
            print(f"[DEBUG]   # rows in protospacer_calls: {len(df)}")
            # Let's show a small sample
            print("[DEBUG]   sample lines:")
            print(df.head(5))

            for _, row in df.iterrows():
                raw_bc = row['cell_barcode']
                if '-' in raw_bc:
                    raw_bc = raw_bc.rsplit('-', 1)[0]

                guide_field = row['feature_call']
                all_guides = guide_field.split('|')

                for guide_id in all_guides:
                    tg = perturbation_info['guide_to_gene'].get(guide_id, 'Unknown')

                    if raw_bc not in perturbation_info['cell_to_guide']:
                        perturbation_info['cell_to_guide'][raw_bc] = []
                    perturbation_info['cell_to_guide'][raw_bc].append(guide_id)

                    if raw_bc not in perturbation_info['cell_to_target']:
                        perturbation_info['cell_to_target'][raw_bc] = []
                    if tg not in perturbation_info['cell_to_target'][raw_bc]:
                        perturbation_info['cell_to_target'][raw_bc].append(tg)
        else:
            print(f"[DEBUG] No protospacer_calls_per_cell.csv in {cdir}")

    # Show sample of cell_to_target
    print(f"[DEBUG] Done building cell_to_target. # cells: {len(perturbation_info['cell_to_target'])}")
    print("[DEBUG] sample of cell_to_target items:")
    items = list(perturbation_info['cell_to_target'].items())[:20]
    for k,v in items:
        print("  ", k, v)

    return perturbation_info


def extract_perturb_seq_data(perturb_path, sample_id):
    tar_candidates = glob.glob(
        os.path.join(perturb_path, f"{sample_id}_*_filtered_feature_bc_matrix.tar.gz")
    )
    if not tar_candidates:
        print(f"[DEBUG] No filtered_feature_bc_matrix tar found for {sample_id}")
        return None

    tar_file = tar_candidates[0]
    extract_dir = os.path.join(perturb_path, f"{sample_id}_filtered_feature_bc_matrix")
    if not os.path.exists(extract_dir):
        os.makedirs(extract_dir, exist_ok=True)
        with tarfile.open(tar_file) as t:
            t.extractall(extract_dir)

    subdirs = [d for d in os.listdir(extract_dir) if os.path.isdir(os.path.join(extract_dir, d))]
    if len(subdirs) == 1:
        read_dir = os.path.join(extract_dir, subdirs[0])
    else:
        read_dir = extract_dir

    try:
        adata = sc.read_10x_mtx(read_dir, var_names='gene_symbols', cache=False)
        new_barcodes = [bc.rsplit('-', 1)[0] for bc in adata.obs_names]
        adata.obs_names = new_barcodes
        adata.var_names_make_unique()
        return adata
    except Exception as e:
        print(f"[DEBUG] Error reading 10x data for {sample_id}: {e}")
        return None


def process_sample(sample_id, data_path, perturb_path=None, perturb_info=None):
    is_perturb = sample_id.startswith('GSM7660')

    if is_perturb and perturb_path:
        adata = extract_perturb_seq_data(perturb_path, sample_id)
        if adata is None:
            print(f"[DEBUG] Skipping sample {sample_id}, no perturbation data found.")
            return None
    else:
        bc_candidates = glob.glob(
            os.path.join(data_path, f"{sample_id}_*_barcodes.tsv.gz")
        )
        if not bc_candidates:
            print(f"[DEBUG] No 10x barcodes found for sample {sample_id}")
            return None
        prefix = bc_candidates[0].replace("_barcodes.tsv.gz", "")
        adata = read_10x_mtx(prefix)
        new_barcodes = [bc.rsplit('-', 1)[0] for bc in adata.obs_names]
        adata.obs_names = new_barcodes

    old_barcodes = list(adata.obs_names)
    adata.obs['original_barcode'] = old_barcodes
    adata.obs_names = [f"{sample_id}_{b}" for b in old_barcodes]

    meta = SAMPLE_METADATA.get(sample_id, {})
    for k, v in meta.items():
        adata.obs[k] = v
    adata.obs['sample_id'] = sample_id

    # columns: organism, cell_type, crispr_type, cancer_type, condition, perturbation_name
    adata.obs['organism'] = 'Homo sapiens'
    adata.obs['cell_type'] = meta.get('cell_type','Hepatocyte')
    adata.obs['cancer_type'] = 'Non-Cancer'
    adata.obs['condition'] = meta.get('treatment','Control')

    if is_perturb and perturb_info:
        adata.obs['crispr_type'] = 'CRISPRi'
        adata.obs['perturbation_name'] = 'Non-Targeting'

        for orig_bc, full_bc in zip(old_barcodes, adata.obs_names):
            if orig_bc in perturb_info['cell_to_target']:
                gene_list = perturb_info['cell_to_target'][orig_bc]
                # DEBUG each cell
                print(f"[DEBUG]  sample {sample_id}, cell {orig_bc} => {gene_list}")
                real_targets = [g for g in gene_list if g not in ('Non-Targeting','Unknown')]
                if real_targets:
                    adata.obs.loc[full_bc, 'perturbation_name'] = ' + '.join(real_targets)
                else:
                    adata.obs.loc[full_bc, 'perturbation_name'] = 'Non-Targeting'
    else:
        adata.obs['crispr_type'] = 'None'
        adata.obs['perturbation_name'] = 'None'

    return adata


def harmonize_data(root_path):
    data_path, perturb_path = download_and_extract_data(root_path)
    perturb_info = extract_perturbation_data(perturb_path)

    sample_ids = [
        'GSM8132774',
        'GSM8132775',
        'GSM8132776',
        'GSM8132777',
        'GSM8132778',
        'GSM7660623',
        'GSM7660624',
        'GSM7660625',
        'GSM7660626',
        'GSM7660627'
    ]

    all_adatas = []
    for sid in sample_ids:
        print(f"\n=== Processing sample {sid} ===")
        a = process_sample(sid, data_path, perturb_path, perturb_info)
        if a is not None:
            all_adatas.append(a)

    if not all_adatas:
        print("[DEBUG] No samples processed successfully. Exiting.")
        return None

    print("\n[DEBUG] Concatenating all samples into one AnnData (join='outer')...")
    combined = ad.concat(all_adatas, join='outer', label='sample_id', index_unique='-')

    must_have = ["organism","cell_type","crispr_type","cancer_type","condition","perturbation_name"]
    for field in must_have:
        if field not in combined.obs.columns:
            combined.obs[field] = 'Unknown'

    out_file = os.path.join(root_path, "GSE261025_harmonized_fixed_debug.h5ad")
    print(f"[DEBUG] Saving final AnnData to: {out_file}")
    combined.write_h5ad(out_file)

    return combined


def run_harmonize(root_path=os.getcwd()):
    print(f"[DEBUG] Using root path: {root_path}")
    adata = harmonize_data(root_path)
    return adata


if __name__ == "__main__":
    run_harmonize()
