In [None]:
import os
import requests
import gzip
import shutil
import numpy as np
import pandas as pd
import anndata as ad
from scipy import sparse

# If you need rpy2 for reading .rds files, uncomment the following:
# import rpy2.robjects as robjects
# from rpy2.robjects import pandas2ri
# from rpy2.robjects.packages import importr
# pandas2ri.activate()
# base = importr('base')

GEO_ACCESSION = 'GSE278692'
DATASET_TYPES = ['CRC', 'HNSCC']

FILE_URLS = {
    'CRC_RNACounts': f'https://ftp.ncbi.nlm.nih.gov/geo/series/GSE278nnn/{GEO_ACCESSION}/suppl/{GEO_ACCESSION}_CRC_CD8TIL_RNACounts.rds.gz',
    'HNSCC_RNACounts': f'https://ftp.ncbi.nlm.nih.gov/geo/series/GSE278nnn/{GEO_ACCESSION}/suppl/{GEO_ACCESSION}_HNSCC_CD8TIL_RNACounts.rds.gz',
    'CRC_ADTCounts': f'https://ftp.ncbi.nlm.nih.gov/geo/series/GSE278nnn/{GEO_ACCESSION}/suppl/{GEO_ACCESSION}_CRC_CD8TIL_ADTCounts.rds.gz',
    'HNSCC_ADTCounts': f'https://ftp.ncbi.nlm.nih.gov/geo/series/GSE278nnn/{GEO_ACCESSION}/suppl/{GEO_ACCESSION}_HNSCC_CD8TIL_ADTCounts.rds.gz',
    'CRC_metadata': f'https://ftp.ncbi.nlm.nih.gov/geo/series/GSE278nnn/{GEO_ACCESSION}/suppl/{GEO_ACCESSION}_CRC_CD8TIL_metadata.csv.gz',
    'HNSCC_metadata': f'https://ftp.ncbi.nlm.nih.gov/geo/series/GSE278nnn/{GEO_ACCESSION}/suppl/{GEO_ACCESSION}_HNSCC_CD8TIL_metadata.csv.gz',
    'feature_reference': f'https://ftp.ncbi.nlm.nih.gov/geo/series/GSE278nnn/{GEO_ACCESSION}/suppl/{GEO_ACCESSION}_feature_reference.csv.gz'
}

CELL_TYPE_MAP = {
    'DP': 'CD39+CD103+ Double Positive CD8+ T cells',
    'DN': 'CD39-CD103- Double Negative CD8+ T cells',
    'SP39': 'CD39+ Single Positive CD8+ T cells',
    'SP103': 'CD103+ Single Positive CD8+ T cells'
}

STUDY_INFO = {
    'geo_accession': GEO_ACCESSION,
    'title': 'IL-12 drives the expression of the inhibitory receptor NKG2A on human tumor-reactive CD8 T cells',
    'organism': 'Homo sapiens',
    'experiment_type': 'Expression profiling by high throughput sequencing'
}


def robust_download(url, output_path, max_retries=3, chunk_size=8192):
    """
    Download a file from `url` to `output_path` in a more robust way, 
    retrying up to `max_retries` times if the download fails.
    """
    if os.path.exists(output_path):
        print(f"File already exists: {output_path}")
        return

    attempts = 0
    while attempts < max_retries:
        attempts += 1
        try:
            print(f"Downloading (attempt {attempts}/{max_retries}): {url}")
            with requests.get(url, stream=True, timeout=30) as r:
                r.raise_for_status()
                with open(output_path, "wb") as f:
                    for chunk in r.iter_content(chunk_size=chunk_size):
                        if chunk:  # filter out keep-alive new chunks
                            f.write(chunk)
            # If we get here, download succeeded
            break
        except Exception as e:
            print(f"Download failed: {e}")
            if attempts == max_retries:
                raise RuntimeError(f"Failed to download {url} after {max_retries} attempts")

    # If it's a gzipped file, automatically decompress it (unless it's .h5ad.gz)
    if output_path.endswith('.gz') and not output_path.endswith('.h5ad.gz'):
        decompressed_path = output_path[:-3]  # remove .gz
        print(f"Decompressing {output_path} -> {decompressed_path}")
        with gzip.open(output_path, 'rb') as f_in:
            with open(decompressed_path, 'wb') as f_out:
                shutil.copyfileobj(f_in, f_out)
        print(f"Decompressed to {decompressed_path}")


