In [4]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import warnings
warnings.filterwarnings('ignore')

# Load the datasets
print("Loading data...")
healthy_data = pd.read_csv('data/merged_unstranded_healthy_data.csv', index_col=0, low_memory=False)
aml_data = pd.read_csv('data/merged_unstranded_unhealthy_data.csv', index_col=0, low_memory=False)

healthy_data.index.name = 'gene_id'
aml_data.index.name = 'gene_id'
# Load metadata
healthy_meta = pd.read_csv('data/healthy_aml_metadata.csv')
aml_meta = pd.read_csv('data/unhealthy_aml_metadata.csv')

# Combine metadata
metadata = pd.concat([healthy_meta, aml_meta], ignore_index=True)
print(f"Total samples in metadata: {len(metadata)}")

# Prepare gene expression data
# Get gene info columns
gene_cols = ['gene_id', 'gene_name', 'gene_type']

# Get sample columns (everything except gene info)
healthy_samples = [col for col in healthy_data.columns if col not in gene_cols]
aml_samples = [col for col in aml_data.columns if col not in gene_cols]

print(f"Healthy samples: {len(healthy_samples)}")
print(f"AML samples: {len(aml_samples)}")

# Combine expression data
# Set gene_id as index
healthy_expr = healthy_data.set_index('gene_id')[healthy_samples]
aml_expr = aml_data.set_index('gene_id')[aml_samples]

# Combine into one dataframe
combined_expr = pd.concat([healthy_expr, aml_expr], axis=1)
print(f"Combined expression matrix shape: {combined_expr.shape}")

# Filter low-count genes (keep genes with >1 count in at least 10 samples)
gene_filter = (combined_expr > 1).sum(axis=1) >= 10
filtered_expr = combined_expr.loc[gene_filter]
print(f"After filtering: {filtered_expr.shape}")

# Log transform for visualization
log_expr = np.log2(filtered_expr + 1)

# Match metadata to expression columns
metadata_matched = metadata[metadata['uuid'].isin(log_expr.columns)].copy()
metadata_matched = metadata_matched.set_index('uuid').reindex(log_expr.columns)

print(f"Matched samples: {len(metadata_matched)}")

# Check for missing values and clean data
print(f"Missing values in condition: {metadata_matched['condition'].isna().sum()}")
print(f"Missing values in workflow_version: {metadata_matched['workflow_version'].isna().sum()}")

# Remove rows with missing essential data
metadata_matched = metadata_matched.dropna(subset=['condition', 'workflow_version'])
print(f"After removing NaN values: {len(metadata_matched)}")

# Ensure expression data matches cleaned metadata
log_expr = log_expr[metadata_matched.index]
print(f"Final expression matrix shape: {log_expr.shape}")

# Check batch distribution
print("\n=== BATCH DISTRIBUTION ===")
batch_dist = pd.crosstab(metadata_matched['condition'],
                        metadata_matched['workflow_version'],
                        margins=True)
print(batch_dist)

# PCA for batch visualization
print("\nRunning PCA...")
scaler = StandardScaler()
scaled_data = scaler.fit_transform(log_expr.T)  # Transpose: samples x genes
pca = PCA(n_components=10)
pca_result = pca.fit_transform(scaled_data)

# Create visualization
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

# Plot 1: PC1 vs PC2 colored by condition
ax1 = axes[0, 0]
colors_condition = {'healthy': 'blue', 'aml': 'red'}
for condition in metadata_matched['condition'].unique():
    mask = metadata_matched['condition'] == condition
    ax1.scatter(pca_result[mask, 0], pca_result[mask, 1],
               c=colors_condition.get(condition, 'gray'),
               label=condition, alpha=0.7, s=50)
ax1.set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.2%} variance)')
ax1.set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.2%} variance)')
ax1.set_title('PCA: Colored by Condition')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Plot 2: PC1 vs PC2 colored by batch (workflow_version)
ax2 = axes[0, 1]
colors_batch = {'122a0dd1445b2664b1b40b7df7b0e2240183d712': 'orange',
                '61fd5ef8ab410a784da2e89eca063ca3c66998ec': 'green'}
