In [None]:
"""
GSE252589 Dataset Processor (Jupyter-friendly)

This notebook cell includes all the necessary functions to download, process, 
and harmonize the GSE252589 dataset into h5ad format. The dataset contains 
single-cell RNA-seq data from osteosarcoma xenograft implants from 
TrkA^WT;NOD-Scid and TrkA^F592A;NOD-Scid animals.

Usage:
    1. Set your desired `root_path`.
    2. Call `adata = process_gse252589(root_path, perform_qc=True)` to run QC 
       or `adata = process_gse252589(root_path)` without QC.

Output:
    - GSE252589_harmonized.h5ad: Harmonized AnnData object with standardized metadata
    - GSE252589_qc_report.pdf: Quality control report (if `perform_qc=True`)
"""

import os
import gzip
import numpy as np
import pandas as pd
import scanpy as sc
import scipy.io
import anndata as ad
import matplotlib.pyplot as plt
from pathlib import Path
import urllib.request
import tarfile
import shutil

def download_dataset(output_dir):
    """
    Download the GSE252589 dataset from GEO.
    
    Parameters:
    -----------
    output_dir : Path
        Path to the directory where the dataset will be downloaded
    """
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Download the dataset
    tar_file = output_dir / "GSE252589_RAW.tar"
    if not tar_file.exists():
        print(f"Downloading GSE252589 dataset to {tar_file}...")
        url = "https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE252589&format=file"
        urllib.request.urlretrieve(url, tar_file)
    
    # Extract the dataset
    if not (output_dir / "GSM8003561_barcodes.TrkA.WT.tsv.gz").exists():
        print("Extracting dataset...")
        with tarfile.open(tar_file, "r") as tar:
            tar.extractall(path=output_dir)
    
    print("Dataset downloaded and extracted successfully.")

def load_10x_data(matrix_file, features_file, barcodes_file, sample_id):
    """
    Load 10x Genomics data from matrix, features, and barcodes files.
    
    Parameters:
    -----------
    matrix_file : Path
        Path to the matrix.mtx.gz file
    features_file : Path
        Path to the features.tsv.gz file
    barcodes_file : Path
        Path to the barcodes.tsv.gz file
    sample_id : str
        Sample identifier
    
    Returns:
    --------
    adata : AnnData
        AnnData object containing the loaded data
    """
    # Read the matrix
    X = scipy.io.mmread(gzip.open(matrix_file, 'rb')).T.tocsr()
    
    # Read features (genes)
    with gzip.open(features_file, 'rt') as f:
        gene_df = pd.read_csv(f, sep='\t', header=None)
    
    # Read barcodes
    with gzip.open(barcodes_file, 'rt') as f:
        barcodes = pd.read_csv(f, sep='\t', header=None)[0].values
    
    # Extract gene information
    gene_ids = gene_df[0].values
    gene_symbols = gene_df[1].values
    
    # Clean up gene symbols - remove the prefix if present
    clean_gene_symbols = []
    for symbol in gene_symbols:
        if symbol.startswith('GRCh38_'):
            clean_gene_symbols.append(symbol.split('_', 1)[1])
        else:
            clean_gene_symbols.append(symbol)
    
    # Create var DataFrame with gene information
    var = pd.DataFrame(index=clean_gene_symbols)
    var['gene_ids'] = gene_ids
    var['feature_types'] = (gene_df[2].values 
                            if gene_df.shape[1] > 2 
                            else ['Gene Expression'] * len(gene_df))
    
    # Create obs DataFrame with cell information
    obs = pd.DataFrame(index=barcodes)
    obs['sample_id'] = sample_id
    
    # Create AnnData object
    adata = ad.AnnData(X=X, obs=obs, var=var)
    
    return adata

