In [None]:
import os
import glob
import tarfile
import urllib.request
import pandas as pd
import numpy as np
import anndata as ad
from scipy import sparse
from scipy.io import mmread
import h5py
import re
import logging
import time
from pathlib import Path
import gzip
import subprocess
import warnings

warnings.filterwarnings('ignore')

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(),
        logging.FileHandler('process_GSE264587.log')
    ]
)

def get_main_protein_name(name):
    """
    1. Strip known species/prefix markers: 
       - 'anti-human_'
       - 'Mouse_'
       - 'Rat_'
       - 'Armenian_Hamster_'
       - 'human_'
       (repeat if multiple occur)
    2. Keep only the substring up to the first underscore.

    Example:
      "anti-human_CD274_B7-H1_PD-L1" -> "CD274"
      "Mouse_IgG1_isotype_Ctrl" -> "IgG1"
    """
    result = re.sub(r'^anti-[^_]+_', '', name)
    
    # Remove repeated species prefixes ("Mouse_", "Rat_", etc.) if any remain
    while True:
        new_result = re.sub(r'^(?:[Mm]ouse|[Rr]at|Armenian_Hamster|human)(?:_[^_]+)?_', '', result)
        if new_result == result:
            break
        result = new_result
    
    # Keep only substring before first underscore
    if "_" in result:
        result = result.split("_", 1)[0]
    
    return result

def download_dataset(data_dir):
    """
    Download the GSE264587 dataset if not already present
    
    Parameters:
    -----------
    data_dir : str
        Directory to store the downloaded data
        
    Returns:
    --------
    str
        Path to the directory containing the extracted files
    """
    try:
        os.makedirs(data_dir, exist_ok=True)
        
        # Define the URL and target file
        url = "https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE264587&format=file"
        tar_file = os.path.join(data_dir, "GSE264587_RAW.tar")
        
        # Approx >600MB if already complete
        if os.path.exists(tar_file) and os.path.getsize(tar_file) > 600000000:
            logging.info(f"File {tar_file} already exists and appears complete. Skipping download.")
        else:
            logging.info(f"Downloading GSE264587_RAW.tar to {tar_file}...")
            
            # Use subprocess to call wget
            try:
                cmd = [
                    "wget", 
                    "--continue",
                    "--tries=5",
                    "--timeout=60",
                    "--waitretry=60",
                    "--no-verbose",
                    "-O", tar_file,
                    url
                ]
                logging.info(f"Running command: {' '.join(cmd)}")
                subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
                logging.info("Download complete.")
                
            except subprocess.CalledProcessError as e:
                logging.error(f"Error downloading file with wget: {e}")
                logging.info("Trying alternative download method with urllib...")
                
                try:
                    urllib.request.urlretrieve(url, tar_file)
                    logging.info("Download complete with urllib.")
                except Exception as e2:
                    logging.error(f"Error downloading file with urllib: {e2}")
                    raise
            except Exception as e:
                logging.error(f"Unexpected error during download: {e}")
                raise
        
        # Verify
        if not os.path.exists(tar_file):
            logging.error(f"Download failed: {tar_file} does not exist")
            raise FileNotFoundError(f"Downloaded file {tar_file} not found")
        
        size_mb = os.path.getsize(tar_file)/(1024*1024)
        logging.info(f"Downloaded file size: {size_mb:.2f} MB")
        if size_mb < 10:  # < 10 MB means likely incomplete
            logging.error(f"Downloaded file is too small ({size_mb} MB).")
            raise ValueError("Downloaded file is too small. Possibly incomplete.")
        
        # Extract
        extracted_dir = os.path.join(data_dir, "extracted")
        if not os.path.exists(extracted_dir) or len(os.listdir(extracted_dir)) == 0:
            os.makedirs(extracted_dir, exist_ok=True)
            logging.info(f"Extracting {tar_file} to {extracted_dir}...")
            try:
                # Attempt system tar
                cmd = ["tar", "-xf", tar_file, "-C", extracted_dir]
                logging.info(f"Running command: {' '.join(cmd)}")
                subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
                logging.info("Extraction complete using tar command.")
            except (subprocess.CalledProcessError, FileNotFoundError) as e:
                logging.warning(f"Extraction with tar command failed: {e}.\n Trying Python tarfile...")
                with tarfile.open(tar_file, 'r') as tar:
                    tar.extractall(path=extracted_dir)
                logging.info("Extraction complete using Python tarfile.")
        else:
            logging.info(f"Files already extracted in {extracted_dir}. Skipping extraction.")
        
        # Check extraction
        if not os.path.exists(extracted_dir) or len(os.listdir(extracted_dir)) == 0:
            logging.error(f"No files found in {extracted_dir} after extraction.")
            raise FileNotFoundError(f"No files found in {extracted_dir}")
        
        extracted_files = os.listdir(extracted_dir)
        logging.info(f"Extracted {len(extracted_files)} files: {extracted_files[:5]}...")
        return extracted_dir
    
    except Exception as e:
        logging.error(f"Failed to download or extract dataset: {e}")
        raise

