In [1]:
# %%
import os
import pandas as pd
import scanpy as sc
import anndata as ad
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
from scipy.sparse import csr_matrix
import datetime
from collections import defaultdict, Counter
import scipy.sparse as sp
from pathlib import Path
from pandas.api.types import CategoricalDtype

# Get current directory
current_dir = os.getcwd()
print("Current working directory:", current_dir)

# Configuration
today = datetime.datetime.now().strftime("%Y-%m-%d")
print(f"Processing date: {today}")

# %%
def validate_file_exists(filepath, description=""):
    """Validate that a file exists and is readable."""
    if not os.path.exists(filepath):
        raise FileNotFoundError(f"❌ {description} file not found: {filepath}")
    if not os.access(filepath, os.R_OK):
        raise PermissionError(f"❌ {description} file not readable: {filepath}")
    print(f"✅ {description} file found: {filepath}")
    return True

def validate_adata_structure(adata, dataset_name):
    """Validate AnnData object structure and content."""
    print(f"\n=== VALIDATING {dataset_name.upper()} ADATA ===")
    
    print(f"Shape: {adata.shape}")
    print(f"Obs columns: {list(adata.obs.columns)}")
    print(f"Var columns: {list(adata.var.columns)}")
    
    # Check for required columns
    required_obs = ['dataset']
    missing_obs = [col for col in required_obs if col not in adata.obs.columns]
    
    if missing_obs:
        print(f"❌ Missing required obs columns: {missing_obs}")
        return False
    
    # Check data integrity
    if adata.obs.index.duplicated().any():
        dup_count = adata.obs.index.duplicated().sum()
        print(f"❌ {dup_count} duplicate cell indices found!")
        return False
    
    if adata.var.index.duplicated().any():
        dup_count = adata.var.index.duplicated().sum()
        print(f"❌ {dup_count} duplicate gene indices found!")
        return False
    
    print(f"✅ {dataset_name} AnnData structure validated")
    return True

def extract_metadata_with_validation(adata, dataset_name, required_columns):
    """Extract metadata from AnnData with validation."""
    print(f"\n=== EXTRACTING {dataset_name.upper()} METADATA ===")
    
    # Check if required columns exist
    missing_cols = [col for col in required_columns if col not in adata.obs.columns]
    if missing_cols:
        print(f"❌ Missing columns in {dataset_name}: {missing_cols}")
        print(f"Available columns: {list(adata.obs.columns)}")
        raise ValueError(f"Required columns missing from {dataset_name}")
    
    # Extract metadata
    metadata = adata.obs[required_columns].reset_index(drop=True)
    metadata["dataset"] = dataset_name
    
    # Check for missing values
    missing_counts = metadata.isnull().sum()
    if missing_counts.any():
        print(f"⚠️  Missing values in {dataset_name}:")
        for col, count in missing_counts[missing_counts > 0].items():
            print(f"  {col}: {count} missing ({count/len(metadata)*100:.1f}%)")
    
    print(f"✅ Extracted {len(metadata)} records from {dataset_name}")
    return metadata

def validate_cell_id_format(metadata, dataset_name, cell_id_col="cell_id"):
    """Validate cell ID format and uniqueness."""
    print(f"\n=== VALIDATING {dataset_name.upper()} CELL IDS ===")
    
    if cell_id_col not in metadata.columns:
        print(f"❌ Cell ID column '{cell_id_col}' not found!")
        return False
    
    cell_ids = metadata[cell_id_col]
    
    # Check for duplicates
    duplicates = cell_ids.duplicated().sum()
    if duplicates > 0:
        print(f"❌ {duplicates} duplicate cell IDs found!")
        # Show examples
        dup_examples = cell_ids[cell_ids.duplicated()].head(3).tolist()
        print(f"Examples: {dup_examples}")
        return False
    
    # Check for missing/null IDs
    missing = cell_ids.isnull().sum()
    if missing > 0:
        print(f"❌ {missing} missing cell IDs found!")
        return False
    
    # Check for empty strings
    empty = (cell_ids == "").sum()
    if empty > 0:
        print(f"❌ {empty} empty cell IDs found!")
        return False
    
    print(f"✅ All {len(cell_ids)} cell IDs are valid and unique")
    return True

def convert_categoricals_to_strings(df):
    """Convert categorical columns to strings for consistency."""
    converted_cols = []
    for col in df.columns:
        if isinstance(df[col].dtype, CategoricalDtype):
            df[col] = df[col].astype(str)
            converted_cols.append(col)
    
    if converted_cols:
        print(f"✅ Converted categorical columns to strings: {converted_cols}")
    
    return df

def validate_metadata_consistency(metadata1, metadata2, dataset1_name, dataset2_name):
    """Validate consistency between two metadata datasets."""
    print(f"\n=== VALIDATING METADATA CONSISTENCY ===")
    print(f"Comparing {dataset1_name} vs {dataset2_name}")
    
    # Check column consistency
    cols1, cols2 = set(metadata1.columns), set(metadata2.columns)
    common_cols = cols1.intersection(cols2)
    unique_to_1 = cols1 - cols2
    unique_to_2 = cols2 - cols1
    
    print(f"Common columns: {len(common_cols)}")
    if unique_to_1:
        print(f"Unique to {dataset1_name}: {unique_to_1}")
    if unique_to_2:
        print(f"Unique to {dataset2_name}: {unique_to_2}")
    
    # Check value consistency for common categorical columns
    categorical_cols = ['sex', 'dataset']
    for col in categorical_cols:
        if col in common_cols:
            vals1 = set(metadata1[col].unique())
            vals2 = set(metadata2[col].unique())
            all_vals = vals1.union(vals2)
            print(f"{col} values across datasets: {sorted(all_vals)}")
    
    return True

