In [None]:
import os
import sys
import re
import gzip
import shutil
import urllib.request
import tempfile
import subprocess
from collections import defaultdict

import numpy as np
import pandas as pd
import scanpy as sc
import anndata as ad
from scipy import sparse
from scipy.io import mmread



# List of files to download
FILES = [
    "GSE264681_Arrayed_NALM6_D1_filtered_feature_bc_matrix.tar.gz",
    "GSE264681_Arrayed_NALM6_D1_feature_reference_CARBC.csv.gz",
    "GSE264681_Arrayed_NALM6_D2_filtered_feature_bc_matrix.tar.gz",
    "GSE264681_Arrayed_NALM6_D2_feature_reference_CARBC.csv.gz",
    "GSE264681_CD4_Spike-In_filtered_feature_bc_matrix.tar.gz",
    "GSE264681_CD4_Spike-In_feature_reference_CARBC.csv.gz",
    "GSE264681_In_Vivo_A375_filtered_feature_bc_matrix.tar.gz",
    "GSE264681_In_Vivo_A375_feature_reference_CITEseq_CARBC.csv.gz",
    "GSE264681_Pooled_A375_filtered_feature_bc_matrix.tar.gz",
    "GSE264681_Pooled_A375_feature_reference_CARBC.csv.gz",
    "GSE264681_Pooled_NALM6_filtered_feature_bc_matrix.tar.gz",
    "GSE264681_Pooled_NALM6_feature_reference_CARBC.csv.gz",
    "GSE264681_Resting_filtered_feature_bc_matrix.tar.gz",
    "GSE264681_Resting_feature_reference_CARBC.csv.gz"
]

# Sample information (not strictly necessary for the script, but here for reference)
SAMPLES = [
    "Arrayed_NALM6_D1",
    "Arrayed_NALM6_D2",
    "CD4_Spike-In",
    "In_Vivo_A375",
    "Pooled_A375",
    "Pooled_NALM6",
    "Resting"
]

# CAR architecture dictionary
CAR_ARCHITECTURES = {
    'CAR1': 'CD28-CD3z',
    'CAR2': 'CD28-CD28-CD3z',
    'CAR3': '41BB-CD3z',
    'CAR4': '41BB-41BB-CD3z',
    'CAR5': 'CD28-41BB-CD3z',
    'CAR6': '41BB-CD28-CD3z',
    'CAR7': 'ICOS-CD3z',
    'CAR8': 'ICOS-ICOS-CD3z',
    'CAR9': 'CD28-ICOS-CD3z',
    'CAR10': 'ICOS-CD28-CD3z',
    'CAR11': '41BB-ICOS-CD3z',
    'CAR12': 'ICOS-41BB-CD3z',
    'CAR13': 'OX40-CD3z',
    'CAR14': 'OX40-OX40-CD3z',
    'CAR15': 'CD28-OX40-CD3z',
    'CAR16': 'OX40-CD28-CD3z',
    'CAR17': '41BB-OX40-CD3z',
    'CAR18': 'OX40-41BB-CD3z',
    'CAR19': 'ICOS-OX40-CD3z',
    'CAR20': 'OX40-ICOS-CD3z',
    'CAR21': 'CD27-CD3z',
    'CAR22': 'CD27-CD27-CD3z',
    'CAR23': 'CD28-CD27-CD3z',
    'CAR24': 'CD27-CD28-CD3z',
    'CAR25': '41BB-CD27-CD3z',
    'CAR26': 'CD27-41BB-CD3z',
    'CAR27': 'ICOS-CD27-CD3z',
    'CAR28': 'CD27-ICOS-CD3z',
    'CAR29': 'OX40-CD27-CD3z',
    'CAR30': 'CD27-OX40-CD3z',
    'CAR31': 'CD28-41BB-ICOS-OX40-CD27-CD3z',
    'CAR32': 'CD28-41BB-ICOS-OX40-CD27-CD3z-CD3z',
    'CAR33': 'CD28-41BB-ICOS-OX40-CD27-CD3z-CD3z-CD3z',
    'CAR34': 'CD28-ICOS-4-1BB-OX40-CD27-CD3z-CD3z',
    'CAR35': 'CD3z',
}