def load_sample_data(extracted_dir, sample_id):
    """
    Load data for a specific sample from the extracted directory
    
    Parameters:
    -----------
    extracted_dir : str
        Directory containing the extracted files
    sample_id : str
        Sample ID (e.g., 'T086')
        
    Returns:
    --------
    tuple
        (gene_data, protein_data) - AnnData objects for gene and protein expression
    """
    try:
        # Gather sample files
        sample_files = [f for f in os.listdir(extracted_dir) if sample_id in f]
        if not sample_files:
            logging.warning(f"No files found for sample {sample_id}")
            return None, None
        
        mtx_files = [f for f in sample_files if f.endswith('matrix.mtx.gz')]
        feature_files = [f for f in sample_files if f.endswith('features.tsv.gz')]
        barcode_files = [f for f in sample_files if f.endswith('barcodes.tsv.gz')]
        
        if not (mtx_files and feature_files and barcode_files):
            logging.warning(f"Missing required matrix/features/barcodes files for {sample_id}")
            return None, None
        
        mtx_path = os.path.join(extracted_dir, mtx_files[0])
        feature_path = os.path.join(extracted_dir, feature_files[0])
        barcode_path = os.path.join(extracted_dir, barcode_files[0])
        
        logging.info(f"Loading {sample_id} from {mtx_path}")
        
        # Features
        features_df = pd.read_csv(feature_path, sep='\t', header=None, names=['id','name','feature_type'])
        
        # Ensure feature names are unique right away
        features_df['name'] = features_df['name'].astype(str)
        features_df['name'] = pd.Series(
            f"{name}_{i}" if i > 0 else name
            for name, i in zip(features_df['name'], features_df.groupby('name').cumcount())
        )
        
        # Barcodes
        barcodes = pd.read_csv(barcode_path, sep='\t', header=None, names=['barcode'])
        # Make unique cell barcodes by prepending sample ID
        barcodes['barcode'] = f"{sample_id}_" + barcodes['barcode'].astype(str)
        
        # Subset features
        gene_mask = (features_df['feature_type'] == 'Gene Expression')
        protein_mask = (features_df['feature_type'] == 'Antibody Capture')
        
        gene_features = features_df[gene_mask]
        protein_features = features_df[protein_mask]
        
        logging.info(f"{sample_id}: {len(gene_features)} gene features; {len(protein_features)} protein features.")
        
        # Matrix
        from scipy.io import mmread
        matrix = mmread(mtx_path).tocsc()  # shape = [features, cells]
        
        # Safety checks
        if matrix.shape[0] != len(features_df):
            logging.error(f"Matrix rows ({matrix.shape[0]}) != features ({len(features_df)})")
            return None, None
        if matrix.shape[1] != len(barcodes):
            logging.error(f"Matrix cols ({matrix.shape[1]}) != barcodes ({len(barcodes)})")
            return None, None
        
        # Split gene vs protein
        gene_idx = gene_features.index.tolist()
        prot_idx = protein_features.index.tolist()
        
        gene_matrix = matrix[gene_idx, :].T  # shape = [cells, genes]
        prot_matrix = matrix[prot_idx, :].T  # shape = [cells, proteins]
        
        logging.info(f"{sample_id} gene matrix: {gene_matrix.shape}; protein matrix: {prot_matrix.shape}")
        
        # Construct AnnData
        gene_data = ad.AnnData(
            X=gene_matrix,
            obs=pd.DataFrame(index=barcodes['barcode']),
            var=pd.DataFrame(
                data={
                    'feature_id': gene_features['id'].values,
                    'feature_type': gene_features['feature_type'].values
                },
                index=gene_features['name'].values
            )
        )
        
        protein_data = ad.AnnData(
            X=prot_matrix,
            obs=pd.DataFrame(index=barcodes['barcode']),
            var=pd.DataFrame(
                data={
                    'feature_id': protein_features['id'].values,
                    'feature_type': protein_features['feature_type'].values
                },
                index=protein_features['name'].values
            )
        )
        
        # If a feature_reference file exists, you could parse it here
        feature_ref = [f for f in sample_files if 'feature_reference' in f]
        
        # Clean protein var_names (remove prefix, etc.)
        if protein_data.n_vars > 0:
            original_names = protein_data.var_names.tolist()
            stripped_names = [get_main_protein_name(n) for n in original_names]
            protein_data.var_names = stripped_names
        
        # >>>>>>> MAKE OBS/VAR NAMES UNIQUE PER SAMPLE <<<<<<<
        gene_data.var_names_make_unique()
        gene_data.obs_names_make_unique()   # Should be unique already but just in case
        protein_data.var_names_make_unique()
        protein_data.obs_names_make_unique()
        
        logging.info(f"Sample {sample_id} loaded: gene_data {gene_data.shape}, protein_data {protein_data.shape}")
        return gene_data, protein_data
    
    except Exception as e:
        logging.error(f"Error loading data for {sample_id}: {e}")
        return None, None

