In [None]:
# %% [markdown]
# # GSE212396 Full Pipeline: Auto-Download, Process, Filter Unknown, Combine
#
# This updated script includes a snippet that sets `condition = "Control"`
# for any `perturbation_name` in ["negative_control", "safe_targeting", "Non-targeting"]
# and `condition = "Test"` otherwise.

# %% [code]
import os
import logging
import urllib.request
import gzip
import shutil
import pandas as pd
import numpy as np
import anndata as ad
from scipy.io import mmread
from tqdm import tqdm

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s"
)

# ------------------------------------------------------------------------
# FTP Source Info
# ------------------------------------------------------------------------
GEO_ACCESSION = "GSE212396"
BASE_URL = "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE212nnn/GSE212396/suppl"

# ------------------------------------------------------------------------
# Datasets to process
# ------------------------------------------------------------------------
DATASETS = [
    "50gLib_Eahy926_16K_scRNAseq",
    "50gLib_Eahy926_150K_scRNAseq",
    "50gLib_TeloHAEC_150K_scRNAseq",
    "200gLib_TeloHAEC_scRNAseq_1",
    "200gLib_TeloHAEC_scRNAseq_2",
]

# Guide files to process
GUIDE_FILES = [
    "50gLib_targets_n_guides.txt",
    "200gLib_targets_n_guides.txt",
]

# ------------------------------------------------------------------------
# Calls-file column name candidates
# ------------------------------------------------------------------------
CELL_BARCODE_COLS = ["cell_barcode", "barcode", "cell", "cbc_10x", "CBC_10x", "CBC"]
GUIDE_COLS = ["guide", "guide_id", "sgRNA_ID", "Guide"]

# ------------------------------------------------------------------------
# Helper: Download with progress bar
# ------------------------------------------------------------------------
class DownloadProgressBar(tqdm):
    def update_to(self, b=1, bsize=1, tsize=None):
        if tsize is not None:
            self.total = tsize
        self.update(b * bsize - self.n)

def download_file_if_needed(url: str, out_path: str) -> str:
    """Download out_path if it doesn't exist."""
    if os.path.exists(out_path):
        logging.info(f"File already exists: {out_path}")
        return out_path
    
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    logging.info(f"Downloading {url} to {out_path}")
    with DownloadProgressBar(
        unit='B', unit_scale=True, miniters=1, desc=os.path.basename(url)
    ) as t:
        urllib.request.urlretrieve(url, out_path, reporthook=t.update_to)
    return out_path

def gunzip_file_if_needed(gz_path: str) -> str:
    """Gunzip gz_path -> gz_path.replace('.gz','') if needed."""
    if not os.path.exists(gz_path):
        raise FileNotFoundError(f"Cannot gunzip. File not found: {gz_path}")
    
    out_path = gz_path.replace(".gz", "")
    if os.path.exists(out_path):
        logging.info(f"Unzipped file already exists: {out_path}")
        return out_path
    
    logging.info(f"Gunzipping {gz_path} -> {out_path}")
    with gzip.open(gz_path, 'rb') as f_in, open(out_path, 'wb') as f_out:
        shutil.copyfileobj(f_in, f_out)
    return out_path

# ------------------------------------------------------------------------
# Step 1: Download or confirm guide files
# ------------------------------------------------------------------------
def get_guide_files(root_dir: str, debug=False):
    """
    Make sure each guide file is present & unzipped:
      GSE212396_50gLib_targets_n_guides.txt
      GSE212396_200gLib_targets_n_guides.txt
    returns paths to the unzipped .txt files
    """
    guide_dir = os.path.join(root_dir, "guides")
    os.makedirs(guide_dir, exist_ok=True)

    final_paths = []
    for gf in GUIDE_FILES:
        gz_name = f"{GEO_ACCESSION}_{gf}.gz"  # e.g. GSE212396_50gLib_targets_n_guides.txt.gz
        gz_url  = f"{BASE_URL}/{gz_name}"
        gz_path = os.path.join(guide_dir, gz_name)

        # Download if needed
        download_file_if_needed(gz_url, gz_path)
        # Unzip if needed
        out_path = gunzip_file_if_needed(gz_path)
        final_paths.append(out_path)

    return final_paths