def download_dataset(data_dir):
    """Download the GSE278692 dataset files (if not already present)."""
    os.makedirs(data_dir, exist_ok=True)

    for name, url in FILE_URLS.items():
        out_path = os.path.join(data_dir, os.path.basename(url))
        robust_download(url, out_path)


def extract_rds_matrix(rds_path):
    """
    Extract a matrix from an RDS file using rpy2.
    Note: Requires rpy2 and a working R installation.
    """
    import rpy2.robjects as robjects

    # Load the RDS file
    rds_obj = robjects.r['readRDS'](rds_path)
    
    # Extract dimensions
    dims = robjects.r['dim'](rds_obj)
    n_features, n_cells = dims[0], dims[1]
    
    # Extract row/column names
    features = list(robjects.r['rownames'](rds_obj))
    barcodes = list(robjects.r['colnames'](rds_obj))
    
    # R function to extract the sparse components
    robjects.r('''
    extract_sparse_matrix <- function(mat) {
        list(
            i = mat@i,
            p = mat@p,
            x = mat@x,
            dim = dim(mat)
        )
    }
    ''')
    sparse_components = robjects.r['extract_sparse_matrix'](rds_obj)
    i = np.array(sparse_components.rx2('i'))
    p = np.array(sparse_components.rx2('p'))
    x = np.array(sparse_components.rx2('x'))
    
    # Construct scipy CSC matrix
    matrix = sparse.csc_matrix((x, i, p), shape=(n_features, n_cells))
    return matrix, features, barcodes