def download_files(data_dir):
    """Download dataset files if they don't exist."""
    os.makedirs(data_dir, exist_ok=True)
    
    for file in FILES:
        file_path = os.path.join(data_dir, file)
        if not os.path.exists(file_path):
            # Use the GEO download URL directly - this is more reliable than FTP
            url = f"https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE264681&format=file&file={file}"
            print(f"Downloading {file}...")
            
            # Use curl with --location to follow redirects
            cmd = f"curl -L -s -o {file_path} '{url}'"
            try:
                result = os.system(cmd)
                if result == 0 and os.path.exists(file_path) and os.path.getsize(file_path) > 0:
                    print(f"Downloaded {file}")
                else:
                    # If curl fails or file is empty, try wget
                    print(f"Curl failed, trying wget for {file}...")
                    cmd = f"wget -q -O {file_path} '{url}'"
                    result = os.system(cmd)
                    if result == 0 and os.path.exists(file_path) and os.path.getsize(file_path) > 0:
                        print(f"Downloaded {file} using wget")
                    else:
                        # If wget also fails, try Python's urllib as last resort
                        print(f"Wget failed, trying urllib for {file}...")
                        try:
                            # Add headers to mimic a browser request
                            headers = {
                                'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
                                'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8',
                                'Accept-Language': 'en-US,en;q=0.5',
                                'Connection': 'keep-alive',
                                'Upgrade-Insecure-Requests': '1',
                            }
                            req = urllib.request.Request(url, headers=headers)
                            with urllib.request.urlopen(req, timeout=60) as response:
                                with open(file_path, 'wb') as out_file:
                                    shutil.copyfileobj(response, out_file)
                            if os.path.exists(file_path) and os.path.getsize(file_path) > 0:
                                print(f"Downloaded {file} using urllib")
                            else:
                                print(f"Downloaded file {file} is empty")
                        except Exception as e:
                            print(f"All download methods failed for {file}: {e}")
                            print(f"Please download {file} manually from GEO and place it in {data_dir}")
            except Exception as e:
                print(f"Error during download of {file}: {e}")
                print(f"Please download {file} manually from GEO and place it in {data_dir}")
        else:
            print(f"File {file} already exists, skipping download")

def extract_tar_gz(file_path, extract_dir):
    """Extract tar.gz file to specified directory."""
    if not os.path.exists(os.path.join(extract_dir, "filtered_feature_bc_matrix")):
        os.makedirs(extract_dir, exist_ok=True)
        print(f"Extracting {file_path} to {extract_dir}...")
        
        # Check if the file exists and has content
        if not os.path.exists(file_path):
            raise FileNotFoundError(f"File {file_path} does not exist")
        
        if os.path.getsize(file_path) == 0:
            raise ValueError(f"File {file_path} is empty")
        
        # Try different extraction methods
        try:
            # First try Python's built-in extraction
            shutil.unpack_archive(file_path, extract_dir)
        except Exception as e:
            print(f"Error with shutil.unpack_archive: {e}")
            print("Trying tar command...")
            
            # Try using tar command
            cmd = f"tar -xzf {file_path} -C {extract_dir}"
            result = os.system(cmd)
            
            if result != 0:
                raise RuntimeError(f"Failed to extract {file_path} using both Python and tar command")
        
        # Verify extraction was successful
        if not os.path.exists(os.path.join(extract_dir, "filtered_feature_bc_matrix")):
            # Sometimes the directory structure might be different
            # Look for any directory that might contain the matrix files
            matrix_dirs = []
            for root, dirs, files in os.walk(extract_dir):
                if "matrix.mtx.gz" in files or "matrix.mtx" in files:
                    matrix_dirs.append(root)
            
            if matrix_dirs:
                # If we found a directory with matrix files, create a symlink
                matrix_dir = matrix_dirs[0]
                target_dir = os.path.join(extract_dir, "filtered_feature_bc_matrix")
                os.symlink(matrix_dir, target_dir)
                print(f"Created symlink from {matrix_dir} to {target_dir}")
            else:
                raise FileNotFoundError(f"Could not find matrix files in extracted directory {extract_dir}")
        
        print(f"Extracted {file_path}")
    else:
        print(f"Directory {extract_dir}/filtered_feature_bc_matrix already exists, skipping extraction")