# ------------------------------------------------------------------------
# Step 2: Load guide info into a dictionary
# ------------------------------------------------------------------------
def load_guide_info(guide_txt_paths, debug=False):
    """
    Build a dict:
      guide_info[<GuideSequence>] = { target_gene: <>, library: <> }
    from the 50gLib & 200gLib .txt files.
    
    If 'TargetGene'/'Gene' is missing, 
    we also check 'guideSet' as a fallback for the target gene.
    """
    guide_info = {}
    
    for p in guide_txt_paths:
        df = pd.read_csv(p, sep='\t')
        logging.info(f"Reading guide info from {p}: {df.shape[0]} rows.")
        if debug:
            print("[DEBUG] Columns:", df.columns.tolist())
            print(df.head(5))

        # We look for 'GuideSequence', plus potential target columns
        possible_guide_seq_cols = ["GuideSequence"]
        possible_target_cols    = ["TargetGene", "Gene"]
        fallback_col = "guideSet"  # We can treat guideSet as the "target gene" if needed

        # find the guide-sequence col
        guide_seq_col = None
        for c in possible_guide_seq_cols:
            if c in df.columns:
                guide_seq_col = c
                break

        if not guide_seq_col:
            logging.warning(f"No 'GuideSequence' column in {p}; skipping.")
            continue

        # find the target col
        target_col = None
        for c in possible_target_cols:
            if c in df.columns:
                target_col = c
                break

        if not target_col:
            # see if fallback is present
            if fallback_col in df.columns:
                target_col = fallback_col
                logging.info(f"Using '{fallback_col}' as target gene column for {p}.")
            else:
                logging.warning(f"No 'TargetGene'/'Gene'/'guideSet' in {p}; using 'Unknown'.")
                df["UnknownGene"] = "Unknown"
                target_col = "UnknownGene"

        library_name = os.path.basename(p)
        for _, row in df.iterrows():
            g_seq = str(row[guide_seq_col])
            t_gene = str(row[target_col]) if pd.notnull(row[target_col]) else "Unknown"
            guide_info[g_seq] = {
                "target_gene": t_gene,
                "library": library_name
            }

    logging.info(f"guide_info built with {len(guide_info)} entries.")
    return guide_info

# ------------------------------------------------------------------------
# Step 3: Download or confirm dataset files
# ------------------------------------------------------------------------
def get_dataset_files(root_dir: str, dataset_name: str):
    """
    For each dataset, we expect 4 files:
      1) _barcodes.tsv.gz
      2) _features.tsv.gz
      3) _matrix.mtx.gz
      4) _calls.tsv.gz
    Then we gunzip them if needed, returning final unzipped paths.
    """
    dataset_dir = os.path.join(root_dir, dataset_name)
    os.makedirs(dataset_dir, exist_ok=True)

    file_types = ["barcodes.tsv", "features.tsv", "matrix.mtx", "calls.tsv"]
    final_paths = {}
    for ft in file_types:
        gz_name = f"{GEO_ACCESSION}_{dataset_name}_{ft}.gz"
        gz_url  = f"{BASE_URL}/{gz_name}"
        gz_path = os.path.join(dataset_dir, gz_name)

        # Download
        try:
            download_file_if_needed(gz_url, gz_path)
        except Exception as e:
            logging.warning(f"Could not download {gz_url}: {e}")
            continue

        # Unzip
        try:
            unzipped_path = gunzip_file_if_needed(gz_path)
            final_paths[ft] = unzipped_path
        except Exception as e:
            logging.warning(f"Failed to gunzip {gz_path}: {e}")
            continue

    return final_paths