def add_metadata(adata, sample_id, sample_metadata):
    """
    Add standardized metadata to AnnData object
    """
    if adata is None:
        return None
    
    # sample ID
    adata.obs['sample_id'] = sample_id
    
    # Standard fields
    fields = {
        'organism': 'Homo sapiens',
        'cell_type': 'Thymus cells',
        'crispr_type': 'None',
        'cancer_type': 'Non-Cancer',
        'condition': 'Normal',
        'perturbation_name': 'None'
    }
    for k,v in fields.items():
        adata.obs[k] = v
    
    # Add from sample_metadata
    if 'age' in sample_metadata:
        adata.obs['age'] = sample_metadata['age']
    if 'sex' in sample_metadata:
        adata.obs['sex'] = sample_metadata['sex']
    
    return adata

def get_sample_metadata(sample_id):
    """
    Return a dictionary of relevant metadata given the sample ID
    """
    meta = {
        'T086': {'sex': 'Male', 'age': '33 months'},
        'T087': {'sex': 'Female', 'age': '4 months'},
        'T096': {'sex': 'Male', 'age': '5 months'},
        'T097': {'sex': 'Female', 'age': '4 months'},
        'T098': {'sex': 'Male', 'age': '4 months'},
        'T099': {'sex': 'Female', 'age': '4 months'},
        'T100': {'sex': 'Male', 'age': '4 months'}
    }
    return meta.get(sample_id, {})