# %%
print("=== STARTING METADATA PROCESSING PIPELINE ===")

# File paths
outdir = "/gpfs/commons/groups/knowles_lab/Karin/Leaflet-analysis-WD/HUMAN_SPLICING_FOUNDATION/processed_data/"
ab_file = f"{outdir}/ab_adata_exons_2025-09-20.h5ad"
ts_file = f"{outdir}/tabsap_adata_2025-09-20.h5ad"

# Validate input files
print("\n=== FILE VALIDATION ===")
validate_file_exists(ab_file, "Allen Brain")
validate_file_exists(ts_file, "Tabula Sapiens")

# Load data with validation
print("\n=== LOADING ANNDATA OBJECTS ===")
print("⏳ Loading Allen Brain data...")
ab_exons = sc.read_h5ad(ab_file)
validate_adata_structure(ab_exons, "Allen Brain")

print("⏳ Loading Tabula Sapiens data...")
ts_adata = sc.read_h5ad(ts_file)
validate_adata_structure(ts_adata, "Tabula Sapiens")

# %%
print("\n=== PROCESSING ALLEN BRAIN METADATA ===")

# Define required columns for Allen Brain
ab_required_cols = [
    "sample_name", "cell_type_designation_label", "cell_type_alias_label", 
    "specimen_type", "subclass_label", "donor_sex_label", 
    "external_donor_name_label", "class_label", "region_label"
]

# Extract Allen Brain metadata
ab_metadata = extract_metadata_with_validation(ab_exons, "allen_brain", ab_required_cols)

# SANITY CHECK 1: Validate donor information
print("\n--- Allen Brain Donor Validation ---")
donor_counts = ab_metadata["external_donor_name_label"].value_counts()
print(f"Donors found: {list(donor_counts.index)}")
print(f"Cells per donor:\n{donor_counts}")

# Add age information with validation
print("\n--- Adding Age Information ---")

# adding age manuall for donors using https://pmc.ncbi.nlm.nih.gov/articles/PMC6919571/table/T1/
age_mapping = {
    "H200.1030": 54,  # Caucasian
    "H200.1023": 43,  # Iranian descent  
    "H200.1025": 50   # Caucasian
}

ab_metadata["age"] = 0
cells_with_age = 0

for donor, age in age_mapping.items():
    donor_mask = ab_metadata["external_donor_name_label"] == donor
    donor_count = donor_mask.sum()
    if donor_count > 0:
        ab_metadata.loc[donor_mask, "age"] = age
        cells_with_age += donor_count
        print(f"✅ Set age {age} for {donor_count:,} cells from donor {donor}")
    else:
        print(f"⚠️  No cells found for donor {donor}")

print(f"Total cells with age assigned: {cells_with_age:,} / {len(ab_metadata):,}")

# SANITY CHECK 2: Validate age assignment
age_dist = ab_metadata["age"].value_counts().sort_index()
print(f"Age distribution:\n{age_dist}")

if (ab_metadata["age"] == 0).any():
    cells_without_age = (ab_metadata["age"] == 0).sum()
    print(f"⚠️  {cells_without_age} cells still have age = 0")

# %%
print("\n=== PROCESSING TABULA SAPIENS METADATA ===")

# Define required columns for Tabula Sapiens
ts_required_cols = [
    "old_index", "sample_id", "donor", "tissue", "cell_ontology_class", 
    "compartment", "broad_cell_class", "age", "sex", "dataset", "free_annotation"
]

# Extract Tabula Sapiens metadata
ts_metadata = extract_metadata_with_validation(ts_adata, "tabula_sapiens", ts_required_cols)

# SANITY CHECK 3: Validate old_index format for cell ID generation
print("\n--- Validating Tabula Sapiens Cell ID Generation ---")
sample_old_index = ts_metadata["old_index"].iloc[0]
print(f"Sample old_index format: '{sample_old_index}'")

# Parse old_index components
ts_metadata["project_id"] = ts_metadata["sample_id"].str.extract(r'^(TSP\d+)')
parts = ts_metadata["old_index"].str.split("_")

# Check if splitting worked correctly
if len(parts.iloc[0]) < 8:
    print(f"❌ ERROR: old_index format unexpected. Expected 8+ parts, got {len(parts.iloc[0])}")
    print(f"Sample parts: {parts.iloc[0]}")
    raise ValueError("old_index parsing failed")