# ------------------------------------------------------------------------
# Step 4: Load a single dataset into AnnData
# ------------------------------------------------------------------------
def load_dataset(file_paths: dict, debug=False):
    """
    Given the unzipped file paths for 'matrix.mtx', 'features.tsv', 'barcodes.tsv', 'calls.tsv',
    read them into an AnnData, ensuring var_names are unique.
    """
    mtx_path = file_paths.get("matrix.mtx")
    feat_path = file_paths.get("features.tsv")
    bc_path   = file_paths.get("barcodes.tsv")
    calls_path = file_paths.get("calls.tsv")

    if not mtx_path or not feat_path or not bc_path:
        raise FileNotFoundError("Missing one of matrix.mtx, features.tsv, or barcodes.tsv")

    X = mmread(mtx_path).tocsr()
    var_df = pd.read_csv(feat_path, sep='\t', header=None)
    var_df.columns = ['gene_id', 'gene_symbol', 'feature_type']
    obs_names = pd.read_csv(bc_path, sep='\t', header=None)[0].values

    # Check orientation
    if X.shape[0] == len(var_df) and X.shape[1] == len(obs_names):
        X = X.transpose()

    adata = ad.AnnData(
        X=X,
        obs=pd.DataFrame(index=obs_names),
        var=pd.DataFrame(index=var_df["gene_symbol"].values)
    )
    adata.var["gene_ids"] = var_df["gene_id"].values
    adata.var["feature_types"] = var_df["feature_type"].values

    # calls
    if calls_path and os.path.exists(calls_path):
        calls_df = pd.read_csv(calls_path, sep='\t')
        if debug:
            print("[DEBUG] Calls file columns:", calls_df.columns.tolist())
            print(calls_df.head(5))

        # find cell col
        cell_col = None
        for c in CELL_BARCODE_COLS:
            if c in calls_df.columns:
                cell_col = c
                break
        # find guide col
        guide_col = None
        for g in GUIDE_COLS:
            if g in calls_df.columns:
                guide_col = g
                break

        if cell_col and guide_col:
            raw_dict = dict(zip(calls_df[cell_col], calls_df[guide_col]))
            # unify barcodes (-1 suffix)
            adj_dict = {}
            for k, v in raw_dict.items():
                if k.endswith("-1"):
                    adj_dict[k] = v
                else:
                    adj_dict[k + "-1"] = v

            guides = []
            for bc in adata.obs_names:
                guides.append(adj_dict.get(bc, "Unknown"))
            adata.obs["guide"] = pd.Categorical(guides)
            if debug:
                n_unk = (adata.obs["guide"] == "Unknown").sum()
                logging.info(f"Barcodes->guide mapped: {adata.n_obs - n_unk} matched, {n_unk} unknown.")
        else:
            logging.warning("No suitable cell/guide columns found in calls file.")
            adata.obs["guide"] = "Unknown"
    else:
        logging.warning("No calls file found or path is missing.")
        adata.obs["guide"] = "Unknown"

    # **Ensure gene names are unique** to avoid issues when concatenating
    adata.var_names_make_unique()

    return adata

# ------------------------------------------------------------------------
# Step 5A: Harmonize metadata
# ------------------------------------------------------------------------
def harmonize_metadata(adata, dataset_name: str, guide_info: dict, debug=False):
    """
    Fill in obs fields: library, cell_line, loading, dataset, etc.
    Then map guide -> perturbation_name if possible.
    """
    parts = dataset_name.split('_')
    if len(parts) >= 1:
        adata.obs['library'] = parts[0]
    if len(parts) >= 2:
        adata.obs['cell_line'] = parts[1]
    if len(parts) >= 3:
        adata.obs['loading'] = parts[2]

    adata.obs['dataset'] = dataset_name
    adata.obs['organism'] = "Homo sapiens"
    adata.obs['cell_type'] = "Endothelial cells"
    adata.obs['crispr_type'] = "CRISPRi"
    adata.obs['cancer_type'] = "Non-Cancer"
    adata.obs['condition'] = "Unknown"  # Will be overwritten below

    # Build 'perturbation_name'
    if 'guide' in adata.obs.columns:
        def get_gene(g):
            return guide_info[g]['target_gene'] if g in guide_info else "Unknown"
        adata.obs['perturbation_name'] = adata.obs['guide'].map(get_gene)

        if debug:
            n_unk = (adata.obs['perturbation_name'] == "Unknown").sum()
            logging.info(f"perturbation_name assigned for {adata.n_obs - n_unk} cells; {n_unk} unknown.")
    else:
        adata.obs['perturbation_name'] = "Unknown"

    return adata