for batch in metadata_matched['workflow_version'].unique():
    if pd.isna(batch):  # Skip NaN values
        continue
    mask = metadata_matched['workflow_version'] == batch
    batch_short = str(batch)[:8] + '...' if len(str(batch)) > 8 else str(batch)
    ax2.scatter(pca_result[mask, 0], pca_result[mask, 1],
               c=colors_batch.get(batch, 'gray'),
               label=f'Batch {batch_short}', alpha=0.7, s=50)
ax2.set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.2%} variance)')
ax2.set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.2%} variance)')
ax2.set_title('PCA: Colored by Batch (Workflow Version)')
ax2.legend()
ax2.grid(True, alpha=0.3)

# Plot 3: PC1 vs PC3 colored by condition
ax3 = axes[1, 0]
for condition in metadata_matched['condition'].unique():
    mask = metadata_matched['condition'] == condition
    ax3.scatter(pca_result[mask, 0], pca_result[mask, 2],
               c=colors_condition.get(condition, 'gray'),
               label=condition, alpha=0.7, s=50)
ax3.set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.2%} variance)')
ax3.set_ylabel(f'PC3 ({pca.explained_variance_ratio_[2]:.2%} variance)')
ax3.set_title('PC1 vs PC3: Colored by Condition')
ax3.legend()
ax3.grid(True, alpha=0.3)

# Plot 4: Variance explained
ax4 = axes[1, 1]
ax4.bar(range(1, 11), pca.explained_variance_ratio_[:10] * 100)
ax4.set_xlabel('Principal Component')
ax4.set_ylabel('Variance Explained (%)')
ax4.set_title('Variance Explained by Top 10 PCs')
ax4.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Statistical analysis
print("\n=== BATCH EFFECT ANALYSIS ===")
print(f"PC1 explains {pca.explained_variance_ratio_[0]:.2%} of variance")
print(f"PC2 explains {pca.explained_variance_ratio_[1]:.2%} of variance")

# Check correlation between PCs and batch/condition
from scipy.stats import chi2_contingency, f_oneway

# PC1 vs condition
pc1_healthy = pca_result[metadata_matched['condition'] == 'healthy', 0]
pc1_aml = pca_result[metadata_matched['condition'] == 'aml', 0]
f_stat, p_val = f_oneway(pc1_healthy, pc1_aml)
print(f"\nPC1 vs Condition: F-statistic = {f_stat:.3f}, p-value = {p_val:.2e}")

# PC1 vs batch
batches = metadata_matched['workflow_version'].dropna().unique()
if len(batches) == 2:
    pc1_batch1 = pca_result[metadata_matched['workflow_version'] == batches[0], 0]
    pc1_batch2 = pca_result[metadata_matched['workflow_version'] == batches[1], 0]
    f_stat_batch, p_val_batch = f_oneway(pc1_batch1, pc1_batch2)
    print(f"PC1 vs Batch: F-statistic = {f_stat_batch:.3f}, p-value = {p_val_batch:.2e}")

# Interpretation
print("\n=== INTERPRETATION ===")
if p_val < 0.001:
    print("✓ Strong biological separation (condition effect)")
else:
    print("⚠ Weak biological separation")

if 'p_val_batch' in locals() and p_val_batch < 0.001:
    print("⚠ Strong batch effect detected - consider batch correction")
else:
    print("✓ Minimal batch effect")

print(f"\nBatch confounding summary:")
for i, batch in enumerate(batches):
    batch_count = sum(metadata_matched['workflow_version'] == batch)
    batch_short = str(batch)[:8] + '...' if len(str(batch)) > 8 else str(batch)
    print(f"- Batch {i+1} ({batch_short}): {batch_count} samples")

# Sample information for verification
print(f"\nSample verification:")
print(f"Expression matrix samples: {list(log_expr.columns[:3])}...")  # First 3 sample IDs
print(f"Metadata samples: {list(metadata_matched.index[:3])}...")     # First 3 sample IDs

Loading data...
Total samples in metadata: 442
Healthy samples: 60660
AML samples: 60660


KeyError: "None of ['gene_id'] are in the columns"