# Generate cell ID components with validation
try:
    ts_metadata["junc_prefix_core"] = (
        parts.str[0] + "_" +  # TSP1
        parts.str[1] + "_" +  # smartseq2
        parts.str[2] + "_" +  # NA
        parts.str[4] + "_" +  # B107921
        parts.str[6] + "_" +  # Muscle
        parts.str[7]          # NA
    )
    
    ts_metadata["cell_id_prefix"] = (
        ts_metadata["project_id"] + "_" +
        ts_metadata["junc_prefix_core"] +
        ".multi_star.output_raw.output_raw_per_"
    )
    
    ts_metadata["cell_id"] = ts_metadata["cell_id_prefix"] + ts_metadata["old_index"]
    
    print("✅ Cell ID generation successful")
    
    # Show examples
    print("Sample generated cell IDs:")
    for i in range(min(3, len(ts_metadata))):
        print(f"  {ts_metadata['cell_id'].iloc[i]}")
        
except Exception as e:
    print(f"❌ ERROR in cell ID generation: {e}")
    raise

# Validate generated cell IDs
validate_cell_id_format(ts_metadata, "Tabula Sapiens", "cell_id")

# %%
print("\n=== STANDARDIZING METADATA FORMATS ===")

# Create working copies
ts_meta = ts_metadata.copy()
ab_meta = ab_metadata.copy()

# SANITY CHECK 4: Pre-standardization validation
print("--- Pre-standardization shapes ---")
print(f"Tabula Sapiens: {ts_meta.shape}")
print(f"Allen Brain: {ab_meta.shape}")

# Rename columns to match target structure
ts_meta = ts_meta.rename(columns={
    "donor": "donor",
    "sex": "sex", 
    "free_annotation": "cell_type",
})

ab_meta = ab_meta.rename(columns={
    "sample_name": "cell_id",
    "external_donor_name_label": "donor",
    "donor_sex_label": "sex",
    "region_label": "tissue",
    "subclass_label": "cell_type"
})

# Define final column structure
final_columns = ["cell_id", "donor", "sex", "age", "dataset", "tissue", "cell_type"]

# SANITY CHECK 5: Check for missing final columns
for dataset_name, df in [("Tabula Sapiens", ts_meta), ("Allen Brain", ab_meta)]:
    missing_cols = [col for col in final_columns if col not in df.columns]
    if missing_cols:
        print(f"❌ Missing final columns in {dataset_name}: {missing_cols}")
        print(f"Available columns: {list(df.columns)}")
        raise ValueError(f"Column standardization failed for {dataset_name}")

# Select and reorder columns
ts_meta = ts_meta[final_columns].reset_index(drop=True)
ab_meta = ab_meta[final_columns].reset_index(drop=True)

# Convert categoricals to strings
ts_meta = convert_categoricals_to_strings(ts_meta)
ab_meta = convert_categoricals_to_strings(ab_meta)

print("✅ Metadata standardization complete")

# %%
print("\n=== COMBINING AND CLEANING METADATA ===")

# Validate consistency before combining
validate_metadata_consistency(ts_meta, ab_meta, "Tabula Sapiens", "Allen Brain")

# Combine metadata
combined_metadata = pd.concat([ts_meta, ab_meta], ignore_index=True)
print(f"✅ Combined metadata shape: {combined_metadata.shape}")

# SANITY CHECK 6: Post-combination validation
print("--- Post-combination validation ---")
dataset_counts = combined_metadata["dataset"].value_counts()
print(f"Dataset distribution:\n{dataset_counts}")

expected_total = len(ts_meta) + len(ab_meta)
if len(combined_metadata) != expected_total:
    print(f"❌ ERROR: Expected {expected_total} rows, got {len(combined_metadata)}")
    raise ValueError("Row count mismatch after combination")

# Clean sex column with validation
print("\n--- Cleaning Sex Column ---")
sex_before = combined_metadata["sex"].value_counts()
print(f"Sex values before cleaning:\n{sex_before}")

combined_metadata["sex"] = combined_metadata["sex"].astype(str)
combined_metadata.loc[combined_metadata["sex"] == "male", "sex"] = "M"
combined_metadata.loc[combined_metadata["sex"] == "female", "sex"] = "F"

sex_after = combined_metadata["sex"].value_counts()
print(f"Sex values after cleaning:\n{sex_after}")

# Check for unexpected sex values
unexpected_sex = combined_metadata[~combined_metadata["sex"].isin(["M", "F"])]
if len(unexpected_sex) > 0:
    print(f"⚠️  {len(unexpected_sex)} cells with unexpected sex values:")
    print(unexpected_sex["sex"].value_counts())

# %%
print("\n=== CELL TYPE MAPPING ===")