def perform_quality_control(adata, output_dir):
    """
    Perform quality control on the dataset and generate a QC report.
    
    Parameters:
    -----------
    adata : AnnData
        AnnData object containing the dataset
    output_dir : Path
        Path to the directory where the QC report will be saved
    """
    # Calculate QC metrics
    sc.pp.calculate_qc_metrics(adata, inplace=True)
    
    # Set up the figure
    fig = plt.figure(figsize=(15, 12))
    fig.suptitle("GSE252589 Quality Control Report", fontsize=16)
    
    # Plot 1: Number of genes per cell
    ax1 = fig.add_subplot(2, 3, 1)
    sc.pl.violin(adata, 'n_genes_by_counts', groupby='sample_id', ax=ax1, show=False)
    ax1.set_title('Number of genes per cell')
    
    # Plot 2: Number of UMIs per cell
    ax2 = fig.add_subplot(2, 3, 2)
    sc.pl.violin(adata, 'total_counts', groupby='sample_id', ax=ax2, show=False)
    ax2.set_title('Number of UMIs per cell')
    
    # Plot 3: Percentage of mitochondrial genes
    adata.var['mt'] = adata.var_names.str.startswith('MT-')
    if adata.var['mt'].sum() > 0:
        sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], inplace=True)
        ax3 = fig.add_subplot(2, 3, 3)
        sc.pl.violin(adata, 'pct_counts_mt', groupby='sample_id', ax=ax3, show=False)
        ax3.set_title('Percentage of mitochondrial genes')
    else:
        ax3 = fig.add_subplot(2, 3, 3)
        ax3.text(0.5, 0.5, "No mitochondrial genes found", 
                 horizontalalignment='center', verticalalignment='center')
        ax3.set_title('Percentage of mitochondrial genes')
    
    # Plot 4: Scatter plot of UMIs vs genes
    ax4 = fig.add_subplot(2, 3, 4)
    sc.pl.scatter(adata, 'total_counts', 'n_genes_by_counts', color='sample_id', ax=ax4, show=False)
    ax4.set_title('UMIs vs genes')
    
    # Plot 5: Distribution of top 20 expressed genes
    ax5 = fig.add_subplot(2, 3, 5)
    sc.pl.highest_expr_genes(adata, n_top=20, ax=ax5, show=False)
    ax5.set_title('Top 20 expressed genes')
    
    # Plot 6: Number of cells per sample
    ax6 = fig.add_subplot(2, 3, 6)
    adata.obs['sample_id'].value_counts().plot(kind='bar', ax=ax6)
    ax6.set_title('Number of cells per sample')
    ax6.set_ylabel('Number of cells')
    
    # Adjust layout and save the figure
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    qc_report_file = output_dir / "GSE252589_qc_report.pdf"
    plt.savefig(qc_report_file)
    plt.close()
    
    print(f"Quality control report saved to {qc_report_file}")

