In [None]:
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
import gc

def apply_tpm_normalization():
    """
    Apply TPM normalization using your retrieved gene lengths.
    """

    print("🧬 TPM NORMALIZATION PIPELINE")
    print("="*50)

    # Step 1: Load the raw merged dataset
    print("📂 Loading raw merged dataset...")
    merged_df = pd.read_parquet('data/merged_dataset.pq')

    X = merged_df.drop(columns=['condition'])
    y = merged_df['condition']

    print(f"✅ Loaded dataset: {X.shape}")
    print(f"   Value range: {X.min().min()} to {X.max().max():,}")
    print(f"   Class distribution: {y.value_counts().to_dict()}")

    # Step 2: Load gene lengths
    print("\n📏 Loading gene lengths...")

    #TODO: Update this to your actual gene lengths file from get_gene_lengths_mygene
    gene_length_files = ''

    gene_lengths = None
    for filename in gene_length_files:
        try:
            gene_lengths_df = pd.read_csv(filename, index_col=0)
            gene_lengths = gene_lengths_df.iloc[:, 0]  # First column should be lengths
            print(f"✅ Loaded gene lengths from: {filename}")
            break
        except FileNotFoundError:
            continue

    if gene_lengths is None:
        print("❌ Could not find gene lengths file. Please ensure it's saved as 'gene_lengths.csv'")
        print("Expected columns: gene_id (index), length")
        return None

    print(f"✅ Gene lengths loaded: {len(gene_lengths)} genes")
    print(f"   Length range: {gene_lengths.min():,.0f} - {gene_lengths.max():,.0f} bp")
    print(f"   Median length: {gene_lengths.median():,.0f} bp")

    # Step 3: Align gene lengths with expression data
    print("\n🎯 Aligning gene lengths with expression data...")

    # Find common genes between expression data and gene lengths
    common_genes = list(set(X.columns) & set(gene_lengths.index))
    missing_genes = list(set(X.columns) - set(gene_lengths.index))

    print(f"   Common genes: {len(common_genes)}")
    print(f"   Missing gene lengths: {len(missing_genes)}")

    if missing_genes:
        print(f"   First few missing: {missing_genes[:5]}")

        # Use median length for missing genes
        median_length = gene_lengths.median()
        print(f"   Using median length ({median_length:,.0f} bp) for missing genes")

        for gene in missing_genes:
            gene_lengths[gene] = median_length

    # Subset to common genes and align order
    gene_lengths_aligned = gene_lengths[X.columns]

    print(f"✅ Aligned gene lengths: {len(gene_lengths_aligned)} genes")

    # Step 4: Calculate TPM
    print("\n🔢 Calculating TPM...")

    # Convert gene lengths to kilobases
    gene_lengths_kb = gene_lengths_aligned / 1000

    print(f"   Gene lengths in kb: {gene_lengths_kb.min():.2f} - {gene_lengths_kb.max():.2f}")

    # Calculate reads per kilobase (RPK)
    print("   Calculating RPK (reads per kilobase)...")
    rpk = X.div(gene_lengths_kb, axis=1)

    print(f"   RPK range: {rpk.min().min():.2f} - {rpk.max().max():,.2f}")

    # Calculate scaling factor (total RPK per sample / 1M)
    print("   Calculating scaling factors...")
    total_rpk_per_sample = rpk.sum(axis=1)
    scaling_factors = total_rpk_per_sample / 1e6

    print(f"   Scaling factors range: {scaling_factors.min():.2f} - {scaling_factors.max():.2f}")
    print(f"   Scaling factors mean: {scaling_factors.mean():.2f}")

    # Calculate TPM
    print("   Calculating final TPM values...")
    tpm = rpk.div(scaling_factors, axis=0)

    print(f"✅ TPM calculated!")
    print(f"   TPM range: {tpm.min().min():.2f} - {tpm.max().max():,.2f}")

    # Verify TPM sums (should be ~1M per sample)
    tpm_sums = tpm.sum(axis=1)
    print(f"   TPM sums per sample: {tpm_sums.min():,.0f} - {tpm_sums.max():,.0f}")
    print(f"   Mean TPM sum: {tpm_sums.mean():,.0f} (should be ~1,000,000)")

    if abs(tpm_sums.mean() - 1e6) > 1000:
        print("   ⚠️ Warning: TPM sums not close to 1M - check calculation")
    else:
        print("   ✅ TPM sums look correct!")

    # Step 5: Log transformation
    print("\n📈 Applying log2 transformation...")

    # Add pseudocount to avoid log(0)
    tpm_log = np.log2(tpm + 1)

    print(f"   Log2(TPM+1) range: {tpm_log.min().min():.2f} - {tpm_log.max().max():.2f}")

    # Step 6: Create final dataframe and save
    print("\n💾 Saving TPM-normalized data...")

    # Add condition back
    tpm_final = tpm_log.copy()
    tpm_final['condition'] = y

    # Save raw TPM (before log)
    tpm_raw = tpm.copy()
    tpm_raw['condition'] = y
    tpm_raw.to_parquet('data/merged_dataset_tpm_raw.pq')

    # Save log-transformed TPM
    tpm_final.to_parquet('data/merged_dataset_tpm_normalized.pq')

    # Save TPM calculation details for reference
    tpm_info = {
        'total_genes': len(X.columns),
        'genes_with_lengths': len(common_genes),
        'genes_missing_lengths': len(missing_genes),
        'median_gene_length': gene_lengths.median(),
        'mean_tpm_sum': tpm_sums.mean(),
        'tpm_range': (tpm.min().min(), tpm.max().max()),
        'log_tpm_range': (tpm_log.min().min(), tpm_log.max().max())
    }

    pd.Series(tpm_info).to_csv('data/tpm_normalization_info.csv')

    print(f"✅ Files saved:")
    print(f"   📁 Raw TPM: data/merged_dataset_tpm_raw.pq")
    print(f"   📁 Log TPM: data/merged_dataset_tpm_normalized.pq")
    print(f"   📁 TPM info: data/tpm_normalization_info.csv")

    # Step 7: Summary statistics
    print(f"\n📊 FINAL TPM SUMMARY:")
    print(f"   Dataset shape: {tpm_final.shape}")
    print(f"   Genes processed: {len(X.columns):,}")
    print(f"   Samples: {len(X):,}")
    print(f"   Class distribution: {y.value_counts().to_dict()}")
    print(f"   Memory usage: {tpm_final.memory_usage(deep=True).sum() / 1024**2:.1f} MB")

    # Cleanup
    del merged_df, X, rpk, tpm, tpm_raw
    gc.collect()

    return tpm_final, gene_lengths_aligned, tpm_info