# Define refined cell type mapping
grouped_refined_map = {
    # === NEURONS - Split by major functional classes ===
    'Excitatory_Neuron': [
        'IT', 'L4 IT', 'L5 ET', 'L6 CT', 'L6b', 'L5/6 IT Car3', 'L5/6 NP'
    ],
    'Inhibitory_Neuron': [
        'VIP', 'PVALB', 'SST', 'LAMP5'
    ],
    'Other_Neuron': [
        'PAX6', 'retinal bipolar neuron'
    ],
    
    # === T CELLS - Organized by major functional subsets ===
    'CD4_T_cell': [
        'cd4-positive, alpha-beta t cell', 'cd4-positive helper t cell', 
        'cd4-positive, alpha-beta memory t cell', 'naive thymus-derived cd4-positive, alpha-beta t cell',
        'activated cd4-positive, alpha-beta t cell', 'cd4-positive, alpha-beta thymocyte',
        'cd4-positive memory t cell', 't follicular helper cell'
    ],
    'CD8_T_cell': [
        'cd8-positive, alpha-beta t cell', 'cd8-positive, alpha-beta memory t cell',
        'cd8-positive, alpha-beta thymocyte', 'naive cd8-positive t cell',
        'activated cd8-positive, alpha-beta t cell', 'cd8-positive cytotoxic t cell',
        'cd8+, alpha-beta cytokine secreting effector t cell'
    ],
    'Regulatory_T_cell': [
        'regulatory t cell', 'naive regulatory t cell'
    ],
    'Other_T_cell': [
        't cell', 'gamma-delta t cell', 'thymocyte'
    ],
    
    # === B CELLS ===
    'B_cell': [
        'b cell', 'memory b cell', 'naive b cell'
    ],
    'Plasma_cell': [
        'plasma cell', 'antibody secreting cell'
    ],
    
    # === MYELOID CELLS - Split by major lineages ===
    'Microglia': [
        'Microglia', 'microglial cell', 'retina - microglia'
    ],
    'Macrophage': [
        'macrophage', 'Monocyte_Macrophage', 'tissue-resident macrophage', 
        'muscle macrophage'
    ],
    'Monocyte': [
        'monocyte', 'classical monocyte', 'non-classical monocyte', 'intermediate monocyte'
    ],
    'Dendritic_cell': [
        'dendritic cell', 'myeloid dendritic cell', 'plasmacytoid dendritic cell',
        'cd1c-positive myeloid dendritic cell', 'cd141-positive myeloid dendritic cell',
        'cdc1', 'cdc2', 'conventional dendritic cell'
    ],
    'Granulocyte': [
        'neutrophil', 'cd24 neutrophil', 'nampt neutrophil', 'granulocyte',
        'basophil', 'mast cell'
    ],
    
    # === NK/ILC ===
    'NK_ILC': [
        'nk cell', 'natural killer cell', 'nk t cell', 'innate lymphoid cell', 
        'mature nk t cell', 'type i nk t cell', 'uterine nk cell', 
        'proliferating nk cell', 'immature natural killer cell'
    ],
    
    # === EPITHELIAL - Split by organ system ===
    'Respiratory_Epithelial': [
        'club cell', 'ionocyte', 'goblet cell', 'ciliated epithelial cell',
        'pulmonary ionocyte', 'serous cell of epithelium of bronchus',
        'respiratory goblet cell', 'ciliated columnar cell of tracheobronchial tree',
        'tracheal goblet cell'
    ],
    'GI_Epithelial': [
        'enterocyte of epithelium of large intestine', 
        'enterocyte of epithelium proper of small intestine',
        'large intestine goblet cell', 'best4+ intestinal epithelial cell',
        'small intestine goblet cell', 'enterocyte of epithelium proper of ileum',
        'enterocyte of epithelium proper of duodenum', 
        'paneth cell of epithelium of small intestine', 'paneth cell of colon',
        'mature enterocyte', 'intestinal tuft cell', 'tuft cell of colon'
    ],
    'Urogenital_Epithelial': [
        'bladder urothelial cell', 'basal bladder urothelial cell', 
        'intermediate bladder urothelial cell', 'epithelial cell of uterus'
    ],
    'Mammary_Epithelial': [
        'HR positive luminal epithelial cell of mammary gland',
        'secretory luminal epithelial cell of mammary gland',
        'luminal epithelial cell'
    ],
    'Other_Epithelial': [
        'epithelial cell', 'duct epithelial cell', 'ltf+ epithelial cell',
        'basal epithelial cell', 'salivary gland cell', 'medullary thymic epithelial cell',
        'conjunctival epithelial cell', 'corneal epithelial cell', 'glandular epithelial cell',
        'cycling epithelial cell', 'mucus secreting cell', 'biliary epithelial cell',
        'pancreatic ductal cell', 'stratified squamous epithelial cell', 'sebum secreting cell'
    ],
    
    # === ENDOTHELIAL - Split by vessel type ===
    'Arterial_Endothelial': [
        'arterial endothelial cell', 'endothelial cell of arteriole', 'endothelial cell of artery'
    ],
    'Venous_Endothelial': [
        'vein endothelial cell', 'venous capillary endothelial cell', 'endothelial cell of venule'
    ],
    'Capillary_Endothelial': [
        'capillary endothelial cell', 'blood vessel endothelial cell'
    ],
    'Lymphatic_Endothelial': [
        'endothelial cell of lymphatic vessel'
    ],
    'Specialized_Endothelial': [
        'endothelial cell', 'endothelial cell of vascular tree', 'cardiac endothelial cell',
        'colon endothelial cell', 'retinal blood vessel endothelial cell', 'vascular endothelial cell'
    ],
    
    # === GLIA - Split by CNS vs PNS ===
    'CNS_Glia': [
        'Astrocyte', 'OPC', 'Oligodendrocyte', 'retina - muller glia', 'mueller cell'
    ],
    'PNS_Glia': [
        'enteroglial cell', 'schwann cell'
    ],
    'Glia_Other': [
        'glial cell'
    ],
    
    # === MUSCLE - Split by muscle type ===
    'Smooth_Muscle': [
        'smooth muscle cell', 'airway smooth muscle cell', 'vascular associated smooth muscle cell'
    ],
    'Cardiac_Muscle': [
        'atrial cardiac muscle cell', 'ventricular cardiac muscle cell'
    ],
    'Skeletal_Muscle': [
        'skeletal muscle satellite stem cell', 'fast muscle cell', 'slow muscle cell',
        'tongue muscle cell'
    ],
    'Muscle_Other': [
        'muscle cell', 'tendon cell'
    ],
    
    # === STROMAL/FIBROBLAST - More specific organ groupings ===
    'General_Fibroblast': [
        'fibroblast', 'stromal cell', 'myofibroblast cell', 'adventitial fibroblast',
        'cd34+ fibroblasts', 'VLMC', 'adventitial cell', 'connective tissue cell'
    ],
    'Organ_Specific_Fibroblast': [
        'alveolar fibroblast', 'fibroblast of breast', 'fibroblast of cardiac tissue',
        'uterine fibroblast', 'stellate_fibroblast', 'endometrial stromal fibroblast'
    ],
    'Specialized_Stromal': [
        'fat cell', 'cornea - mesenchymal cell - stromal keratinocytes',
        'limbal stromal cell', 'follicle', 'granulosa cell', 'mesothelial cell', 'theca cell'
    ],
    
    # === LIVER - Split by major cell types ===
    'Hepatocyte': [
        'hepatocyte'
    ],
    'Liver_Non_Parenchymal': [
        'hepatic stellate cell', 'intrahepatic cholangiocyte'
    ],
    
    # === SPECIALIZED CELLS ===
    'Pericyte': [
        'pericyte', 'Pericyte', 'myofibroblast cell and pericyte', 'mural cell'
    ],
    
    'Photoreceptor': [
        'retinal pigment epithelial cell', 'retina - photoreceptor cell', 'eye photoreceptor cell'
    ],
    
    'Alveolar_cell': [
        'type ii pneumocyte', 'type i pneumocyte', 'capillary aerocyte'
    ],
    
    'Enteroendocrine': [
        'enteroendocrine cell of small intestine', 'type l enteroendocrine cell',
        'enterochromaffin-like cell'
    ],
    
    'Hematopoietic_Mature': [
        'platelet', 'erythrocyte'
    ],
    
    'Hematopoietic_Progenitor': [
        'erythroid progenitor cell', 'hematopoietic stem cell', 'myeloid progenitor',
        'common myeloid progenitor'
    ],
    
    'Mesenchymal_Stem': [
        'mesenchymal stem cell', 'mesenchymal stem cell of adipose tissue'
    ],
    
    'Stem_Progenitor_Other': [
        'oocyte', 'radial glia progenitor cell', 'intestinal crypt stem cell of small intestine',
        'intestinal crypt stem cell of large intestine'
    ],
    
    'Secretory_Gland': [
        'acinar cell of salivary gland', 'lacrimal gland functional unit cell', 'myoepithelial cell'
    ],
    
    'Pigment_cell': [
        'melanocyte', 'melanocyte or limbal stem cell'
    ],
    
    'Sensory_cell': [
        'taste receptor cell'
    ],
    
    'Skin_cell': [
        'keratocyte'
    ],
    
    'Myeloid_Other': [
        'myeloid cell', 'mononuclear phagocyte'
    ],
    
    'Immune_Other': [
        'leukocyte', 'langerhans cell', 'immune cell'
    ],
    
    'Unknown': [
        'unknown'
    ]
}