def read_10x_mtx(matrix_dir):
    """Read 10x matrix files and return AnnData object."""
    # Check if directory exists
    if not os.path.exists(matrix_dir):
        raise FileNotFoundError(f"Matrix directory {matrix_dir} does not exist")
    
    # Check if required files exist
    required_files = ['matrix.mtx.gz', 'features.tsv.gz', 'barcodes.tsv.gz']
    missing_files = [f for f in required_files if not os.path.exists(os.path.join(matrix_dir, f))]
    
    # Check for uncompressed versions if compressed ones are missing
    for missing_file in list(missing_files):
        uncompressed = missing_file[:-3]  # Remove .gz extension
        if os.path.exists(os.path.join(matrix_dir, uncompressed)):
            missing_files.remove(missing_file)
            print(f"Found uncompressed version of {missing_file}: {uncompressed}")
    
    if missing_files:
        raise FileNotFoundError(f"Missing required files in {matrix_dir}: {', '.join(missing_files)}")
    
    try:
        # Use scanpy's read_10x_mtx function
        print(f"Reading 10x matrix from {matrix_dir}...")
        adata = sc.read_10x_mtx(matrix_dir, var_names='gene_symbols', cache=True)
        
        # Read features to get feature types
        features_file = os.path.join(matrix_dir, 'features.tsv.gz')
        if not os.path.exists(features_file):
            features_file = os.path.join(matrix_dir, 'features.tsv')  # Try uncompressed version
        
        try:
            if features_file.endswith('.gz'):
                with gzip.open(features_file, 'rt') as f:
                    features = [line.strip().split('\t') for line in f]
            else:
                with open(features_file, 'rt') as f:
                    features = [line.strip().split('\t') for line in f]
            
            # Create a DataFrame with feature information
            feature_df = pd.DataFrame(features, columns=['id', 'name', 'feature_type'])
            
            # Add feature_type column to adata.var
            adata.var['feature_type'] = 'Unknown'  # Default value
            
            # Map feature types to var_names
            feature_type_dict = dict(zip(feature_df['name'], feature_df['feature_type']))
            for var_name in adata.var_names:
                if var_name in feature_type_dict:
                    adata.var.loc[var_name, 'feature_type'] = feature_type_dict[var_name]
        except Exception as e:
            print(f"Warning: Error reading feature types: {e}")
            print("Setting all features to 'Gene Expression' by default")
            adata.var['feature_type'] = 'Gene Expression'
        
        return adata
    except Exception as e:
        print(f"Error reading with scanpy: {e}")
        
        # Manual reading as fallback
        try:
            print("Trying manual reading...")
            matrix_file = os.path.join(matrix_dir, 'matrix.mtx.gz')
            features_file = os.path.join(matrix_dir, 'features.tsv.gz')
            barcodes_file = os.path.join(matrix_dir, 'barcodes.tsv.gz')
            
            # Check for uncompressed versions
            if not os.path.exists(matrix_file):
                matrix_file = os.path.join(matrix_dir, 'matrix.mtx')
            if not os.path.exists(features_file):
                features_file = os.path.join(matrix_dir, 'features.tsv')
            if not os.path.exists(barcodes_file):
                barcodes_file = os.path.join(matrix_dir, 'barcodes.tsv')
            
            # Read the matrix
            print(f"Reading matrix from {matrix_file}...")
            X = mmread(matrix_file).T.tocsr()
            
            # Read features (genes/proteins)
            print(f"Reading features from {features_file}...")
            if features_file.endswith('.gz'):
                with gzip.open(features_file, 'rt') as f:
                    features = [line.strip().split('\t') for line in f]
            else:
                with open(features_file, 'rt') as f:
                    features = [line.strip().split('\t') for line in f]
            
            feature_ids = [f[0] for f in features]
            feature_names = [f[1] for f in features]
            
            # Handle case where feature_types might be missing
            if len(features[0]) >= 3:
                feature_types = [f[2] for f in features]
            else:
                print("Warning: Feature types not found in features file, assuming all are 'Gene Expression'")
                feature_types = ['Gene Expression'] * len(feature_names)
            
            # Read cell barcodes
            print(f"Reading barcodes from {barcodes_file}...")
            if barcodes_file.endswith('.gz'):
                with gzip.open(barcodes_file, 'rt') as f:
                    barcodes = [line.strip() for line in f]
            else:
                with open(barcodes_file, 'rt') as f:
                    barcodes = [line.strip() for line in f]
            
            # Create feature metadata
            var = pd.DataFrame({
                'feature_id': feature_ids,
                'feature_type': feature_types
            }, index=feature_names)
            
            # Create AnnData object
            print(f"Creating AnnData object with {X.shape[0]} cells and {X.shape[1]} features...")
            adata = ad.AnnData(X, obs=pd.DataFrame(index=barcodes), var=var)
            return adata
        except Exception as e2:
            print(f"Error with manual reading: {e2}")
            raise