# ------------------------------------------------------------------------
# Step 5B: Filter out "Unknown" + rename "NO TARGET..." => "Non-targeting"
#          Then set condition = "Control" or "Test" based on snippet
# ------------------------------------------------------------------------
def filter_and_assign_condition(adata: ad.AnnData, debug: bool=False) -> ad.AnnData:
    """
    1) Exclude cells whose perturbation_name == 'Unknown'.
    2) Rename "NO TARGET..." => "Non-targeting".
    3) If perturbation_name in ['negative_control','safe_targeting','Non-targeting'],
       condition = "Control", else "Test".
    """
    # 1) Exclude "Unknown"
    keep_mask = (adata.obs['perturbation_name'] != "Unknown")
    before = adata.n_obs
    adata = adata[keep_mask].copy()
    after = adata.n_obs
    if debug:
        logging.info(f"Excluded {before - after} cells with 'Unknown' perturbation_name.")

    # 2) Rename "NO TARGET" => "Non-targeting"
    no_targ_mask = adata.obs['perturbation_name'].str.lower().str.contains("no") & \
                   adata.obs['perturbation_name'].str.lower().str.contains("target")
    adata.obs.loc[no_targ_mask, 'perturbation_name'] = "Non-targeting"

    # 3) Incorporate snippet:
    #    if perturbation_name in ["negative_control", "safe_targeting", "Non-targeting"]
    #    => condition = "Control"
    #    else => condition = "Test"
    control_perturbations = ["negative_control", "safe_targeting", "Non-targeting"]
    mask_control = adata.obs['perturbation_name'].isin(control_perturbations)
    adata.obs.loc[mask_control, "condition"] = "Control"
    adata.obs.loc[~mask_control, "condition"] = "Test"

    # convert object->categorical
    for col in adata.obs.columns:
        if adata.obs[col].dtype == object:
            adata.obs[col] = adata.obs[col].astype('category')

    return adata

# ------------------------------------------------------------------------
# Main pipeline function
# ------------------------------------------------------------------------
def process_all_datasets(root_dir="./GSE212396_auto", debug=False):
    """
    1) Download/unzip guide files if missing
    2) Build guide_info
    3) For each dataset in DATASETS:
        - download/unzip matrix/calls if missing
        - load into AnnData
        - harmonize
        - filter out 'Unknown', rename 'NO TARGET' => 'Non-targeting', assign condition
        - save .h5ad
    4) Combine all datasets into "combined.h5ad"
    """
    logging.info(f"=== Starting GSE212396 pipeline ===")
    os.makedirs(root_dir, exist_ok=True)

    guide_txt_paths = get_guide_files(root_dir, debug=debug)
    guide_info = load_guide_info(guide_txt_paths, debug=debug)

    adata_list = []
    for ds in DATASETS:
        logging.info(f"=== Processing {ds} ===")
        ds_files = get_dataset_files(root_dir, ds)
        adata = load_dataset(ds_files, debug=debug)
        adata = harmonize_metadata(adata, ds, guide_info, debug=debug)
        adata = filter_and_assign_condition(adata, debug=debug)
        
        out_path = os.path.join(root_dir, f"{ds}.h5ad")
        adata.write_h5ad(out_path)
        logging.info(f"Saved {out_path}")
        adata_list.append(adata)

    # Combine all
    logging.info("Concatenating all datasets.")
    combined = ad.concat(adata_list, axis=0, join="outer", label="batch", keys=DATASETS)
    combined_out = os.path.join(root_dir, "combined.h5ad")
    combined.write_h5ad(combined_out)
    logging.info(f"Combined dataset saved to {combined_out}")

    logging.info("All datasets processed successfully!")
    return combined_out

# %% [code]
# In your notebook, just call:
final_combined_path = process_all_datasets(debug=True)
print("Final combined .h5ad:", final_combined_path)