# SANITY CHECK 7: Validate cell type mapping
print("--- Cell Type Mapping Validation ---")

# Flatten the mapping
def flatten_grouped_map(grouped_map):
    flattened = {}
    for broad_type, labels in grouped_map.items():
        for label in labels:
            if label in flattened:
                print(f"⚠️  WARNING: '{label}' appears in multiple categories: {flattened[label]} and {broad_type}")
            flattened[label] = broad_type
    return flattened

grouped_broad_map_flat = flatten_grouped_map(grouped_refined_map)
print(f"✅ Cell type mapping created: {len(grouped_broad_map_flat)} mappings")

# Apply mapping
combined_metadata['broad_cell_type'] = (
    combined_metadata['cell_type']
    .map(grouped_broad_map_flat)
    .fillna('Other')
)

# Handle unmapped cell types
unmapped_mask = combined_metadata["broad_cell_type"] == "Other"
unmapped_count = unmapped_mask.sum()

if unmapped_count > 0:
    print(f"⚠️  {unmapped_count} cells with unmapped cell types")
    unmapped_types = combined_metadata[unmapped_mask]["cell_type"].value_counts()
    print(f"Top unmapped cell types:\n{unmapped_types.head(10)}")
    
    # Use original cell_type for unmapped
    combined_metadata.loc[unmapped_mask, "broad_cell_type"] = combined_metadata.loc[unmapped_mask, "cell_type"]
    print("✅ Used original cell_type for unmapped cells")