def process_gse252589(root_path, perform_qc=False):
    """
    Process GSE252589 dataset and harmonize it to h5ad format.
    
    Parameters:
    -----------
    root_path : str
        Path to the directory containing or to contain the GSE252589 data files
    perform_qc : bool, optional (default: False)
        Whether to perform quality control and generate a QC report
    
    Returns:
    --------
    adata : AnnData
        Harmonized AnnData object
    """
    print(f"Processing GSE252589 dataset from {root_path}")
    root_path = Path(root_path)
    
    # Download the dataset if it doesn't exist
    if not (root_path / "GSM8003561_barcodes.TrkA.WT.tsv.gz").exists():
        download_dataset(root_path)
    
    # Define file paths
    matrix_file_wt = root_path / "GSM8003561_matrix.mtx.TrkA.WT.mtx.gz"
    features_file_wt = root_path / "GSM8003561_features.TrkA.WT.tsv.gz"
    barcodes_file_wt = root_path / "GSM8003561_barcodes.TrkA.WT.tsv.gz"
    
    matrix_file_mut = root_path / "GSM8003562_matrix.TrkA.Mut.mtx.gz"
    features_file_mut = root_path / "GSM8003562_features.TrkA.Mut.tsv.gz"
    barcodes_file_mut = root_path / "GSM8003562_barcodes.TrkA.Mut.tsv.gz"
    
    # Check if files exist
    for file_path in [matrix_file_wt, features_file_wt, barcodes_file_wt, 
                      matrix_file_mut, features_file_mut, barcodes_file_mut]:
        if not file_path.exists():
            raise FileNotFoundError(f"File not found: {file_path}")
    
    # Load data for TrkA WT sample
    print("Loading TrkA WT sample data...")
    adata_wt = load_10x_data(
        matrix_file=matrix_file_wt,
        features_file=features_file_wt,
        barcodes_file=barcodes_file_wt,
        sample_id="TrkA_WT"
    )
    adata_wt.var_names_make_unique()
    
    # Load data for TrkA Mutant sample
    print("Loading TrkA Mutant sample data...")
    adata_mut = load_10x_data(
        matrix_file=matrix_file_mut,
        features_file=features_file_mut,
        barcodes_file=barcodes_file_mut,
        sample_id="TrkA_F592A"
    )
    adata_mut.var_names_make_unique()
    
    # Make sure both datasets have the same var_names order
    common_genes = adata_wt.var_names.intersection(adata_mut.var_names)
    adata_wt = adata_wt[:, common_genes]
    adata_mut = adata_mut[:, common_genes]
    
    # Store gene information before concatenation
    gene_ids = adata_wt.var['gene_ids'].copy()
    feature_types = adata_wt.var['feature_types'].copy()
    
    # Concatenate the two datasets
    print("Concatenating datasets...")
    adata = ad.concat([adata_wt, adata_mut], join="outer", label="batch")
    
    # Restore gene information
    adata.var['gene_ids'] = gene_ids
    adata.var['feature_types'] = feature_types
    
    # Fix sample_id values
    adata.obs["sample_id"] = adata.obs["batch"].astype(str)
    adata.obs.loc[adata.obs["batch"] == "0", "sample_id"] = "TrkA_WT"
    adata.obs.loc[adata.obs["batch"] == "1", "sample_id"] = "TrkA_F592A"
    
    # Add metadata
    print("Adding metadata...")
    adata.obs["organism"] = "Mus musculus"
    adata.obs["cell_type"] = "Osteosarcoma"
    adata.obs["crispr_type"] = "None"
    adata.obs["cancer_type"] = "Osteosarcoma"
    
    # Add condition information
    adata.obs["condition"] = "Control"
    adata.obs.loc[adata.obs["sample_id"] == "TrkA_F592A", "condition"] = "Test"
    
    # Add perturbation information
    adata.obs["perturbation_name"] = "None"
    adata.obs.loc[adata.obs["sample_id"] == "TrkA_F592A", "perturbation_name"] = "TrkA_F592A_mutation"
    
    # -- Ad hoc changes for perturbation_name --
    # Ensure 'perturbation_name' exists and is categorical
    if "perturbation_name" not in adata.obs:
        adata.obs["perturbation_name"] = pd.Categorical(["Unknown"] * adata.shape[0])
    
    # Add 'Non-Targeting' to the list of categories and assign it to Control cells
    adata.obs["perturbation_name"] = adata.obs["perturbation_name"].astype("category")
    adata.obs["perturbation_name"] = adata.obs["perturbation_name"].cat.add_categories(["Non-Targeting"])
    adata.obs.loc[adata.obs["condition"] == "Control", "perturbation_name"] = "Non-Targeting"
    
    # Add 'NTRK1' to the list of categories and update mutation values
    adata.obs["perturbation_name"] = adata.obs["perturbation_name"].astype("category")
    adata.obs["perturbation_name"] = adata.obs["perturbation_name"].cat.add_categories(["NTRK1"])
    adata.obs.loc[adata.obs["perturbation_name"] == "TrkA_F592A_mutation", "perturbation_name"] = "NTRK1"
    
    # Add additional metadata from GEO
    adata.obs["genotype"] = "TrkA_WT;NOD-Scid"
    adata.obs.loc[adata.obs["sample_id"] == "TrkA_F592A", "genotype"] = "TrkA_F592A;NOD-Scid"
    
    adata.obs["cell_line"] = "143B human OS"
    adata.obs["treatment"] = "None"
    adata.obs.loc[adata.obs["sample_id"] == "TrkA_F592A", "treatment"] = "1NMPP1"
    
    # Add dataset-specific metadata
    adata.uns["dataset_id"] = "GSE252589"
    adata.uns["dataset_name"] = (
        "TrkA^WT;NOD-Scid and TrkA^F592A;NOD-Scid xenograft osteosarcoma"
    )
    adata.uns["dataset_description"] = (
        "Single-cell RNA-seq of osteosarcoma xenograft implants from TrkA^WT;NOD-Scid "
        "and TrkA^F592A;NOD-Scid animals at 12 days post-tumor cell inoculation."
    )
    
    # Make observation names unique
    adata.obs_names_make_unique()
    
    # Perform QC if requested
    if perform_qc:
        print("Performing quality control...")
        perform_quality_control(adata, root_path)
    
    # Save the harmonized dataset
    output_file = root_path / "GSE252589_harmonized.h5ad"
    print(f"Saving harmonized dataset to {output_file}")
    adata.write(output_file)
    
    return adata


# --------------------------------------------------------------------------------------
# EXAMPLE USAGE IN A JUPYTER NOTEBOOK:
# 
root_path = "/content/GSE252589"
adata = process_gse252589(root_path, perform_qc=True)
print(f"Processed dataset shape: {adata.shape}")
# --------------------------------------------------------------------------------------