def process_dataset(data_dir):
    """
    Download, extract, and load the GSE264587 data
    returning (combined_gene_data, combined_protein_data).
    """
    start_time = time.time()
    logging.info(f"Processing GSE264587 in {data_dir}...")
    
    # 1. Download/extract
    extracted_dir = download_dataset(data_dir)
    
    # 2. Find sample IDs in extracted_dir
    files = os.listdir(extracted_dir)
    sample_pattern = r'GSM\d+_([A-Za-z0-9]+)_'
    sample_ids = set()
    for f in files:
        m = re.search(sample_pattern, f)
        if m:
            sample_ids.add(m.group(1))
    
    if not sample_ids:
        # fallback pattern
        alt_pattern = r'T\d+'
        for f in files:
            m = re.search(alt_pattern, f)
            if m:
                sample_ids.add(m.group(0))
    
    sample_ids = sorted(sample_ids)
    logging.info(f"Found {len(sample_ids)} samples: {sample_ids}")
    
    # 3. Load data for each sample
    all_gene_data = []
    all_protein_data = []
    
    for sid in sample_ids:
        g_dat, p_dat = load_sample_data(extracted_dir, sid)
        if g_dat is not None:
            add_metadata(g_dat, sid, get_sample_metadata(sid))
            all_gene_data.append(g_dat)
        if p_dat is not None:
            add_metadata(p_dat, sid, get_sample_metadata(sid))
            all_protein_data.append(p_dat)
    
    # >>>>>>> BEFORE CONCAT, ENSURE EACH SAMPLE'S NAMES ARE UNIQUE <<<<<<<
    # (We already did it in load_sample_data, but let's be extra safe.)
    for gd in all_gene_data:
        gd.var_names_make_unique()
        gd.obs_names_make_unique()
    for pd in all_protein_data:
        pd.var_names_make_unique()
        pd.obs_names_make_unique()
    
    # 4. Concatenate gene data
    if all_gene_data:
        combined_gene_data = ad.concat(
            all_gene_data, 
            join='outer',  # union of gene sets
            merge='same',  # only merge obs/var fields with same name
            label='sample_id'
        )
        logging.info(f"Combined gene data: {combined_gene_data.shape}")
        
        # check duplicates after concat
        if combined_gene_data.var_names.duplicated().any():
            logging.warning("Duplicate gene var_names found – making them unique.")
            combined_gene_data.var_names_make_unique()
    else:
        combined_gene_data = None
        logging.warning("No gene data found.")
    
    # 5. Concatenate protein data
    if all_protein_data:
        combined_protein_data = ad.concat(
            all_protein_data,
            join='outer',
            merge='same',
            label='sample_id'
        )
        logging.info(f"Combined protein data: {combined_protein_data.shape}")
        
        # check duplicates
        if combined_protein_data.var_names.duplicated().any():
            logging.warning("Duplicate protein var_names found – making them unique.")
            combined_protein_data.var_names_make_unique()
    else:
        combined_protein_data = None
        logging.warning("No protein data found.")
    
    # 6. Filter to cells present in both gene and protein data (if both exist)
    if combined_gene_data is not None and combined_protein_data is not None:
        common_barcodes = list(set(combined_gene_data.obs_names).intersection(combined_protein_data.obs_names))
        logging.info(f"{len(common_barcodes)} cells present in BOTH gene/protein data.")
        combined_gene_data = combined_gene_data[common_barcodes].copy()
        combined_protein_data = combined_protein_data[common_barcodes].copy()
    
    # 7. Save results
    if combined_gene_data is not None:
        gene_out = os.path.join(data_dir, "GSE264587_gene_expression.h5ad")
        combined_gene_data.write_h5ad(gene_out, compression="gzip")
        logging.info(f"Saved gene data: {gene_out}")
    
    if combined_protein_data is not None:
        protein_out = os.path.join(data_dir, "GSE264587_protein_expression.h5ad")
        combined_protein_data.write_h5ad(protein_out, compression="gzip")
        logging.info(f"Saved protein data: {protein_out}")
    
    elapsed = time.time() - start_time
    logging.info(f"Done. Elapsed time = {elapsed:.1f} s")
    return combined_gene_data, combined_protein_data

# ---------------------------
# Jupyter main cell
# ---------------------------
data_dir = "GSE264587_"  # Adjust path if needed
gene_data, protein_data = process_dataset(data_dir)