# SANITY CHECK 8: Analyze broad cell type distribution
print("\n--- Broad Cell Type Distribution ---")
broad_type_counts = combined_metadata["broad_cell_type"].value_counts()
print(f"Number of broad cell types: {len(broad_type_counts)}")
print(f"Top 10 broad cell types:\n{broad_type_counts.head(10)}")

# %%
print("\n=== SHARED CELL TYPE ANALYSIS ===")

# Determine shared cell types with at least 50 cells per dataset
min_cells_per_dataset = 50

print(f"Finding cell types with ≥{min_cells_per_dataset} cells per dataset...")

broad_counts = (
    combined_metadata.groupby(["dataset", "broad_cell_type"])
    .size().reset_index(name="count")
)

pivot_counts = broad_counts.pivot(
    index="broad_cell_type", 
    columns="dataset", 
    values="count"
).fillna(0)

# Find shared cell types meeting threshold
shared_mask = (pivot_counts >= min_cells_per_dataset).all(axis=1)
shared_broad_cell_types = pivot_counts[shared_mask].index.tolist()

print(f"✅ Shared broad cell types (≥{min_cells_per_dataset} per dataset): {len(shared_broad_cell_types)}")
print(f"Shared cell types: {shared_broad_cell_types}")

# SANITY CHECK 9: Detailed shared cell type analysis
print("\n--- Detailed Shared Cell Type Analysis ---")
print("Cell counts for shared types:")
shared_counts_df = pivot_counts.loc[shared_broad_cell_types].round(0).astype(int)
print(shared_counts_df)

# Calculate total cells in shared types
total_shared_cells = combined_metadata[
    combined_metadata["broad_cell_type"].isin(shared_broad_cell_types)
].shape[0]
total_cells = len(combined_metadata)
shared_percentage = (total_shared_cells / total_cells) * 100

print(f"\nShared cell type summary:")
print(f"  Total cells: {total_cells:,}")
print(f"  Cells in shared types: {total_shared_cells:,}")
print(f"  Percentage in shared types: {shared_percentage:.1f}%")

# Show cells excluded from shared analysis
excluded_types = pivot_counts[~shared_mask]
if len(excluded_types) > 0:
    print(f"\nExcluded cell types ({len(excluded_types)}):")
    for cell_type in excluded_types.index[:10]:  # Show first 10
        ab_count = excluded_types.loc[cell_type, "allen_brain"]
        ts_count = excluded_types.loc[cell_type, "tabula_sapiens"] 
        print(f"  {cell_type}: AB={ab_count:.0f}, TS={ts_count:.0f}")
    if len(excluded_types) > 10:
        print(f"  ... and {len(excluded_types) - 10} more")

# %%
print("\n=== SAVING METADATA ===")

# Output path
output_path = "/gpfs/commons/groups/knowles_lab/Karin/Leaflet-analysis-WD/HUMAN_SPLICING_FOUNDATION/human_metadata_combined.tsv"

# Create output directory if needed
output_dir = os.path.dirname(output_path)
os.makedirs(output_dir, exist_ok=True)

# SANITY CHECK 10: Pre-save validation
print("--- Pre-save Validation ---")
print(f"Final metadata shape: {combined_metadata.shape}")
print(f"Columns: {list(combined_metadata.columns)}")

# Check for missing values in critical columns
critical_cols = ["cell_id", "dataset", "broad_cell_type"]
for col in critical_cols:
    missing = combined_metadata[col].isnull().sum()
    if missing > 0:
        print(f"❌ {missing} missing values in critical column '{col}'")
    else:
        print(f"✅ No missing values in '{col}'")

# Check for duplicate cell IDs
total_cell_ids = len(combined_metadata)
unique_cell_ids = combined_metadata["cell_id"].nunique()
if total_cell_ids != unique_cell_ids:
    duplicates = total_cell_ids - unique_cell_ids
    print(f"❌ {duplicates} duplicate cell IDs found!")
    
    # Show duplicate examples
    dup_ids = combined_metadata[combined_metadata["cell_id"].duplicated()]["cell_id"].head(3)
    print(f"Example duplicates: {dup_ids.tolist()}")
else:
    print(f"✅ All {total_cell_ids:,} cell IDs are unique")

# Save metadata
try:
    combined_metadata.to_csv(output_path, sep="\t", index=False)
    file_size = os.path.getsize(output_path) / (1024**2)  # MB
    print(f"✅ Metadata saved to: {output_path}")
    print(f"File size: {file_size:.1f} MB")
except Exception as e:
    print(f"❌ Error saving metadata: {e}")
    raise



Current working directory: /gpfs/commons/home/kisaev/Leaflet-analysis/Human_Splicing_Foundation/metadata
Processing date: 2025-09-20
=== STARTING METADATA PROCESSING PIPELINE ===

=== FILE VALIDATION ===
✅ Allen Brain file found: /gpfs/commons/groups/knowles_lab/Karin/Leaflet-analysis-WD/HUMAN_SPLICING_FOUNDATION/processed_data//ab_adata_exons_2025-09-20.h5ad
✅ Tabula Sapiens file found: /gpfs/commons/groups/knowles_lab/Karin/Leaflet-analysis-WD/HUMAN_SPLICING_FOUNDATION/processed_data//tabsap_adata_2025-09-20.h5ad

=== LOADING ANNDATA OBJECTS ===
⏳ Loading Allen Brain data...