In [None]:
def get_gene_lengths_mygene(gene_ids):
    """Get gene lengths using MyGene API."""

    import mygene
    mg = mygene.MyGeneInfo()

    print(f"Getting lengths for {len(gene_ids)} genes...")

    # Query MyGene for gene lengths
    results = mg.querymany(
        gene_ids,
        scopes='ensembl.gene',
        fields='genomic_pos',  # Contains start/end positions
        species='human',
        verbose=False
    )

    gene_lengths = {}
    default_fallback_genes = {}
    for result in results:
        gene_id = result['query']

        if 'genomic_pos' in result:
            # Calculate length from genomic positions
            pos_info = result['genomic_pos']
            if isinstance(pos_info, list):
                pos_info = pos_info[0]  # Take first if multiple

            if 'start' in pos_info and 'end' in pos_info:
                length = abs(pos_info['end'] - pos_info['start'])
                gene_lengths[gene_id] = length
            else:
                default_fallback_genes[gene_id] = pos_info
                gene_lengths[gene_id] = 10000  # Default fallback
        else:
            default_fallback_genes = {gene_id: None}
            gene_lengths[gene_id] = 10000  # Default fallback

    print(f"Found lengths for {len(gene_lengths)} genes.")
    if default_fallback_genes:
        print(f"Using default length for {len(default_fallback_genes)} genes: 10,000 bp")
    return pd.Series(gene_lengths)

## Example usage:
# TODO: Update this to your actual gene IDs
if __name__ == "__main__":
    # Load in dataset and get all gene IDs
    # merged_df = pd.read_parquet('data/merged_dataset.pq')
    # gene_ids = merged_df.drop(columns=['condition']).columns.tolist()
