# Paper Experiments: SCimilarity for AML Annotation

This notebook runs all experiments to support the paper claims.

**Note**: Update `DATA_PATH` below to point to your AML_scAtlas.h5ad file.

In [None]:
# Setup and Imports
import sys
import warnings
warnings.filterwarnings('ignore')

# Add SCCL to path
sys.path.insert(0, '/home/user/aml-batch-correction')

# Core imports
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
import seaborn as sns

# SCCL imports
from sccl import Pipeline
from sccl.data import subset_data, preprocess_data
from sccl.evaluation import (
    compute_metrics, 
    compute_per_class_metrics,
    compute_confusion_stats,
    plot_confusion_matrix,
    plot_umap
)

# Configuration
DATA_PATH = "/home/daniilf/full_aml_tasks/batch_correction/data/AML_scAtlas.h5ad"

# Set display options
pd.set_option('display.max_rows', 100)
plt.rcParams['figure.figsize'] = (12, 8)

print("‚úì Setup complete")

In [None]:
# Load Data
print("Loading AML scAtlas...")
adata = sc.read_h5ad(DATA_PATH)

print(f"\nüìä Dataset Overview:")
print(f"  Cells: {adata.n_obs:,}")
print(f"  Genes: {adata.n_vars:,}")
print(f"  Studies: {adata.obs['study'].nunique()}")
print(f"  Cell Types: {adata.obs['cell_type'].nunique()}")

# Show cell type distribution
print("\nüî¨ Cell Type Distribution:")
print(adata.obs['cell_type'].value_counts())

## Experiment 1: Annotation Replication

**Question**: Can SCimilarity approximate the expert consensus pipeline?

**Expected**: ARI > 0.70 indicates good replication

In [None]:
# Subset to Van Galen studies
van_galen_studies = [
    'van_galen_2019',
    'zhang_2023',
    'beneyto-calabuig-2023',
    'jiang_2020',
    'velten_2021',
    'zhai_2022',
]

available_studies = adata.obs['study'].unique()
valid_studies = [s for s in van_galen_studies if s in available_studies]

print(f"Using {len(valid_studies)} Van Galen studies:")
for study in valid_studies:
    n_cells = (adata.obs['study'] == study).sum()
    print(f"  ‚Ä¢ {study}: {n_cells:,} cells")

adata_vg = subset_data(adata, studies=valid_studies)
print(f"\nSubset: {adata_vg.n_obs:,} cells")

In [None]:
# Run SCimilarity
print("Running SCimilarity predictions...")
pipeline = Pipeline(model="scimilarity", batch_key="study")
predictions = pipeline.predict(adata_vg.copy(), target_column="cell_type")

# Compute metrics
print("\nComputing metrics...")
metrics = compute_metrics(
    y_true=adata_vg.obs['cell_type'].values,
    y_pred=predictions,
    adata=adata_vg,
    metrics=['accuracy', 'ari', 'nmi', 'f1']
)

print("\n" + "="*60)
print("RESULTS")
print("="*60)
for metric, value in metrics.items():
    print(f"  {metric.upper():20s}: {value:.4f}")
print("="*60)

# Interpretation
ari = metrics['ari']
if ari > 0.80:
    print("\n‚úÖ EXCELLENT: SCimilarity closely approximates expert consensus")
elif ari > 0.70:
    print("\n‚úÖ GOOD: SCimilarity approximates expert consensus well")
elif ari > 0.60:
    print("\n‚ö†Ô∏è MODERATE: Some agreement but room for improvement")
else:
    print("\n‚ùå LOW: Significant discrepancy from expert annotations")

In [None]:
# Per-class performance
per_class = compute_per_class_metrics(
    y_true=adata_vg.obs['cell_type'].values,
    y_pred=predictions
)

per_class_df = pd.DataFrame(per_class).T.sort_values('support')

print("Per-Class Performance (sorted by rarity):")
print(per_class_df)

# Identify rare types
threshold = adata_vg.n_obs * 0.01
rare_types = per_class_df[per_class_df['support'] < threshold]

if len(rare_types) > 0:
    print(f"\nüî¨ Rare Cell Types (< 1% frequency):")
    print(rare_types[['f1', 'precision', 'recall', 'support']])
    print(f"\n  Average F1 on rare types: {rare_types['f1'].mean():.3f}")

In [None]:
# Confusion Matrix
fig = plot_confusion_matrix(
    y_true=adata_vg.obs['cell_type'].values,
    y_pred=predictions,
    normalize=True,
    figsize=(14, 12)
)
plt.show()

## Experiment 2: Label Transfer Benchmark

**Question**: Is SCimilarity better than traditional ML for cross-study transfer?

**Setup**: Train on van_galen_2019, test on other studies

In [None]:
# Label Transfer Benchmark
reference_study = 'van_galen_2019'
query_studies = [s for s in valid_studies if s != reference_study][:2]  # Test first 2 for speed