=== VALIDATING ALLEN BRAIN ADATA ===
Shape: (49417, 50281)
Obs columns: ['sample_name', 'specimen_type', 'cluster_color', 'cluster_order', 'cluster_label', 'class_color', 'class_order', 'class_label', 'subclass_color', 'subclass_order', 'subclass_label', 'full_genotype_color', 'full_genotype_order', 'full_genotype_label', 'donor_sex_color', 'donor_sex_order', 'donor_sex_label', 'region_color', 'region_order', 'r

In [2]:
# %%
print("\n=== FILTERING JUNCTION FILES ===")

# Junction file paths
ab_junc_filelist = "/gpfs/commons/groups/knowles_lab/Karin/Leaflet-analysis-WD/HUMAN_SPLICING_FOUNDATION/ATSE_mapper/junction_files_AB_20250920.txt"
ts_junc_filelist = "/gpfs/commons/groups/knowles_lab/Karin/Leaflet-analysis-WD/HUMAN_SPLICING_FOUNDATION/ATSE_mapper/junction_files_TS_20250920.txt"

# Output paths
clean_ts = "/gpfs/commons/groups/knowles_lab/Karin/Leaflet-analysis-WD/HUMAN_SPLICING_FOUNDATION/ATSE_mapper/junction_files_TS_subset.txt"
clean_ab = "/gpfs/commons/groups/knowles_lab/Karin/Leaflet-analysis-WD/HUMAN_SPLICING_FOUNDATION/ATSE_mapper/junction_files_AB_subset.txt"

# Validate junction file lists exist
print("--- Junction File Validation ---")
for filepath, desc in [(ab_junc_filelist, "Allen Brain"), (ts_junc_filelist, "Tabula Sapiens")]:
    try:
        validate_file_exists(filepath, f"{desc} junction list")
    except FileNotFoundError:
        print(f"⚠️  {desc} junction list not found, skipping junction filtering")
        continue

# Get valid cell IDs from metadata
valid_cell_ids = set(combined_metadata["cell_id"])
print(f"Valid cell IDs for filtering: {len(valid_cell_ids):,}")

def filter_junction_paths(filelist_path, output_path, valid_ids, dataset_name):
    """Filter junction file paths to only include cells with metadata."""
    print(f"\n--- Filtering {dataset_name} Junction Files ---")
    
    if not os.path.exists(filelist_path):
        print(f"⚠️  {dataset_name} junction list not found: {filelist_path}")
        return
    
    # Read junction file list
    with open(filelist_path) as f:
        lines = f.read().splitlines()
    
    print(f"Total junction files in list: {len(lines):,}")
    
    # Function to extract cell ID from file path
    def extract_cell_id_from_path(path):
        fname = Path(path).name
        return fname.replace("_junctions_with_barcodes.bed", "")
    
    # Filter paths and track statistics
    filtered_lines = []
    missing_metadata = []
    
    for path in lines:
        cell_id = extract_cell_id_from_path(path)
        if cell_id in valid_ids:
            filtered_lines.append(path)
        else:
            missing_metadata.append(cell_id)
    
    # SANITY CHECK 11: Junction filtering validation
    print(f"Junction filtering results for {dataset_name}:")
    print(f"  Original files: {len(lines):,}")
    print(f"  Files with metadata: {len(filtered_lines):,}")
    print(f"  Files excluded: {len(missing_metadata):,}")
    print(f"  Retention rate: {len(filtered_lines)/len(lines)*100:.1f}%")
    
    if len(missing_metadata) > 0:
        print(f"  Example excluded cell IDs: {missing_metadata[:3]}")
    
    # Save filtered list
    try:
        with open(output_path, "w") as f:
            f.write("\n".join(filtered_lines))
        print(f"✅ Filtered {dataset_name} list saved to: {output_path}")
        
        # Verify saved file
        with open(output_path) as f:
            saved_lines = f.read().splitlines()
        
        if len(saved_lines) != len(filtered_lines):
            print(f"❌ ERROR: Saved file has {len(saved_lines)} lines, expected {len(filtered_lines)}")
        else:
            print(f"✅ Verification passed: {len(saved_lines)} lines saved correctly")
            
    except Exception as e:
        print(f"❌ Error saving filtered {dataset_name} list: {e}")
        raise
    
    return {
        'original_count': len(lines),
        'filtered_count': len(filtered_lines),
        'excluded_count': len(missing_metadata),
        'retention_rate': len(filtered_lines)/len(lines) if len(lines) > 0 else 0
    }

# Filter both junction file lists
filtering_results = {}

# Filter Allen Brain junctions
if os.path.exists(ab_junc_filelist):
    filtering_results['allen_brain'] = filter_junction_paths(
        ab_junc_filelist, clean_ab, valid_cell_ids, "Allen Brain"
    )

# Filter Tabula Sapiens junctions  
if os.path.exists(ts_junc_filelist):
    filtering_results['tabula_sapiens'] = filter_junction_paths(
        ts_junc_filelist, clean_ts, valid_cell_ids, "Tabula Sapiens"
    )

# %%
print("\n=== POST-PROCESSING VERIFICATION ===")

# Verify saved metadata can be read back
print("--- Metadata Verification ---")
try:
    test_metadata = pd.read_csv(output_path, sep="\t")
    print(f"✅ Metadata file readable: {test_metadata.shape}")
    
    # Check key properties
    if test_metadata.shape == combined_metadata.shape:
        print("✅ Shape matches original")
    else:
        print(f"❌ Shape mismatch: saved {test_metadata.shape} vs original {combined_metadata.shape}")
    
    # Check column consistency
    if list(test_metadata.columns) == list(combined_metadata.columns):
        print("✅ Columns match original")
    else:
        print("❌ Column mismatch detected")
    
    # Check cell ID uniqueness
    if test_metadata["cell_id"].nunique() == len(test_metadata):
        print("✅ All cell IDs unique in saved file")
    else:
        print("❌ Duplicate cell IDs found in saved file")
        
    del test_metadata  # Free memory
    
except Exception as e:
    print(f"❌ Error reading saved metadata: {e}")

# %%
print("\n=== FINAL PROCESSING SUMMARY ===")
print("🎉 METADATA PROCESSING PIPELINE COMPLETE!")
print(f"📅 Processing date: {today}")

print(f"\n📊 DATASET SUMMARY:")
dataset_summary = combined_metadata.groupby("dataset").agg({
    'cell_id': 'count',
    'donor': 'nunique', 
    'broad_cell_type': 'nunique',
    'age': ['min', 'max']
}).round(1)

print(dataset_summary)

print(f"\n🔬 CELL TYPE SUMMARY:")
print(f"Total broad cell types: {combined_metadata['broad_cell_type'].nunique():,}")
print(f"Shared cell types (≥{min_cells_per_dataset} cells/dataset): {len(shared_broad_cell_types)}")
print(f"Cells in shared types: {total_shared_cells:,} ({shared_percentage:.1f}%)")

print(f"\n📁 OUTPUT FILES:")
print(f"✅ Combined metadata: {output_path}")
if 'allen_brain' in filtering_results:
    print(f"✅ Allen Brain junction list: {clean_ab}")
if 'tabula_sapiens' in filtering_results:
    print(f"✅ Tabula Sapiens junction list: {clean_ts}")

print(f"\n📈 JUNCTION FILTERING SUMMARY:")
for dataset, results in filtering_results.items():
    print(f"{dataset}:")
    print(f"  Original: {results['original_count']:,}")
    print(f"  Retained: {results['filtered_count']:,}")
    print(f"  Retention rate: {results['retention_rate']*100:.1f}%")

print(f"\n🎯 NEXT STEPS:")
print("1. Use filtered junction files for Leaflet and ATSEmapper pipeline")
print("2. Filter expression data to shared cell types for cross-dataset analysis")
print("3. Validate junction file existence before running downstream analysis")

print("\n✅ ALL PROCESSING COMPLETE WITH COMPREHENSIVE VALIDATION!")
print("=" * 80)

# %%
# Optional: Display final metadata sample and statistics
print("\n=== FINAL METADATA SAMPLE ===")
print("First 5 rows:")
print(combined_metadata.head())

print(f"\nFinal metadata statistics:")
print(f"Shape: {combined_metadata.shape}")
print(f"Memory usage: {combined_metadata.memory_usage(deep=True).sum() / 1024**2:.1f} MB")

# Show data types
print(f"\nData types:")
for col in combined_metadata.columns:
    dtype = combined_metadata[col].dtype
    unique_vals = combined_metadata[col].nunique()
    print(f"  {col}: {dtype} ({unique_vals:,} unique values)")

# Show sample values for categorical columns
categorical_cols = ['dataset', 'sex', 'tissue', 'broad_cell_type']
print(f"\nSample values for key columns:")
for col in categorical_cols:
    if col in combined_metadata.columns:
        values = combined_metadata[col].value_counts().head(3)
        print(f"  {col}: {dict(values)}")

print(f"\n📋 Processing complete - all data validated and ready for analysis!")


=== FILTERING JUNCTION FILES ===
--- Junction File Validation ---
✅ Allen Brain junction list file found: /gpfs/commons/groups/knowles_lab/Karin/Leaflet-analysis-WD/HUMAN_SPLICING_FOUNDATION/ATSE_mapper/junction_files_AB_20250920.txt
✅ Tabula Sapiens junction list file found: /gpfs/commons/groups/knowles_lab/Karin/Leaflet-analysis-WD/HUMAN_SPLICING_FOUNDATION/ATSE_mapper/junction_files_TS_20250920.txt
Valid cell IDs for filtering: 90,918

--- Filtering Allen Brain Junction Files ---
Total junction files in list: 50,625
Junction filtering results for Allen Brain:
  Original files: 50,625
  Files with metadata: 49,356
  Files excluded: 1,269
  Retention rate: 97.5%
  Example excluded cell IDs: ['F1S4_170302_092_B01', 'F1S4_160831_074_H01', 'F2S4_170405_052_C01']
✅ Filtered Allen Brain list saved to: /gpfs/commons/groups/knowles_lab/Karin/Leaflet-analysis-WD/HUMAN_SPLICING_FOUNDATION/ATSE_mapper/junction_files_AB_subset.txt
✅ Verification passed: 49356 lines saved correctly

--- Filterin