def process_dataset(data_dir, dataset_type):
    """Process and return (gene_adata, protein_adata) for the specified dataset_type."""
    print(f"Processing {dataset_type} dataset...")

    rna_path = os.path.join(data_dir, f"{GEO_ACCESSION}_{dataset_type}_CD8TIL_RNACounts.rds")
    adt_path = os.path.join(data_dir, f"{GEO_ACCESSION}_{dataset_type}_CD8TIL_ADTCounts.rds")
    metadata_path = os.path.join(data_dir, f"{GEO_ACCESSION}_{dataset_type}_CD8TIL_metadata.csv")

    # Make sure decompressed versions exist if only .gz was downloaded
    for fp in [rna_path, adt_path, metadata_path]:
        gz_fp = fp + ".gz"
        if not os.path.exists(fp) and os.path.exists(gz_fp):
            print(f"Decompressing {gz_fp}")
            try:
                with gzip.open(gz_fp, 'rb') as f_in, open(fp, 'wb') as f_out:
                    shutil.copyfileobj(f_in, f_out)
            except Exception as e:
                raise RuntimeError(f"Failed to decompress {gz_fp}: {e}")

    # Load metadata
    if not os.path.exists(metadata_path):
        print(f"Warning: metadata not found at {metadata_path}")
        metadata = pd.DataFrame()
    else:
        metadata = pd.read_csv(metadata_path)
        print(f"Loaded metadata with {len(metadata)} entries.")

    # Extract RNA data
    gene_adata = None
    if os.path.exists(rna_path):
        try:
            rna_matrix, rna_features, rna_barcodes = extract_rds_matrix(rna_path)
            print(f"RNA data shape: {rna_matrix.shape}")
            gene_adata = ad.AnnData(
                X=rna_matrix.T, 
                var=pd.DataFrame(index=rna_features)
            )
            gene_adata.var['gene_symbol'] = rna_features
            gene_adata.obs_names = rna_barcodes
        except Exception as e:
            print(f"Error extracting RNA from {rna_path}: {e}")

    # Extract ADT data
    protein_adata = None
    if os.path.exists(adt_path):
        try:
            adt_matrix, adt_features, adt_barcodes = extract_rds_matrix(adt_path)
            print(f"ADT data shape: {adt_matrix.shape}")
            # Build an AnnData for proteins
            protein_var = pd.DataFrame(index=adt_features)
            protein_var['feature_type'] = 'Antibody Capture'
            protein_var['target'] = adt_features

            protein_adata = ad.AnnData(
                X=adt_matrix.T,
                var=protein_var
            )
            protein_adata.obs_names = adt_barcodes
        except Exception as e:
            print(f"Error extracting ADT from {adt_path}: {e}")

    # Harmonize the metadata into gene_adata
    if gene_adata is not None and not metadata.empty:
        if 'barcode' in metadata.columns:
            metadata_indexed = metadata.set_index('barcode')
            common_cells = set(gene_adata.obs_names).intersection(metadata_indexed.index)
            print(f"{len(common_cells)} common cells between gene data and metadata.")
            # Subset gene_adata to common cells
            gene_adata = gene_adata[list(common_cells)].copy()
            # Build harmonized metadata
            harmonized_metadata = pd.DataFrame(index=gene_adata.obs_names)
            harmonized_metadata['barcode'] = gene_adata.obs_names
            harmonized_metadata['organism'] = 'Homo sapiens'
            harmonized_metadata['cancer_type'] = (
                'Colorectal Cancer' if dataset_type=='CRC' else 'Head and Neck Squamous Cell Carcinoma'
            )
            # Transfer columns from metadata
            for col in metadata_indexed.columns:
                if col in ['barcode']:
                    continue
                # fill from metadata, ensuring order
                harmonized_metadata[col] = metadata_indexed.loc[gene_adata.obs_names, col].values

            # Add cell_type from subgroup (if present)
            if 'subgroup' in harmonized_metadata.columns:
                harmonized_metadata['cell_type'] = (
                    harmonized_metadata['subgroup']
                    .map(lambda x: CELL_TYPE_MAP.get(str(x), 'CD8+ T cells') if pd.notna(x) else 'CD8+ T cells')
                )
            else:
                harmonized_metadata['cell_type'] = 'CD8+ T cells'

            # Additional required columns
            if 'patient' in harmonized_metadata.columns:
                harmonized_metadata['condition'] = harmonized_metadata['patient']
            else:
                harmonized_metadata['condition'] = '0'

            harmonized_metadata['perturbation_name'] = 'None'
            harmonized_metadata['crispr_type'] = 'None'

            gene_adata.obs = harmonized_metadata

    # If both gene & protein exist, align them
    if gene_adata is not None and protein_adata is not None:
        common_cells = set(gene_adata.obs_names).intersection(protein_adata.obs_names)
        print(f"{len(common_cells)} cells intersect gene and protein data.")
        gene_adata = gene_adata[list(common_cells)].copy()
        protein_adata = protein_adata[list(common_cells)].copy()
        # copy metadata
        protein_adata.obs = gene_adata.obs.copy()

    return gene_adata, protein_adata