def read_feature_reference(file_path):
    """Read feature reference file and return DataFrame."""
    try:
        if file_path.endswith('.gz'):
            with gzip.open(file_path, 'rt') as f:
                df = pd.read_csv(f)
        else:
            df = pd.read_csv(file_path)
        return df
    except Exception as e:
        print(f"Error reading feature reference file {file_path}: {e}")
        return None

def process_sample(sample_name, data_dir, feature_ref_file):
    """Process a single sample and return gene and protein AnnData objects."""
    # Check if files exist
    matrix_file = f"GSE264681_{sample_name}_filtered_feature_bc_matrix.tar.gz"
    matrix_file_path = os.path.join(data_dir, matrix_file)
    feature_ref_path = os.path.join(data_dir, feature_ref_file)
    
    if not os.path.exists(matrix_file_path):
        print(f"Warning: Matrix file {matrix_file_path} not found, skipping sample {sample_name}")
        return None, None
    
    # Extract the matrix files
    extract_dir = os.path.join(data_dir, sample_name)
    try:
        extract_tar_gz(matrix_file_path, extract_dir)
    except Exception as e:
        print(f"Error extracting {matrix_file_path}: {e}")
        return None, None
    
    # Read the matrix data
    matrix_dir = os.path.join(extract_dir, "filtered_feature_bc_matrix")
    if not os.path.exists(matrix_dir):
        print(f"Warning: Matrix directory {matrix_dir} not found, skipping sample {sample_name}")
        return None, None
    
    try:
        adata = read_10x_mtx(matrix_dir)
    except Exception as e:
        print(f"Error reading matrix data for {sample_name}: {e}")
        return None, None
    
    # Read the feature reference file if it exists
    if os.path.exists(feature_ref_path):
        try:
            feature_ref = read_feature_reference(feature_ref_path)
        except Exception as e:
            print(f"Error reading feature reference file {feature_ref_path}: {e}")
            feature_ref = None
    else:
        print(f"Warning: Feature reference file {feature_ref_path} not found")
        feature_ref = None
    
    # Add sample information to obs
    adata.obs['sample'] = sample_name
    
    # Check if feature_type column exists
    if 'feature_type' not in adata.var.columns:
        print(f"Warning: 'feature_type' column not found in {sample_name} data. Adding it...")
        adata.var['feature_type'] = 'Gene Expression'  # Default to gene expression
        
        # For In_Vivo_A375, we know it has CITE-seq data; attempt a naive classification
        if sample_name == "In_Vivo_A375":
            import re
            protein_pattern = re.compile(r'^(CD|HLA|IgG|Ig[ADGME]|IL|TNF|IFN|CCR|CXCR)')
            protein_mask = [bool(protein_pattern.match(name)) for name in adata.var_names]
            if sum(protein_mask) > 0:
                print(f"Identified {sum(protein_mask)} potential protein features based on naming patterns")
                adata.var.loc[protein_mask, 'feature_type'] = 'Antibody Capture'
    
    # Split gene expression and protein expression
    gene_mask = adata.var['feature_type'] == 'Gene Expression'
    protein_mask = adata.var['feature_type'] == 'Antibody Capture'
    
    print(f"Gene features: {sum(gene_mask)}, Protein features: {sum(protein_mask)}")
    
    if sum(gene_mask) > 0:
        gene_expr = adata[:, gene_mask].copy()
    else:
        gene_expr = ad.AnnData(
            X=sparse.csr_matrix((adata.n_obs, 0)),
            obs=adata.obs.copy()
        )
    
    if sum(protein_mask) > 0:
        protein_expr = adata[:, protein_mask].copy()
    else:
        protein_expr = ad.AnnData(
            X=sparse.csr_matrix((adata.n_obs, 0)),
            obs=adata.obs.copy()
        )
    
    # Add metadata
    for adata_obj in [gene_expr, protein_expr]:
        if adata_obj.n_obs == 0:
            continue
            
        # Add organism information
        adata_obj.obs['organism'] = 'Homo sapiens'
        
        # Add cell type information based on sample name
        if 'CD4' in sample_name:
            adata_obj.obs['cell_type'] = 'CD4+ T Cells'
        elif 'NALM6' in sample_name or 'A375' in sample_name or 'Resting' in sample_name:
            adata_obj.obs['cell_type'] = 'T Cells'
        else:
            adata_obj.obs['cell_type'] = 'T Cells'
        
        # Add CRISPR type information
        adata_obj.obs['crispr_type'] = 'None'  # No CRISPR in this dataset
        
        # Add cancer type information
        if 'NALM6' in sample_name:
            adata_obj.obs['cancer_type'] = 'B-cell acute lymphoblastic leukemia'
        elif 'A375' in sample_name:
            adata_obj.obs['cancer_type'] = 'Melanoma'
        else:
            adata_obj.obs['cancer_type'] = 'Non-Cancer'
        
        # Add condition information
        if 'Resting' in sample_name:
            adata_obj.obs['condition'] = 'Resting'
        elif 'In_Vivo' in sample_name:
            adata_obj.obs['condition'] = 'In vivo'
        elif 'Arrayed' in sample_name:
            adata_obj.obs['condition'] = 'Arrayed co-culture'
        elif 'Pooled' in sample_name:
            adata_obj.obs['condition'] = 'Pooled co-culture'
        elif 'CD4_Spike-In' in sample_name:
            adata_obj.obs['condition'] = 'CD4 spike-in'
        else:
            adata_obj.obs['condition'] = 'Unknown'
            
        # Add perturbation_name (default to Unknown)
        adata_obj.obs['perturbation_name'] = 'Unknown'
    
    # Add CAR barcode information if available
    if feature_ref is not None:
        # Map cell barcodes to CAR barcodes
        if 'cell_barcode' in feature_ref.columns and 'CAR_barcode' in feature_ref.columns:
            # Create a dictionary mapping cell barcodes to CAR barcodes
            cell_to_car = dict(zip(feature_ref['cell_barcode'], feature_ref['CAR_barcode']))
            
            # Add CAR barcode information to obs
            for adata_obj in [gene_expr, protein_expr]:
                if adata_obj.n_obs == 0:
                    continue
                
                adata_obj.obs['CAR_barcode'] = 'Unknown'
                for cell_barcode in adata_obj.obs_names:
                    if cell_barcode in cell_to_car:
                        adata_obj.obs.loc[cell_barcode, 'CAR_barcode'] = cell_to_car[cell_barcode]
                
                # Add CAR architecture information
                adata_obj.obs['perturbation_name'] = 'Unknown'
                for car_id, car_arch in CAR_ARCHITECTURES.items():
                    # Create a mask for cells with this CAR
                    car_mask = adata_obj.obs['CAR_barcode'].str.contains(car_id, regex=False)
                    if car_mask.any():
                        adata_obj.obs.loc[car_mask, 'perturbation_name'] = car_arch
                
                # Add CAR component flags
                adata_obj.obs['has_CD28'] = adata_obj.obs['perturbation_name'].str.contains('CD28', regex=False)
                adata_obj.obs['has_41BB'] = adata_obj.obs['perturbation_name'].str.contains('41BB', regex=False) | adata_obj.obs['perturbation_name'].str.contains('4-1BB', regex=False)
                adata_obj.obs['has_ICOS'] = adata_obj.obs['perturbation_name'].str.contains('ICOS', regex=False)
                adata_obj.obs['has_OX40'] = adata_obj.obs['perturbation_name'].str.contains('OX40', regex=False)
                adata_obj.obs['has_CD27'] = adata_obj.obs['perturbation_name'].str.contains('CD27', regex=False)
                adata_obj.obs['has_CD3z'] = adata_obj.obs['perturbation_name'].str.contains('CD3z', regex=False)
                
                # Convert boolean columns to string for consistency
                for col in ['has_CD28', 'has_41BB', 'has_ICOS', 'has_OX40', 'has_CD27', 'has_CD3z']:
                    adata_obj.obs[col] = adata_obj.obs[col].astype(str)
    
    return gene_expr, protein_expr