print(f"Reference: {reference_study}")
print(f"Query studies: {query_studies}")

# Prepare reference
adata_ref = subset_data(adata_vg, studies=[reference_study])

# Models to test
models_to_test = {
    'SCimilarity': 'scimilarity',
    'Random Forest': 'random_forest',
    'SVM': 'svm',
    'KNN': 'knn',
}

# Results storage
transfer_results = []

# Test each query study
for query_study in query_studies:
    print(f"\nTesting on: {query_study}")
    adata_query = subset_data(adata_vg, studies=[query_study])
    
    for model_name, model_type in models_to_test.items():
        print(f"  {model_name}...", end=' ')
        
        try:
            pipeline = Pipeline(model=model_type)
            
            # Train if needed
            if hasattr(pipeline.model, 'fit'):
                adata_ref_prep = preprocess_data(adata_ref.copy(), batch_key=None)
                pipeline.model.fit(adata_ref_prep, target_column='cell_type')
            
            # Predict
            adata_query_prep = preprocess_data(adata_query.copy(), batch_key=None)
            pred = pipeline.model.predict(adata_query_prep, target_column=None)
            
            # Evaluate
            metrics = compute_metrics(
                y_true=adata_query.obs['cell_type'].values,
                y_pred=pred,
                metrics=['accuracy', 'ari', 'nmi', 'f1_macro']
            )
            
            transfer_results.append({
                'model': model_name,
                'query_study': query_study,
                'accuracy': metrics['accuracy'],
                'ari': metrics['ari'],
                'nmi': metrics['nmi'],
                'f1': metrics['f1_macro']
            })
            
            print(f"‚úì ARI: {metrics['ari']:.3f}")
            
        except Exception as e:
            print(f"‚úó Error: {e}")

# Show results
transfer_df = pd.DataFrame(transfer_results)
print("\n" + "="*60)
print("RESULTS")
print("="*60)
print(transfer_df)

# Average by model
print("\nAverage by Model:")
avg_by_model = transfer_df.groupby('model')[['accuracy', 'ari', 'f1']].mean()
print(avg_by_model.sort_values('ari', ascending=False))

## Experiment 3: Computational Efficiency

**Question**: How fast is SCimilarity compared to traditional pipeline?

In [None]:
import time

# Subsample for timing
adata_timing = subset_data(adata_vg.copy(), n_cells=min(5000, adata_vg.n_obs))
print(f"Timing on {adata_timing.n_obs:,} cells")

timing_results = []

# Time SCimilarity
print("\nTiming SCimilarity...")
start = time.time()
pipeline_scim = Pipeline(model="scimilarity")
pred_scim = pipeline_scim.predict(adata_timing.copy())
scim_time = time.time() - start
timing_results.append({'method': 'SCimilarity', 'time_seconds': scim_time})
print(f"  ‚úì {scim_time:.1f} seconds ({scim_time/60:.2f} minutes)")

# Time Random Forest
print("\nTiming Random Forest...")
start = time.time()
pipeline_rf = Pipeline(model="random_forest")
pred_rf = pipeline_rf.predict(adata_timing.copy(), target_column='cell_type')
rf_time = time.time() - start
timing_results.append({'method': 'Random Forest', 'time_seconds': rf_time})
print(f"  ‚úì {rf_time:.1f} seconds ({rf_time/60:.2f} minutes)")

# Estimate traditional pipeline
traditional_time = 27 * 60  # CellTypist + SingleR + scType (27 minutes)
timing_results.append({'method': 'Traditional Pipeline', 'time_seconds': traditional_time})

# Show results
timing_df = pd.DataFrame(timing_results)
timing_df['time_minutes'] = timing_df['time_seconds'] / 60
timing_df['speedup'] = traditional_time / timing_df['time_seconds']

print("\n" + "="*60)
print("TIMING RESULTS")
print("="*60)
print(timing_df[['method', 'time_minutes', 'speedup']])

print(f"\nSpeedup: {traditional_time/scim_time:.1f}x faster than traditional pipeline")

## Summary Report

All experiments completed!

In [None]:
print("="*80)
print("PAPER EXPERIMENTS SUMMARY")
print("="*80)

print("\n1. Annotation Replication")
print(f"   ARI: {metrics['ari']:.4f}")
print(f"   Accuracy: {metrics['accuracy']:.4f}")

if 'transfer_df' in locals():
    print("\n2. Label Transfer")
    avg = transfer_df.groupby('model')['ari'].mean()
    best = avg.idxmax()
    print(f"   Best: {best} (ARI: {avg[best]:.3f})")

if 'timing_df' in locals():
    print("\n3. Computational Efficiency")
    print(f"   SCimilarity: {scim_time/60:.1f} min")
    print(f"   Speedup: {traditional_time/scim_time:.1f}x")

print("\n" + "="*80)
print("‚úÖ All experiments completed!")
print("="*80)