def run_harmonization(data_dir, no_download=False):
    """Complete harmonization workflow."""
    if not no_download:
        download_dataset(data_dir)

    gene_adatas, protein_adatas = {}, {}

    for ds_type in DATASET_TYPES:
        print("\n" + "="*80)
        print(f"Processing: {ds_type}")
        print("="*80)
        g_adata, p_adata = process_dataset(data_dir, ds_type)
        gene_adatas[ds_type] = g_adata
        protein_adatas[ds_type] = p_adata

    # Saving results
    output_dir = os.path.join(data_dir, 'harmonized')
    os.makedirs(output_dir, exist_ok=True)

    print("\n" + "="*80)
    print("Saving individual datasets")
    print("="*80)

    for ds_type in DATASET_TYPES:
        g_adata = gene_adatas[ds_type]
        p_adata = protein_adatas[ds_type]
        if g_adata is not None:
            g_fp = os.path.join(output_dir, f"{GEO_ACCESSION}_{ds_type}_gene_expression.h5ad")
            print(f"Saving gene expression for {ds_type} -> {g_fp}")
            g_adata.write_h5ad(g_fp)
        if p_adata is not None:
            p_fp = os.path.join(output_dir, f"{GEO_ACCESSION}_{ds_type}_protein_expression.h5ad")
            print(f"Saving protein expression for {ds_type} -> {p_fp}")
            p_adata.write_h5ad(p_fp)

    # Combine across CRC / HNSCC
    print("\n" + "="*80)
    print("Creating combined datasets")
    print("="*80)

    # Combine gene expression
    valid_g_adatas = []
    for ds_type in DATASET_TYPES:
        g_adata = gene_adatas[ds_type]
        if g_adata is not None:
            # Make obs names unique per dataset
            g_adata.obs['dataset_type'] = ds_type
            g_adata.obs_names = [f"{ds_type}_{x}" for x in g_adata.obs_names]
            g_adata.obs['barcode'] = g_adata.obs_names
            if len(set(g_adata.var_names)) < len(g_adata.var_names):
                g_adata.var_names_make_unique()
            valid_g_adatas.append(g_adata)
    if valid_g_adatas:
        print("Combining gene expression AnnData objects...")
        combined_g_adata = ad.concat(valid_g_adatas, join='outer', merge='same')
        combined_g_adata.uns['study_info'] = STUDY_INFO
        combined_g_fp = os.path.join(output_dir, f"{GEO_ACCESSION}_combined_gene_expression.h5ad")
        print(f"Writing combined gene expression -> {combined_g_fp}")
        combined_g_adata.write_h5ad(combined_g_fp)
    else:
        combined_g_adata = None

    # Combine protein
    valid_p_adatas = []
    for ds_type in DATASET_TYPES:
        p_adata = protein_adatas[ds_type]
        if p_adata is not None:
            p_adata.obs['dataset_type'] = ds_type
            p_adata.obs_names = [f"{ds_type}_{x}" for x in p_adata.obs_names]
            p_adata.obs['barcode'] = p_adata.obs_names
            if len(set(p_adata.var_names)) < len(p_adata.var_names):
                p_adata.var_names_make_unique()
            valid_p_adatas.append(p_adata)
    if valid_p_adatas:
        print("Combining protein AnnData objects...")
        combined_p_adata = ad.concat(valid_p_adatas, join='outer', merge='same')
        combined_p_adata.uns['study_info'] = STUDY_INFO
        combined_p_fp = os.path.join(output_dir, f"{GEO_ACCESSION}_combined_protein_expression.h5ad")
        print(f"Writing combined protein expression -> {combined_p_fp}")
        combined_p_adata.write_h5ad(combined_p_fp)
    else:
        combined_p_adata = None

    # Final summary
    print("\nHarmonization complete! Results saved in:", output_dir)
    for ds_type in DATASET_TYPES:
        g_adata = gene_adatas[ds_type]
        p_adata = protein_adatas[ds_type]
        if g_adata is not None:
            print(f"{ds_type} gene: {g_adata.shape}")
        if p_adata is not None:
            print(f"{ds_type} protein: {p_adata.shape}")
    if combined_g_adata is not None:
        print("Combined gene:", combined_g_adata.shape)
    if combined_p_adata is not None:
        print("Combined protein:", combined_p_adata.shape)


# ------------------------------------------------------------------------------
# Usage within Jupyter:
# Set data_dir to desired path, then run `run_harmonization(data_dir, no_download=False)`.
data_dir = "./data_GSE278692"
run_harmonization(data_dir, no_download=False)