def harmonize_data(data_dir):
    """Process all samples and harmonize the data."""
    # Define sample to feature reference file mapping
    samples = {
        "Arrayed_NALM6_D1": "GSE264681_Arrayed_NALM6_D1_feature_reference_CARBC.csv.gz",
        "Arrayed_NALM6_D2": "GSE264681_Arrayed_NALM6_D2_feature_reference_CARBC.csv.gz",
        "CD4_Spike-In": "GSE264681_CD4_Spike-In_feature_reference_CARBC.csv.gz",
        "In_Vivo_A375": "GSE264681_In_Vivo_A375_feature_reference_CITEseq_CARBC.csv.gz",
        "Pooled_A375": "GSE264681_Pooled_A375_feature_reference_CARBC.csv.gz",
        "Pooled_NALM6": "GSE264681_Pooled_NALM6_feature_reference_CARBC.csv.gz",
        "Resting": "GSE264681_Resting_feature_reference_CARBC.csv.gz"
    }
    
    # Process each sample
    gene_adatas = []
    protein_adatas = []
    
    for sample_name, feature_ref_file in samples.items():
        print(f"Processing sample: {sample_name}")
        gene_adata, protein_adata = process_sample(sample_name, data_dir, feature_ref_file)
        
        if gene_adata is not None and gene_adata.n_obs > 0:
            print(f"Adding gene expression data for {sample_name}: {gene_adata.n_obs} cells, {gene_adata.n_vars} genes")
            # Make cell barcodes unique by adding sample name
            gene_adata.obs_names = [f"{sample_name}_{bc}" for bc in gene_adata.obs_names]
            gene_adatas.append(gene_adata)
        else:
            print(f"Skipping gene expression data for {sample_name}: no data or errors occurred")
        
        if protein_adata is not None and protein_adata.n_obs > 0:
            print(f"Adding protein expression data for {sample_name}: {protein_adata.n_obs} cells, {protein_adata.n_vars} proteins")
            # Make cell barcodes unique by adding sample name
            protein_adata.obs_names = [f"{sample_name}_{bc}" for bc in protein_adata.obs_names]
            protein_adatas.append(protein_adata)
        else:
            print(f"Skipping protein expression data for {sample_name}: no data or errors occurred")
    
    # Combine all gene expression data
    combined_gene = None
    if gene_adatas:
        try:
            # Check for duplicate var_names
            for adata in gene_adatas:
                adata.var_names_make_unique()
            
            # Concatenate all gene expression data
            combined_gene = ad.concat(gene_adatas, join='outer', merge='same')
            print(f"Combined gene expression data: {combined_gene.n_obs} cells, {combined_gene.n_vars} genes")
            
            # Ensure var_names are unique
            combined_gene.var_names_make_unique()
        except Exception as e:
            print(f"Error combining gene expression data: {e}")
    else:
        print("No gene expression data to combine")
    
    # Combine all protein expression data
    combined_protein = None
    if protein_adatas:
        try:
            # Check for duplicate var_names
            for adata in protein_adatas:
                adata.var_names_make_unique()
            
            # Concatenate all protein expression data
            combined_protein = ad.concat(protein_adatas, join='outer', merge='same')
            print(f"Combined protein expression data: {combined_protein.n_obs} cells, {combined_protein.n_vars} proteins")
            
            # Ensure var_names are unique
            combined_protein.var_names_make_unique()
        except Exception as e:
            print(f"Error combining protein expression data: {e}")
    else:
        print("No protein expression data to combine")
    
    # Keep only paired cells (cells that exist in both gene and protein datasets)
    if combined_gene is not None and combined_protein is not None:
        common_cells = list(set(combined_gene.obs_names).intersection(set(combined_protein.obs_names)))
        if common_cells:
            print(f"Found {len(common_cells)} cells in both gene and protein expression datasets")
            combined_gene = combined_gene[common_cells].copy()
            combined_protein = combined_protein[common_cells].copy()
            
            # Ensure the cell order is the same in both datasets
            combined_protein = combined_protein[combined_gene.obs_names].copy()
        else:
            print("Warning: No common cells found between gene and protein expression datasets")
    
    return combined_gene, combined_protein

def process_and_save_all_data(data_dir):
    """Main function to process the dataset."""
    # Download files if they don't exist
    download_files(data_dir)
    
    # Check if we have at least some files to process
    available_files = [f for f in FILES if os.path.exists(os.path.join(data_dir, f))]
    if not available_files:
        print("\nNo files were successfully downloaded. Please check your internet connection.")
        print("You can download the files manually from GEO:")
        print("https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE264681")
        return
    
    print(f"\nSuccessfully downloaded or found {len(available_files)} out of {len(FILES)} files.")
    print("Proceeding with available data...")
    
    # Process and harmonize the data
    gene_adata, protein_adata = harmonize_data(data_dir)
    
    # Check if we have any data to save
    if gene_adata is None and protein_adata is None:
        print("\nNo data was successfully processed. Please check the error messages above.")
        print("You may need to download some files manually from GEO:")
        print("https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE264681")
        return
    
    # Save the harmonized data
    if gene_adata is not None:
        gene_adata.write_h5ad(os.path.join(data_dir, 'gene_expression.h5ad'))
        print(f"Saved gene expression data to {os.path.join(data_dir, 'gene_expression.h5ad')}")
        print(f"Gene expression data shape: {gene_adata.shape}")
    
    if protein_adata is not None:
        protein_adata.write_h5ad(os.path.join(data_dir, 'protein_expression.h5ad'))
        print(f"Saved protein expression data to {os.path.join(data_dir, 'protein_expression.h5ad')}")
        print(f"Protein expression data shape: {protein_adata.shape}")
    
    print("\nProcessing complete!")

# Example usage in Jupyter
# Adjust 'data_dir' to the folder where you want data to be downloaded and processed
data_dir = "./GSE264681_data"

# Run the main processing function
process_and_save_all_data(data_dir)
